thiruvanth's picture
Update app.py
292d488 verified
raw
history blame
1.78 kB
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
import numpy as np
device = 'cuda' if torch.cuda.is_available() else 'cpu'
MODEL_PATH = 'finiteautomata/bertweet-base-sentiment-analysis'
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
model = AutoModelForSequenceClassification.from_pretrained(MODEL_PATH)
model = model.to(device)
inputs = tokenizer(query, return_tensors='pt', truncation=True)
inputs = inputs.to(device)
outputs = model(**inputs)
label2id = model.config.label2id
logits = outputs.logits
sigmoid = torch.nn.Sigmoid()
probs = sigmoid(logits.squeeze().cpu())
probs = probs.detach().numpy()
for i, k in enumerate(label2id.keys()):
label2id[k] = probs[i]
label2id = {k: v for k, v in sorted(label2id.items(), key=lambda item: item[1], reverse=True)}
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
model = AutoModelForSequenceClassification.from_pretrained(MODEL_PATH)
model = model.to(device)
def get_predictions(input_text: str) -> dict:
label2id = model.config.label2id
inputs = tokenizer(input_text, return_tensors='pt', truncation=True)
inputs = inputs.to(device)
outputs = model(**inputs)
logits = outputs.logits
sigmoid = torch.nn.Sigmoid()
probs = sigmoid(logits.squeeze().cpu())
probs = probs.detach().numpy()
for i, k in enumerate(label2id.keys()):
label2id[k] = probs[i]
label2id = {k: float(v) for k, v in sorted(label2id.items(), key=lambda item: item[1].item(), reverse=True)}
return label2id
import gradio as gr
gr.Interface(
fn=get_predictions,
inputs=gr.components.Textbox(label='Input'),
outputs=gr.components.Label(label='Predictions', num_top_classes=3),
allow_flagging='never'
).launch(debug='True')