migrate model loading out of main class
This commit is contained in:
		
							parent
							
								
									3dc2bcd501
								
							
						
					
					
						commit
						ef66c21cad
					
				| @ -1,8 +1,6 @@ | |||||||
| package org.studiorailgun; | package org.studiorailgun; | ||||||
| 
 | 
 | ||||||
| import org.tensorflow.SavedModelBundle; | import org.studiorailgun.conversation.AgentLoop; | ||||||
| import org.tensorflow.Tensor; |  | ||||||
| import org.tensorflow.types.TInt64; |  | ||||||
| 
 | 
 | ||||||
| /** | /** | ||||||
|  * The main class |  * The main class | ||||||
| @ -13,70 +11,7 @@ public class Main { | |||||||
|      * The main method |      * The main method | ||||||
|      */ |      */ | ||||||
|     public static void main(String[] args){ |     public static void main(String[] args){ | ||||||
|         // AgentLoop.main(); |         AgentLoop.main(); | ||||||
|         // try (Graph graph = new Graph()) { |  | ||||||
|             // prediction_signature = tf.saved_model.signature_def_utils.predict_signature_def({"input": inputs}, {"output":output}); |  | ||||||
|             SavedModelBundle model = SavedModelBundle.load("./data/semantic/model"); |  | ||||||
|             System.out.println(model.functions()); |  | ||||||
|             System.out.println(model.functions().get(0).signature()); |  | ||||||
|             System.out.println(model.functions().get(1).signature()); |  | ||||||
|             System.out.println(model.functions().get(2).signature()); |  | ||||||
|             System.out.println(model.functions().get(0).signature().getInputs()); |  | ||||||
|             System.out.println(model.functions().get(0).signature().getInputs().values().iterator().next().shape); |  | ||||||
|             //init |  | ||||||
|             // model.functions().get(2).session().runner() |  | ||||||
|             // .fetch("__saved_model_init_op") |  | ||||||
|             // .run(); |  | ||||||
|             //run predict |  | ||||||
|             TInt64 tensor = TInt64.scalarOf(10); |  | ||||||
|             org.tensorflow.ndarray.Shape shape = tensor.shape().append(2001).append(128); |  | ||||||
|             tensor = TInt64.tensorOf(shape); |  | ||||||
|             // NdArray<LongNdArray> arr = NdArrays.vectorOfObjects(NdArrays.vectorOf((long)0, (long)0, (long)0, (long)0)); |  | ||||||
|             // System.out.println(arr.shape()); |  | ||||||
|             // tensor = TInt64.tensorOf(arr.shape()); |  | ||||||
|              |  | ||||||
|             Tensor result = model.function("serve").call(tensor); |  | ||||||
|             // model.functions().get(0).session().runner() |  | ||||||
|             // .feed("keras_tensor", tensor) |  | ||||||
|             // .fetch("output_0") |  | ||||||
|             // .run(); |  | ||||||
| 
 |  | ||||||
|             // Graph graph = model.graph(); |  | ||||||
|             // Session session = model.session(); |  | ||||||
| 
 |  | ||||||
|             // Tensor<Long> inputTensor = Tensor.of(Long.class, model.functions().get(0).signature().getInputs().values().iterator().next().shape); |  | ||||||
|             // NdArrays.ofLongs(shape(2, 3, 2)); |  | ||||||
|             // TInt64 tensor = TInt64.vectorOf(10,100); |  | ||||||
|             // // TInt64 tensor = TInt64.tensorOf(model.functions().get(0).signature().getInputs().values().iterator().next().shape); |  | ||||||
|             // Tensor result = session.runner() |  | ||||||
|             //     .feed("keras_tensor", tensor) |  | ||||||
|             //     .fetch("output_0") |  | ||||||
|             //     .run().get(0); |  | ||||||
|              |  | ||||||
|             // Result result = model.call(tensors); |  | ||||||
|             // byte[] graphDef = Files.readAllBytes(Paths.get("./data/semantic/model.keras")); |  | ||||||
|             // GraphDef def = GraphDef.newBuilder() |  | ||||||
|             // .mergeFrom(graphDef) |  | ||||||
|             // .build(); |  | ||||||
|             // graph.importGraphDef(def); |  | ||||||
| 
 |  | ||||||
|             // try (Session session = new Session(graph)) { |  | ||||||
|             //     // // Prepare input data as a Tensor |  | ||||||
|             //     // Tensor inputTensor = Tensor.create(inputData); |  | ||||||
|                  |  | ||||||
|             //     // // Run inference |  | ||||||
|             //     // Tensor result = session.runner() |  | ||||||
|             //     //                         .feed("input_node_name", inputTensor) // Replace with actual input tensor name |  | ||||||
|             //     //                         .fetch("output_node_name") // Replace with actual output tensor name |  | ||||||
|             //     //                         .run().get(0); |  | ||||||
| 
 |  | ||||||
|             //     // // Process the result |  | ||||||
|             //     // float[][] output = result.copyTo(new float[1][outputSize]);  |  | ||||||
|             //     // System.out.println(Arrays.toString(output[0])); |  | ||||||
|             // } |  | ||||||
|         // } catch (IOException e) { |  | ||||||
|         //     e.printStackTrace(); |  | ||||||
|         // } |  | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
| } | } | ||||||
|  | |||||||
| @ -1,53 +1,54 @@ | |||||||
| package org.studiorailgun.conversation.semantic; | package org.studiorailgun.conversation.semantic; | ||||||
| 
 | 
 | ||||||
