import tensorflow as tf keras = tf.keras from tensorflow import Tensor from keras.api.layers import TextVectorization, Embedding, LSTM, Dense, Input from keras.api.models import Sequential import numpy as np import numpy.typing as npt # Model constants. max_features: int = 20000 embedding_dim: int = 128 sequence_length: int = 500 epochs: int = 50 max_tokens: int = 5000 output_sequence_length: int = 4 # read sentences data_path: str = './data/semantic/subject.txt' data_raw: str = open(data_path).read() vocab: list[str] = data_raw.split('\n') # read labels label_data_path: str = './data/semantic/subject_label.txt' label_data_raw: str = open(label_data_path).read() labels: list[int] = list(map(int,label_data_raw.split())) # init vectorizer textVec: TextVectorization = TextVectorization( max_tokens=max_tokens, output_mode='int', output_sequence_length=output_sequence_length, pad_to_max_tokens=True) # Add the vocab to the tokenizer textVec.adapt(vocab) input_data: list[str] = vocab data: Tensor = textVec.call(input_data) # construct model model: Sequential = Sequential([ keras.Input(shape=(None,), dtype="int64"), Embedding(max_features + 1, embedding_dim), LSTM(64), Dense(1, activation='sigmoid') ]) #compile the model # model.build(keras.Input(shape=(None,), dtype="int64")) model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy']) # fit the training data npData = np.array(data) npLabel = np.array(labels) model.fit(npData,npLabel,epochs=epochs) # evaluate here # predict predictTargetRaw: list[str] = ['saf'] predictTargetToken: list[int] = textVec.call(predictTargetRaw) npPredict: npt.NDArray[np.complex64] = np.array(predictTargetToken) # print(npPredict) result: list[int] = model.predict(npPredict) print("predict result:") print(predictTargetToken) print(result) print(data) print(labels) # save the model so keras can reload # savePath: str = './data/semantic/model.keras' # model.save(savePath) # export the model so java can leverage it exportPath: str = './data/semantic/model' model.export(exportPath) # tf.keras.utils.get_file('asdf') # asdf: str = 'a'