sdiazlor HF staff commited on
Commit
4e19310
1 Parent(s): 857f1ba

update textcat (separate prompt and labels) and use input parameters

Browse files
src/distilabel_dataset_generator/apps/textcat.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import re
2
  import uuid
3
  from typing import List, Union
@@ -24,7 +25,6 @@ from src.distilabel_dataset_generator.pipelines.embeddings import (
24
  )
25
  from src.distilabel_dataset_generator.pipelines.textcat import (
26
  DEFAULT_DATASET_DESCRIPTIONS,
27
- PROMPT_CREATION_PROMPT,
28
  generate_pipeline_code,
29
  get_labeller_generator,
30
  get_prompt_generator,
@@ -44,36 +44,33 @@ def generate_system_prompt(dataset_description, temperature, progress=gr.Progres
44
  progress(0.3, desc="Initializing text generation")
45
  generate_description = get_prompt_generator(temperature)
46
  progress(0.7, desc="Generating text classification task")
47
- system_prompt = next(
48
  generate_description.process(
49
  [
50
  {
51
- "system_prompt": PROMPT_CREATION_PROMPT,
52
  "instruction": dataset_description,
53
  }
54
  ]
55
  )
56
  )[0]["generation"]
57
  progress(1.0, desc="Text classification task generated")
58
- return system_prompt, pd.DataFrame()
 
 
 
59
 
60
-
61
- def generate_sample_dataset(system_prompt, progress=gr.Progress()):
62
- df = generate_dataset(
63
  system_prompt=system_prompt,
64
- difficulty="mixed",
65
- clarity="mixed",
66
- labels=[],
67
- num_labels=1,
68
  num_rows=10,
69
  progress=progress,
70
  is_sample=True,
71
  )
72
- if "label" in df.columns:
73
- df = df[["label", "text"]]
74
- elif "labels" in df.columns:
75
- df = df[["labels", "text"]]
76
- return df
77
 
78
 
79
  def generate_dataset(
@@ -86,17 +83,13 @@ def generate_dataset(
86
  is_sample: bool = False,
87
  progress=gr.Progress(),
88
  ) -> pd.DataFrame:
89
- if is_sample:
90
- multiplier = 1
91
- else:
92
- multiplier = 2
93
  progress(0.0, desc="(1/2) Generating text classification data")
94
  labels = get_preprocess_labels(labels)
95
  textcat_generator = get_textcat_generator(
96
  difficulty=difficulty, clarity=clarity, is_sample=is_sample
97
  )
98
  labeller_generator = get_labeller_generator(
99
- system_prompt=system_prompt,
100
  labels=labels,
101
  num_labels=num_labels,
102
  )
@@ -108,13 +101,15 @@ def generate_dataset(
108
  textcat_results = []
109
  while n_processed < num_rows:
110
  progress(
111
- multiplier * 0.5 * n_processed / num_rows,
112
  total=total_steps,
113
  desc="(1/2) Generating text classification data",
114
  )
115
  remaining_rows = num_rows - n_processed
116
  batch_size = min(batch_size, remaining_rows)
117
- inputs = [{"task": system_prompt} for _ in range(batch_size)]
 
 
118
  batch = list(textcat_generator.process(inputs=inputs))
119
  textcat_results.extend(batch[0])
120
  n_processed += batch_size
@@ -122,58 +117,41 @@ def generate_dataset(
122
  result["text"] = result["input_text"]
123
 
124
  # label text classification data
125
- progress(multiplier * 0.5, desc="(1/2) Generating text classification data")
126
- if not is_sample:
127
- n_processed = 0
128
- labeller_results = []
129
- while n_processed < num_rows:
130
- progress(
131
- 0.5 + 0.5 * n_processed / num_rows,
132
- total=total_steps,
133
- desc="(1/2) Labeling text classification data",
134
- )
135
- batch = textcat_results[n_processed : n_processed + batch_size]
136
- labels_batch = list(labeller_generator.process(inputs=batch))
137
- labeller_results.extend(labels_batch[0])
138
- n_processed += batch_size
139
  progress(
140
- 1,
141
  total=total_steps,
142
- desc="(2/2) Creating dataset",
143
  )
 
 
 
 
 
 
 
 
 
144
 
145
  # create final dataset
146
  distiset_results = []
147
- source_results = textcat_results if is_sample else labeller_results
148
- for result in source_results:
149
  record = {
150
  key: result[key]
151
- for key in ["text", "label" if is_sample else "labels"]
152
  if key in result
153
  }
154
  distiset_results.append(record)
155
 
156
  dataframe = pd.DataFrame(distiset_results)
157
- if not is_sample:
158
- if num_labels == 1:
159
- dataframe = dataframe.rename(columns={"labels": "label"})
160
- dataframe["label"] = dataframe["label"].apply(
161
- lambda x: x.lower().strip() if x.lower().strip() in labels else None
162
- )
163
- else:
164
- dataframe["labels"] = dataframe["labels"].apply(
165
- lambda x: (
166
- list(
167
- set(
168
- label.lower().strip()
169
- for label in x
170
- if label.lower().strip() in labels
171
- )
172
- )
173
- if isinstance(x, list)
174
- else None
175
- )
176
- )
177
  progress(1.0, desc="Dataset generation completed")
178
  return dataframe
179
 
@@ -281,7 +259,7 @@ def push_dataset_to_argilla(
281
  )
282
 
283
  dataframe["text_length"] = dataframe["text"].apply(len)
284
- dataframe["text_embeddings"] = get_embeddings(dataframe["text"])
285
 
286
  progress(0.5, desc="Creating dataset")
287
  rg_dataset = client.datasets(name=repo_name, workspace=hf_user)
@@ -330,15 +308,6 @@ def push_dataset_to_argilla(
330
  return ""
331
 
332
 
333
- def update_suggested_labels(system_prompt):
334
- new_labels = re.findall(r"'(\b[\w-]+\b)'", system_prompt)
335
- if not new_labels:
336
- return gr.Warning(
337
- "No labels found in the system prompt. Please add labels manually."
338
- )
339
- return gr.update(choices=new_labels, value=new_labels)
340
-
341
-
342
  def validate_input_labels(labels):
343
  if not labels or len(labels) < 2:
344
  raise gr.Error(
@@ -448,7 +417,7 @@ with gr.Blocks(css=_LOGGED_OUT_CSS) as app:
448
  )
449
  with gr.Column(scale=3):
450
  dataframe = gr.Dataframe(
451
- headers=["labels", "text"], wrap=True, height=300, column_widths=[1, 3]
452
  )
453
 
454
  gr.HTML("<hr>")
@@ -496,27 +465,35 @@ with gr.Blocks(css=_LOGGED_OUT_CSS) as app:
496
  label="Distilabel Pipeline Code",
497
  )
498
 
499
- gr.on(
500
- triggers=[load_btn.click, btn_apply_to_sample_dataset.click],
501
  fn=generate_system_prompt,
502
  inputs=[dataset_description, temperature],
503
- outputs=[system_prompt, dataframe],
504
  show_progress=True,
505
  ).then(
506
  fn=generate_sample_dataset,
507
- inputs=[system_prompt],
508
  outputs=[dataframe],
509
  show_progress=True,
510
  ).then(
511
- fn=update_suggested_labels,
512
- inputs=[system_prompt],
513
- outputs=labels,
514
- ).then(
515
  fn=update_max_num_labels,
516
  inputs=[labels],
517
  outputs=[num_labels],
518
  )
519
 
 
 
 
 
 
 
 
 
 
 
 
 
 
520
  btn_push_to_hub.click(
521
  fn=validate_argilla_user_workspace_dataset,
522
  inputs=[repo_name],
 
1
+ import json
2
  import re
3
  import uuid
4
  from typing import List, Union
 
25
  )
26
  from src.distilabel_dataset_generator.pipelines.textcat import (
27
  DEFAULT_DATASET_DESCRIPTIONS,
 
28
  generate_pipeline_code,
29
  get_labeller_generator,
30
  get_prompt_generator,
 
44
  progress(0.3, desc="Initializing text generation")
45
  generate_description = get_prompt_generator(temperature)
46
  progress(0.7, desc="Generating text classification task")
47
+ result = next(
48
  generate_description.process(
49
  [
50
  {
 
51
  "instruction": dataset_description,
52
  }
53
  ]
54
  )
55
  )[0]["generation"]
56
  progress(1.0, desc="Text classification task generated")
57
+ data = json.loads(result)
58
+ system_prompt = data["classification_task"]
59
+ labels = data["labels"]
60
+ return system_prompt, labels
61
 
62
+ def generate_sample_dataset(system_prompt, difficulty, clarity, labels, num_labels, progress=gr.Progress()):
63
+ dataframe = generate_dataset(
 
64
  system_prompt=system_prompt,
65
+ difficulty=difficulty,
66
+ clarity=clarity,
67
+ labels=labels,
68
+ num_labels=num_labels,
69
  num_rows=10,
70
  progress=progress,
71
  is_sample=True,
72
  )
73
+ return dataframe
 
 
 
 
74
 
75
 
76
  def generate_dataset(
 
83
  is_sample: bool = False,
84
  progress=gr.Progress(),
85
  ) -> pd.DataFrame:
 
 
 
 
86
  progress(0.0, desc="(1/2) Generating text classification data")
87
  labels = get_preprocess_labels(labels)
88
  textcat_generator = get_textcat_generator(
89
  difficulty=difficulty, clarity=clarity, is_sample=is_sample
90
  )
91
  labeller_generator = get_labeller_generator(
92
+ system_prompt=f"{system_prompt} {', '.join(labels)}",
93
  labels=labels,
94
  num_labels=num_labels,
95
  )
 
101
  textcat_results = []
102
  while n_processed < num_rows:
103
  progress(
104
+ 2 * 0.5 * n_processed / num_rows,
105
  total=total_steps,
106
  desc="(1/2) Generating text classification data",
107
  )
108
  remaining_rows = num_rows - n_processed
109
  batch_size = min(batch_size, remaining_rows)
110
+ inputs = [
111
+ {"task": f"{system_prompt} {', '.join(labels)}"} for _ in range(batch_size)
112
+ ]
113
  batch = list(textcat_generator.process(inputs=inputs))
114
  textcat_results.extend(batch[0])
115
  n_processed += batch_size
 
117
  result["text"] = result["input_text"]
118
 
119
  # label text classification data
120
+ progress(2 * 0.5, desc="(1/2) Generating text classification data")
121
+ n_processed = 0
122
+ labeller_results = []
123
+ while n_processed < num_rows:
 
 
 
 
 
 
 
 
 
 
124
  progress(
125
+ 0.5 + 0.5 * n_processed / num_rows,
126
  total=total_steps,
127
+ desc="(1/2) Labeling text classification data",
128
  )
129
+ batch = textcat_results[n_processed : n_processed + batch_size]
130
+ labels_batch = list(labeller_generator.process(inputs=batch))
131
+ labeller_results.extend(labels_batch[0])
132
+ n_processed += batch_size
133
+ progress(
134
+ 1,
135
+ total=total_steps,
136
+ desc="(2/2) Creating dataset",
137
+ )
138
 
139
  # create final dataset
140
  distiset_results = []
141
+ for result in labeller_results:
 
142
  record = {
143
  key: result[key]
144
+ for key in ["labels", "text"]
145
  if key in result
146
  }
147
  distiset_results.append(record)
148
 
149
  dataframe = pd.DataFrame(distiset_results)
150
+ if num_labels == 1:
151
+ dataframe = dataframe.rename(columns={"labels": "label"})
152
+ dataframe["label"] = dataframe["label"].apply(
153
+ lambda x: x.lower().strip() if x.lower().strip() in labels else None
154
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
  progress(1.0, desc="Dataset generation completed")
156
  return dataframe
157
 
 
259
  )
260
 
261
  dataframe["text_length"] = dataframe["text"].apply(len)
262
+ dataframe["text_embeddings"] = get_embeddings(dataframe["text"].to_list())
263
 
264
  progress(0.5, desc="Creating dataset")
265
  rg_dataset = client.datasets(name=repo_name, workspace=hf_user)
 
308
  return ""
309
 
310
 
 
 
 
 
 
 
 
 
 
311
  def validate_input_labels(labels):
312
  if not labels or len(labels) < 2:
313
  raise gr.Error(
 
417
  )
418
  with gr.Column(scale=3):
419
  dataframe = gr.Dataframe(
420
+ headers=["labels", "text"], wrap=True, height=500, interactive=False
421
  )
422
 
423
  gr.HTML("<hr>")
 
465
  label="Distilabel Pipeline Code",
466
  )
467
 
468
+ load_btn.click(
 
469
  fn=generate_system_prompt,
470
  inputs=[dataset_description, temperature],
471
+ outputs=[system_prompt, labels],
472
  show_progress=True,
473
  ).then(
474
  fn=generate_sample_dataset,
475
+ inputs=[system_prompt, difficulty, clarity, labels, num_labels],
476
  outputs=[dataframe],
477
  show_progress=True,
478
  ).then(
 
 
 
 
479
  fn=update_max_num_labels,
480
  inputs=[labels],
481
  outputs=[num_labels],
482
  )
483
 
484
+ labels.input(
485
+ fn=update_max_num_labels,
486
+ inputs=[labels],
487
+ outputs=[num_labels],
488
+ )
489
+
490
+ btn_apply_to_sample_dataset.click(
491
+ fn=generate_sample_dataset,
492
+ inputs=[system_prompt, difficulty, clarity, labels, num_labels],
493
+ outputs=[dataframe],
494
+ show_progress=True,
495
+ )
496
+
497
  btn_push_to_hub.click(
498
  fn=validate_argilla_user_workspace_dataset,
499
  inputs=[repo_name],
src/distilabel_dataset_generator/pipelines/textcat.py CHANGED
@@ -1,4 +1,5 @@
1
  import random
 
2
  from typing import List
3
 
4
  from distilabel.llms import InferenceEndpointsLLM
@@ -22,25 +23,27 @@ The prompt you write should follow the same style and structure as the following
22
 
23
  If a label is composed of multiple words, use a hyphen to separate them. For example, 'smartphone-review', 'customer-service', 'product-quality'.:
24
 
25
- Classify the following customer review of a cinema as either 'positive' or 'negative'.
26
 
27
- Classify the following news article into one or more of the following categories: 'politics', 'sports', 'technology', 'entertainment', 'health', 'business', 'environment', 'education', 'science', 'international'.
28
 
29
- Determine the sentiment of the following social media post: 'ambiguous', 'sarcastic', 'informative', 'emotional'.
30
 
31
- Identify the issue category for the following technical support ticket: 'billing', 'technical', 'account', 'shipping', 'returns', 'installation', 'subscription'.
32
 
33
- Classify the following movie review into one of the following categories: 'critical', 'praise', 'disappointed', 'enthusiastic'.
34
 
35
- Determine the level of customer satisfaction from the following customer service transcript: 'satisfied', 'dissatisfied', 'highly-satisfied', 'somewhat-dissatisfied', 'indifferent'.
36
 
37
- Categorize the following product description into one of the following product types: 'smartphone', 'laptop', 'tablet', 'smartwatch', 'e-reader', 'headphones'.
38
 
39
- Classify the following tweet as expressing either 'support' or 'opposition' to the political event discussed.
40
 
41
- Classify the following restaurant review into one of the following categories: 'food-quality', 'service', 'ambiance', or 'price'.
42
 
43
- Classify the following blog post based on its primary fashion trend or style: 'casual', 'formal', 'streetwear', 'vintage' or 'sustainable-fashion'.
 
 
44
 
45
  User dataset description:
46
  """
@@ -51,6 +54,82 @@ DEFAULT_DATASET_DESCRIPTIONS = [
51
  ]
52
 
53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  def generate_pipeline_code(
55
  system_prompt: str,
56
  difficulty: str = None,
@@ -146,63 +225,3 @@ with Pipeline(name="textcat") as pipeline:
146
  distiset = pipeline.run()
147
  """
148
  )
149
-
150
-
151
- def get_prompt_generator(temperature):
152
- prompt_generator = TextGeneration(
153
- llm=InferenceEndpointsLLM(
154
- api_key=_get_next_api_key(),
155
- model_id=MODEL,
156
- tokenizer_id=MODEL,
157
- generation_kwargs={
158
- "temperature": temperature,
159
- "max_new_tokens": 2048,
160
- "do_sample": True,
161
- },
162
- ),
163
- use_system_prompt=True,
164
- )
165
- prompt_generator.load()
166
- return prompt_generator
167
-
168
-
169
- def get_textcat_generator(difficulty, clarity, is_sample):
170
- textcat_generator = GenerateTextClassificationData(
171
- llm=InferenceEndpointsLLM(
172
- model_id=MODEL,
173
- tokenizer_id=MODEL,
174
- api_key=_get_next_api_key(),
175
- generation_kwargs={
176
- "temperature": 0.9,
177
- "max_new_tokens": 256 if is_sample else 2048,
178
- "do_sample": True,
179
- "top_k": 50,
180
- "top_p": 0.95,
181
- },
182
- ),
183
- difficulty=None if difficulty == "mixed" else difficulty,
184
- clarity=None if clarity == "mixed" else clarity,
185
- seed=random.randint(0, 2**32 - 1),
186
- )
187
- textcat_generator.load()
188
- return textcat_generator
189
-
190
-
191
- def get_labeller_generator(system_prompt, labels, num_labels):
192
- labeller_generator = TextClassification(
193
- llm=InferenceEndpointsLLM(
194
- model_id=MODEL,
195
- tokenizer_id=MODEL,
196
- api_key=_get_next_api_key(),
197
- generation_kwargs={
198
- "temperature": 0.7,
199
- "max_new_tokens": 2048,
200
- },
201
- ),
202
- context=system_prompt,
203
- available_labels=labels,
204
- n=num_labels,
205
- default_label="unknown",
206
- )
207
- labeller_generator.load()
208
- return labeller_generator
 
1
  import random
2
+ from pydantic import BaseModel, Field
3
  from typing import List
4
 
5
  from distilabel.llms import InferenceEndpointsLLM
 
23
 
24
  If a label is composed of multiple words, use a hyphen to separate them. For example, 'smartphone-review', 'customer-service', 'product-quality'.:
25
 
26
+ {"classification_task": "Classify the following customer review of a cinema as", "labels": ["positive", "negative"]}
27
 
28
+ {"classification_task": "Categorize the following news article into one or more of the following categories:", "labels": ["politics", "sports", "technology", "entertainment", "health", "business", "environment", "education", "science", "international"]}
29
 
30
+ {"classification_task": "Classify the following news article into one or more of the following categories:", "labels": ['politics', 'sports', 'technology', 'entertainment', 'health', 'business', 'environment', 'education', 'science', 'international']}
31
 
32
+ {"classification_task": "Determine the sentiment of the following social media post:", "labels": ['ambiguous', 'sarcastic', 'informative', 'emotional']}
33
 
34
+ {"classification_task": "Identify the issue category for the following technical support ticket:", "labels": ['billing', 'technical', 'account', 'shipping', 'returns', 'installation', 'subscription']}
35
 
36
+ {"classification_task": "Classify the following movie review into one of the following categories:", "labels": ['critical', 'praise', 'disappointed', 'enthusiastic']}
37
 
38
+ {"classification_task": "Categorize the following customer service transcript into one of the following categories:", "labels": ['satisfied', 'dissatisfied', 'highly-satisfied', 'somewhat-dissatisfied', 'indifferent']}
39
 
40
+ {"classification_task": "Classify the following product description into one of the following product types:", "labels": ['smartphone', 'laptop', 'tablet', 'smartwatch', 'e-reader', 'headphones']}
41
 
42
+ {"classification_task": "Categorize the following tweet expressing the political event discussed as", "labels": ['support', 'opposition']}
43
 
44
+ {"classification_task": "Classify the following restaurant review into one of the following categories:", "labels": ['food-quality', 'service', 'ambiance', 'price']}
45
+
46
+ {"classification_task": "Categorize the following blog post based on its primary fashion trend or style:", "labels": ['casual', 'formal', 'streetwear', 'vintage', 'sustainable-fashion']}
47
 
48
  User dataset description:
49
  """
 
54
  ]
55
 
56
 
57
+ class TextClassificationTask(BaseModel):
58
+ classification_task: str = Field(
59
+ ...,
60
+ title="classification_task",
61
+ description="The classification task to be performed.",
62
+ )
63
+
64
+ labels: list[str] = Field(
65
+ ...,
66
+ title="Labels",
67
+ description="The possible labels for the classification task.",
68
+ )
69
+
70
+
71
+ def get_prompt_generator(temperature):
72
+ prompt_generator = TextGeneration(
73
+ llm=InferenceEndpointsLLM(
74
+ api_key=_get_next_api_key(),
75
+ model_id=MODEL,
76
+ tokenizer_id=MODEL,
77
+ structured_output={"format": "json", "schema": TextClassificationTask},
78
+ generation_kwargs={
79
+ "temperature": temperature,
80
+ "max_new_tokens": 2048,
81
+ "do_sample": True,
82
+ },
83
+ ),
84
+ system_prompt=PROMPT_CREATION_PROMPT,
85
+ use_system_prompt=True,
86
+ )
87
+ prompt_generator.load()
88
+ return prompt_generator
89
+
90
+
91
+ def get_textcat_generator(difficulty, clarity, is_sample):
92
+ textcat_generator = GenerateTextClassificationData(
93
+ llm=InferenceEndpointsLLM(
94
+ model_id=MODEL,
95
+ tokenizer_id=MODEL,
96
+ api_key=_get_next_api_key(),
97
+ generation_kwargs={
98
+ "temperature": 0.9,
99
+ "max_new_tokens": 256 if is_sample else 2048,
100
+ "do_sample": True,
101
+ "top_k": 50,
102
+ "top_p": 0.95,
103
+ },
104
+ ),
105
+ difficulty=None if difficulty == "mixed" else difficulty,
106
+ clarity=None if clarity == "mixed" else clarity,
107
+ seed=random.randint(0, 2**32 - 1),
108
+ )
109
+ textcat_generator.load()
110
+ return textcat_generator
111
+
112
+
113
+ def get_labeller_generator(system_prompt, labels, num_labels):
114
+ labeller_generator = TextClassification(
115
+ llm=InferenceEndpointsLLM(
116
+ model_id=MODEL,
117
+ tokenizer_id=MODEL,
118
+ api_key=_get_next_api_key(),
119
+ generation_kwargs={
120
+ "temperature": 0.7,
121
+ "max_new_tokens": 2048,
122
+ },
123
+ ),
124
+ context=system_prompt,
125
+ available_labels=labels,
126
+ n=num_labels,
127
+ default_label="unknown",
128
+ )
129
+ labeller_generator.load()
130
+ return labeller_generator
131
+
132
+
133
  def generate_pipeline_code(
134
  system_prompt: str,
135
  difficulty: str = None,
 
225
  distiset = pipeline.run()
226
  """
227
  )