bentrevett commited on
Commit
e9c923a
1 Parent(s): c5b1982
Files changed (1) hide show
  1. app.py +7 -3
app.py CHANGED
@@ -2,17 +2,21 @@ import streamlit as st
2
  import transformers
3
  import matplotlib.pyplot as plt
4
 
 
5
  @st.cache(allow_output_mutation=True)
6
  def get_pipe():
7
  model_name = "joeddav/distilbert-base-uncased-go-emotions-student"
8
  model = transformers.AutoModelForSequenceClassification.from_pretrained(model_name)
9
  tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
10
- pipe = transformers.pipeline('text-classification', model=model, tokenizer=tokenizer, return_all_scores=True, truncation=True)
 
11
  return pipe
12
 
 
13
  def sort_predictions(predictions):
14
  return sorted(predictions, key=lambda x: x['score'], reverse=True)
15
 
 
16
  st.set_page_config(page_title="Emotion Prediction")
17
  st.title("Emotion Prediction")
18
  st.write("Type text into the text box and then press 'Predict' to get the predicted emotion.")
@@ -24,7 +28,7 @@ text = st.text_area('Enter text here:')
24
  submit = st.button('Predict')
25
 
26
  if submit:
27
-
28
  prediction = pipe(text)[0]
29
  prediction = sort_predictions(prediction)
30
 
@@ -34,7 +38,7 @@ if submit:
34
  tick_label=[p['label'] for p in prediction])
35
  ax.tick_params(rotation=90)
36
  ax.set_ylim(0, 1)
37
-
38
  st.header('Prediction:')
39
  st.pyplot(fig)
40
 
 
2
  import transformers
3
  import matplotlib.pyplot as plt
4
 
5
+
6
  @st.cache(allow_output_mutation=True)
7
  def get_pipe():
8
  model_name = "joeddav/distilbert-base-uncased-go-emotions-student"
9
  model = transformers.AutoModelForSequenceClassification.from_pretrained(model_name)
10
  tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
11
+ pipe = transformers.pipeline('text-classification', model=model, tokenizer=tokenizer,
12
+ return_all_scores=True, truncation=True)
13
  return pipe
14
 
15
+
16
  def sort_predictions(predictions):
17
  return sorted(predictions, key=lambda x: x['score'], reverse=True)
18
 
19
+
20
  st.set_page_config(page_title="Emotion Prediction")
21
  st.title("Emotion Prediction")
22
  st.write("Type text into the text box and then press 'Predict' to get the predicted emotion.")
 
28
  submit = st.button('Predict')
29
 
30
  if submit:
31
+
32
  prediction = pipe(text)[0]
33
  prediction = sort_predictions(prediction)
34
 
 
38
  tick_label=[p['label'] for p in prediction])
39
  ax.tick_params(rotation=90)
40
  ax.set_ylim(0, 1)
41
+
42
  st.header('Prediction:')
43
  st.pyplot(fig)
44