Spaces:
Sleeping
Sleeping
import os | |
import requests | |
import streamlit as st | |
API_URL = "https://api-inference.huggingface.co/models/distilgpt2" # Updated API endpoint | |
API_TOKEN = os.environ.get('HUGGINGFACEHUB_API_TOKEN') | |
HEADERS = {"Authorization": f"Bearer {API_TOKEN}"} | |
def get_sentiment_category(sentiment_label): | |
sentiment_label = sentiment_label.lower() | |
if "pos" in sentiment_label: | |
return "Positive" | |
elif "neg" in sentiment_label: | |
return "Negative" | |
else: | |
return "Mixed" | |
st.title("DistilGPT2 Movie Review Sentiment Analysis") | |
input_text = st.text_area("Enter movie review:", "") | |
analysis_type = st.radio("Select analysis type:", ["Zero-shot", "One-shot", "Few-shot"]) | |
if analysis_type == "Zero-shot": | |
prompt = f"Classify sentiment as positive or negative or mixed: \n\n{input_text}\n\nSentiment:" | |
elif analysis_type == "One-shot": | |
example = st.text_area("Input one example:") | |
prompt = f"Classify sentiment as positive or negative or mixed: \n{example}\n\nMovie review:\n{input_text}\n\nSentiment:" | |
elif analysis_type == "Few-shot": | |
examples = st.text_area("Input few-shot examples, one per line:") | |
examples_list = examples.split('\n') | |
prompt = f"Classify sentiment as positive or negative or mixed: \n{', '.join(examples_list)}\n\nMovie review: \n{input_text}\n\nSentiment:" | |
if st.button("Analyze"): | |
try: | |
response = requests.post(API_URL, headers=HEADERS, json={"inputs": prompt}, timeout=10) | |
response.raise_for_status() | |
result = response.json()[0]['generated_text'] | |
# Extract sentiment label directly | |
sentiment_start = result.find("Sentiment:") + len("Sentiment:") | |
sentiment_end = result.find(".", sentiment_start) | |
sentiment_label = result[sentiment_start:sentiment_end].strip() | |
# Convert sentiment label to category | |
sentiment_category = get_sentiment_category(sentiment_label) | |
st.write(f"Sentiment: {sentiment_category}") | |
except requests.exceptions.RequestException as e: | |
st.error("Error reaching API\n{}".format(e)) | |