Ryan Kim commited on
Commit
9f21123
1 Parent(s): 323c81e

attempting to fix parser issue

Browse files
Files changed (1) hide show
  1. src/main.py +5 -5
src/main.py CHANGED
@@ -33,18 +33,18 @@ def load_model(model_name):
33
  # 3) the parser for the outputs, in case we actually need to parse the output to something more sensible
34
  if "model" not in st.session_state:
35
  st.session_state.model_name = "cardiffnlp/twitter-roberta-base-sentiment"
36
- model, tokenizer, classifier, parser = load_model("cardiffnlp/twitter-roberta-base-sentiment")
37
  st.session_state.model = model
38
  st.session_state.tokenizer = tokenizer
39
  st.session_state.classifier = classifier
40
- st.session_state.parser = parser
41
 
42
  def model_change():
43
- model, tokenizer, classifier, parser = load_model(st.session_state.model_name)
44
  st.session_state.model = model
45
  st.session_state.tokenizer = tokenizer
46
  st.session_state.classifier = classifier
47
- st.session_state.parser = parser
48
 
49
  # Title
50
  st.title("CSGY-6613 Sentiment Analysis")
@@ -81,7 +81,7 @@ if submit:
81
  label = result[0]['label']
82
  score = result[0]['score']
83
 
84
- label = st.session_state.parser(label)
85
 
86
  st.markdown("#### Result:")
87
  st.markdown("**{}**: {}".format(label,score))
 
33
  # 3) the parser for the outputs, in case we actually need to parse the output to something more sensible
34
  if "model" not in st.session_state:
35
  st.session_state.model_name = "cardiffnlp/twitter-roberta-base-sentiment"
36
+ model, tokenizer, classifier, label_parser = load_model("cardiffnlp/twitter-roberta-base-sentiment")
37
  st.session_state.model = model
38
  st.session_state.tokenizer = tokenizer
39
  st.session_state.classifier = classifier
40
+ st.session_state.label_parser = label_parser
41
 
42
  def model_change():
43
+ model, tokenizer, classifier, label_parser = load_model(st.session_state.model_name)
44
  st.session_state.model = model
45
  st.session_state.tokenizer = tokenizer
46
  st.session_state.classifier = classifier
47
+ st.session_state.label_parser = label_parser
48
 
49
  # Title
50
  st.title("CSGY-6613 Sentiment Analysis")
 
81
  label = result[0]['label']
82
  score = result[0]['score']
83
 
84
+ label = st.session_state.label_parser(label)
85
 
86
  st.markdown("#### Result:")
87
  st.markdown("**{}**: {}".format(label,score))