ajeetkumar01 commited on
Commit
385dc2f
1 Parent(s): 993989f

added app.py file

Browse files
Files changed (1) hide show
  1. app.py +47 -0
app.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from transformers import pipeline
3
+
4
+ # Load the zero-shot classification model
5
+ classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
6
+
7
+ # Define Streamlit app
8
+ def main():
9
+ # Set page title and favicon
10
+ st.set_page_config(page_title="Zero-Shot Text Classification", page_icon=":rocket:")
11
+
12
+ # App title and description
13
+ st.title("Zero-Shot Text Classification")
14
+ st.markdown("""
15
+ This app performs zero-shot text classification using the Facebook BART-Large-MNLI model.
16
+ Enter a sentence and candidate labels, and the model will predict the most relevant label.
17
+ """)
18
+
19
+ # Input text box for the sentence to classify
20
+ sequence_to_classify = st.text_input("Enter the sentence to classify:")
21
+
22
+ # Candidate labels input with help text
23
+ st.text("Enter candidate labels separated by commas (e.g., travel, cooking, dancing):")
24
+ candidate_labels = st.text_input("Candidate Labels:")
25
+
26
+ # Confidence threshold slider
27
+ confidence_threshold = st.slider("Confidence Threshold:", min_value=0.0, max_value=1.0, value=0.5, step=0.01)
28
+
29
+ # Classification button
30
+ if st.button("Classify"):
31
+ if sequence_to_classify and candidate_labels:
32
+ # Split candidate labels into a list
33
+ candidate_labels = [label.strip() for label in candidate_labels.split(",")]
34
+
35
+ # Perform classification
36
+ classification_result = classifier(sequence_to_classify, candidate_labels)
37
+
38
+ # Display classification results
39
+ st.subheader("Classification Results:")
40
+ for label, score in zip(classification_result["labels"], classification_result["scores"]):
41
+ if score >= confidence_threshold:
42
+ st.write(f"- {label}: {score}")
43
+ else:
44
+ st.write(f"- {label}: Below threshold ({score})")
45
+
46
+ if __name__ == "__main__":
47
+ main()