Spaces:
Runtime error
Runtime error
File size: 1,824 Bytes
41889a9 72b7afd 41889a9 292d488 71411c5 72b7afd 292d488 71411c5 41889a9 292d488 72b7afd 41889a9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 |
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
import numpy as np
import gradio as gr
device = 'cuda' if torch.cuda.is_available() else 'cpu'
MODEL_PATH = 'finiteautomata/bertweet-base-sentiment-analysis'
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
model = AutoModelForSequenceClassification.from_pretrained(MODEL_PATH)
model = model.to(device)
query = 'I am not having a great day.'
inputs = tokenizer(query, return_tensors='pt', truncation=True)
inputs = inputs.to(device)
outputs = model(**inputs)
label2id = model.config.label2id
logits = outputs.logits
sigmoid = torch.nn.Sigmoid()
probs = sigmoid(logits.squeeze().cpu())
probs = probs.detach().numpy()
for i, k in enumerate(label2id.keys()):
label2id[k] = probs[i]
label2id = {k: v for k, v in sorted(label2id.items(), key=lambda item: item[1], reverse=True)}
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
model = AutoModelForSequenceClassification.from_pretrained(MODEL_PATH)
model = model.to(device)
def get_predictions(input_text: str) -> dict:
label2id = model.config.label2id
inputs = tokenizer(input_text, return_tensors='pt', truncation=True)
inputs = inputs.to(device)
outputs = model(**inputs)
logits = outputs.logits
sigmoid = torch.nn.Sigmoid()
probs = sigmoid(logits.squeeze().cpu())
probs = probs.detach().numpy()
for i, k in enumerate(label2id.keys()):
label2id[k] = probs[i]
label2id = {k: float(v) for k, v in sorted(label2id.items(), key=lambda item: item[1].item(), reverse=True)}
return label2id
gr.Interface(
fn=get_predictions,
inputs=gr.components.Textbox(label='Input'),
outputs=gr.components.Label(label='Predictions', num_top_classes=3),
allow_flagging='never'
).launch(debug='True') |