joonkim's picture
Update app.py
8f775fb
raw
history blame contribute delete
No virus
1.4 kB
import gradio as gr
import numpy as np
import torch.nn.functional as F
from model import *
from transformers import BertTokenizer
DEVICE = torch.device('cpu')
PATH = 'checkpoints/'
model = torch.load(PATH + 'model.pt', map_location=DEVICE)
model.load_state_dict(torch.load(PATH + 'model_state_dict.pt',
map_location=DEVICE))
model.eval()
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
def evaluate(text) :
encoding = tokenizer.encode_plus(
text, max_length=150,
padding='max_length',
truncation=True,
add_special_tokens=True,
return_token_type_ids=False,
return_attention_mask=True,
return_tensors='pt'
)
input_id = encoding['input_ids']
attention_mask = encoding['attention_mask']
result = F.softmax(model(input_id, attention_mask), dim=1)
with torch.no_grad() :
result = np.round(result.numpy(), 2).tolist()
return {'Liberal': result[0][0], 'Conservative': result[0][1]}
iface = gr.Interface(fn=evaluate,
inputs='text',
outputs=gr.components.Label(num_top_classes=2),
examples=[["Biden speech draws 38.2 million U.S. TV viewers"],
["Biden's first State of the Union address in 67 seconds"]],
title='Political Sentiment Classification Using BERT Transformer')
iface.launch()