gefedya commited on
Commit
74ce976
1 Parent(s): d2fa891

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -3
app.py CHANGED
@@ -7,7 +7,7 @@ import json
7
  @st.cache()
8
  def get_model():
9
  model = AutoModelForSequenceClassification.from_pretrained("siebert/sentiment-roberta-large-english", num_labels=2)
10
- model.load_state_dict(torch.load('model'))
11
  return model
12
 
13
  @st.cache()
@@ -15,9 +15,11 @@ def get_tokenizer():
15
  tokenizer = AutoTokenizer.from_pretrained("siebert/sentiment-roberta-large-english")
16
  return tokenizer
17
 
18
- def make_prediction():
19
  model = get_model()
20
  tokenizer = tokenizer()
 
 
21
 
22
 
23
 
@@ -41,7 +43,8 @@ with st.form(key='input_form'):
41
  button = st.form_submit_button(label='Classify')
42
  if button:
43
  if to_analyze:
44
- make_prediction(to_analyze)
 
45
  else:
46
  st.markdown("Empty request. Please resubmit")
47
 
 
7
  @st.cache()
8
  def get_model():
9
  model = AutoModelForSequenceClassification.from_pretrained("siebert/sentiment-roberta-large-english", num_labels=2)
10
+ model.load_state_dict(torch.load('cached_model.pth'))
11
  return model
12
 
13
  @st.cache()
 
15
  tokenizer = AutoTokenizer.from_pretrained("siebert/sentiment-roberta-large-english")
16
  return tokenizer
17
 
18
+ def make_prediction(to_analyze):
19
  model = get_model()
20
  tokenizer = tokenizer()
21
+ to_return = model(**tokenizer(to_anayze))
22
+ return to_return
23
 
24
 
25
 
 
43
  button = st.form_submit_button(label='Classify')
44
  if button:
45
  if to_analyze:
46
+ pred = make_prediction(to_analyze)
47
+ st.markdown(pred)
48
  else:
49
  st.markdown("Empty request. Please resubmit")
50