File size: 1,824 Bytes
41889a9
 
 
 
72b7afd
41889a9
292d488
 
 
71411c5
 
 
 
 
72b7afd
292d488
 
 
 
71411c5
 
 
 
 
 
 
 
 
 
 
 
41889a9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
292d488
 
72b7afd
41889a9
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
import numpy as np
import gradio as gr

device = 'cuda' if torch.cuda.is_available() else 'cpu'


MODEL_PATH = 'finiteautomata/bertweet-base-sentiment-analysis'
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
model = AutoModelForSequenceClassification.from_pretrained(MODEL_PATH)
model = model.to(device)

query = 'I am not having a great day.'
inputs = tokenizer(query, return_tensors='pt', truncation=True)
inputs = inputs.to(device)
outputs = model(**inputs)
label2id = model.config.label2id

logits = outputs.logits
sigmoid = torch.nn.Sigmoid()
probs = sigmoid(logits.squeeze().cpu())
probs = probs.detach().numpy()

for i, k in enumerate(label2id.keys()):
    label2id[k] = probs[i]


label2id = {k: v for k, v in sorted(label2id.items(), key=lambda item: item[1], reverse=True)}

tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
model = AutoModelForSequenceClassification.from_pretrained(MODEL_PATH)
model = model.to(device)

def get_predictions(input_text: str) -> dict:
    label2id = model.config.label2id
    inputs = tokenizer(input_text, return_tensors='pt', truncation=True)
    inputs = inputs.to(device)
    outputs = model(**inputs)
    logits = outputs.logits
    sigmoid = torch.nn.Sigmoid()
    probs = sigmoid(logits.squeeze().cpu())
    probs = probs.detach().numpy()
    for i, k in enumerate(label2id.keys()):
        label2id[k] = probs[i]
    label2id = {k: float(v) for k, v in sorted(label2id.items(), key=lambda item: item[1].item(), reverse=True)}
    return label2id
    

gr.Interface(
    fn=get_predictions,
    inputs=gr.components.Textbox(label='Input'),
    outputs=gr.components.Label(label='Predictions', num_top_classes=3),
    allow_flagging='never'
).launch(debug='True')