import streamlit as st from transformers import pipeline import plotly.express as px import pandas as pd st.set_page_config(layout="wide") @st.cache(allow_output_mutation = True) def get_classifier_model(): return pipeline("zero-shot-classification", model="typeform/distilbert-base-uncased-mnli") #return pipeline("zero-shot-classification", model="facebook/bart-large-mnli") #return pipeline("zero-shot-classification",model="sentence-transformers/paraphrase-MiniLM-L6-v2") #st.sidebar.image("Suncorp-Bank-logo.png",width=255) #st.image("Suncorp-Bank-logo.png",width=255) st.title("Review Analyzer") st.markdown("***") text = st.text_area(label="Paste/Type the review here..") st.markdown("***") col1, col2, col3 = st.columns((1,1,1)) col1.header("Select Sentiments") sentiments = col1.multiselect("",["Happy","Sad","Neutral"],["Happy","Sad","Neutral"]) col1.markdown(" \n") col1.markdown(" \n") additional_sentiments = col1.text_input("Enter comma separated sentiments.") if additional_sentiments: sentiments = sentiments + additional_sentiments.split(",") col2.header("Select Topics") entities = col2.multiselect("",["Bank Account","Credit Card","Home Loan","Motor Loan"], ["Bank Account","Credit Card","Home Loan","Motor Loan"]) additional_entities= col2.text_input("Enter comma separated entities.") if additional_entities: entities = entities + additional_entities.split(",") col3.header("Select Reasons") reasons = col3.multiselect("",["Poor Service","No Empathy","Abuse"], ["Poor Service","No Empathy","Abuse"]) additional_reasons= col3.text_input("Enter comma separated reasons.") if additional_reasons: reasons = reasons + additional_reasons.split(",") is_multi_class = st.checkbox("Can have more than one classes",value=True) st.markdown("***") classify_button_clicked = st.button("Classify") def get_classification(candidate_labels): classification_output = classifier(sequence_to_classify, candidate_labels, multi_class=is_multi_class) data = {'Class': classification_output['labels'], 'Scores': classification_output['scores']} df = pd.DataFrame(data) df = df.sort_values(by='Scores', ascending=False) fig = px.bar(df, x='Scores', y='Class', orientation='h', width=400, height=500) fig.update_layout( yaxis=dict( autorange='reversed' ) ) return fig if classify_button_clicked: if text: st.markdown("***") with st.spinner(" Please wait while the text is being classified.."): classifier = get_classifier_model() sequence_to_classify = text # candidate_labels = sentiments + entities + reasons if sentiments: #print(classification_output) fig = get_classification(sentiments) # col5, col6= st.columns((1, 1)) col1.markdown(" \n") col1.write(fig) if entities: #print(classification_output) fig = get_classification(entities) # col7, col8= st.columns((1, 1)) col2.write(fig) if reasons: #print(classification_output) fig = get_classification(reasons) # col7, col8= st.columns((1, 1)) col3.write(fig)