player command parsing
All checks were successful
studiorailgun/trpg/pipeline/head This commit looks good
All checks were successful
studiorailgun/trpg/pipeline/head This commit looks good
This commit is contained in:
parent
aaf093f85a
commit
ea0af29777
1
data/model/sim/command_cat/fingerprint.pb
Normal file
1
data/model/sim/command_cat/fingerprint.pb
Normal file
@ -0,0 +1 @@
|
|||||||
|
д╢б÷Э²Ф⌡ЬЗс▐зК╜·З═╔в╔▓ИКЙG сМ╘И┐Ь▓б├(≤≥╞юлЬ║╟·2
|
||||||
BIN
data/model/sim/command_cat/saved_model.pb
Normal file
BIN
data/model/sim/command_cat/saved_model.pb
Normal file
Binary file not shown.
Binary file not shown.
BIN
data/model/sim/command_cat/variables/variables.index
Normal file
BIN
data/model/sim/command_cat/variables/variables.index
Normal file
Binary file not shown.
@ -1,2 +0,0 @@
|
|||||||
"utility","transfer","query","imperative","sentence"
|
|
||||||
1,0,0,0,"Hello"
|
|
||||||
|
@ -1,2 +0,0 @@
|
|||||||
"look","transfer","query","imperative","sentence"
|
|
||||||
1,0,0,0,"Hello"
|
|
||||||
|
25
data/semantic/sim/rpg_command/test.csv
Normal file
25
data/semantic/sim/rpg_command/test.csv
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
help,look,walk,speak,sentence
|
||||||
|
0,1,0,0, I look around.
|
||||||
|
1,0,0,0,Help.
|
||||||
|
0,0,1,0,I approach one of the merchants.
|
||||||
|
0,0,1,0,I head towards the adventurer's guild.
|
||||||
|
0,0,1,0,I walk up to the gnome.
|
||||||
|
0,1,0,0,I look at the paper describing the trial quest.
|
||||||
|
0,0,0,1,"I then ask, ""Where can I find the kobold caves?"""
|
||||||
|
0,0,0,1,"""Thank you!"" I reply."
|
||||||
|
0,0,0,1,"""Thank you again."" I say."
|
||||||
|
0,0,1,0,I then follow his directions.
|
||||||
|
0,1,0,0,I look around briefly.
|
||||||
|
0,0,1,0,I follow the path up to the cave.
|
||||||
|
0,1,0,0,I look around once I arrive.
|
||||||
|
0,0,1,0,I enter the cave and begin to follow the footprints.
|
||||||
|
0,0,1,0,I approach the sounds of battle cautiously.
|
||||||
|
0,0,1,0,I sneak up behind a kobold.
|
||||||
|
0,1,0,0,I try to assess whether the battle is still ongoing or has finished.
|
||||||
|
0,0,0,1,"I ask them ""Are y'all ok?"""
|
||||||
|
0,1,0,0,I inspect the human warrior.
|
||||||
|
0,1,0,0,I inspect the halfling.
|
||||||
|
0,1,0,0,I inspect the elf.
|
||||||
|
0,0,1,0,I then head towards the entrance of the cave.
|
||||||
|
0,0,1,0,I follow the human.
|
||||||
|
0,0,1,0,I head towards the hot spring connected to the inn.
|
||||||
|
25
data/semantic/sim/rpg_command/train.csv
Normal file
25
data/semantic/sim/rpg_command/train.csv
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
help,look,walk,speak,sentence
|
||||||
|
0,1,0,0, I look around.
|
||||||
|
1,0,0,0,Help.
|
||||||
|
0,0,1,0,I approach one of the merchants.
|
||||||
|
0,0,1,0,I head towards the adventurer's guild.
|
||||||
|
0,0,1,0,I walk up to the gnome.
|
||||||
|
0,1,0,0,I look at the paper describing the trial quest.
|
||||||
|
0,0,0,1,"I then ask, ""Where can I find the kobold caves?"""
|
||||||
|
0,0,0,1,"""Thank you!"" I reply."
|
||||||
|
0,0,0,1,"""Thank you again."" I say."
|
||||||
|
0,0,1,0,I then follow his directions.
|
||||||
|
0,1,0,0,I look around briefly.
|
||||||
|
0,0,1,0,I follow the path up to the cave.
|
||||||
|
0,1,0,0,I look around once I arrive.
|
||||||
|
0,0,1,0,I enter the cave and begin to follow the footprints.
|
||||||
|
0,0,1,0,I approach the sounds of battle cautiously.
|
||||||
|
0,0,1,0,I sneak up behind a kobold.
|
||||||
|
0,1,0,0,I try to assess whether the battle is still ongoing or has finished.
|
||||||
|
0,0,0,1,"I ask them ""Are y'all ok?"""
|
||||||
|
0,1,0,0,I inspect the human warrior.
|
||||||
|
0,1,0,0,I inspect the halfling.
|
||||||
|
0,1,0,0,I inspect the elf.
|
||||||
|
0,0,1,0,I then head towards the entrance of the cave.
|
||||||
|
0,0,1,0,I follow the human.
|
||||||
|
0,0,1,0,I head towards the hot spring connected to the inn.
|
||||||
|
@ -24,6 +24,7 @@ public class GameLoop {
|
|||||||
*/
|
*/
|
||||||
public static void main(){
|
public static void main(){
|
||||||
Globals.init("web.json");
|
Globals.init("web.json");
|
||||||
|
GameCommandParser.init();
|
||||||
Globals.world = WorldGenerator.generateWorld();
|
Globals.world = WorldGenerator.generateWorld();
|
||||||
Globals.playerCharacter = PlayerCharSourcer.getPlayerCharacter();
|
Globals.playerCharacter = PlayerCharSourcer.getPlayerCharacter();
|
||||||
|
|
||||||
|
|||||||
@ -23,6 +23,12 @@ public class Character {
|
|||||||
*/
|
*/
|
||||||
EmotionData emotions;
|
EmotionData emotions;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Simulates this character
|
||||||
|
*/
|
||||||
|
public void simulate(){
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Gets the appearance of the character
|
* Gets the appearance of the character
|
||||||
* @return The appearance of the character
|
* @return The appearance of the character
|
||||||
|
|||||||
@ -1,16 +1,53 @@
|
|||||||
package org.studiorailgun.sim.eval;
|
package org.studiorailgun.sim.eval;
|
||||||
|
|
||||||
|
import org.studiorailgun.sim.eval.command.LookCommand;
|
||||||
|
import org.studiorailgun.sim.eval.command.SpeakCommand;
|
||||||
|
import org.studiorailgun.sim.eval.command.SystemCommand;
|
||||||
|
import org.studiorailgun.sim.eval.command.WalkCommand;
|
||||||
|
import org.studiorailgun.sim.eval.parse.CommandCategorizationModel;
|
||||||
|
import org.studiorailgun.sim.eval.parse.CommandCategorizationModel.CommandType;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Parses game commands from the user
|
* Parses game commands from the user
|
||||||
*/
|
*/
|
||||||
public class GameCommandParser {
|
public class GameCommandParser {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Initializes the command parser
|
||||||
|
*/
|
||||||
|
public static void init(){
|
||||||
|
CommandCategorizationModel.init();
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Parses the player's input
|
* Parses the player's input
|
||||||
* @param input The player's input
|
* @param input The player's input
|
||||||
*/
|
*/
|
||||||
public static void parse(String input){
|
public static void parse(String input){
|
||||||
|
CommandType type = CommandCategorizationModel.categorize(input);
|
||||||
|
switch(type){
|
||||||
|
|
||||||
|
case HELP: {
|
||||||
|
SystemCommand.handle(input);
|
||||||
|
} break;
|
||||||
|
|
||||||
|
case LOOK: {
|
||||||
|
LookCommand.handle(input);
|
||||||
|
} break;
|
||||||
|
|
||||||
|
case WALK: {
|
||||||
|
WalkCommand.handle(input);
|
||||||
|
} break;
|
||||||
|
|
||||||
|
case SPEAK: {
|
||||||
|
SpeakCommand.handle(input);
|
||||||
|
} break;
|
||||||
|
|
||||||
|
default: {
|
||||||
|
String message = "Could not determine what the command was. \"" + input + "\" " + type;
|
||||||
|
throw new Error(message);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,5 +1,8 @@
|
|||||||
package org.studiorailgun.sim.eval;
|
package org.studiorailgun.sim.eval;
|
||||||
|
|
||||||
|
import org.studiorailgun.Globals;
|
||||||
|
import org.studiorailgun.sim.character.Character;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Simulates the game world
|
* Simulates the game world
|
||||||
*/
|
*/
|
||||||
@ -9,9 +12,9 @@ public class Simulator {
|
|||||||
* Simulates a single frame of the game world
|
* Simulates a single frame of the game world
|
||||||
*/
|
*/
|
||||||
public static void simulate(){
|
public static void simulate(){
|
||||||
|
for(Character character : Globals.world.getCharacters()){
|
||||||
|
character.simulate();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@ -0,0 +1,15 @@
|
|||||||
|
package org.studiorailgun.sim.eval.command;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Handler for look command
|
||||||
|
*/
|
||||||
|
public class LookCommand {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Handles a look command
|
||||||
|
*/
|
||||||
|
public static void handle(String input){
|
||||||
|
System.out.println("Look command..");
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
@ -0,0 +1,16 @@
|
|||||||
|
package org.studiorailgun.sim.eval.command;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Handles a speaking command
|
||||||
|
*/
|
||||||
|
public class SpeakCommand {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Handles a speaking command
|
||||||
|
* @param input The raw command
|
||||||
|
*/
|
||||||
|
public static void handle(String input){
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
@ -0,0 +1,15 @@
|
|||||||
|
package org.studiorailgun.sim.eval.command;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Handles system commands
|
||||||
|
*/
|
||||||
|
public class SystemCommand {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Handles a system command
|
||||||
|
*/
|
||||||
|
public static void handle(String input){
|
||||||
|
System.out.println("System command.. " + input);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
@ -0,0 +1,16 @@
|
|||||||
|
package org.studiorailgun.sim.eval.command;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Handles a walk command
|
||||||
|
*/
|
||||||
|
public class WalkCommand {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Handles a walk command
|
||||||
|
* @param input The input text
|
||||||
|
*/
|
||||||
|
public static void handle(String input){
|
||||||
|
System.out.println("Walk command..");
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
@ -0,0 +1,141 @@
|
|||||||
|
package org.studiorailgun.sim.eval.parse;
|
||||||
|
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
import org.tensorflow.Result;
|
||||||
|
import org.tensorflow.SavedModelBundle;
|
||||||
|
import org.tensorflow.Tensor;
|
||||||
|
import org.tensorflow.ndarray.FloatNdArray;
|
||||||
|
import org.tensorflow.ndarray.NdArray;
|
||||||
|
import org.tensorflow.ndarray.NdArrays;
|
||||||
|
import org.tensorflow.ndarray.Shape;
|
||||||
|
import org.tensorflow.types.TFloat32;
|
||||||
|
import org.tensorflow.types.TString;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Categorizes player commands
|
||||||
|
*/
|
||||||
|
public class CommandCategorizationModel {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The type of the command
|
||||||
|
*/
|
||||||
|
public static enum CommandType {
|
||||||
|
/**
|
||||||
|
* Asking the system for help
|
||||||
|
*/
|
||||||
|
HELP,
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Looking at something
|
||||||
|
*/
|
||||||
|
LOOK,
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Walk towards something
|
||||||
|
*/
|
||||||
|
WALK,
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Speak at something
|
||||||
|
*/
|
||||||
|
SPEAK,
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The number of different command types
|
||||||
|
*/
|
||||||
|
static final int NUM_COMMAND_TYPES = 4;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Cutoff after which we say the command was not classified successfully
|
||||||
|
*/
|
||||||
|
static final float CLASSIFICAITON_CUTOFF = 0.5f;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The model
|
||||||
|
*/
|
||||||
|
static SavedModelBundle model;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Initializes the categorization model
|
||||||
|
*/
|
||||||
|
public static void init(){
|
||||||
|
model = SavedModelBundle.load("./data/model/sim/command_cat");
|
||||||
|
|
||||||
|
//describe the model
|
||||||
|
// System.out.println(model.functions());
|
||||||
|
// System.out.println(model.functions().get(0).signature());
|
||||||
|
// System.out.println(model.functions().get(1).signature());
|
||||||
|
// System.out.println(model.functions().get(2).signature());
|
||||||
|
// System.out.println(model.functions().get(0).signature().getInputs());
|
||||||
|
// System.out.println(model.functions().get(0).signature().getInputs().values().iterator().next().name);
|
||||||
|
// System.out.println(model.functions().get(0).signature().getInputs().values().iterator().next().shape);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Categorizes the sentence by function
|
||||||
|
* @param input The input string
|
||||||
|
* @return The function of the sentence
|
||||||
|
*/
|
||||||
|
public static CommandType categorize(String input){
|
||||||
|
//construct input
|
||||||
|
TString inputTensor = TString.scalarOf(input);
|
||||||
|
inputTensor.shape().append(1);
|
||||||
|
NdArray<String> stringArr = NdArrays.ofObjects(String.class, Shape.of(1, 1));
|
||||||
|
stringArr.setObject(input);
|
||||||
|
inputTensor = TString.tensorOf(stringArr);
|
||||||
|
|
||||||
|
Map<String,Tensor> inputMap = new HashMap<String,Tensor>();
|
||||||
|
inputMap.put("keras_tensor",inputTensor);
|
||||||
|
|
||||||
|
//call
|
||||||
|
Result result = model.function("serve").call(inputMap);
|
||||||
|
|
||||||
|
//parse results
|
||||||
|
TFloat32 resultFloats = (TFloat32)result.get(0);
|
||||||
|
FloatNdArray floatArr = resultFloats.get(0);
|
||||||
|
float[] classification = new float[NUM_COMMAND_TYPES];
|
||||||
|
for(int i = 0; i < NUM_COMMAND_TYPES; i++){
|
||||||
|
classification[i] = floatArr.getFloat(i);
|
||||||
|
}
|
||||||
|
|
||||||
|
//figure out highest value category
|
||||||
|
int maxAt = 0;
|
||||||
|
for(int i = 0; i < classification.length; i++){
|
||||||
|
if(classification[i] > classification[maxAt]){
|
||||||
|
maxAt = i;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//throw an error if not confident in the classification
|
||||||
|
if(classification[maxAt] < CLASSIFICAITON_CUTOFF){
|
||||||
|
String message = "Failed to classify command definitively! \n";
|
||||||
|
for(int i = 0; i < NUM_COMMAND_TYPES; i++){
|
||||||
|
message = message + "class[" + i + "] " + classification[i] + "\n";
|
||||||
|
}
|
||||||
|
throw new Error(message);
|
||||||
|
}
|
||||||
|
|
||||||
|
//convert to enum of command types
|
||||||
|
switch(maxAt){
|
||||||
|
case 0: {
|
||||||
|
return CommandType.HELP;
|
||||||
|
}
|
||||||
|
case 1: {
|
||||||
|
return CommandType.LOOK;
|
||||||
|
}
|
||||||
|
case 2: {
|
||||||
|
return CommandType.WALK;
|
||||||
|
}
|
||||||
|
case 3: {
|
||||||
|
return CommandType.SPEAK;
|
||||||
|
}
|
||||||
|
default: {
|
||||||
|
throw new Error("Failed to categorize!");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
@ -1,5 +1,10 @@
|
|||||||
package org.studiorailgun.sim.space;
|
package org.studiorailgun.sim.space;
|
||||||
|
|
||||||
|
import java.util.LinkedList;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
import org.studiorailgun.sim.character.Character;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Top level spatial container
|
* Top level spatial container
|
||||||
*/
|
*/
|
||||||
@ -10,6 +15,11 @@ public class World {
|
|||||||
*/
|
*/
|
||||||
Region region;
|
Region region;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The list of characters
|
||||||
|
*/
|
||||||
|
List<Character> characters = new LinkedList<Character>();
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Gets the top level region of this world
|
* Gets the top level region of this world
|
||||||
* @return The top level region of this world
|
* @return The top level region of this world
|
||||||
@ -26,4 +36,22 @@ public class World {
|
|||||||
this.region = region;
|
this.region = region;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Gets the list of characters in this world
|
||||||
|
* @return The list of characters in this world
|
||||||
|
*/
|
||||||
|
public List<Character> getCharacters() {
|
||||||
|
return characters;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Sets the list of characters in this world
|
||||||
|
* @param characters The list of characters in this world
|
||||||
|
*/
|
||||||
|
public void setCharacters(List<Character> characters) {
|
||||||
|
this.characters = characters;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
220
src/main/python/sim/command_cat.py
Normal file
220
src/main/python/sim/command_cat.py
Normal file
@ -0,0 +1,220 @@
|
|||||||
|
import tensorflow as tf
|
||||||
|
keras = tf.keras
|
||||||
|
from tensorflow import Tensor
|
||||||
|
from keras.api.layers import TextVectorization, Embedding, LSTM, Dense, Bidirectional
|
||||||
|
from keras.api.models import Sequential
|
||||||
|
import numpy as np
|
||||||
|
import numpy.typing as npt
|
||||||
|
import pandas as pd
|
||||||
|
from pandas import DataFrame
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# Description: The purpose of this model is to categorize the command a player inputs
|
||||||
|
#
|
||||||
|
# - "HELP" - Asking the system for help (ie for what commands are available)
|
||||||
|
#
|
||||||
|
# - "LOOK" - Looks at something
|
||||||
|
#
|
||||||
|
# - "WALK" - Walks towards something
|
||||||
|
#
|
||||||
|
# - "SPEAK" - Speaks at something
|
||||||
|
#
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# MODEL CONSTANTS
|
||||||
|
#
|
||||||
|
#
|
||||||
|
|
||||||
|
# this is the maximum allowed size of the vocabulary
|
||||||
|
max_features: int = 20000
|
||||||
|
|
||||||
|
# the dimension of the output from the embedding layer
|
||||||
|
embedding_dim: int = 128
|
||||||
|
|
||||||
|
# The number of epochs to train for
|
||||||
|
epochs: int = 50
|
||||||
|
|
||||||
|
# Maximum size of the vocab for this layer
|
||||||
|
max_tokens: int = 5000
|
||||||
|
|
||||||
|
# (Only valid in INT mode) If set, the output will have its time dimension padded or truncated to exactly output_sequence_length values
|
||||||
|
output_sequence_length: int = 4
|
||||||
|
|
||||||
|
# The number of classes we're training for
|
||||||
|
num_classes: int = 4
|
||||||
|
|
||||||
|
# Path to the data
|
||||||
|
data_path: str = './data/semantic/sim/rpg_command'
|
||||||
|
|
||||||
|
# Path to export the model to
|
||||||
|
export_path: str = './data/model/sim/command_cat'
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# LOAD DATA
|
||||||
|
#
|
||||||
|
#
|
||||||
|
|
||||||
|
|
||||||
|
# read training sentences
|
||||||
|
train_csv_path: str = data_path + '/train.csv'
|
||||||
|
train_csv_raw: DataFrame = pd.read_csv(open(train_csv_path), dtype={
|
||||||
|
"help": "int64",
|
||||||
|
"look": "int64",
|
||||||
|
"walk": "int64",
|
||||||
|
"speak": "int64",
|
||||||
|
"sentence": "string",
|
||||||
|
})
|
||||||
|
train_data_raw: DataFrame = train_csv_raw[["sentence"]]
|
||||||
|
train_data_split: DataFrame = train_data_raw
|
||||||
|
vocab: list[str] = train_data_raw.to_numpy()
|
||||||
|
train_labels: DataFrame = train_csv_raw[["help", "look", "walk", "speak"]]
|
||||||
|
|
||||||
|
# read testing sentences
|
||||||
|
test_csv_path: str = data_path + '/test.csv'
|
||||||
|
test_csv_raw: DataFrame = pd.read_csv(open(test_csv_path), dtype={
|
||||||
|
"help": "int64",
|
||||||
|
"look": "int64",
|
||||||
|
"walk": "int64",
|
||||||
|
"speak": "int64",
|
||||||
|
"sentence": "string",
|
||||||
|
})
|
||||||
|
test_data_raw: DataFrame = test_csv_raw[["sentence"]]
|
||||||
|
test_data_split: DataFrame = test_data_raw
|
||||||
|
test_labels: DataFrame = test_csv_raw[["help", "look", "walk", "speak"]]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# CREATE VECTORIZER
|
||||||
|
#
|
||||||
|
#
|
||||||
|
|
||||||
|
|
||||||
|
# init vectorizer
|
||||||
|
textVec: TextVectorization = TextVectorization(
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
output_mode='int',
|
||||||
|
output_sequence_length=output_sequence_length,
|
||||||
|
pad_to_max_tokens=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add the vocab to the tokenizer
|
||||||
|
textVec.adapt(vocab)
|
||||||
|
input_data: list[str] = train_data_split
|
||||||
|
train_data: Tensor = textVec.call(input_data)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# CREATE MODEL
|
||||||
|
#
|
||||||
|
#
|
||||||
|
|
||||||
|
|
||||||
|
# construct model
|
||||||
|
model: Sequential = Sequential([
|
||||||
|
keras.Input(shape=(1,), dtype=tf.string),
|
||||||
|
textVec,
|
||||||
|
Embedding(max_features + 1, embedding_dim),
|
||||||
|
Bidirectional(LSTM(256)),
|
||||||
|
Dense(64, activation='relu'),
|
||||||
|
Dense(num_classes, activation='sigmoid')
|
||||||
|
])
|
||||||
|
|
||||||
|
#compile the model
|
||||||
|
# model.build(keras.Input(shape=(None,), dtype="int64"))
|
||||||
|
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# TRAIN MODEL
|
||||||
|
#
|
||||||
|
#
|
||||||
|
|
||||||
|
# Final formatting of data
|
||||||
|
npTrainData = train_data_split.to_numpy(dtype=object).flatten()
|
||||||
|
npTrainLabel: npt.NDArray[np.complex64] = train_labels.to_numpy()
|
||||||
|
if(npTrainData.shape[0] != npTrainLabel.shape[0]):
|
||||||
|
print("Data label size mismatch!")
|
||||||
|
print(npTrainData.shape)
|
||||||
|
print(npTrainLabel.shape)
|
||||||
|
|
||||||
|
# try fitting data
|
||||||
|
print("Training..")
|
||||||
|
model.fit(npTrainData,npTrainLabel,epochs=epochs)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# EVALUATE MODEL
|
||||||
|
#
|
||||||
|
#
|
||||||
|
|
||||||
|
|
||||||
|
# evaluate here
|
||||||
|
npTestData: npt.NDArray = test_data_split.to_numpy(dtype=object).flatten()
|
||||||
|
npTestLabel: npt.NDArray[np.complex64] = test_labels.to_numpy()
|
||||||
|
if(npTestData.shape[0] != npTestLabel.shape[0]):
|
||||||
|
print("Data label size mismatch!")
|
||||||
|
print(npTestData.shape)
|
||||||
|
print(npTestLabel.shape)
|
||||||
|
|
||||||
|
print("Evaluating..")
|
||||||
|
model.evaluate(npTestData,npTestLabel)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# PREDICT
|
||||||
|
#
|
||||||
|
#
|
||||||
|
|
||||||
|
# predictTargetRaw: Tensor = tf.constant(['Hello'])
|
||||||
|
# npPredict: npt.NDArray = np.array(predictTargetRaw, dtype=object)
|
||||||
|
# print("Prediction test..")
|
||||||
|
# result: list[int] = model.predict(predictTargetRaw)
|
||||||
|
# print("predict result:")
|
||||||
|
# print(predictTargetRaw)
|
||||||
|
# print(result)
|
||||||
|
# print(train_data)
|
||||||
|
# print(train_labels)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# SAVE (DEVELOPMENT) MODEL
|
||||||
|
#
|
||||||
|
#
|
||||||
|
|
||||||
|
# save the model so keras can reload
|
||||||
|
# savePath: str = export_path
|
||||||
|
# model.save(savePath)
|
||||||
|
|
||||||
|
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# SAVE (PRODUCTION) MODEL
|
||||||
|
#
|
||||||
|
#
|
||||||
|
|
||||||
|
# export the model so java can leverage it
|
||||||
|
print("Saving..")
|
||||||
|
model.export(export_path)
|
||||||
|
|
||||||
Loading…
Reference in New Issue
Block a user