skylord's picture
Update app.py
82ab74d verified
raw history blame
No virus
1.88 kB
import gradio as gr
from gradio.components import Textbox
import torch
from transformers import pipeline
from transformers import AutoTokenizer
from transformers import AutoModelForSequenceClassification
# Load the DistilBERT tokenizer
tokenizer = AutoTokenizer.from_pretrained('distilbert-base-uncased')
#Load the model
model = AutoModelForSequenceClassification.from_pretrained("skylord/pharma_classification")
def is_pharma(sentence, tokenize=tokenizer, model=model):
# tokenize the input
inputs = tokenizer(sentence, return_tensors='pt')
# ensure model and inputs are on the same device (GPU)
inputs = {name: tensor.cuda() for name, tensor in inputs.items()}
model = model.cuda()
# get prediction - 2 classes "probabilities" (not really true because they still need to be normalized)
with torch.no_grad():
predictions = model(**inputs)[0].cpu().numpy()
# get the top prediction class and convert it to its associated label
top_prediction = predictions.argmax().item()
return ds['train'].features['labels'].int2str(top_prediction)
def predict_sentiment(text):
"""
Predicts the sentiment of the input text using DistilBERT.
:param text: str, input text to analyze.
:return: str, predicted sentiment and confidence score.
"""
result = is_pharma(text)
return f"TAG: {result}" #, Confidence: {score:.2f}
input1 = Textbox(lines=2, placeholder="Type your text here...")
# Create a Gradio interface
iface = gr.Interface(fn=predict_sentiment,
inputs=input1,
outputs="text",
title="Identify if the news item is relevant to the pharma industry",
description="This model predicts the tag of the input text. Enter a sentence to see if it's pharma or not. Response is a Yes or a No")
# Launch the interface
iface.launch()