first commit

This commit is contained in:
austin 2024-12-28 18:01:52 -05:00
commit cec13a218d
37 changed files with 2194 additions and 0 deletions

10
.gitignore vendored Normal file
View File

@ -0,0 +1,10 @@
.mypy_cache
.venv
dist
target
src/trpg.egg-info
# ml model files
data/semantic/model
**.keras
**.h5

5
.vscode/extensions.json vendored Normal file
View File

@ -0,0 +1,5 @@
{
"recommendations": [
"ms-python.python",
]
}

21
.vscode/launch.json vendored Normal file
View File

@ -0,0 +1,21 @@
{
// Use IntelliSense to learn about possible attributes.
// Hover to view descriptions of existing attributes.
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
"version": "0.2.0",
"configurations": [
{
"type": "java",
"name": "Launch Java Program",
"request": "launch",
"mainClass": ""
},
{
"name": "Python Debugger: Current File",
"type": "debugpy",
"request": "launch",
"program": "${file}",
"console": "integratedTerminal"
}
]
}

4
.vscode/settings.json vendored Normal file
View File

@ -0,0 +1,4 @@
{
"java.configuration.updateBuildConfiguration": "interactive",
"java.compile.nullAnalysis.mode": "disabled"
}

24
README.md Normal file
View File

@ -0,0 +1,24 @@
```
python -m venv .venv
```
```
source ./.venv/Scripts/activate
```
```
pip install tensorflow mypy numpy
```
```
pip install keras==2.15.0
```
```
pip install types-tensorflow
```
In settings, make sure to check `Run Using Active Interpreter` in mypy settings

39
current_goal.txt Normal file
View File

@ -0,0 +1,39 @@
context:
user is a barbarian
chatbot is a wizard
sitting in a tavern by a fireplace
interaction:
the user says hello
the chatbot replies and queries the user's name (query for name relation to other participant)
the user answers and then asks the chatbot about the color of its hat (associate name relation to other participant)
the chatbot replies with correct information about the hat (it is blue) (lookup complex association and synthesize response)
things needed:
knowledge web
- a person
- a place
- a thing
- a relationship
- between persons (ie the lack of relationship between the user and chatbot)
- between a person and a thing (the hat the wizard is wearing)
- information about a thing
- color
routines for handling small chat
- introductions
- initial goal routine to query conversation partner for information about themself
information queries
- determine target of query
- determine quality/quantity being queried
information transfers
- determine information that is being transferred
- associate the new information with the knowledge web
distinguish between sentence types (query vs transfer vs command, etc)

View File

@ -0,0 +1,8 @@
fdsafdsa
fdsafdsafdsaf
dsafd
saf
dsafdfds
afds
i
self

View File

@ -0,0 +1,8 @@
1
1
1
1
0
0
0
0

3
data/tokens.txt Normal file
View File

@ -0,0 +1,3 @@
self
i
I

View File

@ -0,0 +1,38 @@
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/maven-v4_0_0.xsd">
<modelVersion>4.0.0</modelVersion>
<groupId>studiorailgun</groupId>
<artifactId>Renderer</artifactId>
<version>0.1.1</version>
<build>
<plugins>
<plugin>
<artifactId>maven-shade-plugin</artifactId>
<version>3.2.4</version>
<executions>
<execution>
<phase>package</phase>
<goals>
<goal>shade</goal>
</goals>
<configuration>
<transformers>
<transformer>
<mainClass>org.studiorailgun.Main</mainClass>
</transformer>
<transformer />
</transformers>
</configuration>
</execution>
</executions>
</plugin>
</plugins>
</build>
<properties>
<maven.compiler.target>17</maven.compiler.target>
<maven.compiler.source>17</maven.compiler.source>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
<dl4j-master.version>1.0.0-M2</dl4j-master.version>
<nd4j.backend>nd4j-native</nd4j.backend>
</properties>
</project>

1
improvement_ideas.txt Normal file
View File

@ -0,0 +1 @@
summarize previous statements to provide context instead of using full statement

27
mypy.ini Normal file
View File

@ -0,0 +1,27 @@
# Global options:
[mypy]
install_types = True
disallow_untyped_defs = True
no_implicit_optional = True
check_untyped_defs = True
warn_return_any = True
show_error_codes = True
warn_unused_ignores = True
warn_unused_configs = True
plugins = numpy.typing.mypy_plugin
disallow_incomplete_defs = True
disallow_untyped_decorators = False
disallow_any_unimported = False
# Per-module options:
[mypy-mycode.foo.*]
disallow_untyped_defs = True
[mypy-mycode.bar]
warn_return_any = False
[mypy-somelibrary]
ignore_missing_imports = True

92
pom.xml Normal file
View File

@ -0,0 +1,92 @@
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<groupId>studiorailgun</groupId>
<artifactId>Renderer</artifactId>
<version>0.1.1</version>
<packaging>jar</packaging>
<properties>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
<maven.compiler.source>21</maven.compiler.source>
<maven.compiler.target>21</maven.compiler.target>
<dl4j-master.version>1.0.0-M2</dl4j-master.version>
<!-- Change the nd4j.backend property to nd4j-cuda-X-platform to use CUDA GPUs -->
<nd4j.backend>nd4j-native</nd4j.backend>
</properties>
<dependencies>
<!--DeepLearning4J-->
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>deeplearning4j-core</artifactId>
<version>${dl4j-master.version}</version>
</dependency>
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>deeplearning4j-nlp</artifactId>
<version>${dl4j-master.version}</version>
</dependency>
<!--Tensorflow Java-->
<!-- <dependency>
<groupId>org.tensorflow</groupId>
<artifactId>tensorflow</artifactId>
<version>1.15.0</version>
</dependency> -->
<dependency>
<groupId>org.tensorflow</groupId>
<artifactId>tensorflow-core-api</artifactId>
<version>1.0.0</version>
</dependency>
<dependency>
<groupId>org.tensorflow</groupId>
<artifactId>tensorflow-core-platform</artifactId>
<version>1.0.0</version>
</dependency>
<dependency>
<groupId>org.tensorflow</groupId>
<artifactId>tensorflow-core-native</artifactId>
<version>1.0.0</version>
</dependency>
<dependency>
<groupId>com.google.code.gson</groupId>
<artifactId>gson</artifactId>
<version>2.8.6</version>
</dependency>
</dependencies>
<!-- MAIN BUILD -->
<build>
<plugins>
<!--Shade the jar (pack all dependencies)-->
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-shade-plugin</artifactId>
<version>3.2.4</version>
<executions>
<execution>
<phase>package</phase>
<goals>
<goal>shade</goal>
</goals>
<configuration>
<transformers>
<transformer implementation="org.apache.maven.plugins.shade.resource.ManifestResourceTransformer">
<mainClass>org.studiorailgun.Main</mainClass>
</transformer>
<transformer implementation="org.apache.maven.plugins.shade.resource.ServicesResourceTransformer">
</transformer>
</transformers>
</configuration>
</execution>
</executions>
</plugin>
</plugins>
</build>
</project>

