ajeetkumar01 commited on
Commit
6fcc580
1 Parent(s): 6f08be9

updated script for top result classifications

Browse files
Files changed (1) hide show
  1. app.py +68 -34
app.py CHANGED
@@ -7,41 +7,75 @@ classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnl
7
  # Define Streamlit app
8
  def main():
9
  # Set page title and favicon
10
- st.set_page_config(page_title="Zero-Shot Text Classification", page_icon=":rocket:")
11
-
12
- # App title and description
13
- st.title("Zero-Shot Text Classification")
14
- st.markdown("""
15
- This app performs zero-shot text classification using the Facebook BART-Large-MNLI model.
16
- Enter a sentence and candidate labels, and the model will predict the most relevant label.
17
- """)
18
-
19
- # Input text box for the sentence to classify
20
- sequence_to_classify = st.text_input("Enter the sentence to classify:")
21
-
22
- # Candidate labels input with help text
23
- st.text("Enter candidate labels separated by commas (e.g., travel, cooking, dancing):")
24
- candidate_labels = st.text_input("Candidate Labels:")
25
-
26
- # Confidence threshold slider
27
- confidence_threshold = st.slider("Confidence Threshold:", min_value=0.0, max_value=1.0, value=0.5, step=0.01)
28
-
29
- # Classification button
30
- if st.button("Classify"):
31
- if sequence_to_classify and candidate_labels:
32
- # Split candidate labels into a list
33
- candidate_labels = [label.strip() for label in candidate_labels.split(",")]
34
-
35
- # Perform classification
36
- classification_result = classifier(sequence_to_classify, candidate_labels)
37
-
38
- # Display classification results
39
- st.subheader("Classification Results:")
40
- for label, score in zip(classification_result["labels"], classification_result["scores"]):
41
- if score >= confidence_threshold:
42
- st.write(f"- {label}: {score}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  else:
44
- st.write(f"- {label}: Below threshold ({score})")
 
45
 
46
  if __name__ == "__main__":
47
  main()
 
 
7
  # Define Streamlit app
8
  def main():
9
  # Set page title and favicon
10
+ st.set_page_config(
11
+ page_title="Zero-Shot Text Classification",
12
+ page_icon=":rocket:",
13
+ layout="wide", # Set layout to wide for better spacing
14
+ initial_sidebar_state="expanded" # Expand sidebar by default
15
+ )
16
+
17
+ # App title and description with colorful text
18
+ st.title(":paintbrush: Zero-Shot Text Classification")
19
+ st.markdown(
20
+ """
21
+ This app performs zero-shot text classification using the Facebook BART-Large-MNLI model.
22
+ Enter a sentence and candidate labels, and the model will predict the most relevant label.
23
+ """
24
+ )
25
+
26
+ # Create a two-column layout
27
+ col1, col2 = st.columns([1, 2])
28
+
29
+ # Left pane: Input elements
30
+ with col1:
31
+ # Input text box for the sentence to classify
32
+ sequence_to_classify = st.text_input("Enter the sentence to classify:")
33
+
34
+ # Candidate labels input with help text
35
+ st.text("Enter candidate labels separated by commas (e.g., travel, cooking, dancing):")
36
+ candidate_labels = st.text_input("Candidate Labels:")
37
+
38
+ # Confidence threshold slider with colorful track
39
+ confidence_threshold = st.slider(
40
+ "Confidence Threshold:",
41
+ min_value=0.0,
42
+ max_value=1.0,
43
+ value=0.5,
44
+ step=0.01,
45
+ key="confidence_threshold",
46
+ help="Move the slider to adjust the confidence threshold."
47
+ )
48
+
49
+ # Classification button with colorful background
50
+ classify_button = st.button(
51
+ "Classify",
52
+ key="classify_button",
53
+ help="Click the button to classify the input text with the provided labels."
54
+ )
55
+
56
+ # Right pane: Results
57
+ with col2:
58
+ if classify_button:
59
+ if sequence_to_classify and candidate_labels:
60
+ # Split candidate labels into a list
61
+ candidate_labels = [label.strip() for label in candidate_labels.split(",")]
62
+
63
+ # Perform classification
64
+ classification_result = classifier(sequence_to_classify, candidate_labels)
65
+
66
+ # Find label with highest score
67
+ max_score_index = classification_result["scores"].index(max(classification_result["scores"]))
68
+ max_label = classification_result["labels"][max_score_index]
69
+ max_score = classification_result["scores"][max_score_index]
70
+
71
+ # Display only the label with the highest score
72
+ if max_score >= confidence_threshold:
73
+ st.subheader("Classification Result:")
74
+ st.write(f"- **{max_label}**: {max_score:.2f}", unsafe_allow_html=True)
75
  else:
76
+ st.subheader("Classification Result:")
77
+ st.write(f"- <span style='color: #888;'>{max_label}:</span> Below threshold ({max_score:.2f})", unsafe_allow_html=True)
78
 
79
  if __name__ == "__main__":
80
  main()
81
+