ajeetkumar01 commited on
Commit
9c9e66f
1 Parent(s): 6a7dca8

updated the app view

Browse files
Files changed (1) hide show
  1. app.py +40 -61
app.py CHANGED
@@ -7,75 +7,54 @@ classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnl
7
  # Define Streamlit app
8
  def main():
9
  # Set page title and favicon
10
- st.set_page_config(
11
- page_title="Zero-Shot Text Classification",
12
- page_icon=":rocket:",
13
- layout="wide", # Set layout to wide for better spacing
14
- initial_sidebar_state="expanded" # Expand sidebar by default
15
- )
16
 
17
- # App title and description with colorful text
18
  st.title("Zero-Shot Text Classification")
19
- st.markdown(
20
  """
21
  This app performs zero-shot text classification using the Facebook BART-Large-MNLI model.
22
  Enter a sentence and candidate labels, and the model will predict the most relevant label.
23
  """
24
  )
25
 
26
- # Create a two-column layout
27
- col1, col2 = st.columns([1, 2])
28
-
29
- # Left pane: Input elements
30
- with col1:
31
- # Input text box for the sentence to classify
32
- sequence_to_classify = st.text_input("Enter the sentence to classify:")
33
-
34
- # Candidate labels input with help text
35
- st.text("Enter candidate labels separated by commas (e.g., travel, cooking, dancing):")
36
- candidate_labels = st.text_input("Candidate Labels:")
37
-
38
- # Confidence threshold slider with colorful track
39
- confidence_threshold = st.slider(
40
- "Confidence Threshold:",
41
- min_value=0.0,
42
- max_value=1.0,
43
- value=0.5,
44
- step=0.01,
45
- key="confidence_threshold",
46
- help="Move the slider to adjust the confidence threshold."
47
- )
48
-
49
- # Classification button with colorful background
50
- classify_button = st.button(
51
- "Classify",
52
- key="classify_button",
53
- help="Click the button to classify the input text with the provided labels."
54
- )
55
-
56
- # Right pane: Results
57
- with col2:
58
- if classify_button:
59
- if sequence_to_classify and candidate_labels:
60
- # Split candidate labels into a list
61
- candidate_labels = [label.strip() for label in candidate_labels.split(",")]
62
-
63
- # Perform classification
64
- classification_result = classifier(sequence_to_classify, candidate_labels)
65
-
66
- # Find label with highest score
67
- max_score_index = classification_result["scores"].index(max(classification_result["scores"]))
68
- max_label = classification_result["labels"][max_score_index]
69
- max_score = classification_result["scores"][max_score_index]
70
-
71
- # Display only the label with the highest score
72
- if max_score >= confidence_threshold:
73
- st.subheader("Classification Result:")
74
- st.write(f"- **{max_label}**: {max_score:.2f}", unsafe_allow_html=True)
75
- else:
76
- st.subheader("Classification Result:")
77
- st.write(f"- <span style='color: #888;'>{max_label}:</span> Below threshold ({max_score:.2f})", unsafe_allow_html=True)
78
 
79
  if __name__ == "__main__":
80
  main()
81
-
 
7
  # Define Streamlit app
8
  def main():
9
  # Set page title and favicon
10
+ st.set_page_config(page_title="Zero-Shot Text Classification", page_icon=":rocket:")
 
 
 
 
 
11
 
12
+ # App title and description
13
  st.title("Zero-Shot Text Classification")
14
+ st.write(
15
  """
16
  This app performs zero-shot text classification using the Facebook BART-Large-MNLI model.
17
  Enter a sentence and candidate labels, and the model will predict the most relevant label.
18
  """
19
  )
20
 
21
+ # Input text box for the sentence to classify
22
+ sequence_to_classify = st.text_input("Enter the sentence to classify:", key="input_sentence",
23
+ type="default", value="")
24
+
25
+ # Candidate labels input with help text
26
+ st.text("Enter candidate labels separated by commas (e.g., travel, cooking, dancing):")
27
+ candidate_labels = st.text_input("Candidate Labels:", key="input_labels", type="default", value="")
28
+
29
+ # Confidence threshold slider
30
+ confidence_threshold = st.slider("Confidence Threshold:", min_value=0.0, max_value=1.0, value=0.5, step=0.01,
31
+ key="confidence_threshold")
32
+
33
+ # Classification button
34
+ if st.button("Classify", key="classify_button"):
35
+ if sequence_to_classify and candidate_labels:
36
+ # Split candidate labels into a list
37
+ candidate_labels = [label.strip() for label in candidate_labels.split(",")]
38
+
39
+ # Perform classification
40
+ classification_result = classifier(sequence_to_classify, candidate_labels)
41
+
42
+ # Find label with highest score
43
+ max_score_index = classification_result["scores"].index(max(classification_result["scores"]))
44
+ max_label = classification_result["labels"][max_score_index]
45
+ max_score = classification_result["scores"][max_score_index]
46
+
47
+ # Display classification results
48
+ st.subheader("Classification Results:")
49
+ st.markdown("---")
50
+ st.markdown(f"**{max_label}**: {max_score:.2f}", unsafe_allow_html=True)
51
+ st.markdown("---")
52
+ for label, score in zip(classification_result["labels"], classification_result["scores"]):
53
+ if label != max_label:
54
+ if score >= confidence_threshold:
55
+ st.text(f"'{label}': {score:.2f}")
56
+ else:
57
+ st.text(f"'{label}': Below threshold ({score:.2f})")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
  if __name__ == "__main__":
60
  main()