multi-sentence architecture
All checks were successful
studiorailgun/trpg/pipeline/head This commit looks good

This commit is contained in:
austin 2024-12-30 12:45:52 -05:00
parent dbc0a96156
commit fd5ed6b181
16 changed files with 242 additions and 105 deletions

View File

@ -7,9 +7,13 @@ sitting in a tavern by a fireplace
Respond to "What color is your hat?"
- respond with the correct information
Respond to "Hello" with "Hello. What is your name?"
- Multiple sentences in synthesizer
- Progressive goals while synthesizing (update goals after each synthesis)
- Multiple sentences per quote
- Move eval to being evaluation-per-sentence
- "Small Talk" conversation goal
- Bank of questions to ask in small talk evaluator

View File

@ -3,7 +3,7 @@ package org.studiorailgun.conversation.categorization;
import java.util.HashMap;
import java.util.Map;
import org.studiorailgun.conversation.tracking.Quote;
import org.studiorailgun.conversation.tracking.Sentence;
import org.tensorflow.Result;
import org.tensorflow.SavedModelBundle;
import org.tensorflow.Tensor;
@ -69,10 +69,10 @@ public class SentenceFunctionCategorizor {
/**
* Categorizes the sentence by function
* @param input The input quote
* @param input The sentence quote
* @return The function of the sentence
*/
public static void categorize(Quote input){
public static void categorize(Sentence input){
//construct input
TString inputTensor = TString.scalarOf(input.getRaw());
inputTensor.shape().append(1);

View File

@ -5,40 +5,69 @@ import org.studiorailgun.conversation.evaluators.goal.GoalEval;
import org.studiorailgun.conversation.evaluators.greet.GreetingEval;
import org.studiorailgun.conversation.evaluators.query.QueryEval;
import org.studiorailgun.conversation.evaluators.synthesis.ResponseEval;
import org.studiorailgun.conversation.parser.NLPParser;
import org.studiorailgun.conversation.tracking.Conversation;
import org.studiorailgun.conversation.tracking.Quote;
import org.studiorailgun.conversation.tracking.Sentence;
/**
* Evaluates a sentence based on data about the sentence
*/
public class EvaluationTree {
/**
* Number of times to evaluate response logic
*/
static final int RESPONSE_EVAL_COUNT = 1;
/**
* Evaluates a quote
* @param quote The quote
*/
public static Quote evaluate(Conversation conversation, Quote quote){
//parse data about the quote
SentenceFunctionCategorizor.categorize(quote);
//perform actions based on the tree
switch(quote.getFunction()){
case UTILITY: {
GreetingEval.evaluate(conversation, quote);
} break;
case QUERY: {
QueryEval.evaluate(conversation, quote);
} break;
default: {
throw new UnsupportedOperationException("Unsupported quote function: " + quote.getFunction());
//perform NLP evaluation of the quote
NLPParser.parse(quote);
//parse data about the quote
for(Sentence sentence : quote.getSentences()){
SentenceFunctionCategorizor.categorize(sentence);
}
//add the quote to the conversation
conversation.addQuote(quote);
//evaluate each sentence
for(Sentence sentence : quote.getSentences()){
switch(sentence.getFunction()){
case UTILITY: {
GreetingEval.evaluate(conversation, quote, sentence);
} break;
case QUERY: {
QueryEval.evaluate(conversation, quote, sentence);
} break;
default: {
throw new UnsupportedOperationException("Unsupported quote function: " + sentence.getFunction());
}
}
}
//evaluate the AI's current goal in the conversation
GoalEval.evaluate(conversation);
//synthesize language based on the results of the actions performed
Quote response = ResponseEval.evaluate(conversation);
//the response quote
Quote response = new Quote("");
//repeatedly check if we can generate a new sentence
for(int i = 0; i < RESPONSE_EVAL_COUNT; i++){
//evaluate the AI's current goal in the conversation
GoalEval.evaluate(conversation);
//synthesize language based on the results of the actions performed
Sentence newSent = ResponseEval.evaluate(conversation);
if(newSent != null){
response.appendSentence(newSent);
}
}
return response;
}

View File

@ -8,6 +8,7 @@ import java.util.List;
import org.studiorailgun.conversation.tracking.ConvParticipant;
import org.studiorailgun.conversation.tracking.Conversation;
import org.studiorailgun.conversation.tracking.Quote;
import org.studiorailgun.conversation.tracking.Sentence;
/**
* Evaluates a greeting
@ -34,7 +35,7 @@ public class GreetingEval {
* Evaluates a greeting
* @param input The sentence
*/
public static void evaluate(Conversation conversation, Quote input){
public static void evaluate(Conversation conversation, Quote input, Sentence sentence){
if(greetingStrings.contains(input.getRaw())){
ConvParticipant other = conversation.getOther();
if(!conversation.getGreetingData().getHaveGreeted().contains(other)){
@ -48,8 +49,8 @@ public class GreetingEval {
* @param conversation The conversation
* @return The greeting
*/
public static Quote constructGreeting(Conversation conversation){
Quote response = new Quote("Hello");
public static Sentence constructGreeting(Conversation conversation){
Sentence response = new Sentence("Hello");
return response;
}

View File

@ -5,6 +5,7 @@ import java.util.List;
import org.studiorailgun.Globals;
import org.studiorailgun.conversation.tracking.Conversation;
import org.studiorailgun.conversation.tracking.Quote;
import org.studiorailgun.conversation.tracking.Sentence;
import org.studiorailgun.knowledge.Node;
import org.studiorailgun.knowledge.query.InstanceQuery;
import org.studiorailgun.knowledge.query.QualiaQuery;
@ -43,7 +44,7 @@ public class Interrogative {
* @param items The items noun
* @return null
*/
public static void evalWhichQuery(Conversation conversation, Quote quote, NounStack interrogative, NounStack items){
public static void evalWhichQuery(Conversation conversation, Quote quote, Sentence sentence, NounStack interrogative, NounStack items){
if(QualiaQuery.isQualiaClarificationQuery(interrogative)){
Node finalQualifier = InstanceQuery.getConcept(items.indexedWord.originalText());
@ -56,8 +57,7 @@ public class Interrogative {
Node qualiaType = QualiaQuery.getRequestedQualiaType(interrogative);
conversation.addQuote(quote);
conversation.getGoalData().getQueryData().getRecentQueries().add(quote);
conversation.getGoalData().getQueryData().getRecentQueries().add(sentence);
} else {
throw new Error("Unknown query type!");
}
@ -71,7 +71,7 @@ public class Interrogative {
* @param items The items noun
* @return null
*/
public static void evalWhatQuery(Conversation conversation, Quote quote, NounStack interrogative, NounStack items){
public static void evalWhatQuery(Conversation conversation, Quote quote, Sentence sentence, NounStack interrogative, NounStack items){
if(QualiaQuery.isQualiaClarificationQuery(interrogative)){
Node finalQualifier = InstanceQuery.getConcept(items.indexedWord.originalText());
@ -84,8 +84,7 @@ public class Interrogative {
Node qualiaType = QualiaQuery.getRequestedQualiaType(interrogative);
conversation.addQuote(quote);
conversation.getGoalData().getQueryData().getRecentQueries().add(quote);
conversation.getGoalData().getQueryData().getRecentQueries().add(sentence);
} else {
throw new Error("Unknown query type!");
}

View File

@ -3,7 +3,7 @@ package org.studiorailgun.conversation.evaluators.query;
import java.util.LinkedList;
import java.util.List;
import org.studiorailgun.conversation.tracking.Quote;
import org.studiorailgun.conversation.tracking.Sentence;
/**
* Data about recent queries in the conversation
@ -13,22 +13,22 @@ public class QueryData {
/**
* The recent queries
*/
List<Quote> recentQueries = new LinkedList<Quote>();
List<Sentence> recentQueries = new LinkedList<Sentence>();
/**
* Gets the recent queries in the conversation
* @return The recent queries
*/
public List<Quote> getRecentQueries(){
public List<Sentence> getRecentQueries(){
return recentQueries;
}
/**
* Adds the quote to the recent queries
* @param quote The quote
* Adds the sentence to the recent queries
* @param sentence The sentence
*/
public void addQuery(Quote quote){
recentQueries.add(quote);
public void addQuery(Sentence sentence){
recentQueries.add(sentence);
}
}

View File

@ -3,10 +3,10 @@ package org.studiorailgun.conversation.evaluators.query;
import java.util.Iterator;
import java.util.Set;
import org.studiorailgun.conversation.parser.NLPParser;
import org.studiorailgun.conversation.parser.PennTreebankTagSet;
import org.studiorailgun.conversation.tracking.Conversation;
import org.studiorailgun.conversation.tracking.Quote;
import org.studiorailgun.conversation.tracking.Sentence;
import edu.stanford.nlp.ling.IndexedWord;
import edu.stanford.nlp.semgraph.SemanticGraph;
@ -21,9 +21,8 @@ public class QueryEval {
* @param conversation The conversation
* @param quote The quote
*/
public static void evaluate(Conversation conversation, Quote quote){
NLPParser.parse(quote);
SemanticGraph semanticGraph = quote.getGraph();
public static void evaluate(Conversation conversation, Quote quote, Sentence sentence){
SemanticGraph semanticGraph = sentence.getGraph();
if(semanticGraph.getRoots().size() > 1){
String message = "Multiple roots to sentence!\n" +
"\"" + quote.getRaw() + "\"\n" +
@ -33,16 +32,16 @@ public class QueryEval {
IndexedWord root = semanticGraph.getFirstRoot();
if(PennTreebankTagSet.isVerb(root.tag())){
if(PennTreebankTagSet.isBe(root.tag())){
QueryEval.evaluateBe(conversation, quote);
QueryEval.evaluateBe(conversation, quote, sentence);
} else {
String message = "Unsupported root verb type!\n" +
"\"" + quote.getRaw() + "\"\n" +
"\"" + sentence.getRaw() + "\"\n" +
semanticGraph;
throw new UnsupportedOperationException(message);
}
} else {
String message = "Unsupported root type!\n" +
"\"" + quote.getRaw() + "\"\n" +
"\"" + sentence.getRaw() + "\"\n" +
semanticGraph;
throw new UnsupportedOperationException(message);
}
@ -53,9 +52,9 @@ public class QueryEval {
* @param conversation The conversation
* @param quote The quote
*/
private static void evaluateBe(Conversation conversation, Quote quote){
private static void evaluateBe(Conversation conversation, Quote quote, Sentence sentence){
//get the two things we're comparing
SemanticGraph graph = quote.getGraph();
SemanticGraph graph = sentence.getGraph();
IndexedWord root = graph.getFirstRoot();
Set<IndexedWord> children = graph.getChildren(root);
Iterator<IndexedWord> iterator = children.iterator();
@ -77,10 +76,10 @@ public class QueryEval {
switch(interrogativeStack.interrogative.toLowerCase()){
case "which": {
Interrogative.evalWhichQuery(conversation, quote, interrogativeStack, qualifierStack);
Interrogative.evalWhichQuery(conversation, quote, sentence, interrogativeStack, qualifierStack);
} break;
case "what": {
Interrogative.evalWhatQuery(conversation, quote, interrogativeStack, qualifierStack);
Interrogative.evalWhatQuery(conversation, quote, sentence, interrogativeStack, qualifierStack);
} break;
default : {
throw new Error("Unhandled interrogative type! " + interrogativeStack.interrogative.toLowerCase());

View File

@ -3,7 +3,7 @@ package org.studiorailgun.conversation.evaluators.synthesis;
import org.studiorailgun.conversation.evaluators.greet.GreetingEval;
import org.studiorailgun.conversation.evaluators.transfer.TransferEval;
import org.studiorailgun.conversation.tracking.Conversation;
import org.studiorailgun.conversation.tracking.Quote;
import org.studiorailgun.conversation.tracking.Sentence;
/**
* Evaluates any response the ai might want to construct
@ -15,8 +15,8 @@ public class ResponseEval {
* @param conversation The conversation
* @return The quote encapsulating the AI's response
*/
public static Quote evaluate(Conversation conversation){
Quote response = null;
public static Sentence evaluate(Conversation conversation){
Sentence response = null;
switch(conversation.getGoalData().getGoal()){
case GREET: {
response = GreetingEval.constructGreeting(conversation);

View File

@ -4,10 +4,10 @@ import java.util.Iterator;
import java.util.Set;
import org.studiorailgun.conversation.evaluators.query.NounStack;
import org.studiorailgun.conversation.parser.NLPParser;
import org.studiorailgun.conversation.parser.PennTreebankTagSet;
import org.studiorailgun.conversation.tracking.Conversation;
import org.studiorailgun.conversation.tracking.Quote;
import org.studiorailgun.conversation.tracking.Sentence;
import edu.stanford.nlp.ling.IndexedWord;
import edu.stanford.nlp.semgraph.SemanticGraph;
@ -23,30 +23,29 @@ public class AnswerSynthesis {
* @param conversation The conversation
* @param quote The quote
*/
public static Quote evaluate(Conversation conversation, Quote query){
NLPParser.parse(query);
SemanticGraph semanticGraph = query.getGraph();
public static Sentence evaluate(Conversation conversation, Quote query, Sentence sentence){
SemanticGraph semanticGraph = sentence.getGraph();
if(semanticGraph.getRoots().size() > 1){
String message = "Multiple roots to sentence!\n" +
"\"" + query.getRaw() + "\"\n" +
"\"" + sentence.getRaw() + "\"\n" +
semanticGraph;
throw new UnsupportedOperationException(message);
}
IndexedWord root = semanticGraph.getFirstRoot();
if(PennTreebankTagSet.isVerb(root.tag())){
if(PennTreebankTagSet.isBe(root.tag())){
String answerRaw = AnswerSynthesis.evaluateBe(conversation, query);
Quote quote = new Quote(answerRaw);
return quote;
String answerRaw = AnswerSynthesis.evaluateBe(conversation, query, sentence);
Sentence rVal = new Sentence(answerRaw);
return rVal;
} else {
String message = "Unsupported root verb type!\n" +
"\"" + query.getRaw() + "\"\n" +
"\"" + sentence.getRaw() + "\"\n" +
semanticGraph;
throw new UnsupportedOperationException(message);
}
} else {
String message = "Unsupported root type!\n" +
"\"" + query.getRaw() + "\"\n" +
"\"" + sentence.getRaw() + "\"\n" +
semanticGraph;
throw new UnsupportedOperationException(message);
}
@ -57,9 +56,9 @@ public class AnswerSynthesis {
* @param conversation The conversation
* @param quote The quote
*/
private static String evaluateBe(Conversation conversation, Quote quote){
private static String evaluateBe(Conversation conversation, Quote quote, Sentence sentence){
//get the two things we're comparing
SemanticGraph graph = quote.getGraph();
SemanticGraph graph = sentence.getGraph();
IndexedWord root = graph.getFirstRoot();
Set<IndexedWord> children = graph.getChildren(root);
Iterator<IndexedWord> iterator = children.iterator();
@ -81,10 +80,10 @@ public class AnswerSynthesis {
switch(interrogativeStack.getInterrogative().toLowerCase()){
case "which": {
return Interrogative.evalWhichQuery(conversation, quote, interrogativeStack, qualifierStack);
return Interrogative.evalWhichQuery(conversation, quote, sentence, interrogativeStack, qualifierStack);
}
case "what": {
return Interrogative.evalWhatQuery(conversation, quote, interrogativeStack, qualifierStack);
return Interrogative.evalWhatQuery(conversation, quote, sentence, interrogativeStack, qualifierStack);
}
default : {
throw new Error("Unhandled interrogative type! " + interrogativeStack.getInterrogative().toLowerCase());

View File

@ -8,6 +8,7 @@ import org.studiorailgun.conversation.synthesis.NounStackSynthesizer;
import org.studiorailgun.conversation.synthesis.QualitySynthesizer;
import org.studiorailgun.conversation.tracking.Conversation;
import org.studiorailgun.conversation.tracking.Quote;
import org.studiorailgun.conversation.tracking.Sentence;
import org.studiorailgun.knowledge.Node;
import org.studiorailgun.knowledge.query.InstanceQuery;
import org.studiorailgun.knowledge.query.QualiaQuery;
@ -27,7 +28,7 @@ public class Interrogative {
* @param items The items noun
* @return null
*/
public static String evalWhichQuery(Conversation conversation, Quote quote, NounStack interrogative, NounStack items){
public static String evalWhichQuery(Conversation conversation, Quote quote, Sentence sentence, NounStack interrogative, NounStack items){
//get the thing that has the quality
if(QualiaQuery.isQualiaClarificationQuery(interrogative)){
@ -61,7 +62,7 @@ public class Interrogative {
* @param items The items noun
* @return null
*/
public static String evalWhatQuery(Conversation conversation, Quote quote, NounStack interrogative, NounStack items){
public static String evalWhatQuery(Conversation conversation, Quote quote, Sentence sentence, NounStack interrogative, NounStack items){
if(QualiaQuery.isQualiaClarificationQuery(interrogative)){
//get the thing that has the quality

View File

@ -2,7 +2,7 @@ package org.studiorailgun.conversation.evaluators.transfer;
import org.studiorailgun.conversation.evaluators.goal.GoalData;
import org.studiorailgun.conversation.tracking.Conversation;
import org.studiorailgun.conversation.tracking.Quote;
import org.studiorailgun.conversation.tracking.Sentence;
/**
* Evaluation for transfer quotes
@ -15,13 +15,13 @@ public class TransferEval {
* @param conversation The conversation
* @return The quote
*/
public static Quote synthesize(Conversation conversation){
Quote response = null;
public static Sentence synthesize(Conversation conversation){
Sentence response = null;
GoalData goalData = conversation.getGoalData();
switch(goalData.getGoal()){
case ANSWER: {
Quote questionToAnswer = goalData.getQueryData().getRecentQueries().remove(0);
response = AnswerSynthesis.evaluate(conversation, questionToAnswer);
Sentence questionToAnswer = goalData.getQueryData().getRecentQueries().remove(0);
response = AnswerSynthesis.evaluate(conversation, questionToAnswer.getParent(), questionToAnswer);
} break;
case GREET: {
throw new Error("Unsupported goal type for information transfer! " + goalData.getGoal());

View File

@ -5,6 +5,7 @@ import edu.stanford.nlp.semgraph.*;
import java.util.*;
import org.studiorailgun.conversation.tracking.Quote;
import org.studiorailgun.conversation.tracking.Sentence;
/**
* Parses a sentence
@ -45,9 +46,13 @@ public class NLPParser {
pipeline.annotate(document);
quote.setParsedDocument(document);
//store the semantic graph
SemanticGraph graph = document.sentences().get(0).dependencyParse();
quote.setGraph(graph);
for(CoreSentence coreSentence : document.sentences()){
//store the semantic graph
SemanticGraph graph = coreSentence.dependencyParse();
Sentence sentence = new Sentence(coreSentence.text());
sentence.setGraph(graph);
quote.addSentence(sentence);
}
}
//TODO: grab information from document here

View File

@ -1,9 +1,9 @@
package org.studiorailgun.conversation.tracking;
import org.studiorailgun.conversation.categorization.SentenceFunctionCategorizor.SentenceFunction;
import java.util.LinkedList;
import java.util.List;
import edu.stanford.nlp.pipeline.CoreDocument;
import edu.stanford.nlp.semgraph.SemanticGraph;
/**
* A quote stated during the conversation
@ -15,20 +15,15 @@ public class Quote {
*/
String raw;
/**
* The function of the sentence
*/
SentenceFunction function;
/**
* The CoreNLP parsed document
*/
CoreDocument parsedDocument;
/**
* The parsed semantic graph
* The list of sentences in this quote
*/
SemanticGraph graph;
List<Sentence> sentences;
/**
* Constructor
@ -36,6 +31,7 @@ public class Quote {
*/
public Quote(String input){
this.raw = input;
this.sentences = new LinkedList<Sentence>();
}
/**
@ -46,18 +42,6 @@ public class Quote {
return raw;
}
/**
* Gets the function of the quote
* @return The function
*/
public SentenceFunction getFunction(){
return function;
}
public void setFunction(SentenceFunction function){
this.function = function;
}
public CoreDocument getParsedDocument() {
return parsedDocument;
}
@ -66,15 +50,25 @@ public class Quote {
this.parsedDocument = parsedDocument;
}
public SemanticGraph getGraph() {
return graph;
public List<Sentence> getSentences(){
return sentences;
}
public void setGraph(SemanticGraph graph) {
this.graph = graph;
public void addSentence(Sentence sentence){
this.sentences.add(sentence);
}
/**
* Appends a sentence to the quote (including adding its raw text to the raw text of the quote)
* @param sentence The sentence
*/
public void appendSentence(Sentence sentence){
this.sentences.add(sentence);
if(this.raw.length() > 0){
this.raw = this.raw + " ";
}
this.raw = this.raw + sentence.getRaw();
}
}

View File

@ -0,0 +1,72 @@
package org.studiorailgun.conversation.tracking;
import org.studiorailgun.conversation.categorization.SentenceFunctionCategorizor.SentenceFunction;
import edu.stanford.nlp.semgraph.SemanticGraph;
/**
* A single sentence in a quote
*/
public class Sentence {
/**
* The raw text of the sentence
*/
String raw;
/**
* The function of the sentence
*/
SentenceFunction function;
/**
* The parsed semantic graph
*/
SemanticGraph graph;
/**
* The parent quote of this sentence
*/
Quote parent;
public Sentence(String raw){
this.raw = raw;
}
/**
* Gets the raw text of the quote
* @return The raw text of the quote
*/
public String getRaw(){
return raw;
}
/**
* Gets the function of the quote
* @return The function
*/
public SentenceFunction getFunction(){
return function;
}
public void setFunction(SentenceFunction function){
this.function = function;
}
public SemanticGraph getGraph() {
return graph;
}
public void setGraph(SemanticGraph graph) {
this.graph = graph;
}
public void setParent(Quote parent){
this.parent = parent;
}
public Quote getParent(){
return parent;
}
}

View File

@ -3,7 +3,7 @@ package org.studiorailgun;
import static org.junit.Assert.*;
import org.junit.Test;
import org.studiorailgun.conversation.evaluators.transfer.AnswerSynthesis;
import org.studiorailgun.conversation.ConvAI;
import org.studiorailgun.conversation.tracking.Quote;
/**
@ -14,7 +14,7 @@ public class AnswerTests {
@Test
public void testAnswerQuery(){
Globals.init("./data/webs/test/web.json");
Quote result = AnswerSynthesis.evaluate(Globals.conversation, new Quote("What color is your hat?"));
Quote result = ConvAI.simFrame("What color is your hat?");
assertEquals(result.getRaw().contains("Blue"), true);
}

View File

@ -4,9 +4,14 @@ import static org.junit.jupiter.api.Assertions.*;
import org.junit.jupiter.api.Test;
import org.studiorailgun.conversation.ConvAI;
import org.studiorailgun.conversation.categorization.SentenceFunctionCategorizor;
import org.studiorailgun.conversation.evaluators.greet.GreetingEval;
import org.studiorailgun.conversation.evaluators.query.QueryData;
import org.studiorailgun.conversation.evaluators.query.QueryEval;
import org.studiorailgun.conversation.parser.NLPParser;
import org.studiorailgun.conversation.tracking.Conversation;
import org.studiorailgun.conversation.tracking.Quote;
import org.studiorailgun.conversation.tracking.Sentence;
/**
* Query tests
@ -27,7 +32,36 @@ public class QueryTests {
@Test
public void testQueryEval(){
Globals.init("./data/webs/test/web.json");
QueryEval.evaluate(Globals.conversation, new Quote("What color is your hat?"));
//Linguistics structures
Conversation conversation = Globals.conversation;
Quote quote = new Quote("What color is your hat?");
//perform NLP evaluation of the quote
NLPParser.parse(quote);
//parse data about the quote
for(Sentence sentence : quote.getSentences()){
SentenceFunctionCategorizor.categorize(sentence);
}
//add the quote to the conversation
conversation.addQuote(quote);
//evaluate each sentence
for(Sentence sentence : quote.getSentences()){
switch(sentence.getFunction()){
case UTILITY: {
GreetingEval.evaluate(conversation, quote, sentence);
} break;
case QUERY: {
QueryEval.evaluate(conversation, quote, sentence);
} break;
default: {
throw new UnsupportedOperationException("Unsupported quote function: " + sentence.getFunction());
}
}
}
QueryData queryData = Globals.conversation.getGoalData().getQueryData();
assertEquals(queryData.getRecentQueries().size(), 1);
}