from __future__ import annotations import psutil import pandas as pd import streamlit as st import plotly.express as px from models import NLI_MODEL_OPTIONS, NSP_MODEL_OPTIONS, METHOD_OPTIONS from zeroshot_classification.classifiers import NSPZeroshotClassifier, NLIZeroshotClassifier print(f"Total mem: {psutil.virtual_memory().total}") def init_state(key: str): if key not in st.session_state: st.session_state[key] = None for k in [ "current_model", "current_model_option", "current_method_option", "current_prediction", "current_chart", ]: init_state(k) def load_model(model_option: str, method_option: str, random_state: int = 0): with st.spinner("Loading selected model..."): if method_option == "Natural Language Inference": st.session_state.current_model = NLIZeroshotClassifier( model_name=model_option, random_state=random_state ) else: st.session_state.current_model = NSPZeroshotClassifier( model_name=model_option, random_state=random_state ) st.success("Model loaded!") def visualize_output(labels: list[str], probabilities: list[float]): data = pd.DataFrame({"labels": labels, "probability": probabilities}).sort_values( by="probability", ascending=False ) chart = px.bar( data, x="probability", y="labels", color="labels", orientation="h", height=290, width=500, ).update_layout( { "xaxis": {"title": "probability", "visible": True, "showticklabels": True}, "yaxis": {"title": None, "visible": True, "showticklabels": True}, "margin": dict( l=10, # left r=10, # right t=50, # top b=10, # bottom ), "showlegend": False, } ) return chart st.title("Zero-shot Turkish Text Classification") method_option = st.radio( "Select a zero-shot classification method.", [ METHOD_OPTIONS["nli"], METHOD_OPTIONS["nsp"], ], ) if method_option == METHOD_OPTIONS["nli"]: model_option = st.selectbox( "Select a natural language inference model.", NLI_MODEL_OPTIONS, index=3 ) if method_option == METHOD_OPTIONS["nsp"]: model_option = st.selectbox( "Select a BERT model for next sentence prediction.", NSP_MODEL_OPTIONS, index=0 ) if model_option != st.session_state.current_model_option: st.session_state.current_model_option = model_option st.session_state.current_method_option = method_option load_model( st.session_state.current_model_option, st.session_state.current_method_option ) st.header("Configure prompts and labels") col1, col2 = st.columns(2) col1.subheader("Candidate labels") labels = col1.text_area( label="These are the labels that the model will try to predict for the given text input. Your input labels should be comma separated and meaningful.", value="tatli,burger,kebab,diğer,tuzlu", key="current_labels", ) col1.header("Make predictions") text = col1.text_area( "Enter a sentence or a paragraph to classify.", value="baklava", key="current_text", ) col2.subheader("Prompt template") prompt_template = col2.text_area( label="Prompt template is used to transform NLI and NSP tasks into a general-use zero-shot classifier. Models replace {} with the labels that you have given.", value="{}", key="current_template", ) col2.header("") make_pred = col1.button("Predict") if make_pred: st.session_state.current_prediction = ( st.session_state.current_model.predict_on_texts( [st.session_state.current_text], candidate_labels=st.session_state.current_labels.split(","), prompt_template=st.session_state.current_template, ) ) if "scores" in st.session_state.current_prediction[0]: st.session_state.current_chart = visualize_output( st.session_state.current_prediction[0]["labels"], st.session_state.current_prediction[0]["scores"], ) elif "probabilities" in st.session_state.current_prediction[0]: st.session_state.current_chart = visualize_output( st.session_state.current_prediction[0]["labels"], st.session_state.current_prediction[0]["probabilities"], ) col2.plotly_chart(st.session_state.current_chart, use_container_width=True)