saattrupdan commited on
Commit
af97119
1 Parent(s): 1a75433

feat: Update deps, sort out deprecation warnings

Browse files
Files changed (2) hide show
  1. app.py +106 -92
  2. requirements.txt +5 -2
app.py CHANGED
@@ -2,75 +2,30 @@
2
 
3
  from typing import Dict, Tuple
4
  import gradio as gr
5
- from transformers import pipeline
 
 
6
  from luga import language as detect_language
 
7
  import re
8
-
9
-
10
- def classification(
11
- doc: str,
12
- da_hypothesis_template: str,
13
- da_candidate_labels: str,
14
- sv_hypothesis_template: str,
15
- sv_candidate_labels: str,
16
- no_hypothesis_template: str,
17
- no_candidate_labels: str,
18
- ) -> Dict[str, float]:
19
- """Classify text into categories.
20
-
21
- Args:
22
- doc (str):
23
- Text to classify.
24
- da_hypothesis_template (str):
25
- Template for the hypothesis to be used for Danish classification.
26
- da_candidate_labels (str):
27
- Comma-separated list of candidate labels for Danish classification.
28
- sv_hypothesis_template (str):
29
- Template for the hypothesis to be used for Swedish classification.
30
- sv_candidate_labels (str):
31
- Comma-separated list of candidate labels for Swedish classification.
32
- no_hypothesis_template (str):
33
- Template for the hypothesis to be used for Norwegian classification.
34
- no_candidate_labels (str):
35
- Comma-separated list of candidate labels for Norwegian classification.
36
-
37
- Returns:
38
- dict of str to float:
39
- The predicted label and the confidence score.
40
- """
41
- # Detect the language of the text
42
- language = detect_language(doc.replace('\n', ' ')).name
43
-
44
- # Set the hypothesis template and candidate labels based on the detected language
45
- if language == "sv":
46
- hypothesis_template = sv_hypothesis_template
47
- candidate_labels = re.split(r', *', sv_candidate_labels)
48
- elif language == "no":
49
- hypothesis_template = no_hypothesis_template
50
- candidate_labels = re.split(r', *', no_candidate_labels)
51
- else:
52
- hypothesis_template = da_hypothesis_template
53
- candidate_labels = re.split(r', *', da_candidate_labels)
54
-
55
- # Run the classifier on the text
56
- result = classifier(
57
- doc,
58
- candidate_labels=candidate_labels,
59
- hypothesis_template=hypothesis_template,
60
- )
61
-
62
- print(result)
63
-
64
- # Return the predicted label
65
- return {lbl: score for lbl, score in zip(result["labels"], result["scores"])}
66
 
67
 
68
  def main():
 
 
69
 
70
  # Load the zero-shot classification pipeline
71
- global classifier
72
- classifier = pipeline(
73
- "zero-shot-classification", model="alexandrainst/scandi-nli-large"
 
 
 
 
 
 
74
  )
75
 
76
  # Create dictionary of descriptions for each task, containing the hypothesis template
@@ -124,8 +79,8 @@ def main():
124
  with gr.Blocks() as demo:
125
 
126
  # Create title and description
127
- gr.Markdown("# Scandinavian Zero-shot Text Classification")
128
- gr.Markdown("""
129
  Classify text in Danish, Swedish or Norwegian into categories, without
130
  finetuning on any training data!
131
 
@@ -140,13 +95,13 @@ def main():
140
  _Also, be patient, as this demo is running on a CPU!_
141
  """)
142
 
143
- with gr.Row():
144
 
145
  # Input column
146
- with gr.Column():
147
 
148
  # Create a dropdown menu for the task
149
- dropdown = gr.inputs.Dropdown(
150
  label="Task",
151
  choices=[
152
  "Sentiment classification",
@@ -155,37 +110,37 @@ def main():
155
  "Product feedback detection",
156
  "Define your own task!",
157
  ],
158
- default="Sentiment classification",
159
  )
160
 
161
- with gr.Row(variant="compact"):
162
- da_hypothesis_template = gr.inputs.Textbox(
163
  label="Danish hypothesis template",
164
- default="Dette eksempel er {}.",
165
  )
166
- da_candidate_labels = gr.inputs.Textbox(
167
  label="Danish candidate labels (comma separated)",
168
- default="positivt, negativt, neutralt",
169
  )
170
 
171
- with gr.Row(variant="compact"):
172
- sv_hypothesis_template = gr.inputs.Textbox(
173
  label="Swedish hypothesis template",
174
- default="Detta exempel är {}.",
175
  )
176
- sv_candidate_labels = gr.inputs.Textbox(
177
  label="Swedish candidate labels (comma separated)",
178
- default="positivt, negativt, neutralt",
179
  )
