Spaces:
Runtime error
Runtime error
Ryan Kim
commited on
Commit
•
9f21123
1
Parent(s):
323c81e
attempting to fix parser issue
Browse files- src/main.py +5 -5
src/main.py
CHANGED
@@ -33,18 +33,18 @@ def load_model(model_name):
|
|
33 |
# 3) the parser for the outputs, in case we actually need to parse the output to something more sensible
|
34 |
if "model" not in st.session_state:
|
35 |
st.session_state.model_name = "cardiffnlp/twitter-roberta-base-sentiment"
|
36 |
-
model, tokenizer, classifier,
|
37 |
st.session_state.model = model
|
38 |
st.session_state.tokenizer = tokenizer
|
39 |
st.session_state.classifier = classifier
|
40 |
-
st.session_state.
|
41 |
|
42 |
def model_change():
|
43 |
-
model, tokenizer, classifier,
|
44 |
st.session_state.model = model
|
45 |
st.session_state.tokenizer = tokenizer
|
46 |
st.session_state.classifier = classifier
|
47 |
-
st.session_state.
|
48 |
|
49 |
# Title
|
50 |
st.title("CSGY-6613 Sentiment Analysis")
|
@@ -81,7 +81,7 @@ if submit:
|
|
81 |
label = result[0]['label']
|
82 |
score = result[0]['score']
|
83 |
|
84 |
-
label = st.session_state.
|
85 |
|
86 |
st.markdown("#### Result:")
|
87 |
st.markdown("**{}**: {}".format(label,score))
|
|
|
33 |
# 3) the parser for the outputs, in case we actually need to parse the output to something more sensible
|
34 |
if "model" not in st.session_state:
|
35 |
st.session_state.model_name = "cardiffnlp/twitter-roberta-base-sentiment"
|
36 |
+
model, tokenizer, classifier, label_parser = load_model("cardiffnlp/twitter-roberta-base-sentiment")
|
37 |
st.session_state.model = model
|
38 |
st.session_state.tokenizer = tokenizer
|
39 |
st.session_state.classifier = classifier
|
40 |
+
st.session_state.label_parser = label_parser
|
41 |
|
42 |
def model_change():
|
43 |
+
model, tokenizer, classifier, label_parser = load_model(st.session_state.model_name)
|
44 |
st.session_state.model = model
|
45 |
st.session_state.tokenizer = tokenizer
|
46 |
st.session_state.classifier = classifier
|
47 |
+
st.session_state.label_parser = label_parser
|
48 |
|
49 |
# Title
|
50 |
st.title("CSGY-6613 Sentiment Analysis")
|
|
|
81 |
label = result[0]['label']
|
82 |
score = result[0]['score']
|
83 |
|
84 |
+
label = st.session_state.label_parser(label)
|
85 |
|
86 |
st.markdown("#### Result:")
|
87 |
st.markdown("**{}**: {}".format(label,score))
|