migrate model loading out of main class
This commit is contained in:
		
							parent
							
								
									3dc2bcd501
								
							
						
					
					
						commit
						ef66c21cad
					
				| @ -1,8 +1,6 @@ | ||||
| package org.studiorailgun; | ||||
| 
 | ||||
| import org.tensorflow.SavedModelBundle; | ||||
| import org.tensorflow.Tensor; | ||||
| import org.tensorflow.types.TInt64; | ||||
| import org.studiorailgun.conversation.AgentLoop; | ||||
| 
 | ||||
| /** | ||||
|  * The main class | ||||
| @ -13,70 +11,7 @@ public class Main { | ||||
|      * The main method | ||||
|      */ | ||||
|     public static void main(String[] args){ | ||||
|         // AgentLoop.main(); | ||||
|         // try (Graph graph = new Graph()) { | ||||
|             // prediction_signature = tf.saved_model.signature_def_utils.predict_signature_def({"input": inputs}, {"output":output}); | ||||
|             SavedModelBundle model = SavedModelBundle.load("./data/semantic/model"); | ||||
|             System.out.println(model.functions()); | ||||
|             System.out.println(model.functions().get(0).signature()); | ||||
|             System.out.println(model.functions().get(1).signature()); | ||||
|             System.out.println(model.functions().get(2).signature()); | ||||
|             System.out.println(model.functions().get(0).signature().getInputs()); | ||||
|             System.out.println(model.functions().get(0).signature().getInputs().values().iterator().next().shape); | ||||
|             //init | ||||
|             // model.functions().get(2).session().runner() | ||||
|             // .fetch("__saved_model_init_op") | ||||
|             // .run(); | ||||
|             //run predict | ||||
|             TInt64 tensor = TInt64.scalarOf(10); | ||||
|             org.tensorflow.ndarray.Shape shape = tensor.shape().append(2001).append(128); | ||||
|             tensor = TInt64.tensorOf(shape); | ||||
|             // NdArray<LongNdArray> arr = NdArrays.vectorOfObjects(NdArrays.vectorOf((long)0, (long)0, (long)0, (long)0)); | ||||
|             // System.out.println(arr.shape()); | ||||
|             // tensor = TInt64.tensorOf(arr.shape()); | ||||
|              | ||||
|             Tensor result = model.function("serve").call(tensor); | ||||
|             // model.functions().get(0).session().runner() | ||||
|             // .feed("keras_tensor", tensor) | ||||
|             // .fetch("output_0") | ||||
|             // .run(); | ||||
| 
 | ||||
|             // Graph graph = model.graph(); | ||||
|             // Session session = model.session(); | ||||
| 
 | ||||
|             // Tensor<Long> inputTensor = Tensor.of(Long.class, model.functions().get(0).signature().getInputs().values().iterator().next().shape); | ||||
|             // NdArrays.ofLongs(shape(2, 3, 2)); | ||||
|             // TInt64 tensor = TInt64.vectorOf(10,100); | ||||
|             // // TInt64 tensor = TInt64.tensorOf(model.functions().get(0).signature().getInputs().values().iterator().next().shape); | ||||
|             // Tensor result = session.runner() | ||||
|             //     .feed("keras_tensor", tensor) | ||||
|             //     .fetch("output_0") | ||||
|             //     .run().get(0); | ||||
|              | ||||
|             // Result result = model.call(tensors); | ||||
|             // byte[] graphDef = Files.readAllBytes(Paths.get("./data/semantic/model.keras")); | ||||
|             // GraphDef def = GraphDef.newBuilder() | ||||
|             // .mergeFrom(graphDef) | ||||
|             // .build(); | ||||
|             // graph.importGraphDef(def); | ||||
| 
 | ||||
|             // try (Session session = new Session(graph)) { | ||||
|             //     // // Prepare input data as a Tensor | ||||
|             //     // Tensor inputTensor = Tensor.create(inputData); | ||||
|                  | ||||
|             //     // // Run inference | ||||
|             //     // Tensor result = session.runner() | ||||
|             //     //                         .feed("input_node_name", inputTensor) // Replace with actual input tensor name | ||||
|             //     //                         .fetch("output_node_name") // Replace with actual output tensor name | ||||
|             //     //                         .run().get(0); | ||||
| 
 | ||||
|             //     // // Process the result | ||||
|             //     // float[][] output = result.copyTo(new float[1][outputSize]);  | ||||
|             //     // System.out.println(Arrays.toString(output[0])); | ||||
|             // } | ||||
|         // } catch (IOException e) { | ||||
|         //     e.printStackTrace(); | ||||
|         // } | ||||
|         AgentLoop.main(); | ||||
|     } | ||||
| 
 | ||||
