86 lines
2.1 KiB
Python
86 lines
2.1 KiB
Python
import tensorflow as tf
|
|
keras = tf.keras
|
|
from tensorflow import Tensor
|
|
from keras.api.layers import TextVectorization, Embedding, LSTM, Dense, Input
|
|
from keras.api.models import Sequential
|
|
import numpy as np
|
|
import numpy.typing as npt
|
|
|
|
|
|
# Model constants.
|
|
max_features: int = 20000
|
|
embedding_dim: int = 128
|
|
sequence_length: int = 500
|
|
epochs: int = 50
|
|
max_tokens: int = 5000
|
|
output_sequence_length: int = 4
|
|
|
|
|
|
# read sentences
|
|
data_path: str = './data/semantic/subject.txt'
|
|
data_raw: str = open(data_path).read()
|
|
vocab: list[str] = data_raw.split('\n')
|
|
|
|
# read labels
|
|
label_data_path: str = './data/semantic/subject_label.txt'
|
|
label_data_raw: str = open(label_data_path).read()
|
|
labels: list[int] = list(map(int,label_data_raw.split()))
|
|
|
|
# init vectorizer
|
|
textVec: TextVectorization = TextVectorization(
|
|
max_tokens=max_tokens,
|
|
output_mode='int',
|
|
output_sequence_length=output_sequence_length,
|
|
pad_to_max_tokens=True)
|
|
|
|
# Add the vocab to the tokenizer
|
|
textVec.adapt(vocab)
|
|
input_data: list[str] = vocab
|
|
data: Tensor = textVec.call(input_data)
|
|
|
|
# construct model
|
|
model: Sequential = Sequential([
|
|
keras.Input(shape=(None,), dtype="int64"),
|
|
Embedding(max_features + 1, embedding_dim),
|
|
LSTM(64),
|
|
Dense(1, activation='sigmoid')
|
|
])
|
|
|
|
#compile the model
|
|
# model.build(keras.Input(shape=(None,), dtype="int64"))
|
|
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
|
|
|
|
|
|
# fit the training data
|
|
npData = np.array(data)
|
|
npLabel = np.array(labels)
|
|
model.fit(npData,npLabel,epochs=epochs)
|
|
|
|
|
|
# evaluate here
|
|
|
|
|
|
# predict
|
|
predictTargetRaw: list[str] = ['saf']
|
|
predictTargetToken: list[int] = textVec.call(predictTargetRaw)
|
|
npPredict: npt.NDArray[np.complex64] = np.array(predictTargetToken)
|
|
# print(npPredict)
|
|
result: list[int] = model.predict(npPredict)
|
|
print("predict result:")
|
|
print(predictTargetToken)
|
|
print(result)
|
|
print(data)
|
|
print(labels)
|
|
|
|
|
|
|
|
# save the model so keras can reload
|
|
# savePath: str = './data/semantic/model.keras'
|
|
# model.save(savePath)
|
|
|
|
# export the model so java can leverage it
|
|
exportPath: str = './data/semantic/model'
|
|
model.export(exportPath)
|
|
|
|
# tf.keras.utils.get_file('asdf')
|
|
# asdf: str = 'a' |