| import java.io.File; | import java.util.HashMap; | ||||||
| import java.io.IOException; | import java.util.Map; | ||||||
| 
 | 
 | ||||||
| import org.deeplearning4j.models.word2vec.Word2Vec; | import org.tensorflow.Result; | ||||||
| import org.deeplearning4j.nn.api.OptimizationAlgorithm; | import org.tensorflow.SavedModelBundle; | ||||||
| import org.deeplearning4j.nn.conf.BackpropType; | import org.tensorflow.Tensor; | ||||||
| import org.deeplearning4j.nn.conf.MultiLayerConfiguration; | import org.tensorflow.ndarray.FloatNdArray; | ||||||
| import org.deeplearning4j.nn.conf.NeuralNetConfiguration; | import org.tensorflow.types.TFloat32; | ||||||
| import org.deeplearning4j.nn.conf.layers.DenseLayer; | import org.tensorflow.types.TInt64; | ||||||
| import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; |  | ||||||
| import org.deeplearning4j.nn.weights.WeightInit; |  | ||||||
| import org.deeplearning4j.optimize.api.InvocationType; |  | ||||||
| import org.deeplearning4j.optimize.listeners.EvaluativeListener; |  | ||||||
| import org.deeplearning4j.optimize.listeners.ScoreIterationListener; |  | ||||||
| import org.deeplearning4j.text.sentenceiterator.LineSentenceIterator; |  | ||||||
| import org.deeplearning4j.text.sentenceiterator.SentenceIterator; |  | ||||||
| import org.deeplearning4j.text.tokenization.tokenizer.preprocessor.CommonPreprocessor; |  | ||||||
| import org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory; |  | ||||||
| import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory; |  | ||||||
| import org.nd4j.evaluation.classification.Evaluation; |  | ||||||
| import org.nd4j.linalg.activations.Activation; |  | ||||||
| import org.nd4j.linalg.api.ndarray.INDArray; |  | ||||||
| import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; |  | ||||||
| import org.nd4j.linalg.learning.config.Sgd; |  | ||||||
| 
 | 
 | ||||||
| /** | /** | ||||||
|  * Parses the subject of a sentence |  * Parses the subject of a sentence | ||||||
|  */ |  */ | ||||||
| public class SentenceSubjectParser { | public class SentenceSubjectParser { | ||||||
|     /** |  | ||||||
|      * The configuration for the network |  | ||||||
|      */ |  | ||||||
|     MultiLayerConfiguration conf; |  | ||||||
| 
 | 
 | ||||||
|     /** |     /** | ||||||
|      * The actual model created from the configuration |      * The model | ||||||
|      */ |      */ | ||||||
|     MultiLayerNetwork model; |     SavedModelBundle model; | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * The number of epochs to try to train for |  | ||||||
|      */ |  | ||||||
|     int targetEpochs; |  | ||||||
|      |      | ||||||
|     /** |     /** | ||||||
|      * Initializes the model |      * Initializes the model | ||||||
|      */ |      */ | ||||||
|     public void init(){ |     public void init(){ | ||||||
|  |         model = SavedModelBundle.load("./data/semantic/model"); | ||||||
| 
 | 
 | ||||||
|  |         //describe the model | ||||||
|  |         // System.out.println(model.functions()); | ||||||
|  |         // System.out.println(model.functions().get(0).signature()); | ||||||
|  |         // System.out.println(model.functions().get(1).signature()); | ||||||
|  |         // System.out.println(model.functions().get(2).signature()); | ||||||
|  |         // System.out.println(model.functions().get(0).signature().getInputs()); | ||||||
|  |         // System.out.println(model.functions().get(0).signature().getInputs().values().iterator().next().name); | ||||||
|  |         // System.out.println(model.functions().get(0).signature().getInputs().values().iterator().next().shape); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     public void evaluate(String sentence){ | ||||||
|  |         //run predict | ||||||
|  |         TInt64 tensor = TInt64.scalarOf(10); | ||||||
|  |         org.tensorflow.ndarray.Shape shape = tensor.shape().append(2001).append(128); | ||||||
|  |         tensor = TInt64.tensorOf(shape); | ||||||
|  | 
 | ||||||
|  |         Map<String,Tensor> inputMap = new HashMap<String,Tensor>(); | ||||||
|  |         inputMap.put("keras_tensor",tensor); | ||||||
|  | 
 | ||||||
|  |         Result result = model.function("serve").call(inputMap); | ||||||
|  |         System.out.println(result.get(0)); | ||||||
|  |         TFloat32 resultTensor = (TFloat32)result.get(0); | ||||||
|  |         FloatNdArray floatArr = resultTensor.get(0); | ||||||
|  |         System.out.println(floatArr.getFloat(0)); | ||||||
|     } |     } | ||||||
| } | } | ||||||
|  | |||||||
		Loading…
	
		Reference in New Issue
	
	Block a user