graph exporting
This commit is contained in:
parent
bcf40e1334
commit
cde2fb8076
5
.gitignore
vendored
5
.gitignore
vendored
@ -7,4 +7,7 @@ src/trpg.egg-info
|
||||
# ml model files
|
||||
data/semantic/model
|
||||
**.keras
|
||||
**.h5
|
||||
**.h5
|
||||
|
||||
# graph related
|
||||
temp/*
|
||||
@ -4,6 +4,23 @@ chatbot is a wizard
|
||||
sitting in a tavern by a fireplace
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
Parse "Hello"
|
||||
Nodes needed:
|
||||
- Conversation
|
||||
- Greeting
|
||||
- Participant
|
||||
- Instances of participants
|
||||
- Instances of conversations
|
||||
- Instances of greetings (created as the conversation starts)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
interaction:
|
||||
the user says hello
|
||||
the chatbot replies and queries the user's name (query for name relation to other participant)
|
||||
|
||||
3
data/sentence_function/greetings.txt
Normal file
3
data/sentence_function/greetings.txt
Normal file
@ -0,0 +1,3 @@
|
||||
Hello
|
||||
Hi
|
||||
Howdy
|
||||
44
data/test/webs/web.json
Normal file
44
data/test/webs/web.json
Normal file
@ -0,0 +1,44 @@
|
||||
{
|
||||
"nodes" : {
|
||||
"0" : {
|
||||
"id" : 0,
|
||||
"name" : "Bert"
|
||||
},
|
||||
"1" : {
|
||||
"id" : 1,
|
||||
"name" : "Name"
|
||||
},
|
||||
"2" : {
|
||||
"id" : 2,
|
||||
"name" : "Self"
|
||||
},
|
||||
"3" : {
|
||||
"id" : 3,
|
||||
"name" : "Person"
|
||||
},
|
||||
"4" : {
|
||||
"id" : 4,
|
||||
"name" : "ConversationParticipant"
|
||||
}
|
||||
},
|
||||
"relations" : {
|
||||
"0" : {
|
||||
"id" : 0,
|
||||
"name" : "Name",
|
||||
"parent" : 0,
|
||||
"child" : 2
|
||||
},
|
||||
"1" : {
|
||||
"id" : 1,
|
||||
"name" : "InstanceOf",
|
||||
"parent" : 3,
|
||||
"child" : 2
|
||||
},
|
||||
"2" : {
|
||||
"id" : 2,
|
||||
"name" : "InstanceOf",
|
||||
"parent" : 1,
|
||||
"child" : 0
|
||||
}
|
||||
}
|
||||
}
|
||||
16
pom.xml
16
pom.xml
@ -29,11 +29,6 @@
|
||||
</dependency>
|
||||
|
||||
<!--Tensorflow Java-->
|
||||
<!-- <dependency>
|
||||
<groupId>org.tensorflow</groupId>
|
||||
<artifactId>tensorflow</artifactId>
|
||||
<version>1.15.0</version>
|
||||
</dependency> -->
|
||||
<dependency>
|
||||
<groupId>org.tensorflow</groupId>
|
||||
<artifactId>tensorflow-core-api</artifactId>
|
||||
@ -52,12 +47,21 @@
|
||||
|
||||
|
||||
|
||||
|
||||
<!--GSON-->
|
||||
<dependency>
|
||||
<groupId>com.google.code.gson</groupId>
|
||||
<artifactId>gson</artifactId>
|
||||
<version>2.8.6</version>
|
||||
</dependency>
|
||||
|
||||
<!--JUnit-->
|
||||
<dependency>
|
||||
<groupId>org.junit.jupiter</groupId>
|
||||
<artifactId>junit-jupiter</artifactId>
|
||||
<version>5.10.3</version>
|
||||
</dependency>
|
||||
|
||||
|
||||
</dependencies>
|
||||
|
||||
<!-- MAIN BUILD -->
|
||||
|
||||
@ -7,8 +7,6 @@ import org.studiorailgun.conversation.parser.CommandParser;
|
||||
import org.studiorailgun.conversation.tracking.Actor;
|
||||
import org.studiorailgun.conversation.tracking.Conversation;
|
||||
import org.studiorailgun.conversation.tracking.Statement;
|
||||
import org.studiorailgun.kobold.KoboldPrinter;
|
||||
import org.studiorailgun.kobold.KoboldRequest;
|
||||
|
||||
public class AgentLoop {
|
||||
|
||||
@ -28,32 +26,30 @@ public class AgentLoop {
|
||||
public static void main(){
|
||||
Globals.init();
|
||||
|
||||
//setup actual input
|
||||
Scanner scan = new Scanner(System.in);
|
||||
String prompt = "";
|
||||
try (Scanner scan = new Scanner(System.in)) {
|
||||
String prompt = "";
|
||||
|
||||
//setup conversation tracking
|
||||
Conversation convo = new Conversation();
|
||||
Actor player = new Actor("John");
|
||||
Actor ai = new Actor("Dave");
|
||||
convo.addParticipant(player);
|
||||
convo.addParticipant(ai);
|
||||
//setup conversation tracking
|
||||
Conversation convo = new Conversation();
|
||||
Actor player = new Actor("John");
|
||||
Actor ai = new Actor("Dave");
|
||||
convo.addParticipant(player);
|
||||
convo.addParticipant(ai);
|
||||
|
||||
//actual main loop
|
||||
while(running){
|
||||
//actual main loop
|
||||
while(running){
|
||||
|
||||
//handle player statement
|
||||
prompt = scan.nextLine();
|
||||
if(CommandParser.parseCommands(prompt)){
|
||||
continue;
|
||||
//handle player statement
|
||||
prompt = scan.nextLine();
|
||||
if(CommandParser.parseCommands(prompt)){
|
||||
continue;
|
||||
}
|
||||
convo.addStatement(new Statement(player, prompt));
|
||||
|
||||
//handle ai statement
|
||||
throw new UnsupportedOperationException("asdf");
|
||||
}
|
||||
convo.addStatement(new Statement(player, prompt));
|
||||
|
||||
//handle ai statement
|
||||
throw new UnsupportedOperationException("asdf");
|
||||
}
|
||||
|
||||
scan.close();
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@ -1,19 +1,38 @@
|
||||
package org.studiorailgun.conversation.evaluators;
|
||||
|
||||
import java.io.File;
|
||||
import java.io.IOException;
|
||||
import java.nio.file.Files;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* Evaluates a greeting
|
||||
*/
|
||||
public class GreetingEval {
|
||||
|
||||
/**
|
||||
* The greeting strings
|
||||
*/
|
||||
private static List<String> greetingStrings = null;
|
||||
|
||||
/**
|
||||
* Inits the greeting evaluator
|
||||
*/
|
||||
public static void init(){
|
||||
try {
|
||||
greetingStrings = Files.readAllLines(new File("./data/sentence_function/greetings.txt").toPath());
|
||||
} catch (IOException e) {
|
||||
e.printStackTrace();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Evaluates a greeting
|
||||
* @param input The sentence
|
||||
*/
|
||||
public static void evaluate(String input){
|
||||
switch(input){
|
||||
case "Hello": {
|
||||
|
||||
} break;
|
||||
if(greetingStrings.contains(input)){
|
||||
System.out.println("Contained!");
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -2,6 +2,7 @@ package org.studiorailgun.conversation.parser;
|
||||
|
||||
import org.studiorailgun.conversation.AgentLoop;
|
||||
import org.studiorailgun.conversation.LLMLoop;
|
||||
import org.studiorailgun.knowledge.CSVExport;
|
||||
|
||||
/**
|
||||
* Parses player commands to execute
|
||||
@ -23,6 +24,8 @@ public class CommandParser {
|
||||
|
||||
} else if(input.equals("s format")){ //save a formatting error
|
||||
|
||||
} else if(input.equals("export")){
|
||||
CSVExport.export("current");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
76
src/main/java/org/studiorailgun/knowledge/CSVExport.java
Normal file
76
src/main/java/org/studiorailgun/knowledge/CSVExport.java
Normal file
@ -0,0 +1,76 @@
|
||||
package org.studiorailgun.knowledge;
|
||||
|
||||
import java.io.BufferedWriter;
|
||||
import java.io.File;
|
||||
import java.io.FileWriter;
|
||||
import java.io.IOException;
|
||||
import java.io.PrintWriter;
|
||||
import java.nio.file.Files;
|
||||
|
||||
import org.studiorailgun.Globals;
|
||||
|
||||
/**
|
||||
* Exports a web as a set of CSVs
|
||||
*/
|
||||
public class CSVExport {
|
||||
|
||||
/**
|
||||
* Path for the export dir
|
||||
*/
|
||||
static final String EXPORT_DIR = "./temp/";
|
||||
|
||||
/**
|
||||
* Ending for the node file
|
||||
*/
|
||||
static final String NODE_FILE_END = "_nodes.csv";
|
||||
|
||||
/**
|
||||
* Ending for the edge file
|
||||
*/
|
||||
static final String EDGE_FILE_END = "_edges.csv";
|
||||
|
||||
/**
|
||||
* Exports a web as a pair of CSVs
|
||||
* @param name The name of the CSVs
|
||||
*/
|
||||
public static void export(String name){
|
||||
|
||||
|
||||
//make sure the export dir exists
|
||||
try {
|
||||
Files.createDirectories(new File(EXPORT_DIR).toPath());
|
||||
} catch (IOException e) {
|
||||
e.printStackTrace();
|
||||
}
|
||||
|
||||
|
||||
//write the nodes
|
||||
try (PrintWriter writer = new PrintWriter(new BufferedWriter(new FileWriter(new File(EXPORT_DIR + name + NODE_FILE_END))))) {
|
||||
//header
|
||||
writer.println("\"id\",\"label\",\"interval\"");
|
||||
//all nodes
|
||||
for(Node node : Globals.web.getNodes()){
|
||||
writer.println(node.getId() + ",\"" + node.getName() + "\",");
|
||||
}
|
||||
} catch (IOException e) {
|
||||
// TODO Auto-generated catch block
|
||||
e.printStackTrace();
|
||||
}
|
||||
|
||||
|
||||
|
||||
//write the edges
|
||||
try (PrintWriter writer = new PrintWriter(new BufferedWriter(new FileWriter(new File(EXPORT_DIR + name + EDGE_FILE_END))))) {
|
||||
//header
|
||||
writer.println("\"source\",\"target\",\"type\",\"id\",\"label\",\"interval\",\"weight\"");
|
||||
//all relations
|
||||
for(Relation relation : Globals.web.getRelations()){
|
||||
writer.println(relation.getParent().getId() + "," + relation.getChild().getId() + ",\"" + relation.getName() + "\"," + relation.getId() + ",\"" + relation.getName() + "\",,");
|
||||
}
|
||||
} catch (IOException e) {
|
||||
// TODO Auto-generated catch block
|
||||
e.printStackTrace();
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
36
src/test/java/org/studiorailgun/LoadingTests.java
Normal file
36
src/test/java/org/studiorailgun/LoadingTests.java
Normal file
@ -0,0 +1,36 @@
|
||||
package org.studiorailgun;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.*;
|
||||
|
||||
import java.io.File;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.studiorailgun.knowledge.CSVExport;
|
||||
import org.studiorailgun.knowledge.KnowledgeWeb;
|
||||
|
||||
/**
|
||||
* Test loading webs
|
||||
*/
|
||||
public class LoadingTests {
|
||||
|
||||
|
||||
|
||||
@Test
|
||||
public void testLoadWeb(){
|
||||
assertDoesNotThrow(() -> {
|
||||
KnowledgeWeb web = FileUtils.loadObjectFromFile(new File("./data/test/webs/web.json"), KnowledgeWeb.class);
|
||||
web.initLinks();
|
||||
});
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testExportCSV(){
|
||||
assertDoesNotThrow(() -> {
|
||||
KnowledgeWeb web = FileUtils.loadObjectFromFile(new File("./data/test/webs/web.json"), KnowledgeWeb.class);
|
||||
web.initLinks();
|
||||
Globals.web = web;
|
||||
CSVExport.export("test");
|
||||
});
|
||||
}
|
||||
|
||||
}
|
||||
Loading…
Reference in New Issue
Block a user