Ryan Kim commited on
Commit
323c81e
1 Parent(s): 5bbb9e4

testing new alternative to caching and state and models and tokenizers

Browse files
Files changed (1) hide show
  1. src/main.py +34 -15
src/main.py CHANGED
@@ -1,15 +1,10 @@
1
  import streamlit as st
2
  from transformers import pipeline
 
3
 
4
- # Title
5
- st.title("CSGY-6613 Sentiment Analysis")
6
- # Subtitle
7
- st.markdown("### Ryan Kim (rk2546)")
8
- st.markdown("")
9
-
10
- @st.cache(allow_output_mutation=True)
11
- def load_model(model_name):
12
- return pipeline(model=model_name, task="sentiment-analysis")
13
 
14
  @st.cache(allow_output_mutation=True)
15
  def label_dictionary(model_name):
@@ -24,14 +19,38 @@ def label_dictionary(model_name):
24
  return twitter_roberta
25
  return lambda x: x
26
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  if "model" not in st.session_state:
28
  st.session_state.model_name = "cardiffnlp/twitter-roberta-base-sentiment"
29
- st.session_state.model = load_model("cardiffnlp/twitter-roberta-base-sentiment")
30
- st.session_state.label_parser = label_dictionary("cardiffnlp/twitter-roberta-base-sentiment")
 
 
 
31
 
32
  def model_change():
33
- st.session_state.model = load_model(st.session_state.model_name)
34
- st.session_state.label_parser = label_dictionary(st.session_state.model_name)
 
 
 
 
 
 
 
 
 
35
 
36
  model_option = st.selectbox(
37
  "What sentiment analysis model do you want to use?",
@@ -58,11 +77,11 @@ if submit:
58
  st.markdown("> {}".format(to_eval))
59
  st.write("Using the NLP model:")
60
  st.markdown("> {}".format(st.session_state.model_name))
61
- result = st.session_state.model(to_eval)
62
  label = result[0]['label']
63
  score = result[0]['score']
64
 
65
- label = st.session_state.label_parser(label)
66
 
67
  st.markdown("#### Result:")
68
  st.markdown("**{}**: {}".format(label,score))
 
1
  import streamlit as st
2
  from transformers import pipeline
3
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
4
 
5
+ # We'll be using Torch this time around
6
+ import torch
7
+ import torch.nn.functional as F
 
 
 
 
 
 
8
 
9
  @st.cache(allow_output_mutation=True)
10
  def label_dictionary(model_name):
 
19
  return twitter_roberta
20
  return lambda x: x
21
 
22
+ @st.cache(allow_output_mutation=True)
23
+ def load_model(model_name):
24
+ model = AutoModelForSequenceClassification.from_pretrained(model_name)
25
+ tokenizer = AutoTokenizer.from_pretrained("cardiffnlp/twitter-roberta-base-sentiment")
26
+ classifier = pipeline(task="sentiment-analysis", model=model, tokenizer=tokenizer)
27
+ parser = label_dictionary(model_name)
28
+ return model, tokenizer, classifier, parser
29
+
30
+ # We first initialize a state. The state will include the following:
31
+ # 1) the name of the model (default: cardiffnlp/twitter-roberta-base-sentiment)
32
+ # 2) the model itself, and
33
+ # 3) the parser for the outputs, in case we actually need to parse the output to something more sensible
34
  if "model" not in st.session_state:
35
  st.session_state.model_name = "cardiffnlp/twitter-roberta-base-sentiment"
36
+ model, tokenizer, classifier, parser = load_model("cardiffnlp/twitter-roberta-base-sentiment")
37
+ st.session_state.model = model
38
+ st.session_state.tokenizer = tokenizer
39
+ st.session_state.classifier = classifier
40
+ st.session_state.parser = parser
41
 
42
  def model_change():
43
+ model, tokenizer, classifier, parser = load_model(st.session_state.model_name)
44
+ st.session_state.model = model
45
+ st.session_state.tokenizer = tokenizer
46
+ st.session_state.classifier = classifier
47
+ st.session_state.parser = parser
48
+
49
+ # Title
50
+ st.title("CSGY-6613 Sentiment Analysis")
51
+ # Subtitle
52
+ st.markdown("### Ryan Kim (rk2546)")
53
+ st.markdown("")
54
 
55
  model_option = st.selectbox(
56
  "What sentiment analysis model do you want to use?",
 
77
  st.markdown("> {}".format(to_eval))
78
  st.write("Using the NLP model:")
79
  st.markdown("> {}".format(st.session_state.model_name))
80
+ result = st.session_state.classifier(to_eval)
81
  label = result[0]['label']
82
  score = result[0]['score']
83
 
84
+ label = st.session_state.parser(label)
85
 
86
  st.markdown("#### Result:")
87
  st.markdown("**{}**: {}".format(label,score))