17
pyproject.toml Normal file
View File

@ -0,0 +1,17 @@
[build-system]
requires = ["setuptools>=68", "setuptools_scm[toml]>=8"]
build-backend = "setuptools.build_meta"
[project]
name = "trpg"
requires-python = ">=3.8"
dynamic = ["version"]
dependencies = [
# Add runtime dependencies here
"numpy",
"keras",
"tensorflow"
]
# Enables the usage of setuptools_scm
[tool.setuptools_scm]

View File

@ -0,0 +1,280 @@
package org.studiorailgun;
import com.google.gson.Gson;
import com.google.gson.GsonBuilder;
import java.io.BufferedReader;
import java.io.ByteArrayOutputStream;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.io.ObjectOutputStream;
import java.io.Serializable;
import java.nio.file.Files;
import java.nio.file.Path;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.util.LinkedList;
import java.util.List;
import java.util.concurrent.TimeUnit;
/**
* Utilities for dealing with files
*/
public class FileUtils {
/**
* Creates the gson instance
*/
static {
//init gson
GsonBuilder gsonBuilder = new GsonBuilder();
gson = gsonBuilder.create();
}
//used for serialization/deserialization in file operations
static Gson gson;
//maximum number of attempt to read the file
static final int maxReadFails = 3;
//Timeout duration between read attempts
static final int READ_TIMEOUT_DURATION = 5;
/**
* Reads a file to a string
* @param f The file
* @return The string
*/
public static String readFileToString(File f){
String rVal = "";
BufferedReader reader;
try {
reader = Files.newBufferedReader(f.toPath());
int failCounter = 0;
boolean reading = true;
StringBuilder builder = new StringBuilder("");
while(reading){
if(reader.ready()){
failCounter = 0;
int nextValue = reader.read();
if(nextValue == -1){
reading = false;
} else {
builder.append((char)nextValue);
}
} else {
failCounter++;
if(failCounter > maxReadFails){
reading = false;
} else {
try {
TimeUnit.MILLISECONDS.sleep(READ_TIMEOUT_DURATION);
} catch (InterruptedException ex) {
ex.printStackTrace();
}
}
}
}
rVal = builder.toString();
} catch (IOException ex) {
ex.printStackTrace();
}
return rVal;
}
public static String readStreamToString(InputStream resourceInputStream){
String rVal = "";
BufferedReader reader;
try {
reader = new BufferedReader(new InputStreamReader(resourceInputStream));
int failCounter = 0;
boolean reading = true;
StringBuilder builder = new StringBuilder("");
while(reading){
if(reader.ready()){
failCounter = 0;
int nextValue = reader.read();
if(nextValue == -1){
reading = false;
} else {
builder.append((char)nextValue);
}
} else {
failCounter++;
if(failCounter > maxReadFails){
reading = false;
} else {
try {
TimeUnit.MILLISECONDS.sleep(READ_TIMEOUT_DURATION);
} catch (InterruptedException ex) {
ex.printStackTrace();
}
}
}
}
rVal = builder.toString();
} catch (IOException ex) {
ex.printStackTrace();
}
return rVal;
}
/**
* Sanitizes a relative file path, guaranteeing that the initial slash is correct
* @param filePath The raw file path
* @return The sanitized file path
*/
public static String sanitizeFilePath(String filePath){
String rVal = new String(filePath);
rVal = rVal.trim();
if(rVal.startsWith("./")){
return rVal;
} else if(!rVal.startsWith("/")){
rVal = "/" + rVal;
}
return rVal;
}
/**
* Serializes an object to a filepath
* @param filePath The filepath
* @param object The object
*/
public static void serializeObjectToFilePath(String filePath, Object object){
Path path = new File(filePath).toPath();
try {
Files.write(path, gson.toJson(object).getBytes());
} catch (IOException ex) {
ex.printStackTrace();
}
}
/**
* Loads an object from the assets folder
* @param <T> The type of object
* @param file The file to load from
* @param className The class of the object inside the file
* @return The file
*/
public static <T>T loadObjectFromFile(File file, Class<T> className){
T rVal = null;
try {
rVal = gson.fromJson(Files.newBufferedReader(file.toPath()), className);
} catch (IOException ex) {
ex.printStackTrace();
}
return rVal;
}
/**
* Checks if a directory exists
* @param fileName
* @return true if directory exists, false otherwise
*/
public static boolean checkFileExists(String fileName){
File targetDir = new File(sanitizeFilePath(fileName));
if(targetDir.exists()){
return true;
} else {
return false;
}
}
/**
* Trys to create a directory
* @param directoryName
* @return true if directory was created, false if it was not
*/
public static boolean createDirectory(String directoryName){
String sanitizedPath = sanitizeFilePath(directoryName);
File targetDir = new File(sanitizedPath);
if(targetDir.exists()){
return false;
} else {
return targetDir.mkdirs();
}
}
/**
* Lists the files in a directory
* @param directoryName The path of the directory
* @return A list containing the names of all files inside that directory
*/
public static List<String> listDirectory(String directoryName){
List<String> rVal = new LinkedList<String>();
String sanitizedPath = sanitizeFilePath(directoryName);
File targetDir = new File(sanitizedPath);
String[] files = targetDir.list();
for(String name : files){
rVal.add(name);
}
return rVal;
}
/**
* Recursively deletes a path
* @param path The path
*/
public static void recursivelyDelete(String path){
File file = new File(path);
if(file.isDirectory()){
for(File child : file.listFiles()){
recursivelyDelete(child.getAbsolutePath());
}
}
if(file.exists()){
try {
Files.delete(file.toPath());
} catch (IOException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
}
}
/**
* Gets a file's path relative to a given directory
* @param file The file
* @param directory The directory
* @return The relative path
*/
public static String relativize(File file, File directory){
return directory.toURI().relativize(file.toURI()).getPath();
}
/**
* Computes the checksum of an object
* @param object The object
* @return The checksum
* @throws IOException Thrown on io errors reading the file
* @throws NoSuchAlgorithmException Thrown if MD5 isn't supported
*/
public static String getChecksum(Serializable object) throws IOException, NoSuchAlgorithmException {
ByteArrayOutputStream baos = null;
ObjectOutputStream oos = null;
try {
baos = new ByteArrayOutputStream();
oos = new ObjectOutputStream(baos);
oos.writeObject(object);
MessageDigest md = MessageDigest.getInstance("MD5");
byte[] bytes = md.digest(baos.toByteArray());
StringBuffer builder = new StringBuffer();
for(byte byteCurr : bytes){
builder.append(String.format("%02x",byteCurr));
}
return builder.toString();
} finally {
oos.close();
baos.close();
}
}
}

View File

@ -0,0 +1,25 @@
package org.studiorailgun;
import java.io.File;
import org.studiorailgun.knowledge.KnowledgeWeb;
/**
* Global variables
*/
public class Globals {
/**
* The knowledge web
*/
public static KnowledgeWeb web;
/**
* Initializes the knowledge web
*/
public static void init(){
web = FileUtils.loadObjectFromFile(new File("web.json"), KnowledgeWeb.class);
web.initLinks();
}
}

View File

@ -0,0 +1,82 @@
package org.studiorailgun;
import org.tensorflow.SavedModelBundle;
import org.tensorflow.Tensor;
import org.tensorflow.types.TInt64;
/**
* The main class
*/
public class Main {
/**
* The main method
*/
public static void main(String[] args){
// AgentLoop.main();
// try (Graph graph = new Graph()) {
// prediction_signature = tf.saved_model.signature_def_utils.predict_signature_def({"input": inputs}, {"output":output});
SavedModelBundle model = SavedModelBundle.load("./data/semantic/model");
System.out.println(model.functions());
System.out.println(model.functions().get(0).signature());
System.out.println(model.functions().get(1).signature());
System.out.println(model.functions().get(2).signature());
System.out.println(model.functions().get(0).signature().getInputs());
System.out.println(model.functions().get(0).signature().getInputs().values().iterator().next().shape);
//init
// model.functions().get(2).session().runner()
// .fetch("__saved_model_init_op")
// .run();
//run predict
TInt64 tensor = TInt64.scalarOf(10);
org.tensorflow.ndarray.Shape shape = tensor.shape().append(2001).append(128);
tensor = TInt64.tensorOf(shape);
// NdArray<LongNdArray> arr = NdArrays.vectorOfObjects(NdArrays.vectorOf((long)0, (long)0, (long)0, (long)0));
// System.out.println(arr.shape());
// tensor = TInt64.tensorOf(arr.shape());
Tensor result = model.function("serve").call(tensor);
// model.functions().get(0).session().runner()
// .feed("keras_tensor", tensor)
// .fetch("output_0")
// .run();
// Graph graph = model.graph();
// Session session = model.session();
// Tensor<Long> inputTensor = Tensor.of(Long.class, model.functions().get(0).signature().getInputs().values().iterator().next().shape);
// NdArrays.ofLongs(shape(2, 3, 2));
// TInt64 tensor = TInt64.vectorOf(10,100);
// // TInt64 tensor = TInt64.tensorOf(model.functions().get(0).signature().getInputs().values().iterator().next().shape);
// Tensor result = session.runner()
// .feed("keras_tensor", tensor)
// .fetch("output_0")
// .run().get(0);
// Result result = model.call(tensors);
// byte[] graphDef = Files.readAllBytes(Paths.get("./data/semantic/model.keras"));
// GraphDef def = GraphDef.newBuilder()
// .mergeFrom(graphDef)
// .build();
// graph.importGraphDef(def);
// try (Session session = new Session(graph)) {
// // // Prepare input data as a Tensor
// // Tensor inputTensor = Tensor.create(inputData);
// // // Run inference
// // Tensor result = session.runner()
// // .feed("input_node_name", inputTensor) // Replace with actual input tensor name
// // .fetch("output_node_name") // Replace with actual output tensor name
// // .run().get(0);
// // // Process the result
// // float[][] output = result.copyTo(new float[1][outputSize]);
// // System.out.println(Arrays.toString(output[0]));
// }
// } catch (IOException e) {
// e.printStackTrace();
// }
}
}

View File

@ -0,0 +1,59 @@
package org.studiorailgun.conversation;
import java.util.Scanner;
import org.studiorailgun.Globals;
import org.studiorailgun.conversation.parser.CommandParser;
import org.studiorailgun.conversation.tracking.Actor;
import org.studiorailgun.conversation.tracking.Conversation;
import org.studiorailgun.conversation.tracking.Statement;
import org.studiorailgun.kobold.KoboldPrinter;
import org.studiorailgun.kobold.KoboldRequest;
public class AgentLoop {
/**
* The size of the context available
*/
static final int CONTEXT_SIZE = 32 * 1024;
/**
* Controls whether the main parser loop is running
*/
public static boolean running = true;
/**
* The main method
*/
public static void main(){
Globals.init();
//setup actual input
Scanner scan = new Scanner(System.in);
String prompt = "";
//setup conversation tracking
Conversation convo = new Conversation();
Actor player = new Actor("John");
Actor ai = new Actor("Dave");
convo.addParticipant(player);
convo.addParticipant(ai);
//actual main loop
while(running){
//handle player statement
prompt = scan.nextLine();
if(CommandParser.parseCommands(prompt)){
continue;
}
convo.addStatement(new Statement(player, prompt));
//handle ai statement
throw new UnsupportedOperationException("asdf");
}
scan.close();
}
}

View File

@ -0,0 +1,62 @@
package org.studiorailgun.conversation;
import java.util.Scanner;
import org.studiorailgun.conversation.parser.CommandParser;
import org.studiorailgun.conversation.tracking.Actor;
import org.studiorailgun.conversation.tracking.Conversation;
import org.studiorailgun.conversation.tracking.Statement;
import org.studiorailgun.kobold.KoboldPrinter;
import org.studiorailgun.kobold.KoboldRequest;
public class LLMLoop {
/**
* The size of the context available
*/
static final int CONTEXT_SIZE = 32 * 1024;
/**
* Controls whether the main parser loop is running
*/
public static boolean running = true;
/**
* The main method
*/
public static void main(){
//setup actual input
Scanner scan = new Scanner(System.in);
String prompt = "";
//setup conversation tracking
Conversation convo = new Conversation();
Actor player = new Actor("John");
Actor ai = new Actor("Dave");
convo.addParticipant(player);
convo.addParticipant(ai);
//actual main loop
while(running){
//handle player statement
prompt = scan.nextLine();
if(CommandParser.parseCommands(prompt)){
continue;
}
convo.addStatement(new Statement(player, prompt));
//handle ai statement
KoboldRequest request = convo.generateStatementRequest(ai);
request.fire();
KoboldPrinter printer = new KoboldPrinter(request);
printer.print();
System.out.println("[System] Created prompt of size: " + request.getPrompt().length() + " (" + (request.getPrompt().length() / (float)CONTEXT_SIZE) + ")");
convo.pushbackRequest(ai, request);
}
scan.close();
}
}

View File

@ -0,0 +1,116 @@
package org.studiorailgun.conversation.categorization;
import java.io.File;
import java.io.IOException;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.BackpropType;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.api.InvocationType;
import org.deeplearning4j.optimize.listeners.EvaluativeListener;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.learning.config.Sgd;
/**
* Computational model for categorizing conversations
*/
public class ConvCategorizationNet {
/**
* The configuration for the network
*/
MultiLayerConfiguration conf;
/**
* The actual model created from the configuration
*/
MultiLayerNetwork model;
/**
* The number of epochs to try to train for
*/
int targetEpochs;
/**
* Initializes the model
*/
public void init(){
this.conf = new NeuralNetConfiguration.Builder()
//setup base config
.weightInit(WeightInit.XAVIER)
.activation(Activation.RELU)
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.updater(new Sgd(0.05))
.list()
//add layers
.layer(0, new DenseLayer.Builder()
.nIn(4096)
.nOut(4096)
.build()
)
//add back propagation
.backpropType(BackpropType.Standard)
//finalize
.build();
this.model = new MultiLayerNetwork(conf);
model.init();
}
public void apply(String input){
// this.model.output();
}
/**
* Trains the model on a set of data
* @param trainData The training data
* @param testData The test data
*/
public void train(DataSetIterator trainData, DataSetIterator testData){
System.out.println("Train model..");
model.setListeners(new ScoreIterationListener(1), new EvaluativeListener(testData, 1, InvocationType.EPOCH_END));
model.fit(trainData, targetEpochs);
}
/**
* Evaluates the model against the test data
* @param testData The test data
*/
public void evaluate(DataSetIterator testData){
System.out.println("Evaluate model..");
Evaluation eval = model.evaluate(testData);
System.out.println(eval.stats());
}
/**
* Saves the model to disk
* @param location The file to save the model into
*/
public void save(File location){
try {
model.save(location, true);
} catch (IOException e) {
e.printStackTrace();
}
}
/**
* Loads the model from disk
* @param location The location of the model file
*/
public void load(File location){
try {
this.model = MultiLayerNetwork.load(location, true);
} catch (IOException e) {
e.printStackTrace();
}
}
}

View File

@ -0,0 +1,36 @@
package org.studiorailgun.conversation.categorization;
/**
* Categorizes sentences based on function
*/
public class SentenceFunctionCategorizor {
/**
* The function of the sentence
*/
public static enum SentenceFunction {
/**
* Transfers information to the other party
* ie, declaring facts, declaring perspective of the world (how you feel), etc
* "The ball is red". "I don't like that".
*/
TRANSFER_INFORMATION,
/**
* Query information from the other party
* ie, asking for a fact or perspective
* "Is that ball red?" "What color is your hat?" "How do you feel about that?"
*/
QUERY_INFORMATION,
}
/**
* Categorizes the sentence by function
* @param input The input sentence
* @return The function of the sentence
*/
public static SentenceFunction categorize(String input){
return SentenceFunction.QUERY_INFORMATION;
}
}

View File

@ -0,0 +1,30 @@
package org.studiorailgun.conversation.parser;
import org.studiorailgun.conversation.AgentLoop;
import org.studiorailgun.conversation.LLMLoop;
/**
* Parses player commands to execute
*/
public class CommandParser {
/**
* Parses the input for commands. If a valid command is recognized, returns true. Otherwise returns false.
* @param input The input string
* @return true if a command was executed, false otherwise
*/
public static boolean parseCommands(String input){
if(input.equals("exit")){
LLMLoop.running = false;
AgentLoop.running = false;
} else if(input.equals("s hypo")){ //save a response failing to consider a hypothetical
} else if(input.equals("s q")){ //save a question
} else if(input.equals("s format")){ //save a formatting error
}
return false;
}
}

View File

@ -0,0 +1,53 @@
package org.studiorailgun.conversation.semantic;
import java.io.File;
import java.io.IOException;
import org.deeplearning4j.models.word2vec.Word2Vec;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.BackpropType;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.api.InvocationType;
import org.deeplearning4j.optimize.listeners.EvaluativeListener;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.deeplearning4j.text.sentenceiterator.LineSentenceIterator;
import org.deeplearning4j.text.sentenceiterator.SentenceIterator;
import org.deeplearning4j.text.tokenization.tokenizer.preprocessor.CommonPreprocessor;
import org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory;
import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory;
import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.learning.config.Sgd;
/**
* Parses the subject of a sentence
*/
public class SentenceSubjectParser {
/**
* The configuration for the network
*/
MultiLayerConfiguration conf;
/**
* The actual model created from the configuration
*/
MultiLayerNetwork model;
/**
* The number of epochs to try to train for
*/
int targetEpochs;
/**
* Initializes the model
*/
public void init(){
}
}

View File

@ -0,0 +1,39 @@
package org.studiorailgun.conversation.tracking;
/**
* A character in a conversation who can say statements
*/
public class Actor {
/**
* The name of the actor
*/
String name;
/**
* Creates an actor
* @param name The name
*/
public Actor(String name){
this.name = name;
}
/**
* Gets the name of the actor
* @return The name
*/
public String getName() {
return name;
}
/**
* Sets the name of the actor
* @param name The name
*/
public void setName(String name) {
this.name = name;
}
}

View File

@ -0,0 +1,113 @@
package org.studiorailgun.conversation.tracking;
import java.util.LinkedList;
import org.studiorailgun.kobold.KoboldRequest;
import org.studiorailgun.kobold.KoboldSymbols;
public class Conversation {
/**
* The participants in the conversation
*/
LinkedList<Actor> participants;
/**
* The statements of the conversation
*/
LinkedList<Statement> statements;
/**
* Number of statements to send in memory
*/
static final int STATEMENTS_TO_SEND_IN_MEMORY = 10;
/**
* Constructor
*/
public Conversation(){
this.participants = new LinkedList<Actor>();
this.statements = new LinkedList<Statement>();
}
/**
* Adds a participant to the conversation
* @param actor The new actor
*/
public void addParticipant(Actor actor){
this.participants.add(actor);
}
/**
* Adds a statement to the conversation
* @param statement The statement
*/
public void addStatement(Statement statement){
this.statements.add(statement);
statement.setConversation(this);
}
/**
* Requests a statement for the actor
* @param actor The actor
*/
public KoboldRequest generateStatementRequest(Actor actor){
KoboldRequest rVal = null;
String prompt = "";
//memory portion
prompt = prompt + "<|system|>\\n";
prompt = prompt + "You are named " + actor.getName() + "\\n";
prompt = prompt + "You are in a conversation with \\n";
for(Actor participant : this.participants){
if(participant != actor){
prompt = prompt + participant.getName() + "\\n";
}
}
prompt = prompt + "The most recent dialog in the conversation is\\n";
int increments = 0;
for(int i = this.statements.size() - 1; i >= 0 && increments < STATEMENTS_TO_SEND_IN_MEMORY; i--){
Statement statement = this.statements.get(i);
prompt = prompt + statement.getActor().getName() + ": " + statement.getContent() + "\\n";
increments++;
}
prompt = prompt + "<|end|>\\n";
//request new statement
prompt = prompt + "<|user|>\\n";
prompt = prompt + "What is the next thing " + actor.getName() + " should say?\\n";
prompt = prompt + "Give me exactly the quote " + actor.getName() + " should say and nothing more.\\n";
// prompt = prompt + "Do not describe the tone.\\n";
prompt = prompt + "Do not explain the quote.\\n";
prompt = prompt + "Do not give me dialog for any other character.\\n";
prompt = prompt + "Do not describe what other characters should do.\\n";
prompt = prompt + "Do not add notes.\\n";
prompt = prompt + "Do not include ways to continue the conversation.\\n";
prompt = prompt + "Do not include quotation marks.\\n";
prompt = prompt + "<|end|>\\n";
prompt = prompt + "<|assistant|>\\n";
//construct actual request from prompt
rVal = new KoboldRequest(prompt);
return rVal;
}
/**
* Pushes the result of a kobold request into the conversation stack
* @param actor The actor that said the statement
* @param request The request
*/
public void pushbackRequest(Actor actor, KoboldRequest request){
String rawResponse = request.getResponse();
String parsed = rawResponse;
//remove notes if they are detected
if(KoboldSymbols.containsBannedSymbol(parsed)){
parsed = KoboldSymbols.removeBannedSymbol(parsed);
}
//actually create and store statement
Statement statement = new Statement(actor, parsed);
this.addStatement(statement);
}
}

View File

@ -0,0 +1,79 @@
package org.studiorailgun.conversation.tracking;
/**
* A statement by a character in a conversation
*/
public class Statement {
/**
* Incrementer for ids
*/
static long idIncrementer = 0;
/**
* The id of this statement
*/
long id;
/**
* The raw content of the statement
*/
String content;
/**
* The actor who said the statement
*/
Actor actor;
/**
* The conversation this statement was uttered in
*/
Conversation conversation;
/**
* Creates a statement
* @param actor The actor who said the statement
* @param content The content of the statement
*/
public Statement(Actor actor, String content){
this.id = idIncrementer;
idIncrementer++;
this.actor = actor;
this.content = content;
}
public long getId() {
return id;
}
public void setId(long id) {
this.id = id;
}
public String getContent() {
return content;
}
public void setContent(String content) {
this.content = content;
}
public Actor getActor() {
return actor;
}
public void setActor(Actor actor) {
this.actor = actor;
}
public Conversation getConversation() {
return conversation;
}
public void setConversation(Conversation conversation) {
this.conversation = conversation;
}
}

View File

@ -0,0 +1,99 @@
package org.studiorailgun.knowledge;
import java.util.Collection;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
/**
* A knowledge web
*/
public class KnowledgeWeb {
/**
* The nodes;
*/
Map<Integer,Node> nodes;
/**
* The relations
*/
Map<Integer,Relation> relations;
/**
* Map of relation type -> list of relations of that type
*/
transient Map<String,List<Integer>> typeRelationLookup;
public KnowledgeWeb(){
this.nodes = new HashMap<Integer,Node>();
this.relations = new HashMap<Integer,Relation>();
this.typeRelationLookup = new HashMap<String,List<Integer>>();
}
/**
* Initializes the linked datastructures between items
*/
public void initLinks(){
//set id incrementers
for(Entry<Integer,Node> entry : this.nodes.entrySet()){
if(entry.getValue().getId() >= Node.idIncrementer){
Node.idIncrementer = entry.getValue().getId() + 1;
}
if(entry.getValue().getId() != entry.getKey()){
throw new Error("Node id doesn't match map value! " + entry.getValue().getId() + " " + entry.getValue().getName() + " " + entry.getKey());
}
}
for(Entry<Integer,Relation> entry : this.relations.entrySet()){
Relation relation = entry.getValue();
if(relation.getId() >= Relation.idIncrementer){
Relation.idIncrementer = relation.getId() + 1;
}
if(this.typeRelationLookup.containsKey(relation.getName())){
this.typeRelationLookup.get(relation.getName()).add(relation.getId());
} else {
LinkedList<Integer> list = new LinkedList<Integer>();
list.add(relation.getId());
this.typeRelationLookup.put(relation.getName(),list);
}
if(relation.getId() != entry.getKey()){
throw new Error("Relation id doesn't match map value! " + relation.getId() + " " + relation.getName() + " " + entry.getKey());
}
}
}
public Collection<Node> getNodes() {
return nodes.values();
}
public Collection<Relation> getRelations() {
return relations.values();
}
public Node getNode(int id){
return nodes.get(id);
}
/**
* Gets a singleton relationship with a given parent
* @param relationType The relationship type
* @param parent The parent
* @return The relation if it exists, null otherwise
*/
public Relation getSingletonRelation(String relationType, Node parent){
List<Integer> relations = this.typeRelationLookup.get(relationType);
if(relations != null){
for(int relationId : relations){
if(relationId == parent.getId()){
return this.relations.get(relationId);
}
}
}
return null;
}
}

View File

@ -0,0 +1,46 @@
package org.studiorailgun.knowledge;
/**
* A node in the knowledge web
*/
public class Node {
/**
* The id incremented
*/
public static int idIncrementer = 0;
/**
* The id of the node
*/
int id;
/**
* The name of the node
*/
String name;
/**
* Constructor
* @param name The name of the node
*/
public Node(String name){
this.id = idIncrementer;
idIncrementer++;
this.name = name;
}
public String getName() {
return name;
}
public int getId() {
return id;
}
}

View File

@ -0,0 +1,64 @@
package org.studiorailgun.knowledge;
import org.studiorailgun.Globals;
public class Relation {
/**
* The id incremented
*/
public static int idIncrementer = 0;
/**
* The id of the relationship
*/
int id;
/**
* The name of the relationship
*/
String name;
/**
* The parent of the relationship
*/
int parent;
/**
* The child of the relationship
*/
int child;
/**
* Constructor
* @param name The name of the relation
* @param parent The parent of the relation
* @param child The child of the relation
*/
public Relation(String name, Node parent, Node child){
this.id = idIncrementer;
idIncrementer++;
this.name = name;
this.parent = parent.getId();
this.child = child.getId();
}
public int getId() {
return id;
}
public String getName() {
return name;
}
public Node getParent() {
return Globals.web.getNode(this.parent);
}
public Node getChild() {
return Globals.web.getNode(this.child);
}
}

View File

@ -0,0 +1,68 @@
package org.studiorailgun.kobold;
import java.util.concurrent.TimeUnit;
/**
* Prints the response from a kobold request
*/
public class KoboldPrinter {
/**
* The number of milliseconds to wait between printing characters
*/
static final int CHAR_DELAY = 10;
/**
* The interval to check at in milliseconds
*/
static final int CHECK_INTERVAL = 300;
/**
* The request to print
*/
KoboldRequest request;
/**
* Position of the printer (how many characters has it printed)
*/
int position = 0;
/**
* Constructor
*/
public KoboldPrinter(KoboldRequest request){
this.request = request;
}
/**
* Prints the request. Blocks the thread until the printing has finished.
*/
public void print(){
int i = 0;
while(!this.request.isCompleted() || position < this.request.getResponse().length()){
try {
TimeUnit.MILLISECONDS.sleep(CHAR_DELAY);
} catch (InterruptedException e) {
e.printStackTrace();
}
if(!this.request.isCompleted() && i % (CHECK_INTERVAL / CHAR_DELAY) == 0){
this.request.check();
}
if(this.position < this.request.getResponse().length()){
System.out.print(this.request.getResponse().charAt(this.position));
this.position++;
}
//check for occurrences of immediate dialong ending symbols
if(KoboldSymbols.containsBannedSymbol(this.request.getResponse())){
String message = KoboldSymbols.removeBannedSymbol(this.request.getResponse());
this.request.markCompleted();
System.out.print("\r" + message);
break;
}
i++;
}
//make sure we're on a fresh line
System.out.println();
}
}

View File

@ -0,0 +1,188 @@
package org.studiorailgun.kobold;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.net.HttpURLConnection;
import java.net.MalformedURLException;
import java.net.ProtocolException;
import java.net.URI;
import java.net.URISyntaxException;
import java.net.URL;
import java.net.URLConnection;
import java.nio.charset.StandardCharsets;
import java.util.UUID;
/**
* A request to koboldcpp
*/
public class KoboldRequest {
/**
* Base Kobold URL
*/
static final String BASE_URL = "http://localhost:5001";
/**
* Endpoint to start generating
*/
static final String STREAM_ENDPOINT = "/api/v1/generate";
/**
* Endpoint to query generation status
*/
static final String CHECK_ENDPOINT = "/api/extra/generate/check";
/**
* The prompt
*/
String prompt;
/**
* The gen key for streaming
*/
String genKey;
/**
* The response string
*/
String response = "";
/**
* True if has received the complete generation response from kobold, false otherwise
*/
boolean completed = false;
/**
* Constructor
* @param prompt The prompt
*/
public KoboldRequest(String prompt){
this.prompt = prompt;
this.genKey = UUID.randomUUID().toString();
}
/**
* Fires the request to generate
*/
public void fire(){
try {
String urlPath = BASE_URL + STREAM_ENDPOINT;
URL url = new URI(urlPath).toURL();
URLConnection con = url.openConnection();
HttpURLConnection http = (HttpURLConnection)con;
http.setRequestMethod("POST"); // PUT is another valid option
http.setDoOutput(true);
String body = "{\"prompt\": \"" + this.prompt + "\", \"genkey\": \"" + genKey + "\"}";
byte[] out = body.getBytes(StandardCharsets.UTF_8);
int length = out.length;
http.setFixedLengthStreamingMode(length);
http.setRequestProperty("Content-Type", "application/json; charset=UTF-8");
http.connect();
try(OutputStream os = http.getOutputStream()) {
os.write(out);
}
http.disconnect();
//if uncommented, will block until response is fully generated
// try(InputStream is = http.getInputStream()){
// byte[] response = is.readAllBytes();
// String responseParsed = new String(response, StandardCharsets.UTF_8);
// System.out.println("parsed: " + responseParsed);
// }
} catch (MalformedURLException e) {
// TODO Auto-generated catch block
e.printStackTrace();
} catch (ProtocolException e) {
// TODO Auto-generated catch block
e.printStackTrace();
} catch (IOException e) {
// TODO Auto-generated catch block
e.printStackTrace();
} catch (URISyntaxException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
}
/**
* Checks for new text generated from the prompt
*/
public void check(){
try {
String urlPath = BASE_URL + CHECK_ENDPOINT;
URL url = new URI(urlPath).toURL();
URLConnection con = url.openConnection();
HttpURLConnection http = (HttpURLConnection)con;
http.setRequestMethod("POST"); // PUT is another valid option
http.setDoOutput(true);
http.setDoInput(true);
String body = "{\"genkey\": \"" + genKey + "\"}";
byte[] out = body.getBytes(StandardCharsets.UTF_8);
int length = out.length;
http.setFixedLengthStreamingMode(length);
http.setRequestProperty("Content-Type", "application/json; charset=UTF-8");
http.connect();
try(OutputStream os = http.getOutputStream()) {
os.write(out);
}
try(InputStream is = http.getInputStream()){
byte[] responseBytes = is.readAllBytes();
String responseParsed = new String(responseBytes, StandardCharsets.UTF_8);
String responseBody = responseParsed.substring(23, responseParsed.length() - 4);
if(this.response.equals(responseBody) && this.response.length() > 0){
this.completed = true;
} else {
this.response = responseBody;
}
}
} catch (MalformedURLException e) {
// TODO Auto-generated catch block
e.printStackTrace();
} catch (ProtocolException e) {
// TODO Auto-generated catch block
e.printStackTrace();
} catch (IOException e) {
// TODO Auto-generated catch block
e.printStackTrace();
} catch (URISyntaxException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
}
/**
* Gets the response
* @return The response
*/
public String getResponse(){
return response;
}
/**
* Gets whether this request has completed generation or not
* @return true if has completed generation, false otherwise
*/
public boolean isCompleted(){
return completed;
}
/**
* Flags the request as completed
*/
public void markCompleted(){
this.completed = true;
}
/**
* Gets the prompt of the request
* @return The prompt
*/
public String getPrompt(){
return prompt;
}
}

View File

@ -0,0 +1,55 @@
package org.studiorailgun.kobold;
/**
* Symbol tables for kobold parsing
*/
public class KoboldSymbols {
/**
* A list of symbols we do not want to appear in the output
*/
public static final String[] BANNED_SYMBOLS = new String[]{
"Note:",
"**Note**:",
"support:",
"Follow-up Questions:",
"Answer:",
"Prompt:",
"---",
"\\n\\n",
"Response to the instruction:",
};
/**
* Checks if the input contains a banned symbol
* @param input The input
* @return true if it contains a banned symbol, false otherwise
*/
public static boolean containsBannedSymbol(String input){
for(String symbol : BANNED_SYMBOLS){
if(input.contains(symbol)){
return true;
}
}
return false;
}
/**
* Removes banned symbols from the input
* @param input The input
* @return The input with banned symbols removed
*/
public static String removeBannedSymbol(String input){
String target = input;
while(KoboldSymbols.containsBannedSymbol(target)){
for(String symbol : BANNED_SYMBOLS){
if(input.contains(symbol)){
target = target.split(symbol)[0];
break;
}
}
}
return target;
}
}

View File

@ -0,0 +1,68 @@
package org.studiorailgun.schedule;
import java.util.LinkedList;
/**
* An event
*/
public class Event implements Comparable<Event>{
/**
* The time when this event will happen
*/
Time time;
/**
* The type of the event
*/
String type;
/**
* The data of the event
*/
LinkedList<Object> data = new LinkedList<Object>();
/**
* Creates an event
* @param time The time the event occurs
* @param type The type of the event
* @param params The data of the event
*/
public Event(Time time, String type, Object ... params){
this.time = time;
this.type = type;
for(Object el : params){
data.add(el);
}
}
/**
* Gets the time when the event occurs
* @return The time
*/
public Time getTime(){
return time;
}
/**
* Gets the type of the event
* @return The type of the event
*/
public String getType(){
return type;
}
/**
* Gets the data for the event
* @return The data for the event
*/
public LinkedList<Object> getData(){
return data;
}
@Override
public int compareTo(Event other) {
return this.time.compareTo(other.time);
}
}

View File

@ -0,0 +1,56 @@
package org.studiorailgun.schedule;
import java.util.PriorityQueue;
/**
* Maintains the schedule of events
*/
public class EventScheduler {
/**
* The queue of events
*/
PriorityQueue<Event> queue = new PriorityQueue<Event>();
/**
* Adds an event to the scheduler
* @param event The event
*/
public void add(Event event){
queue.add(event);
}
/**
* Peeks at the next event in the queue
* @return The next event
*/
public Event peek(){
return queue.peek();
}
/**
* Gets the next scheduled event
* @return The next scheduled event
*/
public Event get(){
return queue.poll();
}
/**
* Removes the event from the schedule
* @param event The event
*/
public void remove(Event event){
queue.remove(event);
}
/**
* Gets the queue of events
* @return The queue of events
*/
public PriorityQueue<Event> getQueue(){
return queue;
}
}

View File

@ -0,0 +1,149 @@
package org.studiorailgun.schedule;
/**
* A time in the simulation timeline
*/
public class Time implements Comparable<Time> {
/**
* Number of months in a year
*/
static final int MONTHS_IN_YEAR = 12;
/**
* Number of days in a month
*/
static final int DAYS_IN_MONTH = 30;
/**
* Number of hours in a day
*/
static final int HOURS_IN_DAY = 24;
/**
* Number of minutes in an hour
*/
static final int MINUTES_IN_HOUR = 60;
/**
* Number of seconds in a minute
*/
static final int SECONDS_IN_MINUTE = 60;
/**
* The value of the time
*/
long value;
/**
* Creates a time with a raw value
* @param value The value
*/
public Time(long value){
this.value = value;
}
/**
* Creates a time at a given year, month, and day
* @param year The year
* @param month The month
* @param day The day
* @param hour The hour
* @param minute The minute
* @param second The second
*/
public Time(int year, int month, int day, int hour, int minute, int second){
this.value = Time.getOffset(year, month, day, hour, minute, second);
}
/**
* Gets a time offset from this time
* @param year The number of years to offset
* @param month The number of months to offset
* @param day The number of days to offset
* @param hour The number of hours to offset
* @param minute The number of minutes to offset
* @param second The number of seconds to offset
* @return The time containing the offset value
*/
public Time offset(int year, int month, int day, int hour, int minute, int second){
Time rVal = new Time(this.value);
rVal.value = rVal.value + Time.getOffset(year, month, day, hour, minute, second);
return rVal;
}
/**
* Gets the year of this time
* @return The year
*/
public int year(){
return (int)this.value / (SECONDS_IN_MINUTE * MINUTES_IN_HOUR * HOURS_IN_DAY * DAYS_IN_MONTH * MONTHS_IN_YEAR);
}
/**
* Gets the month of this time
* @return The month
*/
public int month(){
return (int)this.value / (SECONDS_IN_MINUTE * MINUTES_IN_HOUR * HOURS_IN_DAY * DAYS_IN_MONTH);
}
/**
* Gets the day of this time
* @return The day
*/
public int day(){
return (int)this.value / (SECONDS_IN_MINUTE * MINUTES_IN_HOUR * HOURS_IN_DAY);
}
/**
* Gets the hour of this time
* @return The hour
*/
public int hour(){
return (int)this.value / (SECONDS_IN_MINUTE * MINUTES_IN_HOUR);
}
/**
* Gets the minute of this time
* @return The minute
*/
public int minute(){
return (int)this.value / SECONDS_IN_MINUTE;
}
/**
* Gets the second of this time
* @return The second
*/
public int second(){
return (int)this.value % SECONDS_IN_MINUTE;
}
/**
* Gets an offset given a set of legible time values
* @param year The number of years
* @param month The number of months
* @param day The number of days
* @param hour The number of hours
* @param minute The number of minutes
* @param second The number of seconds
* @return The offset
*/
private static long getOffset(int year, int month, int day, int hour, int minute, int second){
return
second +
minute * SECONDS_IN_MINUTE +
hour * SECONDS_IN_MINUTE * MINUTES_IN_HOUR +
day * SECONDS_IN_MINUTE * MINUTES_IN_HOUR * HOURS_IN_DAY +
month * SECONDS_IN_MINUTE * MINUTES_IN_HOUR * HOURS_IN_DAY * DAYS_IN_MONTH +
year * SECONDS_IN_MINUTE * MINUTES_IN_HOUR * HOURS_IN_DAY * DAYS_IN_MONTH * MONTHS_IN_YEAR
;
}
@Override
public int compareTo(Time other) {
return (int)(this.value - other.value);
}
}

View File

@ -0,0 +1,86 @@
import tensorflow as tf
keras = tf.keras
from tensorflow import Tensor
from keras.api.layers import TextVectorization, Embedding, LSTM, Dense, Input
from keras.api.models import Sequential
import numpy as np
import numpy.typing as npt
# Model constants.
max_features: int = 20000
embedding_dim: int = 128
sequence_length: int = 500
epochs: int = 50
max_tokens: int = 5000
output_sequence_length: int = 4
# read sentences
data_path: str = './data/semantic/subject.txt'
data_raw: str = open(data_path).read()
vocab: list[str] = data_raw.split('\n')
# read labels
label_data_path: str = './data/semantic/subject_label.txt'
label_data_raw: str = open(label_data_path).read()
labels: list[int] = list(map(int,label_data_raw.split()))
# init vectorizer
textVec: TextVectorization = TextVectorization(
max_tokens=max_tokens,
output_mode='int',
output_sequence_length=output_sequence_length,
pad_to_max_tokens=True)
# Add the vocab to the tokenizer
textVec.adapt(vocab)
input_data: list[str] = vocab
data: Tensor = textVec.call(input_data)
# construct model
model: Sequential = Sequential([
keras.Input(shape=(None,), dtype="int64"),
Embedding(max_features + 1, embedding_dim),
LSTM(64),
Dense(1, activation='sigmoid')
])
#compile the model
# model.build(keras.Input(shape=(None,), dtype="int64"))
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
# fit the training data
npData = np.array(data)
npLabel = np.array(labels)
model.fit(npData,npLabel,epochs=epochs)
# evaluate here
# predict
predictTargetRaw: list[str] = ['saf']
predictTargetToken: list[int] = textVec.call(predictTargetRaw)
npPredict: npt.NDArray[np.complex64] = np.array(predictTargetToken)
# print(npPredict)
result: list[int] = model.predict(npPredict)
print("predict result:")
print(predictTargetToken)
print(result)
print(data)
print(labels)
# save the model so keras can reload
# savePath: str = './data/semantic/model.keras'
# model.save(savePath)
# export the model so java can leverage it
exportPath: str = './data/semantic/model'
model.export(exportPath)
# tf.keras.utils.get_file('asdf')
# asdf: str = 'a'

44
web.json Normal file
View File

@ -0,0 +1,44 @@
{
"nodes" : {
"0" : {
"id" : 0,
"name" : "Bert"
},
"1" : {
"id" : 1,
"name" : "Name"
},
"2" : {
"id" : 2,
"name" : "Self"
},
"3" : {
"id" : 3,
"name" : "Person"
},
"4" : {
"id" : 4,
"name" : "ConversationParticipant"
}
},
"relations" : {
"0" : {
"id" : 0,
"name" : "Name",
"parent" : 0,
"child" : 2
},
"1" : {
"id" : 1,
"name" : "InstanceOf",
"parent" : 3,
"child" : 2
},
"2" : {
"id" : 2,
"name" : "InstanceOf",
"parent" : 1,
"child" : 0
}
}
}