migrate model loading out of main class

This commit is contained in:
austin 2024-12-28 21:11:40 -05:00
parent 3dc2bcd501
commit ef66c21cad
2 changed files with 38 additions and 102 deletions

View File

@ -1,8 +1,6 @@
package org.studiorailgun;
import org.tensorflow.SavedModelBundle;
import org.tensorflow.Tensor;
import org.tensorflow.types.TInt64;
import org.studiorailgun.conversation.AgentLoop;
/**
* The main class
@ -13,70 +11,7 @@ public class Main {
* The main method
*/
public static void main(String[] args){
// AgentLoop.main();
// try (Graph graph = new Graph()) {
// prediction_signature = tf.saved_model.signature_def_utils.predict_signature_def({"input": inputs}, {"output":output});
SavedModelBundle model = SavedModelBundle.load("./data/semantic/model");
System.out.println(model.functions());
System.out.println(model.functions().get(0).signature());
System.out.println(model.functions().get(1).signature());
System.out.println(model.functions().get(2).signature());
System.out.println(model.functions().get(0).signature().getInputs());
System.out.println(model.functions().get(0).signature().getInputs().values().iterator().next().shape);
//init
// model.functions().get(2).session().runner()
// .fetch("__saved_model_init_op")
// .run();
//run predict
TInt64 tensor = TInt64.scalarOf(10);
org.tensorflow.ndarray.Shape shape = tensor.shape().append(2001).append(128);
tensor = TInt64.tensorOf(shape);
// NdArray<LongNdArray> arr = NdArrays.vectorOfObjects(NdArrays.vectorOf((long)0, (long)0, (long)0, (long)0));
// System.out.println(arr.shape());
// tensor = TInt64.tensorOf(arr.shape());
Tensor result = model.function("serve").call(tensor);
// model.functions().get(0).session().runner()
// .feed("keras_tensor", tensor)
// .fetch("output_0")
// .run();
// Graph graph = model.graph();
// Session session = model.session();
// Tensor<Long> inputTensor = Tensor.of(Long.class, model.functions().get(0).signature().getInputs().values().iterator().next().shape);
// NdArrays.ofLongs(shape(2, 3, 2));
// TInt64 tensor = TInt64.vectorOf(10,100);
// // TInt64 tensor = TInt64.tensorOf(model.functions().get(0).signature().getInputs().values().iterator().next().shape);
// Tensor result = session.runner()
// .feed("keras_tensor", tensor)
// .fetch("output_0")
// .run().get(0);
// Result result = model.call(tensors);
// byte[] graphDef = Files.readAllBytes(Paths.get("./data/semantic/model.keras"));
// GraphDef def = GraphDef.newBuilder()
// .mergeFrom(graphDef)
// .build();
// graph.importGraphDef(def);
// try (Session session = new Session(graph)) {
// // // Prepare input data as a Tensor
// // Tensor inputTensor = Tensor.create(inputData);
// // // Run inference
// // Tensor result = session.runner()
// // .feed("input_node_name", inputTensor) // Replace with actual input tensor name
// // .fetch("output_node_name") // Replace with actual output tensor name
// // .run().get(0);
// // // Process the result
// // float[][] output = result.copyTo(new float[1][outputSize]);
// // System.out.println(Arrays.toString(output[0]));
// }
// } catch (IOException e) {
// e.printStackTrace();
// }
AgentLoop.main();
}
}

View File

@ -1,53 +1,54 @@
package org.studiorailgun.conversation.semantic;
import java.io.File;
import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import org.deeplearning4j.models.word2vec.Word2Vec;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.BackpropType;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.api.InvocationType;
import org.deeplearning4j.optimize.listeners.EvaluativeListener;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.deeplearning4j.text.sentenceiterator.LineSentenceIterator;
import org.deeplearning4j.text.sentenceiterator.SentenceIterator;
import org.deeplearning4j.text.tokenization.tokenizer.preprocessor.CommonPreprocessor;
import org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory;
import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory;
import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.learning.config.Sgd;
import org.tensorflow.Result;
import org.tensorflow.SavedModelBundle;
import org.tensorflow.Tensor;
import org.tensorflow.ndarray.FloatNdArray;
import org.tensorflow.types.TFloat32;
import org.tensorflow.types.TInt64;
/**
* Parses the subject of a sentence
*/
public class SentenceSubjectParser {
/**
* The configuration for the network
*/
MultiLayerConfiguration conf;
/**
* The actual model created from the configuration
* The model
*/
MultiLayerNetwork model;
/**
* The number of epochs to try to train for
*/
int targetEpochs;
SavedModelBundle model;
/**
* Initializes the model
*/
public void init(){
model = SavedModelBundle.load("./data/semantic/model");
//describe the model
// System.out.println(model.functions());
// System.out.println(model.functions().get(0).signature());
// System.out.println(model.functions().get(1).signature());
// System.out.println(model.functions().get(2).signature());
// System.out.println(model.functions().get(0).signature().getInputs());
// System.out.println(model.functions().get(0).signature().getInputs().values().iterator().next().name);
// System.out.println(model.functions().get(0).signature().getInputs().values().iterator().next().shape);
}
public void evaluate(String sentence){
//run predict
TInt64 tensor = TInt64.scalarOf(10);
org.tensorflow.ndarray.Shape shape = tensor.shape().append(2001).append(128);
tensor = TInt64.tensorOf(shape);
Map<String,Tensor> inputMap = new HashMap<String,Tensor>();
inputMap.put("keras_tensor",tensor);
Result result = model.function("serve").call(inputMap);
System.out.println(result.get(0));
TFloat32 resultTensor = (TFloat32)result.get(0);
FloatNdArray floatArr = resultTensor.get(0);
System.out.println(floatArr.getFloat(0));
}
}