Spaces:
Sleeping
Sleeping
import gradio as gr | |
import fasttext | |
from transformers import AutoModelForSequenceClassification | |
from transformers import AutoTokenizer | |
import numpy as np | |
import pandas as pd | |
import torch | |
id2label = {0: "NEGATIVE", 1: "POSITIVE"} | |
label2id = {"NEGATIVE": 0, "POSITIVE": 1} | |
title = "ํ๊ตญ์ด/์์ด ๊ฐ์ ๋ถ์ ์์ (๋ค์ด๋ฒ ์ํ ๋ฆฌ๋ทฐ๋ฅผ ํ์ฉ)" | |
description = "์ํํ์ ์ ๋ ฅํ์ฌ ๊ธ์ ์ ์ธ์ง ๋ถ์ ์ ์ธ์ง๋ฅผ ๋ถ๋ฅํ๋ ๋ชจ๋ธ์ ๋๋ค." | |
class LanguageIdentification: | |
def __init__(self): | |
pretrained_lang_model = "./lid.176.ftz" | |
self.model = fasttext.load_model(pretrained_lang_model) | |
def predict_lang(self, text): | |
predictions = self.model.predict(text, k=200) | |
return predictions | |
LANGUAGE = LanguageIdentification() | |
def tokenized_data(tokenizer, inputs): | |
return tokenizer.batch_encode_plus( | |
[inputs], | |
return_tensors="pt", | |
padding="max_length", | |
max_length=64, | |
truncation=True) | |
eng_model_name = "roberta-base" | |
eng_step = 1900 | |
eng_tokenizer = AutoTokenizer.from_pretrained(eng_model_name) | |
eng_file_name = "{}-{}.pt".format(eng_model_name, eng_step) | |
eng_state_dict = torch.load(eng_file_name) | |
eng_model = AutoModelForSequenceClassification.from_pretrained( | |
eng_model_name, num_labels=2, id2label=id2label, label2id=label2id, | |
state_dict=eng_state_dict | |
) | |
kor_model_name = "klue/roberta-small" | |
kor_step = 2400 | |
kor_tokenizer = AutoTokenizer.from_pretrained(kor_model_name) | |
kor_file_name = "{}-{}.pt".format(kor_model_name.replace('/', '_'), kor_step) | |
kor_state_dict = torch.load(kor_file_name) | |
kor_model = AutoModelForSequenceClassification.from_pretrained( | |
kor_model_name, num_labels=2, id2label=id2label, label2id=label2id, | |
state_dict=kor_state_dict | |
) | |
def builder(Lang, Text): | |
percent_kor, percent_eng = 0, 0 | |
text_list = Text.split(' ') | |
if Lang == '์ธ์ด๊ฐ์ง ๊ธฐ๋ฅ ์ฌ์ฉ': | |
pred = LANGUAGE.predict_lang(Text) | |
if '__label__en' in pred[0]: | |
Lang = 'Eng' | |
idx = pred[0].index('__label__en') | |
p_eng = pred[1][idx] | |
if '__label__ko' in pred[0]: | |
Lang = 'Kor' | |
idx = pred[0].index('__label__ko') | |
p_kor = pred[1][idx] | |
percent_kor = p_kor / (p_kor+p_eng) | |
percent_eng = p_eng / (p_kor+p_eng) | |
if Lang == 'Eng': | |
model = eng_model | |
tokenizer = eng_tokenizer | |
if percent_eng==0: percent_eng=1 | |
if Lang == 'Kor': | |
model = kor_model | |
tokenizer = kor_tokenizer | |
if percent_kor==0: percent_kor=1 | |
inputs = tokenized_data(tokenizer, Text) | |
model.eval() | |
with torch.no_grad(): | |
logits = model(input_ids=inputs['input_ids'], | |
attention_mask=inputs['attention_mask']).logits | |
m = torch.nn.Softmax(dim=1) | |
output = m(logits) | |
output_analysis = [] | |
for word in text_list: | |
tokenized_word = tokenized_data(tokenizer, word) | |
with torch.no_grad(): | |
logit = model(input_ids=tokenized_word['input_ids'], | |
attention_mask=tokenized_word['attention_mask']).logits | |
word_output = m(logit) | |
if word_output[0][1] > 0.99: | |
output_analysis.append( (word, '+++') ) | |
elif word_output[0][1] > 0.9: | |
output_analysis.append( (word, '++') ) | |
elif word_output[0][1] > 0.8: | |
output_analysis.append( (word, '+') ) | |
elif word_output[0][1] < 0.01: | |
output_analysis.append( (word, '---') ) | |
elif word_output[0][1] < 0.1: | |
output_analysis.append( (word, '--') ) | |
elif word_output[0][1] < 0.2: | |
output_analysis.append( (word, '-') ) | |
else: | |
output_analysis.append( (word, None) ) | |
return [ {'Kor': percent_kor, 'Eng': percent_eng}, | |
{id2label[1]: output[0][1].item(), id2label[0]: output[0][0].item()}, | |
output_analysis ] | |
return id2label[prediction.item()] | |
with gr.Blocks() as demo1: | |
gr.Markdown( | |
""" | |
<h1 align="center"> | |
ํ๊ตญ์ด/์์ด ๊ฐ์ ๋ถ์ ์์ | |
</h1> | |
""") | |
with gr.Row(): | |
with gr.Column(): | |
inputs_1 = gr.Dropdown(choices=['์ธ์ด๊ฐ์ง', 'Eng', 'Kor'], value='์ธ์ด๊ฐ์ง ๊ธฐ๋ฅ ์ฌ์ฉ', label='Lang') | |
inputs_2 = gr.Textbox(placeholder="๋ฆฌ๋ทฐ๋ฅผ ์ ๋ ฅํ์์ค.", label='Text') | |
with gr.Row(): | |
btn = gr.Button("์ ์ถํ๊ธฐ") | |
with gr.Column(): | |
output_1 = gr.Label(num_top_classes=3, label='Lang') | |
output_2 = gr.Label(num_top_classes=2, label='Result') | |
output_3 = gr.HighlightedText(label="Analysis", combine_adjacent=False) \ | |
.style(color_map={"+++": "#CF0000", "++": "#FF3232", "+": "#FFD4D4", "---": "#0004FE", "--": "#4C47FF", "-": "#BEBDFF"}) | |
btn.click(fn=builder, inputs=[inputs_1, inputs_2], outputs=[output_1, output_2, output_3]) | |
if __name__ == "__main__": | |
demo1.launch() |