commit cec13a218d043d0354d86e81341edfd4296392f2 Author: austin Date: Sat Dec 28 18:01:52 2024 -0500 first commit diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..246763b --- /dev/null +++ b/.gitignore @@ -0,0 +1,10 @@ +.mypy_cache +.venv +dist +target +src/trpg.egg-info + +# ml model files +data/semantic/model +**.keras +**.h5 \ No newline at end of file diff --git a/.vscode/extensions.json b/.vscode/extensions.json new file mode 100644 index 0000000..67e0b2e --- /dev/null +++ b/.vscode/extensions.json @@ -0,0 +1,5 @@ +{ + "recommendations": [ + "ms-python.python", + ] +} \ No newline at end of file diff --git a/.vscode/launch.json b/.vscode/launch.json new file mode 100644 index 0000000..2978ffd --- /dev/null +++ b/.vscode/launch.json @@ -0,0 +1,21 @@ +{ + // Use IntelliSense to learn about possible attributes. + // Hover to view descriptions of existing attributes. + // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 + "version": "0.2.0", + "configurations": [ + { + "type": "java", + "name": "Launch Java Program", + "request": "launch", + "mainClass": "" + }, + { + "name": "Python Debugger: Current File", + "type": "debugpy", + "request": "launch", + "program": "${file}", + "console": "integratedTerminal" + } + ] +} \ No newline at end of file diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 0000000..0153b31 --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,4 @@ +{ + "java.configuration.updateBuildConfiguration": "interactive", + "java.compile.nullAnalysis.mode": "disabled" +} \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000..975a79b --- /dev/null +++ b/README.md @@ -0,0 +1,24 @@ + + + +``` +python -m venv .venv +``` + +``` +source ./.venv/Scripts/activate +``` + +``` +pip install tensorflow mypy numpy +``` + +``` +pip install keras==2.15.0 +``` + +``` +pip install types-tensorflow +``` + +In settings, make sure to check `Run Using Active Interpreter` in mypy settings \ No newline at end of file diff --git a/current_goal.txt b/current_goal.txt new file mode 100644 index 0000000..166945a --- /dev/null +++ b/current_goal.txt @@ -0,0 +1,39 @@ +context: +user is a barbarian +chatbot is a wizard +sitting in a tavern by a fireplace + + +interaction: +the user says hello +the chatbot replies and queries the user's name (query for name relation to other participant) +the user answers and then asks the chatbot about the color of its hat (associate name relation to other participant) +the chatbot replies with correct information about the hat (it is blue) (lookup complex association and synthesize response) + + + +things needed: +knowledge web + - a person + - a place + - a thing + - a relationship + - between persons (ie the lack of relationship between the user and chatbot) + - between a person and a thing (the hat the wizard is wearing) + - information about a thing + - color + +routines for handling small chat + - introductions + - initial goal routine to query conversation partner for information about themself + +information queries + - determine target of query + - determine quality/quantity being queried + +information transfers + - determine information that is being transferred + - associate the new information with the knowledge web + +distinguish between sentence types (query vs transfer vs command, etc) + diff --git a/data/semantic/subject.txt b/data/semantic/subject.txt new file mode 100644 index 0000000..5159c18 --- /dev/null +++ b/data/semantic/subject.txt @@ -0,0 +1,8 @@ +fdsafdsa +fdsafdsafdsaf +dsafd +saf +dsafdfds +afds +i +self \ No newline at end of file diff --git a/data/semantic/subject_label.txt b/data/semantic/subject_label.txt new file mode 100644 index 0000000..6653c2c --- /dev/null +++ b/data/semantic/subject_label.txt @@ -0,0 +1,8 @@ +1 +1 +1 +1 +0 +0 +0 +0 \ No newline at end of file diff --git a/data/tokens.txt b/data/tokens.txt new file mode 100644 index 0000000..db2b8f0 --- /dev/null +++ b/data/tokens.txt @@ -0,0 +1,3 @@ +self +i +I \ No newline at end of file diff --git a/dependency-reduced-pom.xml b/dependency-reduced-pom.xml new file mode 100644 index 0000000..89d568d --- /dev/null +++ b/dependency-reduced-pom.xml @@ -0,0 +1,38 @@ + + + 4.0.0 + studiorailgun + Renderer + 0.1.1 + + + + maven-shade-plugin + 3.2.4 + + + package + + shade + + + + + org.studiorailgun.Main + + + + + + + + + + + 17 + 17 + UTF-8 + 1.0.0-M2 + nd4j-native + + diff --git a/improvement_ideas.txt b/improvement_ideas.txt new file mode 100644 index 0000000..f325292 --- /dev/null +++ b/improvement_ideas.txt @@ -0,0 +1 @@ +summarize previous statements to provide context instead of using full statement \ No newline at end of file diff --git a/mypy.ini b/mypy.ini new file mode 100644 index 0000000..a9337fb --- /dev/null +++ b/mypy.ini @@ -0,0 +1,27 @@ +# Global options: + +[mypy] +install_types = True +disallow_untyped_defs = True +no_implicit_optional = True +check_untyped_defs = True +warn_return_any = True +show_error_codes = True +warn_unused_ignores = True +warn_unused_configs = True +plugins = numpy.typing.mypy_plugin + +disallow_incomplete_defs = True +disallow_untyped_decorators = False +disallow_any_unimported = False + +# Per-module options: + +[mypy-mycode.foo.*] +disallow_untyped_defs = True + +[mypy-mycode.bar] +warn_return_any = False + +[mypy-somelibrary] +ignore_missing_imports = True \ No newline at end of file diff --git a/pom.xml b/pom.xml new file mode 100644 index 0000000..f4d6481 --- /dev/null +++ b/pom.xml @@ -0,0 +1,92 @@ + + + 4.0.0 + studiorailgun + Renderer + 0.1.1 + jar + + UTF-8 + 21 + 21 + 1.0.0-M2 + + nd4j-native + + + + + + + org.deeplearning4j + deeplearning4j-core + ${dl4j-master.version} + + + org.deeplearning4j + deeplearning4j-nlp + ${dl4j-master.version} + + + + + + org.tensorflow + tensorflow-core-api + 1.0.0 + + + org.tensorflow + tensorflow-core-platform + 1.0.0 + + + org.tensorflow + tensorflow-core-native + 1.0.0 + + + + + + + com.google.code.gson + gson + 2.8.6 + + + + + + + + + + org.apache.maven.plugins + maven-shade-plugin + 3.2.4 + + + package + + shade + + + + + org.studiorailgun.Main + + + + + + + + + + + \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..05bdb5d --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,17 @@ +[build-system] +requires = ["setuptools>=68", "setuptools_scm[toml]>=8"] +build-backend = "setuptools.build_meta" + +[project] +name = "trpg" +requires-python = ">=3.8" +dynamic = ["version"] +dependencies = [ + # Add runtime dependencies here + "numpy", + "keras", + "tensorflow" +] + +# Enables the usage of setuptools_scm +[tool.setuptools_scm] \ No newline at end of file diff --git a/src/main/java/org/studiorailgun/FileUtils.java b/src/main/java/org/studiorailgun/FileUtils.java new file mode 100644 index 0000000..8cdd831 --- /dev/null +++ b/src/main/java/org/studiorailgun/FileUtils.java @@ -0,0 +1,280 @@ +package org.studiorailgun; + +import com.google.gson.Gson; +import com.google.gson.GsonBuilder; + +import java.io.BufferedReader; +import java.io.ByteArrayOutputStream; +import java.io.File; +import java.io.IOException; +import java.io.InputStream; +import java.io.InputStreamReader; +import java.io.ObjectOutputStream; +import java.io.Serializable; +import java.nio.file.Files; +import java.nio.file.Path; +import java.security.MessageDigest; +import java.security.NoSuchAlgorithmException; +import java.util.LinkedList; +import java.util.List; +import java.util.concurrent.TimeUnit; + +/** + * Utilities for dealing with files + */ +public class FileUtils { + + /** + * Creates the gson instance + */ + static { + //init gson + GsonBuilder gsonBuilder = new GsonBuilder(); + gson = gsonBuilder.create(); + } + + //used for serialization/deserialization in file operations + static Gson gson; + + //maximum number of attempt to read the file + static final int maxReadFails = 3; + //Timeout duration between read attempts + static final int READ_TIMEOUT_DURATION = 5; + /** + * Reads a file to a string + * @param f The file + * @return The string + */ + public static String readFileToString(File f){ + String rVal = ""; + BufferedReader reader; + try { + reader = Files.newBufferedReader(f.toPath()); + int failCounter = 0; + boolean reading = true; + StringBuilder builder = new StringBuilder(""); + while(reading){ + if(reader.ready()){ + failCounter = 0; + int nextValue = reader.read(); + if(nextValue == -1){ + reading = false; + } else { + builder.append((char)nextValue); + } + } else { + failCounter++; + if(failCounter > maxReadFails){ + reading = false; + } else { + try { + TimeUnit.MILLISECONDS.sleep(READ_TIMEOUT_DURATION); + } catch (InterruptedException ex) { + ex.printStackTrace(); + } + } + } + } + rVal = builder.toString(); + } catch (IOException ex) { + ex.printStackTrace(); + } + return rVal; + } + + + public static String readStreamToString(InputStream resourceInputStream){ + String rVal = ""; + BufferedReader reader; + try { + reader = new BufferedReader(new InputStreamReader(resourceInputStream)); + int failCounter = 0; + boolean reading = true; + StringBuilder builder = new StringBuilder(""); + while(reading){ + if(reader.ready()){ + failCounter = 0; + int nextValue = reader.read(); + if(nextValue == -1){ + reading = false; + } else { + builder.append((char)nextValue); + } + } else { + failCounter++; + if(failCounter > maxReadFails){ + reading = false; + } else { + try { + TimeUnit.MILLISECONDS.sleep(READ_TIMEOUT_DURATION); + } catch (InterruptedException ex) { + ex.printStackTrace(); + } + } + } + } + rVal = builder.toString(); + } catch (IOException ex) { + ex.printStackTrace(); + } + return rVal; + } + + + /** + * Sanitizes a relative file path, guaranteeing that the initial slash is correct + * @param filePath The raw file path + * @return The sanitized file path + */ + public static String sanitizeFilePath(String filePath){ + String rVal = new String(filePath); + rVal = rVal.trim(); + if(rVal.startsWith("./")){ + return rVal; + } else if(!rVal.startsWith("/")){ + rVal = "/" + rVal; + } + return rVal; + } + + /** + * Serializes an object to a filepath + * @param filePath The filepath + * @param object The object + */ + public static void serializeObjectToFilePath(String filePath, Object object){ + Path path = new File(filePath).toPath(); + try { + Files.write(path, gson.toJson(object).getBytes()); + } catch (IOException ex) { + ex.printStackTrace(); + } + } + + /** + * Loads an object from the assets folder + * @param The type of object + * @param file The file to load from + * @param className The class of the object inside the file + * @return The file + */ + public static T loadObjectFromFile(File file, Class className){ + T rVal = null; + try { + rVal = gson.fromJson(Files.newBufferedReader(file.toPath()), className); + } catch (IOException ex) { + ex.printStackTrace(); + } + return rVal; + } + + /** + * Checks if a directory exists + * @param fileName + * @return true if directory exists, false otherwise + */ + public static boolean checkFileExists(String fileName){ + File targetDir = new File(sanitizeFilePath(fileName)); + if(targetDir.exists()){ + return true; + } else { + return false; + } + } + + + /** + * Trys to create a directory + * @param directoryName + * @return true if directory was created, false if it was not + */ + public static boolean createDirectory(String directoryName){ + String sanitizedPath = sanitizeFilePath(directoryName); + File targetDir = new File(sanitizedPath); + if(targetDir.exists()){ + return false; + } else { + return targetDir.mkdirs(); + } + } + + /** + * Lists the files in a directory + * @param directoryName The path of the directory + * @return A list containing the names of all files inside that directory + */ + public static List listDirectory(String directoryName){ + List rVal = new LinkedList(); + String sanitizedPath = sanitizeFilePath(directoryName); + File targetDir = new File(sanitizedPath); + String[] files = targetDir.list(); + for(String name : files){ + rVal.add(name); + } + return rVal; + } + + /** + * Recursively deletes a path + * @param path The path + */ + public static void recursivelyDelete(String path){ + File file = new File(path); + if(file.isDirectory()){ + for(File child : file.listFiles()){ + recursivelyDelete(child.getAbsolutePath()); + } + } + if(file.exists()){ + try { + Files.delete(file.toPath()); + } catch (IOException e) { + // TODO Auto-generated catch block + e.printStackTrace(); + } + } + } + + /** + * Gets a file's path relative to a given directory + * @param file The file + * @param directory The directory + * @return The relative path + */ + public static String relativize(File file, File directory){ + return directory.toURI().relativize(file.toURI()).getPath(); + } + + + + /** + * Computes the checksum of an object + * @param object The object + * @return The checksum + * @throws IOException Thrown on io errors reading the file + * @throws NoSuchAlgorithmException Thrown if MD5 isn't supported + */ + public static String getChecksum(Serializable object) throws IOException, NoSuchAlgorithmException { + ByteArrayOutputStream baos = null; + ObjectOutputStream oos = null; + try { + baos = new ByteArrayOutputStream(); + oos = new ObjectOutputStream(baos); + oos.writeObject(object); + MessageDigest md = MessageDigest.getInstance("MD5"); + byte[] bytes = md.digest(baos.toByteArray()); + StringBuffer builder = new StringBuffer(); + for(byte byteCurr : bytes){ + builder.append(String.format("%02x",byteCurr)); + } + return builder.toString(); + } finally { + oos.close(); + baos.close(); + } + } + + + + +} diff --git a/src/main/java/org/studiorailgun/Globals.java b/src/main/java/org/studiorailgun/Globals.java new file mode 100644 index 0000000..7a9557a --- /dev/null +++ b/src/main/java/org/studiorailgun/Globals.java @@ -0,0 +1,25 @@ +package org.studiorailgun; + +import java.io.File; + +import org.studiorailgun.knowledge.KnowledgeWeb; + +/** + * Global variables + */ +public class Globals { + + /** + * The knowledge web + */ + public static KnowledgeWeb web; + + /** + * Initializes the knowledge web + */ + public static void init(){ + web = FileUtils.loadObjectFromFile(new File("web.json"), KnowledgeWeb.class); + web.initLinks(); + } + +} diff --git a/src/main/java/org/studiorailgun/Main.java b/src/main/java/org/studiorailgun/Main.java new file mode 100644 index 0000000..a73f21e --- /dev/null +++ b/src/main/java/org/studiorailgun/Main.java @@ -0,0 +1,82 @@ +package org.studiorailgun; + +import org.tensorflow.SavedModelBundle; +import org.tensorflow.Tensor; +import org.tensorflow.types.TInt64; + +/** + * The main class + */ +public class Main { + + /** + * The main method + */ + public static void main(String[] args){ + // AgentLoop.main(); + // try (Graph graph = new Graph()) { + // prediction_signature = tf.saved_model.signature_def_utils.predict_signature_def({"input": inputs}, {"output":output}); + SavedModelBundle model = SavedModelBundle.load("./data/semantic/model"); + System.out.println(model.functions()); + System.out.println(model.functions().get(0).signature()); + System.out.println(model.functions().get(1).signature()); + System.out.println(model.functions().get(2).signature()); + System.out.println(model.functions().get(0).signature().getInputs()); + System.out.println(model.functions().get(0).signature().getInputs().values().iterator().next().shape); + //init + // model.functions().get(2).session().runner() + // .fetch("__saved_model_init_op") + // .run(); + //run predict + TInt64 tensor = TInt64.scalarOf(10); + org.tensorflow.ndarray.Shape shape = tensor.shape().append(2001).append(128); + tensor = TInt64.tensorOf(shape); + // NdArray arr = NdArrays.vectorOfObjects(NdArrays.vectorOf((long)0, (long)0, (long)0, (long)0)); + // System.out.println(arr.shape()); + // tensor = TInt64.tensorOf(arr.shape()); + + Tensor result = model.function("serve").call(tensor); + // model.functions().get(0).session().runner() + // .feed("keras_tensor", tensor) + // .fetch("output_0") + // .run(); + + // Graph graph = model.graph(); + // Session session = model.session(); + + // Tensor inputTensor = Tensor.of(Long.class, model.functions().get(0).signature().getInputs().values().iterator().next().shape); + // NdArrays.ofLongs(shape(2, 3, 2)); + // TInt64 tensor = TInt64.vectorOf(10,100); + // // TInt64 tensor = TInt64.tensorOf(model.functions().get(0).signature().getInputs().values().iterator().next().shape); + // Tensor result = session.runner() + // .feed("keras_tensor", tensor) + // .fetch("output_0") + // .run().get(0); + + // Result result = model.call(tensors); + // byte[] graphDef = Files.readAllBytes(Paths.get("./data/semantic/model.keras")); + // GraphDef def = GraphDef.newBuilder() + // .mergeFrom(graphDef) + // .build(); + // graph.importGraphDef(def); + + // try (Session session = new Session(graph)) { + // // // Prepare input data as a Tensor + // // Tensor inputTensor = Tensor.create(inputData); + + // // // Run inference + // // Tensor result = session.runner() + // // .feed("input_node_name", inputTensor) // Replace with actual input tensor name + // // .fetch("output_node_name") // Replace with actual output tensor name + // // .run().get(0); + + // // // Process the result + // // float[][] output = result.copyTo(new float[1][outputSize]); + // // System.out.println(Arrays.toString(output[0])); + // } + // } catch (IOException e) { + // e.printStackTrace(); + // } + } + +} diff --git a/src/main/java/org/studiorailgun/conversation/AgentLoop.java b/src/main/java/org/studiorailgun/conversation/AgentLoop.java new file mode 100644 index 0000000..2f74ddd --- /dev/null +++ b/src/main/java/org/studiorailgun/conversation/AgentLoop.java @@ -0,0 +1,59 @@ +package org.studiorailgun.conversation; + +import java.util.Scanner; + +import org.studiorailgun.Globals; +import org.studiorailgun.conversation.parser.CommandParser; +import org.studiorailgun.conversation.tracking.Actor; +import org.studiorailgun.conversation.tracking.Conversation; +import org.studiorailgun.conversation.tracking.Statement; +import org.studiorailgun.kobold.KoboldPrinter; +import org.studiorailgun.kobold.KoboldRequest; + +public class AgentLoop { + + /** + * The size of the context available + */ + static final int CONTEXT_SIZE = 32 * 1024; + + /** + * Controls whether the main parser loop is running + */ + public static boolean running = true; + + /** + * The main method + */ + public static void main(){ + Globals.init(); + + //setup actual input + Scanner scan = new Scanner(System.in); + String prompt = ""; + + //setup conversation tracking + Conversation convo = new Conversation(); + Actor player = new Actor("John"); + Actor ai = new Actor("Dave"); + convo.addParticipant(player); + convo.addParticipant(ai); + + //actual main loop + while(running){ + + //handle player statement + prompt = scan.nextLine(); + if(CommandParser.parseCommands(prompt)){ + continue; + } + convo.addStatement(new Statement(player, prompt)); + + //handle ai statement + throw new UnsupportedOperationException("asdf"); + } + + scan.close(); + } + +} diff --git a/src/main/java/org/studiorailgun/conversation/LLMLoop.java b/src/main/java/org/studiorailgun/conversation/LLMLoop.java new file mode 100644 index 0000000..97a18ca --- /dev/null +++ b/src/main/java/org/studiorailgun/conversation/LLMLoop.java @@ -0,0 +1,62 @@ +package org.studiorailgun.conversation; + +import java.util.Scanner; + +import org.studiorailgun.conversation.parser.CommandParser; +import org.studiorailgun.conversation.tracking.Actor; +import org.studiorailgun.conversation.tracking.Conversation; +import org.studiorailgun.conversation.tracking.Statement; +import org.studiorailgun.kobold.KoboldPrinter; +import org.studiorailgun.kobold.KoboldRequest; + +public class LLMLoop { + + /** + * The size of the context available + */ + static final int CONTEXT_SIZE = 32 * 1024; + + /** + * Controls whether the main parser loop is running + */ + public static boolean running = true; + + /** + * The main method + */ + public static void main(){ + + //setup actual input + Scanner scan = new Scanner(System.in); + String prompt = ""; + + //setup conversation tracking + Conversation convo = new Conversation(); + Actor player = new Actor("John"); + Actor ai = new Actor("Dave"); + convo.addParticipant(player); + convo.addParticipant(ai); + + //actual main loop + while(running){ + + //handle player statement + prompt = scan.nextLine(); + if(CommandParser.parseCommands(prompt)){ + continue; + } + convo.addStatement(new Statement(player, prompt)); + + //handle ai statement + KoboldRequest request = convo.generateStatementRequest(ai); + request.fire(); + KoboldPrinter printer = new KoboldPrinter(request); + printer.print(); + System.out.println("[System] Created prompt of size: " + request.getPrompt().length() + " (" + (request.getPrompt().length() / (float)CONTEXT_SIZE) + ")"); + convo.pushbackRequest(ai, request); + } + + scan.close(); + } + +} diff --git a/src/main/java/org/studiorailgun/conversation/categorization/ConvCategorizationNet.java b/src/main/java/org/studiorailgun/conversation/categorization/ConvCategorizationNet.java new file mode 100644 index 0000000..465c498 --- /dev/null +++ b/src/main/java/org/studiorailgun/conversation/categorization/ConvCategorizationNet.java @@ -0,0 +1,116 @@ +package org.studiorailgun.conversation.categorization; + +import java.io.File; +import java.io.IOException; + +import org.deeplearning4j.nn.api.OptimizationAlgorithm; +import org.deeplearning4j.nn.conf.BackpropType; +import org.deeplearning4j.nn.conf.MultiLayerConfiguration; +import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.layers.DenseLayer; +import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; +import org.deeplearning4j.nn.weights.WeightInit; +import org.deeplearning4j.optimize.api.InvocationType; +import org.deeplearning4j.optimize.listeners.EvaluativeListener; +import org.deeplearning4j.optimize.listeners.ScoreIterationListener; +import org.nd4j.evaluation.classification.Evaluation; +import org.nd4j.linalg.activations.Activation; +import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; +import org.nd4j.linalg.learning.config.Sgd; + +/** + * Computational model for categorizing conversations + */ +public class ConvCategorizationNet { + + /** + * The configuration for the network + */ + MultiLayerConfiguration conf; + + /** + * The actual model created from the configuration + */ + MultiLayerNetwork model; + + /** + * The number of epochs to try to train for + */ + int targetEpochs; + + /** + * Initializes the model + */ + public void init(){ + this.conf = new NeuralNetConfiguration.Builder() + //setup base config + .weightInit(WeightInit.XAVIER) + .activation(Activation.RELU) + .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) + .updater(new Sgd(0.05)) + .list() + //add layers + .layer(0, new DenseLayer.Builder() + .nIn(4096) + .nOut(4096) + .build() + ) + //add back propagation + .backpropType(BackpropType.Standard) + //finalize + .build(); + + this.model = new MultiLayerNetwork(conf); + model.init(); + } + + public void apply(String input){ + // this.model.output(); + } + + /** + * Trains the model on a set of data + * @param trainData The training data + * @param testData The test data + */ + public void train(DataSetIterator trainData, DataSetIterator testData){ + System.out.println("Train model.."); + model.setListeners(new ScoreIterationListener(1), new EvaluativeListener(testData, 1, InvocationType.EPOCH_END)); + model.fit(trainData, targetEpochs); + } + + /** + * Evaluates the model against the test data + * @param testData The test data + */ + public void evaluate(DataSetIterator testData){ + System.out.println("Evaluate model.."); + Evaluation eval = model.evaluate(testData); + System.out.println(eval.stats()); + } + + /** + * Saves the model to disk + * @param location The file to save the model into + */ + public void save(File location){ + try { + model.save(location, true); + } catch (IOException e) { + e.printStackTrace(); + } + } + + /** + * Loads the model from disk + * @param location The location of the model file + */ + public void load(File location){ + try { + this.model = MultiLayerNetwork.load(location, true); + } catch (IOException e) { + e.printStackTrace(); + } + } + +} diff --git a/src/main/java/org/studiorailgun/conversation/categorization/SentenceFunctionCategorizor.java b/src/main/java/org/studiorailgun/conversation/categorization/SentenceFunctionCategorizor.java new file mode 100644 index 0000000..e9c7bf5 --- /dev/null +++ b/src/main/java/org/studiorailgun/conversation/categorization/SentenceFunctionCategorizor.java @@ -0,0 +1,36 @@ +package org.studiorailgun.conversation.categorization; + +/** + * Categorizes sentences based on function + */ +public class SentenceFunctionCategorizor { + + /** + * The function of the sentence + */ + public static enum SentenceFunction { + /** + * Transfers information to the other party + * ie, declaring facts, declaring perspective of the world (how you feel), etc + * "The ball is red". "I don't like that". + */ + TRANSFER_INFORMATION, + + /** + * Query information from the other party + * ie, asking for a fact or perspective + * "Is that ball red?" "What color is your hat?" "How do you feel about that?" + */ + QUERY_INFORMATION, + } + + /** + * Categorizes the sentence by function + * @param input The input sentence + * @return The function of the sentence + */ + public static SentenceFunction categorize(String input){ + return SentenceFunction.QUERY_INFORMATION; + } + +} diff --git a/src/main/java/org/studiorailgun/conversation/parser/CommandParser.java b/src/main/java/org/studiorailgun/conversation/parser/CommandParser.java new file mode 100644 index 0000000..061afec --- /dev/null +++ b/src/main/java/org/studiorailgun/conversation/parser/CommandParser.java @@ -0,0 +1,30 @@ +package org.studiorailgun.conversation.parser; + +import org.studiorailgun.conversation.AgentLoop; +import org.studiorailgun.conversation.LLMLoop; + +/** + * Parses player commands to execute + */ +public class CommandParser { + + /** + * Parses the input for commands. If a valid command is recognized, returns true. Otherwise returns false. + * @param input The input string + * @return true if a command was executed, false otherwise + */ + public static boolean parseCommands(String input){ + if(input.equals("exit")){ + LLMLoop.running = false; + AgentLoop.running = false; + } else if(input.equals("s hypo")){ //save a response failing to consider a hypothetical + + } else if(input.equals("s q")){ //save a question + + } else if(input.equals("s format")){ //save a formatting error + + } + return false; + } + +} diff --git a/src/main/java/org/studiorailgun/conversation/semantic/SentenceSubjectParser.java b/src/main/java/org/studiorailgun/conversation/semantic/SentenceSubjectParser.java new file mode 100644 index 0000000..026572c --- /dev/null +++ b/src/main/java/org/studiorailgun/conversation/semantic/SentenceSubjectParser.java @@ -0,0 +1,53 @@ +package org.studiorailgun.conversation.semantic; + +import java.io.File; +import java.io.IOException; + +import org.deeplearning4j.models.word2vec.Word2Vec; +import org.deeplearning4j.nn.api.OptimizationAlgorithm; +import org.deeplearning4j.nn.conf.BackpropType; +import org.deeplearning4j.nn.conf.MultiLayerConfiguration; +import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.layers.DenseLayer; +import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; +import org.deeplearning4j.nn.weights.WeightInit; +import org.deeplearning4j.optimize.api.InvocationType; +import org.deeplearning4j.optimize.listeners.EvaluativeListener; +import org.deeplearning4j.optimize.listeners.ScoreIterationListener; +import org.deeplearning4j.text.sentenceiterator.LineSentenceIterator; +import org.deeplearning4j.text.sentenceiterator.SentenceIterator; +import org.deeplearning4j.text.tokenization.tokenizer.preprocessor.CommonPreprocessor; +import org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory; +import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory; +import org.nd4j.evaluation.classification.Evaluation; +import org.nd4j.linalg.activations.Activation; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; +import org.nd4j.linalg.learning.config.Sgd; + +/** + * Parses the subject of a sentence + */ +public class SentenceSubjectParser { + /** + * The configuration for the network + */ + MultiLayerConfiguration conf; + + /** + * The actual model created from the configuration + */ + MultiLayerNetwork model; + + /** + * The number of epochs to try to train for + */ + int targetEpochs; + + /** + * Initializes the model + */ + public void init(){ + + } +} diff --git a/src/main/java/org/studiorailgun/conversation/tracking/Actor.java b/src/main/java/org/studiorailgun/conversation/tracking/Actor.java new file mode 100644 index 0000000..90b8692 --- /dev/null +++ b/src/main/java/org/studiorailgun/conversation/tracking/Actor.java @@ -0,0 +1,39 @@ +package org.studiorailgun.conversation.tracking; + +/** + * A character in a conversation who can say statements + */ +public class Actor { + + /** + * The name of the actor + */ + String name; + + /** + * Creates an actor + * @param name The name + */ + public Actor(String name){ + this.name = name; + } + + /** + * Gets the name of the actor + * @return The name + */ + public String getName() { + return name; + } + + /** + * Sets the name of the actor + * @param name The name + */ + public void setName(String name) { + this.name = name; + } + + + +} diff --git a/src/main/java/org/studiorailgun/conversation/tracking/Conversation.java b/src/main/java/org/studiorailgun/conversation/tracking/Conversation.java new file mode 100644 index 0000000..25264fc --- /dev/null +++ b/src/main/java/org/studiorailgun/conversation/tracking/Conversation.java @@ -0,0 +1,113 @@ +package org.studiorailgun.conversation.tracking; + +import java.util.LinkedList; + +import org.studiorailgun.kobold.KoboldRequest; +import org.studiorailgun.kobold.KoboldSymbols; + +public class Conversation { + + /** + * The participants in the conversation + */ + LinkedList participants; + + /** + * The statements of the conversation + */ + LinkedList statements; + + /** + * Number of statements to send in memory + */ + static final int STATEMENTS_TO_SEND_IN_MEMORY = 10; + + /** + * Constructor + */ + public Conversation(){ + this.participants = new LinkedList(); + this.statements = new LinkedList(); + } + + /** + * Adds a participant to the conversation + * @param actor The new actor + */ + public void addParticipant(Actor actor){ + this.participants.add(actor); + } + + /** + * Adds a statement to the conversation + * @param statement The statement + */ + public void addStatement(Statement statement){ + this.statements.add(statement); + statement.setConversation(this); + } + + /** + * Requests a statement for the actor + * @param actor The actor + */ + public KoboldRequest generateStatementRequest(Actor actor){ + KoboldRequest rVal = null; + String prompt = ""; + + //memory portion + prompt = prompt + "<|system|>\\n"; + prompt = prompt + "You are named " + actor.getName() + "\\n"; + prompt = prompt + "You are in a conversation with \\n"; + for(Actor participant : this.participants){ + if(participant != actor){ + prompt = prompt + participant.getName() + "\\n"; + } + } + prompt = prompt + "The most recent dialog in the conversation is\\n"; + int increments = 0; + for(int i = this.statements.size() - 1; i >= 0 && increments < STATEMENTS_TO_SEND_IN_MEMORY; i--){ + Statement statement = this.statements.get(i); + prompt = prompt + statement.getActor().getName() + ": " + statement.getContent() + "\\n"; + increments++; + } + prompt = prompt + "<|end|>\\n"; + + //request new statement + prompt = prompt + "<|user|>\\n"; + prompt = prompt + "What is the next thing " + actor.getName() + " should say?\\n"; + prompt = prompt + "Give me exactly the quote " + actor.getName() + " should say and nothing more.\\n"; + // prompt = prompt + "Do not describe the tone.\\n"; + prompt = prompt + "Do not explain the quote.\\n"; + prompt = prompt + "Do not give me dialog for any other character.\\n"; + prompt = prompt + "Do not describe what other characters should do.\\n"; + prompt = prompt + "Do not add notes.\\n"; + prompt = prompt + "Do not include ways to continue the conversation.\\n"; + prompt = prompt + "Do not include quotation marks.\\n"; + prompt = prompt + "<|end|>\\n"; + prompt = prompt + "<|assistant|>\\n"; + + //construct actual request from prompt + rVal = new KoboldRequest(prompt); + return rVal; + } + + /** + * Pushes the result of a kobold request into the conversation stack + * @param actor The actor that said the statement + * @param request The request + */ + public void pushbackRequest(Actor actor, KoboldRequest request){ + String rawResponse = request.getResponse(); + String parsed = rawResponse; + //remove notes if they are detected + if(KoboldSymbols.containsBannedSymbol(parsed)){ + parsed = KoboldSymbols.removeBannedSymbol(parsed); + } + + //actually create and store statement + Statement statement = new Statement(actor, parsed); + this.addStatement(statement); + } + +} diff --git a/src/main/java/org/studiorailgun/conversation/tracking/Statement.java b/src/main/java/org/studiorailgun/conversation/tracking/Statement.java new file mode 100644 index 0000000..74f17af --- /dev/null +++ b/src/main/java/org/studiorailgun/conversation/tracking/Statement.java @@ -0,0 +1,79 @@ +package org.studiorailgun.conversation.tracking; + +/** + * A statement by a character in a conversation + */ +public class Statement { + + /** + * Incrementer for ids + */ + static long idIncrementer = 0; + + /** + * The id of this statement + */ + long id; + + /** + * The raw content of the statement + */ + String content; + + /** + * The actor who said the statement + */ + Actor actor; + + /** + * The conversation this statement was uttered in + */ + Conversation conversation; + + /** + * Creates a statement + * @param actor The actor who said the statement + * @param content The content of the statement + */ + public Statement(Actor actor, String content){ + this.id = idIncrementer; + idIncrementer++; + this.actor = actor; + this.content = content; + } + + public long getId() { + return id; + } + + public void setId(long id) { + this.id = id; + } + + public String getContent() { + return content; + } + + public void setContent(String content) { + this.content = content; + } + + public Actor getActor() { + return actor; + } + + public void setActor(Actor actor) { + this.actor = actor; + } + + public Conversation getConversation() { + return conversation; + } + + public void setConversation(Conversation conversation) { + this.conversation = conversation; + } + + + +} diff --git a/src/main/java/org/studiorailgun/knowledge/KnowledgeWeb.java b/src/main/java/org/studiorailgun/knowledge/KnowledgeWeb.java new file mode 100644 index 0000000..08a8242 --- /dev/null +++ b/src/main/java/org/studiorailgun/knowledge/KnowledgeWeb.java @@ -0,0 +1,99 @@ +package org.studiorailgun.knowledge; + +import java.util.Collection; +import java.util.HashMap; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; + +/** + * A knowledge web + */ +public class KnowledgeWeb { + + + /** + * The nodes; + */ + Map nodes; + + /** + * The relations + */ + Map relations; + + /** + * Map of relation type -> list of relations of that type + */ + transient Map> typeRelationLookup; + + public KnowledgeWeb(){ + this.nodes = new HashMap(); + this.relations = new HashMap(); + this.typeRelationLookup = new HashMap>(); + } + + /** + * Initializes the linked datastructures between items + */ + public void initLinks(){ + //set id incrementers + for(Entry entry : this.nodes.entrySet()){ + if(entry.getValue().getId() >= Node.idIncrementer){ + Node.idIncrementer = entry.getValue().getId() + 1; + } + if(entry.getValue().getId() != entry.getKey()){ + throw new Error("Node id doesn't match map value! " + entry.getValue().getId() + " " + entry.getValue().getName() + " " + entry.getKey()); + } + } + for(Entry entry : this.relations.entrySet()){ + Relation relation = entry.getValue(); + if(relation.getId() >= Relation.idIncrementer){ + Relation.idIncrementer = relation.getId() + 1; + } + if(this.typeRelationLookup.containsKey(relation.getName())){ + this.typeRelationLookup.get(relation.getName()).add(relation.getId()); + } else { + LinkedList list = new LinkedList(); + list.add(relation.getId()); + this.typeRelationLookup.put(relation.getName(),list); + } + if(relation.getId() != entry.getKey()){ + throw new Error("Relation id doesn't match map value! " + relation.getId() + " " + relation.getName() + " " + entry.getKey()); + } + } + } + + public Collection getNodes() { + return nodes.values(); + } + + public Collection getRelations() { + return relations.values(); + } + + public Node getNode(int id){ + return nodes.get(id); + } + + /** + * Gets a singleton relationship with a given parent + * @param relationType The relationship type + * @param parent The parent + * @return The relation if it exists, null otherwise + */ + public Relation getSingletonRelation(String relationType, Node parent){ + List relations = this.typeRelationLookup.get(relationType); + if(relations != null){ + for(int relationId : relations){ + if(relationId == parent.getId()){ + return this.relations.get(relationId); + } + } + } + return null; + } + + +} diff --git a/src/main/java/org/studiorailgun/knowledge/Node.java b/src/main/java/org/studiorailgun/knowledge/Node.java new file mode 100644 index 0000000..df50e92 --- /dev/null +++ b/src/main/java/org/studiorailgun/knowledge/Node.java @@ -0,0 +1,46 @@ +package org.studiorailgun.knowledge; + +/** + * A node in the knowledge web + */ +public class Node { + + /** + * The id incremented + */ + public static int idIncrementer = 0; + + /** + * The id of the node + */ + int id; + + /** + * The name of the node + */ + String name; + + + /** + * Constructor + * @param name The name of the node + */ + public Node(String name){ + this.id = idIncrementer; + idIncrementer++; + this.name = name; + } + + + public String getName() { + return name; + } + + + public int getId() { + return id; + } + + + +} diff --git a/src/main/java/org/studiorailgun/knowledge/Relation.java b/src/main/java/org/studiorailgun/knowledge/Relation.java new file mode 100644 index 0000000..ace476a --- /dev/null +++ b/src/main/java/org/studiorailgun/knowledge/Relation.java @@ -0,0 +1,64 @@ +package org.studiorailgun.knowledge; + +import org.studiorailgun.Globals; + +public class Relation { + + /** + * The id incremented + */ + public static int idIncrementer = 0; + + /** + * The id of the relationship + */ + int id; + + /** + * The name of the relationship + */ + String name; + + /** + * The parent of the relationship + */ + int parent; + + /** + * The child of the relationship + */ + int child; + + /** + * Constructor + * @param name The name of the relation + * @param parent The parent of the relation + * @param child The child of the relation + */ + public Relation(String name, Node parent, Node child){ + this.id = idIncrementer; + idIncrementer++; + this.name = name; + this.parent = parent.getId(); + this.child = child.getId(); + } + + public int getId() { + return id; + } + + public String getName() { + return name; + } + + public Node getParent() { + return Globals.web.getNode(this.parent); + } + + public Node getChild() { + return Globals.web.getNode(this.child); + } + + + +} diff --git a/src/main/java/org/studiorailgun/kobold/KoboldPrinter.java b/src/main/java/org/studiorailgun/kobold/KoboldPrinter.java new file mode 100644 index 0000000..0f63bfb --- /dev/null +++ b/src/main/java/org/studiorailgun/kobold/KoboldPrinter.java @@ -0,0 +1,68 @@ +package org.studiorailgun.kobold; + +import java.util.concurrent.TimeUnit; + +/** + * Prints the response from a kobold request + */ +public class KoboldPrinter { + + /** + * The number of milliseconds to wait between printing characters + */ + static final int CHAR_DELAY = 10; + + /** + * The interval to check at in milliseconds + */ + static final int CHECK_INTERVAL = 300; + + /** + * The request to print + */ + KoboldRequest request; + + /** + * Position of the printer (how many characters has it printed) + */ + int position = 0; + + /** + * Constructor + */ + public KoboldPrinter(KoboldRequest request){ + this.request = request; + } + + /** + * Prints the request. Blocks the thread until the printing has finished. + */ + public void print(){ + int i = 0; + while(!this.request.isCompleted() || position < this.request.getResponse().length()){ + try { + TimeUnit.MILLISECONDS.sleep(CHAR_DELAY); + } catch (InterruptedException e) { + e.printStackTrace(); + } + if(!this.request.isCompleted() && i % (CHECK_INTERVAL / CHAR_DELAY) == 0){ + this.request.check(); + } + if(this.position < this.request.getResponse().length()){ + System.out.print(this.request.getResponse().charAt(this.position)); + this.position++; + } + //check for occurrences of immediate dialong ending symbols + if(KoboldSymbols.containsBannedSymbol(this.request.getResponse())){ + String message = KoboldSymbols.removeBannedSymbol(this.request.getResponse()); + this.request.markCompleted(); + System.out.print("\r" + message); + break; + } + i++; + } + //make sure we're on a fresh line + System.out.println(); + } + +} diff --git a/src/main/java/org/studiorailgun/kobold/KoboldRequest.java b/src/main/java/org/studiorailgun/kobold/KoboldRequest.java new file mode 100644 index 0000000..d93ae05 --- /dev/null +++ b/src/main/java/org/studiorailgun/kobold/KoboldRequest.java @@ -0,0 +1,188 @@ +package org.studiorailgun.kobold; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.net.HttpURLConnection; +import java.net.MalformedURLException; +import java.net.ProtocolException; +import java.net.URI; +import java.net.URISyntaxException; +import java.net.URL; +import java.net.URLConnection; +import java.nio.charset.StandardCharsets; +import java.util.UUID; + +/** + * A request to koboldcpp + */ +public class KoboldRequest { + + /** + * Base Kobold URL + */ + static final String BASE_URL = "http://localhost:5001"; + + /** + * Endpoint to start generating + */ + static final String STREAM_ENDPOINT = "/api/v1/generate"; + + /** + * Endpoint to query generation status + */ + static final String CHECK_ENDPOINT = "/api/extra/generate/check"; + + /** + * The prompt + */ + String prompt; + + /** + * The gen key for streaming + */ + String genKey; + + /** + * The response string + */ + String response = ""; + + /** + * True if has received the complete generation response from kobold, false otherwise + */ + boolean completed = false; + + /** + * Constructor + * @param prompt The prompt + */ + public KoboldRequest(String prompt){ + this.prompt = prompt; + this.genKey = UUID.randomUUID().toString(); + } + + + /** + * Fires the request to generate + */ + public void fire(){ + try { + String urlPath = BASE_URL + STREAM_ENDPOINT; + URL url = new URI(urlPath).toURL(); + URLConnection con = url.openConnection(); + HttpURLConnection http = (HttpURLConnection)con; + http.setRequestMethod("POST"); // PUT is another valid option + http.setDoOutput(true); + + String body = "{\"prompt\": \"" + this.prompt + "\", \"genkey\": \"" + genKey + "\"}"; + byte[] out = body.getBytes(StandardCharsets.UTF_8); + int length = out.length; + + http.setFixedLengthStreamingMode(length); + http.setRequestProperty("Content-Type", "application/json; charset=UTF-8"); + http.connect(); + try(OutputStream os = http.getOutputStream()) { + os.write(out); + } + http.disconnect(); + //if uncommented, will block until response is fully generated + // try(InputStream is = http.getInputStream()){ + // byte[] response = is.readAllBytes(); + // String responseParsed = new String(response, StandardCharsets.UTF_8); + // System.out.println("parsed: " + responseParsed); + // } + } catch (MalformedURLException e) { + // TODO Auto-generated catch block + e.printStackTrace(); + } catch (ProtocolException e) { + // TODO Auto-generated catch block + e.printStackTrace(); + } catch (IOException e) { + // TODO Auto-generated catch block + e.printStackTrace(); + } catch (URISyntaxException e) { + // TODO Auto-generated catch block + e.printStackTrace(); + } + } + + /** + * Checks for new text generated from the prompt + */ + public void check(){ + try { + String urlPath = BASE_URL + CHECK_ENDPOINT; + URL url = new URI(urlPath).toURL(); + URLConnection con = url.openConnection(); + HttpURLConnection http = (HttpURLConnection)con; + http.setRequestMethod("POST"); // PUT is another valid option + http.setDoOutput(true); + http.setDoInput(true); + String body = "{\"genkey\": \"" + genKey + "\"}"; + byte[] out = body.getBytes(StandardCharsets.UTF_8); + int length = out.length; + + http.setFixedLengthStreamingMode(length); + http.setRequestProperty("Content-Type", "application/json; charset=UTF-8"); + http.connect(); + try(OutputStream os = http.getOutputStream()) { + os.write(out); + } + try(InputStream is = http.getInputStream()){ + byte[] responseBytes = is.readAllBytes(); + String responseParsed = new String(responseBytes, StandardCharsets.UTF_8); + String responseBody = responseParsed.substring(23, responseParsed.length() - 4); + if(this.response.equals(responseBody) && this.response.length() > 0){ + this.completed = true; + } else { + this.response = responseBody; + } + } + } catch (MalformedURLException e) { + // TODO Auto-generated catch block + e.printStackTrace(); + } catch (ProtocolException e) { + // TODO Auto-generated catch block + e.printStackTrace(); + } catch (IOException e) { + // TODO Auto-generated catch block + e.printStackTrace(); + } catch (URISyntaxException e) { + // TODO Auto-generated catch block + e.printStackTrace(); + } + } + + /** + * Gets the response + * @return The response + */ + public String getResponse(){ + return response; + } + + /** + * Gets whether this request has completed generation or not + * @return true if has completed generation, false otherwise + */ + public boolean isCompleted(){ + return completed; + } + + /** + * Flags the request as completed + */ + public void markCompleted(){ + this.completed = true; + } + + /** + * Gets the prompt of the request + * @return The prompt + */ + public String getPrompt(){ + return prompt; + } + +} diff --git a/src/main/java/org/studiorailgun/kobold/KoboldSymbols.java b/src/main/java/org/studiorailgun/kobold/KoboldSymbols.java new file mode 100644 index 0000000..6bfd608 --- /dev/null +++ b/src/main/java/org/studiorailgun/kobold/KoboldSymbols.java @@ -0,0 +1,55 @@ +package org.studiorailgun.kobold; + +/** + * Symbol tables for kobold parsing + */ +public class KoboldSymbols { + + /** + * A list of symbols we do not want to appear in the output + */ + public static final String[] BANNED_SYMBOLS = new String[]{ + "Note:", + "**Note**:", + "support:", + "Follow-up Questions:", + "Answer:", + "Prompt:", + "---", + "\\n\\n", + "Response to the instruction:", + }; + + /** + * Checks if the input contains a banned symbol + * @param input The input + * @return true if it contains a banned symbol, false otherwise + */ + public static boolean containsBannedSymbol(String input){ + for(String symbol : BANNED_SYMBOLS){ + if(input.contains(symbol)){ + return true; + } + } + return false; + } + + /** + * Removes banned symbols from the input + * @param input The input + * @return The input with banned symbols removed + */ + public static String removeBannedSymbol(String input){ + String target = input; + while(KoboldSymbols.containsBannedSymbol(target)){ + for(String symbol : BANNED_SYMBOLS){ + if(input.contains(symbol)){ + target = target.split(symbol)[0]; + break; + } + } + } + return target; + } + +} diff --git a/src/main/java/org/studiorailgun/schedule/Event.java b/src/main/java/org/studiorailgun/schedule/Event.java new file mode 100644 index 0000000..4b6b96f --- /dev/null +++ b/src/main/java/org/studiorailgun/schedule/Event.java @@ -0,0 +1,68 @@ +package org.studiorailgun.schedule; + +import java.util.LinkedList; + +/** + * An event + */ +public class Event implements Comparable{ + + /** + * The time when this event will happen + */ + Time time; + + /** + * The type of the event + */ + String type; + + /** + * The data of the event + */ + LinkedList data = new LinkedList(); + + /** + * Creates an event + * @param time The time the event occurs + * @param type The type of the event + * @param params The data of the event + */ + public Event(Time time, String type, Object ... params){ + this.time = time; + this.type = type; + for(Object el : params){ + data.add(el); + } + } + + /** + * Gets the time when the event occurs + * @return The time + */ + public Time getTime(){ + return time; + } + + /** + * Gets the type of the event + * @return The type of the event + */ + public String getType(){ + return type; + } + + /** + * Gets the data for the event + * @return The data for the event + */ + public LinkedList getData(){ + return data; + } + + @Override + public int compareTo(Event other) { + return this.time.compareTo(other.time); + } + +} diff --git a/src/main/java/org/studiorailgun/schedule/EventScheduler.java b/src/main/java/org/studiorailgun/schedule/EventScheduler.java new file mode 100644 index 0000000..45e444c --- /dev/null +++ b/src/main/java/org/studiorailgun/schedule/EventScheduler.java @@ -0,0 +1,56 @@ +package org.studiorailgun.schedule; + +import java.util.PriorityQueue; + +/** + * Maintains the schedule of events + */ +public class EventScheduler { + + /** + * The queue of events + */ + PriorityQueue queue = new PriorityQueue(); + + + /** + * Adds an event to the scheduler + * @param event The event + */ + public void add(Event event){ + queue.add(event); + } + + /** + * Peeks at the next event in the queue + * @return The next event + */ + public Event peek(){ + return queue.peek(); + } + + /** + * Gets the next scheduled event + * @return The next scheduled event + */ + public Event get(){ + return queue.poll(); + } + + /** + * Removes the event from the schedule + * @param event The event + */ + public void remove(Event event){ + queue.remove(event); + } + + /** + * Gets the queue of events + * @return The queue of events + */ + public PriorityQueue getQueue(){ + return queue; + } + +} diff --git a/src/main/java/org/studiorailgun/schedule/Time.java b/src/main/java/org/studiorailgun/schedule/Time.java new file mode 100644 index 0000000..a366289 --- /dev/null +++ b/src/main/java/org/studiorailgun/schedule/Time.java @@ -0,0 +1,149 @@ +package org.studiorailgun.schedule; + +/** + * A time in the simulation timeline + */ +public class Time implements Comparable