|
import streamlit as st |
|
from transformers import pipeline |
|
|
|
|
|
classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli") |
|
|
|
|
|
def main(): |
|
|
|
st.set_page_config(page_title="Zero-Shot Text Classification", page_icon=":rocket:") |
|
|
|
|
|
st.title("Zero-Shot Text Classification") |
|
st.write( |
|
""" |
|
This app performs zero-shot text classification using the Facebook BART-Large-MNLI model. |
|
Enter a sentence and candidate labels, and the model will predict the most relevant label. |
|
""" |
|
) |
|
|
|
|
|
sequence_to_classify = st.text_input("Enter the sentence to classify:", key="input_sentence", |
|
type="default", value="") |
|
|
|
|
|
st.text("Enter candidate labels separated by commas (e.g., travel, cooking, dancing):") |
|
candidate_labels = st.text_input("Candidate Labels:", key="input_labels", type="default", value="") |
|
|
|
|
|
confidence_threshold = st.slider("Confidence Threshold:", min_value=0.0, max_value=1.0, value=0.5, step=0.01, |
|
key="confidence_threshold") |
|
|
|
|
|
if st.button("Classify", key="classify_button"): |
|
if sequence_to_classify and candidate_labels: |
|
|
|
candidate_labels = [label.strip() for label in candidate_labels.split(",")] |
|
|
|
|
|
classification_result = classifier(sequence_to_classify, candidate_labels) |
|
|
|
|
|
max_score_index = classification_result["scores"].index(max(classification_result["scores"])) |
|
max_label = classification_result["labels"][max_score_index] |
|
max_score = classification_result["scores"][max_score_index] |
|
|
|
|
|
st.subheader("Classification Results:") |
|
st.markdown("---") |
|
st.markdown(f"**{max_label}**: {max_score:.2f}", unsafe_allow_html=True) |
|
st.markdown("---") |
|
for label, score in zip(classification_result["labels"], classification_result["scores"]): |
|
if label != max_label: |
|
if score >= confidence_threshold: |
|
st.text(f"'{label}': {score:.2f}") |
|
else: |
|
st.text(f"'{label}': Below threshold ({score:.2f})") |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|