File size: 3,097 Bytes
d9b3b55 0cd2c97 3a9c126 0cd2c97 3a9c126 0cd2c97 499c31b 0cd2c97 8966d80 b59964c 0cd2c97 8966d80 b59964c 0cd2c97 8966d80 0cd2c97 b59964c 0cd2c97 b59964c 0cd2c97 95ba507 0cd2c97 d9b3b55 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 |
import gradio as gr
from models import *
from huggingface_hub import hf_hub_download
import os
from config import *
ENTITY_REPO_ID = 'vaivTA/absa_v2_entity'
ENTITY_FILENAME = "entity_model.pt"
SENTIMENT_REPO_ID = 'vaivTA/absa_v2_sentiment'
SENTIMENT_FILENAME = "sentiment_model.pt"
print("downloading model...")
sen_model_file = hf_hub_download(repo_id=SENTIMENT_REPO_ID, filename=SENTIMENT_FILENAME)
entity_model_file = hf_hub_download(repo_id=ENTITY_REPO_ID, filename=ENTITY_FILENAME)
base_model = base_model
tokenizer = AutoTokenizer.from_pretrained(base_model)
sen_model = Classifier(base_model, num_labels=2, device='cpu', tokenizer=tokenizer)
sen_model.load_state_dict(torch.load(sen_model_file, map_location=torch.device('cpu')))
entity_model = Classifier(base_model, num_labels=2, device='cpu', tokenizer=tokenizer)
entity_model.load_state_dict(torch.load(entity_model_file, map_location=torch.device('cpu')))
def infer(test_sentence):
# entity_model.to(device)
# entity_model.eval()
# sen_model.to(device)
# sen_model.eval()
form = test_sentence
annotation = []
if len(form) > 500:
return "Too long sentence!"
for pair in entity_property_pair:
form_ = form + "[SEP]"
pair_ = entity2str[pair] + "[SEP]"
tokenized_data = tokenizer(form_, pair_, padding='max_length', max_length=512, truncation=True)
input_ids = torch.tensor([tokenized_data['input_ids']])
attention_mask = torch.tensor([tokenized_data['attention_mask']])
first_sep = tokenized_data['input_ids'].index(2)
last_sep = tokenized_data['input_ids'][first_sep+2:].index(2) + (first_sep + 2)
mask = [0] * len(tokenized_data['input_ids'])
for i in range(first_sep + 2, last_sep):
mask[i] = 1
mask = torch.tensor([mask])
with torch.no_grad():
outputs = entity_model(input_ids, attention_mask, mask)
ce_logits = outputs
ce_predictions = torch.argmax(ce_logits, dim = -1)
ce_result = tf_id_to_name[ce_predictions[0]]
if ce_result == 'True':
with torch.no_grad():
outputs = sen_model(input_ids, attention_mask, mask)
pc_logits = outputs
pc_predictions = torch.argmax(pc_logits, dim=-1)
pc_result = polarity_id_to_name[pc_predictions[0]]
annotation.append(f"{pair} - {pc_result}")
result = '\n'.join(annotation)
return result
demo = gr.Interface(fn=infer,
inputs=gr.Textbox(type="text", label="Input Sentence"),
outputs=gr.Textbox(type="text", label="Result Sentence"),
article="**리뷰 μμ** : μννΈλ μ€λλμμ§λ§ λλ€κ° μ‘°μ©νκ³ μΎμ νμ¬ μ΄κΈ°μλ μμ£Ό μ’μ΅λλ€. ν° λ§νΈκ° μ£Όλ³μ μλ λ¨μ μ΄ μμ§λ§ μ΄μ΄μμ΄ λ§€μ° κ°κΉκ³ μνκΆ λ΄μ λ§μλ μλΉκ³Ό 컀νΌμμ΄ μ¦λΉν©λλ€ γ
γ
"
)
demo.launch(share=True)
|