File size: 2,239 Bytes
92051e7
9f3f4d0
40399bd
 
a218fec
2e7807e
5471072
40399bd
a218fec
 
 
 
 
f439bdc
a218fec
40399bd
a218fec
e59550c
 
40399bd
bfb7cd7
40399bd
bfb7cd7
a218fec
643df72
bfb7cd7
 
a218fec
643df72
bfb7cd7
 
 
a218fec
643df72
bfb7cd7
 
643df72
 
e59550c
4d96058
b5600cf
a218fec
e143930
 
a218fec
70285b0
a218fec
 
 
 
 
 
 
 
 
 
3e55229
e59550c
643df72
a218fec
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import os
import requests
import streamlit as st

API_URL = "https://api-inference.huggingface.co/models/gpt2"  # Updated API endpoint
API_TOKEN = os.environ.get('HUGGINGFACEHUB_API_TOKEN')
HEADERS = {"Authorization": f"Bearer {API_TOKEN}"}

def get_sentiment_category(sentiment_score):
    if sentiment_score > 0:
        return "positive"
    elif sentiment_score < 0:
        return "negative"
    else:
        return "mixed"

st.title("GPT-2 Movie Review Sentiment Analysis")

input_text = st.text_area("Enter movie review:", "")

analysis_type = st.radio("Select analysis type:", ["Zero-shot", "One-shot", "Few-shot"])

if analysis_type == "Zero-shot":
    prompt = f"Sentiment analysis of the following movie review: \n\n{input_text}\n\nSentiment:"

elif analysis_type == "One-shot":
    example = st.text_area("Input one example:")
    prompt = f"{example}\n\nMovie review:\n{input_text}\n\nSentiment:"

elif analysis_type == "Few-shot":
    examples = st.text_area("Input few-shot examples, one per line:")
    examples_list = examples.split('\n')
    prompt = f"Sentiment analysis examples: \n{', '.join(examples_list)}\n\nMovie review: \n{input_text}\n\nSentiment:"

if st.button("Analyze"):
    try:
        response = requests.post(API_URL, headers=HEADERS, json={"inputs": prompt}, timeout=10)
        response.raise_for_status()

        result = response.json()[0]['generated_text']

        # Extract sentiment score and category
        sentiment_start = result.find("Sentiment:") + len("Sentiment:")
        sentiment_end = result.find(".", sentiment_start)
        sentiment_score_text = result[sentiment_start:sentiment_end].strip()

        # Extract only the numerical part of the score
        if '/' in sentiment_score_text:
            sentiment_score_text = sentiment_score_text.split('/')[0]

        try:
            sentiment_score = float(sentiment_score_text.split()[0])
        except ValueError:
            sentiment_score = None

        sentiment_category = get_sentiment_category(sentiment_score) if sentiment_score is not None else "mixed"
        st.write(f"Sentiment: {sentiment_category}")

    except requests.exceptions.RequestException as e:
        st.error("Error reaching API\n{}".format(e))