diff --git a/.gitignore b/.gitignore index 246763b..82a69ad 100644 --- a/.gitignore +++ b/.gitignore @@ -7,4 +7,7 @@ src/trpg.egg-info # ml model files data/semantic/model **.keras -**.h5 \ No newline at end of file +**.h5 + +# graph related +temp/* \ No newline at end of file diff --git a/current_goal.txt b/current_goal.txt index 166945a..b859478 100644 --- a/current_goal.txt +++ b/current_goal.txt @@ -4,6 +4,23 @@ chatbot is a wizard sitting in a tavern by a fireplace + + + +Parse "Hello" +Nodes needed: + - Conversation + - Greeting + - Participant + - Instances of participants + - Instances of conversations + - Instances of greetings (created as the conversation starts) + + + + + + interaction: the user says hello the chatbot replies and queries the user's name (query for name relation to other participant) diff --git a/data/sentence_function/greetings.txt b/data/sentence_function/greetings.txt new file mode 100644 index 0000000..83fa01b --- /dev/null +++ b/data/sentence_function/greetings.txt @@ -0,0 +1,3 @@ +Hello +Hi +Howdy \ No newline at end of file diff --git a/data/test/webs/web.json b/data/test/webs/web.json new file mode 100644 index 0000000..9029418 --- /dev/null +++ b/data/test/webs/web.json @@ -0,0 +1,44 @@ +{ + "nodes" : { + "0" : { + "id" : 0, + "name" : "Bert" + }, + "1" : { + "id" : 1, + "name" : "Name" + }, + "2" : { + "id" : 2, + "name" : "Self" + }, + "3" : { + "id" : 3, + "name" : "Person" + }, + "4" : { + "id" : 4, + "name" : "ConversationParticipant" + } + }, + "relations" : { + "0" : { + "id" : 0, + "name" : "Name", + "parent" : 0, + "child" : 2 + }, + "1" : { + "id" : 1, + "name" : "InstanceOf", + "parent" : 3, + "child" : 2 + }, + "2" : { + "id" : 2, + "name" : "InstanceOf", + "parent" : 1, + "child" : 0 + } + } +} \ No newline at end of file diff --git a/pom.xml b/pom.xml index f4d6481..2863660 100644 --- a/pom.xml +++ b/pom.xml @@ -29,11 +29,6 @@ - org.tensorflow tensorflow-core-api @@ -52,12 +47,21 @@ - + com.google.code.gson gson 2.8.6 + + + + org.junit.jupiter + junit-jupiter + 5.10.3 + + + diff --git a/src/main/java/org/studiorailgun/conversation/AgentLoop.java b/src/main/java/org/studiorailgun/conversation/AgentLoop.java index 2f74ddd..c38d667 100644 --- a/src/main/java/org/studiorailgun/conversation/AgentLoop.java +++ b/src/main/java/org/studiorailgun/conversation/AgentLoop.java @@ -7,8 +7,6 @@ import org.studiorailgun.conversation.parser.CommandParser; import org.studiorailgun.conversation.tracking.Actor; import org.studiorailgun.conversation.tracking.Conversation; import org.studiorailgun.conversation.tracking.Statement; -import org.studiorailgun.kobold.KoboldPrinter; -import org.studiorailgun.kobold.KoboldRequest; public class AgentLoop { @@ -28,32 +26,30 @@ public class AgentLoop { public static void main(){ Globals.init(); - //setup actual input - Scanner scan = new Scanner(System.in); - String prompt = ""; + try (Scanner scan = new Scanner(System.in)) { + String prompt = ""; - //setup conversation tracking - Conversation convo = new Conversation(); - Actor player = new Actor("John"); - Actor ai = new Actor("Dave"); - convo.addParticipant(player); - convo.addParticipant(ai); + //setup conversation tracking + Conversation convo = new Conversation(); + Actor player = new Actor("John"); + Actor ai = new Actor("Dave"); + convo.addParticipant(player); + convo.addParticipant(ai); - //actual main loop - while(running){ + //actual main loop + while(running){ - //handle player statement - prompt = scan.nextLine(); - if(CommandParser.parseCommands(prompt)){ - continue; + //handle player statement + prompt = scan.nextLine(); + if(CommandParser.parseCommands(prompt)){ + continue; + } + convo.addStatement(new Statement(player, prompt)); + + //handle ai statement + throw new UnsupportedOperationException("asdf"); } - convo.addStatement(new Statement(player, prompt)); - - //handle ai statement - throw new UnsupportedOperationException("asdf"); } - - scan.close(); } } diff --git a/src/main/java/org/studiorailgun/conversation/evaluators/GreetingEval.java b/src/main/java/org/studiorailgun/conversation/evaluators/GreetingEval.java index be5ef63..247fb92 100644 --- a/src/main/java/org/studiorailgun/conversation/evaluators/GreetingEval.java +++ b/src/main/java/org/studiorailgun/conversation/evaluators/GreetingEval.java @@ -1,19 +1,38 @@ package org.studiorailgun.conversation.evaluators; +import java.io.File; +import java.io.IOException; +import java.nio.file.Files; +import java.util.List; + /** * Evaluates a greeting */ public class GreetingEval { + + /** + * The greeting strings + */ + private static List greetingStrings = null; + + /** + * Inits the greeting evaluator + */ + public static void init(){ + try { + greetingStrings = Files.readAllLines(new File("./data/sentence_function/greetings.txt").toPath()); + } catch (IOException e) { + e.printStackTrace(); + } + } /** * Evaluates a greeting * @param input The sentence */ public static void evaluate(String input){ - switch(input){ - case "Hello": { - - } break; + if(greetingStrings.contains(input)){ + System.out.println("Contained!"); } } diff --git a/src/main/java/org/studiorailgun/conversation/parser/CommandParser.java b/src/main/java/org/studiorailgun/conversation/parser/CommandParser.java index 061afec..04f6011 100644 --- a/src/main/java/org/studiorailgun/conversation/parser/CommandParser.java +++ b/src/main/java/org/studiorailgun/conversation/parser/CommandParser.java @@ -2,6 +2,7 @@ package org.studiorailgun.conversation.parser; import org.studiorailgun.conversation.AgentLoop; import org.studiorailgun.conversation.LLMLoop; +import org.studiorailgun.knowledge.CSVExport; /** * Parses player commands to execute @@ -23,6 +24,8 @@ public class CommandParser { } else if(input.equals("s format")){ //save a formatting error + } else if(input.equals("export")){ + CSVExport.export("current"); } return false; } diff --git a/src/main/java/org/studiorailgun/knowledge/CSVExport.java b/src/main/java/org/studiorailgun/knowledge/CSVExport.java new file mode 100644 index 0000000..4ae0edf --- /dev/null +++ b/src/main/java/org/studiorailgun/knowledge/CSVExport.java @@ -0,0 +1,76 @@ +package org.studiorailgun.knowledge; + +import java.io.BufferedWriter; +import java.io.File; +import java.io.FileWriter; +import java.io.IOException; +import java.io.PrintWriter; +import java.nio.file.Files; + +import org.studiorailgun.Globals; + +/** + * Exports a web as a set of CSVs + */ +public class CSVExport { + + /** + * Path for the export dir + */ + static final String EXPORT_DIR = "./temp/"; + + /** + * Ending for the node file + */ + static final String NODE_FILE_END = "_nodes.csv"; + + /** + * Ending for the edge file + */ + static final String EDGE_FILE_END = "_edges.csv"; + + /** + * Exports a web as a pair of CSVs + * @param name The name of the CSVs + */ + public static void export(String name){ + + + //make sure the export dir exists + try { + Files.createDirectories(new File(EXPORT_DIR).toPath()); + } catch (IOException e) { + e.printStackTrace(); + } + + + //write the nodes + try (PrintWriter writer = new PrintWriter(new BufferedWriter(new FileWriter(new File(EXPORT_DIR + name + NODE_FILE_END))))) { + //header + writer.println("\"id\",\"label\",\"interval\""); + //all nodes + for(Node node : Globals.web.getNodes()){ + writer.println(node.getId() + ",\"" + node.getName() + "\","); + } + } catch (IOException e) { + // TODO Auto-generated catch block + e.printStackTrace(); + } + + + + //write the edges + try (PrintWriter writer = new PrintWriter(new BufferedWriter(new FileWriter(new File(EXPORT_DIR + name + EDGE_FILE_END))))) { + //header + writer.println("\"source\",\"target\",\"type\",\"id\",\"label\",\"interval\",\"weight\""); + //all relations + for(Relation relation : Globals.web.getRelations()){ + writer.println(relation.getParent().getId() + "," + relation.getChild().getId() + ",\"" + relation.getName() + "\"," + relation.getId() + ",\"" + relation.getName() + "\",,"); + } + } catch (IOException e) { + // TODO Auto-generated catch block + e.printStackTrace(); + } + } + +} diff --git a/src/test/java/org/studiorailgun/LoadingTests.java b/src/test/java/org/studiorailgun/LoadingTests.java new file mode 100644 index 0000000..f7ef11b --- /dev/null +++ b/src/test/java/org/studiorailgun/LoadingTests.java @@ -0,0 +1,36 @@ +package org.studiorailgun; + +import static org.junit.jupiter.api.Assertions.*; + +import java.io.File; + +import org.junit.jupiter.api.Test; +import org.studiorailgun.knowledge.CSVExport; +import org.studiorailgun.knowledge.KnowledgeWeb; + +/** + * Test loading webs + */ +public class LoadingTests { + + + + @Test + public void testLoadWeb(){ + assertDoesNotThrow(() -> { + KnowledgeWeb web = FileUtils.loadObjectFromFile(new File("./data/test/webs/web.json"), KnowledgeWeb.class); + web.initLinks(); + }); + } + + @Test + public void testExportCSV(){ + assertDoesNotThrow(() -> { + KnowledgeWeb web = FileUtils.loadObjectFromFile(new File("./data/test/webs/web.json"), KnowledgeWeb.class); + web.initLinks(); + Globals.web = web; + CSVExport.export("test"); + }); + } + +}