File size: 3,388 Bytes
a83f80b
509d266
 
6671a55
509d266
6671a55
a83f80b
d564f5f
 
 
 
 
c49cab1
d564f5f
 
766e63e
9660558
 
 
 
 
 
7c020ac
 
 
 
9660558
 
509d266
 
fe7c35d
509d266
 
 
 
9660558
 
 
509d266
 
 
 
 
 
 
 
fe7c35d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
509d266
fe7c35d
 
509d266
fe7c35d
 
509d266
 
 
 
 
 
 
 
 
 
 
fe7c35d
509d266
 
 
fe7c35d
509d266
 
1acabbc
 
 
 
509d266
837f208
 
 
7c020ac
 
fe7c35d
509d266
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import gradio as gr
from transformers import AutoModelForSequenceClassification
from transformers import AutoTokenizer
import random
import torch


README = """
    # Movie Review Score Discriminator
    It is a program that classifies whether it is positive or negative by entering movie reviews.
    You can choose between the Korean version and the English version.
    ## Usage

"""




id2label = {0: "NEGATIVE", 1: "POSITIVE"}
label2id = {"NEGATIVE": 0, "POSITIVE": 1}


title = "Movie Review Score Discriminator"
description = "It is a program that classifies whether it is positive or negative by entering movie reviews. You can choose between the Korean version and the English version."




def tokenized_data(tokenizer, inputs):
    return tokenizer.batch_encode_plus(
        [inputs],
        return_tensors="pt",
        padding="max_length",
        max_length=64,
        truncation=True)




examples_eng = ["the greatest musicians ", "cold movie "]
examples_kor = ["긍정", "부정"]

examples = []
df = pd.read_csv('examples.csv', sep='\t', index_col='Unnamed: 0')
for i in range(2):
    idx = random.randint(0, 50)
    examples.append(['Eng', df.iloc[idx, 0]])
    examples.append(['Kor', df.iloc[idx, 1]])


eng_model_name = "roberta-base"
eng_step = 1900
eng_tokenizer = AutoTokenizer.from_pretrained(eng_model_name)
eng_file_name = "{}-{}.pt".format(eng_model_name, eng_step)
eng_state_dict = torch.load(eng_file_name)
eng_model = AutoModelForSequenceClassification.from_pretrained(
    eng_model_name, num_labels=2, id2label=id2label, label2id=label2id,
    state_dict=eng_state_dict
)


kor_model_name = "klue_roberta-small"
kor_step = 2400
kor_tokenizer = AutoTokenizer.from_pretrained(kor_model_name.replace('_', '/'))
kor_file_name = "{}-{}.pt".format(kor_model_name, kor_step)
kor_state_dict = torch.load(kor_file_name)
kor_model = AutoModelForSequenceClassification.from_pretrained(
    kor_model_name.replace('_', '/'), num_labels=2, id2label=id2label, label2id=label2id,
    state_dict=kor_state_dict
)


def builder(lang, text):
    if lang == 'Eng':
        model = eng_model
        tokenizer = eng_tokenizer
    else:
        model = kor_model
        tokenizer = kor_tokenizer
        
    inputs = tokenized_data(tokenizer, text)
    
    model.eval()
    with torch.no_grad():
        logits = model(input_ids=inputs['input_ids'], 
            attention_mask=inputs['attention_mask']).logits

    prediction = torch.argmax(logits, axis=1)
    
    return id2label[prediction.item()]


def builder2(inputs):
    return eng_model(inputs)


demo = gr.Interface(builder, inputs=[gr.inputs.Dropdown(['Eng', 'Kor']), "text"], outputs="text", 
                            title=title, description=description, examples=examples)

# demo2 = gr.Interface(builder2, inputs="text", outputs="text", 
#                          title=title, theme="peach",
#                          allow_flagging="auto",
#                          description=description, examples=examples)

# demo3 = gr.Interface.load("models/mdj1412/movie_review_score_discriminator_eng", inputs="text", outputs="text", 
#                          title=title, theme="peach",
#                          allow_flagging="auto",
#                          description=description, examples=examples)
    
if __name__ == "__main__":
    # print(examples)
    demo.launch()
    # demo3.launch()