sileod commited on
Commit
3f8cae3
·
verified ·
1 Parent(s): 89f43c3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +65 -14
app.py CHANGED
@@ -1,23 +1,74 @@
1
  import gradio as gr
2
  from transformers import pipeline
 
3
 
4
- def zero_shot_classification(text, labels):
5
- classifier = pipeline("zero-shot-classification", model="models/tasksource/ModernBERT-nli")
6
- result = classifier(text, labels)
7
- return {label: score for label, score in zip(result['labels'], result['scores'])}
 
 
 
 
 
 
8
 
9
- default_text = "all cats are blue"
10
- default_labels = ['true', 'false']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
- demo = gr.Interface(
13
- fn=zero_shot_classification,
 
14
  inputs=[
15
- gr.Textbox(label="Input Text", value=default_text),
16
- gr.Textbox(label="Possible Labels (comma-separated)", value=','.join(default_labels))
 
 
 
 
 
 
 
 
17
  ],
18
- outputs=gr.Label(label="Classification Scores"),
19
- title="Zero-Shot Classification",
20
- description="Classify a text into labels without prior training for the specific labels."
 
 
 
 
 
21
  )
22
 
23
- demo.launch()
 
 
 
1
  import gradio as gr
2
  from transformers import pipeline
3
+ import torch
4
 
5
+ # Initialize the zero-shot classification pipeline
6
+ try:
7
+ classifier = pipeline(
8
+ "zero-shot-classification",
9
+ model="models/tasksource/ModernBERT-nli",
10
+ device=0 if torch.cuda.is_available() else -1
11
+ )
12
+ except Exception as e:
13
+ print(f"Error loading model: {e}")
14
+ classifier = None
15
 
16
+ def classify_text(text, candidate_labels):
17
+ """
18
+ Perform zero-shot classification on input text.
19
+
20
+ Args:
21
+ text (str): Input text to classify
22
+ candidate_labels (str): Comma-separated string of possible labels
23
+
24
+ Returns:
25
+ dict: Dictionary containing labels and their corresponding scores
26
+ """
27
+ if classifier is None:
28
+ return {"Error": "Model failed to load"}
29
+
30
+ try:
31
+ # Convert comma-separated string to list
32
+ labels = [label.strip() for label in candidate_labels.split(",")]
33
+
34
+ # Perform classification
35
+ result = classifier(text, labels)
36
+
37
+ # Create formatted output
38
+ output = {}
39
+ for label, score in zip(result["labels"], result["scores"]):
40
+ output[label] = f"{score:.4f}"
41
+
42
+ return output
43
+
44
+ except Exception as e:
45
+ return {"Error": str(e)}
46
 
47
+ # Create Gradio interface
48
+ iface = gr.Interface(
49
+ fn=classify_text,
50
  inputs=[
51
+ gr.Textbox(
52
+ label="Text to classify",
53
+ placeholder="Enter text here...",
54
+ value="all cats are blue"
55
+ ),
56
+ gr.Textbox(
57
+ label="Possible labels (comma-separated)",
58
+ placeholder="Enter labels...",
59
+ value="true,false"
60
+ )
61
  ],
62
+ outputs=gr.Label(label="Classification Results"),
63
+ title="Zero-Shot Text Classification",
64
+ description="Classify text into given categories without any training examples.",
65
+ examples=[
66
+ ["all cats are blue", "true,false"],
67
+ ["the sky is above us", "true,false"],
68
+ ["birds can fly", "true,false,unknown"]
69
+ ]
70
  )
71
 
72
+ # Launch the app
73
+ if __name__ == "__main__":
74
+ iface.launch(share=True)