graph exporting

This commit is contained in:
austin 2024-12-29 11:01:40 -05:00
parent bcf40e1334
commit cde2fb8076
10 changed files with 235 additions and 34 deletions

5
.gitignore vendored
View File

@ -7,4 +7,7 @@ src/trpg.egg-info
# ml model files
data/semantic/model
**.keras
**.h5
**.h5
# graph related
temp/*

View File

@ -4,6 +4,23 @@ chatbot is a wizard
sitting in a tavern by a fireplace
Parse "Hello"
Nodes needed:
- Conversation
- Greeting
- Participant
- Instances of participants
- Instances of conversations
- Instances of greetings (created as the conversation starts)
interaction:
the user says hello
the chatbot replies and queries the user's name (query for name relation to other participant)

View File

@ -0,0 +1,3 @@
Hello
Hi
Howdy

44
data/test/webs/web.json Normal file
View File

@ -0,0 +1,44 @@
{
"nodes" : {
"0" : {
"id" : 0,
"name" : "Bert"
},
"1" : {
"id" : 1,
"name" : "Name"
},
"2" : {
"id" : 2,
"name" : "Self"
},
"3" : {
"id" : 3,
"name" : "Person"
},
"4" : {
"id" : 4,
"name" : "ConversationParticipant"
}
},
"relations" : {
"0" : {
"id" : 0,
"name" : "Name",
"parent" : 0,
"child" : 2
},
"1" : {
"id" : 1,
"name" : "InstanceOf",
"parent" : 3,
"child" : 2
},
"2" : {
"id" : 2,
"name" : "InstanceOf",
"parent" : 1,
"child" : 0
}
}
}

16
pom.xml
View File

@ -29,11 +29,6 @@
</dependency>
<!--Tensorflow Java-->
<!-- <dependency>
<groupId>org.tensorflow</groupId>
<artifactId>tensorflow</artifactId>
<version>1.15.0</version>
</dependency> -->
<dependency>
<groupId>org.tensorflow</groupId>
<artifactId>tensorflow-core-api</artifactId>
@ -52,12 +47,21 @@
<!--GSON-->
<dependency>
<groupId>com.google.code.gson</groupId>
<artifactId>gson</artifactId>
<version>2.8.6</version>
</dependency>
<!--JUnit-->
<dependency>
<groupId>org.junit.jupiter</groupId>
<artifactId>junit-jupiter</artifactId>
<version>5.10.3</version>
</dependency>
</dependencies>
<!-- MAIN BUILD -->

View File

@ -7,8 +7,6 @@ import org.studiorailgun.conversation.parser.CommandParser;
import org.studiorailgun.conversation.tracking.Actor;
import org.studiorailgun.conversation.tracking.Conversation;
import org.studiorailgun.conversation.tracking.Statement;
import org.studiorailgun.kobold.KoboldPrinter;
import org.studiorailgun.kobold.KoboldRequest;
public class AgentLoop {
@ -28,32 +26,30 @@ public class AgentLoop {
public static void main(){
Globals.init();
//setup actual input
Scanner scan = new Scanner(System.in);
String prompt = "";
try (Scanner scan = new Scanner(System.in)) {
String prompt = "";
//setup conversation tracking
Conversation convo = new Conversation();
Actor player = new Actor("John");
Actor ai = new Actor("Dave");
convo.addParticipant(player);
convo.addParticipant(ai);
//setup conversation tracking
Conversation convo = new Conversation();
Actor player = new Actor("John");
Actor ai = new Actor("Dave");
convo.addParticipant(player);
convo.addParticipant(ai);
//actual main loop
while(running){
//actual main loop
while(running){
//handle player statement
prompt = scan.nextLine();
if(CommandParser.parseCommands(prompt)){
continue;
//handle player statement
prompt = scan.nextLine();
if(CommandParser.parseCommands(prompt)){
continue;
}
convo.addStatement(new Statement(player, prompt));
//handle ai statement
throw new UnsupportedOperationException("asdf");
}
convo.addStatement(new Statement(player, prompt));
//handle ai statement
throw new UnsupportedOperationException("asdf");
}
scan.close();
}
}

View File

@ -1,19 +1,38 @@
package org.studiorailgun.conversation.evaluators;
import java.io.File;
import java.io.IOException;
import java.nio.file.Files;
import java.util.List;
/**
* Evaluates a greeting
*/
public class GreetingEval {
/**
* The greeting strings
*/
private static List<String> greetingStrings = null;
/**
* Inits the greeting evaluator
*/
public static void init(){
try {
greetingStrings = Files.readAllLines(new File("./data/sentence_function/greetings.txt").toPath());
} catch (IOException e) {
e.printStackTrace();
}
}
/**
* Evaluates a greeting
* @param input The sentence
*/
public static void evaluate(String input){
switch(input){
case "Hello": {
} break;
if(greetingStrings.contains(input)){
System.out.println("Contained!");
}
}

View File

@ -2,6 +2,7 @@ package org.studiorailgun.conversation.parser;
import org.studiorailgun.conversation.AgentLoop;
import org.studiorailgun.conversation.LLMLoop;
import org.studiorailgun.knowledge.CSVExport;
/**
* Parses player commands to execute
@ -23,6 +24,8 @@ public class CommandParser {
} else if(input.equals("s format")){ //save a formatting error
} else if(input.equals("export")){
CSVExport.export("current");
}
return false;
}

View File

@ -0,0 +1,76 @@
package org.studiorailgun.knowledge;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileWriter;
import java.io.IOException;
import java.io.PrintWriter;
import java.nio.file.Files;
import org.studiorailgun.Globals;
/**
* Exports a web as a set of CSVs
*/
public class CSVExport {
/**
* Path for the export dir
*/
static final String EXPORT_DIR = "./temp/";
/**
* Ending for the node file
*/
static final String NODE_FILE_END = "_nodes.csv";
/**
* Ending for the edge file
*/
static final String EDGE_FILE_END = "_edges.csv";
/**
* Exports a web as a pair of CSVs
* @param name The name of the CSVs
*/
public static void export(String name){
//make sure the export dir exists
try {
Files.createDirectories(new File(EXPORT_DIR).toPath());
} catch (IOException e) {
e.printStackTrace();
}
//write the nodes
try (PrintWriter writer = new PrintWriter(new BufferedWriter(new FileWriter(new File(EXPORT_DIR + name + NODE_FILE_END))))) {
//header
writer.println("\"id\",\"label\",\"interval\"");
//all nodes
for(Node node : Globals.web.getNodes()){
writer.println(node.getId() + ",\"" + node.getName() + "\",");
}
} catch (IOException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
//write the edges
try (PrintWriter writer = new PrintWriter(new BufferedWriter(new FileWriter(new File(EXPORT_DIR + name + EDGE_FILE_END))))) {
//header
writer.println("\"source\",\"target\",\"type\",\"id\",\"label\",\"interval\",\"weight\"");
//all relations
for(Relation relation : Globals.web.getRelations()){
writer.println(relation.getParent().getId() + "," + relation.getChild().getId() + ",\"" + relation.getName() + "\"," + relation.getId() + ",\"" + relation.getName() + "\",,");
}
} catch (IOException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
}
}

View File

@ -0,0 +1,36 @@
package org.studiorailgun;
import static org.junit.jupiter.api.Assertions.*;
import java.io.File;
import org.junit.jupiter.api.Test;
import org.studiorailgun.knowledge.CSVExport;
import org.studiorailgun.knowledge.KnowledgeWeb;
/**
* Test loading webs
*/
public class LoadingTests {
@Test
public void testLoadWeb(){
assertDoesNotThrow(() -> {
KnowledgeWeb web = FileUtils.loadObjectFromFile(new File("./data/test/webs/web.json"), KnowledgeWeb.class);
web.initLinks();
});
}
@Test
public void testExportCSV(){
assertDoesNotThrow(() -> {
KnowledgeWeb web = FileUtils.loadObjectFromFile(new File("./data/test/webs/web.json"), KnowledgeWeb.class);
web.initLinks();
Globals.web = web;
CSVExport.export("test");
});
}
}