ashrestha commited on
Commit
06ba6f1
1 Parent(s): db840ed

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -5
app.py CHANGED
@@ -1,4 +1,5 @@
1
  import streamlit as st
 
2
 
3
  from transformers import pipeline
4
 
@@ -9,15 +10,21 @@ classifier = pipeline("zero-shot-classification",
9
  with st.form('inputs'):
10
  input_text = st.text_area("Input text")
11
  input_label = st.text_input("Labels", placeholder="support, help, important")
 
12
  submit_button = st.form_submit_button(label='Submit')
13
 
14
  if submit_button:
15
 
16
  labels = list(l.strip() for l in input_label.split(','))
17
- pred = classifier(input_text, labels, multi_class=True)
 
 
 
 
 
 
 
 
 
18
 
19
- out = f"Top predicted labels are {', '.join(p for p in pred['labels'][0:2])}"
20
-
21
- st.success(out)
22
-
23
  # st.markdown(pred)
 
1
  import streamlit as st
2
+ import json
3
 
4
  from transformers import pipeline
5
 
 
10
  with st.form('inputs'):
11
  input_text = st.text_area("Input text")
12
  input_label = st.text_input("Labels", placeholder="support, help, important")
13
+ input_multi = st.checkbox('Allow multiple true classes', value=False)
14
  submit_button = st.form_submit_button(label='Submit')
15
 
16
  if submit_button:
17
 
18
  labels = list(l.strip() for l in input_label.split(','))
19
+ pred = classifier(input_text, labels, multi_class=input_multi)
20
+
21
+ if input_multi:
22
+ st.vega_lite_chart(json.dumps(pred),
23
+ {'mark': 'bar', 'encoding': {'y': 'labels', 'x': 'scores'}}
24
+ )
25
+ else:
26
+ out = f"Top predicted labels are {', '.join(p for p in pred['labels'][0:2])}"
27
+
28
+ st.success(out)
29
 
 
 
 
 
30
  # st.markdown(pred)