|
import streamlit as st |
|
from transformers import pipeline |
|
|
|
|
|
classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli") |
|
|
|
|
|
def main(): |
|
|
|
st.set_page_config( |
|
page_title="Zero-Shot Text Classification", |
|
page_icon=":rocket:", |
|
layout="wide", |
|
initial_sidebar_state="expanded" |
|
) |
|
|
|
|
|
st.title(":paintbrush: Zero-Shot Text Classification") |
|
st.markdown( |
|
""" |
|
This app performs zero-shot text classification using the Facebook BART-Large-MNLI model. |
|
Enter a sentence and candidate labels, and the model will predict the most relevant label. |
|
""" |
|
) |
|
|
|
|
|
col1, col2 = st.columns([1, 2]) |
|
|
|
|
|
with col1: |
|
|
|
sequence_to_classify = st.text_input("Enter the sentence to classify:") |
|
|
|
|
|
st.text("Enter candidate labels separated by commas (e.g., travel, cooking, dancing):") |
|
candidate_labels = st.text_input("Candidate Labels:") |
|
|
|
|
|
confidence_threshold = st.slider( |
|
"Confidence Threshold:", |
|
min_value=0.0, |
|
max_value=1.0, |
|
value=0.5, |
|
step=0.01, |
|
key="confidence_threshold", |
|
help="Move the slider to adjust the confidence threshold." |
|
) |
|
|
|
|
|
classify_button = st.button( |
|
"Classify", |
|
key="classify_button", |
|
help="Click the button to classify the input text with the provided labels." |
|
) |
|
|
|
|
|
with col2: |
|
if classify_button: |
|
if sequence_to_classify and candidate_labels: |
|
|
|
candidate_labels = [label.strip() for label in candidate_labels.split(",")] |
|
|
|
|
|
classification_result = classifier(sequence_to_classify, candidate_labels) |
|
|
|
|
|
max_score_index = classification_result["scores"].index(max(classification_result["scores"])) |
|
max_label = classification_result["labels"][max_score_index] |
|
max_score = classification_result["scores"][max_score_index] |
|
|
|
|
|
if max_score >= confidence_threshold: |
|
st.subheader("Classification Result:") |
|
st.write(f"- **{max_label}**: {max_score:.2f}", unsafe_allow_html=True) |
|
else: |
|
st.subheader("Classification Result:") |
|
st.write(f"- <span style='color: #888;'>{max_label}:</span> Below threshold ({max_score:.2f})", unsafe_allow_html=True) |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|
|
|