classification model for sentence function
All checks were successful
studiorailgun/trpg/pipeline/head This commit looks good

This commit is contained in:
austin 2024-12-29 14:05:27 -05:00
parent e99520ccf0
commit 67360681c9
11 changed files with 191 additions and 35 deletions

View File

@ -10,7 +10,7 @@ source ./.venv/Scripts/activate
``` ```
``` ```
pip install tensorflow keras mypy numpy types-tensorflow pip install tensorflow keras mypy pandas numpy types-tensorflow
``` ```
In settings, make sure to check `Run Using Active Interpreter` in mypy settings In settings, make sure to check `Run Using Active Interpreter` in mypy settings

View File

@ -7,14 +7,7 @@ sitting in a tavern by a fireplace
Parse "Hello" Respond to "What color is your hat?"
Nodes needed:
- Conversation
- Greeting
- Participant
- Instances of participants
- Instances of conversations
- Instances of greetings (created as the conversation starts)

View File

@ -0,0 +1 @@
▄╘з▓┌кЪДЖЬ▄йЩ▄сЩлЖ╔в╔▓ИКЙG ▄шЁЗаБ÷Т╒(▀лЙ▀Рэ║Р 2

Binary file not shown.

Binary file not shown.

View File

@ -0,0 +1,4 @@
"utility","transfer","query","imperative","sentence"
1,0,0,0,"Hello"
1,0,0,0,"Hi"
1,0,0,0,"Howdy"
1 utility transfer query imperative sentence
2 1 0 0 0 Hello
3 1 0 0 0 Hi
4 1 0 0 0 Howdy

View File

@ -0,0 +1,4 @@
"utility","transfer","query","imperative","sentence"
1,0,0,0,"Hello"
1,0,0,0,"Hi"
1,0,0,0,"Howdy"
1 utility transfer query imperative sentence
2 1 0 0 0 Hello
3 1 0 0 0 Hi
4 1 0 0 0 Howdy

View File

@ -2,6 +2,7 @@ package org.studiorailgun;
import java.io.File; import java.io.File;
import org.studiorailgun.conversation.categorization.SentenceFunctionCategorizor;
import org.studiorailgun.conversation.evaluators.greet.GreetingEval; import org.studiorailgun.conversation.evaluators.greet.GreetingEval;
import org.studiorailgun.conversation.tracking.Conversation; import org.studiorailgun.conversation.tracking.Conversation;
import org.studiorailgun.knowledge.KnowledgeWeb; import org.studiorailgun.knowledge.KnowledgeWeb;
@ -27,6 +28,7 @@ public class Globals {
public static void init(String webPath){ public static void init(String webPath){
//initialize evaluators //initialize evaluators
GreetingEval.init(); GreetingEval.init();
SentenceFunctionCategorizor.init();
//init web //init web
Globals.web = FileUtils.loadObjectFromFile(new File(webPath), KnowledgeWeb.class); Globals.web = FileUtils.loadObjectFromFile(new File(webPath), KnowledgeWeb.class);

View File

@ -1,6 +1,18 @@
package org.studiorailgun.conversation.categorization; package org.studiorailgun.conversation.categorization;
import java.util.HashMap;
import java.util.Map;
import org.studiorailgun.conversation.tracking.Quote; import org.studiorailgun.conversation.tracking.Quote;
import org.tensorflow.Result;
import org.tensorflow.SavedModelBundle;
import org.tensorflow.Tensor;
import org.tensorflow.ndarray.FloatNdArray;
import org.tensorflow.ndarray.NdArray;
import org.tensorflow.ndarray.NdArrays;
import org.tensorflow.ndarray.Shape;
import org.tensorflow.types.TFloat32;
import org.tensorflow.types.TString;
/** /**
* Categorizes sentences based on function * Categorizes sentences based on function
@ -34,13 +46,78 @@ public class SentenceFunctionCategorizor {
IMPERATIVE, IMPERATIVE,
} }
/**
* The model
*/
static SavedModelBundle model;
/**
* Initializes the categorization model
*/
public static void init(){
model = SavedModelBundle.load("./data/model/sent_func");
//describe the model
// System.out.println(model.functions());
// System.out.println(model.functions().get(0).signature());
// System.out.println(model.functions().get(1).signature());
// System.out.println(model.functions().get(2).signature());
// System.out.println(model.functions().get(0).signature().getInputs());
// System.out.println(model.functions().get(0).signature().getInputs().values().iterator().next().name);
// System.out.println(model.functions().get(0).signature().getInputs().values().iterator().next().shape);
}
/** /**
* Categorizes the sentence by function * Categorizes the sentence by function
* @param input The input quote * @param input The input quote
* @return The function of the sentence * @return The function of the sentence
*/ */
public static void categorize(Quote input){ public static void categorize(Quote input){
input.setFunction(SentenceFunction.UTILITY); //construct input
TString inputTensor = TString.scalarOf(input.getRaw());
inputTensor.shape().append(1);
NdArray<String> stringArr = NdArrays.ofObjects(String.class, Shape.of(1, 1));
stringArr.setObject(input.getRaw());
inputTensor = TString.tensorOf(stringArr);
Map<String,Tensor> inputMap = new HashMap<String,Tensor>();
inputMap.put("keras_tensor",inputTensor);
//call
Result result = model.function("serve").call(inputMap);
//parse results
TFloat32 resultFloats = (TFloat32)result.get(0);
FloatNdArray floatArr = resultFloats.get(0);
float[] classification = new float[]{
floatArr.getFloat(0),
floatArr.getFloat(1),
floatArr.getFloat(2),
floatArr.getFloat(3),
};
//figure out highest value category
int maxAt = 0;
for(int i = 0; i < classification.length; i++){
if(classification[i] > classification[maxAt]){
maxAt = i;
}
}
switch(maxAt){
case 0: {
input.setFunction(SentenceFunction.UTILITY);
} break;
case 1: {
input.setFunction(SentenceFunction.TRANSFER);
} break;
case 2: {
input.setFunction(SentenceFunction.QUERY);
} break;
case 3: {
input.setFunction(SentenceFunction.IMPERATIVE);
} break;
}
} }
} }

View File

@ -5,6 +5,8 @@ from keras.api.layers import TextVectorization, Embedding, LSTM, Dense, Input
from keras.api.models import Sequential from keras.api.models import Sequential
import numpy as np import numpy as np
import numpy.typing as npt import numpy.typing as npt
import pandas as pd
from pandas import DataFrame
@ -33,17 +35,48 @@ sequence_length: int = 500
epochs: int = 50 epochs: int = 50
max_tokens: int = 5000 max_tokens: int = 5000
output_sequence_length: int = 4 output_sequence_length: int = 4
num_classes: int = 4
# read sentences
data_path: str = './data/semantic/subject.txt'
data_raw: str = open(data_path).read()
vocab: list[str] = data_raw.split('\n')
# read labels #
label_data_path: str = './data/semantic/subject_label.txt' # LOAD DATA
label_data_raw: str = open(label_data_path).read() #
labels: list[int] = list(map(int,label_data_raw.split()))
# read training sentences
train_csv_path: str = './data/semantic/sent_func/train.csv'
train_csv_raw: DataFrame = pd.read_csv(open(train_csv_path), dtype={
"utility": "int64",
"transfer": "int64",
"query": "int64",
"imperative": "int64",
"sentence": "string",
})
train_data_raw: DataFrame = train_csv_raw[["sentence"]]
train_data_split: DataFrame = train_data_raw
vocab: list[str] = train_data_raw.to_numpy()
train_labels: DataFrame = train_csv_raw[["utility", "transfer", "query", "imperative"]]
# read testing sentences
test_csv_path: str = './data/semantic/sent_func/test.csv'
test_csv_raw: DataFrame = pd.read_csv(open(test_csv_path), dtype={
"utility": "int64",
"transfer": "int64",
"query": "int64",
"imperative": "int64",
"sentence": "string",
})
test_data_raw: DataFrame = test_csv_raw[["sentence"]]
test_data_split: DataFrame = test_data_raw
test_labels: DataFrame = test_csv_raw[["utility", "transfer", "query", "imperative"]]
#
# CREATE VECTORIZER
#
# init vectorizer # init vectorizer
textVec: TextVectorization = TextVectorization( textVec: TextVectorization = TextVectorization(
@ -54,15 +87,24 @@ textVec: TextVectorization = TextVectorization(
# Add the vocab to the tokenizer # Add the vocab to the tokenizer
textVec.adapt(vocab) textVec.adapt(vocab)
input_data: list[str] = vocab input_data: list[str] = train_data_split
data: Tensor = textVec.call(input_data) train_data: Tensor = textVec.call(input_data)
#
# CREATE MODEL
#
# construct model # construct model
model: Sequential = Sequential([ model: Sequential = Sequential([
keras.Input(shape=(None,), dtype="int64"), keras.Input(shape=(1,), dtype=tf.string),
textVec,
Embedding(max_features + 1, embedding_dim), Embedding(max_features + 1, embedding_dim),
LSTM(64), LSTM(64),
Dense(1, activation='sigmoid') Dense(num_classes, activation='sigmoid')
]) ])
#compile the model #compile the model
@ -70,26 +112,53 @@ model: Sequential = Sequential([
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy']) model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
# fit the training data
npData = np.array(data) #
npLabel = np.array(labels) # TRAIN MODEL
model.fit(npData,npLabel,epochs=epochs) #
# Final formatting of data
npTrainData = train_data_split.to_numpy(dtype=object).flatten()
npTrainLabel: npt.NDArray[np.complex64] = train_labels.to_numpy()
if(npTrainData.shape[0] != npTrainLabel.shape[0]):
print("Data label size mismatch!")
print(npTrainData.shape)
print(npTrainLabel.shape)
# try fitting data
print("Training..")
model.fit(npTrainData,npTrainLabel,epochs=epochs)
#
# EVALUATE MODEL
#
# evaluate here # evaluate here
npTestData: npt.NDArray = test_data_split.to_numpy(dtype=object).flatten()
npTestLabel: npt.NDArray[np.complex64] = test_labels.to_numpy()
if(npTestData.shape[0] != npTestLabel.shape[0]):
print("Data label size mismatch!")
print(npTestData.shape)
print(npTestLabel.shape)
print("Evaluating..")
model.evaluate(npTestData,npTestLabel)
# predict # predict
predictTargetRaw: list[str] = ['saf'] predictTargetRaw: Tensor = tf.constant(['Hello'])
predictTargetToken: list[int] = textVec.call(predictTargetRaw) npPredict: npt.NDArray = np.array(predictTargetRaw, dtype=object)
npPredict: npt.NDArray[np.complex64] = np.array(predictTargetToken) print("Prediction test..")
# print(npPredict) result: list[int] = model.predict(predictTargetRaw)
result: list[int] = model.predict(npPredict)
print("predict result:") print("predict result:")
print(predictTargetToken) print(predictTargetRaw)
print(result) print(result)
print(data) print(train_data)
print(labels) print(train_labels)
@ -97,8 +166,14 @@ print(labels)
# savePath: str = './data/semantic/model.keras' # savePath: str = './data/semantic/model.keras'
# model.save(savePath) # model.save(savePath)
#
# SAVE MODEL
#
# export the model so java can leverage it # export the model so java can leverage it
exportPath: str = './data/semantic/model' print("Saving..")
exportPath: str = './data/model/sent_func'
model.export(exportPath) model.export(exportPath)
# tf.keras.utils.get_file('asdf') # tf.keras.utils.get_file('asdf')