180
 
181
- with gr.Row(variant="compact"):
182
- no_hypothesis_template = gr.inputs.Textbox(
183
  label="Norwegian hypothesis template",
184
- default="Dette eksemplet er {}.",
185
  )
186
- no_candidate_labels = gr.inputs.Textbox(
187
  label="Norwegian candidate labels (comma separated)",
188
- default="positivt, negativt, nøytralt",
189
  )
190
 
191
  # When a new task is chosen, update the description
@@ -203,16 +158,16 @@ def main():
203
  )
204
 
205
  # Output column
206
- with gr.Column():
207
 
208
  # Create a text box for the input text
209
- input_textbox = gr.inputs.Textbox(
210
- label="Input text", default="Jeg er helt vild med fodbolden 😊"
211
  )
212
 
213
- with gr.Row():
214
- clear_btn = gr.Button(value="Clear", width=0.5)
215
- submit_btn = gr.Button(value="Submit", width=0.5, variant="primary")
216
 
217
  # When the clear button is clicked, clear the input text box
218
  clear_btn.click(
@@ -220,10 +175,10 @@ def main():
220
  )
221
 
222
 
223
- with gr.Column():
224
 
225
  # Create output text box
226
- output_textbox = gr.Label(label="Result")
227
 
228
  # When the submit button is clicked, run the classifier on the input text
229
  # and display the result in the output text box
@@ -242,7 +197,66 @@ def main():
242
  )
243
 
244
  # Run the app
245
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
246
 
247
 
248
  if __name__ == "__main__":
 
2
 
3
  from typing import Dict, Tuple
4
  import gradio as gr
5
+ from gradio.components import Dropdown, Textbox, Row, Column, Button, Label, Markdown
6
+ from types import MethodType
7
+ from transformers import pipeline, AutoModelForSequenceClassification, AutoTokenizer
8
  from luga import language as detect_language
9
+ import torch
10
  import re
11
+ import os
12
+ import torch._dynamo
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
 
15
  def main():
16
+ # Disable tokenizers parallelism
17
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
18
 
19
  # Load the zero-shot classification pipeline
20
+ global classifier, model, tokenizer
21
+ model_id = "alexandrainst/scandi-nli-large"
22
+ model = AutoModelForSequenceClassification.from_pretrained(model_id)
23
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
24
+ model = torch.compile(model=model, backend="aot_eager")
25
+ model.eval()
26
+ classifier = pipeline("zero-shot-classification", model=model, tokenizer=tokenizer)
27
+ classifier.get_inference_context = MethodType(
28
+ lambda self: torch.no_grad, classifier
29
  )
30
 
31
  # Create dictionary of descriptions for each task, containing the hypothesis template
 
79
  with gr.Blocks() as demo:
80
 
81
  # Create title and description
82
+ Markdown("# Scandinavian Zero-shot Text Classification")
83
+ Markdown("""
84
  Classify text in Danish, Swedish or Norwegian into categories, without
85
  finetuning on any training data!
86
 
 
95
  _Also, be patient, as this demo is running on a CPU!_
96
  """)
97
 
98
+ with Row():
99
 
100
  # Input column
101
+ with Column():
102
 
103
  # Create a dropdown menu for the task
104
+ dropdown = Dropdown(
105
  label="Task",
106
  choices=[
107
  "Sentiment classification",
 
110
  "Product feedback detection",
111
  "Define your own task!",
112
  ],
113
+ value="Sentiment classification",
114
  )
115
 
116
+ with Row(variant="compact"):
117
+ da_hypothesis_template = Textbox(
118
  label="Danish hypothesis template",
119
+ value="Dette eksempel er {}.",
120
  )
121
+ da_candidate_labels = Textbox(
122
  label="Danish candidate labels (comma separated)",
123
+ value="positivt, negativt, neutralt",
124
  )
125
 
126
+ with Row(variant="compact"):
127
+ sv_hypothesis_template = Textbox(
128
  label="Swedish hypothesis template",
129
+ value="Detta exempel är {}.",
130
  )
131
+ sv_candidate_labels = Textbox(
132
  label="Swedish candidate labels (comma separated)",
133
+ value="positivt, negativt, neutralt",
134
  )
135
 
136
+ with Row(variant="compact"):
137
+ no_hypothesis_template = Textbox(
138
  label="Norwegian hypothesis template",
139
+ value="Dette eksemplet er {}.",
140
  )
141
+ no_candidate_labels = Textbox(
142
  label="Norwegian candidate labels (comma separated)",
143
+ value="positivt, negativt, nøytralt",
144
  )
145
 
146
  # When a new task is chosen, update the description
 
158
  )
159
 
