Spaces:
Sleeping
Sleeping
import numpy as np | |
import gradio as gr | |
import os | |
import tempfile | |
import shutil | |
from trainer import Trainer | |
def predict(input_text): | |
predicted_label = trainer.predict(input_text ) | |
return str(predicted_label) | |
# pass | |
def predict_maxent(input_text): | |
predicted_label = trainer_maxent.predict_maxent(input_text ) | |
return str(predicted_label) | |
# pass | |
def predict_svm(input_text): | |
predicted_label = trainer_svm.predict_svm(input_text ) | |
return str(predicted_label) | |
# pass | |
def create_demo(): | |
USAGE = """## Text Classification | |
### Online demo for Artificial Intelligence Principles 2024 spring course project. | |
""" | |
with gr.Blocks() as demo: | |
gr.Markdown(USAGE) | |
# demo = | |
# gr.Interface( | |
# predict, | |
# # gr.Dataframe(type="numpy", datatype="number", row_count=5, col_count=3), | |
# gr.File(type="filepath"), | |
# gr.File(type="filepath"), | |
# cache_examples=False | |
# ) | |
# input_file = gr.File(type="filepath") | |
# output_file = gr.File(type="filepath") | |
gr.Interface(fn=predict, inputs="textbox", outputs="textbox", title='LSTM') | |
gr.Interface(fn=predict_maxent, inputs="textbox", outputs="textbox", title='MaxEnt') | |
gr.Interface(fn=predict_svm, inputs="textbox", outputs="textbox", title='SVM') | |
# gr.Interface( | |
# predict, | |
# # gr.Dataframe(type="numpy", datatype="number", row_count=5, col_count=3), | |
# input_file, | |
# output_file, | |
# cache_examples=False | |
# ) | |
# inputs = input_file | |
# outputs = output_file | |
# gr.Examples( | |
# examples=[os.path.join(os.path.dirname(__file__), "./gradio_inter/20231104_017.pkl")], | |
# inputs=inputs, | |
# fn=predict, | |
# outputs=outputs, | |
# ) | |
return demo | |
if __name__ == "__main__": | |
vocab_size = 8000 | |
sequence_len = 150 | |
# batch_size = 1024 | |
batch_size = 256 | |
nn_epochs = 20 | |
model_type = "lstm" | |
# model_type = "bilstm" | |
trainer = Trainer(vocab_size, sequence_len, batch_size, nn_epochs, model_type) | |
model_type = "max_ent" | |
trainer_maxent = Trainer(vocab_size, sequence_len, batch_size, nn_epochs, model_type) | |
model_type = "svm" | |
trainer_svm = Trainer(vocab_size, sequence_len, batch_size, nn_epochs, model_type) | |
demo = create_demo() | |
demo.launch() | |