migrate model loading out of main class
This commit is contained in:
parent
3dc2bcd501
commit
ef66c21cad
@ -1,8 +1,6 @@
|
||||
package org.studiorailgun;
|
||||
|
||||
import org.tensorflow.SavedModelBundle;
|
||||
import org.tensorflow.Tensor;
|
||||
import org.tensorflow.types.TInt64;
|
||||
import org.studiorailgun.conversation.AgentLoop;
|
||||
|
||||
/**
|
||||
* The main class
|
||||
@ -13,70 +11,7 @@ public class Main {
|
||||
* The main method
|
||||
*/
|
||||
public static void main(String[] args){
|
||||
// AgentLoop.main();
|
||||
// try (Graph graph = new Graph()) {
|
||||
// prediction_signature = tf.saved_model.signature_def_utils.predict_signature_def({"input": inputs}, {"output":output});
|
||||
SavedModelBundle model = SavedModelBundle.load("./data/semantic/model");
|
||||
System.out.println(model.functions());
|
||||
System.out.println(model.functions().get(0).signature());
|
||||
System.out.println(model.functions().get(1).signature());
|
||||
System.out.println(model.functions().get(2).signature());
|
||||
System.out.println(model.functions().get(0).signature().getInputs());
|
||||
System.out.println(model.functions().get(0).signature().getInputs().values().iterator().next().shape);
|
||||
//init
|
||||
// model.functions().get(2).session().runner()
|
||||
// .fetch("__saved_model_init_op")
|
||||
// .run();
|
||||
//run predict
|
||||
TInt64 tensor = TInt64.scalarOf(10);
|
||||
org.tensorflow.ndarray.Shape shape = tensor.shape().append(2001).append(128);
|
||||
tensor = TInt64.tensorOf(shape);
|
||||
// NdArray<LongNdArray> arr = NdArrays.vectorOfObjects(NdArrays.vectorOf((long)0, (long)0, (long)0, (long)0));
|
||||
// System.out.println(arr.shape());
|
||||
// tensor = TInt64.tensorOf(arr.shape());
|
||||
|
||||
Tensor result = model.function("serve").call(tensor);
|
||||
// model.functions().get(0).session().runner()
|
||||
// .feed("keras_tensor", tensor)
|
||||
// .fetch("output_0")
|
||||
// .run();
|
||||
|
||||
// Graph graph = model.graph();
|
||||
// Session session = model.session();
|
||||
|
||||
// Tensor<Long> inputTensor = Tensor.of(Long.class, model.functions().get(0).signature().getInputs().values().iterator().next().shape);
|
||||
// NdArrays.ofLongs(shape(2, 3, 2));
|
||||
// TInt64 tensor = TInt64.vectorOf(10,100);
|
||||
// // TInt64 tensor = TInt64.tensorOf(model.functions().get(0).signature().getInputs().values().iterator().next().shape);
|
||||
// Tensor result = session.runner()
|
||||
// .feed("keras_tensor", tensor)
|
||||
// .fetch("output_0")
|
||||
// .run().get(0);
|
||||
|
||||
// Result result = model.call(tensors);
|
||||
// byte[] graphDef = Files.readAllBytes(Paths.get("./data/semantic/model.keras"));
|
||||
// GraphDef def = GraphDef.newBuilder()
|
||||
// .mergeFrom(graphDef)
|
||||
// .build();
|
||||
// graph.importGraphDef(def);
|
||||
|
||||
// try (Session session = new Session(graph)) {
|
||||
// // // Prepare input data as a Tensor
|
||||
// // Tensor inputTensor = Tensor.create(inputData);
|
||||
|
||||
// // // Run inference
|
||||
// // Tensor result = session.runner()
|
||||
// // .feed("input_node_name", inputTensor) // Replace with actual input tensor name
|
||||
// // .fetch("output_node_name") // Replace with actual output tensor name
|
||||
// // .run().get(0);
|
||||
|
||||
// // // Process the result
|
||||
// // float[][] output = result.copyTo(new float[1][outputSize]);
|
||||
// // System.out.println(Arrays.toString(output[0]));
|
||||
// }
|
||||
// } catch (IOException e) {
|
||||
// e.printStackTrace();
|
||||
// }
|
||||
AgentLoop.main();
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@ -1,53 +1,54 @@
|
||||
package org.studiorailgun.conversation.semantic;
|
||||
|
||||
import java.io.File;
|
||||
import java.io.IOException;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
import org.deeplearning4j.models.word2vec.Word2Vec;
|
||||
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
|
||||
import org.deeplearning4j.nn.conf.BackpropType;
|
||||
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||
import org.deeplearning4j.nn.conf.layers.DenseLayer;
|
||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||
import org.deeplearning4j.nn.weights.WeightInit;
|
||||
import org.deeplearning4j.optimize.api.InvocationType;
|
||||
import org.deeplearning4j.optimize.listeners.EvaluativeListener;
|
||||
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
|
||||
import org.deeplearning4j.text.sentenceiterator.LineSentenceIterator;
|
||||
import org.deeplearning4j.text.sentenceiterator.SentenceIterator;
|
||||
import org.deeplearning4j.text.tokenization.tokenizer.preprocessor.CommonPreprocessor;
|
||||
import org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory;
|
||||
import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory;
|
||||
import org.nd4j.evaluation.classification.Evaluation;
|
||||
import org.nd4j.linalg.activations.Activation;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
||||
import org.nd4j.linalg.learning.config.Sgd;
|
||||
import org.tensorflow.Result;
|
||||
import org.tensorflow.SavedModelBundle;
|
||||
import org.tensorflow.Tensor;
|
||||
import org.tensorflow.ndarray.FloatNdArray;
|
||||
import org.tensorflow.types.TFloat32;
|
||||
import org.tensorflow.types.TInt64;
|
||||
|
||||
/**
|
||||
* Parses the subject of a sentence
|
||||
*/
|
||||
public class SentenceSubjectParser {
|
||||
/**
|
||||
* The configuration for the network
|
||||
*/
|
||||
MultiLayerConfiguration conf;
|
||||
|
||||
/**
|
||||
* The actual model created from the configuration
|
||||
* The model
|
||||
*/
|
||||
MultiLayerNetwork model;
|
||||
|
||||
/**
|
||||
* The number of epochs to try to train for
|
||||
*/
|
||||
int targetEpochs;
|
||||
SavedModelBundle model;
|
||||
|
||||
/**
|
||||
* Initializes the model
|
||||
*/
|
||||
public void init(){
|
||||
|
||||
model = SavedModelBundle.load("./data/semantic/model");
|
||||
|
||||
//describe the model
|
||||
// System.out.println(model.functions());
|
||||
// System.out.println(model.functions().get(0).signature());
|
||||
// System.out.println(model.functions().get(1).signature());
|
||||
// System.out.println(model.functions().get(2).signature());
|
||||
// System.out.println(model.functions().get(0).signature().getInputs());
|
||||
// System.out.println(model.functions().get(0).signature().getInputs().values().iterator().next().name);
|
||||
// System.out.println(model.functions().get(0).signature().getInputs().values().iterator().next().shape);
|
||||
}
|
||||
|
||||
public void evaluate(String sentence){
|
||||
//run predict
|
||||
TInt64 tensor = TInt64.scalarOf(10);
|
||||
org.tensorflow.ndarray.Shape shape = tensor.shape().append(2001).append(128);
|
||||
tensor = TInt64.tensorOf(shape);
|
||||
|
||||
Map<String,Tensor> inputMap = new HashMap<String,Tensor>();
|
||||
inputMap.put("keras_tensor",tensor);
|
||||
|
||||
Result result = model.function("serve").call(inputMap);
|
||||
System.out.println(result.get(0));
|
||||
TFloat32 resultTensor = (TFloat32)result.get(0);
|
||||
FloatNdArray floatArr = resultTensor.get(0);
|
||||
System.out.println(floatArr.getFloat(0));
|
||||
}
|
||||
}
|
||||
|
||||
Loading…
Reference in New Issue
Block a user