160
  # Output column
161
+ with Column():
162
 
163
  # Create a text box for the input text
164
+ input_textbox = Textbox(
165
+ label="Input text", value="Jeg er helt vild med fodbolden 😊"
166
  )
167
 
168
+ with Row():
169
+ clear_btn = Button(value="Clear")
170
+ submit_btn = Button(value="Submit", variant="primary")
171
 
172
  # When the clear button is clicked, clear the input text box
173
  clear_btn.click(
 
175
  )
176
 
177
 
178
+ with Column():
179
 
180
  # Create output text box
181
+ output_textbox = Label(label="Result")
182
 
183
  # When the submit button is clicked, run the classifier on the input text
184
  # and display the result in the output text box
 
197
  )
198
 
199
  # Run the app
200
+ demo.launch(width=.5)
201
+
202
+
203
+ @torch.compile()
204
+ def classification(
205
+ doc: str,
206
+ da_hypothesis_template: str,
207
+ da_candidate_labels: str,
208
+ sv_hypothesis_template: str,
209
+ sv_candidate_labels: str,
210
+ no_hypothesis_template: str,
211
+ no_candidate_labels: str,
212
+ ) -> Dict[str, float]:
213
+ """Classify text into categories.
214
+
215
+ Args:
216
+ doc (str):
217
+ Text to classify.
218
+ da_hypothesis_template (str):
219
+ Template for the hypothesis to be used for Danish classification.
220
+ da_candidate_labels (str):
221
+ Comma-separated list of candidate labels for Danish classification.
222
+ sv_hypothesis_template (str):
223
+ Template for the hypothesis to be used for Swedish classification.
224
+ sv_candidate_labels (str):
225
+ Comma-separated list of candidate labels for Swedish classification.
226
+ no_hypothesis_template (str):
227
+ Template for the hypothesis to be used for Norwegian classification.
228
+ no_candidate_labels (str):
229
+ Comma-separated list of candidate labels for Norwegian classification.
230
+
231
+ Returns:
232
+ dict of str to float:
233
+ The predicted label and the confidence score.
234
+ """
235
+ # Detect the language of the text
236
+ language = detect_language(doc.replace('\n', ' ')).name
237
+
238
+ # Set the hypothesis template and candidate labels based on the detected language
239
+ if language == "sv":
240
+ hypothesis_template = sv_hypothesis_template
241
+ candidate_labels = re.split(r', *', sv_candidate_labels)
242
+ elif language == "no":
243
+ hypothesis_template = no_hypothesis_template
244
+ candidate_labels = re.split(r', *', no_candidate_labels)
245
+ else:
246
+ hypothesis_template = da_hypothesis_template
247
+ candidate_labels = re.split(r', *', da_candidate_labels)
248
+
249
+ # Run the classifier on the text
250
+ result = classifier(
251
+ doc,
252
+ candidate_labels=candidate_labels,
253
+ hypothesis_template=hypothesis_template,
254
+ )
255
+
256
+ print(result)
257
+
258
+ # Return the predicted label
259
+ return {lbl: score for lbl, score in zip(result["labels"], result["scores"])}
260
 
261
 
262
  if __name__ == "__main__":
requirements.txt CHANGED
@@ -35,7 +35,9 @@ MarkupSafe==2.1.1
35
  matplotlib==3.6.2
36
  mdit-py-plugins==0.3.1
37
  mdurl==0.1.2
 
38
  multidict==6.0.2
 
39
  nptyping==1.4.4
40
  numpy==1.23.5
41
  orjson==3.8.2
@@ -62,10 +64,11 @@ six==1.16.0
62
  sniffio==1.3.0
63
  soupsieve==2.3.2.post1
64
  starlette==0.22.0
 
65
  tokenizers==0.13.2
66
- torch==1.12.1
67
  tqdm==4.64.1
68
- transformers==4.24.0
69
  typing_extensions==4.4.0
70
  typish==1.9.3
71
  uc-micro-py==1.0.1
 
35
  matplotlib==3.6.2
36
  mdit-py-plugins==0.3.1
37
  mdurl==0.1.2
38
+ mpmath==1.3.0
39
  multidict==6.0.2
40
+ networkx==3.1
41
  nptyping==1.4.4
42
  numpy==1.23.5
43
  orjson==3.8.2
 
64
  sniffio==1.3.0
65
  soupsieve==2.3.2.post1
66
  starlette==0.22.0
67
+ sympy==1.11.1
68
  tokenizers==0.13.2
69
+ torch==2.0.0
70
  tqdm==4.64.1
71
+ transformers==4.28.1
72
  typing_extensions==4.4.0
73
  typish==1.9.3
74
  uc-micro-py==1.0.1