| } | ||||
|  | ||||
| @ -1,53 +1,54 @@ | ||||
| package org.studiorailgun.conversation.semantic; | ||||
| 
 | ||||
| import java.io.File; | ||||
| import java.io.IOException; | ||||
| import java.util.HashMap; | ||||
| import java.util.Map; | ||||
| 
 | ||||
| import org.deeplearning4j.models.word2vec.Word2Vec; | ||||
| import org.deeplearning4j.nn.api.OptimizationAlgorithm; | ||||
| import org.deeplearning4j.nn.conf.BackpropType; | ||||
| import org.deeplearning4j.nn.conf.MultiLayerConfiguration; | ||||
| import org.deeplearning4j.nn.conf.NeuralNetConfiguration; | ||||
| import org.deeplearning4j.nn.conf.layers.DenseLayer; | ||||
| import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; | ||||
| import org.deeplearning4j.nn.weights.WeightInit; | ||||
| import org.deeplearning4j.optimize.api.InvocationType; | ||||
| import org.deeplearning4j.optimize.listeners.EvaluativeListener; | ||||
| import org.deeplearning4j.optimize.listeners.ScoreIterationListener; | ||||
| import org.deeplearning4j.text.sentenceiterator.LineSentenceIterator; | ||||
| import org.deeplearning4j.text.sentenceiterator.SentenceIterator; | ||||
| import org.deeplearning4j.text.tokenization.tokenizer.preprocessor.CommonPreprocessor; | ||||
| import org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory; | ||||
| import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory; | ||||
| import org.nd4j.evaluation.classification.Evaluation; | ||||
| import org.nd4j.linalg.activations.Activation; | ||||
| import org.nd4j.linalg.api.ndarray.INDArray; | ||||
| import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; | ||||
| import org.nd4j.linalg.learning.config.Sgd; | ||||
| import org.tensorflow.Result; | ||||
| import org.tensorflow.SavedModelBundle; | ||||
| import org.tensorflow.Tensor; | ||||
| import org.tensorflow.ndarray.FloatNdArray; | ||||
| import org.tensorflow.types.TFloat32; | ||||
| import org.tensorflow.types.TInt64; | ||||
| 
 | ||||
| /** | ||||
|  * Parses the subject of a sentence | ||||
|  */ | ||||
| public class SentenceSubjectParser { | ||||
|     /** | ||||
|      * The configuration for the network | ||||
|      */ | ||||
|     MultiLayerConfiguration conf; | ||||
| 
 | ||||
|     /** | ||||
|      * The actual model created from the configuration | ||||
|      * The model | ||||
|      */ | ||||
|     MultiLayerNetwork model; | ||||
| 
 | ||||
|     /** | ||||
|      * The number of epochs to try to train for | ||||
|      */ | ||||
|     int targetEpochs; | ||||
|     SavedModelBundle model; | ||||
|      | ||||
|     /** | ||||
|      * Initializes the model | ||||
|      */ | ||||
|     public void init(){ | ||||
|         model = SavedModelBundle.load("./data/semantic/model"); | ||||
| 
 | ||||
|         //describe the model | ||||
|         // System.out.println(model.functions()); | ||||
|         // System.out.println(model.functions().get(0).signature()); | ||||
|         // System.out.println(model.functions().get(1).signature()); | ||||
|         // System.out.println(model.functions().get(2).signature()); | ||||
|         // System.out.println(model.functions().get(0).signature().getInputs()); | ||||
|         // System.out.println(model.functions().get(0).signature().getInputs().values().iterator().next().name); | ||||
|         // System.out.println(model.functions().get(0).signature().getInputs().values().iterator().next().shape); | ||||
|     } | ||||
| 
 | ||||
|     public void evaluate(String sentence){ | ||||
|         //run predict | ||||
|         TInt64 tensor = TInt64.scalarOf(10); | ||||
|         org.tensorflow.ndarray.Shape shape = tensor.shape().append(2001).append(128); | ||||
|         tensor = TInt64.tensorOf(shape); | ||||
| 
 | ||||
|         Map<String,Tensor> inputMap = new HashMap<String,Tensor>(); | ||||
|         inputMap.put("keras_tensor",tensor); | ||||
| 
 | ||||
|         Result result = model.function("serve").call(inputMap); | ||||
|         System.out.println(result.get(0)); | ||||
|         TFloat32 resultTensor = (TFloat32)result.get(0); | ||||
|         FloatNdArray floatArr = resultTensor.get(0); | ||||
|         System.out.println(floatArr.getFloat(0)); | ||||
|     } | ||||
| } | ||||
|  | ||||
		Loading…
	
		Reference in New Issue
	
	Block a user