kiptoozeff's picture
Update app.py
433ad7a verified
raw
history blame contribute delete
No virus
1.56 kB
import gradio as gr
import torch
from transformers import AutoModelForSequenceClassification
from transformers import AutoTokenizer
from transformers import pipeline
model_path = "trnt/twitter_emotions"
is_gpu = False
device = torch.device('cuda') if is_gpu else torch.device('cpu')
print(device)
model = AutoModelForSequenceClassification.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path)
model.to(device)
model.eval()
print("Model was loaded")
classifier = pipeline("text-classification", model=model, tokenizer=tokenizer, device=is_gpu-1)
emotions = {'LABEL_0': 'sadness', 'LABEL_1': 'joy', 'LABEL_2': 'love', 'LABEL_3': 'anger', 'LABEL_4': 'fear',
'LABEL_5': 'surprise'}
examples = ["I am a happy man", "I hate you!"]
def predict(twitter):
pred = classifier(twitter, return_all_scores=True)[0]
res = {"Sadness": pred[0]["score"],
"Joy": pred[1]["score"],
"Love": pred[2]["score"],
"Anger": pred[3]["score"],
"Fear": pred[4]["score"],
"Surprise": pred[5]["score"]}
return res
if __name__ == '__main__':
interFace = gr.Interface(fn=predict,
inputs=gr.Textbox(placeholder="Enter a tweet here", label="Tweet content", lines=5),
outputs=gr.Label(num_top_classes=6, label="Emotions of this tweet is "),
examples=examples,
title="Emotions of English tweet",
description="")
interFace.launch()