Spaces:
Runtime error
Runtime error
import gradio as gr | |
from transformers import AutoTokenizer, pipeline, AutoModelForSequenceClassification | |
# load the tokenizer and model from Hugging Face | |
tokenizer = AutoTokenizer.from_pretrained("ethanrom/a2") | |
model = AutoModelForSequenceClassification.from_pretrained("ethanrom/a2") | |
# define the classification labels | |
class_labels = ["Negative", "Positive", "No Impact", "Mixed"] | |
# create the zero-shot classification pipeline | |
classifier = pipeline("zero-shot-classification", model=model, tokenizer=tokenizer, device=0) | |
# define the Gradio interface | |
def predict_sentiment(text, model_choice): | |
if model_choice == "bert": | |
# use the default BERT sentiment analysis pipeline | |
sentiment_classifier = pipeline("sentiment-analysis", device=0) | |
result = sentiment_classifier(text)[0] | |
label = result["label"] | |
score = result["score"] | |
return f"{label} ({score:.2f})" | |
else: | |
# use the fine-tuned RoBERTa model for multi-class classification | |
labels = class_labels | |
hypothesis_template = "This text is about {}." | |
result = classifier(text, hypothesis_template=hypothesis_template, multi_class=True, labels=labels) | |
scores = result["scores"] | |
predicted_label = result["labels"][0] | |
return f"{predicted_label} ({scores[0]:.2f})" | |
# define the Gradio interface inputs and outputs | |
inputs = [gr.inputs.Textbox(label="Input Text"), gr.inputs.Radio(["bert", "fine-tuned RoBERTa"], label="Model Choice")] | |
outputs = gr.outputs.Textbox(label="Sentiment Prediction") | |
# create the Gradio interface | |
gr.Interface(predict_sentiment, inputs, outputs, title="Sentiment Analysis App").launch() | |