File size: 3,487 Bytes
adb7c0e
ca4ab65
323c81e
ca4ab65
323c81e
 
 
adb7c0e
7883937
 
 
 
 
 
 
 
 
 
 
 
49709ee
323c81e
 
6cfdc3f
323c81e
 
 
 
 
 
 
 
c8eeef2
 
9f21123
323c81e
 
 
9f21123
cffde68
c8eeef2
9f21123
323c81e
 
 
9f21123
323c81e
 
 
 
 
 
cffde68
9d5a1f8
c8eeef2
 
 
161261c
7f3c646
 
c8eeef2
 
 
 
9d5a1f8
 
 
cffde68
 
 
 
 
 
 
7f3c646
 
 
 
323c81e
cffde68
 
 
9f21123
ec1e335
c8eeef2
ec1e335
4ba9444
0f1e899
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import streamlit as st
from transformers import pipeline
from transformers import AutoTokenizer, AutoModelForSequenceClassification

# We'll be using Torch this time around
import torch
import torch.nn.functional as F

def label_dictionary(model_name):
    if model_name == "cardiffnlp/twitter-roberta-base-sentiment":
        def twitter_roberta(label):
            if label == "LABEL_0":
                return "Negative"
            elif label == "LABEL_2":
                return "Positive"
            else:
                return "Neutral"
        return twitter_roberta
    return lambda x: x

@st.cache(allow_output_mutation=True)
def load_model(model_name):
    model = AutoModelForSequenceClassification.from_pretrained(model_name)
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    classifier = pipeline(task="sentiment-analysis", model=model, tokenizer=tokenizer)
    parser = label_dictionary(model_name)
    return model, tokenizer, classifier, parser

# We first initialize a state. The state will include the following:
# 1) the name of the model (default: cardiffnlp/twitter-roberta-base-sentiment)
# 2) the model itself, and 
# 3) the parser for the outputs, in case we actually need to parse the output to something more sensible
if "model" not in st.session_state:
    st.session_state.model_name = "cardiffnlp/twitter-roberta-base-sentiment"
    model, tokenizer, classifier, label_parser = load_model("cardiffnlp/twitter-roberta-base-sentiment")
    st.session_state.model = model
    st.session_state.tokenizer = tokenizer
    st.session_state.classifier = classifier
    st.session_state.label_parser = label_parser

def model_change():
    model, tokenizer, classifier, label_parser = load_model(st.session_state.model_name)
    st.session_state.model = model
    st.session_state.tokenizer = tokenizer
    st.session_state.classifier = classifier
    st.session_state.label_parser = label_parser

# Title
st.title("CSGY-6613 Sentiment Analysis")
# Subtitle
st.markdown("### Ryan Kim (rk2546)")
st.markdown("")

model_option = st.selectbox(
    "What sentiment analysis model do you want to use?",
    (
        "cardiffnlp/twitter-roberta-base-sentiment",
        "finiteautomata/beto-sentiment-analysis",
        "bhadresh-savani/distilbert-base-uncased-emotion",
        "siebert/sentiment-roberta-large-english"
    ),
    on_change=model_change,
    key="model_name"
)
placeholder="@AmericanAir just landed - 3hours Late Flight - and now we need to wait TWENTY MORE MINUTES for a gate! I have patience but none for incompetence."
form = st.form(key='sentiment-analysis-form')
text_input = form.text_area("Enter some text for sentiment analysis! If you just want to test it out without entering anything, just press the \"Submit\" button and the model will look at the placeholder.", placeholder=placeholder)
submit = form.form_submit_button('Submit')

if submit:
    if text_input is None or len(text_input.strip()) == 0:
        to_eval = placeholder
    else:
        to_eval = text_input.strip()
    st.write("You entered:")
    st.markdown("> {}".format(to_eval))
    st.write("Using the NLP model:")
    st.markdown("> {}".format(st.session_state.model_name))
    result = st.session_state.classifier(to_eval)  
    label = result[0]['label']
    score = result[0]['score']

    label = st.session_state.label_parser(label)

    st.markdown("#### Result:")
    st.markdown("**{}**: {}".format(label,score))
    st.write("")
    st.write("")