File size: 1,759 Bytes
4c3d645
5653972
4c3d645
a81ae9a
4c3d645
 
5653972
abbb908
5653972
 
 
 
 
 
 
4c3d645
5653972
 
 
 
4c3d645
a81ae9a
5653972
 
 
 
 
 
 
 
 
 
4c3d645
a81ae9a
 
5653972
 
 
 
 
 
 
 
a81ae9a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
import gradio as gr
import torch
from transformers import AutoModelForSequenceClassification
from transformers import AutoTokenizer
from transformers import pipeline

model_path = "trnt/twitter_emotions"
is_gpu = False
device = torch.device('cuda') if is_gpu else torch.device('cpu')
print(device)
model = AutoModelForSequenceClassification.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path)
model.to(device)
model.eval()
print("Model was loaded")

classifier = pipeline("text-classification", model=model, tokenizer=tokenizer, device=is_gpu-1)
emotions = {'LABEL_0': 'sadness', 'LABEL_1': 'joy', 'LABEL_2': 'love', 'LABEL_3': 'anger', 'LABEL_4': 'fear',
            'LABEL_5': 'surprise'}
examples = ["I love you!", "I hate you!"]


def predict(twitter):
    pred = classifier(twitter, return_all_scores=True)[0]
    res = {"Sadness": pred[0]["score"],
           "Joy": pred[1]["score"],
           "Love": pred[2]["score"],
           "Anger": pred[3]["score"],
           "Fear": pred[4]["score"],
           "Surprise": pred[5]["score"]}
    # "This tweet is %s with probability=%.2f" % (emotions[pred['label']], 100 * pred['score']) + "%"
    return res


if __name__ == '__main__':
    interFace = gr.Interface(fn=predict,
                             inputs=gr.inputs.Textbox(placeholder="Enter a tweet here", label="Tweet content", lines=5),
                             outputs=gr.outputs.Label(num_top_classes=6, label="Emotions of this tweet is "),
                             verbose=True,
                             examples=examples,
                             title="Emotions of English tweet",
                             description="",
                             theme="grass")
    interFace.launch()