diff --git a/src/main/java/org/studiorailgun/Main.java b/src/main/java/org/studiorailgun/Main.java index a73f21e..788fc62 100644 --- a/src/main/java/org/studiorailgun/Main.java +++ b/src/main/java/org/studiorailgun/Main.java @@ -1,8 +1,6 @@ package org.studiorailgun; -import org.tensorflow.SavedModelBundle; -import org.tensorflow.Tensor; -import org.tensorflow.types.TInt64; +import org.studiorailgun.conversation.AgentLoop; /** * The main class @@ -13,70 +11,7 @@ public class Main { * The main method */ public static void main(String[] args){ - // AgentLoop.main(); - // try (Graph graph = new Graph()) { - // prediction_signature = tf.saved_model.signature_def_utils.predict_signature_def({"input": inputs}, {"output":output}); - SavedModelBundle model = SavedModelBundle.load("./data/semantic/model"); - System.out.println(model.functions()); - System.out.println(model.functions().get(0).signature()); - System.out.println(model.functions().get(1).signature()); - System.out.println(model.functions().get(2).signature()); - System.out.println(model.functions().get(0).signature().getInputs()); - System.out.println(model.functions().get(0).signature().getInputs().values().iterator().next().shape); - //init - // model.functions().get(2).session().runner() - // .fetch("__saved_model_init_op") - // .run(); - //run predict - TInt64 tensor = TInt64.scalarOf(10); - org.tensorflow.ndarray.Shape shape = tensor.shape().append(2001).append(128); - tensor = TInt64.tensorOf(shape); - // NdArray arr = NdArrays.vectorOfObjects(NdArrays.vectorOf((long)0, (long)0, (long)0, (long)0)); - // System.out.println(arr.shape()); - // tensor = TInt64.tensorOf(arr.shape()); - - Tensor result = model.function("serve").call(tensor); - // model.functions().get(0).session().runner() - // .feed("keras_tensor", tensor) - // .fetch("output_0") - // .run(); - - // Graph graph = model.graph(); - // Session session = model.session(); - - // Tensor inputTensor = Tensor.of(Long.class, model.functions().get(0).signature().getInputs().values().iterator().next().shape); - // NdArrays.ofLongs(shape(2, 3, 2)); - // TInt64 tensor = TInt64.vectorOf(10,100); - // // TInt64 tensor = TInt64.tensorOf(model.functions().get(0).signature().getInputs().values().iterator().next().shape); - // Tensor result = session.runner() - // .feed("keras_tensor", tensor) - // .fetch("output_0") - // .run().get(0); - - // Result result = model.call(tensors); - // byte[] graphDef = Files.readAllBytes(Paths.get("./data/semantic/model.keras")); - // GraphDef def = GraphDef.newBuilder() - // .mergeFrom(graphDef) - // .build(); - // graph.importGraphDef(def); - - // try (Session session = new Session(graph)) { - // // // Prepare input data as a Tensor - // // Tensor inputTensor = Tensor.create(inputData); - - // // // Run inference - // // Tensor result = session.runner() - // // .feed("input_node_name", inputTensor) // Replace with actual input tensor name - // // .fetch("output_node_name") // Replace with actual output tensor name - // // .run().get(0); - - // // // Process the result - // // float[][] output = result.copyTo(new float[1][outputSize]); - // // System.out.println(Arrays.toString(output[0])); - // } - // } catch (IOException e) { - // e.printStackTrace(); - // } + AgentLoop.main(); } } diff --git a/src/main/java/org/studiorailgun/conversation/semantic/SentenceSubjectParser.java b/src/main/java/org/studiorailgun/conversation/semantic/SentenceSubjectParser.java index 026572c..ff1d113 100644 --- a/src/main/java/org/studiorailgun/conversation/semantic/SentenceSubjectParser.java +++ b/src/main/java/org/studiorailgun/conversation/semantic/SentenceSubjectParser.java @@ -1,53 +1,54 @@ package org.studiorailgun.conversation.semantic; -import java.io.File; -import java.io.IOException; +import java.util.HashMap; +import java.util.Map; -import org.deeplearning4j.models.word2vec.Word2Vec; -import org.deeplearning4j.nn.api.OptimizationAlgorithm; -import org.deeplearning4j.nn.conf.BackpropType; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; -import org.deeplearning4j.nn.conf.NeuralNetConfiguration; -import org.deeplearning4j.nn.conf.layers.DenseLayer; -import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; -import org.deeplearning4j.nn.weights.WeightInit; -import org.deeplearning4j.optimize.api.InvocationType; -import org.deeplearning4j.optimize.listeners.EvaluativeListener; -import org.deeplearning4j.optimize.listeners.ScoreIterationListener; -import org.deeplearning4j.text.sentenceiterator.LineSentenceIterator; -import org.deeplearning4j.text.sentenceiterator.SentenceIterator; -import org.deeplearning4j.text.tokenization.tokenizer.preprocessor.CommonPreprocessor; -import org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory; -import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory; -import org.nd4j.evaluation.classification.Evaluation; -import org.nd4j.linalg.activations.Activation; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; -import org.nd4j.linalg.learning.config.Sgd; +import org.tensorflow.Result; +import org.tensorflow.SavedModelBundle; +import org.tensorflow.Tensor; +import org.tensorflow.ndarray.FloatNdArray; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TInt64; /** * Parses the subject of a sentence */ public class SentenceSubjectParser { - /** - * The configuration for the network - */ - MultiLayerConfiguration conf; /** - * The actual model created from the configuration + * The model */ - MultiLayerNetwork model; - - /** - * The number of epochs to try to train for - */ - int targetEpochs; + SavedModelBundle model; /** * Initializes the model */ public void init(){ - + model = SavedModelBundle.load("./data/semantic/model"); + + //describe the model + // System.out.println(model.functions()); + // System.out.println(model.functions().get(0).signature()); + // System.out.println(model.functions().get(1).signature()); + // System.out.println(model.functions().get(2).signature()); + // System.out.println(model.functions().get(0).signature().getInputs()); + // System.out.println(model.functions().get(0).signature().getInputs().values().iterator().next().name); + // System.out.println(model.functions().get(0).signature().getInputs().values().iterator().next().shape); + } + + public void evaluate(String sentence){ + //run predict + TInt64 tensor = TInt64.scalarOf(10); + org.tensorflow.ndarray.Shape shape = tensor.shape().append(2001).append(128); + tensor = TInt64.tensorOf(shape); + + Map inputMap = new HashMap(); + inputMap.put("keras_tensor",tensor); + + Result result = model.function("serve").call(inputMap); + System.out.println(result.get(0)); + TFloat32 resultTensor = (TFloat32)result.get(0); + FloatNdArray floatArr = resultTensor.get(0); + System.out.println(floatArr.getFloat(0)); } }