|
import streamlit as st |
|
from transformers import pipeline |
|
|
|
|
|
classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli") |
|
|
|
|
|
def main(): |
|
|
|
st.set_page_config(page_title="Zero-Shot Text Classification", page_icon=":rocket:") |
|
|
|
|
|
st.title("Zero-Shot Text Classification") |
|
st.markdown(""" |
|
This app performs zero-shot text classification using the Facebook BART-Large-MNLI model. |
|
Enter a sentence and candidate labels, and the model will predict the most relevant label. |
|
""") |
|
|
|
|
|
sequence_to_classify = st.text_input("Enter the sentence to classify:") |
|
|
|
|
|
st.text("Enter candidate labels separated by commas (e.g., travel, cooking, dancing):") |
|
candidate_labels = st.text_input("Candidate Labels:") |
|
|
|
|
|
confidence_threshold = st.slider("Confidence Threshold:", min_value=0.0, max_value=1.0, value=0.5, step=0.01) |
|
|
|
|
|
if st.button("Classify"): |
|
if sequence_to_classify and candidate_labels: |
|
|
|
candidate_labels = [label.strip() for label in candidate_labels.split(",")] |
|
|
|
|
|
classification_result = classifier(sequence_to_classify, candidate_labels) |
|
|
|
|
|
st.subheader("Classification Results:") |
|
for label, score in zip(classification_result["labels"], classification_result["scores"]): |
|
if score >= confidence_threshold: |
|
st.write(f"- {label}: {score}") |
|
else: |
|
st.write(f"- {label}: Below threshold ({score})") |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|