File size: 2,573 Bytes
ab2adfb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8a58366
8aa390f
ab2adfb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6930b7b
ab2adfb
6930b7b
ab2adfb
6930b7b
ab2adfb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import numpy as np

import gradio as gr


import os

import tempfile

import shutil

from trainer import Trainer

def predict(input_text):
    predicted_label = trainer.predict(input_text )
    return str(predicted_label)
    # pass
    
def predict_maxent(input_text):
    predicted_label = trainer_maxent.predict_maxent(input_text )
    return str(predicted_label)
    # pass
    
def predict_svm(input_text):
    predicted_label = trainer_svm.predict_svm(input_text )
    return str(predicted_label)
    # pass
    

def create_demo():

    USAGE = """## Text Classification

    ### Online demo for Artificial Intelligence Principles 2024 spring course project.
    """


    with gr.Blocks() as demo:

        gr.Markdown(USAGE)

        # demo = 
        # gr.Interface(
        #     predict,
        #     # gr.Dataframe(type="numpy", datatype="number", row_count=5, col_count=3),
        #     gr.File(type="filepath"),
        #     gr.File(type="filepath"),
        #     cache_examples=False
        # )
        
        # input_file = gr.File(type="filepath")
        # output_file = gr.File(type="filepath")
        
        gr.Interface(fn=predict, inputs="textbox", outputs="textbox", title='LSTM')
        
        gr.Interface(fn=predict_maxent, inputs="textbox", outputs="textbox", title='MaxEnt')
        
        gr.Interface(fn=predict_svm, inputs="textbox", outputs="textbox", title='SVM')
        
        # gr.Interface(
        #     predict,
        #     # gr.Dataframe(type="numpy", datatype="number", row_count=5, col_count=3),
        #     input_file,
        #     output_file,
        #     cache_examples=False
        # )
        
        # inputs = input_file
        # outputs = output_file
        # gr.Examples(
        #     examples=[os.path.join(os.path.dirname(__file__), "./gradio_inter/20231104_017.pkl")],
        #     inputs=inputs,
        #     fn=predict,
        #     outputs=outputs,
        # )
        

    return demo
    
if __name__ == "__main__":
    
    vocab_size = 8000
    sequence_len = 150
    
    # batch_size = 1024
    batch_size = 256
    nn_epochs = 20
    model_type = "lstm"
    
    # model_type = "bilstm"
    
    trainer = Trainer(vocab_size, sequence_len, batch_size, nn_epochs, model_type)
    
    model_type = "max_ent"
    trainer_maxent = Trainer(vocab_size, sequence_len, batch_size, nn_epochs, model_type)
    
    
    model_type = "svm"
    trainer_svm = Trainer(vocab_size, sequence_len, batch_size, nn_epochs, model_type)
    
    
    
    demo = create_demo()
    demo.launch()