first commit
This commit is contained in:
commit
cec13a218d
10
.gitignore
vendored
Normal file
10
.gitignore
vendored
Normal file
@ -0,0 +1,10 @@
|
||||
.mypy_cache
|
||||
.venv
|
||||
dist
|
||||
target
|
||||
src/trpg.egg-info
|
||||
|
||||
# ml model files
|
||||
data/semantic/model
|
||||
**.keras
|
||||
**.h5
|
||||
5
.vscode/extensions.json
vendored
Normal file
5
.vscode/extensions.json
vendored
Normal file
@ -0,0 +1,5 @@
|
||||
{
|
||||
"recommendations": [
|
||||
"ms-python.python",
|
||||
]
|
||||
}
|
||||
21
.vscode/launch.json
vendored
Normal file
21
.vscode/launch.json
vendored
Normal file
@ -0,0 +1,21 @@
|
||||
{
|
||||
// Use IntelliSense to learn about possible attributes.
|
||||
// Hover to view descriptions of existing attributes.
|
||||
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
|
||||
"version": "0.2.0",
|
||||
"configurations": [
|
||||
{
|
||||
"type": "java",
|
||||
"name": "Launch Java Program",
|
||||
"request": "launch",
|
||||
"mainClass": ""
|
||||
},
|
||||
{
|
||||
"name": "Python Debugger: Current File",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"program": "${file}",
|
||||
"console": "integratedTerminal"
|
||||
}
|
||||
]
|
||||
}
|
||||
4
.vscode/settings.json
vendored
Normal file
4
.vscode/settings.json
vendored
Normal file
@ -0,0 +1,4 @@
|
||||
{
|
||||
"java.configuration.updateBuildConfiguration": "interactive",
|
||||
"java.compile.nullAnalysis.mode": "disabled"
|
||||
}
|
||||
24
README.md
Normal file
24
README.md
Normal file
@ -0,0 +1,24 @@
|
||||
|
||||
|
||||
|
||||
```
|
||||
python -m venv .venv
|
||||
```
|
||||
|
||||
```
|
||||
source ./.venv/Scripts/activate
|
||||
```
|
||||
|
||||
```
|
||||
pip install tensorflow mypy numpy
|
||||
```
|
||||
|
||||
```
|
||||
pip install keras==2.15.0
|
||||
```
|
||||
|
||||
```
|
||||
pip install types-tensorflow
|
||||
```
|
||||
|
||||
In settings, make sure to check `Run Using Active Interpreter` in mypy settings
|
||||
39
current_goal.txt
Normal file
39
current_goal.txt
Normal file
@ -0,0 +1,39 @@
|
||||
context:
|
||||
user is a barbarian
|
||||
chatbot is a wizard
|
||||
sitting in a tavern by a fireplace
|
||||
|
||||
|
||||
interaction:
|
||||
the user says hello
|
||||
the chatbot replies and queries the user's name (query for name relation to other participant)
|
||||
the user answers and then asks the chatbot about the color of its hat (associate name relation to other participant)
|
||||
the chatbot replies with correct information about the hat (it is blue) (lookup complex association and synthesize response)
|
||||
|
||||
|
||||
|
||||
things needed:
|
||||
knowledge web
|
||||
- a person
|
||||
- a place
|
||||
- a thing
|
||||
- a relationship
|
||||
- between persons (ie the lack of relationship between the user and chatbot)
|
||||
- between a person and a thing (the hat the wizard is wearing)
|
||||
- information about a thing
|
||||
- color
|
||||
|
||||
routines for handling small chat
|
||||
- introductions
|
||||
- initial goal routine to query conversation partner for information about themself
|
||||
|
||||
information queries
|
||||
- determine target of query
|
||||
- determine quality/quantity being queried
|
||||
|
||||
information transfers
|
||||
- determine information that is being transferred
|
||||
- associate the new information with the knowledge web
|
||||
|
||||
distinguish between sentence types (query vs transfer vs command, etc)
|
||||
|
||||
8
data/semantic/subject.txt
Normal file
8
data/semantic/subject.txt
Normal file
@ -0,0 +1,8 @@
|
||||
fdsafdsa
|
||||
fdsafdsafdsaf
|
||||
dsafd
|
||||
saf
|
||||
dsafdfds
|
||||
afds
|
||||
i
|
||||
self
|
||||
8
data/semantic/subject_label.txt
Normal file
8
data/semantic/subject_label.txt
Normal file
@ -0,0 +1,8 @@
|
||||
1
|
||||
1
|
||||
1
|
||||
1
|
||||
0
|
||||
0
|
||||
0
|
||||
0
|
||||
3
data/tokens.txt
Normal file
3
data/tokens.txt
Normal file
@ -0,0 +1,3 @@
|
||||
self
|
||||
i
|
||||
I
|
||||
38
dependency-reduced-pom.xml
Normal file
38
dependency-reduced-pom.xml
Normal file
@ -0,0 +1,38 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/maven-v4_0_0.xsd">
|
||||
<modelVersion>4.0.0</modelVersion>
|
||||
<groupId>studiorailgun</groupId>
|
||||
<artifactId>Renderer</artifactId>
|
||||
<version>0.1.1</version>
|
||||
<build>
|
||||
<plugins>
|
||||
<plugin>
|
||||
<artifactId>maven-shade-plugin</artifactId>
|
||||
<version>3.2.4</version>
|
||||
<executions>
|
||||
<execution>
|
||||
<phase>package</phase>
|
||||
<goals>
|
||||
<goal>shade</goal>
|
||||
</goals>
|
||||
<configuration>
|
||||
<transformers>
|
||||
<transformer>
|
||||
<mainClass>org.studiorailgun.Main</mainClass>
|
||||
</transformer>
|
||||
<transformer />
|
||||
</transformers>
|
||||
</configuration>
|
||||
</execution>
|
||||
</executions>
|
||||
</plugin>
|
||||
</plugins>
|
||||
</build>
|
||||
<properties>
|
||||
<maven.compiler.target>17</maven.compiler.target>
|
||||
<maven.compiler.source>17</maven.compiler.source>
|
||||
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
|
||||
<dl4j-master.version>1.0.0-M2</dl4j-master.version>
|
||||
<nd4j.backend>nd4j-native</nd4j.backend>
|
||||
</properties>
|
||||
</project>
|
||||
1
improvement_ideas.txt
Normal file
1
improvement_ideas.txt
Normal file
@ -0,0 +1 @@
|
||||
summarize previous statements to provide context instead of using full statement
|
||||
27
mypy.ini
Normal file
27
mypy.ini
Normal file
@ -0,0 +1,27 @@
|
||||
# Global options:
|
||||
|
||||
[mypy]
|
||||
install_types = True
|
||||
disallow_untyped_defs = True
|
||||
no_implicit_optional = True
|
||||
check_untyped_defs = True
|
||||
warn_return_any = True
|
||||
show_error_codes = True
|
||||
warn_unused_ignores = True
|
||||
warn_unused_configs = True
|
||||
plugins = numpy.typing.mypy_plugin
|
||||
|
||||
disallow_incomplete_defs = True
|
||||
disallow_untyped_decorators = False
|
||||
disallow_any_unimported = False
|
||||
|
||||
# Per-module options:
|
||||
|
||||
[mypy-mycode.foo.*]
|
||||
disallow_untyped_defs = True
|
||||
|
||||
[mypy-mycode.bar]
|
||||
warn_return_any = False
|
||||
|
||||
[mypy-somelibrary]
|
||||
ignore_missing_imports = True
|
||||
92
pom.xml
Normal file
92
pom.xml
Normal file
@ -0,0 +1,92 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
|
||||
<modelVersion>4.0.0</modelVersion>
|
||||
<groupId>studiorailgun</groupId>
|
||||
<artifactId>Renderer</artifactId>
|
||||
<version>0.1.1</version>
|
||||
<packaging>jar</packaging>
|
||||
<properties>
|
||||
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
|
||||
<maven.compiler.source>21</maven.compiler.source>
|
||||
<maven.compiler.target>21</maven.compiler.target>
|
||||
<dl4j-master.version>1.0.0-M2</dl4j-master.version>
|
||||
<!-- Change the nd4j.backend property to nd4j-cuda-X-platform to use CUDA GPUs -->
|
||||
<nd4j.backend>nd4j-native</nd4j.backend>
|
||||
</properties>
|
||||
|
||||
<dependencies>
|
||||
|
||||
<!--DeepLearning4J-->
|
||||
<dependency>
|
||||
<groupId>org.deeplearning4j</groupId>
|
||||
<artifactId>deeplearning4j-core</artifactId>
|
||||
<version>${dl4j-master.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.deeplearning4j</groupId>
|
||||
<artifactId>deeplearning4j-nlp</artifactId>
|
||||
<version>${dl4j-master.version}</version>
|
||||
</dependency>
|
||||
|
||||
<!--Tensorflow Java-->
|
||||
<!-- <dependency>
|
||||
<groupId>org.tensorflow</groupId>
|
||||
<artifactId>tensorflow</artifactId>
|
||||
<version>1.15.0</version>
|
||||
</dependency> -->
|
||||
<dependency>
|
||||
<groupId>org.tensorflow</groupId>
|
||||
<artifactId>tensorflow-core-api</artifactId>
|
||||
<version>1.0.0</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.tensorflow</groupId>
|
||||
<artifactId>tensorflow-core-platform</artifactId>
|
||||
<version>1.0.0</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.tensorflow</groupId>
|
||||
<artifactId>tensorflow-core-native</artifactId>
|
||||
<version>1.0.0</version>
|
||||
</dependency>
|
||||
|
||||
|
||||
|
||||
|
||||
<dependency>
|
||||
<groupId>com.google.code.gson</groupId>
|
||||
<artifactId>gson</artifactId>
|
||||
<version>2.8.6</version>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
|
||||
<!-- MAIN BUILD -->
|
||||
<build>
|
||||
<plugins>
|
||||
|
||||
<!--Shade the jar (pack all dependencies)-->
|
||||
<plugin>
|
||||
<groupId>org.apache.maven.plugins</groupId>
|
||||
<artifactId>maven-shade-plugin</artifactId>
|
||||
<version>3.2.4</version>
|
||||
<executions>
|
||||
<execution>
|
||||
<phase>package</phase>
|
||||
<goals>
|
||||
<goal>shade</goal>
|
||||
</goals>
|
||||
<configuration>
|
||||
<transformers>
|
||||
<transformer implementation="org.apache.maven.plugins.shade.resource.ManifestResourceTransformer">
|
||||
<mainClass>org.studiorailgun.Main</mainClass>
|
||||
</transformer>
|
||||
<transformer implementation="org.apache.maven.plugins.shade.resource.ServicesResourceTransformer">
|
||||
</transformer>
|
||||
</transformers>
|
||||
</configuration>
|
||||
</execution>
|
||||
</executions>
|
||||
</plugin>
|
||||
</plugins>
|
||||
</build>
|
||||
</project>
|
||||
17
pyproject.toml
Normal file
17
pyproject.toml
Normal file
@ -0,0 +1,17 @@
|
||||
[build-system]
|
||||
requires = ["setuptools>=68", "setuptools_scm[toml]>=8"]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "trpg"
|
||||
requires-python = ">=3.8"
|
||||
dynamic = ["version"]
|
||||
dependencies = [
|
||||
# Add runtime dependencies here
|
||||
"numpy",
|
||||
"keras",
|
||||
"tensorflow"
|
||||
]
|
||||
|
||||
# Enables the usage of setuptools_scm
|
||||
[tool.setuptools_scm]
|
||||
280
src/main/java/org/studiorailgun/FileUtils.java
Normal file
280
src/main/java/org/studiorailgun/FileUtils.java
Normal file
@ -0,0 +1,280 @@
|
||||
package org.studiorailgun;
|
||||
|
||||
import com.google.gson.Gson;
|
||||
import com.google.gson.GsonBuilder;
|
||||
|
||||
import java.io.BufferedReader;
|
||||
import java.io.ByteArrayOutputStream;
|
||||
import java.io.File;
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
import java.io.InputStreamReader;
|
||||
import java.io.ObjectOutputStream;
|
||||
import java.io.Serializable;
|
||||
import java.nio.file.Files;
|
||||
import java.nio.file.Path;
|
||||
import java.security.MessageDigest;
|
||||
import java.security.NoSuchAlgorithmException;
|
||||
import java.util.LinkedList;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
|
||||
/**
|
||||
* Utilities for dealing with files
|
||||
*/
|
||||
public class FileUtils {
|
||||
|
||||
/**
|
||||
* Creates the gson instance
|
||||
*/
|
||||
static {
|
||||
//init gson
|
||||
GsonBuilder gsonBuilder = new GsonBuilder();
|
||||
gson = gsonBuilder.create();
|
||||
}
|
||||
|
||||
//used for serialization/deserialization in file operations
|
||||
static Gson gson;
|
||||
|
||||
//maximum number of attempt to read the file
|
||||
static final int maxReadFails = 3;
|
||||
//Timeout duration between read attempts
|
||||
static final int READ_TIMEOUT_DURATION = 5;
|
||||
/**
|
||||
* Reads a file to a string
|
||||
* @param f The file
|
||||
* @return The string
|
||||
*/
|
||||
public static String readFileToString(File f){
|
||||
String rVal = "";
|
||||
BufferedReader reader;
|
||||
try {
|
||||
reader = Files.newBufferedReader(f.toPath());
|
||||
int failCounter = 0;
|
||||
boolean reading = true;
|
||||
StringBuilder builder = new StringBuilder("");
|
||||
while(reading){
|
||||
if(reader.ready()){
|
||||
failCounter = 0;
|
||||
int nextValue = reader.read();
|
||||
if(nextValue == -1){
|
||||
reading = false;
|
||||
} else {
|
||||
builder.append((char)nextValue);
|
||||
}
|
||||
} else {
|
||||
failCounter++;
|
||||
if(failCounter > maxReadFails){
|
||||
reading = false;
|
||||
} else {
|
||||
try {
|
||||
TimeUnit.MILLISECONDS.sleep(READ_TIMEOUT_DURATION);
|
||||
} catch (InterruptedException ex) {
|
||||
ex.printStackTrace();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
rVal = builder.toString();
|
||||
} catch (IOException ex) {
|
||||
ex.printStackTrace();
|
||||
}
|
||||
return rVal;
|
||||
}
|
||||
|
||||
|
||||
public static String readStreamToString(InputStream resourceInputStream){
|
||||
String rVal = "";
|
||||
BufferedReader reader;
|
||||
try {
|
||||
reader = new BufferedReader(new InputStreamReader(resourceInputStream));
|
||||
int failCounter = 0;
|
||||
boolean reading = true;
|
||||
StringBuilder builder = new StringBuilder("");
|
||||
while(reading){
|
||||
if(reader.ready()){
|
||||
failCounter = 0;
|
||||
int nextValue = reader.read();
|
||||
if(nextValue == -1){
|
||||
reading = false;
|
||||
} else {
|
||||
builder.append((char)nextValue);
|
||||
}
|
||||
} else {
|
||||
failCounter++;
|
||||
if(failCounter > maxReadFails){
|
||||
reading = false;
|
||||
} else {
|
||||
try {
|
||||
TimeUnit.MILLISECONDS.sleep(READ_TIMEOUT_DURATION);
|
||||
} catch (InterruptedException ex) {
|
||||
ex.printStackTrace();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
rVal = builder.toString();
|
||||
} catch (IOException ex) {
|
||||
ex.printStackTrace();
|
||||
}
|
||||
return rVal;
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Sanitizes a relative file path, guaranteeing that the initial slash is correct
|
||||
* @param filePath The raw file path
|
||||
* @return The sanitized file path
|
||||
*/
|
||||
public static String sanitizeFilePath(String filePath){
|
||||
String rVal = new String(filePath);
|
||||
rVal = rVal.trim();
|
||||
if(rVal.startsWith("./")){
|
||||
return rVal;
|
||||
} else if(!rVal.startsWith("/")){
|
||||
rVal = "/" + rVal;
|
||||
}
|
||||
return rVal;
|
||||
}
|
||||
|
||||
/**
|
||||
* Serializes an object to a filepath
|
||||
* @param filePath The filepath
|
||||
* @param object The object
|
||||
*/
|
||||
public static void serializeObjectToFilePath(String filePath, Object object){
|
||||
Path path = new File(filePath).toPath();
|
||||
try {
|
||||
Files.write(path, gson.toJson(object).getBytes());
|
||||
} catch (IOException ex) {
|
||||
ex.printStackTrace();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Loads an object from the assets folder
|
||||
* @param <T> The type of object
|
||||
* @param file The file to load from
|
||||
* @param className The class of the object inside the file
|
||||
* @return The file
|
||||
*/
|
||||
public static <T>T loadObjectFromFile(File file, Class<T> className){
|
||||
T rVal = null;
|
||||
try {
|
||||
rVal = gson.fromJson(Files.newBufferedReader(file.toPath()), className);
|
||||
} catch (IOException ex) {
|
||||
ex.printStackTrace();
|
||||
}
|
||||
return rVal;
|
||||
}
|
||||
|
||||
/**
|
||||
* Checks if a directory exists
|
||||
* @param fileName
|
||||
* @return true if directory exists, false otherwise
|
||||
*/
|
||||
public static boolean checkFileExists(String fileName){
|
||||
File targetDir = new File(sanitizeFilePath(fileName));
|
||||
if(targetDir.exists()){
|
||||
return true;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Trys to create a directory
|
||||
* @param directoryName
|
||||
* @return true if directory was created, false if it was not
|
||||
*/
|
||||
public static boolean createDirectory(String directoryName){
|
||||
String sanitizedPath = sanitizeFilePath(directoryName);
|
||||
File targetDir = new File(sanitizedPath);
|
||||
if(targetDir.exists()){
|
||||
return false;
|
||||
} else {
|
||||
return targetDir.mkdirs();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Lists the files in a directory
|
||||
* @param directoryName The path of the directory
|
||||
* @return A list containing the names of all files inside that directory
|
||||
*/
|
||||
public static List<String> listDirectory(String directoryName){
|
||||
List<String> rVal = new LinkedList<String>();
|
||||
String sanitizedPath = sanitizeFilePath(directoryName);
|
||||
File targetDir = new File(sanitizedPath);
|
||||
String[] files = targetDir.list();
|
||||
for(String name : files){
|
||||
rVal.add(name);
|
||||
}
|
||||
return rVal;
|
||||
}
|
||||
|
||||
/**
|
||||
* Recursively deletes a path
|
||||
* @param path The path
|
||||
*/
|
||||
public static void recursivelyDelete(String path){
|
||||
File file = new File(path);
|
||||
if(file.isDirectory()){
|
||||
for(File child : file.listFiles()){
|
||||
recursivelyDelete(child.getAbsolutePath());
|
||||
}
|
||||
}
|
||||
if(file.exists()){
|
||||
try {
|
||||
Files.delete(file.toPath());
|
||||
} catch (IOException e) {
|
||||
// TODO Auto-generated catch block
|
||||
e.printStackTrace();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets a file's path relative to a given directory
|
||||
* @param file The file
|
||||
* @param directory The directory
|
||||
* @return The relative path
|
||||
*/
|
||||
public static String relativize(File file, File directory){
|
||||
return directory.toURI().relativize(file.toURI()).getPath();
|
||||
}
|
||||
|
||||
|
||||
|
||||
/**
|
||||
* Computes the checksum of an object
|
||||
* @param object The object
|
||||
* @return The checksum
|
||||
* @throws IOException Thrown on io errors reading the file
|
||||
* @throws NoSuchAlgorithmException Thrown if MD5 isn't supported
|
||||
*/
|
||||
public static String getChecksum(Serializable object) throws IOException, NoSuchAlgorithmException {
|
||||
ByteArrayOutputStream baos = null;
|
||||
ObjectOutputStream oos = null;
|
||||
try {
|
||||
baos = new ByteArrayOutputStream();
|
||||
oos = new ObjectOutputStream(baos);
|
||||
oos.writeObject(object);
|
||||
MessageDigest md = MessageDigest.getInstance("MD5");
|
||||
byte[] bytes = md.digest(baos.toByteArray());
|
||||
StringBuffer builder = new StringBuffer();
|
||||
for(byte byteCurr : bytes){
|
||||
builder.append(String.format("%02x",byteCurr));
|
||||
}
|
||||
return builder.toString();
|
||||
} finally {
|
||||
oos.close();
|
||||
baos.close();
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
}
|
||||
25
src/main/java/org/studiorailgun/Globals.java
Normal file
25
src/main/java/org/studiorailgun/Globals.java
Normal file
@ -0,0 +1,25 @@
|
||||
package org.studiorailgun;
|
||||
|
||||
import java.io.File;
|
||||
|
||||
import org.studiorailgun.knowledge.KnowledgeWeb;
|
||||
|
||||
/**
|
||||
* Global variables
|
||||
*/
|
||||
public class Globals {
|
||||
|
||||
/**
|
||||
* The knowledge web
|
||||
*/
|
||||
public static KnowledgeWeb web;
|
||||
|
||||
/**
|
||||
* Initializes the knowledge web
|
||||
*/
|
||||
public static void init(){
|
||||
web = FileUtils.loadObjectFromFile(new File("web.json"), KnowledgeWeb.class);
|
||||
web.initLinks();
|
||||
}
|
||||
|
||||
}
|
||||
82
src/main/java/org/studiorailgun/Main.java
Normal file
82
src/main/java/org/studiorailgun/Main.java
Normal file
@ -0,0 +1,82 @@
|
||||
package org.studiorailgun;
|
||||
|
||||
import org.tensorflow.SavedModelBundle;
|
||||
import org.tensorflow.Tensor;
|
||||
import org.tensorflow.types.TInt64;
|
||||
|
||||
/**
|
||||
* The main class
|
||||
*/
|
||||
public class Main {
|
||||
|
||||
/**
|
||||
* The main method
|
||||
*/
|
||||
public static void main(String[] args){
|
||||
// AgentLoop.main();
|
||||
// try (Graph graph = new Graph()) {
|
||||
// prediction_signature = tf.saved_model.signature_def_utils.predict_signature_def({"input": inputs}, {"output":output});
|
||||
SavedModelBundle model = SavedModelBundle.load("./data/semantic/model");
|
||||
System.out.println(model.functions());
|
||||
System.out.println(model.functions().get(0).signature());
|
||||
System.out.println(model.functions().get(1).signature());
|
||||
System.out.println(model.functions().get(2).signature());
|
||||
System.out.println(model.functions().get(0).signature().getInputs());
|
||||
System.out.println(model.functions().get(0).signature().getInputs().values().iterator().next().shape);
|
||||
//init
|
||||
// model.functions().get(2).session().runner()
|
||||
// .fetch("__saved_model_init_op")
|
||||
// .run();
|
||||
//run predict
|
||||
TInt64 tensor = TInt64.scalarOf(10);
|
||||
org.tensorflow.ndarray.Shape shape = tensor.shape().append(2001).append(128);
|
||||
tensor = TInt64.tensorOf(shape);
|
||||
// NdArray<LongNdArray> arr = NdArrays.vectorOfObjects(NdArrays.vectorOf((long)0, (long)0, (long)0, (long)0));
|
||||
// System.out.println(arr.shape());
|
||||
// tensor = TInt64.tensorOf(arr.shape());
|
||||
|
||||
Tensor result = model.function("serve").call(tensor);
|
||||
// model.functions().get(0).session().runner()
|
||||
// .feed("keras_tensor", tensor)
|
||||
// .fetch("output_0")
|
||||
// .run();
|
||||
|
||||
// Graph graph = model.graph();
|
||||
// Session session = model.session();
|
||||
|
||||
// Tensor<Long> inputTensor = Tensor.of(Long.class, model.functions().get(0).signature().getInputs().values().iterator().next().shape);
|
||||
// NdArrays.ofLongs(shape(2, 3, 2));
|
||||
// TInt64 tensor = TInt64.vectorOf(10,100);
|
||||
// // TInt64 tensor = TInt64.tensorOf(model.functions().get(0).signature().getInputs().values().iterator().next().shape);
|
||||
// Tensor result = session.runner()
|
||||
// .feed("keras_tensor", tensor)
|
||||
// .fetch("output_0")
|
||||
// .run().get(0);
|
||||
|
||||
// Result result = model.call(tensors);
|
||||
// byte[] graphDef = Files.readAllBytes(Paths.get("./data/semantic/model.keras"));
|
||||
// GraphDef def = GraphDef.newBuilder()
|
||||
// .mergeFrom(graphDef)
|
||||
// .build();
|
||||
// graph.importGraphDef(def);
|
||||
|
||||
// try (Session session = new Session(graph)) {
|
||||
// // // Prepare input data as a Tensor
|
||||
// // Tensor inputTensor = Tensor.create(inputData);
|
||||
|
||||
// // // Run inference
|
||||
// // Tensor result = session.runner()
|
||||
// // .feed("input_node_name", inputTensor) // Replace with actual input tensor name
|
||||
// // .fetch("output_node_name") // Replace with actual output tensor name
|
||||
// // .run().get(0);
|
||||
|
||||
// // // Process the result
|
||||
// // float[][] output = result.copyTo(new float[1][outputSize]);
|
||||
// // System.out.println(Arrays.toString(output[0]));
|
||||
// }
|
||||
// } catch (IOException e) {
|
||||
// e.printStackTrace();
|
||||
// }
|
||||
}
|
||||
|
||||
}
|
||||
59
src/main/java/org/studiorailgun/conversation/AgentLoop.java
Normal file
59
src/main/java/org/studiorailgun/conversation/AgentLoop.java
Normal file
@ -0,0 +1,59 @@
|
||||
package org.studiorailgun.conversation;
|
||||
|
||||
import java.util.Scanner;
|
||||
|
||||
import org.studiorailgun.Globals;
|
||||
import org.studiorailgun.conversation.parser.CommandParser;
|
||||
import org.studiorailgun.conversation.tracking.Actor;
|
||||
import org.studiorailgun.conversation.tracking.Conversation;
|
||||
import org.studiorailgun.conversation.tracking.Statement;
|
||||
import org.studiorailgun.kobold.KoboldPrinter;
|
||||
import org.studiorailgun.kobold.KoboldRequest;
|
||||
|
||||
public class AgentLoop {
|
||||
|
||||
/**
|
||||
* The size of the context available
|
||||
*/
|
||||
static final int CONTEXT_SIZE = 32 * 1024;
|
||||
|
||||
/**
|
||||
* Controls whether the main parser loop is running
|
||||
*/
|
||||
public static boolean running = true;
|
||||
|
||||
/**
|
||||
* The main method
|
||||
*/
|
||||
public static void main(){
|
||||
Globals.init();
|
||||
|
||||
//setup actual input
|
||||
Scanner scan = new Scanner(System.in);
|
||||
String prompt = "";
|
||||
|
||||
//setup conversation tracking
|
||||
Conversation convo = new Conversation();
|
||||
Actor player = new Actor("John");
|
||||
Actor ai = new Actor("Dave");
|
||||
convo.addParticipant(player);
|
||||
convo.addParticipant(ai);
|
||||
|
||||
//actual main loop
|
||||
while(running){
|
||||
|
||||
//handle player statement
|
||||
prompt = scan.nextLine();
|
||||
if(CommandParser.parseCommands(prompt)){
|
||||
continue;
|
||||
}
|
||||
convo.addStatement(new Statement(player, prompt));
|
||||
|
||||
//handle ai statement
|
||||
throw new UnsupportedOperationException("asdf");
|
||||
}
|
||||
|
||||
scan.close();
|
||||
}
|
||||
|
||||
}
|
||||
62
src/main/java/org/studiorailgun/conversation/LLMLoop.java
Normal file
62
src/main/java/org/studiorailgun/conversation/LLMLoop.java
Normal file
@ -0,0 +1,62 @@
|
||||
package org.studiorailgun.conversation;
|
||||
|
||||
import java.util.Scanner;
|
||||
|
||||
import org.studiorailgun.conversation.parser.CommandParser;
|
||||
import org.studiorailgun.conversation.tracking.Actor;
|
||||
import org.studiorailgun.conversation.tracking.Conversation;
|
||||
import org.studiorailgun.conversation.tracking.Statement;
|
||||
import org.studiorailgun.kobold.KoboldPrinter;
|
||||
import org.studiorailgun.kobold.KoboldRequest;
|
||||
|
||||
public class LLMLoop {
|
||||
|
||||
/**
|
||||
* The size of the context available
|
||||
*/
|
||||
static final int CONTEXT_SIZE = 32 * 1024;
|
||||
|
||||
/**
|
||||
* Controls whether the main parser loop is running
|
||||
*/
|
||||
public static boolean running = true;
|
||||
|
||||
/**
|
||||
* The main method
|
||||
*/
|
||||
public static void main(){
|
||||
|
||||
//setup actual input
|
||||
Scanner scan = new Scanner(System.in);
|
||||
String prompt = "";
|
||||
|
||||
//setup conversation tracking
|
||||
Conversation convo = new Conversation();
|
||||
Actor player = new Actor("John");
|
||||
Actor ai = new Actor("Dave");
|
||||
convo.addParticipant(player);
|
||||
convo.addParticipant(ai);
|
||||
|
||||
//actual main loop
|
||||
while(running){
|
||||
|
||||
//handle player statement
|
||||
prompt = scan.nextLine();
|
||||
if(CommandParser.parseCommands(prompt)){
|
||||
continue;
|
||||
}
|
||||
convo.addStatement(new Statement(player, prompt));
|
||||
|
||||
//handle ai statement
|
||||
KoboldRequest request = convo.generateStatementRequest(ai);
|
||||
request.fire();
|
||||
KoboldPrinter printer = new KoboldPrinter(request);
|
||||
printer.print();
|
||||
System.out.println("[System] Created prompt of size: " + request.getPrompt().length() + " (" + (request.getPrompt().length() / (float)CONTEXT_SIZE) + ")");
|
||||
convo.pushbackRequest(ai, request);
|
||||
}
|
||||
|
||||
scan.close();
|
||||
}
|
||||
|
||||
}
|
||||
@ -0,0 +1,116 @@
|
||||
package org.studiorailgun.conversation.categorization;
|
||||
|
||||
import java.io.File;
|
||||
import java.io.IOException;
|
||||
|
||||
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
|
||||
import org.deeplearning4j.nn.conf.BackpropType;
|
||||
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||
import org.deeplearning4j.nn.conf.layers.DenseLayer;
|
||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||
import org.deeplearning4j.nn.weights.WeightInit;
|
||||
import org.deeplearning4j.optimize.api.InvocationType;
|
||||
import org.deeplearning4j.optimize.listeners.EvaluativeListener;
|
||||
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
|
||||
import org.nd4j.evaluation.classification.Evaluation;
|
||||
import org.nd4j.linalg.activations.Activation;
|
||||
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
||||
import org.nd4j.linalg.learning.config.Sgd;
|
||||
|
||||
/**
|
||||
* Computational model for categorizing conversations
|
||||
*/
|
||||
public class ConvCategorizationNet {
|
||||
|
||||
/**
|
||||
* The configuration for the network
|
||||
*/
|
||||
MultiLayerConfiguration conf;
|
||||
|
||||
/**
|
||||
* The actual model created from the configuration
|
||||
*/
|
||||
MultiLayerNetwork model;
|
||||
|
||||
/**
|
||||
* The number of epochs to try to train for
|
||||
*/
|
||||
int targetEpochs;
|
||||
|
||||
/**
|
||||
* Initializes the model
|
||||
*/
|
||||
public void init(){
|
||||
this.conf = new NeuralNetConfiguration.Builder()
|
||||
//setup base config
|
||||
.weightInit(WeightInit.XAVIER)
|
||||
.activation(Activation.RELU)
|
||||
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
|
||||
.updater(new Sgd(0.05))
|
||||
.list()
|
||||
//add layers
|
||||
.layer(0, new DenseLayer.Builder()
|
||||
.nIn(4096)
|
||||
.nOut(4096)
|
||||
.build()
|
||||
)
|
||||
//add back propagation
|
||||
.backpropType(BackpropType.Standard)
|
||||
//finalize
|
||||
.build();
|
||||
|
||||
this.model = new MultiLayerNetwork(conf);
|
||||
model.init();
|
||||
}
|
||||
|
||||
public void apply(String input){
|
||||
// this.model.output();
|
||||
}
|
||||
|
||||
/**
|
||||
* Trains the model on a set of data
|
||||
* @param trainData The training data
|
||||
* @param testData The test data
|
||||
*/
|
||||
public void train(DataSetIterator trainData, DataSetIterator testData){
|
||||
System.out.println("Train model..");
|
||||
model.setListeners(new ScoreIterationListener(1), new EvaluativeListener(testData, 1, InvocationType.EPOCH_END));
|
||||
model.fit(trainData, targetEpochs);
|
||||
}
|
||||
|
||||
/**
|
||||
* Evaluates the model against the test data
|
||||
* @param testData The test data
|
||||
*/
|
||||
public void evaluate(DataSetIterator testData){
|
||||
System.out.println("Evaluate model..");
|
||||
Evaluation eval = model.evaluate(testData);
|
||||
System.out.println(eval.stats());
|
||||
}
|
||||
|
||||
/**
|
||||
* Saves the model to disk
|
||||
* @param location The file to save the model into
|
||||
*/
|
||||
public void save(File location){
|
||||
try {
|
||||
model.save(location, true);
|
||||
} catch (IOException e) {
|
||||
e.printStackTrace();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Loads the model from disk
|
||||
* @param location The location of the model file
|
||||
*/
|
||||
public void load(File location){
|
||||
try {
|
||||
this.model = MultiLayerNetwork.load(location, true);
|
||||
} catch (IOException e) {
|
||||
e.printStackTrace();
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
@ -0,0 +1,36 @@
|
||||
package org.studiorailgun.conversation.categorization;
|
||||
|
||||
/**
|
||||
* Categorizes sentences based on function
|
||||
*/
|
||||
public class SentenceFunctionCategorizor {
|
||||
|
||||
/**
|
||||
* The function of the sentence
|
||||
*/
|
||||
public static enum SentenceFunction {
|
||||
/**
|
||||
* Transfers information to the other party
|
||||
* ie, declaring facts, declaring perspective of the world (how you feel), etc
|
||||
* "The ball is red". "I don't like that".
|
||||
*/
|
||||
TRANSFER_INFORMATION,
|
||||
|
||||
/**
|
||||
* Query information from the other party
|
||||
* ie, asking for a fact or perspective
|
||||
* "Is that ball red?" "What color is your hat?" "How do you feel about that?"
|
||||
*/
|
||||
QUERY_INFORMATION,
|
||||
}
|
||||
|
||||
/**
|
||||
* Categorizes the sentence by function
|
||||
* @param input The input sentence
|
||||
* @return The function of the sentence
|
||||
*/
|
||||
public static SentenceFunction categorize(String input){
|
||||
return SentenceFunction.QUERY_INFORMATION;
|
||||
}
|
||||
|
||||
}
|
||||
@ -0,0 +1,30 @@
|
||||
package org.studiorailgun.conversation.parser;
|
||||
|
||||
import org.studiorailgun.conversation.AgentLoop;
|
||||
import org.studiorailgun.conversation.LLMLoop;
|
||||
|
||||
/**
|
||||
* Parses player commands to execute
|
||||
*/
|
||||
public class CommandParser {
|
||||
|
||||
/**
|
||||
* Parses the input for commands. If a valid command is recognized, returns true. Otherwise returns false.
|
||||
* @param input The input string
|
||||
* @return true if a command was executed, false otherwise
|
||||
*/
|
||||
public static boolean parseCommands(String input){
|
||||
if(input.equals("exit")){
|
||||
LLMLoop.running = false;
|
||||
AgentLoop.running = false;
|
||||
} else if(input.equals("s hypo")){ //save a response failing to consider a hypothetical
|
||||
|
||||
} else if(input.equals("s q")){ //save a question
|
||||
|
||||
} else if(input.equals("s format")){ //save a formatting error
|
||||
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
}
|
||||
@ -0,0 +1,53 @@
|
||||
package org.studiorailgun.conversation.semantic;
|
||||
|
||||
import java.io.File;
|
||||
import java.io.IOException;
|
||||
|
||||
import org.deeplearning4j.models.word2vec.Word2Vec;
|
||||
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
|
||||
import org.deeplearning4j.nn.conf.BackpropType;
|
||||
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||
import org.deeplearning4j.nn.conf.layers.DenseLayer;
|
||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||
import org.deeplearning4j.nn.weights.WeightInit;
|
||||
import org.deeplearning4j.optimize.api.InvocationType;
|
||||
import org.deeplearning4j.optimize.listeners.EvaluativeListener;
|
||||
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
|
||||
import org.deeplearning4j.text.sentenceiterator.LineSentenceIterator;
|
||||
import org.deeplearning4j.text.sentenceiterator.SentenceIterator;
|
||||
import org.deeplearning4j.text.tokenization.tokenizer.preprocessor.CommonPreprocessor;
|
||||
import org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory;
|
||||
import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory;
|
||||
import org.nd4j.evaluation.classification.Evaluation;
|
||||
import org.nd4j.linalg.activations.Activation;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
||||
import org.nd4j.linalg.learning.config.Sgd;
|
||||
|
||||
/**
|
||||
* Parses the subject of a sentence
|
||||
*/
|
||||
public class SentenceSubjectParser {
|
||||
/**
|
||||
* The configuration for the network
|
||||
*/
|
||||
MultiLayerConfiguration conf;
|
||||
|
||||
/**
|
||||
* The actual model created from the configuration
|
||||
*/
|
||||
MultiLayerNetwork model;
|
||||
|
||||
/**
|
||||
* The number of epochs to try to train for
|
||||
*/
|
||||
int targetEpochs;
|
||||
|
||||
/**
|
||||
* Initializes the model
|
||||
*/
|
||||
public void init(){
|
||||
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,39 @@
|
||||
package org.studiorailgun.conversation.tracking;
|
||||
|
||||
/**
|
||||
* A character in a conversation who can say statements
|
||||
*/
|
||||
public class Actor {
|
||||
|
||||
/**
|
||||
* The name of the actor
|
||||
*/
|
||||
String name;
|
||||
|
||||
/**
|
||||
* Creates an actor
|
||||
* @param name The name
|
||||
*/
|
||||
public Actor(String name){
|
||||
this.name = name;
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the name of the actor
|
||||
* @return The name
|
||||
*/
|
||||
public String getName() {
|
||||
return name;
|
||||
}
|
||||
|
||||
/**
|
||||
* Sets the name of the actor
|
||||
* @param name The name
|
||||
*/
|
||||
public void setName(String name) {
|
||||
this.name = name;
|
||||
}
|
||||
|
||||
|
||||
|
||||
}
|
||||
@ -0,0 +1,113 @@
|
||||
package org.studiorailgun.conversation.tracking;
|
||||
|
||||
import java.util.LinkedList;
|
||||
|
||||
import org.studiorailgun.kobold.KoboldRequest;
|
||||
import org.studiorailgun.kobold.KoboldSymbols;
|
||||
|
||||
public class Conversation {
|
||||
|
||||
/**
|
||||
* The participants in the conversation
|
||||
*/
|
||||
LinkedList<Actor> participants;
|
||||
|
||||
/**
|
||||
* The statements of the conversation
|
||||
*/
|
||||
LinkedList<Statement> statements;
|
||||
|
||||
/**
|
||||
* Number of statements to send in memory
|
||||
*/
|
||||
static final int STATEMENTS_TO_SEND_IN_MEMORY = 10;
|
||||
|
||||
/**
|
||||
* Constructor
|
||||
*/
|
||||
public Conversation(){
|
||||
this.participants = new LinkedList<Actor>();
|
||||
this.statements = new LinkedList<Statement>();
|
||||
}
|
||||
|
||||
/**
|
||||
* Adds a participant to the conversation
|
||||
* @param actor The new actor
|
||||
*/
|
||||
public void addParticipant(Actor actor){
|
||||
this.participants.add(actor);
|
||||
}
|
||||
|
||||
/**
|
||||
* Adds a statement to the conversation
|
||||
* @param statement The statement
|
||||
*/
|
||||
public void addStatement(Statement statement){
|
||||
this.statements.add(statement);
|
||||
statement.setConversation(this);
|
||||
}
|
||||
|
||||
/**
|
||||
* Requests a statement for the actor
|
||||
* @param actor The actor
|
||||
*/
|
||||
public KoboldRequest generateStatementRequest(Actor actor){
|
||||
KoboldRequest rVal = null;
|
||||
String prompt = "";
|
||||
|
||||
//memory portion
|
||||
prompt = prompt + "<|system|>\\n";
|
||||
prompt = prompt + "You are named " + actor.getName() + "\\n";
|
||||
prompt = prompt + "You are in a conversation with \\n";
|
||||
for(Actor participant : this.participants){
|
||||
if(participant != actor){
|
||||
prompt = prompt + participant.getName() + "\\n";
|
||||
}
|
||||
}
|
||||
prompt = prompt + "The most recent dialog in the conversation is\\n";
|
||||
int increments = 0;
|
||||
for(int i = this.statements.size() - 1; i >= 0 && increments < STATEMENTS_TO_SEND_IN_MEMORY; i--){
|
||||
Statement statement = this.statements.get(i);
|
||||
prompt = prompt + statement.getActor().getName() + ": " + statement.getContent() + "\\n";
|
||||
increments++;
|
||||
}
|
||||
prompt = prompt + "<|end|>\\n";
|
||||
|
||||
//request new statement
|
||||
prompt = prompt + "<|user|>\\n";
|
||||
prompt = prompt + "What is the next thing " + actor.getName() + " should say?\\n";
|
||||
prompt = prompt + "Give me exactly the quote " + actor.getName() + " should say and nothing more.\\n";
|
||||
// prompt = prompt + "Do not describe the tone.\\n";
|
||||
prompt = prompt + "Do not explain the quote.\\n";
|
||||
prompt = prompt + "Do not give me dialog for any other character.\\n";
|
||||
prompt = prompt + "Do not describe what other characters should do.\\n";
|
||||
prompt = prompt + "Do not add notes.\\n";
|
||||
prompt = prompt + "Do not include ways to continue the conversation.\\n";
|
||||
prompt = prompt + "Do not include quotation marks.\\n";
|
||||
prompt = prompt + "<|end|>\\n";
|
||||
prompt = prompt + "<|assistant|>\\n";
|
||||
|
||||
//construct actual request from prompt
|
||||
rVal = new KoboldRequest(prompt);
|
||||
return rVal;
|
||||
}
|
||||
|
||||
/**
|
||||
* Pushes the result of a kobold request into the conversation stack
|
||||
* @param actor The actor that said the statement
|
||||
* @param request The request
|
||||
*/
|
||||
public void pushbackRequest(Actor actor, KoboldRequest request){
|
||||
String rawResponse = request.getResponse();
|
||||
String parsed = rawResponse;
|
||||
//remove notes if they are detected
|
||||
if(KoboldSymbols.containsBannedSymbol(parsed)){
|
||||
parsed = KoboldSymbols.removeBannedSymbol(parsed);
|
||||
}
|
||||
|
||||
//actually create and store statement
|
||||
Statement statement = new Statement(actor, parsed);
|
||||
this.addStatement(statement);
|
||||
}
|
||||
|
||||
}
|
||||
@ -0,0 +1,79 @@
|
||||
package org.studiorailgun.conversation.tracking;
|
||||
|
||||
/**
|
||||
* A statement by a character in a conversation
|
||||
*/
|
||||
public class Statement {
|
||||
|
||||
/**
|
||||
* Incrementer for ids
|
||||
*/
|
||||
static long idIncrementer = 0;
|
||||
|
||||
/**
|
||||
* The id of this statement
|
||||
*/
|
||||
long id;
|
||||
|
||||
/**
|
||||
* The raw content of the statement
|
||||
*/
|
||||
String content;
|
||||
|
||||
/**
|
||||
* The actor who said the statement
|
||||
*/
|
||||
Actor actor;
|
||||
|
||||
/**
|
||||
* The conversation this statement was uttered in
|
||||
*/
|
||||
Conversation conversation;
|
||||
|
||||
/**
|
||||
* Creates a statement
|
||||
* @param actor The actor who said the statement
|
||||
* @param content The content of the statement
|
||||
*/
|
||||
public Statement(Actor actor, String content){
|
||||
this.id = idIncrementer;
|
||||
idIncrementer++;
|
||||
this.actor = actor;
|
||||
this.content = content;
|
||||
}
|
||||
|
||||
public long getId() {
|
||||
return id;
|
||||
}
|
||||
|
||||
public void setId(long id) {
|
||||
this.id = id;
|
||||
}
|
||||
|
||||
public String getContent() {
|
||||
return content;
|
||||
}
|
||||
|
||||
public void setContent(String content) {
|
||||
this.content = content;
|
||||
}
|
||||
|
||||
public Actor getActor() {
|
||||
return actor;
|
||||
}
|
||||
|
||||
public void setActor(Actor actor) {
|
||||
this.actor = actor;
|
||||
}
|
||||
|
||||
public Conversation getConversation() {
|
||||
return conversation;
|
||||
}
|
||||
|
||||
public void setConversation(Conversation conversation) {
|
||||
this.conversation = conversation;
|
||||
}
|
||||
|
||||
|
||||
|
||||
}
|
||||
99
src/main/java/org/studiorailgun/knowledge/KnowledgeWeb.java
Normal file
99
src/main/java/org/studiorailgun/knowledge/KnowledgeWeb.java
Normal file
@ -0,0 +1,99 @@
|
||||
package org.studiorailgun.knowledge;
|
||||
|
||||
import java.util.Collection;
|
||||
import java.util.HashMap;
|
||||
import java.util.LinkedList;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Map.Entry;
|
||||
|
||||
/**
|
||||
* A knowledge web
|
||||
*/
|
||||
public class KnowledgeWeb {
|
||||
|
||||
|
||||
/**
|
||||
* The nodes;
|
||||
*/
|
||||
Map<Integer,Node> nodes;
|
||||
|
||||
/**
|
||||
* The relations
|
||||
*/
|
||||
Map<Integer,Relation> relations;
|
||||
|
||||
/**
|
||||
* Map of relation type -> list of relations of that type
|
||||
*/
|
||||
transient Map<String,List<Integer>> typeRelationLookup;
|
||||
|
||||
public KnowledgeWeb(){
|
||||
this.nodes = new HashMap<Integer,Node>();
|
||||
this.relations = new HashMap<Integer,Relation>();
|
||||
this.typeRelationLookup = new HashMap<String,List<Integer>>();
|
||||
}
|
||||
|
||||
/**
|
||||
* Initializes the linked datastructures between items
|
||||
*/
|
||||
public void initLinks(){
|
||||
//set id incrementers
|
||||
for(Entry<Integer,Node> entry : this.nodes.entrySet()){
|
||||
if(entry.getValue().getId() >= Node.idIncrementer){
|
||||
Node.idIncrementer = entry.getValue().getId() + 1;
|
||||
}
|
||||
if(entry.getValue().getId() != entry.getKey()){
|
||||
throw new Error("Node id doesn't match map value! " + entry.getValue().getId() + " " + entry.getValue().getName() + " " + entry.getKey());
|
||||
}
|
||||
}
|
||||
for(Entry<Integer,Relation> entry : this.relations.entrySet()){
|
||||
Relation relation = entry.getValue();
|
||||
if(relation.getId() >= Relation.idIncrementer){
|
||||
Relation.idIncrementer = relation.getId() + 1;
|
||||
}
|
||||
if(this.typeRelationLookup.containsKey(relation.getName())){
|
||||
this.typeRelationLookup.get(relation.getName()).add(relation.getId());
|
||||
} else {
|
||||
LinkedList<Integer> list = new LinkedList<Integer>();
|
||||
list.add(relation.getId());
|
||||
this.typeRelationLookup.put(relation.getName(),list);
|
||||
}
|
||||
if(relation.getId() != entry.getKey()){
|
||||
throw new Error("Relation id doesn't match map value! " + relation.getId() + " " + relation.getName() + " " + entry.getKey());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public Collection<Node> getNodes() {
|
||||
return nodes.values();
|
||||
}
|
||||
|
||||
public Collection<Relation> getRelations() {
|
||||
return relations.values();
|
||||
}
|
||||
|
||||
public Node getNode(int id){
|
||||
return nodes.get(id);
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets a singleton relationship with a given parent
|
||||
* @param relationType The relationship type
|
||||
* @param parent The parent
|
||||
* @return The relation if it exists, null otherwise
|
||||
*/
|
||||
public Relation getSingletonRelation(String relationType, Node parent){
|
||||
List<Integer> relations = this.typeRelationLookup.get(relationType);
|
||||
if(relations != null){
|
||||
for(int relationId : relations){
|
||||
if(relationId == parent.getId()){
|
||||
return this.relations.get(relationId);
|
||||
}
|
||||
}
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
46
src/main/java/org/studiorailgun/knowledge/Node.java
Normal file
46
src/main/java/org/studiorailgun/knowledge/Node.java
Normal file
@ -0,0 +1,46 @@
|
||||
package org.studiorailgun.knowledge;
|
||||
|
||||
/**
|
||||
* A node in the knowledge web
|
||||
*/
|
||||
public class Node {
|
||||
|
||||
/**
|
||||
* The id incremented
|
||||
*/
|
||||
public static int idIncrementer = 0;
|
||||
|
||||
/**
|
||||
* The id of the node
|
||||
*/
|
||||
int id;
|
||||
|
||||
/**
|
||||
* The name of the node
|
||||
*/
|
||||
String name;
|
||||
|
||||
|
||||
/**
|
||||
* Constructor
|
||||
* @param name The name of the node
|
||||
*/
|
||||
public Node(String name){
|
||||
this.id = idIncrementer;
|
||||
idIncrementer++;
|
||||
this.name = name;
|
||||
}
|
||||
|
||||
|
||||
public String getName() {
|
||||
return name;
|
||||
}
|
||||
|
||||
|
||||
public int getId() {
|
||||
return id;
|
||||
}
|
||||
|
||||
|
||||
|
||||
}
|
||||
64
src/main/java/org/studiorailgun/knowledge/Relation.java
Normal file
64
src/main/java/org/studiorailgun/knowledge/Relation.java
Normal file
@ -0,0 +1,64 @@
|
||||
package org.studiorailgun.knowledge;
|
||||
|
||||
import org.studiorailgun.Globals;
|
||||
|
||||
public class Relation {
|
||||
|
||||
/**
|
||||
* The id incremented
|
||||
*/
|
||||
public static int idIncrementer = 0;
|
||||
|
||||
/**
|
||||
* The id of the relationship
|
||||
*/
|
||||
int id;
|
||||
|
||||
/**
|
||||
* The name of the relationship
|
||||
*/
|
||||
String name;
|
||||
|
||||
/**
|
||||
* The parent of the relationship
|
||||
*/
|
||||
int parent;
|
||||
|
||||
/**
|
||||
* The child of the relationship
|
||||
*/
|
||||
int child;
|
||||
|
||||
/**
|
||||
* Constructor
|
||||
* @param name The name of the relation
|
||||
* @param parent The parent of the relation
|
||||
* @param child The child of the relation
|
||||
*/
|
||||
public Relation(String name, Node parent, Node child){
|
||||
this.id = idIncrementer;
|
||||
idIncrementer++;
|
||||
this.name = name;
|
||||
this.parent = parent.getId();
|
||||
this.child = child.getId();
|
||||
}
|
||||
|
||||
public int getId() {
|
||||
return id;
|
||||
}
|
||||
|
||||
public String getName() {
|
||||
return name;
|
||||
}
|
||||
|
||||
public Node getParent() {
|
||||
return Globals.web.getNode(this.parent);
|
||||
}
|
||||
|
||||
public Node getChild() {
|
||||
return Globals.web.getNode(this.child);
|
||||
}
|
||||
|
||||
|
||||
|
||||
}
|
||||
68
src/main/java/org/studiorailgun/kobold/KoboldPrinter.java
Normal file
68
src/main/java/org/studiorailgun/kobold/KoboldPrinter.java
Normal file
@ -0,0 +1,68 @@
|
||||
package org.studiorailgun.kobold;
|
||||
|
||||
import java.util.concurrent.TimeUnit;
|
||||
|
||||
/**
|
||||
* Prints the response from a kobold request
|
||||
*/
|
||||
public class KoboldPrinter {
|
||||
|
||||
/**
|
||||
* The number of milliseconds to wait between printing characters
|
||||
*/
|
||||
static final int CHAR_DELAY = 10;
|
||||
|
||||
/**
|
||||
* The interval to check at in milliseconds
|
||||
*/
|
||||
static final int CHECK_INTERVAL = 300;
|
||||
|
||||
/**
|
||||
* The request to print
|
||||
*/
|
||||
KoboldRequest request;
|
||||
|
||||
/**
|
||||
* Position of the printer (how many characters has it printed)
|
||||
*/
|
||||
int position = 0;
|
||||
|
||||
/**
|
||||
* Constructor
|
||||
*/
|
||||
public KoboldPrinter(KoboldRequest request){
|
||||
this.request = request;
|
||||
}
|
||||
|
||||
/**
|
||||
* Prints the request. Blocks the thread until the printing has finished.
|
||||
*/
|
||||
public void print(){
|
||||
int i = 0;
|
||||
while(!this.request.isCompleted() || position < this.request.getResponse().length()){
|
||||
try {
|
||||
TimeUnit.MILLISECONDS.sleep(CHAR_DELAY);
|
||||
} catch (InterruptedException e) {
|
||||
e.printStackTrace();
|
||||
}
|
||||
if(!this.request.isCompleted() && i % (CHECK_INTERVAL / CHAR_DELAY) == 0){
|
||||
this.request.check();
|
||||
}
|
||||
if(this.position < this.request.getResponse().length()){
|
||||
System.out.print(this.request.getResponse().charAt(this.position));
|
||||
this.position++;
|
||||
}
|
||||
//check for occurrences of immediate dialong ending symbols
|
||||
if(KoboldSymbols.containsBannedSymbol(this.request.getResponse())){
|
||||
String message = KoboldSymbols.removeBannedSymbol(this.request.getResponse());
|
||||
this.request.markCompleted();
|
||||
System.out.print("\r" + message);
|
||||
break;
|
||||
}
|
||||
i++;
|
||||
}
|
||||
//make sure we're on a fresh line
|
||||
System.out.println();
|
||||
}
|
||||
|
||||
}
|
||||
188
src/main/java/org/studiorailgun/kobold/KoboldRequest.java
Normal file
188
src/main/java/org/studiorailgun/kobold/KoboldRequest.java
Normal file
@ -0,0 +1,188 @@
|
||||
package org.studiorailgun.kobold;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
import java.io.OutputStream;
|
||||
import java.net.HttpURLConnection;
|
||||
import java.net.MalformedURLException;
|
||||
import java.net.ProtocolException;
|
||||
import java.net.URI;
|
||||
import java.net.URISyntaxException;
|
||||
import java.net.URL;
|
||||
import java.net.URLConnection;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.util.UUID;
|
||||
|
||||
/**
|
||||
* A request to koboldcpp
|
||||
*/
|
||||
public class KoboldRequest {
|
||||
|
||||
/**
|
||||
* Base Kobold URL
|
||||
*/
|
||||
static final String BASE_URL = "http://localhost:5001";
|
||||
|
||||
/**
|
||||
* Endpoint to start generating
|
||||
*/
|
||||
static final String STREAM_ENDPOINT = "/api/v1/generate";
|
||||
|
||||
/**
|
||||
* Endpoint to query generation status
|
||||
*/
|
||||
static final String CHECK_ENDPOINT = "/api/extra/generate/check";
|
||||
|
||||
/**
|
||||
* The prompt
|
||||
*/
|
||||
String prompt;
|
||||
|
||||
/**
|
||||
* The gen key for streaming
|
||||
*/
|
||||
String genKey;
|
||||
|
||||
/**
|
||||
* The response string
|
||||
*/
|
||||
String response = "";
|
||||
|
||||
/**
|
||||
* True if has received the complete generation response from kobold, false otherwise
|
||||
*/
|
||||
boolean completed = false;
|
||||
|
||||
/**
|
||||
* Constructor
|
||||
* @param prompt The prompt
|
||||
*/
|
||||
public KoboldRequest(String prompt){
|
||||
this.prompt = prompt;
|
||||
this.genKey = UUID.randomUUID().toString();
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Fires the request to generate
|
||||
*/
|
||||
public void fire(){
|
||||
try {
|
||||
String urlPath = BASE_URL + STREAM_ENDPOINT;
|
||||
URL url = new URI(urlPath).toURL();
|
||||
URLConnection con = url.openConnection();
|
||||
HttpURLConnection http = (HttpURLConnection)con;
|
||||
http.setRequestMethod("POST"); // PUT is another valid option
|
||||
http.setDoOutput(true);
|
||||
|
||||
String body = "{\"prompt\": \"" + this.prompt + "\", \"genkey\": \"" + genKey + "\"}";
|
||||
byte[] out = body.getBytes(StandardCharsets.UTF_8);
|
||||
int length = out.length;
|
||||
|
||||
http.setFixedLengthStreamingMode(length);
|
||||
http.setRequestProperty("Content-Type", "application/json; charset=UTF-8");
|
||||
http.connect();
|
||||
try(OutputStream os = http.getOutputStream()) {
|
||||
os.write(out);
|
||||
}
|
||||
http.disconnect();
|
||||
//if uncommented, will block until response is fully generated
|
||||
// try(InputStream is = http.getInputStream()){
|
||||
// byte[] response = is.readAllBytes();
|
||||
// String responseParsed = new String(response, StandardCharsets.UTF_8);
|
||||
// System.out.println("parsed: " + responseParsed);
|
||||
// }
|
||||
} catch (MalformedURLException e) {
|
||||
// TODO Auto-generated catch block
|
||||
e.printStackTrace();
|
||||
} catch (ProtocolException e) {
|
||||
// TODO Auto-generated catch block
|
||||
e.printStackTrace();
|
||||
} catch (IOException e) {
|
||||
// TODO Auto-generated catch block
|
||||
e.printStackTrace();
|
||||
} catch (URISyntaxException e) {
|
||||
// TODO Auto-generated catch block
|
||||
e.printStackTrace();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Checks for new text generated from the prompt
|
||||
*/
|
||||
public void check(){
|
||||
try {
|
||||
String urlPath = BASE_URL + CHECK_ENDPOINT;
|
||||
URL url = new URI(urlPath).toURL();
|
||||
URLConnection con = url.openConnection();
|
||||
HttpURLConnection http = (HttpURLConnection)con;
|
||||
http.setRequestMethod("POST"); // PUT is another valid option
|
||||
http.setDoOutput(true);
|
||||
http.setDoInput(true);
|
||||
String body = "{\"genkey\": \"" + genKey + "\"}";
|
||||
byte[] out = body.getBytes(StandardCharsets.UTF_8);
|
||||
int length = out.length;
|
||||
|
||||
http.setFixedLengthStreamingMode(length);
|
||||
http.setRequestProperty("Content-Type", "application/json; charset=UTF-8");
|
||||
http.connect();
|
||||
try(OutputStream os = http.getOutputStream()) {
|
||||
os.write(out);
|
||||
}
|
||||
try(InputStream is = http.getInputStream()){
|
||||
byte[] responseBytes = is.readAllBytes();
|
||||
String responseParsed = new String(responseBytes, StandardCharsets.UTF_8);
|
||||
String responseBody = responseParsed.substring(23, responseParsed.length() - 4);
|
||||
if(this.response.equals(responseBody) && this.response.length() > 0){
|
||||
this.completed = true;
|
||||
} else {
|
||||
this.response = responseBody;
|
||||
}
|
||||
}
|
||||
} catch (MalformedURLException e) {
|
||||
// TODO Auto-generated catch block
|
||||
e.printStackTrace();
|
||||
} catch (ProtocolException e) {
|
||||
// TODO Auto-generated catch block
|
||||
e.printStackTrace();
|
||||
} catch (IOException e) {
|
||||
// TODO Auto-generated catch block
|
||||
e.printStackTrace();
|
||||
} catch (URISyntaxException e) {
|
||||
// TODO Auto-generated catch block
|
||||
e.printStackTrace();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the response
|
||||
* @return The response
|
||||
*/
|
||||
public String getResponse(){
|
||||
return response;
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets whether this request has completed generation or not
|
||||
* @return true if has completed generation, false otherwise
|
||||
*/
|
||||
public boolean isCompleted(){
|
||||
return completed;
|
||||
}
|
||||
|
||||
/**
|
||||
* Flags the request as completed
|
||||
*/
|
||||
public void markCompleted(){
|
||||
this.completed = true;
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the prompt of the request
|
||||
* @return The prompt
|
||||
*/
|
||||
public String getPrompt(){
|
||||
return prompt;
|
||||
}
|
||||
|
||||
}
|
||||
55
src/main/java/org/studiorailgun/kobold/KoboldSymbols.java
Normal file
55
src/main/java/org/studiorailgun/kobold/KoboldSymbols.java
Normal file
@ -0,0 +1,55 @@
|
||||
package org.studiorailgun.kobold;
|
||||
|
||||
/**
|
||||
* Symbol tables for kobold parsing
|
||||
*/
|
||||
public class KoboldSymbols {
|
||||
|
||||
/**
|
||||
* A list of symbols we do not want to appear in the output
|
||||
*/
|
||||
public static final String[] BANNED_SYMBOLS = new String[]{
|
||||
"Note:",
|
||||
"**Note**:",
|
||||
"support:",
|
||||
"Follow-up Questions:",
|
||||
"Answer:",
|
||||
"Prompt:",
|
||||
"---",
|
||||
"\\n\\n",
|
||||
"Response to the instruction:",
|
||||
};
|
||||
|
||||
/**
|
||||
* Checks if the input contains a banned symbol
|
||||
* @param input The input
|
||||
* @return true if it contains a banned symbol, false otherwise
|
||||
*/
|
||||
public static boolean containsBannedSymbol(String input){
|
||||
for(String symbol : BANNED_SYMBOLS){
|
||||
if(input.contains(symbol)){
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
/**
|
||||
* Removes banned symbols from the input
|
||||
* @param input The input
|
||||
* @return The input with banned symbols removed
|
||||
*/
|
||||
public static String removeBannedSymbol(String input){
|
||||
String target = input;
|
||||
while(KoboldSymbols.containsBannedSymbol(target)){
|
||||
for(String symbol : BANNED_SYMBOLS){
|
||||
if(input.contains(symbol)){
|
||||
target = target.split(symbol)[0];
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
return target;
|
||||
}
|
||||
|
||||
}
|
||||
68
src/main/java/org/studiorailgun/schedule/Event.java
Normal file
68
src/main/java/org/studiorailgun/schedule/Event.java
Normal file
@ -0,0 +1,68 @@
|
||||
package org.studiorailgun.schedule;
|
||||
|
||||
import java.util.LinkedList;
|
||||
|
||||
/**
|
||||
* An event
|
||||
*/
|
||||
public class Event implements Comparable<Event>{
|
||||
|
||||
/**
|
||||
* The time when this event will happen
|
||||
*/
|
||||
Time time;
|
||||
|
||||
/**
|
||||
* The type of the event
|
||||
*/
|
||||
String type;
|
||||
|
||||
/**
|
||||
* The data of the event
|
||||
*/
|
||||
LinkedList<Object> data = new LinkedList<Object>();
|
||||
|
||||
/**
|
||||
* Creates an event
|
||||
* @param time The time the event occurs
|
||||
* @param type The type of the event
|
||||
* @param params The data of the event
|
||||
*/
|
||||
public Event(Time time, String type, Object ... params){
|
||||
this.time = time;
|
||||
this.type = type;
|
||||
for(Object el : params){
|
||||
data.add(el);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the time when the event occurs
|
||||
* @return The time
|
||||
*/
|
||||
public Time getTime(){
|
||||
return time;
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the type of the event
|
||||
* @return The type of the event
|
||||
*/
|
||||
public String getType(){
|
||||
return type;
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the data for the event
|
||||
* @return The data for the event
|
||||
*/
|
||||
public LinkedList<Object> getData(){
|
||||
return data;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int compareTo(Event other) {
|
||||
return this.time.compareTo(other.time);
|
||||
}
|
||||
|
||||
}
|
||||
56
src/main/java/org/studiorailgun/schedule/EventScheduler.java
Normal file
56
src/main/java/org/studiorailgun/schedule/EventScheduler.java
Normal file
@ -0,0 +1,56 @@
|
||||
package org.studiorailgun.schedule;
|
||||
|
||||
import java.util.PriorityQueue;
|
||||
|
||||
/**
|
||||
* Maintains the schedule of events
|
||||
*/
|
||||
public class EventScheduler {
|
||||
|
||||
/**
|
||||
* The queue of events
|
||||
*/
|
||||
PriorityQueue<Event> queue = new PriorityQueue<Event>();
|
||||
|
||||
|
||||
/**
|
||||
* Adds an event to the scheduler
|
||||
* @param event The event
|
||||
*/
|
||||
public void add(Event event){
|
||||
queue.add(event);
|
||||
}
|
||||
|
||||
/**
|
||||
* Peeks at the next event in the queue
|
||||
* @return The next event
|
||||
*/
|
||||
public Event peek(){
|
||||
return queue.peek();
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the next scheduled event
|
||||
* @return The next scheduled event
|
||||
*/
|
||||
public Event get(){
|
||||
return queue.poll();
|
||||
}
|
||||
|
||||
/**
|
||||
* Removes the event from the schedule
|
||||
* @param event The event
|
||||
*/
|
||||
public void remove(Event event){
|
||||
queue.remove(event);
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the queue of events
|
||||
* @return The queue of events
|
||||
*/
|
||||
public PriorityQueue<Event> getQueue(){
|
||||
return queue;
|
||||
}
|
||||
|
||||
}
|
||||
149
src/main/java/org/studiorailgun/schedule/Time.java
Normal file
149
src/main/java/org/studiorailgun/schedule/Time.java
Normal file
@ -0,0 +1,149 @@
|
||||
package org.studiorailgun.schedule;
|
||||
|
||||
/**
|
||||
* A time in the simulation timeline
|
||||
*/
|
||||
public class Time implements Comparable<Time> {
|
||||
|
||||
/**
|
||||
* Number of months in a year
|
||||
*/
|
||||
static final int MONTHS_IN_YEAR = 12;
|
||||
|
||||
/**
|
||||
* Number of days in a month
|
||||
*/
|
||||
static final int DAYS_IN_MONTH = 30;
|
||||
|
||||
/**
|
||||
* Number of hours in a day
|
||||
*/
|
||||
static final int HOURS_IN_DAY = 24;
|
||||
|
||||
/**
|
||||
* Number of minutes in an hour
|
||||
*/
|
||||
static final int MINUTES_IN_HOUR = 60;
|
||||
|
||||
/**
|
||||
* Number of seconds in a minute
|
||||
*/
|
||||
static final int SECONDS_IN_MINUTE = 60;
|
||||
|
||||
/**
|
||||
* The value of the time
|
||||
*/
|
||||
long value;
|
||||
|
||||
/**
|
||||
* Creates a time with a raw value
|
||||
* @param value The value
|
||||
*/
|
||||
public Time(long value){
|
||||
this.value = value;
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a time at a given year, month, and day
|
||||
* @param year The year
|
||||
* @param month The month
|
||||
* @param day The day
|
||||
* @param hour The hour
|
||||
* @param minute The minute
|
||||
* @param second The second
|
||||
*/
|
||||
public Time(int year, int month, int day, int hour, int minute, int second){
|
||||
this.value = Time.getOffset(year, month, day, hour, minute, second);
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets a time offset from this time
|
||||
* @param year The number of years to offset
|
||||
* @param month The number of months to offset
|
||||
* @param day The number of days to offset
|
||||
* @param hour The number of hours to offset
|
||||
* @param minute The number of minutes to offset
|
||||
* @param second The number of seconds to offset
|
||||
* @return The time containing the offset value
|
||||
*/
|
||||
public Time offset(int year, int month, int day, int hour, int minute, int second){
|
||||
Time rVal = new Time(this.value);
|
||||
rVal.value = rVal.value + Time.getOffset(year, month, day, hour, minute, second);
|
||||
return rVal;
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the year of this time
|
||||
* @return The year
|
||||
*/
|
||||
public int year(){
|
||||
return (int)this.value / (SECONDS_IN_MINUTE * MINUTES_IN_HOUR * HOURS_IN_DAY * DAYS_IN_MONTH * MONTHS_IN_YEAR);
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the month of this time
|
||||
* @return The month
|
||||
*/
|
||||
public int month(){
|
||||
return (int)this.value / (SECONDS_IN_MINUTE * MINUTES_IN_HOUR * HOURS_IN_DAY * DAYS_IN_MONTH);
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the day of this time
|
||||
* @return The day
|
||||
*/
|
||||
public int day(){
|
||||
return (int)this.value / (SECONDS_IN_MINUTE * MINUTES_IN_HOUR * HOURS_IN_DAY);
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the hour of this time
|
||||
* @return The hour
|
||||
*/
|
||||
public int hour(){
|
||||
return (int)this.value / (SECONDS_IN_MINUTE * MINUTES_IN_HOUR);
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the minute of this time
|
||||
* @return The minute
|
||||
*/
|
||||
public int minute(){
|
||||
return (int)this.value / SECONDS_IN_MINUTE;
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the second of this time
|
||||
* @return The second
|
||||
*/
|
||||
public int second(){
|
||||
return (int)this.value % SECONDS_IN_MINUTE;
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets an offset given a set of legible time values
|
||||
* @param year The number of years
|
||||
* @param month The number of months
|
||||
* @param day The number of days
|
||||
* @param hour The number of hours
|
||||
* @param minute The number of minutes
|
||||
* @param second The number of seconds
|
||||
* @return The offset
|
||||
*/
|
||||
private static long getOffset(int year, int month, int day, int hour, int minute, int second){
|
||||
return
|
||||
second +
|
||||
minute * SECONDS_IN_MINUTE +
|
||||
hour * SECONDS_IN_MINUTE * MINUTES_IN_HOUR +
|
||||
day * SECONDS_IN_MINUTE * MINUTES_IN_HOUR * HOURS_IN_DAY +
|
||||
month * SECONDS_IN_MINUTE * MINUTES_IN_HOUR * HOURS_IN_DAY * DAYS_IN_MONTH +
|
||||
year * SECONDS_IN_MINUTE * MINUTES_IN_HOUR * HOURS_IN_DAY * DAYS_IN_MONTH * MONTHS_IN_YEAR
|
||||
;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int compareTo(Time other) {
|
||||
return (int)(this.value - other.value);
|
||||
}
|
||||
|
||||
}
|
||||
86
src/main/python/conversation/subject.py
Normal file
86
src/main/python/conversation/subject.py
Normal file
@ -0,0 +1,86 @@
|
||||
import tensorflow as tf
|
||||
keras = tf.keras
|
||||
from tensorflow import Tensor
|
||||
from keras.api.layers import TextVectorization, Embedding, LSTM, Dense, Input
|
||||
from keras.api.models import Sequential
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
|
||||
|
||||
# Model constants.
|
||||
max_features: int = 20000
|
||||
embedding_dim: int = 128
|
||||
sequence_length: int = 500
|
||||
epochs: int = 50
|
||||
max_tokens: int = 5000
|
||||
output_sequence_length: int = 4
|
||||
|
||||
|
||||
# read sentences
|
||||
data_path: str = './data/semantic/subject.txt'
|
||||
data_raw: str = open(data_path).read()
|
||||
vocab: list[str] = data_raw.split('\n')
|
||||
|
||||
# read labels
|
||||
label_data_path: str = './data/semantic/subject_label.txt'
|
||||
label_data_raw: str = open(label_data_path).read()
|
||||
labels: list[int] = list(map(int,label_data_raw.split()))
|
||||
|
||||
# init vectorizer
|
||||
textVec: TextVectorization = TextVectorization(
|
||||
max_tokens=max_tokens,
|
||||
output_mode='int',
|
||||
output_sequence_length=output_sequence_length,
|
||||
pad_to_max_tokens=True)
|
||||
|
||||
# Add the vocab to the tokenizer
|
||||
textVec.adapt(vocab)
|
||||
input_data: list[str] = vocab
|
||||
data: Tensor = textVec.call(input_data)
|
||||
|
||||
# construct model
|
||||
model: Sequential = Sequential([
|
||||
keras.Input(shape=(None,), dtype="int64"),
|
||||
Embedding(max_features + 1, embedding_dim),
|
||||
LSTM(64),
|
||||
Dense(1, activation='sigmoid')
|
||||
])
|
||||
|
||||
#compile the model
|
||||
# model.build(keras.Input(shape=(None,), dtype="int64"))
|
||||
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
|
||||
|
||||
|
||||
# fit the training data
|
||||
npData = np.array(data)
|
||||
npLabel = np.array(labels)
|
||||
model.fit(npData,npLabel,epochs=epochs)
|
||||
|
||||
|
||||
# evaluate here
|
||||
|
||||
|
||||
# predict
|
||||
predictTargetRaw: list[str] = ['saf']
|
||||
predictTargetToken: list[int] = textVec.call(predictTargetRaw)
|
||||
npPredict: npt.NDArray[np.complex64] = np.array(predictTargetToken)
|
||||
# print(npPredict)
|
||||
result: list[int] = model.predict(npPredict)
|
||||
print("predict result:")
|
||||
print(predictTargetToken)
|
||||
print(result)
|
||||
print(data)
|
||||
print(labels)
|
||||
|
||||
|
||||
|
||||
# save the model so keras can reload
|
||||
# savePath: str = './data/semantic/model.keras'
|
||||
# model.save(savePath)
|
||||
|
||||
# export the model so java can leverage it
|
||||
exportPath: str = './data/semantic/model'
|
||||
model.export(exportPath)
|
||||
|
||||
# tf.keras.utils.get_file('asdf')
|
||||
# asdf: str = 'a'
|
||||
44
web.json
Normal file
44
web.json
Normal file
@ -0,0 +1,44 @@
|
||||
{
|
||||
"nodes" : {
|
||||
"0" : {
|
||||
"id" : 0,
|
||||
"name" : "Bert"
|
||||
},
|
||||
"1" : {
|
||||
"id" : 1,
|
||||
"name" : "Name"
|
||||
},
|
||||
"2" : {
|
||||
"id" : 2,
|
||||
"name" : "Self"
|
||||
},
|
||||
"3" : {
|
||||
"id" : 3,
|
||||
"name" : "Person"
|
||||
},
|
||||
"4" : {
|
||||
"id" : 4,
|
||||
"name" : "ConversationParticipant"
|
||||
}
|
||||
},
|
||||
"relations" : {
|
||||
"0" : {
|
||||
"id" : 0,
|
||||
"name" : "Name",
|
||||
"parent" : 0,
|
||||
"child" : 2
|
||||
},
|
||||
"1" : {
|
||||
"id" : 1,
|
||||
"name" : "InstanceOf",
|
||||
"parent" : 3,
|
||||
"child" : 2
|
||||
},
|
||||
"2" : {
|
||||
"id" : 2,
|
||||
"name" : "InstanceOf",
|
||||
"parent" : 1,
|
||||
"child" : 0
|
||||
}
|
||||
}
|
||||
}
|
||||
Loading…
Reference in New Issue
Block a user