BeMerciless's picture
Update app.py
567331a
raw
history blame
No virus
2.05 kB
import gradio as gr
import torch
from transformers import AutoTokenizer
#def greet(name):
# return "Hello " + name + "!!"
def greet(sent,mode):
print("input_sent= " + sent)
if mode=='Malicious_comment':
pt_model ='best.pt'
if mode=='Economic_article':
pt_model ='best2.pt'
print(pt_model)
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print("device:",device)
device = "cuda" if torch.cuda.is_available() else "cpu"
model = torch.load(pt_model, map_location=device)
print(model)
MODEL_NAME = "beomi/KcELECTRA-base" # hugging face 에 등록된 모델
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model.eval() # 평가
# 입력문장 토크나이징
tokenized_sent = tokenizer(
sent,
return_tensors="pt",
truncation=True,
add_special_tokens=True,
max_length=128
)
# 모델 위치 gpu이동
tokenized_sent.to(device)
# 예측
with torch.no_grad():
outputs = model(
input_ids=tokenized_sent["input_ids"],
attention_mask=tokenized_sent["attention_mask"],
token_type_ids=tokenized_sent["token_type_ids"],
)
# 결과
logits = outputs[0] ## 마지막 노드에서 아무런 Activation Function을 거치지 않은 값을 Logit
logits = logits.detach().cpu()
result = logits.argmax(-1)
if mode=='Malicious_comment':
if result == 0:
result = sent + ">> 악성글로 판단됩니다. 조심하세요."
elif result ==1:
result= sent + ">> 악의적인 내용이 보이지 않습니다."
elif mode=='Economic_article':
if result == 0:
result = "중립"
elif result == 1:
result = "긍정"
elif result == 2:
result = "부정"
return result
intput="text"
input2= gr.Dropdown(choices=['Malicious_comment','Economic_article'])
iface = gr.Interface(fn=greet,title='Korean classification',description="한국어 악플 && 경제기사 긍부정 판별기",inputs=[intput,input2], outputs="text")
iface.launch()