|
import gradio as gr |
|
import torch |
|
from transformers import ( |
|
BertTokenizer, |
|
XLNetTokenizer, |
|
GPT2Tokenizer, |
|
AutoModelForSequenceClassification |
|
) |
|
|
|
|
|
model_repos = { |
|
"BERT": "sk23aib/Bert_emotion_model", |
|
"XLNet": "sk23aib/Xlnet_emotion_model", |
|
"GPT-2": "sk23aib/Gpt2_emotion_model" |
|
} |
|
|
|
|
|
emotion_labels = [ |
|
"anger", "boredom", "empty", "enthusiasm", "fun", "happiness", "hate", |
|
"love", "neutral", "relief", "sadness", "surprise", "worry" |
|
] |
|
|
|
|
|
loaded_models = {} |
|
|
|
|
|
bert_tokenizer = BertTokenizer.from_pretrained(model_repos["BERT"]) |
|
bert_model = AutoModelForSequenceClassification.from_pretrained(model_repos["BERT"]) |
|
bert_model.eval() |
|
loaded_models["BERT"] = {"tokenizer": bert_tokenizer, "model": bert_model} |
|
|
|
|
|
xlnet_tokenizer = XLNetTokenizer.from_pretrained(model_repos["XLNet"]) |
|
xlnet_model = AutoModelForSequenceClassification.from_pretrained(model_repos["XLNet"]) |
|
xlnet_model.eval() |
|
loaded_models["XLNet"] = {"tokenizer": xlnet_tokenizer, "model": xlnet_model} |
|
|
|
|
|
gpt2_tokenizer = GPT2Tokenizer.from_pretrained(model_repos["GPT-2"], padding_side="left") |
|
gpt2_tokenizer.pad_token = gpt2_tokenizer.eos_token |
|
gpt2_model = AutoModelForSequenceClassification.from_pretrained(model_repos["GPT-2"]) |
|
gpt2_model.config.pad_token_id = gpt2_tokenizer.pad_token_id |
|
gpt2_model.eval() |
|
loaded_models["GPT-2"] = {"tokenizer": gpt2_tokenizer, "model": gpt2_model} |
|
|
|
|
|
def predict_emotions(text): |
|
output_lines = [] |
|
with torch.no_grad(): |
|
for model_name, components in loaded_models.items(): |
|
tokenizer = components["tokenizer"] |
|
model = components["model"] |
|
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True) |
|
logits = model(**inputs).logits |
|
probs = torch.nn.functional.softmax(logits, dim=1)[0] |
|
top_idx = torch.argmax(probs).item() |
|
top_emotion = emotion_labels[top_idx] |
|
top_confidence = round(float(probs[top_idx]), 4) |
|
output_lines.append(f"{model_name}: {top_emotion} ({top_confidence})") |
|
return "\n".join(output_lines) |
|
|
|
|
|
interface = gr.Interface( |
|
fn=predict_emotions, |
|
inputs=gr.Textbox(lines=3, placeholder="Type a sentence to analyze..."), |
|
outputs=gr.Textbox(label="Top Emotion by Model"), |
|
title="Multi-Model Emotion Classifier", |
|
description="See which emotion is predicted by BERT, XLNet, and GPT-2, along with their confidence." |
|
) |
|
|
|
|
|
interface.launch() |
|
|