migrate model loading out of main class
This commit is contained in:
parent
3dc2bcd501
commit
ef66c21cad
@ -1,8 +1,6 @@
|
|||||||
package org.studiorailgun;
|
package org.studiorailgun;
|
||||||
|
|
||||||
import org.tensorflow.SavedModelBundle;
|
import org.studiorailgun.conversation.AgentLoop;
|
||||||
import org.tensorflow.Tensor;
|
|
||||||
import org.tensorflow.types.TInt64;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* The main class
|
* The main class
|
||||||
@ -13,70 +11,7 @@ public class Main {
|
|||||||
* The main method
|
* The main method
|
||||||
*/
|
*/
|
||||||
public static void main(String[] args){
|
public static void main(String[] args){
|
||||||
// AgentLoop.main();
|
AgentLoop.main();
|
||||||
// try (Graph graph = new Graph()) {
|
|
||||||
// prediction_signature = tf.saved_model.signature_def_utils.predict_signature_def({"input": inputs}, {"output":output});
|
|
||||||
SavedModelBundle model = SavedModelBundle.load("./data/semantic/model");
|
|
||||||
System.out.println(model.functions());
|
|
||||||
System.out.println(model.functions().get(0).signature());
|
|
||||||
System.out.println(model.functions().get(1).signature());
|
|
||||||
System.out.println(model.functions().get(2).signature());
|
|
||||||
System.out.println(model.functions().get(0).signature().getInputs());
|
|
||||||
System.out.println(model.functions().get(0).signature().getInputs().values().iterator().next().shape);
|
|
||||||
//init
|
|
||||||
// model.functions().get(2).session().runner()
|
|
||||||
// .fetch("__saved_model_init_op")
|
|
||||||
// .run();
|
|
||||||
//run predict
|
|
||||||
TInt64 tensor = TInt64.scalarOf(10);
|
|
||||||
org.tensorflow.ndarray.Shape shape = tensor.shape().append(2001).append(128);
|
|
||||||
tensor = TInt64.tensorOf(shape);
|
|
||||||
// NdArray<LongNdArray> arr = NdArrays.vectorOfObjects(NdArrays.vectorOf((long)0, (long)0, (long)0, (long)0));
|
|
||||||
// System.out.println(arr.shape());
|
|
||||||
// tensor = TInt64.tensorOf(arr.shape());
|
|
||||||
|
|
||||||
Tensor result = model.function("serve").call(tensor);
|
|
||||||
// model.functions().get(0).session().runner()
|
|
||||||
// .feed("keras_tensor", tensor)
|
|
||||||
// .fetch("output_0")
|
|
||||||
// .run();
|
|
||||||
|
|
||||||
// Graph graph = model.graph();
|
|
||||||
// Session session = model.session();
|
|
||||||
|
|
||||||
// Tensor<Long> inputTensor = Tensor.of(Long.class, model.functions().get(0).signature().getInputs().values().iterator().next().shape);
|
|
||||||
// NdArrays.ofLongs(shape(2, 3, 2));
|
|
||||||
// TInt64 tensor = TInt64.vectorOf(10,100);
|
|
||||||
// // TInt64 tensor = TInt64.tensorOf(model.functions().get(0).signature().getInputs().values().iterator().next().shape);
|
|
||||||
// Tensor result = session.runner()
|
|
||||||
// .feed("keras_tensor", tensor)
|
|
||||||
// .fetch("output_0")
|
|
||||||
// .run().get(0);
|
|
||||||
|
|
||||||
// Result result = model.call(tensors);
|
|
||||||
// byte[] graphDef = Files.readAllBytes(Paths.get("./data/semantic/model.keras"));
|
|
||||||
// GraphDef def = GraphDef.newBuilder()
|
|
||||||
// .mergeFrom(graphDef)
|
|
||||||
// .build();
|
|
||||||
// graph.importGraphDef(def);
|
|
||||||
|
|
||||||
// try (Session session = new Session(graph)) {
|
|
||||||
// // // Prepare input data as a Tensor
|
|
||||||
// // Tensor inputTensor = Tensor.create(inputData);
|
|
||||||
|
|
||||||
// // // Run inference
|
|
||||||
// // Tensor result = session.runner()
|
|
||||||
// // .feed("input_node_name", inputTensor) // Replace with actual input tensor name
|
|
||||||
// // .fetch("output_node_name") // Replace with actual output tensor name
|
|
||||||
// // .run().get(0);
|
|
||||||
|
|
||||||
// // // Process the result
|
|
||||||
// // float[][] output = result.copyTo(new float[1][outputSize]);
|
|
||||||
// // System.out.println(Arrays.toString(output[0]));
|
|
||||||
// }
|
|
||||||
// } catch (IOException e) {
|
|
||||||
// e.printStackTrace();
|
|
||||||
// }
|
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,53 +1,54 @@
|
|||||||
package org.studiorailgun.conversation.semantic;
|
package org.studiorailgun.conversation.semantic;
|
||||||
|
|
||||||
import java.io.File;
|
import java.util.HashMap;
|
||||||
import java.io.IOException;
|
import java.util.Map;
|
||||||
|
|
||||||
import org.deeplearning4j.models.word2vec.Word2Vec;
|
import org.tensorflow.Result;
|
||||||
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
|
import org.tensorflow.SavedModelBundle;
|
||||||
import org.deeplearning4j.nn.conf.BackpropType;
|
import org.tensorflow.Tensor;
|
||||||
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
import org.tensorflow.ndarray.FloatNdArray;
|
||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
import org.tensorflow.types.TFloat32;
|
||||||
import org.deeplearning4j.nn.conf.layers.DenseLayer;
|
import org.tensorflow.types.TInt64;
|
||||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
|
||||||
import org.deeplearning4j.nn.weights.WeightInit;
|
|
||||||
import org.deeplearning4j.optimize.api.InvocationType;
|
|
||||||
import org.deeplearning4j.optimize.listeners.EvaluativeListener;
|
|
||||||
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
|
|
||||||
import org.deeplearning4j.text.sentenceiterator.LineSentenceIterator;
|
|
||||||
import org.deeplearning4j.text.sentenceiterator.SentenceIterator;
|
|
||||||
import org.deeplearning4j.text.tokenization.tokenizer.preprocessor.CommonPreprocessor;
|
|
||||||
import org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory;
|
|
||||||
import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory;
|
|
||||||
import org.nd4j.evaluation.classification.Evaluation;
|
|
||||||
import org.nd4j.linalg.activations.Activation;
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
||||||
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
|
||||||
import org.nd4j.linalg.learning.config.Sgd;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Parses the subject of a sentence
|
* Parses the subject of a sentence
|
||||||
*/
|
*/
|
||||||
public class SentenceSubjectParser {
|
public class SentenceSubjectParser {
|
||||||
/**
|
|
||||||
* The configuration for the network
|
|
||||||
*/
|
|
||||||
MultiLayerConfiguration conf;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* The actual model created from the configuration
|
* The model
|
||||||
*/
|
*/
|
||||||
MultiLayerNetwork model;
|
SavedModelBundle model;
|
||||||
|
|
||||||
/**
|
|
||||||
* The number of epochs to try to train for
|
|
||||||
*/
|
|
||||||
int targetEpochs;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Initializes the model
|
* Initializes the model
|
||||||
*/
|
*/
|
||||||
public void init(){
|
public void init(){
|
||||||
|
model = SavedModelBundle.load("./data/semantic/model");
|
||||||
|
|
||||||
|
//describe the model
|
||||||
|
// System.out.println(model.functions());
|
||||||
|
// System.out.println(model.functions().get(0).signature());
|
||||||
|
// System.out.println(model.functions().get(1).signature());
|
||||||
|
// System.out.println(model.functions().get(2).signature());
|
||||||
|
// System.out.println(model.functions().get(0).signature().getInputs());
|
||||||
|
// System.out.println(model.functions().get(0).signature().getInputs().values().iterator().next().name);
|
||||||
|
// System.out.println(model.functions().get(0).signature().getInputs().values().iterator().next().shape);
|
||||||
|
}
|
||||||
|
|
||||||
|
public void evaluate(String sentence){
|
||||||
|
//run predict
|
||||||
|
TInt64 tensor = TInt64.scalarOf(10);
|
||||||
|
org.tensorflow.ndarray.Shape shape = tensor.shape().append(2001).append(128);
|
||||||
|
tensor = TInt64.tensorOf(shape);
|
||||||
|
|
||||||
|
Map<String,Tensor> inputMap = new HashMap<String,Tensor>();
|
||||||
|
inputMap.put("keras_tensor",tensor);
|
||||||
|
|
||||||
|
Result result = model.function("serve").call(inputMap);
|
||||||
|
System.out.println(result.get(0));
|
||||||
|
TFloat32 resultTensor = (TFloat32)result.get(0);
|
||||||
|
FloatNdArray floatArr = resultTensor.get(0);
|
||||||
|
System.out.println(floatArr.getFloat(0));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user