import gradio as gr from models import * from huggingface_hub import hf_hub_download import os from config import * ENTITY_REPO_ID = 'vaivTA/absa_v2_entity' ENTITY_FILENAME = "entity_model.pt" SENTIMENT_REPO_ID = 'vaivTA/absa_v2_sentiment' SENTIMENT_FILENAME = "sentiment_model.pt" print("downloading model...") sen_model_file = hf_hub_download(repo_id=SENTIMENT_REPO_ID, filename=SENTIMENT_FILENAME) entity_model_file = hf_hub_download(repo_id=ENTITY_REPO_ID, filename=ENTITY_FILENAME) base_model = base_model tokenizer = AutoTokenizer.from_pretrained(base_model) sen_model = Classifier(base_model, num_labels=2, device='cpu', tokenizer=tokenizer) sen_model.load_state_dict(torch.load(sen_model_file, map_location=torch.device('cpu'))) entity_model = Classifier(base_model, num_labels=2, device='cpu', tokenizer=tokenizer) entity_model.load_state_dict(torch.load(entity_model_file, map_location=torch.device('cpu'))) def infer(test_sentence): # entity_model.to(device) # entity_model.eval() # sen_model.to(device) # sen_model.eval() form = test_sentence annotation = [] if len(form) > 500: return "Too long sentence!" for pair in entity_property_pair: form_ = form + "[SEP]" pair_ = entity2str[pair] + "[SEP]" tokenized_data = tokenizer(form_, pair_, padding='max_length', max_length=512, truncation=True) input_ids = torch.tensor([tokenized_data['input_ids']]) attention_mask = torch.tensor([tokenized_data['attention_mask']]) first_sep = tokenized_data['input_ids'].index(2) last_sep = tokenized_data['input_ids'][first_sep+2:].index(2) + (first_sep + 2) mask = [0] * len(tokenized_data['input_ids']) for i in range(first_sep + 2, last_sep): mask[i] = 1 mask = torch.tensor([mask]) with torch.no_grad(): outputs = entity_model(input_ids, attention_mask, mask) ce_logits = outputs ce_predictions = torch.argmax(ce_logits, dim = -1) ce_result = tf_id_to_name[ce_predictions[0]] if ce_result == 'True': with torch.no_grad(): outputs = sen_model(input_ids, attention_mask, mask) pc_logits = outputs pc_predictions = torch.argmax(pc_logits, dim=-1) pc_result = polarity_id_to_name[pc_predictions[0]] annotation.append(f"{pair} - {pc_result}") result = '\n'.join(annotation) return result demo = gr.Interface(fn=infer, inputs=gr.Textbox(type="text", label="Input Sentence"), outputs=gr.Textbox(type="text", label="Result Sentence"), article="**리뷰 예시** : 아파트는 오래되었지만 동네가 조용하고 쾌적하여 살기에는 아주 좋습니다. 큰 마트가 주변에 없는 단점이 잇지만 이촌역이 매우 가깝고 생활권 내에 맛잇는 식당과 커피숖이 즐비합니다 ㅎㅎ" ) demo.launch(share=True)