File size: 1,881 Bytes
c7a74ff
b56b900
c7a74ff
82ab74d
 
 
 
 
 
 
 
 
 
 
3255f31
82ab74d
 
 
 
 
 
 
 
 
 
 
 
99aedab
033d543
 
 
 
 
 
82ab74d
 
c7a74ff
b56b900
 
033d543
 
3255f31
033d543
82ab74d
 
033d543
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import gradio as gr
from gradio.components import Textbox

import torch
from transformers import pipeline
from transformers import AutoTokenizer
from transformers import AutoModelForSequenceClassification

# Load the DistilBERT tokenizer
tokenizer = AutoTokenizer.from_pretrained('distilbert-base-uncased')

#Load the model
model = AutoModelForSequenceClassification.from_pretrained("skylord/pharma_classification")


def is_pharma(sentence, tokenize=tokenizer, model=model):
    # tokenize the input
    inputs = tokenizer(sentence, return_tensors='pt')
    # ensure model and inputs are on the same device (GPU)
    inputs = {name: tensor.cuda() for name, tensor in inputs.items()}
    model = model.cuda()
    # get prediction - 2 classes "probabilities" (not really true because they still need to be normalized)
    with torch.no_grad():
        predictions = model(**inputs)[0].cpu().numpy()
    # get the top prediction class and convert it to its associated label
    top_prediction = predictions.argmax().item()
    return ds['train'].features['labels'].int2str(top_prediction)

def predict_sentiment(text):
    """
    Predicts the sentiment of the input text using DistilBERT.
    :param text: str, input text to analyze.
    :return: str, predicted sentiment and confidence score.
    """
    result = is_pharma(text)
    return f"TAG: {result}" #, Confidence: {score:.2f}

input1 = Textbox(lines=2, placeholder="Type your text here...")

# Create a Gradio interface
iface = gr.Interface(fn=predict_sentiment,
                     inputs=input1,
                     outputs="text",
                     title="Identify if the news item is relevant to the pharma industry",
                     description="This model predicts the tag of the input text. Enter a sentence to see if it's pharma or not. Response is a Yes or a No")

# Launch the interface
iface.launch()