feat/text-classification

#11
by davidberenstein1957 HF staff - opened
.python-version ADDED
@@ -0,0 +1 @@
 
 
1
+ synthetic-data-generator
app.py CHANGED
@@ -2,6 +2,7 @@ import gradio as gr
2
 
3
  from src.distilabel_dataset_generator.apps.faq import app as faq_app
4
  from src.distilabel_dataset_generator.apps.sft import app as sft_app
 
5
 
6
  theme = gr.themes.Monochrome(
7
  spacing_size="md",
@@ -25,8 +26,8 @@ css = """
25
  """
26
 
27
  demo = gr.TabbedInterface(
28
- [sft_app, faq_app],
29
- ["Supervised Fine-Tuning", "FAQ"],
30
  css=css,
31
  title="""
32
  <style>
@@ -54,6 +55,17 @@ demo = gr.TabbedInterface(
54
  margin-bottom: 20px;
55
  }
56
  }
 
 
 
 
 
 
 
 
 
 
 
57
  </style>
58
  <div class="header-container">
59
  <div class="logo-container">
@@ -62,7 +74,7 @@ demo = gr.TabbedInterface(
62
  </a>
63
  </div>
64
  <div class="title-container">
65
- <h1 style="margin: 0; font-size: 2em;">🧬 Synthetic Data Generator</h1>
66
  <p style="margin: 10px 0 0 0; color: #666; font-size: 1.1em;">Build datasets using natural language</p>
67
  </div>
68
  </div>
 
2
 
3
  from src.distilabel_dataset_generator.apps.faq import app as faq_app
4
  from src.distilabel_dataset_generator.apps.sft import app as sft_app
5
+ from src.distilabel_dataset_generator.apps.textcat import app as textcat_app
6
 
7
  theme = gr.themes.Monochrome(
8
  spacing_size="md",
 
26
  """
27
 
28
  demo = gr.TabbedInterface(
29
+ [textcat_app, sft_app, faq_app],
30
+ ["Text Classification", "Supervised Fine-Tuning", "FAQ"],
31
  css=css,
32
  title="""
33
  <style>
 
55
  margin-bottom: 20px;
56
  }
57
  }
58
+ button[role="tab"].selected,
59
+ button[role="tab"][aria-selected="true"],
60
+ button[role="tab"][data-tab-id][aria-selected="true"] {
61
+ background-color: #000000;
62
+ color: white;
63
+ border: none;
64
+ font-size: 16px;
65
+ font-weight: bold;
66
+ box-shadow: 0 4px 8px rgba(0, 0, 0, 0.2);
67
+ transition: background-color 0.3s ease, color 0.3s ease;
68
+ }
69
  </style>
70
  <div class="header-container">
71
  <div class="logo-container">
 
74
  </a>
75
  </div>
76
  <div class="title-container">
77
+ <h1 style="margin: 0; font-size: 2em;">🧬 Synthetic Data Generator</h1>
78
  <p style="margin: 10px 0 0 0; color: #666; font-size: 1.1em;">Build datasets using natural language</p>
79
  </div>
80
  </div>
pdm.lock CHANGED
The diff for this file is too large to render. See raw diff
 
pyproject.toml CHANGED
@@ -6,11 +6,13 @@ authors = [
6
  {name = "davidberenstein1957", email = "david.m.berenstein@gmail.com"},
7
  ]
8
  dependencies = [
9
- "distilabel[hf-inference-endpoints] @ git+https://github.com/argilla-io/distilabel.git@develop",
10
  "gradio[oauth]<5,>=4.38",
11
  "transformers>=4.44.2",
 
 
12
  ]
13
- requires-python = ">=3.10"
14
  readme = "README.md"
15
  license = {text = "apache 2"}
16
 
 
6
  {name = "davidberenstein1957", email = "david.m.berenstein@gmail.com"},
7
  ]
8
  dependencies = [
9
+ "distilabel[hf-inference-endpoints,argilla]==1.4.0",
10
  "gradio[oauth]<5,>=4.38",
11
  "transformers>=4.44.2",
12
+ "sentence-transformers>=3.2.0",
13
+ "model2vec>=0.2.4",
14
  ]
15
+ requires-python = "<3.13,>=3.10"
16
  readme = "README.md"
17
  license = {text = "apache 2"}
18
 
requirements.txt CHANGED
@@ -1,4 +1,7 @@
1
  transformers
2
  gradio[oauth]
3
- distilabel[hf-inference-endpoints] @ git+https://github.com/argilla-io/distilabel.git@develop
4
- beautifulsoup4
 
 
 
 
1
  transformers
2
  gradio[oauth]
3
+ distilabel[hf-inference-endpoints,argilla]
4
+ beautifulsoup4
5
+ sentence-transformers
6
+ model2vec
7
+ outlines
src/distilabel_dataset_generator/apps/base.py ADDED
@@ -0,0 +1,526 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import uuid
3
+ from typing import Any, Callable, List, Tuple, Union
4
+
5
+ import argilla as rg
6
+ import gradio as gr
7
+ import pandas as pd
8
+ from datasets import ClassLabel, Dataset, Features, Sequence, Value
9
+ from distilabel.distiset import Distiset
10
+ from gradio import OAuthToken
11
+ from huggingface_hub import HfApi, upload_file
12
+
13
+ from src.distilabel_dataset_generator.utils import (
14
+ _LOGGED_OUT_CSS,
15
+ get_argilla_client,
16
+ list_orgs,
17
+ swap_visibilty,
18
+ get_login_button,
19
+ )
20
+
21
+ TEXTCAT_TASK = "text_classification"
22
+ SFT_TASK = "supervised_fine_tuning"
23
+
24
+
25
+ def get_main_ui(
26
+ default_dataset_descriptions: List[str],
27
+ default_system_prompts: List[str],
28
+ default_datasets: List[pd.DataFrame],
29
+ fn_generate_system_prompt: Callable,
30
+ fn_generate_dataset: Callable,
31
+ task: str,
32
+ ):
33
+ def fn_generate_sample_dataset(system_prompt, progress=gr.Progress()):
34
+ if system_prompt in default_system_prompts:
35
+ index = default_system_prompts.index(system_prompt)
36
+ if index < len(default_datasets):
37
+ return default_datasets[index]
38
+ if task == TEXTCAT_TASK:
39
+ result = fn_generate_dataset(
40
+ system_prompt=system_prompt,
41
+ difficulty="mixed",
42
+ clarity="mixed",
43
+ labels=[],
44
+ num_labels=1,
45
+ num_rows=1,
46
+ progress=progress,
47
+ is_sample=True,
48
+ )
49
+ else:
50
+ result = fn_generate_dataset(
51
+ system_prompt=system_prompt,
52
+ num_turns=1,
53
+ num_rows=1,
54
+ progress=progress,
55
+ is_sample=True,
56
+ )
57
+ return result
58
+
59
+ with gr.Blocks(
60
+ title="🧬 Synthetic Data Generator",
61
+ head="🧬 Synthetic Data Generator",
62
+ css=_LOGGED_OUT_CSS,
63
+ ) as app:
64
+ with gr.Row():
65
+ gr.Markdown(
66
+ "Want to run this locally or with other LLMs? Take a look at the FAQ tab. distilabel Synthetic Data Generator is free, we use the authentication token to push the dataset to the Hugging Face Hub and not for data generation."
67
+ )
68
+ with gr.Row():
69
+ gr.Column()
70
+ get_login_button()
71
+ gr.Column()
72
+
73
+ gr.Markdown("## Iterate on a sample dataset")
74
+ with gr.Column() as main_ui:
75
+ (
76
+ dataset_description,
77
+ examples,
78
+ btn_generate_system_prompt,
79
+ system_prompt,
80
+ sample_dataset,
81
+ btn_generate_sample_dataset,
82
+ ) = get_iterate_on_sample_dataset_ui(
83
+ default_dataset_descriptions=default_dataset_descriptions,
84
+ default_system_prompts=default_system_prompts,
85
+ default_datasets=default_datasets,
86
+ task=task,
87
+ )
88
+ gr.Markdown("## Generate full dataset")
89
+ gr.Markdown(
90
+ "Once you're satisfied with the sample, generate a larger dataset and push it to Argilla or the Hugging Face Hub."
91
+ )
92
+ with gr.Row(variant="panel") as custom_input_ui:
93
+ pass
94
+
95
+ (
96
+ dataset_name,
97
+ add_to_existing_dataset,
98
+ btn_generate_full_dataset_argilla,
99
+ btn_generate_and_push_to_argilla,
100
+ btn_push_to_argilla,
101
+ org_name,
102
+ repo_name,
103
+ private,
104
+ btn_generate_full_dataset,
105
+ btn_generate_and_push_to_hub,
106
+ btn_push_to_hub,
107
+ final_dataset,
108
+ success_message,
109
+ ) = get_push_to_ui(default_datasets)
110
+
111
+ sample_dataset.change(
112
+ fn=lambda x: x,
113
+ inputs=[sample_dataset],
114
+ outputs=[final_dataset],
115
+ )
116
+
117
+ btn_generate_system_prompt.click(
118
+ fn=fn_generate_system_prompt,
119
+ inputs=[dataset_description],
120
+ outputs=[system_prompt],
121
+ show_progress=True,
122
+ ).then(
123
+ fn=fn_generate_sample_dataset,
124
+ inputs=[system_prompt],
125
+ outputs=[sample_dataset],
126
+ show_progress=True,
127
+ )
128
+
129
+ btn_generate_sample_dataset.click(
130
+ fn=fn_generate_sample_dataset,
131
+ inputs=[system_prompt],
132
+ outputs=[sample_dataset],
133
+ show_progress=True,
134
+ )
135
+
136
+ app.load(fn=swap_visibilty, outputs=main_ui)
137
+ app.load(get_org_dropdown, outputs=[org_name])
138
+
139
+ return (
140
+ app,
141
+ main_ui,
142
+ custom_input_ui,
143
+ dataset_description,
144
+ examples,
145
+ btn_generate_system_prompt,
146
+ system_prompt,
147
+ sample_dataset,
148
+ btn_generate_sample_dataset,
149
+ dataset_name,
150
+ add_to_existing_dataset,
151
+ btn_generate_full_dataset_argilla,
152
+ btn_generate_and_push_to_argilla,
153
+ btn_push_to_argilla,
154
+ org_name,
155
+ repo_name,
156
+ private,
157
+ btn_generate_full_dataset,
158
+ btn_generate_and_push_to_hub,
159
+ btn_push_to_hub,
160
+ final_dataset,
161
+ success_message,
162
+ )
163
+
164
+
165
+ def validate_argilla_user_workspace_dataset(
166
+ dataset_name: str,
167
+ final_dataset: pd.DataFrame,
168
+ add_to_existing_dataset: bool,
169
+ oauth_token: Union[OAuthToken, None] = None,
170
+ progress=gr.Progress(),
171
+ ) -> str:
172
+ progress(0, desc="Validating dataset configuration")
173
+ hf_user = HfApi().whoami(token=oauth_token.token)["name"]
174
+ client = get_argilla_client()
175
+ if dataset_name is None or dataset_name == "":
176
+ raise gr.Error("Dataset name is required")
177
+ # Create user if it doesn't exist
178
+ rg_user = client.users(username=hf_user)
179
+ if rg_user is None:
180
+ rg_user = client.users.add(
181
+ rg.User(username=hf_user, role="admin", password=str(uuid.uuid4()))
182
+ )
183
+ # Create workspace if it doesn't exist
184
+ workspace = client.workspaces(name=hf_user)
185
+ if workspace is None:
186
+ workspace = client.workspaces.add(rg.Workspace(name=hf_user))
187
+ workspace.add_user(hf_user)
188
+ # Check if dataset exists
189
+ dataset = client.datasets(name=dataset_name, workspace=hf_user)
190
+ if dataset and not add_to_existing_dataset:
191
+ raise gr.Error(f"Dataset {dataset_name} already exists")
192
+ return final_dataset
193
+
194
+
195
+ def get_org_dropdown(oauth_token: OAuthToken = None):
196
+ orgs = list_orgs(oauth_token)
197
+ return gr.Dropdown(
198
+ label="Organization",
199
+ choices=orgs,
200
+ value=orgs[0] if orgs else None,
201
+ allow_custom_value=True,
202
+ )
203
+
204
+
205
+ def get_push_to_ui(default_datasets):
206
+ with gr.Column() as push_to_ui:
207
+ (
208
+ dataset_name,
209
+ add_to_existing_dataset,
210
+ btn_generate_full_dataset_argilla,
211
+ btn_generate_and_push_to_argilla,
212
+ btn_push_to_argilla,
213
+ ) = get_argilla_tab()
214
+ (
215
+ org_name,
216
+ repo_name,
217
+ private,
218
+ btn_generate_full_dataset,
219
+ btn_generate_and_push_to_hub,
220
+ btn_push_to_hub,
221
+ ) = get_hf_tab()
222
+ final_dataset = get_final_dataset_row(default_datasets)
223
+ success_message = get_success_message_row()
224
+ return (
225
+ dataset_name,
226
+ add_to_existing_dataset,
227
+ btn_generate_full_dataset_argilla,
228
+ btn_generate_and_push_to_argilla,
229
+ btn_push_to_argilla,
230
+ org_name,
231
+ repo_name,
232
+ private,
233
+ btn_generate_full_dataset,
234
+ btn_generate_and_push_to_hub,
235
+ btn_push_to_hub,
236
+ final_dataset,
237
+ success_message,
238
+ )
239
+
240
+
241
+ def get_iterate_on_sample_dataset_ui(
242
+ default_dataset_descriptions: List[str],
243
+ default_system_prompts: List[str],
244
+ default_datasets: List[pd.DataFrame],
245
+ task: str,
246
+ ):
247
+ with gr.Column():
248
+ dataset_description = gr.TextArea(
249
+ label="Give a precise description of your desired application. Check the examples for inspiration.",
250
+ value=default_dataset_descriptions[0],
251
+ lines=2,
252
+ )
253
+ examples = gr.Examples(
254
+ elem_id="system_prompt_examples",
255
+ examples=[[example] for example in default_dataset_descriptions],
256
+ inputs=[dataset_description],
257
+ )
258
+ with gr.Row():
259
+ gr.Column(scale=1)
260
+ btn_generate_system_prompt = gr.Button(
261
+ value="Generate system prompt and sample dataset"
262
+ )
263
+ gr.Column(scale=1)
264
+
265
+ system_prompt = gr.TextArea(
266
+ label="System prompt for dataset generation. You can tune it and regenerate the sample.",
267
+ value=default_system_prompts[0],
268
+ lines=2 if task == TEXTCAT_TASK else 5,
269
+ )
270
+
271
+ with gr.Row():
272
+ sample_dataset = gr.Dataframe(
273
+ value=default_datasets[0],
274
+ label="Sample dataset. Prompts and completions truncated to 256 tokens.",
275
+ interactive=False,
276
+ wrap=True,
277
+ )
278
+
279
+ with gr.Row():
280
+ gr.Column(scale=1)
281
+ btn_generate_sample_dataset = gr.Button(
282
+ value="Generate sample dataset",
283
+ )
284
+ gr.Column(scale=1)
285
+
286
+ return (
287
+ dataset_description,
288
+ examples,
289
+ btn_generate_system_prompt,
290
+ system_prompt,
291
+ sample_dataset,
292
+ btn_generate_sample_dataset,
293
+ )
294
+
295
+
296
+ def get_pipeline_code_ui(pipeline_code: str) -> gr.Code:
297
+ gr.Markdown("## Or run this pipeline locally with distilabel")
298
+ gr.Markdown(
299
+ "You can run this pipeline locally with distilabel. For more information, please refer to the [distilabel documentation](https://distilabel.argilla.io/) or go to the FAQ tab at the top of the page for more information."
300
+ )
301
+ with gr.Accordion(
302
+ "Run this pipeline using distilabel",
303
+ open=False,
304
+ ):
305
+ pipeline_code = gr.Code(
306
+ value=pipeline_code,
307
+ language="python",
308
+ label="Distilabel Pipeline Code",
309
+ )
310
+ return pipeline_code
311
+
312
+
313
+ def get_argilla_tab() -> Tuple[Any]:
314
+ with gr.Tab(label="Argilla"):
315
+ if get_argilla_client() is not None:
316
+ with gr.Row(variant="panel"):
317
+ dataset_name = gr.Textbox(
318
+ label="Dataset name",
319
+ placeholder="dataset_name",
320
+ value="my-distiset",
321
+ )
322
+ add_to_existing_dataset = gr.Checkbox(
323
+ label="Allow adding records to existing dataset",
324
+ info="When selected, you do need to ensure the dataset options are the same as in the existing dataset.",
325
+ value=False,
326
+ interactive=True,
327
+ scale=1,
328
+ )
329
+
330
+ with gr.Row(variant="panel"):
331
+ btn_generate_full_dataset_argilla = gr.Button(
332
+ value="Generate", variant="primary", scale=2
333
+ )
334
+ btn_generate_and_push_to_argilla = gr.Button(
335
+ value="Generate and Push to Argilla",
336
+ variant="primary",
337
+ scale=2,
338
+ )
339
+ btn_push_to_argilla = gr.Button(
340
+ value="Push to Argilla", variant="primary", scale=2
341
+ )
342
+ else:
343
+ gr.Markdown(
344
+ "Please add `ARGILLA_API_URL` and `ARGILLA_API_KEY` to use Argilla or export the dataset to the Hugging Face Hub."
345
+ )
346
+ return (
347
+ dataset_name,
348
+ add_to_existing_dataset,
349
+ btn_generate_full_dataset_argilla,
350
+ btn_generate_and_push_to_argilla,
351
+ btn_push_to_argilla,
352
+ )
353
+
354
+
355
+ def get_hf_tab() -> Tuple[Any]:
356
+ with gr.Tab("Hugging Face Hub"):
357
+ with gr.Row(variant="panel"):
358
+ org_name = get_org_dropdown()
359
+ repo_name = gr.Textbox(
360
+ label="Repo name",
361
+ placeholder="dataset_name",
362
+ value="my-distiset",
363
+ )
364
+ private = gr.Checkbox(
365
+ label="Private dataset",
366
+ value=True,
367
+ interactive=True,
368
+ scale=1,
369
+ )
370
+ with gr.Row(variant="panel"):
371
+ btn_generate_full_dataset = gr.Button(
372
+ value="Generate", variant="primary", scale=2
373
+ )
374
+ btn_generate_and_push_to_hub = gr.Button(
375
+ value="Generate and Push to Hub", variant="primary", scale=2
376
+ )
377
+ btn_push_to_hub = gr.Button(value="Push to Hub", variant="primary", scale=2)
378
+ return (
379
+ org_name,
380
+ repo_name,
381
+ private,
382
+ btn_generate_full_dataset,
383
+ btn_generate_and_push_to_hub,
384
+ btn_push_to_hub,
385
+ )
386
+
387
+
388
+ def push_pipeline_code_to_hub(
389
+ pipeline_code: str,
390
+ org_name: str,
391
+ repo_name: str,
392
+ oauth_token: Union[OAuthToken, None] = None,
393
+ progress=gr.Progress(),
394
+ ):
395
+ repo_id = _check_push_to_hub(org_name, repo_name)
396
+ progress(0.1, desc="Uploading pipeline code")
397
+ with io.BytesIO(pipeline_code.encode("utf-8")) as f:
398
+ upload_file(
399
+ path_or_fileobj=f,
400
+ path_in_repo="pipeline.py",
401
+ repo_id=repo_id,
402
+ repo_type="dataset",
403
+ token=oauth_token.token,
404
+ commit_message="Include pipeline script",
405
+ create_pr=False,
406
+ )
407
+ progress(1.0, desc="Pipeline code uploaded")
408
+
409
+
410
+ def push_dataset_to_hub(
411
+ dataframe: pd.DataFrame,
412
+ private: bool = True,
413
+ org_name: str = None,
414
+ repo_name: str = None,
415
+ oauth_token: Union[OAuthToken, None] = None,
416
+ progress=gr.Progress(),
417
+ labels: List[str] = None,
418
+ num_labels: int = None,
419
+ task: str = TEXTCAT_TASK,
420
+ ) -> pd.DataFrame:
421
+ progress(0.1, desc="Setting up dataset")
422
+ repo_id = _check_push_to_hub(org_name, repo_name)
423
+
424
+ if task == TEXTCAT_TASK:
425
+ if num_labels == 1:
426
+ features = Features(
427
+ {"text": Value("string"), "label": ClassLabel(names=labels)}
428
+ )
429
+ else:
430
+ features = Features({
431
+ "text": Value("string"),
432
+ "labels": Sequence(feature=ClassLabel(names=labels))
433
+ })
434
+ distiset = Distiset({
435
+ "default": Dataset.from_pandas(dataframe, features=features)
436
+ })
437
+ else:
438
+ distiset = Distiset({
439
+ "default": Dataset.from_pandas(dataframe)
440
+ })
441
+ progress(0.2, desc="Pushing dataset to hub")
442
+ distiset.push_to_hub(
443
+ repo_id=repo_id,
444
+ private=private,
445
+ include_script=False,
446
+ token=oauth_token.token,
447
+ create_pr=False,
448
+ )
449
+ progress(1.0, desc="Dataset pushed to hub")
450
+ return dataframe
451
+
452
+
453
+ def _check_push_to_hub(org_name, repo_name):
454
+ repo_id = (
455
+ f"{org_name}/{repo_name}"
456
+ if repo_name is not None and org_name is not None
457
+ else None
458
+ )
459
+ if repo_id is not None:
460
+ if not all([repo_id, org_name, repo_name]):
461
+ raise gr.Error(
462
+ "Please provide a `repo_name` and `org_name` to push the dataset to."
463
+ )
464
+ return repo_id
465
+
466
+
467
+ def get_final_dataset_row(default_datasets) -> gr.Dataframe:
468
+ with gr.Row():
469
+ final_dataset = gr.Dataframe(
470
+ value=default_datasets[0],
471
+ label="Generated dataset",
472
+ interactive=False,
473
+ wrap=True,
474
+ min_width=300,
475
+ )
476
+ return final_dataset
477
+
478
+
479
+ def get_success_message_row() -> gr.Markdown:
480
+ with gr.Row():
481
+ success_message = gr.Markdown(visible=False)
482
+ return success_message
483
+
484
+
485
+ def show_success_message_argilla() -> gr.Markdown:
486
+ client = get_argilla_client()
487
+ argilla_api_url = client.api_url
488
+ return gr.Markdown(
489
+ value=f"""
490
+ <div style="padding: 1em; background-color: #e6f3e6; border-radius: 5px; margin-top: 1em;">
491
+ <h3 style="color: #2e7d32; margin: 0;">Dataset Published Successfully!</h3>
492
+ <p style="margin-top: 0.5em;">
493
+ Your dataset is now available at:
494
+ <a href="{argilla_api_url}" target="_blank" style="color: #1565c0; text-decoration: none;">
495
+ {argilla_api_url}
496
+ </a>
497
+ <br>Unfamiliar with Argilla? Here are some docs to help you get started:
498
+ <br>• <a href="https://docs.argilla.io/latest/how_to_guides/annotate/" target="_blank">How to curate data in Argilla</a>
499
+ <br>• <a href="https://docs.argilla.io/latest/how_to_guides/import_export/" target="_blank">How to export data once you have reviewed the dataset</a>
500
+ </p>
501
+ </div>
502
+ """,
503
+ visible=True,
504
+ )
505
+
506
+
507
+ def show_success_message_hub(org_name, repo_name) -> gr.Markdown:
508
+ return gr.Markdown(
509
+ value=f"""
510
+ <div style="padding: 1em; background-color: #e6f3e6; border-radius: 5px; margin-top: 1em;">
511
+ <h3 style="color: #2e7d32; margin: 0;">Dataset Published Successfully!</h3>
512
+ <p style="margin-top: 0.5em;">
513
+ The generated dataset is in the right format for fine-tuning with TRL, AutoTrain or other frameworks.
514
+ Your dataset is now available at:
515
+ <a href="https://huggingface.co/datasets/{org_name}/{repo_name}" target="_blank" style="color: #1565c0; text-decoration: none;">
516
+ https://huggingface.co/datasets/{org_name}/{repo_name}
517
+ </a>
518
+ </p>
519
+ </div>
520
+ """,
521
+ visible=True,
522
+ )
523
+
524
+
525
+ def hide_success_message() -> gr.Markdown:
526
+ return gr.Markdown(visible=False)
src/distilabel_dataset_generator/apps/faq.py CHANGED
@@ -15,7 +15,7 @@ with gr.Blocks() as app:
15
  <p>This tool simplifies the process of creating custom datasets, enabling you to:</p>
16
  <ul>
17
  <li>Define the characteristics of your desired application</li>
18
- <li>Generate system prompts automatically</li>
19
  <li>Create sample datasets for quick iteration</li>
20
  <li>Produce full-scale datasets with customizable parameters</li>
21
  <li>Push your generated datasets directly to the Hugging Face Hub</li>
 
15
  <p>This tool simplifies the process of creating custom datasets, enabling you to:</p>
16
  <ul>
17
  <li>Define the characteristics of your desired application</li>
18
+ <li>Generate system prompts and tasks automatically</li>
19
  <li>Create sample datasets for quick iteration</li>
20
  <li>Produce full-scale datasets with customizable parameters</li>
21
  <li>Push your generated datasets directly to the Hugging Face Hub</li>
src/distilabel_dataset_generator/apps/sft.py CHANGED
@@ -1,16 +1,34 @@
1
- import io
2
- from typing import Union
3
 
 
4
  import gradio as gr
5
  import pandas as pd
6
  from datasets import Dataset
7
  from distilabel.distiset import Distiset
8
- from distilabel.steps.tasks.text_generation import TextGeneration
9
- from gradio.oauth import OAuthToken
10
- from huggingface_hub import upload_file
11
-
12
- from src.distilabel_dataset_generator.pipelines.sft import (
 
 
 
 
 
 
 
 
 
 
 
13
  DEFAULT_BATCH_SIZE,
 
 
 
 
 
 
14
  DEFAULT_DATASET_DESCRIPTIONS,
15
  DEFAULT_DATASETS,
16
  DEFAULT_SYSTEM_PROMPTS,
@@ -20,11 +38,169 @@ from src.distilabel_dataset_generator.pipelines.sft import (
20
  get_prompt_generator,
21
  get_response_generator,
22
  )
23
- from src.distilabel_dataset_generator.utils import (
24
- get_login_button,
25
- get_org_dropdown,
26
- swap_visibilty,
27
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
 
30
  def generate_system_prompt(dataset_description, progress=gr.Progress()):
@@ -35,7 +211,7 @@ def generate_system_prompt(dataset_description, progress=gr.Progress()):
35
  return DEFAULT_SYSTEM_PROMPTS[index]
36
 
37
  progress(0.3, desc="Initializing text generation")
38
- generate_description: TextGeneration = get_prompt_generator()
39
  progress(0.7, desc="Generating system prompt")
40
  result = next(
41
  generate_description.process(
@@ -51,38 +227,13 @@ def generate_system_prompt(dataset_description, progress=gr.Progress()):
51
  return result
52
 
53
 
54
- def generate_sample_dataset(system_prompt, progress=gr.Progress()):
55
- if system_prompt in DEFAULT_SYSTEM_PROMPTS:
56
- index = DEFAULT_SYSTEM_PROMPTS.index(system_prompt)
57
- if index < len(DEFAULT_DATASETS):
58
- return DEFAULT_DATASETS[index]
59
- result = generate_dataset(
60
- system_prompt, num_turns=1, num_rows=1, progress=progress, is_sample=True
61
- )
62
- return result
63
-
64
-
65
- def _check_push_to_hub(org_name, repo_name):
66
- repo_id = (
67
- f"{org_name}/{repo_name}"
68
- if repo_name is not None and org_name is not None
69
- else None
70
- )
71
- if repo_id is not None:
72
- if not all([repo_id, org_name, repo_name]):
73
- raise gr.Error(
74
- "Please provide a `repo_name` and `org_name` to push the dataset to."
75
- )
76
- return repo_id
77
-
78
-
79
  def generate_dataset(
80
  system_prompt: str,
81
  num_turns: int = 1,
82
  num_rows: int = 5,
83
  is_sample: bool = False,
84
  progress=gr.Progress(),
85
- ):
86
  progress(0.0, desc="(1/2) Generating instructions")
87
  magpie_generator = get_magpie_generator(
88
  num_turns, num_rows, system_prompt, is_sample
@@ -149,7 +300,7 @@ def generate_dataset(
149
  progress(
150
  1,
151
  total=total_steps,
152
- desc="(2/2) Generating responses",
153
  )
154
 
155
  # create distiset
@@ -184,238 +335,98 @@ def generate_dataset(
184
  return dataframe
185
 
186
 
187
- def push_to_hub(
188
- dataframe: pd.DataFrame,
189
- private: bool = True,
190
- org_name: str = None,
191
- repo_name: str = None,
192
- oauth_token: Union[OAuthToken, None] = None,
193
- progress=gr.Progress(),
194
- ):
195
- progress(0.1, desc="Setting up dataset")
196
- repo_id = _check_push_to_hub(org_name, repo_name)
197
- distiset = Distiset(
198
- {
199
- "default": Dataset.from_pandas(dataframe),
200
- }
201
- )
202
- progress(0.2, desc="Pushing dataset to hub")
203
- distiset.push_to_hub(
204
- repo_id=repo_id,
205
- private=private,
206
- include_script=False,
207
- token=oauth_token.token,
208
- create_pr=False,
209
- )
210
- progress(1.0, desc="Dataset pushed to hub")
211
- return dataframe
212
-
213
-
214
- def upload_pipeline_code(
215
- pipeline_code,
216
  org_name,
217
  repo_name,
218
- oauth_token: Union[OAuthToken, None] = None,
219
- progress=gr.Progress(),
220
- ):
221
- repo_id = _check_push_to_hub(org_name, repo_name)
222
- progress(0.1, desc="Uploading pipeline code")
223
- with io.BytesIO(pipeline_code.encode("utf-8")) as f:
224
- upload_file(
225
- path_or_fileobj=f,
226
- path_in_repo="pipeline.py",
227
- repo_id=repo_id,
228
- repo_type="dataset",
229
- token=oauth_token.token,
230
- commit_message="Include pipeline script",
231
- create_pr=False,
232
- )
233
- progress(1.0, desc="Pipeline code uploaded")
234
-
235
-
236
- css = """
237
- .main_ui_logged_out{opacity: 0.3; pointer-events: none}
238
- """
239
-
240
- with gr.Blocks(
241
- title="🧬 Synthetic Data Generator",
242
- head="🧬 Synthetic Data Generator",
243
- css=css,
244
- ) as app:
245
- with gr.Row():
246
- gr.Markdown(
247
- "Want to run this locally or with other LLMs? Take a look at the FAQ tab. distilabel Synthetic Data Generator is free, we use the authentication token to push the dataset to the Hugging Face Hub and not for data generation."
248
- )
249
- with gr.Row():
250
- gr.Column()
251
- get_login_button()
252
- gr.Column()
253
-
254
- gr.Markdown("## Iterate on a sample dataset")
255
- with gr.Column() as main_ui:
256
- dataset_description = gr.TextArea(
257
- label="Give a precise description of the assistant or tool. Don't describe the dataset",
258
- value=DEFAULT_DATASET_DESCRIPTIONS[0],
259
- lines=2,
260
- )
261
- examples = gr.Examples(
262
- elem_id="system_prompt_examples",
263
- examples=[[example] for example in DEFAULT_DATASET_DESCRIPTIONS],
264
- inputs=[dataset_description],
265
- )
266
- with gr.Row():
267
- gr.Column(scale=1)
268
- btn_generate_system_prompt = gr.Button(
269
- value="Generate system prompt and sample dataset"
270
- )
271
- gr.Column(scale=1)
272
-
273
- system_prompt = gr.TextArea(
274
- label="System prompt for dataset generation. You can tune it and regenerate the sample",
275
- value=DEFAULT_SYSTEM_PROMPTS[0],
276
- lines=5,
277
- )
278
 
279
- with gr.Row():
280
- sample_dataset = gr.Dataframe(
281
- value=DEFAULT_DATASETS[0],
282
- label="Sample dataset. Prompts and completions truncated to 256 tokens.",
283
- interactive=False,
284
- wrap=True,
 
 
 
 
285
  )
286
-
287
- with gr.Row():
288
- gr.Column(scale=1)
289
- btn_generate_sample_dataset = gr.Button(
290
- value="Generate sample dataset",
 
291
  )
292
- gr.Column(scale=1)
293
-
294
- result = btn_generate_system_prompt.click(
295
- fn=generate_system_prompt,
296
- inputs=[dataset_description],
297
- outputs=[system_prompt],
298
- show_progress=True,
299
- ).then(
300
- fn=generate_sample_dataset,
301
- inputs=[system_prompt],
302
- outputs=[sample_dataset],
303
- show_progress=True,
304
- )
305
-
306
- btn_generate_sample_dataset.click(
307
- fn=generate_sample_dataset,
308
- inputs=[system_prompt],
309
- outputs=[sample_dataset],
310
- show_progress=True,
311
- )
312
-
313
- # Add a header for the full dataset generation section
314
- gr.Markdown("## Generate full dataset")
315
- gr.Markdown(
316
- "Once you're satisfied with the sample, generate a larger dataset and push it to the Hub."
317
- )
318
-
319
- with gr.Column() as push_to_hub_ui:
320
- with gr.Row(variant="panel"):
321
- num_turns = gr.Number(
322
- value=1,
323
- label="Number of turns in the conversation",
324
- minimum=1,
325
- maximum=4,
326
- step=1,
327
- info="Choose between 1 (single turn with 'instruction-response' columns) and 2-4 (multi-turn conversation with a 'messages' column).",
328
- )
329
- num_rows = gr.Number(
330
- value=10,
331
- label="Number of rows in the dataset",
332
- minimum=1,
333
- maximum=500,
334
- info="The number of rows in the dataset. Note that you are able to generate more rows at once but that this will take time.",
335
- )
336
- with gr.Row(variant="panel"):
337
- org_name = get_org_dropdown()
338
- repo_name = gr.Textbox(
339
- label="Repo name", placeholder="dataset_name", value="my-distiset"
340
- )
341
- private = gr.Checkbox(
342
- label="Private dataset",
343
- value=True,
344
- interactive=True,
345
- scale=0.5,
346
- )
347
- with gr.Row() as regenerate_row:
348
- btn_generate_full_dataset = gr.Button(
349
- value="Generate", variant="primary", scale=2
350
- )
351
- btn_generate_and_push_to_hub = gr.Button(
352
- value="Generate and Push to Hub", variant="primary", scale=2
353
- )
354
- btn_push_to_hub = gr.Button(
355
- value="Push to Hub", variant="primary", scale=2
356
- )
357
- with gr.Row():
358
- final_dataset = gr.Dataframe(
359
- value=DEFAULT_DATASETS[0],
360
- label="Generated dataset",
361
- interactive=False,
362
- wrap=True,
363
- )
364
-
365
- with gr.Row():
366
- success_message = gr.Markdown(visible=False)
367
-
368
- def show_success_message(org_name, repo_name):
369
- return gr.Markdown(
370
- value=f"""
371
- <div style="padding: 1em; background-color: #e6f3e6; border-radius: 5px; margin-top: 1em;">
372
- <h3 style="color: #2e7d32; margin: 0;">Dataset Published Successfully!</h3>
373
- <p style="margin-top: 0.5em;">
374
- The generated dataset is in the right format for fine-tuning with TRL, AutoTrain or other frameworks.
375
- Your dataset is now available at:
376
- <a href="https://huggingface.co/datasets/{org_name}/{repo_name}" target="_blank" style="color: #1565c0; text-decoration: none;">
377
- https://huggingface.co/datasets/{org_name}/{repo_name}
378
- </a>
379
- </p>
380
- </div>
381
- """,
382
- visible=True,
383
- )
384
-
385
- def hide_success_message():
386
- return gr.Markdown(visible=False)
387
 
388
- gr.Markdown("## Or run this pipeline locally with distilabel")
389
- gr.Markdown(
390
- "You can run this pipeline locally with distilabel. For more information, please refer to the [distilabel documentation](https://distilabel.argilla.io/) or go to the FAQ tab at the top of the page for more information."
391
- )
392
-
393
- with gr.Accordion(
394
- "Run this pipeline using distilabel",
395
- open=False,
396
- ):
397
- pipeline_code = gr.Code(
398
- value=generate_pipeline_code(
399
- system_prompt.value, num_turns.value, num_rows.value
400
- ),
401
- language="python",
402
- label="Distilabel Pipeline Code",
403
  )
404
 
405
- sample_dataset.change(
406
- fn=lambda x: x,
407
- inputs=[sample_dataset],
 
 
 
 
 
 
 
 
408
  outputs=[final_dataset],
 
409
  )
410
 
411
- btn_generate_full_dataset.click(
 
 
 
 
 
412
  fn=hide_success_message,
413
  outputs=[success_message],
414
- ).then(
415
  fn=generate_dataset,
416
  inputs=[system_prompt, num_turns, num_rows],
417
  outputs=[final_dataset],
418
  show_progress=True,
 
 
 
 
 
 
 
 
 
419
  )
420
 
421
  btn_generate_and_push_to_hub.click(
@@ -427,17 +438,17 @@ with gr.Blocks(
427
  outputs=[final_dataset],
428
  show_progress=True,
429
  ).then(
430
- fn=push_to_hub,
431
  inputs=[final_dataset, private, org_name, repo_name],
432
  outputs=[final_dataset],
433
  show_progress=True,
434
  ).then(
435
- fn=upload_pipeline_code,
436
  inputs=[pipeline_code, org_name, repo_name],
437
  outputs=[],
438
  show_progress=True,
439
  ).success(
440
- fn=show_success_message,
441
  inputs=[org_name, repo_name],
442
  outputs=[success_message],
443
  )
@@ -446,21 +457,40 @@ with gr.Blocks(
446
  fn=hide_success_message,
447
  outputs=[success_message],
448
  ).then(
449
- fn=push_to_hub,
450
  inputs=[final_dataset, private, org_name, repo_name],
451
  outputs=[final_dataset],
452
  show_progress=True,
453
  ).then(
454
- fn=upload_pipeline_code,
455
  inputs=[pipeline_code, org_name, repo_name],
456
  outputs=[],
457
  show_progress=True,
458
  ).success(
459
- fn=show_success_message,
460
  inputs=[org_name, repo_name],
461
  outputs=[success_message],
462
  )
463
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
464
  system_prompt.change(
465
  fn=generate_pipeline_code,
466
  inputs=[system_prompt, num_turns, num_rows],
@@ -476,5 +506,3 @@ with gr.Blocks(
476
  inputs=[system_prompt, num_turns, num_rows],
477
  outputs=[pipeline_code],
478
  )
479
- app.load(get_org_dropdown, outputs=[org_name])
480
- app.load(fn=swap_visibilty, outputs=main_ui)
 
1
+ import ast
2
+ from typing import Dict, List, Union
3
 
4
+ import argilla as rg
5
  import gradio as gr
6
  import pandas as pd
7
  from datasets import Dataset
8
  from distilabel.distiset import Distiset
9
+ from huggingface_hub import HfApi
10
+
11
+ from src.distilabel_dataset_generator.apps.base import (
12
+ get_argilla_client,
13
+ get_main_ui,
14
+ get_pipeline_code_ui,
15
+ hide_success_message,
16
+ push_pipeline_code_to_hub,
17
+ show_success_message_argilla,
18
+ show_success_message_hub,
19
+ validate_argilla_user_workspace_dataset,
20
+ )
21
+ from src.distilabel_dataset_generator.apps.base import (
22
+ push_dataset_to_hub as push_to_hub_base,
23
+ )
24
+ from src.distilabel_dataset_generator.pipelines.base import (
25
  DEFAULT_BATCH_SIZE,
26
+ )
27
+ from src.distilabel_dataset_generator.pipelines.embeddings import (
28
+ get_embeddings,
29
+ get_sentence_embedding_dimensions,
30
+ )
31
+ from src.distilabel_dataset_generator.pipelines.sft import (
32
  DEFAULT_DATASET_DESCRIPTIONS,
33
  DEFAULT_DATASETS,
34
  DEFAULT_SYSTEM_PROMPTS,
 
38
  get_prompt_generator,
39
  get_response_generator,
40
  )
41
+
42
+ TASK = "supervised_fine_tuning"
43
+
44
+
45
+ def convert_dataframe_messages(dataframe: pd.DataFrame) -> pd.DataFrame:
46
+ def convert_to_list_of_dicts(messages: str) -> List[Dict[str, str]]:
47
+ return ast.literal_eval(
48
+ messages.replace("'user'}", "'user'},")
49
+ .replace("'system'}", "'system'},")
50
+ .replace("'assistant'}", "'assistant'},")
51
+ )
52
+
53
+ if "messages" in dataframe.columns:
54
+ dataframe["messages"] = dataframe["messages"].apply(
55
+ lambda x: convert_to_list_of_dicts(x) if isinstance(x, str) else x
56
+ )
57
+ return dataframe
58
+
59
+
60
+ def push_dataset_to_hub(
61
+ dataframe: pd.DataFrame,
62
+ private: bool = True,
63
+ org_name: str = None,
64
+ repo_name: str = None,
65
+ oauth_token: Union[gr.OAuthToken, None] = None,
66
+ progress=gr.Progress(),
67
+ ):
68
+ original_dataframe = dataframe.copy(deep=True)
69
+ dataframe = convert_dataframe_messages(dataframe)
70
+ try:
71
+ push_to_hub_base(
72
+ dataframe, private, org_name, repo_name, oauth_token, progress, task=TASK
73
+ )
74
+ except Exception as e:
75
+ raise gr.Error(f"Error pushing dataset to the Hub: {e}")
76
+ return original_dataframe
77
+
78
+
79
+ def push_dataset_to_argilla(
80
+ dataframe: pd.DataFrame,
81
+ dataset_name: str,
82
+ oauth_token: Union[gr.OAuthToken, None] = None,
83
+ progress=gr.Progress(),
84
+ ) -> pd.DataFrame:
85
+ original_dataframe = dataframe.copy(deep=True)
86
+ dataframe = convert_dataframe_messages(dataframe)
87
+ try:
88
+ progress(0.1, desc="Setting up user and workspace")
89
+ client = get_argilla_client()
90
+ hf_user = HfApi().whoami(token=oauth_token.token)["name"]
91
+ if "messages" in dataframe.columns:
92
+ settings = rg.Settings(
93
+ fields=[
94
+ rg.ChatField(
95
+ name="messages",
96
+ description="The messages in the conversation",
97
+ title="Messages",
98
+ ),
99
+ ],
100
+ questions=[
101
+ rg.RatingQuestion(
102
+ name="rating",
103
+ title="Rating",
104
+ description="The rating of the conversation",
105
+ values=list(range(1, 6)),
106
+ ),
107
+ ],
108
+ metadata=[
109
+ rg.IntegerMetadataProperty(
110
+ name="user_message_length", title="User Message Length"
111
+ ),
112
+ rg.IntegerMetadataProperty(
113
+ name="assistant_message_length",
114
+ title="Assistant Message Length",
115
+ ),
116
+ ],
117
+ vectors=[
118
+ rg.VectorField(
119
+ name="messages_embeddings",
120
+ dimensions=get_sentence_embedding_dimensions(),
121
+ )
122
+ ],
123
+ guidelines="Please review the conversation and provide a score for the assistant's response.",
124
+ )
125
+
126
+ dataframe["user_message_length"] = dataframe["messages"].apply(
127
+ lambda x: sum([len(y["content"]) for y in x if y["role"] == "user"])
128
+ )
129
+ dataframe["assistant_message_length"] = dataframe["messages"].apply(
130
+ lambda x: sum(
131
+ [len(y["content"]) for y in x if y["role"] == "assistant"]
132
+ )
133
+ )
134
+ dataframe["messages_embeddings"] = get_embeddings(
135
+ dataframe["messages"].apply(
136
+ lambda x: " ".join([y["content"] for y in x])
137
+ )
138
+ )
139
+ else:
140
+ settings = rg.Settings(
141
+ fields=[
142
+ rg.TextField(
143
+ name="system_prompt",
144
+ title="System Prompt",
145
+ description="The system prompt used for the conversation",
146
+ required=False,
147
+ ),
148
+ rg.TextField(
149
+ name="prompt",
150
+ title="Prompt",
151
+ description="The prompt used for the conversation",
152
+ ),
153
+ rg.TextField(
154
+ name="completion",
155
+ title="Completion",
156
+ description="The completion from the assistant",
157
+ ),
158
+ ],
159
+ questions=[
160
+ rg.RatingQuestion(
161
+ name="rating",
162
+ title="Rating",
163
+ description="The rating of the conversation",
164
+ values=list(range(1, 6)),
165
+ ),
166
+ ],
167
+ metadata=[
168
+ rg.IntegerMetadataProperty(
169
+ name="prompt_length", title="Prompt Length"
170
+ ),
171
+ rg.IntegerMetadataProperty(
172
+ name="completion_length", title="Completion Length"
173
+ ),
174
+ ],
175
+ vectors=[
176
+ rg.VectorField(
177
+ name="prompt_embeddings",
178
+ dimensions=get_sentence_embedding_dimensions(),
179
+ )
180
+ ],
181
+ guidelines="Please review the conversation and correct the prompt and completion where needed.",
182
+ )
183
+ dataframe["prompt_length"] = dataframe["prompt"].apply(len)
184
+ dataframe["completion_length"] = dataframe["completion"].apply(len)
185
+ dataframe["prompt_embeddings"] = get_embeddings(dataframe["prompt"])
186
+
187
+ progress(0.5, desc="Creating dataset")
188
+ rg_dataset = client.datasets(name=dataset_name, workspace=hf_user)
189
+ if rg_dataset is None:
190
+ rg_dataset = rg.Dataset(
191
+ name=dataset_name,
192
+ workspace=hf_user,
193
+ settings=settings,
194
+ client=client,
195
+ )
196
+ rg_dataset = rg_dataset.create()
197
+ progress(0.7, desc="Pushing dataset to Argilla")
198
+ hf_dataset = Dataset.from_pandas(dataframe)
199
+ rg_dataset.records.log(records=hf_dataset)
200
+ progress(1.0, desc="Dataset pushed to Argilla")
201
+ except Exception as e:
202
+ raise gr.Error(f"Error pushing dataset to Argilla: {e}")
203
+ return original_dataframe
204
 
205
 
206
  def generate_system_prompt(dataset_description, progress=gr.Progress()):
 
211
  return DEFAULT_SYSTEM_PROMPTS[index]
212
 
213
  progress(0.3, desc="Initializing text generation")
214
+ generate_description = get_prompt_generator()
215
  progress(0.7, desc="Generating system prompt")
216
  result = next(
217
  generate_description.process(
 
227
  return result
228
 
229
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
230
  def generate_dataset(
231
  system_prompt: str,
232
  num_turns: int = 1,
233
  num_rows: int = 5,
234
  is_sample: bool = False,
235
  progress=gr.Progress(),
236
+ ) -> pd.DataFrame:
237
  progress(0.0, desc="(1/2) Generating instructions")
238
  magpie_generator = get_magpie_generator(
239
  num_turns, num_rows, system_prompt, is_sample
 
300
  progress(
301
  1,
302
  total=total_steps,
303
+ desc="(2/2) Creating dataset",
304
  )
305
 
306
  # create distiset
 
335
  return dataframe
336
 
337
 
338
+ (
339
+ app,
340
+ main_ui,
341
+ custom_input_ui,
342
+ dataset_description,
343
+ examples,
344
+ btn_generate_system_prompt,
345
+ system_prompt,
346
+ sample_dataset,
347
+ btn_generate_sample_dataset,
348
+ dataset_name,
349
+ add_to_existing_dataset,
350
+ btn_generate_full_dataset_argilla,
351
+ btn_generate_and_push_to_argilla,
352
+ btn_push_to_argilla,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
353
  org_name,
354
  repo_name,
355
+ private,
356
+ btn_generate_full_dataset,
357
+ btn_generate_and_push_to_hub,
358
+ btn_push_to_hub,
359
+ final_dataset,
360
+ success_message,
361
+ ) = get_main_ui(
362
+ default_dataset_descriptions=DEFAULT_DATASET_DESCRIPTIONS,
363
+ default_system_prompts=DEFAULT_SYSTEM_PROMPTS,
364
+ default_datasets=DEFAULT_DATASETS,
365
+ fn_generate_system_prompt=generate_system_prompt,
366
+ fn_generate_dataset=generate_dataset,
367
+ task=TASK,
368
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
369
 
370
+ with app:
371
+ with main_ui:
372
+ with custom_input_ui:
373
+ num_turns = gr.Number(
374
+ value=1,
375
+ label="Number of turns in the conversation",
376
+ minimum=1,
377
+ maximum=4,
378
+ step=1,
379
+ info="Choose between 1 (single turn with 'instruction-response' columns) and 2-4 (multi-turn conversation with a 'messages' column).",
380
  )
381
+ num_rows = gr.Number(
382
+ value=10,
383
+ label="Number of rows in the dataset",
384
+ minimum=1,
385
+ maximum=500,
386
+ info="The number of rows in the dataset. Note that you are able to generate more rows at once but that this will take time.",
387
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
388
 
389
+ pipeline_code = get_pipeline_code_ui(
390
+ generate_pipeline_code(system_prompt.value, num_turns.value, num_rows.value)
 
 
 
 
 
 
 
 
 
 
 
 
 
391
  )
392
 
393
+ # define app triggers
394
+ gr.on(
395
+ triggers=[
396
+ btn_generate_full_dataset.click,
397
+ btn_generate_full_dataset_argilla.click,
398
+ ],
399
+ fn=hide_success_message,
400
+ outputs=[success_message],
401
+ ).then(
402
+ fn=generate_dataset,
403
+ inputs=[system_prompt, num_turns, num_rows],
404
  outputs=[final_dataset],
405
+ show_progress=True,
406
  )
407
 
408
+ btn_generate_and_push_to_argilla.click(
409
+ fn=validate_argilla_user_workspace_dataset,
410
+ inputs=[dataset_name, final_dataset, add_to_existing_dataset],
411
+ outputs=[final_dataset],
412
+ show_progress=True,
413
+ ).success(
414
  fn=hide_success_message,
415
  outputs=[success_message],
416
+ ).success(
417
  fn=generate_dataset,
418
  inputs=[system_prompt, num_turns, num_rows],
419
  outputs=[final_dataset],
420
  show_progress=True,
421
+ ).success(
422
+ fn=push_dataset_to_argilla,
423
+ inputs=[final_dataset, dataset_name],
424
+ outputs=[final_dataset],
425
+ show_progress=True,
426
+ ).success(
427
+ fn=show_success_message_argilla,
428
+ inputs=[],
429
+ outputs=[success_message],
430
  )
431
 
432
  btn_generate_and_push_to_hub.click(
 
438
  outputs=[final_dataset],
439
  show_progress=True,
440
  ).then(
441
+ fn=push_dataset_to_hub,
442
  inputs=[final_dataset, private, org_name, repo_name],
443
  outputs=[final_dataset],
444
  show_progress=True,
445
  ).then(
446
+ fn=push_pipeline_code_to_hub,
447
  inputs=[pipeline_code, org_name, repo_name],
448
  outputs=[],
449
  show_progress=True,
450
  ).success(
451
+ fn=show_success_message_hub,
452
  inputs=[org_name, repo_name],
453
  outputs=[success_message],
454
  )
 
457
  fn=hide_success_message,
458
  outputs=[success_message],
459
  ).then(
460
+ fn=push_dataset_to_hub,
461
  inputs=[final_dataset, private, org_name, repo_name],
462
  outputs=[final_dataset],
463
  show_progress=True,
464
  ).then(
465
+ fn=push_pipeline_code_to_hub,
466
  inputs=[pipeline_code, org_name, repo_name],
467
  outputs=[],
468
  show_progress=True,
469
  ).success(
470
+ fn=show_success_message_hub,
471
  inputs=[org_name, repo_name],
472
  outputs=[success_message],
473
  )
474
 
475
+ btn_push_to_argilla.click(
476
+ fn=hide_success_message,
477
+ outputs=[success_message],
478
+ ).success(
479
+ fn=validate_argilla_user_workspace_dataset,
480
+ inputs=[dataset_name, final_dataset, add_to_existing_dataset],
481
+ outputs=[final_dataset],
482
+ show_progress=True,
483
+ ).success(
484
+ fn=push_dataset_to_argilla,
485
+ inputs=[final_dataset, dataset_name],
486
+ outputs=[final_dataset],
487
+ show_progress=True,
488
+ ).success(
489
+ fn=show_success_message_argilla,
490
+ inputs=[],
491
+ outputs=[success_message],
492
+ )
493
+
494
  system_prompt.change(
495
  fn=generate_pipeline_code,
496
  inputs=[system_prompt, num_turns, num_rows],
 
506
  inputs=[system_prompt, num_turns, num_rows],
507
  outputs=[pipeline_code],
508
  )
 
 
src/distilabel_dataset_generator/apps/textcat.py ADDED
@@ -0,0 +1,548 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from typing import List, Union
3
+
4
+ import argilla as rg
5
+ import gradio as gr
6
+ import pandas as pd
7
+ from datasets import Dataset
8
+ from huggingface_hub import HfApi
9
+
10
+ from src.distilabel_dataset_generator.apps.base import (
11
+ get_argilla_client,
12
+ get_main_ui,
13
+ get_pipeline_code_ui,
14
+ hide_success_message,
15
+ push_pipeline_code_to_hub,
16
+ show_success_message_argilla,
17
+ show_success_message_hub,
18
+ validate_argilla_user_workspace_dataset,
19
+ )
20
+ from src.distilabel_dataset_generator.apps.base import (
21
+ push_dataset_to_hub as push_to_hub_base,
22
+ )
23
+ from src.distilabel_dataset_generator.pipelines.base import (
24
+ DEFAULT_BATCH_SIZE,
25
+ )
26
+ from src.distilabel_dataset_generator.pipelines.embeddings import (
27
+ get_embeddings,
28
+ get_sentence_embedding_dimensions,
29
+ )
30
+ from src.distilabel_dataset_generator.pipelines.textcat import (
31
+ DEFAULT_DATASET_DESCRIPTIONS,
32
+ DEFAULT_DATASETS,
33
+ DEFAULT_SYSTEM_PROMPTS,
34
+ PROMPT_CREATION_PROMPT,
35
+ generate_pipeline_code,
36
+ get_labeller_generator,
37
+ get_prompt_generator,
38
+ get_textcat_generator,
39
+ )
40
+ from src.distilabel_dataset_generator.utils import get_preprocess_labels
41
+
42
+ TASK = "text_classification"
43
+
44
+
45
+ def push_dataset_to_hub(
46
+ dataframe: pd.DataFrame,
47
+ private: bool = True,
48
+ org_name: str = None,
49
+ repo_name: str = None,
50
+ oauth_token: Union[gr.OAuthToken, None] = None,
51
+ progress=gr.Progress(),
52
+ labels: List[str] = None,
53
+ num_labels: int = 1,
54
+ ):
55
+ original_dataframe = dataframe.copy(deep=True)
56
+ labels = get_preprocess_labels(labels)
57
+ try:
58
+ push_to_hub_base(
59
+ dataframe,
60
+ private,
61
+ org_name,
62
+ repo_name,
63
+ oauth_token,
64
+ progress,
65
+ labels,
66
+ num_labels,
67
+ task=TASK,
68
+ )
69
+ except Exception as e:
70
+ raise gr.Error(f"Error pushing dataset to the Hub: {e}")
71
+ return original_dataframe
72
+
73
+
74
+ def push_dataset_to_argilla(
75
+ dataframe: pd.DataFrame,
76
+ dataset_name: str,
77
+ oauth_token: Union[gr.OAuthToken, None] = None,
78
+ progress=gr.Progress(),
79
+ num_labels: int = 1,
80
+ labels: List[str] = None,
81
+ ) -> pd.DataFrame:
82
+ original_dataframe = dataframe.copy(deep=True)
83
+ try:
84
+ progress(0.1, desc="Setting up user and workspace")
85
+ client = get_argilla_client()
86
+ hf_user = HfApi().whoami(token=oauth_token.token)["name"]
87
+ labels = get_preprocess_labels(labels)
88
+ settings = rg.Settings(
89
+ fields=[
90
+ rg.TextField(
91
+ name="text",
92
+ description="The text classification data",
93
+ title="Text",
94
+ ),
95
+ ],
96
+ questions=[
97
+ (
98
+ rg.LabelQuestion(
99
+ name="label",
100
+ title="Label",
101
+ description="The label of the text",
102
+ labels=labels,
103
+ )
104
+ if num_labels == 1
105
+ else rg.MultiLabelQuestion(
106
+ name="labels",
107
+ title="Labels",
108
+ description="The labels of the conversation",
109
+ labels=labels,
110
+ )
111
+ ),
112
+ ],
113
+ metadata=[
114
+ rg.IntegerMetadataProperty(name="text_length", title="Text Length"),
115
+ ],
116
+ vectors=[
117
+ rg.VectorField(
118
+ name="text_embeddings",
119
+ dimensions=get_sentence_embedding_dimensions(),
120
+ )
121
+ ],
122
+ guidelines="Please review the text and provide or correct the label where needed.",
123
+ )
124
+
125
+ dataframe["text_length"] = dataframe["text"].apply(len)
126
+ dataframe["text_embeddings"] = get_embeddings(dataframe["text"])
127
+
128
+ progress(0.5, desc="Creating dataset")
129
+ rg_dataset = client.datasets(name=dataset_name, workspace=hf_user)
130
+ if rg_dataset is None:
131
+ rg_dataset = rg.Dataset(
132
+ name=dataset_name,
133
+ workspace=hf_user,
134
+ settings=settings,
135
+ client=client,
136
+ )
137
+ rg_dataset = rg_dataset.create()
138
+ progress(0.7, desc="Pushing dataset to Argilla")
139
+ hf_dataset = Dataset.from_pandas(dataframe)
140
+ records = [
141
+ rg.Record(
142
+ fields={
143
+ "text": sample["text"],
144
+ },
145
+ metadata={"text_length": sample["text_length"]},
146
+ vectors={"text_embeddings": sample["text_embeddings"]},
147
+ suggestions=(
148
+ [
149
+ rg.Suggestion(
150
+ question_name="label" if num_labels == 1 else "labels",
151
+ value=(
152
+ sample["label"] if num_labels == 1 else sample["labels"]
153
+ ),
154
+ )
155
+ ]
156
+ if (
157
+ (num_labels == 1 and sample["label"] in labels)
158
+ or (
159
+ num_labels > 1
160
+ and all(label in labels for label in sample["labels"])
161
+ )
162
+ )
163
+ else []
164
+ ),
165
+ )
166
+ for sample in hf_dataset
167
+ ]
168
+ rg_dataset.records.log(records=records)
169
+ progress(1.0, desc="Dataset pushed to Argilla")
170
+ except Exception as e:
171
+ raise gr.Error(f"Error pushing dataset to Argilla: {e}")
172
+ return original_dataframe
173
+
174
+
175
+ def generate_system_prompt(dataset_description, progress=gr.Progress()):
176
+ progress(0.0, desc="Generating text classification task")
177
+ if dataset_description in DEFAULT_DATASET_DESCRIPTIONS:
178
+ index = DEFAULT_DATASET_DESCRIPTIONS.index(dataset_description)
179
+ if index < len(DEFAULT_SYSTEM_PROMPTS):
180
+ return DEFAULT_SYSTEM_PROMPTS[index]
181
+
182
+ progress(0.3, desc="Initializing text generation")
183
+ generate_description = get_prompt_generator()
184
+ progress(0.7, desc="Generating text classification task")
185
+ result = next(
186
+ generate_description.process(
187
+ [
188
+ {
189
+ "system_prompt": PROMPT_CREATION_PROMPT,
190
+ "instruction": dataset_description,
191
+ }
192
+ ]
193
+ )
194
+ )[0]["generation"]
195
+ progress(1.0, desc="Text classification task generated")
196
+ return result
197
+
198
+
199
+ def generate_dataset(
200
+ system_prompt: str,
201
+ difficulty: str,
202
+ clarity: str,
203
+ labels: List[str] = None,
204
+ num_labels: int = 1,
205
+ num_rows: int = 10,
206
+ is_sample: bool = False,
207
+ progress=gr.Progress(),
208
+ ) -> pd.DataFrame:
209
+ progress(0.0, desc="(1/2) Generating text classification data")
210
+ labels = get_preprocess_labels(labels)
211
+ textcat_generator = get_textcat_generator(
212
+ difficulty=difficulty, clarity=clarity, is_sample=is_sample
213
+ )
214
+ labeller_generator = get_labeller_generator(
215
+ system_prompt=system_prompt,
216
+ labels=labels,
217
+ num_labels=num_labels,
218
+ is_sample=is_sample,
219
+ )
220
+ total_steps: int = num_rows * 2
221
+ batch_size = DEFAULT_BATCH_SIZE
222
+
223
+ # create text classification data
224
+ n_processed = 0
225
+ textcat_results = []
226
+ while n_processed < num_rows:
227
+ progress(
228
+ 0.5 * n_processed / num_rows,
229
+ total=total_steps,
230
+ desc="(1/2) Generating text classification data",
231
+ )
232
+ remaining_rows = num_rows - n_processed
233
+ batch_size = min(batch_size, remaining_rows)
234
+ inputs = [{"task": system_prompt} for _ in range(batch_size)]
235
+ batch = list(textcat_generator.process(inputs=inputs))
236
+ textcat_results.extend(batch[0])
237
+ n_processed += batch_size
238
+ for result in textcat_results:
239
+ result["text"] = result["input_text"]
240
+
241
+ # label text classification data
242
+ progress(0.5, desc="(1/2) Generating text classification data")
243
+ if not is_sample:
244
+ n_processed = 0
245
+ labeller_results = []
246
+ while n_processed < num_rows:
247
+ progress(
248
+ 0.5 + 0.5 * n_processed / num_rows,
249
+ total=total_steps,
250
+ desc="(1/2) Labeling text classification data",
251
+ )
252
+ batch = textcat_results[n_processed : n_processed + batch_size]
253
+ labels_batch = list(labeller_generator.process(inputs=batch))
254
+ labeller_results.extend(labels_batch[0])
255
+ n_processed += batch_size
256
+ progress(
257
+ 1,
258
+ total=total_steps,
259
+ desc="(2/2) Creating dataset",
260
+ )
261
+
262
+ # create final dataset
263
+ distiset_results = []
264
+ source_results = textcat_results if is_sample else labeller_results
265
+ for result in source_results:
266
+ record = {
267
+ key: result[key]
268
+ for key in ["text", "label" if is_sample else "labels"]
269
+ if key in result
270
+ }
271
+ distiset_results.append(record)
272
+
273
+ dataframe = pd.DataFrame(distiset_results)
274
+ if not is_sample:
275
+ if num_labels == 1:
276
+ dataframe = dataframe.rename(columns={"labels": "label"})
277
+ dataframe["label"] = dataframe["label"].apply(
278
+ lambda x: x.lower().strip() if x.lower().strip() in labels else None
279
+ )
280
+ else:
281
+ dataframe["labels"] = dataframe["labels"].apply(
282
+ lambda x: (
283
+ [
284
+ label.lower().strip()
285
+ for label in x
286
+ if label.lower().strip() in labels
287
+ ]
288
+ if isinstance(x, list)
289
+ else None
290
+ )
291
+ )
292
+ progress(1.0, desc="Dataset generation completed")
293
+ return dataframe
294
+
295
+
296
+ def update_suggested_labels(system_prompt):
297
+ new_labels = re.findall(r"'(\b[\w-]+\b)'", system_prompt)
298
+ if not new_labels:
299
+ return gr.Warning(
300
+ "No labels found in the system prompt. Please add labels manually."
301
+ )
302
+ return gr.update(choices=new_labels, value=new_labels)
303
+
304
+
305
+ def validate_input_labels(labels):
306
+ if not labels or len(labels) < 2:
307
+ raise gr.Error(
308
+ f"Please select at least 2 labels to classify your text. You selected {len(labels) if labels else 0}."
309
+ )
310
+ return labels
311
+
312
+
313
+ (
314
+ app,
315
+ main_ui,
316
+ custom_input_ui,
317
+ dataset_description,
318
+ examples,
319
+ btn_generate_system_prompt,
320
+ system_prompt,
321
+ sample_dataset,
322
+ btn_generate_sample_dataset,
323
+ dataset_name,
324
+ add_to_existing_dataset,
325
+ btn_generate_full_dataset_argilla,
326
+ btn_generate_and_push_to_argilla,
327
+ btn_push_to_argilla,
328
+ org_name,
329
+ repo_name,
330
+ private,
331
+ btn_generate_full_dataset,
332
+ btn_generate_and_push_to_hub,
333
+ btn_push_to_hub,
334
+ final_dataset,
335
+ success_message,
336
+ ) = get_main_ui(
337
+ default_dataset_descriptions=DEFAULT_DATASET_DESCRIPTIONS,
338
+ default_system_prompts=DEFAULT_SYSTEM_PROMPTS,
339
+ default_datasets=DEFAULT_DATASETS,
340
+ fn_generate_system_prompt=generate_system_prompt,
341
+ fn_generate_dataset=generate_dataset,
342
+ task=TASK,
343
+ )
344
+
345
+ with app:
346
+ with main_ui:
347
+ with custom_input_ui:
348
+ difficulty = gr.Dropdown(
349
+ choices=[
350
+ ("High School", "high school"),
351
+ ("College", "college"),
352
+ ("PhD", "PhD"),
353
+ ("Mixed", "mixed"),
354
+ ],
355
+ value="mixed",
356
+ label="Difficulty",
357
+ info="The difficulty of the text to be generated.",
358
+ )
359
+ clarity = gr.Dropdown(
360
+ choices=[
361
+ ("Clear", "clear"),
362
+ (
363
+ "Understandable",
364
+ "understandable with some effort",
365
+ ),
366
+ ("Ambiguous", "ambiguous"),
367
+ ("Mixed", "mixed"),
368
+ ],
369
+ value="mixed",
370
+ label="Clarity",
371
+ info="The clarity of the text to be generated.",
372
+ )
373
+ with gr.Column():
374
+ labels = gr.Dropdown(
375
+ choices=[],
376
+ allow_custom_value=True,
377
+ interactive=True,
378
+ label="Labels",
379
+ multiselect=True,
380
+ info="Add the labels to classify the text.",
381
+ )
382
+ with gr.Blocks():
383
+ btn_suggested_labels = gr.Button(
384
+ value="Add suggested labels",
385
+ size="sm",
386
+ )
387
+ num_labels = gr.Number(
388
+ label="Number of labels",
389
+ value=1,
390
+ minimum=1,
391
+ maximum=10,
392
+ info="The number of labels to classify the text.",
393
+ )
394
+ num_rows = gr.Number(
395
+ label="Number of rows",
396
+ value=10,
397
+ minimum=1,
398
+ maximum=500,
399
+ info="More rows will take longer to generate.",
400
+ )
401
+
402
+ pipeline_code = get_pipeline_code_ui(
403
+ generate_pipeline_code(
404
+ system_prompt.value,
405
+ difficulty=difficulty.value,
406
+ clarity=clarity.value,
407
+ labels=labels.value,
408
+ num_labels=num_labels.value,
409
+ num_rows=num_rows.value,
410
+ )
411
+ )
412
+
413
+ # define app triggers
414
+ btn_suggested_labels.click(
415
+ fn=update_suggested_labels,
416
+ inputs=[system_prompt],
417
+ outputs=labels,
418
+ )
419
+
420
+ gr.on(
421
+ triggers=[
422
+ btn_generate_full_dataset.click,
423
+ btn_generate_full_dataset_argilla.click,
424
+ ],
425
+ fn=hide_success_message,
426
+ outputs=[success_message],
427
+ ).then(
428
+ fn=validate_input_labels,
429
+ inputs=[labels],
430
+ outputs=[labels],
431
+ ).success(
432
+ fn=generate_dataset,
433
+ inputs=[system_prompt, difficulty, clarity, labels, num_labels, num_rows],
434
+ outputs=[final_dataset],
435
+ show_progress=True,
436
+ )
437
+
438
+ btn_generate_and_push_to_argilla.click(
439
+ fn=validate_argilla_user_workspace_dataset,
440
+ inputs=[dataset_name, final_dataset, add_to_existing_dataset],
441
+ outputs=[final_dataset],
442
+ show_progress=True,
443
+ ).success(
444
+ fn=hide_success_message,
445
+ outputs=[success_message],
446
+ ).success(
447
+ fn=generate_dataset,
448
+ inputs=[system_prompt, difficulty, clarity, labels, num_labels, num_rows],
449
+ outputs=[final_dataset],
450
+ show_progress=True,
451
+ ).success(
452
+ fn=push_dataset_to_argilla,
453
+ inputs=[final_dataset, dataset_name, num_labels, labels],
454
+ outputs=[final_dataset],
455
+ show_progress=True,
456
+ ).success(
457
+ fn=show_success_message_argilla,
458
+ inputs=[],
459
+ outputs=[success_message],
460
+ )
461
+
462
+ btn_generate_and_push_to_hub.click(
463
+ fn=hide_success_message,
464
+ outputs=[success_message],
465
+ ).then(
466
+ fn=generate_dataset,
467
+ inputs=[system_prompt, difficulty, clarity, labels, num_labels, num_rows],
468
+ outputs=[final_dataset],
469
+ show_progress=True,
470
+ ).then(
471
+ fn=push_dataset_to_hub,
472
+ inputs=[final_dataset, private, org_name, repo_name, labels, num_labels],
473
+ outputs=[final_dataset],
474
+ show_progress=True,
475
+ ).then(
476
+ fn=push_pipeline_code_to_hub,
477
+ inputs=[pipeline_code, org_name, repo_name],
478
+ outputs=[],
479
+ show_progress=True,
480
+ ).success(
481
+ fn=show_success_message_hub,
482
+ inputs=[org_name, repo_name],
483
+ outputs=[success_message],
484
+ )
485
+
486
+ btn_push_to_hub.click(
487
+ fn=hide_success_message,
488
+ outputs=[success_message],
489
+ ).then(
490
+ fn=push_dataset_to_hub,
491
+ inputs=[final_dataset, private, org_name, repo_name, labels, num_labels],
492
+ outputs=[final_dataset],
493
+ show_progress=True,
494
+ ).then(
495
+ fn=push_pipeline_code_to_hub,
496
+ inputs=[pipeline_code, org_name, repo_name],
497
+ outputs=[],
498
+ show_progress=True,
499
+ ).success(
500
+ fn=show_success_message_hub,
501
+ inputs=[org_name, repo_name],
502
+ outputs=[success_message],
503
+ )
504
+
505
+ btn_push_to_argilla.click(
506
+ fn=hide_success_message,
507
+ outputs=[success_message],
508
+ ).success(
509
+ fn=validate_argilla_user_workspace_dataset,
510
+ inputs=[dataset_name, final_dataset, add_to_existing_dataset],
511
+ outputs=[final_dataset],
512
+ show_progress=True,
513
+ ).success(
514
+ fn=push_dataset_to_argilla,
515
+ inputs=[final_dataset, dataset_name, num_labels, labels],
516
+ outputs=[final_dataset],
517
+ show_progress=True,
518
+ ).success(
519
+ fn=show_success_message_argilla,
520
+ inputs=[],
521
+ outputs=[success_message],
522
+ )
523
+
524
+ system_prompt.change(
525
+ fn=generate_pipeline_code,
526
+ inputs=[system_prompt, difficulty, clarity, labels, num_labels, num_rows],
527
+ outputs=[pipeline_code],
528
+ )
529
+ difficulty.change(
530
+ fn=generate_pipeline_code,
531
+ inputs=[system_prompt, difficulty, clarity, labels, num_labels, num_rows],
532
+ outputs=[pipeline_code],
533
+ )
534
+ clarity.change(
535
+ fn=generate_pipeline_code,
536
+ inputs=[system_prompt, difficulty, clarity, labels, num_labels, num_rows],
537
+ outputs=[pipeline_code],
538
+ )
539
+ labels.change(
540
+ fn=generate_pipeline_code,
541
+ inputs=[system_prompt, difficulty, clarity, labels, num_labels, num_rows],
542
+ outputs=[pipeline_code],
543
+ )
544
+ num_labels.change(
545
+ fn=generate_pipeline_code,
546
+ inputs=[system_prompt, difficulty, clarity, labels, num_labels, num_rows],
547
+ outputs=[pipeline_code],
548
+ )
src/distilabel_dataset_generator/pipelines/base.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from src.distilabel_dataset_generator.utils import HF_TOKENS
2
+
3
+ DEFAULT_BATCH_SIZE = 5
4
+ TOKEN_INDEX = 0
5
+ MODEL = "meta-llama/Meta-Llama-3.1-8B-Instruct"
6
+
7
+
8
+ def _get_next_api_key():
9
+ global TOKEN_INDEX
10
+ api_key = HF_TOKENS[TOKEN_INDEX % len(HF_TOKENS)]
11
+ TOKEN_INDEX += 1
12
+ return api_key
src/distilabel_dataset_generator/pipelines/embeddings.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+
3
+ from sentence_transformers import SentenceTransformer
4
+ from sentence_transformers.models import StaticEmbedding
5
+
6
+ # Initialize a StaticEmbedding module
7
+ static_embedding = StaticEmbedding.from_model2vec("minishlab/M2V_base_output")
8
+ model = SentenceTransformer(modules=[static_embedding])
9
+
10
+
11
+ def get_embeddings(texts: List[str]) -> List[List[float]]:
12
+ return [embedding.tolist() for embedding in model.encode(texts)]
13
+
14
+
15
+ def get_sentence_embedding_dimensions() -> int:
16
+ return model.get_sentence_embedding_dimension()
src/distilabel_dataset_generator/pipelines/sft.py CHANGED
@@ -1,12 +1,11 @@
1
  import pandas as pd
2
- from datasets import Dataset
3
- from distilabel.distiset import Distiset
4
  from distilabel.llms import InferenceEndpointsLLM
5
- from distilabel.pipeline import Pipeline
6
- from distilabel.steps import KeepColumns
7
  from distilabel.steps.tasks import ChatGeneration, Magpie, TextGeneration
8
 
9
- from src.distilabel_dataset_generator.utils import HF_TOKENS
 
 
 
10
 
11
  INFORMATION_SEEKING_PROMPT = (
12
  "You are an AI assistant designed to provide accurate and concise information on a wide"
@@ -120,7 +119,6 @@ The prompt you write should follow the same style and structure as the following
120
  User dataset description:
121
  """
122
 
123
- MODEL = "meta-llama/Meta-Llama-3.1-8B-Instruct"
124
  DEFAULT_DATASET_DESCRIPTIONS = (
125
  "rude customer assistant for a phone company",
126
  "assistant that solves math puzzles using python",
@@ -157,8 +155,6 @@ _STOP_SEQUENCES = [
157
  "assistant",
158
  " \n\n",
159
  ]
160
- DEFAULT_BATCH_SIZE = 5
161
- TOKEN_INDEX = 0
162
 
163
 
164
  def _get_output_mappings(num_turns):
@@ -189,7 +185,7 @@ with Pipeline(name="sft") as pipeline:
189
  tokenizer_id=MODEL,
190
  magpie_pre_query_template="llama3",
191
  generation_kwargs={{
192
- "temperature": 0.8,
193
  "do_sample": True,
194
  "max_new_tokens": 2048,
195
  "stop_sequences": {_STOP_SEQUENCES}
@@ -213,13 +209,6 @@ if __name__ == "__main__":
213
  return code
214
 
215
 
216
- def _get_next_api_key():
217
- global TOKEN_INDEX
218
- api_key = HF_TOKENS[TOKEN_INDEX % len(HF_TOKENS)]
219
- TOKEN_INDEX += 1
220
- return api_key
221
-
222
-
223
  def get_magpie_generator(num_turns, num_rows, system_prompt, is_sample):
224
  input_mappings = _get_output_mappings(num_turns)
225
  output_mappings = input_mappings.copy()
@@ -231,7 +220,7 @@ def get_magpie_generator(num_turns, num_rows, system_prompt, is_sample):
231
  api_key=_get_next_api_key(),
232
  magpie_pre_query_template="llama3",
233
  generation_kwargs={
234
- "temperature": 0.8,
235
  "do_sample": True,
236
  "max_new_tokens": 256 if is_sample else 512,
237
  "stop_sequences": _STOP_SEQUENCES,
@@ -250,7 +239,7 @@ def get_magpie_generator(num_turns, num_rows, system_prompt, is_sample):
250
  api_key=_get_next_api_key(),
251
  magpie_pre_query_template="llama3",
252
  generation_kwargs={
253
- "temperature": 0.8,
254
  "do_sample": True,
255
  "max_new_tokens": 256 if is_sample else 1024,
256
  "stop_sequences": _STOP_SEQUENCES,
@@ -300,12 +289,9 @@ def get_response_generator(num_turns, system_prompt, is_sample):
300
 
301
 
302
  def get_prompt_generator():
303
- global TOKEN_INDEX
304
- api_key = HF_TOKENS[TOKEN_INDEX % len(HF_TOKENS)]
305
- TOKEN_INDEX += 1
306
  prompt_generator = TextGeneration(
307
  llm=InferenceEndpointsLLM(
308
- api_key=api_key,
309
  model_id=MODEL,
310
  tokenizer_id=MODEL,
311
  generation_kwargs={
@@ -318,95 +304,3 @@ def get_prompt_generator():
318
  )
319
  prompt_generator.load()
320
  return prompt_generator
321
-
322
-
323
- def get_pipeline(num_turns, num_rows, system_prompt, is_sample):
324
- input_mappings = _get_output_mappings(num_turns)
325
- output_mappings = input_mappings
326
-
327
- with Pipeline(name="sft") as pipeline:
328
- magpie = get_magpie_generator(num_turns, num_rows, system_prompt, is_sample)
329
- generate_response = get_response_generator(system_prompt, is_sample)
330
-
331
- keep_columns = KeepColumns(
332
- columns=list(output_mappings.values()) + ["model_name"],
333
- )
334
-
335
- magpie.connect(generate_response)
336
- generate_response.connect(keep_columns)
337
- return pipeline
338
-
339
-
340
- if __name__ == "__main__":
341
- prompt_generation_step = get_prompt_generator()
342
- system_prompt = next(
343
- prompt_generation_step.process(
344
- [
345
- {
346
- "system_prompt": PROMPT_CREATION_PROMPT,
347
- "instruction": DEFAULT_DATASET_DESCRIPTIONS[0],
348
- }
349
- ]
350
- )
351
- )[0]["generation"]
352
- num_rows = 2
353
- num_turns = 1
354
- magpie_generator = get_magpie_generator(num_turns, num_rows, system_prompt, False)
355
- response_generator = get_response_generator(num_turns, system_prompt, False)
356
- total_steps = num_rows * 2
357
- batch_size = 5 # Adjust this value as needed
358
-
359
- # create instructions
360
- magpie_results = []
361
- for i in range(0, num_rows, batch_size):
362
- batch = list(magpie_generator.process())[:batch_size]
363
- magpie_results.extend([item[0] for item in batch])
364
-
365
- # generate responses
366
- response_results = []
367
- if num_turns == 1:
368
- for i in range(0, len(magpie_results), batch_size):
369
- batch = magpie_results[i : i + batch_size]
370
- batch = [entry[0] for entry in batch]
371
- responses = list(response_generator.process(inputs=batch))
372
- response_results.extend(responses)
373
- for result in response_results:
374
- result[0]["prompt"] = result[0]["instruction"]
375
- result[0]["completion"] = result[0]["generation"]
376
- result[0]["system_prompt"] = system_prompt
377
- else:
378
- for result in magpie_results:
379
- result[0]["conversation"].insert(
380
- 0, {"role": "system", "content": system_prompt}
381
- )
382
- result[0]["messages"] = result[0]["conversation"]
383
- for i in range(0, len(magpie_results), batch_size):
384
- batch = magpie_results[i : i + batch_size]
385
- batch = [entry[0] for entry in batch]
386
- responses = list(response_generator.process(inputs=batch))
387
- response_results.extend(responses)
388
-
389
- for result in response_results:
390
- result[0]["messages"].append(
391
- {"role": "assistant", "content": result[0]["generation"]}
392
- )
393
-
394
- distiset_results = []
395
- for result in response_results[0]:
396
- record = {}
397
- for relevant_keys in [
398
- "messages",
399
- "prompt",
400
- "completion",
401
- "model_name",
402
- "system_prompt",
403
- ]:
404
- if relevant_keys in result:
405
- record[relevant_keys] = result[relevant_keys]
406
- distiset_results.append(record)
407
-
408
- distiset = Distiset(
409
- {
410
- "default": Dataset.from_list(distiset_results),
411
- }
412
- )
 
1
  import pandas as pd
 
 
2
  from distilabel.llms import InferenceEndpointsLLM
 
 
3
  from distilabel.steps.tasks import ChatGeneration, Magpie, TextGeneration
4
 
5
+ from src.distilabel_dataset_generator.pipelines.base import (
6
+ MODEL,
7
+ _get_next_api_key,
8
+ )
9
 
10
  INFORMATION_SEEKING_PROMPT = (
11
  "You are an AI assistant designed to provide accurate and concise information on a wide"
 
119
  User dataset description:
120
  """
121
 
 
122
  DEFAULT_DATASET_DESCRIPTIONS = (
123
  "rude customer assistant for a phone company",
124
  "assistant that solves math puzzles using python",
 
155
  "assistant",
156
  " \n\n",
157
  ]
 
 
158
 
159
 
160
  def _get_output_mappings(num_turns):
 
185
  tokenizer_id=MODEL,
186
  magpie_pre_query_template="llama3",
187
  generation_kwargs={{
188
+ "temperature": 0.9,
189
  "do_sample": True,
190
  "max_new_tokens": 2048,
191
  "stop_sequences": {_STOP_SEQUENCES}
 
209
  return code
210
 
211
 
 
 
 
 
 
 
 
212
  def get_magpie_generator(num_turns, num_rows, system_prompt, is_sample):
213
  input_mappings = _get_output_mappings(num_turns)
214
  output_mappings = input_mappings.copy()
 
220
  api_key=_get_next_api_key(),
221
  magpie_pre_query_template="llama3",
222
  generation_kwargs={
223
+ "temperature": 0.9,
224
  "do_sample": True,
225
  "max_new_tokens": 256 if is_sample else 512,
226
  "stop_sequences": _STOP_SEQUENCES,
 
239
  api_key=_get_next_api_key(),
240
  magpie_pre_query_template="llama3",
241
  generation_kwargs={
242
+ "temperature": 0.9,
243
  "do_sample": True,
244
  "max_new_tokens": 256 if is_sample else 1024,
245
  "stop_sequences": _STOP_SEQUENCES,
 
289
 
290
 
291
  def get_prompt_generator():
 
 
 
292
  prompt_generator = TextGeneration(
293
  llm=InferenceEndpointsLLM(
294
+ api_key=_get_next_api_key(),
295
  model_id=MODEL,
296
  tokenizer_id=MODEL,
297
  generation_kwargs={
 
304
  )
305
  prompt_generator.load()
306
  return prompt_generator
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/distilabel_dataset_generator/pipelines/textcat.py ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+
3
+ import pandas as pd
4
+ from distilabel.llms import InferenceEndpointsLLM
5
+ from distilabel.steps.tasks import (
6
+ GenerateTextClassificationData,
7
+ TextClassification,
8
+ TextGeneration,
9
+ )
10
+ from src.distilabel_dataset_generator.pipelines.base import (
11
+ MODEL,
12
+ _get_next_api_key,
13
+ )
14
+ from src.distilabel_dataset_generator.utils import get_preprocess_labels
15
+
16
+ PROMPT_CREATION_PROMPT = """You are an AI assistant specialized in generating very precise text classification tasks for dataset creation.
17
+
18
+ Your task is to write a prompt following the instruction of the user. Respond with the prompt and nothing else.
19
+
20
+ The prompt you write should follow the same style and structure as the following example prompts, clearly specifying the possible classification labels.
21
+
22
+ If a label is composed of multiple words, use a hyphen to separate them. For example, 'smartphone-review', 'customer-service', 'product-quality'.:
23
+
24
+ Classify the following customer review of a cinema as either 'positive' or 'negative'.
25
+
26
+ Classify the following news article into one or more of the following categories: 'politics', 'sports', 'technology', 'entertainment', 'health', 'business', 'environment', 'education', 'science', 'international'.
27
+
28
+ Determine the sentiment of the following social media post: 'ambiguous', 'sarcastic', 'informative', 'emotional'.
29
+
30
+ Identify the issue category for the following technical support ticket: 'billing', 'technical', 'account', 'shipping', 'returns', 'installation', 'subscription'.
31
+
32
+ Classify the following movie review into one of the following categories: 'critical', 'praise', 'disappointed', 'enthusiastic'.
33
+
34
+ Determine the level of customer satisfaction from the following customer service transcript: 'satisfied', 'dissatisfied', 'highly-satisfied', 'somewhat-dissatisfied', 'indifferent'.
35
+
36
+ Categorize the following product description into one of the following product types: 'smartphone', 'laptop', 'tablet', 'smartwatch', 'e-reader', 'headphones'.
37
+
38
+ Classify the following tweet as expressing either 'support' or 'opposition' to the political event discussed.
39
+
40
+ Classify the following restaurant review into one of the following categories: 'food-quality', 'service', 'ambiance', or 'price'.
41
+
42
+ Classify the following blog post based on its primary fashion trend or style: 'casual', 'formal', 'streetwear', 'vintage' or 'sustainable-fashion'.
43
+
44
+ User dataset description:
45
+ """
46
+
47
+ DEFAULT_DATASET_DESCRIPTIONS = [
48
+ "A dataset covering customer reviews for an e-commerce website.",
49
+ "A dataset covering news articles about various topics.",
50
+ ]
51
+
52
+ DEFAULT_DATASETS = [
53
+ pd.DataFrame.from_dict(
54
+ {
55
+ "text": [
56
+ "I love the product! It's amazing and I'll buy it again.",
57
+ "The product was okay, but I wouldn't buy it again.",
58
+ ],
59
+ "label": ["positive", "negative"],
60
+ }
61
+ ),
62
+ pd.DataFrame.from_dict(
63
+ {
64
+ "text": [
65
+ "Yesterday, the US stock market had a significant increase.",
66
+ "New research suggests that the Earth is not a perfect sphere.",
67
+ ],
68
+ "labels": [["economy", "politics"], ["science", "environment"]],
69
+ }
70
+ ),
71
+ ]
72
+
73
+ DEFAULT_SYSTEM_PROMPTS = [
74
+ "Classify the following customer review as either 'positive' or 'negative'.",
75
+ "Classify the following news article into one of the following categories: 'politics', 'economy', 'environment', 'science', 'health'.",
76
+ ]
77
+
78
+
79
+ def generate_pipeline_code(
80
+ system_prompt: str,
81
+ difficulty: str = None,
82
+ clarity: str = None,
83
+ labels: List[str] = None,
84
+ num_labels: int = 1,
85
+ num_rows: int = 10,
86
+ ) -> str:
87
+ labels = get_preprocess_labels(labels)
88
+ base_code = f"""
89
+ # Requirements: `pip install distilabel[hf-inference-endpoints]`
90
+ import os
91
+ from distilabel.llms import InferenceEndpointsLLM
92
+ from distilabel.pipeline import Pipeline
93
+ from distilabel.steps import LoadDataFromDicts, KeepColumns
94
+ from distilabel.steps.tasks import {"GenerateTextClassificationData" if num_labels == 1 else "GenerateTextClassificationData, TextClassification"}
95
+
96
+ MODEL = "{MODEL}"
97
+ TEXT_CLASSIFICATION_TASK = "{system_prompt}"
98
+ os.environ["HF_TOKEN"] = (
99
+ "hf_xxx" # https://huggingface.co/settings/tokens/new?ownUserPermissions=repo.content.read&ownUserPermissions=repo.write&globalPermissions=inference.serverless.write&canReadGatedRepos=true&tokenType=fineGrained
100
+ )
101
+
102
+ with Pipeline(name="textcat") as pipeline:
103
+
104
+ task_generator = LoadDataFromDicts(data=[{{"task": TEXT_CLASSIFICATION_TASK}}])
105
+
106
+ textcat_generation = GenerateTextClassificationData(
107
+ llm=InferenceEndpointsLLM(
108
+ model_id=MODEL,
109
+ tokenizer_id=MODEL,
110
+ api_key=os.environ["HF_TOKEN"],
111
+ generation_kwargs={{
112
+ "temperature": 0.8,
113
+ "max_new_tokens": 2048,
114
+ }},
115
+ ),
116
+ difficulty={None if difficulty == "mixed" else repr(difficulty)},
117
+ clarity={None if clarity == "mixed" else repr(clarity)},
118
+ num_generations={num_rows},
119
+ output_mappings={{"input_text": "text"}},
120
+ )
121
+ """
122
+
123
+ if num_labels == 1:
124
+ return (
125
+ base_code
126
+ + """
127
+ keep_columns = KeepColumns(
128
+ columns=["text", "label"],
129
+ )
130
+
131
+ # Connect steps in the pipeline
132
+ task_generator >> textcat_generation >> keep_columns
133
+
134
+ if __name__ == "__main__":
135
+ distiset = pipeline.run()
136
+ """
137
+ )
138
+
139
+ return (
140
+ base_code
141
+ + f"""
142
+ keep_columns = KeepColumns(
143
+ columns=["text"],
144
+ )
145
+
146
+ textcat_labeller = TextClassification(
147
+ llm=InferenceEndpointsLLM(
148
+ model_id=MODEL,
149
+ tokenizer_id=MODEL,
150
+ api_key=os.environ["HF_TOKEN"],
151
+ generation_kwargs={{
152
+ "temperature": 0.8,
153
+ "max_new_tokens": 2048,
154
+ }},
155
+ ),
156
+ n={num_labels},
157
+ available_labels={labels},
158
+ context=TEXT_CLASSIFICATION_TASK,
159
+ default_label="unknown"
160
+ )
161
+
162
+ # Connect steps in the pipeline
163
+ task_generator >> textcat_generation >> keep_columns >> textcat_labeller
164
+
165
+ if __name__ == "__main__":
166
+ distiset = pipeline.run()
167
+ """
168
+ )
169
+
170
+
171
+ def get_textcat_generator(difficulty, clarity, is_sample):
172
+ textcat_generator = GenerateTextClassificationData(
173
+ llm=InferenceEndpointsLLM(
174
+ model_id=MODEL,
175
+ tokenizer_id=MODEL,
176
+ api_key=_get_next_api_key(),
177
+ generation_kwargs={
178
+ "temperature": 0.8,
179
+ "max_new_tokens": 256 if is_sample else 1024,
180
+ },
181
+ ),
182
+ difficulty=None if difficulty == "mixed" else difficulty,
183
+ clarity=None if clarity == "mixed" else clarity,
184
+ )
185
+ textcat_generator.load()
186
+ return textcat_generator
187
+
188
+
189
+ def get_labeller_generator(system_prompt, labels, num_labels, is_sample):
190
+ labeller_generator = TextClassification(
191
+ llm=InferenceEndpointsLLM(
192
+ model_id=MODEL,
193
+ tokenizer_id=MODEL,
194
+ api_key=_get_next_api_key(),
195
+ generation_kwargs={
196
+ "temperature": 0.8,
197
+ "max_new_tokens": 256 if is_sample else 1024,
198
+ },
199
+ ),
200
+ context=system_prompt,
201
+ available_labels=labels,
202
+ n=num_labels,
203
+ default_label="unknown",
204
+ )
205
+ labeller_generator.load()
206
+ return labeller_generator
207
+
208
+
209
+ def get_prompt_generator():
210
+ prompt_generator = TextGeneration(
211
+ llm=InferenceEndpointsLLM(
212
+ api_key=_get_next_api_key(),
213
+ model_id=MODEL,
214
+ tokenizer_id=MODEL,
215
+ generation_kwargs={
216
+ "temperature": 0.8,
217
+ "max_new_tokens": 2048,
218
+ "do_sample": True,
219
+ },
220
+ ),
221
+ use_system_prompt=True,
222
+ )
223
+ prompt_generator.load()
224
+ return prompt_generator
src/distilabel_dataset_generator/utils.py CHANGED
@@ -1,5 +1,7 @@
1
  import os
 
2
 
 
3
  import gradio as gr
4
  from gradio.oauth import (
5
  OAUTH_CLIENT_ID,
@@ -10,6 +12,8 @@ from gradio.oauth import (
10
  )
11
  from huggingface_hub import whoami
12
 
 
 
13
  HF_TOKENS = [os.getenv("HF_TOKEN")] + [os.getenv(f"HF_TOKEN_{i}") for i in range(1, 10)]
14
  HF_TOKENS = [token for token in HF_TOKENS if token]
15
 
@@ -76,8 +80,48 @@ def get_token(oauth_token: OAuthToken = None):
76
  return ""
77
 
78
 
79
- def swap_visibilty(oauth_token: OAuthToken = None):
80
  if oauth_token:
81
  return gr.update(elem_classes=["main_ui_logged_in"])
82
  else:
83
  return gr.update(elem_classes=["main_ui_logged_out"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
+ from typing import Union, List, Optional
3
 
4
+ import argilla as rg
5
  import gradio as gr
6
  from gradio.oauth import (
7
  OAUTH_CLIENT_ID,
 
12
  )
13
  from huggingface_hub import whoami
14
 
15
+ _LOGGED_OUT_CSS = ".main_ui_logged_out{opacity: 0.3; pointer-events: none}"
16
+
17
  HF_TOKENS = [os.getenv("HF_TOKEN")] + [os.getenv(f"HF_TOKEN_{i}") for i in range(1, 10)]
18
  HF_TOKENS = [token for token in HF_TOKENS if token]
19
 
 
80
  return ""
81
 
82
 
83
+ def swap_visibilty(oauth_token: Optional[OAuthToken] = None):
84
  if oauth_token:
85
  return gr.update(elem_classes=["main_ui_logged_in"])
86
  else:
87
  return gr.update(elem_classes=["main_ui_logged_out"])
88
+
89
+
90
+ def get_base_app():
91
+ with gr.Blocks(
92
+ title="🧬 Synthetic Data Generator",
93
+ head="🧬 Synthetic Data Generator",
94
+ css=_LOGGED_OUT_CSS,
95
+ ) as app:
96
+ with gr.Row():
97
+ gr.Markdown(
98
+ "Want to run this locally or with other LLMs? Take a look at the FAQ tab. distilabel Synthetic Data Generator is free, we use the authentication token to push the dataset to the Hugging Face Hub and not for data generation."
99
+ )
100
+ with gr.Row():
101
+ gr.Column()
102
+ get_login_button()
103
+ gr.Column()
104
+
105
+ gr.Markdown("## Iterate on a sample dataset")
106
+ with gr.Column() as main_ui:
107
+ pass
108
+
109
+ return app
110
+
111
+
112
+ def get_argilla_client() -> Union[rg.Argilla, None]:
113
+ try:
114
+ api_url = os.getenv("ARGILLA_API_URL_SDG_REVIEWER")
115
+ api_key = os.getenv("ARGILLA_API_KEY_SDG_REVIEWER")
116
+ if api_url is None or api_key is None:
117
+ api_url = os.getenv("ARGILLA_API_URL")
118
+ api_key = os.getenv("ARGILLA_API_KEY")
119
+ return rg.Argilla(
120
+ api_url=api_url,
121
+ api_key=api_key,
122
+ )
123
+ except Exception:
124
+ return None
125
+
126
+ def get_preprocess_labels(labels: Optional[List[str]]) -> List[str]:
127
+ return [label.lower().strip() for label in labels] if labels else []