ajeetkumar01 commited on
Commit
6f08be9
1 Parent(s): 7d24f90

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -61
app.py CHANGED
@@ -7,69 +7,41 @@ classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnl
7
  # Define Streamlit app
8
  def main():
9
  # Set page title and favicon
10
- st.set_page_config(
11
- page_title="Zero-Shot Text Classification",
12
- page_icon=":rocket:",
13
- layout="wide", # Set layout to wide for better spacing
14
- initial_sidebar_state="expanded" # Expand sidebar by default
15
- )
16
 
17
- # App title and description with colorful text
18
  st.title("Zero-Shot Text Classification")
19
- st.markdown(
20
- """
21
- This app performs zero-shot text classification using the Facebook BART-Large-MNLI model.
22
- Enter a sentence and candidate labels, and the model will predict the most relevant label.
23
- """
24
- )
25
-
26
- # Create a two-column layout
27
- col1, col2 = st.columns([1, 2])
28
-
29
- # Left pane: Input elements
30
- with col1:
31
- # Input text box for the sentence to classify
32
- sequence_to_classify = st.text_input("Enter the sentence to classify:")
33
-
34
- # Candidate labels input with help text
35
- st.text("Enter candidate labels separated by commas (e.g., travel, cooking, dancing):")
36
- candidate_labels = st.text_input("Candidate Labels:")
37
-
38
- # Confidence threshold slider with colorful track
39
- confidence_threshold = st.slider(
40
- "Confidence Threshold:",
41
- min_value=0.0,
42
- max_value=1.0,
43
- value=0.5,
44
- step=0.01,
45
- key="confidence_threshold",
46
- help="Move the slider to adjust the confidence threshold."
47
- )
48
-
49
- # Classification button with colorful background
50
- classify_button = st.button(
51
- "Classify",
52
- key="classify_button",
53
- help="Click the button to classify the input text with the provided labels."
54
- )
55
-
56
- # Right pane: Results
57
- with col2:
58
- if classify_button:
59
- if sequence_to_classify and candidate_labels:
60
- # Split candidate labels into a list
61
- candidate_labels = [label.strip() for label in candidate_labels.split(",")]
62
-
63
- # Perform classification
64
- classification_result = classifier(sequence_to_classify, candidate_labels)
65
-
66
- # Display classification results with colorful labels
67
- st.subheader("Classification Results:")
68
- for label, score in zip(classification_result["labels"], classification_result["scores"]):
69
- if score >= confidence_threshold:
70
- st.write(f"- **{label}**: {score:.2f}", unsafe_allow_html=True)
71
- else:
72
- st.write(f"- <span style='color: #888;'>{label}:</span> Below threshold ({score:.2f})", unsafe_allow_html=True)
73
 
74
  if __name__ == "__main__":
75
  main()
 
7
  # Define Streamlit app
8
  def main():
9
  # Set page title and favicon
10
+ st.set_page_config(page_title="Zero-Shot Text Classification", page_icon=":rocket:")
 
 
 
 
 
11
 
12
+ # App title and description
13
  st.title("Zero-Shot Text Classification")
14
+ st.markdown("""
15
+ This app performs zero-shot text classification using the Facebook BART-Large-MNLI model.
16
+ Enter a sentence and candidate labels, and the model will predict the most relevant label.
17
+ """)
18
+
19
+ # Input text box for the sentence to classify
20
+ sequence_to_classify = st.text_input("Enter the sentence to classify:")
21
+
22
+ # Candidate labels input with help text
23
+ st.text("Enter candidate labels separated by commas (e.g., travel, cooking, dancing):")
24
+ candidate_labels = st.text_input("Candidate Labels:")
25
+
26
+ # Confidence threshold slider
27
+ confidence_threshold = st.slider("Confidence Threshold:", min_value=0.0, max_value=1.0, value=0.5, step=0.01)
28
+
29
+ # Classification button
30
+ if st.button("Classify"):
31
+ if sequence_to_classify and candidate_labels:
32
+ # Split candidate labels into a list
33
+ candidate_labels = [label.strip() for label in candidate_labels.split(",")]
34
+
35
+ # Perform classification
36
+ classification_result = classifier(sequence_to_classify, candidate_labels)
37
+
38
+ # Display classification results
39
+ st.subheader("Classification Results:")
40
+ for label, score in zip(classification_result["labels"], classification_result["scores"]):
41
+ if score >= confidence_threshold:
42
+ st.write(f"- {label}: {score}")
43
+ else:
44
+ st.write(f"- {label}: Below threshold ({score})")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
  if __name__ == "__main__":
47
  main()