Spaces:
Sleeping
Sleeping
import numpy as np | |
# import gradio as gr | |
import os | |
import tempfile | |
import shutil | |
from trainer import Trainer | |
def predict(input_text, model_type): | |
if model_type in ['lstm', 'bilstm']: | |
predicted_label = trainer.predict(input_text ) | |
elif model_type == 'max_ent': | |
predicted_label = trainer.predict_maxent(input_text) | |
elif model_type == 'svm': | |
predicted_label = trainer.predict_svm(input_text) | |
return str(predicted_label) | |
# pass | |
def predict_omni(input_text, model_type): | |
predicted_label_net = trainer.predict(input_text ) | |
predicted_label_maxent = trainer_maxent.predict_maxent(input_text ) | |
predicted_label_svm = trainer_svm.predict_svm(input_text ) | |
# if model_type in ['lstm', 'bilstm']: | |
# predicted_label = trainer.predict(input_text ) | |
# elif model_type == 'max_ent': | |
# predicted_label = trainer.predict_maxent(input_text) | |
# elif model_type == 'svm': | |
# predicted_label = trainer.predict_svm(input_text) | |
predicted_text = f"LSTM: {predicted_label_net}, Max Ent: {predicted_label_maxent}, SVM: {predicted_label_svm}" | |
return predicted_text | |
# pass | |
def create_demo(): | |
USAGE = """## Text Classification | |
""" | |
with gr.Blocks() as demo: | |
gr.Markdown(USAGE) | |
# demo = | |
# gr.Interface( | |
# predict, | |
# # gr.Dataframe(type="numpy", datatype="number", row_count=5, col_count=3), | |
# gr.File(type="filepath"), | |
# gr.File(type="filepath"), | |
# cache_examples=False | |
# ) | |
input_file = gr.File(type="filepath") | |
output_file = gr.File(type="filepath") | |
gr.Interface(fn=greet, inputs="textbox", outputs="textbox") | |
# gr.Interface( | |
# predict, | |
# # gr.Dataframe(type="numpy", datatype="number", row_count=5, col_count=3), | |
# input_file, | |
# output_file, | |
# cache_examples=False | |
# ) | |
# inputs = input_file | |
# outputs = output_file | |
# gr.Examples( | |
# examples=[os.path.join(os.path.dirname(__file__), "./gradio_inter/20231104_017.pkl")], | |
# inputs=inputs, | |
# fn=predict, | |
# outputs=outputs, | |
# ) | |
return demo | |
if __name__ == "__main__": | |
vocab_size = 8000 | |
sequence_len = 150 | |
# batch_size = 1024 | |
batch_size = 256 | |
nn_epochs = 20 | |
model_type = "lstm" | |
# model_type = "bilstm" | |
# model_type = "max_ent" | |
# trainer = Trainer(vocab_size, sequence_len, batch_size, nn_epochs, model_type) | |
# print(f"Trainer loaded") | |
model_type = "lstm" | |
trainer = Trainer(vocab_size, sequence_len, batch_size, nn_epochs, model_type) | |
model_type = "max_ent" | |
trainer_maxent = Trainer(vocab_size, sequence_len, batch_size, nn_epochs, model_type) | |
model_type = "svm" | |
trainer_svm = Trainer(vocab_size, sequence_len, batch_size, nn_epochs, model_type) | |
while True: | |
input_text = input() | |
# if model_type in ["lstm", "bilstm"]: | |
# label = predict(input_text, model_type) | |
label = predict_omni(input_text, model_type) | |
# elif model_type in ["max_ent"]: | |
# label = | |
print(label) | |
# demo = create_demo() | |
# demo.launch() | |
# python app_local.py | |