classification model for sentence function
All checks were successful
studiorailgun/trpg/pipeline/head This commit looks good
All checks were successful
studiorailgun/trpg/pipeline/head This commit looks good
This commit is contained in:
parent
e99520ccf0
commit
67360681c9
@ -10,7 +10,7 @@ source ./.venv/Scripts/activate
|
|||||||
```
|
```
|
||||||
|
|
||||||
```
|
```
|
||||||
pip install tensorflow keras mypy numpy types-tensorflow
|
pip install tensorflow keras mypy pandas numpy types-tensorflow
|
||||||
```
|
```
|
||||||
|
|
||||||
In settings, make sure to check `Run Using Active Interpreter` in mypy settings
|
In settings, make sure to check `Run Using Active Interpreter` in mypy settings
|
||||||
@ -7,14 +7,7 @@ sitting in a tavern by a fireplace
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
Parse "Hello"
|
Respond to "What color is your hat?"
|
||||||
Nodes needed:
|
|
||||||
- Conversation
|
|
||||||
- Greeting
|
|
||||||
- Participant
|
|
||||||
- Instances of participants
|
|
||||||
- Instances of conversations
|
|
||||||
- Instances of greetings (created as the conversation starts)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
1
data/model/sent_func/fingerprint.pb
Normal file
1
data/model/sent_func/fingerprint.pb
Normal file
@ -0,0 +1 @@
|
|||||||
|
▄╘з▓┌кЪДЖЬ▄йЩ▄сЩлЖ╔в╔▓ИКЙG ▄шЁЗаБ÷Т╒(▀лЙ▀Рэ║Р
2
|
||||||
BIN
data/model/sent_func/saved_model.pb
Normal file
BIN
data/model/sent_func/saved_model.pb
Normal file
Binary file not shown.
BIN
data/model/sent_func/variables/variables.data-00000-of-00001
Normal file
BIN
data/model/sent_func/variables/variables.data-00000-of-00001
Normal file
Binary file not shown.
BIN
data/model/sent_func/variables/variables.index
Normal file
BIN
data/model/sent_func/variables/variables.index
Normal file
Binary file not shown.
4
data/semantic/sent_func/test.csv
Normal file
4
data/semantic/sent_func/test.csv
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
"utility","transfer","query","imperative","sentence"
|
||||||
|
1,0,0,0,"Hello"
|
||||||
|
1,0,0,0,"Hi"
|
||||||
|
1,0,0,0,"Howdy"
|
||||||
|
4
data/semantic/sent_func/train.csv
Normal file
4
data/semantic/sent_func/train.csv
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
"utility","transfer","query","imperative","sentence"
|
||||||
|
1,0,0,0,"Hello"
|
||||||
|
1,0,0,0,"Hi"
|
||||||
|
1,0,0,0,"Howdy"
|
||||||
|
@ -2,6 +2,7 @@ package org.studiorailgun;
|
|||||||
|
|
||||||
import java.io.File;
|
import java.io.File;
|
||||||
|
|
||||||
|
import org.studiorailgun.conversation.categorization.SentenceFunctionCategorizor;
|
||||||
import org.studiorailgun.conversation.evaluators.greet.GreetingEval;
|
import org.studiorailgun.conversation.evaluators.greet.GreetingEval;
|
||||||
import org.studiorailgun.conversation.tracking.Conversation;
|
import org.studiorailgun.conversation.tracking.Conversation;
|
||||||
import org.studiorailgun.knowledge.KnowledgeWeb;
|
import org.studiorailgun.knowledge.KnowledgeWeb;
|
||||||
@ -27,6 +28,7 @@ public class Globals {
|
|||||||
public static void init(String webPath){
|
public static void init(String webPath){
|
||||||
//initialize evaluators
|
//initialize evaluators
|
||||||
GreetingEval.init();
|
GreetingEval.init();
|
||||||
|
SentenceFunctionCategorizor.init();
|
||||||
|
|
||||||
//init web
|
//init web
|
||||||
Globals.web = FileUtils.loadObjectFromFile(new File(webPath), KnowledgeWeb.class);
|
Globals.web = FileUtils.loadObjectFromFile(new File(webPath), KnowledgeWeb.class);
|
||||||
|
|||||||
@ -1,6 +1,18 @@
|
|||||||
package org.studiorailgun.conversation.categorization;
|
package org.studiorailgun.conversation.categorization;
|
||||||
|
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
import org.studiorailgun.conversation.tracking.Quote;
|
import org.studiorailgun.conversation.tracking.Quote;
|
||||||
|
import org.tensorflow.Result;
|
||||||
|
import org.tensorflow.SavedModelBundle;
|
||||||
|
import org.tensorflow.Tensor;
|
||||||
|
import org.tensorflow.ndarray.FloatNdArray;
|
||||||
|
import org.tensorflow.ndarray.NdArray;
|
||||||
|
import org.tensorflow.ndarray.NdArrays;
|
||||||
|
import org.tensorflow.ndarray.Shape;
|
||||||
|
import org.tensorflow.types.TFloat32;
|
||||||
|
import org.tensorflow.types.TString;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Categorizes sentences based on function
|
* Categorizes sentences based on function
|
||||||
@ -34,13 +46,78 @@ public class SentenceFunctionCategorizor {
|
|||||||
IMPERATIVE,
|
IMPERATIVE,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The model
|
||||||
|
*/
|
||||||
|
static SavedModelBundle model;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Initializes the categorization model
|
||||||
|
*/
|
||||||
|
public static void init(){
|
||||||
|
model = SavedModelBundle.load("./data/model/sent_func");
|
||||||
|
|
||||||
|
//describe the model
|
||||||
|
// System.out.println(model.functions());
|
||||||
|
// System.out.println(model.functions().get(0).signature());
|
||||||
|
// System.out.println(model.functions().get(1).signature());
|
||||||
|
// System.out.println(model.functions().get(2).signature());
|
||||||
|
// System.out.println(model.functions().get(0).signature().getInputs());
|
||||||
|
// System.out.println(model.functions().get(0).signature().getInputs().values().iterator().next().name);
|
||||||
|
// System.out.println(model.functions().get(0).signature().getInputs().values().iterator().next().shape);
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Categorizes the sentence by function
|
* Categorizes the sentence by function
|
||||||
* @param input The input quote
|
* @param input The input quote
|
||||||
* @return The function of the sentence
|
* @return The function of the sentence
|
||||||
*/
|
*/
|
||||||
public static void categorize(Quote input){
|
public static void categorize(Quote input){
|
||||||
|
//construct input
|
||||||
|
TString inputTensor = TString.scalarOf(input.getRaw());
|
||||||
|
inputTensor.shape().append(1);
|
||||||
|
NdArray<String> stringArr = NdArrays.ofObjects(String.class, Shape.of(1, 1));
|
||||||
|
stringArr.setObject(input.getRaw());
|
||||||
|
inputTensor = TString.tensorOf(stringArr);
|
||||||
|
|
||||||
|
Map<String,Tensor> inputMap = new HashMap<String,Tensor>();
|
||||||
|
inputMap.put("keras_tensor",inputTensor);
|
||||||
|
|
||||||
|
//call
|
||||||
|
Result result = model.function("serve").call(inputMap);
|
||||||
|
|
||||||
|
//parse results
|
||||||
|
TFloat32 resultFloats = (TFloat32)result.get(0);
|
||||||
|
FloatNdArray floatArr = resultFloats.get(0);
|
||||||
|
float[] classification = new float[]{
|
||||||
|
floatArr.getFloat(0),
|
||||||
|
floatArr.getFloat(1),
|
||||||
|
floatArr.getFloat(2),
|
||||||
|
floatArr.getFloat(3),
|
||||||
|
};
|
||||||
|
|
||||||
|
//figure out highest value category
|
||||||
|
int maxAt = 0;
|
||||||
|
for(int i = 0; i < classification.length; i++){
|
||||||
|
if(classification[i] > classification[maxAt]){
|
||||||
|
maxAt = i;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
switch(maxAt){
|
||||||
|
case 0: {
|
||||||
input.setFunction(SentenceFunction.UTILITY);
|
input.setFunction(SentenceFunction.UTILITY);
|
||||||
|
} break;
|
||||||
|
case 1: {
|
||||||
|
input.setFunction(SentenceFunction.TRANSFER);
|
||||||
|
} break;
|
||||||
|
case 2: {
|
||||||
|
input.setFunction(SentenceFunction.QUERY);
|
||||||
|
} break;
|
||||||
|
case 3: {
|
||||||
|
input.setFunction(SentenceFunction.IMPERATIVE);
|
||||||
|
} break;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@ -5,6 +5,8 @@ from keras.api.layers import TextVectorization, Embedding, LSTM, Dense, Input
|
|||||||
from keras.api.models import Sequential
|
from keras.api.models import Sequential
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import numpy.typing as npt
|
import numpy.typing as npt
|
||||||
|
import pandas as pd
|
||||||
|
from pandas import DataFrame
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@ -33,17 +35,48 @@ sequence_length: int = 500
|
|||||||
epochs: int = 50
|
epochs: int = 50
|
||||||
max_tokens: int = 5000
|
max_tokens: int = 5000
|
||||||
output_sequence_length: int = 4
|
output_sequence_length: int = 4
|
||||||
|
num_classes: int = 4
|
||||||
|
|
||||||
|
|
||||||
# read sentences
|
|
||||||
data_path: str = './data/semantic/subject.txt'
|
|
||||||
data_raw: str = open(data_path).read()
|
|
||||||
vocab: list[str] = data_raw.split('\n')
|
|
||||||
|
|
||||||
# read labels
|
#
|
||||||
label_data_path: str = './data/semantic/subject_label.txt'
|
# LOAD DATA
|
||||||
label_data_raw: str = open(label_data_path).read()
|
#
|
||||||
labels: list[int] = list(map(int,label_data_raw.split()))
|
|
||||||
|
|
||||||
|
# read training sentences
|
||||||
|
train_csv_path: str = './data/semantic/sent_func/train.csv'
|
||||||
|
train_csv_raw: DataFrame = pd.read_csv(open(train_csv_path), dtype={
|
||||||
|
"utility": "int64",
|
||||||
|
"transfer": "int64",
|
||||||
|
"query": "int64",
|
||||||
|
"imperative": "int64",
|
||||||
|
"sentence": "string",
|
||||||
|
})
|
||||||
|
train_data_raw: DataFrame = train_csv_raw[["sentence"]]
|
||||||
|
train_data_split: DataFrame = train_data_raw
|
||||||
|
vocab: list[str] = train_data_raw.to_numpy()
|
||||||
|
train_labels: DataFrame = train_csv_raw[["utility", "transfer", "query", "imperative"]]
|
||||||
|
|
||||||
|
# read testing sentences
|
||||||
|
test_csv_path: str = './data/semantic/sent_func/test.csv'
|
||||||
|
test_csv_raw: DataFrame = pd.read_csv(open(test_csv_path), dtype={
|
||||||
|
"utility": "int64",
|
||||||
|
"transfer": "int64",
|
||||||
|
"query": "int64",
|
||||||
|
"imperative": "int64",
|
||||||
|
"sentence": "string",
|
||||||
|
})
|
||||||
|
test_data_raw: DataFrame = test_csv_raw[["sentence"]]
|
||||||
|
test_data_split: DataFrame = test_data_raw
|
||||||
|
test_labels: DataFrame = test_csv_raw[["utility", "transfer", "query", "imperative"]]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
#
|
||||||
|
# CREATE VECTORIZER
|
||||||
|
#
|
||||||
|
|
||||||
|
|
||||||
# init vectorizer
|
# init vectorizer
|
||||||
textVec: TextVectorization = TextVectorization(
|
textVec: TextVectorization = TextVectorization(
|
||||||
@ -54,15 +87,24 @@ textVec: TextVectorization = TextVectorization(
|
|||||||
|
|
||||||
# Add the vocab to the tokenizer
|
# Add the vocab to the tokenizer
|
||||||
textVec.adapt(vocab)
|
textVec.adapt(vocab)
|
||||||
input_data: list[str] = vocab
|
input_data: list[str] = train_data_split
|
||||||
data: Tensor = textVec.call(input_data)
|
train_data: Tensor = textVec.call(input_data)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
#
|
||||||
|
# CREATE MODEL
|
||||||
|
#
|
||||||
|
|
||||||
|
|
||||||
# construct model
|
# construct model
|
||||||
model: Sequential = Sequential([
|
model: Sequential = Sequential([
|
||||||
keras.Input(shape=(None,), dtype="int64"),
|
keras.Input(shape=(1,), dtype=tf.string),
|
||||||
|
textVec,
|
||||||
Embedding(max_features + 1, embedding_dim),
|
Embedding(max_features + 1, embedding_dim),
|
||||||
LSTM(64),
|
LSTM(64),
|
||||||
Dense(1, activation='sigmoid')
|
Dense(num_classes, activation='sigmoid')
|
||||||
])
|
])
|
||||||
|
|
||||||
#compile the model
|
#compile the model
|
||||||
@ -70,26 +112,53 @@ model: Sequential = Sequential([
|
|||||||
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
|
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
|
||||||
|
|
||||||
|
|
||||||
# fit the training data
|
|
||||||
npData = np.array(data)
|
#
|
||||||
npLabel = np.array(labels)
|
# TRAIN MODEL
|
||||||
model.fit(npData,npLabel,epochs=epochs)
|
#
|
||||||
|
|
||||||
|
# Final formatting of data
|
||||||
|
npTrainData = train_data_split.to_numpy(dtype=object).flatten()
|
||||||
|
npTrainLabel: npt.NDArray[np.complex64] = train_labels.to_numpy()
|
||||||
|
if(npTrainData.shape[0] != npTrainLabel.shape[0]):
|
||||||
|
print("Data label size mismatch!")
|
||||||
|
print(npTrainData.shape)
|
||||||
|
print(npTrainLabel.shape)
|
||||||
|
|
||||||
|
# try fitting data
|
||||||
|
print("Training..")
|
||||||
|
model.fit(npTrainData,npTrainLabel,epochs=epochs)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
#
|
||||||
|
# EVALUATE MODEL
|
||||||
|
#
|
||||||
|
|
||||||
|
|
||||||
# evaluate here
|
# evaluate here
|
||||||
|
npTestData: npt.NDArray = test_data_split.to_numpy(dtype=object).flatten()
|
||||||
|
npTestLabel: npt.NDArray[np.complex64] = test_labels.to_numpy()
|
||||||
|
if(npTestData.shape[0] != npTestLabel.shape[0]):
|
||||||
|
print("Data label size mismatch!")
|
||||||
|
print(npTestData.shape)
|
||||||
|
print(npTestLabel.shape)
|
||||||
|
|
||||||
|
print("Evaluating..")
|
||||||
|
model.evaluate(npTestData,npTestLabel)
|
||||||
|
|
||||||
|
|
||||||
# predict
|
# predict
|
||||||
predictTargetRaw: list[str] = ['saf']
|
predictTargetRaw: Tensor = tf.constant(['Hello'])
|
||||||
predictTargetToken: list[int] = textVec.call(predictTargetRaw)
|
npPredict: npt.NDArray = np.array(predictTargetRaw, dtype=object)
|
||||||
npPredict: npt.NDArray[np.complex64] = np.array(predictTargetToken)
|
print("Prediction test..")
|
||||||
# print(npPredict)
|
result: list[int] = model.predict(predictTargetRaw)
|
||||||
result: list[int] = model.predict(npPredict)
|
|
||||||
print("predict result:")
|
print("predict result:")
|
||||||
print(predictTargetToken)
|
print(predictTargetRaw)
|
||||||
print(result)
|
print(result)
|
||||||
print(data)
|
print(train_data)
|
||||||
print(labels)
|
print(train_labels)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@ -97,8 +166,14 @@ print(labels)
|
|||||||
# savePath: str = './data/semantic/model.keras'
|
# savePath: str = './data/semantic/model.keras'
|
||||||
# model.save(savePath)
|
# model.save(savePath)
|
||||||
|
|
||||||
|
|
||||||
|
#
|
||||||
|
# SAVE MODEL
|
||||||
|
#
|
||||||
|
|
||||||
# export the model so java can leverage it
|
# export the model so java can leverage it
|
||||||
exportPath: str = './data/semantic/model'
|
print("Saving..")
|
||||||
|
exportPath: str = './data/model/sent_func'
|
||||||
model.export(exportPath)
|
model.export(exportPath)
|
||||||
|
|
||||||
# tf.keras.utils.get_file('asdf')
|
# tf.keras.utils.get_file('asdf')
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user