File size: 2,672 Bytes
385dc2f
 
 
 
 
 
 
 
 
9c9e66f
6fcc580
9c9e66f
6a7dca8
9c9e66f
6fcc580
 
 
 
 
 
9c9e66f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
385dc2f
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import streamlit as st
from transformers import pipeline

# Load the zero-shot classification model
classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")

# Define Streamlit app
def main():
    # Set page title and favicon
    st.set_page_config(page_title="Zero-Shot Text Classification", page_icon=":rocket:")

    # App title and description
    st.title("Zero-Shot Text Classification")
    st.write(
        """
        This app performs zero-shot text classification using the Facebook BART-Large-MNLI model.
        Enter a sentence and candidate labels, and the model will predict the most relevant label.
        """
    )

    # Input text box for the sentence to classify
    sequence_to_classify = st.text_input("Enter the sentence to classify:", key="input_sentence", 
                                         type="default", value="")

    # Candidate labels input with help text
    st.text("Enter candidate labels separated by commas (e.g., travel, cooking, dancing):")
    candidate_labels = st.text_input("Candidate Labels:", key="input_labels", type="default", value="")

    # Confidence threshold slider
    confidence_threshold = st.slider("Confidence Threshold:", min_value=0.0, max_value=1.0, value=0.5, step=0.01, 
                                     key="confidence_threshold")

    # Classification button
    if st.button("Classify", key="classify_button"):
        if sequence_to_classify and candidate_labels:
            # Split candidate labels into a list
            candidate_labels = [label.strip() for label in candidate_labels.split(",")]

            # Perform classification
            classification_result = classifier(sequence_to_classify, candidate_labels)

            # Find label with highest score
            max_score_index = classification_result["scores"].index(max(classification_result["scores"]))
            max_label = classification_result["labels"][max_score_index]
            max_score = classification_result["scores"][max_score_index]

            # Display classification results
            st.subheader("Classification Results:")
            st.markdown("---")
            st.markdown(f"**{max_label}**: {max_score:.2f}", unsafe_allow_html=True)
            st.markdown("---")
            for label, score in zip(classification_result["labels"], classification_result["scores"]):
                if label != max_label:
                    if score >= confidence_threshold:
                        st.text(f"'{label}': {score:.2f}")
                    else:
                        st.text(f"'{label}': Below threshold ({score:.2f})")

if __name__ == "__main__":
    main()