File size: 1,385 Bytes
6ec105b
 
17a8518
6ec105b
bb514b9
202a9f4
6ec105b
 
bb514b9
 
bad8793
17a8518
 
bad8793
17a8518
3b5b50d
2bf97c9
f665bf2
 
0f571f7
bad8793
 
2bf97c9
332e4a1
0f571f7
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
from os import write
from typing import Sequence
import streamlit as st
from hf_model import classifier_zero,load_model
from utils import plot_result,examples_load
import json

classifier=load_model()
ex_text,ex_labels=examples_load()


if __name__ == '__main__':
    st.header("Zero Shot Classification")
    st.write("This app allows you to classify any text into any categories you are interested in.")


    with st.form(key='my_form'):
        text_input = st.text_area("Input any text you want to classify here:",ex_text)
        labels = st.text_input('Write any topic keywords you are interested in here (separate different topics with a ","):',ex_labels, max_chars=1000)
        labels = list(set([x.strip() for x in labels.strip().split(',') if len(x.strip()) > 0]))
        radio = st.radio("Select Multiclass",('Only one topic can be corect at a time','Multiple topics can be correct at a time'),)
        multi_class= True if radio=="Multiple topics can be correct at a time" else False
        submit_button = st.form_submit_button(label='Submit')

    if submit_button:
        if len(labels) == 0:
            st.write('Enter some text and at least one possible topic to see predictions.')
        top_topics, scores = classifier_zero(classifier,sequence=text_input,labels=labels,multi_class=multi_class)
        plot_result(top_topics[::-1][-10:], scores[::-1][-10:])