sentBERT / app.py
sd99's picture
Update app.py
cf0a4fc
import gradio as gr
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
import torch
description = "Sentiment Analysis :) && :("
title = "SentBERT"
examples = [["That ice cream was really bad"], ["Great to meet you!"], ["Hey, there's a snake there"]]
class2interpret = {
0: 'Positive/Neutral',
1: 'Negative'
}
def classify(example):
tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
model = DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased")
inputs = tokenizer(example, return_tensors="pt")
with torch.no_grad():
logits = model(**inputs).logits
probs = torch.nn.Softmax(dim=1)(logits).tolist()[0]
return {class2interpret[0]: probs[0], class2interpret[1]: probs[1]}, {class2interpret[0]: probs[0], class2interpret[1]: probs[1]}
interface = gr.Interface(fn=classify, inputs='text', outputs=['label', 'json'], examples=examples, description=description, title=title)
interface.launch()