|
import gradio as gr |
|
from models import * |
|
from huggingface_hub import hf_hub_download |
|
import os |
|
from config import * |
|
|
|
ENTITY_REPO_ID = 'vaivTA/absa_v2_entity' |
|
ENTITY_FILENAME = "entity_model.pt" |
|
|
|
SENTIMENT_REPO_ID = 'vaivTA/absa_v2_sentiment' |
|
SENTIMENT_FILENAME = "sentiment_model.pt" |
|
|
|
print("downloading model...") |
|
sen_model_file = hf_hub_download(repo_id=SENTIMENT_REPO_ID, filename=SENTIMENT_FILENAME) |
|
entity_model_file = hf_hub_download(repo_id=ENTITY_REPO_ID, filename=ENTITY_FILENAME) |
|
|
|
base_model = base_model |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(base_model) |
|
|
|
sen_model = Classifier(base_model, num_labels=2, device='cpu', tokenizer=tokenizer) |
|
sen_model.load_state_dict(torch.load(sen_model_file, map_location=torch.device('cpu'))) |
|
|
|
entity_model = Classifier(base_model, num_labels=2, device='cpu', tokenizer=tokenizer) |
|
entity_model.load_state_dict(torch.load(entity_model_file, map_location=torch.device('cpu'))) |
|
|
|
|
|
def infer(test_sentence): |
|
|
|
|
|
|
|
|
|
|
|
form = test_sentence |
|
annotation = [] |
|
|
|
if len(form) > 500: |
|
return "Too long sentence!" |
|
|
|
|
|
for pair in entity_property_pair: |
|
|
|
form_ = form + "[SEP]" |
|
pair_ = entity2str[pair] + "[SEP]" |
|
|
|
tokenized_data = tokenizer(form_, pair_, padding='max_length', max_length=512, truncation=True) |
|
|
|
input_ids = torch.tensor([tokenized_data['input_ids']]) |
|
attention_mask = torch.tensor([tokenized_data['attention_mask']]) |
|
|
|
first_sep = tokenized_data['input_ids'].index(2) |
|
last_sep = tokenized_data['input_ids'][first_sep+2:].index(2) + (first_sep + 2) |
|
mask = [0] * len(tokenized_data['input_ids']) |
|
for i in range(first_sep + 2, last_sep): |
|
mask[i] = 1 |
|
mask = torch.tensor([mask]) |
|
|
|
with torch.no_grad(): |
|
outputs = entity_model(input_ids, attention_mask, mask) |
|
ce_logits = outputs |
|
ce_predictions = torch.argmax(ce_logits, dim = -1) |
|
|
|
ce_result = tf_id_to_name[ce_predictions[0]] |
|
|
|
if ce_result == 'True': |
|
with torch.no_grad(): |
|
outputs = sen_model(input_ids, attention_mask, mask) |
|
pc_logits = outputs |
|
pc_predictions = torch.argmax(pc_logits, dim=-1) |
|
pc_result = polarity_id_to_name[pc_predictions[0]] |
|
|
|
annotation.append(f"{pair} - {pc_result}") |
|
|
|
result = '\n'.join(annotation) |
|
return result |
|
|
|
|
|
demo = gr.Interface(fn=infer, |
|
inputs=gr.Textbox(type="text", label="Input Sentence"), |
|
outputs=gr.Textbox(type="text", label="Result Sentence"), |
|
article="**리뷰 μμ** : μννΈλ μ€λλμμ§λ§ λλ€κ° μ‘°μ©νκ³ μΎμ νμ¬ μ΄κΈ°μλ μμ£Ό μ’μ΅λλ€. ν° λ§νΈκ° μ£Όλ³μ μλ λ¨μ μ΄ μμ§λ§ μ΄μ΄μμ΄ λ§€μ° κ°κΉκ³ μνκΆ λ΄μ λ§μλ μλΉκ³Ό 컀νΌμμ΄ μ¦λΉν©λλ€ γ
γ
" |
|
) |
|
|
|
demo.launch(share=True) |
|
|
|
|
|
|
|
|
|
|