classification model for sentence function
All checks were successful
studiorailgun/trpg/pipeline/head This commit looks good
All checks were successful
studiorailgun/trpg/pipeline/head This commit looks good
This commit is contained in:
parent
e99520ccf0
commit
67360681c9
@ -10,7 +10,7 @@ source ./.venv/Scripts/activate
|
||||
```
|
||||
|
||||
```
|
||||
pip install tensorflow keras mypy numpy types-tensorflow
|
||||
pip install tensorflow keras mypy pandas numpy types-tensorflow
|
||||
```
|
||||
|
||||
In settings, make sure to check `Run Using Active Interpreter` in mypy settings
|
||||
@ -7,14 +7,7 @@ sitting in a tavern by a fireplace
|
||||
|
||||
|
||||
|
||||
Parse "Hello"
|
||||
Nodes needed:
|
||||
- Conversation
|
||||
- Greeting
|
||||
- Participant
|
||||
- Instances of participants
|
||||
- Instances of conversations
|
||||
- Instances of greetings (created as the conversation starts)
|
||||
Respond to "What color is your hat?"
|
||||
|
||||
|
||||
|
||||
|
||||
1
data/model/sent_func/fingerprint.pb
Normal file
1
data/model/sent_func/fingerprint.pb
Normal file
@ -0,0 +1 @@
|
||||
▄╘з▓┌кЪДЖЬ▄йЩ▄сЩлЖ╔в╔▓ИКЙG ▄шЁЗаБ÷Т╒(▀лЙ▀Рэ║Р
2
|
||||
BIN
data/model/sent_func/saved_model.pb
Normal file
BIN
data/model/sent_func/saved_model.pb
Normal file
Binary file not shown.
BIN
data/model/sent_func/variables/variables.data-00000-of-00001
Normal file
BIN
data/model/sent_func/variables/variables.data-00000-of-00001
Normal file
Binary file not shown.
BIN
data/model/sent_func/variables/variables.index
Normal file
BIN
data/model/sent_func/variables/variables.index
Normal file
Binary file not shown.
4
data/semantic/sent_func/test.csv
Normal file
4
data/semantic/sent_func/test.csv
Normal file
@ -0,0 +1,4 @@
|
||||
"utility","transfer","query","imperative","sentence"
|
||||
1,0,0,0,"Hello"
|
||||
1,0,0,0,"Hi"
|
||||
1,0,0,0,"Howdy"
|
||||
|
4
data/semantic/sent_func/train.csv
Normal file
4
data/semantic/sent_func/train.csv
Normal file
@ -0,0 +1,4 @@
|
||||
"utility","transfer","query","imperative","sentence"
|
||||
1,0,0,0,"Hello"
|
||||
1,0,0,0,"Hi"
|
||||
1,0,0,0,"Howdy"
|
||||
|
@ -2,6 +2,7 @@ package org.studiorailgun;
|
||||
|
||||
import java.io.File;
|
||||
|
||||
import org.studiorailgun.conversation.categorization.SentenceFunctionCategorizor;
|
||||
import org.studiorailgun.conversation.evaluators.greet.GreetingEval;
|
||||
import org.studiorailgun.conversation.tracking.Conversation;
|
||||
import org.studiorailgun.knowledge.KnowledgeWeb;
|
||||
@ -27,6 +28,7 @@ public class Globals {
|
||||
public static void init(String webPath){
|
||||
//initialize evaluators
|
||||
GreetingEval.init();
|
||||
SentenceFunctionCategorizor.init();
|
||||
|
||||
//init web
|
||||
Globals.web = FileUtils.loadObjectFromFile(new File(webPath), KnowledgeWeb.class);
|
||||
|
||||
@ -1,6 +1,18 @@
|
||||
package org.studiorailgun.conversation.categorization;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
import org.studiorailgun.conversation.tracking.Quote;
|
||||
import org.tensorflow.Result;
|
||||
import org.tensorflow.SavedModelBundle;
|
||||
import org.tensorflow.Tensor;
|
||||
import org.tensorflow.ndarray.FloatNdArray;
|
||||
import org.tensorflow.ndarray.NdArray;
|
||||
import org.tensorflow.ndarray.NdArrays;
|
||||
import org.tensorflow.ndarray.Shape;
|
||||
import org.tensorflow.types.TFloat32;
|
||||
import org.tensorflow.types.TString;
|
||||
|
||||
/**
|
||||
* Categorizes sentences based on function
|
||||
@ -34,13 +46,78 @@ public class SentenceFunctionCategorizor {
|
||||
IMPERATIVE,
|
||||
}
|
||||
|
||||
/**
|
||||
* The model
|
||||
*/
|
||||
static SavedModelBundle model;
|
||||
|
||||
/**
|
||||
* Initializes the categorization model
|
||||
*/
|
||||
public static void init(){
|
||||
model = SavedModelBundle.load("./data/model/sent_func");
|
||||
|
||||
//describe the model
|
||||
// System.out.println(model.functions());
|
||||
// System.out.println(model.functions().get(0).signature());
|
||||
// System.out.println(model.functions().get(1).signature());
|
||||
// System.out.println(model.functions().get(2).signature());
|
||||
// System.out.println(model.functions().get(0).signature().getInputs());
|
||||
// System.out.println(model.functions().get(0).signature().getInputs().values().iterator().next().name);
|
||||
// System.out.println(model.functions().get(0).signature().getInputs().values().iterator().next().shape);
|
||||
}
|
||||
|
||||
/**
|
||||
* Categorizes the sentence by function
|
||||
* @param input The input quote
|
||||
* @return The function of the sentence
|
||||
*/
|
||||
public static void categorize(Quote input){
|
||||
input.setFunction(SentenceFunction.UTILITY);
|
||||
//construct input
|
||||
TString inputTensor = TString.scalarOf(input.getRaw());
|
||||
inputTensor.shape().append(1);
|
||||
NdArray<String> stringArr = NdArrays.ofObjects(String.class, Shape.of(1, 1));
|
||||
stringArr.setObject(input.getRaw());
|
||||
inputTensor = TString.tensorOf(stringArr);
|
||||
|
||||
Map<String,Tensor> inputMap = new HashMap<String,Tensor>();
|
||||
inputMap.put("keras_tensor",inputTensor);
|
||||
|
||||
//call
|
||||
Result result = model.function("serve").call(inputMap);
|
||||
|
||||
//parse results
|
||||
TFloat32 resultFloats = (TFloat32)result.get(0);
|
||||
FloatNdArray floatArr = resultFloats.get(0);
|
||||
float[] classification = new float[]{
|
||||
floatArr.getFloat(0),
|
||||
floatArr.getFloat(1),
|
||||
floatArr.getFloat(2),
|
||||
floatArr.getFloat(3),
|
||||
};
|
||||
|
||||
//figure out highest value category
|
||||
int maxAt = 0;
|
||||
for(int i = 0; i < classification.length; i++){
|
||||
if(classification[i] > classification[maxAt]){
|
||||
maxAt = i;
|
||||
}
|
||||
}
|
||||
|
||||
switch(maxAt){
|
||||
case 0: {
|
||||
input.setFunction(SentenceFunction.UTILITY);
|
||||
} break;
|
||||
case 1: {
|
||||
input.setFunction(SentenceFunction.TRANSFER);
|
||||
} break;
|
||||
case 2: {
|
||||
input.setFunction(SentenceFunction.QUERY);
|
||||
} break;
|
||||
case 3: {
|
||||
input.setFunction(SentenceFunction.IMPERATIVE);
|
||||
} break;
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@ -5,6 +5,8 @@ from keras.api.layers import TextVectorization, Embedding, LSTM, Dense, Input
|
||||
from keras.api.models import Sequential
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
import pandas as pd
|
||||
from pandas import DataFrame
|
||||
|
||||
|
||||
|
||||
@ -33,17 +35,48 @@ sequence_length: int = 500
|
||||
epochs: int = 50
|
||||
max_tokens: int = 5000
|
||||
output_sequence_length: int = 4
|
||||
num_classes: int = 4
|
||||
|
||||
|
||||
# read sentences
|
||||
data_path: str = './data/semantic/subject.txt'
|
||||
data_raw: str = open(data_path).read()
|
||||
vocab: list[str] = data_raw.split('\n')
|
||||
|
||||
# read labels
|
||||
label_data_path: str = './data/semantic/subject_label.txt'
|
||||
label_data_raw: str = open(label_data_path).read()
|
||||
labels: list[int] = list(map(int,label_data_raw.split()))
|
||||
#
|
||||
# LOAD DATA
|
||||
#
|
||||
|
||||
|
||||
# read training sentences
|
||||
train_csv_path: str = './data/semantic/sent_func/train.csv'
|
||||
train_csv_raw: DataFrame = pd.read_csv(open(train_csv_path), dtype={
|
||||
"utility": "int64",
|
||||
"transfer": "int64",
|
||||
"query": "int64",
|
||||
"imperative": "int64",
|
||||
"sentence": "string",
|
||||
})
|
||||
train_data_raw: DataFrame = train_csv_raw[["sentence"]]
|
||||
train_data_split: DataFrame = train_data_raw
|
||||
vocab: list[str] = train_data_raw.to_numpy()
|
||||
train_labels: DataFrame = train_csv_raw[["utility", "transfer", "query", "imperative"]]
|
||||
|
||||
# read testing sentences
|
||||
test_csv_path: str = './data/semantic/sent_func/test.csv'
|
||||
test_csv_raw: DataFrame = pd.read_csv(open(test_csv_path), dtype={
|
||||
"utility": "int64",
|
||||
"transfer": "int64",
|
||||
"query": "int64",
|
||||
"imperative": "int64",
|
||||
"sentence": "string",
|
||||
})
|
||||
test_data_raw: DataFrame = test_csv_raw[["sentence"]]
|
||||
test_data_split: DataFrame = test_data_raw
|
||||
test_labels: DataFrame = test_csv_raw[["utility", "transfer", "query", "imperative"]]
|
||||
|
||||
|
||||
|
||||
#
|
||||
# CREATE VECTORIZER
|
||||
#
|
||||
|
||||
|
||||
# init vectorizer
|
||||
textVec: TextVectorization = TextVectorization(
|
||||
@ -54,15 +87,24 @@ textVec: TextVectorization = TextVectorization(
|
||||
|
||||
# Add the vocab to the tokenizer
|
||||
textVec.adapt(vocab)
|
||||
input_data: list[str] = vocab
|
||||
data: Tensor = textVec.call(input_data)
|
||||
input_data: list[str] = train_data_split
|
||||
train_data: Tensor = textVec.call(input_data)
|
||||
|
||||
|
||||
|
||||
|
||||
#
|
||||
# CREATE MODEL
|
||||
#
|
||||
|
||||
|
||||
# construct model
|
||||
model: Sequential = Sequential([
|
||||
keras.Input(shape=(None,), dtype="int64"),
|
||||
keras.Input(shape=(1,), dtype=tf.string),
|
||||
textVec,
|
||||
Embedding(max_features + 1, embedding_dim),
|
||||
LSTM(64),
|
||||
Dense(1, activation='sigmoid')
|
||||
Dense(num_classes, activation='sigmoid')
|
||||
])
|
||||
|
||||
#compile the model
|
||||
@ -70,26 +112,53 @@ model: Sequential = Sequential([
|
||||
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
|
||||
|
||||
|
||||
# fit the training data
|
||||
npData = np.array(data)
|
||||
npLabel = np.array(labels)
|
||||
model.fit(npData,npLabel,epochs=epochs)
|
||||
|
||||
#
|
||||
# TRAIN MODEL
|
||||
#
|
||||
|
||||
# Final formatting of data
|
||||
npTrainData = train_data_split.to_numpy(dtype=object).flatten()
|
||||
npTrainLabel: npt.NDArray[np.complex64] = train_labels.to_numpy()
|
||||
if(npTrainData.shape[0] != npTrainLabel.shape[0]):
|
||||
print("Data label size mismatch!")
|
||||
print(npTrainData.shape)
|
||||
print(npTrainLabel.shape)
|
||||
|
||||
# try fitting data
|
||||
print("Training..")
|
||||
model.fit(npTrainData,npTrainLabel,epochs=epochs)
|
||||
|
||||
|
||||
|
||||
|
||||
#
|
||||
# EVALUATE MODEL
|
||||
#
|
||||
|
||||
|
||||
# evaluate here
|
||||
npTestData: npt.NDArray = test_data_split.to_numpy(dtype=object).flatten()
|
||||
npTestLabel: npt.NDArray[np.complex64] = test_labels.to_numpy()
|
||||
if(npTestData.shape[0] != npTestLabel.shape[0]):
|
||||
print("Data label size mismatch!")
|
||||
print(npTestData.shape)
|
||||
print(npTestLabel.shape)
|
||||
|
||||
print("Evaluating..")
|
||||
model.evaluate(npTestData,npTestLabel)
|
||||
|
||||
|
||||
# predict
|
||||
predictTargetRaw: list[str] = ['saf']
|
||||
predictTargetToken: list[int] = textVec.call(predictTargetRaw)
|
||||
npPredict: npt.NDArray[np.complex64] = np.array(predictTargetToken)
|
||||
# print(npPredict)
|
||||
result: list[int] = model.predict(npPredict)
|
||||
predictTargetRaw: Tensor = tf.constant(['Hello'])
|
||||
npPredict: npt.NDArray = np.array(predictTargetRaw, dtype=object)
|
||||
print("Prediction test..")
|
||||
result: list[int] = model.predict(predictTargetRaw)
|
||||
print("predict result:")
|
||||
print(predictTargetToken)
|
||||
print(predictTargetRaw)
|
||||
print(result)
|
||||
print(data)
|
||||
print(labels)
|
||||
print(train_data)
|
||||
print(train_labels)
|
||||
|
||||
|
||||
|
||||
@ -97,8 +166,14 @@ print(labels)
|
||||
# savePath: str = './data/semantic/model.keras'
|
||||
# model.save(savePath)
|
||||
|
||||
|
||||
#
|
||||
# SAVE MODEL
|
||||
#
|
||||
|
||||
# export the model so java can leverage it
|
||||
exportPath: str = './data/semantic/model'
|
||||
print("Saving..")
|
||||
exportPath: str = './data/model/sent_func'
|
||||
model.export(exportPath)
|
||||
|
||||
# tf.keras.utils.get_file('asdf')
|
||||
|
||||
Loading…
Reference in New Issue
Block a user