mdj1412's picture
Update app.py
bde04fe
raw history blame
No virus
8.39 kB
import gradio as gr
import fasttext
from transformers import AutoModelForSequenceClassification
from transformers import AutoTokenizer
import numpy as np
import pandas as pd
import torch
id2label = {0: "NEGATIVE", 1: "POSITIVE"}
label2id = {"NEGATIVE": 0, "POSITIVE": 1}
title = "Movie Review Score Discriminator"
description = "It is a program that classifies whether it is positive or negative by entering movie reviews. \
You can choose between the Korean version and the English version. \
It also provides a version called ""Default"", which determines whether it is Korean or English and predicts it."
class LanguageIdentification:
def __init__(self):
pretrained_lang_model = "./lid.176.ftz"
self.model = fasttext.load_model(pretrained_lang_model)
def predict_lang(self, text):
predictions = self.model.predict(text, k=200) # returns top 200 matching languages
return predictions
LANGUAGE = LanguageIdentification()
def tokenized_data(tokenizer, inputs):
return tokenizer.batch_encode_plus(
[inputs],
return_tensors="pt",
padding="max_length",
max_length=64,
truncation=True)
examples = []
df = pd.read_csv('examples.csv', sep='\t', index_col='Unnamed: 0')
np.random.seed(100)
idx = np.random.choice(50, size=5, replace=False)
eng_examples = [ ['Eng', df.iloc[i, 0]] for i in idx ]
kor_examples = [ ['Kor', df.iloc[i, 1]] for i in idx ]
examples = eng_examples + kor_examples
eng_model_name = "roberta-base"
eng_step = 1900
eng_tokenizer = AutoTokenizer.from_pretrained(eng_model_name)
eng_file_name = "{}-{}.pt".format(eng_model_name, eng_step)
eng_state_dict = torch.load(eng_file_name)
eng_model = AutoModelForSequenceClassification.from_pretrained(
eng_model_name, num_labels=2, id2label=id2label, label2id=label2id,
state_dict=eng_state_dict
)
kor_model_name = "klue/roberta-small"
kor_step = 2400
kor_tokenizer = AutoTokenizer.from_pretrained(kor_model_name)
kor_file_name = "{}-{}.pt".format(kor_model_name.replace('/', '_'), kor_step)
kor_state_dict = torch.load(kor_file_name)
kor_model = AutoModelForSequenceClassification.from_pretrained(
kor_model_name, num_labels=2, id2label=id2label, label2id=label2id,
state_dict=kor_state_dict
)
def builder(Lang, Text):
percent_kor, percent_eng = 0, 0
text_list = Text.split(' ')
# [ output_1 ]
if Lang == '์–ธ์–ด๊ฐ์ง€ ๊ธฐ๋Šฅ ์‚ฌ์šฉ':
pred = LANGUAGE.predict_lang(Text)
if '__label__en' in pred[0]:
Lang = 'Eng'
idx = pred[0].index('__label__en')
p_eng = pred[1][idx]
if '__label__ko' in pred[0]:
Lang = 'Kor'
idx = pred[0].index('__label__ko')
p_kor = pred[1][idx]
# Normalize Percentage
percent_kor = p_kor / (p_kor+p_eng)
percent_eng = p_eng / (p_kor+p_eng)
if Lang == 'Eng':
model = eng_model
tokenizer = eng_tokenizer
if percent_eng==0: percent_eng=1
if Lang == 'Kor':
model = kor_model
tokenizer = kor_tokenizer
if percent_kor==0: percent_kor=1
# [ output_2 ]
inputs = tokenized_data(tokenizer, Text)
model.eval()
with torch.no_grad():
logits = model(input_ids=inputs['input_ids'],
attention_mask=inputs['attention_mask']).logits
m = torch.nn.Softmax(dim=1)
output = m(logits)
# print(logits, output)
# [ output_3 ]
output_analysis = []
for word in text_list:
tokenized_word = tokenized_data(tokenizer, word)
with torch.no_grad():
logit = model(input_ids=tokenized_word['input_ids'],
attention_mask=tokenized_word['attention_mask']).logits
word_output = m(logit)
if word_output[0][1] > 0.99:
output_analysis.append( (word, '+++') )
elif word_output[0][1] > 0.9:
output_analysis.append( (word, '++') )
elif word_output[0][1] > 0.8:
output_analysis.append( (word, '+') )
elif word_output[0][1] < 0.01:
output_analysis.append( (word, '---') )
elif word_output[0][1] < 0.1:
output_analysis.append( (word, '--') )
elif word_output[0][1] < 0.2:
output_analysis.append( (word, '-') )
else:
output_analysis.append( (word, None) )
return [ {'Kor': percent_kor, 'Eng': percent_eng},
{id2label[1]: output[0][1].item(), id2label[0]: output[0][0].item()},
output_analysis ]
# prediction = torch.argmax(logits, axis=1)
return id2label[prediction.item()]
# demo3 = gr.Interface.load("models/mdj1412/movie_review_score_discriminator_eng", inputs="text", outputs="text",
# title=title, theme="peach",
# allow_flagging="auto",
# description=description, examples=examples)
# demo = gr.Interface(builder, inputs=[gr.inputs.Dropdown(['Default', 'Eng', 'Kor']), gr.Textbox(placeholder="๋ฆฌ๋ทฐ๋ฅผ ์ž…๋ ฅํ•˜์‹œ์˜ค.")],
# outputs=[ gr.Label(num_top_classes=3, label='Lang'),
# gr.Label(num_top_classes=2, label='Result'),
# gr.HighlightedText(label="Analysis", combine_adjacent=False)
# .style(color_map={"+++": "#CF0000", "++": "#FF3232", "+": "#FFD4D4", "---": "#0004FE", "--": "#4C47FF", "-": "#BEBDFF"}) ],
# # outputs='label',
# title=title, description=description, examples=examples)
with gr.Blocks() as demo1:
gr.Markdown(
"""
<h1 align="center">
Movie Review Score Discriminator
</h1>
""")
gr.Markdown(
"""
์˜ํ™” ๋ฆฌ๋ทฐ๋ฅผ ์ž…๋ ฅํ•˜๋ฉด, ๋ฆฌ๋ทฐ๊ฐ€ ๊ธ์ •์ธ์ง€ ๋ถ€์ •์ธ์ง€ ํŒ๋ณ„ํ•ด์ฃผ๋Š” ๋ชจ๋ธ์ด๋‹ค. \
์˜์–ด์™€ ํ•œ๊ธ€์„ ์ง€์›ํ•˜๋ฉฐ, ์–ธ์–ด๋ฅผ ์ง์ ‘ ์„ ํƒํ• ์ˆ˜๋„, ํ˜น์€ ๋ชจ๋ธ์ด ์–ธ์–ด๊ฐ์ง€๋ฅผ ์ง์ ‘ ํ•˜๋„๋ก ํ•  ์ˆ˜ ์žˆ๋‹ค.
๋ฆฌ๋ทฐ๋ฅผ ์ž…๋ ฅํ•˜๋ฉด, (1) ๊ฐ์ง€๋œ ์–ธ์–ด, (2) ๊ธ์ • ๋ฆฌ๋ทฐ์ผ ํ™•๋ฅ ๊ณผ ๋ถ€์ • ๋ฆฌ๋ทฐ์ผ ํ™•๋ฅ , (3) ์ž…๋ ฅ๋œ ๋ฆฌ๋ทฐ์˜ ์–ด๋Š ๋‹จ์–ด๊ฐ€ ๊ธ์ •/๋ถ€์ • ๊ฒฐ์ •์— ์˜ํ–ฅ์„ ์ฃผ์—ˆ๋Š”์ง€ \
(๊ธ์ •์ผ ๊ฒฝ์šฐ ๋นจ๊ฐ•์ƒ‰, ๋ถ€์ •์ผ ๊ฒฝ์šฐ ํŒŒ๋ž€์ƒ‰)๋ฅผ ํ™•์ธํ•  ์ˆ˜ ์žˆ๋‹ค.
""")
with gr.Accordion(label="๋ชจ๋ธ์— ๋Œ€ํ•œ ์„ค๋ช…", open=False):
gr.Markdown(
"""
์˜์–ด ๋ชจ๋ธ์€ bert-base-uncased ๊ธฐ๋ฐ˜์œผ๋กœ, ์˜์–ด ์˜ํ™” ๋ฆฌ๋ทฐ ๋ถ„์„ ๋ฐ์ดํ„ฐ์…‹์ธ SST-2๋กœ ํ•™์Šต ๋ฐ ํ‰๊ฐ€๋˜์—ˆ๋‹ค.
ํ•œ๊ธ€ ๋ชจ๋ธ์€ klue/roberta-base ๊ธฐ๋ฐ˜์ด๋‹ค. ๊ธฐ์กด ํ•œ๊ธ€ ์˜ํ™” ๋ฆฌ๋ทฐ ๋ถ„์„ ๋ฐ์ดํ„ฐ์…‹์ด ์กด์žฌํ•˜์ง€ ์•Š์•„, ๋„ค์ด๋ฒ„ ์˜ํ™”์˜ ๋ฆฌ๋ทฐ๋ฅผ ํฌ๋กค๋งํ•ด์„œ ์˜ํ™” ๋ฆฌ๋ทฐ ๋ถ„์„ ๋ฐ์ดํ„ฐ์…‹์„ ์ œ์ž‘ํ•˜๊ณ , ์ด๋ฅผ ์ด์šฉํ•˜์—ฌ ๋ชจ๋ธ์„ ํ•™์Šต ๋ฐ ํ‰๊ฐ€ํ•˜์˜€๋‹ค.
์˜์–ด ๋ชจ๋ธ์€ SST-2์—์„œ 92.8%, ํ•œ๊ธ€ ๋ชจ๋ธ์€ ๋„ค์ด๋ฒ„ ์˜ํ™” ๋ฆฌ๋ทฐ ๋ฐ์ดํ„ฐ์…‹์—์„œ 94%์˜ ์ •ํ™•๋„๋ฅผ ๊ฐ€์ง„๋‹ค (test set ๊ธฐ์ค€).
์–ธ์–ด๊ฐ์ง€๋Š” fasttext์˜ language detector๋ฅผ ์‚ฌ์šฉํ•˜์˜€๋‹ค. ๋ฆฌ๋ทฐ์˜ ๋‹จ์–ด๋ณ„ ์˜ํ–ฅ๋ ฅ์€, ๋‹จ์–ด ๊ฐ๊ฐ์„ ๋ชจ๋ธ์— ๋„ฃ์—ˆ์„ ๋•Œ ๊ฒฐ๊ณผ๊ฐ€ ๊ธ์ •์œผ๋กœ ๋‚˜์˜ค๋Š”์ง€ ๋ถ€์ •์œผ๋กœ ๋‚˜์˜ค๋Š”์ง€๋ฅผ ๋ฐ”ํƒ•์œผ๋กœ ์ธก์ •ํ•˜์˜€๋‹ค.
""")
with gr.Row():
with gr.Column():
inputs_1 = gr.Dropdown(choices=['์–ธ์–ด๊ฐ์ง€ ๊ธฐ๋Šฅ ์‚ฌ์šฉ', 'Eng', 'Kor'], value='์–ธ์–ด๊ฐ์ง€ ๊ธฐ๋Šฅ ์‚ฌ์šฉ', label='Lang')
inputs_2 = gr.Textbox(placeholder="๋ฆฌ๋ทฐ๋ฅผ ์ž…๋ ฅํ•˜์‹œ์˜ค.", label='Text')
with gr.Row():
# btn2 = gr.Button("ํด๋ฆฌ์–ด")
btn = gr.Button("์ œ์ถœํ•˜๊ธฐ")
with gr.Column():
output_1 = gr.Label(num_top_classes=3, label='Lang')
output_2 = gr.Label(num_top_classes=2, label='Result')
output_3 = gr.HighlightedText(label="Analysis", combine_adjacent=False) \
.style(color_map={"+++": "#CF0000", "++": "#FF3232", "+": "#FFD4D4", "---": "#0004FE", "--": "#4C47FF", "-": "#BEBDFF"})
# btn2.click(fn=fn2, inputs=[None, None], output=[output_1, output_2, output_3])
btn.click(fn=builder, inputs=[inputs_1, inputs_2], outputs=[output_1, output_2, output_3])
gr.Examples(examples, inputs=[inputs_1, inputs_2])
if __name__ == "__main__":
# print(examples)
# demo.launch()
demo1.launch()