trpg/src/main/python/conversation/subject.py
2024-12-28 18:01:52 -05:00

86 lines
2.1 KiB
Python

import tensorflow as tf
keras = tf.keras
from tensorflow import Tensor
from keras.api.layers import TextVectorization, Embedding, LSTM, Dense, Input
from keras.api.models import Sequential
import numpy as np
import numpy.typing as npt
# Model constants.
max_features: int = 20000
embedding_dim: int = 128
sequence_length: int = 500
epochs: int = 50
max_tokens: int = 5000
output_sequence_length: int = 4
# read sentences
data_path: str = './data/semantic/subject.txt'
data_raw: str = open(data_path).read()
vocab: list[str] = data_raw.split('\n')
# read labels
label_data_path: str = './data/semantic/subject_label.txt'
label_data_raw: str = open(label_data_path).read()
labels: list[int] = list(map(int,label_data_raw.split()))
# init vectorizer
textVec: TextVectorization = TextVectorization(
max_tokens=max_tokens,
output_mode='int',
output_sequence_length=output_sequence_length,
pad_to_max_tokens=True)
# Add the vocab to the tokenizer
textVec.adapt(vocab)
input_data: list[str] = vocab
data: Tensor = textVec.call(input_data)
# construct model
model: Sequential = Sequential([
keras.Input(shape=(None,), dtype="int64"),
Embedding(max_features + 1, embedding_dim),
LSTM(64),
Dense(1, activation='sigmoid')
])
#compile the model
# model.build(keras.Input(shape=(None,), dtype="int64"))
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
# fit the training data
npData = np.array(data)
npLabel = np.array(labels)
model.fit(npData,npLabel,epochs=epochs)
# evaluate here
# predict
predictTargetRaw: list[str] = ['saf']
predictTargetToken: list[int] = textVec.call(predictTargetRaw)
npPredict: npt.NDArray[np.complex64] = np.array(predictTargetToken)
# print(npPredict)
result: list[int] = model.predict(npPredict)
print("predict result:")
print(predictTargetToken)
print(result)
print(data)
print(labels)
# save the model so keras can reload
# savePath: str = './data/semantic/model.keras'
# model.save(savePath)
# export the model so java can leverage it
exportPath: str = './data/semantic/model'
model.export(exportPath)
# tf.keras.utils.get_file('asdf')
# asdf: str = 'a'