File size: 3,417 Bytes
ab2adfb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import numpy as np

# import gradio as gr


import os

import tempfile

import shutil

from trainer import Trainer

def predict(input_text, model_type):
    if model_type in ['lstm', 'bilstm']:
        predicted_label = trainer.predict(input_text )
    elif model_type == 'max_ent':
        predicted_label = trainer.predict_maxent(input_text)
    elif model_type == 'svm':
        predicted_label = trainer.predict_svm(input_text)
    
    return str(predicted_label)
    # pass
    
def predict_omni(input_text, model_type):
    predicted_label_net = trainer.predict(input_text )
    predicted_label_maxent = trainer_maxent.predict_maxent(input_text )
    predicted_label_svm = trainer_svm.predict_svm(input_text )
    # if model_type in ['lstm', 'bilstm']:
    #     predicted_label = trainer.predict(input_text )
    # elif model_type == 'max_ent':
    #     predicted_label = trainer.predict_maxent(input_text)
    # elif model_type == 'svm':
    #     predicted_label = trainer.predict_svm(input_text)
    predicted_text = f"LSTM: {predicted_label_net}, Max Ent: {predicted_label_maxent}, SVM: {predicted_label_svm}"
    return predicted_text
    # pass
    

def create_demo():

    USAGE = """## Text Classification
    
    """


    with gr.Blocks() as demo:

        gr.Markdown(USAGE)

        # demo = 
        # gr.Interface(
        #     predict,
        #     # gr.Dataframe(type="numpy", datatype="number", row_count=5, col_count=3),
        #     gr.File(type="filepath"),
        #     gr.File(type="filepath"),
        #     cache_examples=False
        # )
        
        input_file = gr.File(type="filepath")
        output_file = gr.File(type="filepath")
        
        gr.Interface(fn=greet, inputs="textbox", outputs="textbox")
        
        # gr.Interface(
        #     predict,
        #     # gr.Dataframe(type="numpy", datatype="number", row_count=5, col_count=3),
        #     input_file,
        #     output_file,
        #     cache_examples=False
        # )
        
        # inputs = input_file
        # outputs = output_file
        # gr.Examples(
        #     examples=[os.path.join(os.path.dirname(__file__), "./gradio_inter/20231104_017.pkl")],
        #     inputs=inputs,
        #     fn=predict,
        #     outputs=outputs,
        # )
        

    return demo
    
if __name__ == "__main__":
    
    vocab_size = 8000
    sequence_len = 150
    
    # batch_size = 1024
    batch_size = 256
    nn_epochs = 20
    model_type = "lstm"
    
    # model_type = "bilstm"
    
    # model_type = "max_ent"
    
    # trainer = Trainer(vocab_size, sequence_len, batch_size, nn_epochs, model_type)
    # print(f"Trainer loaded")
    
    
    model_type = "lstm"
    
    trainer = Trainer(vocab_size, sequence_len, batch_size, nn_epochs, model_type)
    
    model_type = "max_ent"
    trainer_maxent = Trainer(vocab_size, sequence_len, batch_size, nn_epochs, model_type)
    
    model_type = "svm"
    trainer_svm = Trainer(vocab_size, sequence_len, batch_size, nn_epochs, model_type)
    
    
    while True:
        input_text = input()
        # if model_type in ["lstm", "bilstm"]:
        # label = predict(input_text, model_type)
        label = predict_omni(input_text, model_type)
        # elif model_type in ["max_ent"]:
        #     label = 
        print(label)
        
    # demo = create_demo()
    # demo.launch()
    # python app_local.py