dvilasuero HF staff commited on
Commit
2723bd3
·
verified ·
1 Parent(s): cbf899b

Update src/distilabel_dataset_generator/sft.py

Browse files
src/distilabel_dataset_generator/sft.py CHANGED
@@ -1,4 +1,5 @@
1
  import multiprocessing
 
2
 
3
  import gradio as gr
4
  import pandas as pd
@@ -179,7 +180,8 @@ def _run_pipeline(result_queue, num_turns, num_rows, system_prompt, token: str =
179
  result_queue.put(distiset)
180
 
181
 
182
- def generate_system_prompt(dataset_description, token: OAuthToken = None):
 
183
  generate_description = TextGeneration(
184
  llm=InferenceEndpointsLLM(
185
  model_id=MODEL,
@@ -192,8 +194,10 @@ def generate_system_prompt(dataset_description, token: OAuthToken = None):
192
  ),
193
  use_system_prompt=True,
194
  )
 
195
  generate_description.load()
196
- return next(
 
197
  generate_description.process(
198
  [
199
  {
@@ -203,6 +207,15 @@ def generate_system_prompt(dataset_description, token: OAuthToken = None):
203
  ]
204
  )
205
  )[0]["generation"]
 
 
 
 
 
 
 
 
 
206
 
207
 
208
  def generate_dataset(
@@ -213,6 +226,7 @@ def generate_dataset(
213
  orgs_selector=None,
214
  dataset_name=None,
215
  token: OAuthToken = None,
 
216
  ):
217
  if dataset_name is not None:
218
  if not dataset_name:
@@ -242,7 +256,7 @@ def generate_dataset(
242
  duration = 1000
243
 
244
  gr.Info(
245
- "Started pipeline execution. This might take a while, depending on the number of rows and turns you have selected. Don't close this page.",
246
  duration=duration,
247
  )
248
  result_queue = multiprocessing.Queue()
@@ -250,15 +264,24 @@ def generate_dataset(
250
  target=_run_pipeline,
251
  args=(result_queue, num_turns, num_rows, system_prompt),
252
  )
 
253
  try:
254
  p.start()
 
 
 
 
 
 
255
  p.join()
256
  except Exception as e:
257
  raise gr.Error(f"An error occurred during dataset generation: {str(e)}")
 
 
258
  distiset = result_queue.get()
259
 
260
  if dataset_name is not None:
261
- gr.Info("Pushing dataset to Hugging Face Hub.")
262
  repo_id = f"{orgs_selector}/{dataset_name}"
263
  distiset.push_to_hub(
264
  repo_id=repo_id,
@@ -269,31 +292,30 @@ def generate_dataset(
269
  gr.Info(
270
  f'Dataset pushed to Hugging Face Hub: <a href="https://huggingface.co/datasets/{repo_id}">https://huggingface.co/datasets/{repo_id}</a>'
271
  )
 
 
 
 
 
272
  else:
273
- # If not pushing to hub generate the dataset directly
274
- distiset = distiset["default"]["train"]
275
- if num_turns == 1:
276
- outputs = distiset.to_pandas()[["prompt", "completion"]]
277
- else:
278
- outputs = {"conversation_id": [], "role": [], "content": []}
279
- conversations = distiset["messages"]
280
- for idx, entry in enumerate(conversations):
281
- for message in entry["messages"]:
282
- outputs["conversation_id"].append(idx + 1)
283
- outputs["role"].append(message["role"])
284
- outputs["content"].append(message["content"])
285
- return pd.DataFrame(outputs)
286
 
287
 
288
  with gr.Blocks(
289
  title="⚗️ Distilabel Dataset Generator",
290
  head="⚗️ Distilabel Dataset Generator",
291
  ) as app:
292
- gr.Markdown(
293
- """
294
-
295
- """
296
- )
297
  dataset_description = gr.TextArea(
298
  label="Provide a description of the dataset",
299
  value=DEFAULT_SYSTEM_PROMPT_DESCRIPTION,
@@ -316,25 +338,38 @@ with gr.Blocks(
316
  value="Regenerate sample dataset",
317
  )
318
  gr.Column(scale=1)
319
-
320
- table = gr.Dataframe(label="Generated Dataset", wrap=True, value=DEFAULT_DATASET)
 
 
 
 
 
 
321
 
322
  btn_generate_system_prompt.click(
323
  fn=generate_system_prompt,
324
  inputs=[dataset_description],
325
  outputs=[system_prompt],
 
326
  ).then(
327
- fn=generate_dataset,
328
  inputs=[system_prompt],
329
  outputs=[table],
 
330
  )
331
 
332
  btn_generate_sample_dataset.click(
333
- fn=generate_dataset,
334
  inputs=[system_prompt],
335
  outputs=[table],
 
336
  )
337
 
 
 
 
 
338
  btn_login: gr.LoginButton | None = get_login_button()
339
  with gr.Column() as push_to_hub_ui:
340
  with gr.Row(variant="panel"):
@@ -371,7 +406,9 @@ with gr.Blocks(
371
  orgs_selector,
372
  dataset_name_push_to_hub,
373
  ],
 
 
374
  )
375
 
376
  app.load(get_org_dropdown, outputs=[orgs_selector])
377
- app.load(fn=swap_visibilty, outputs=push_to_hub_ui)
 
1
  import multiprocessing
2
+ import time
3
 
4
  import gradio as gr
5
  import pandas as pd
 
180
  result_queue.put(distiset)
181
 
182
 
183
+ def generate_system_prompt(dataset_description, token: OAuthToken = None, progress=gr.Progress()):
184
+ progress(0.1, desc="Initializing text generation")
185
  generate_description = TextGeneration(
186
  llm=InferenceEndpointsLLM(
187
  model_id=MODEL,
 
194
  ),
195
  use_system_prompt=True,
196
  )
197
+ progress(0.4, desc="Loading model")
198
  generate_description.load()
199
+ progress(0.7, desc="Generating system prompt")
200
+ result = next(
201
  generate_description.process(
202
  [
203
  {
 
207
  ]
208
  )
209
  )[0]["generation"]
210
+ progress(1.0, desc="System prompt generated")
211
+ return result
212
+
213
+
214
+ def generate_sample_dataset(system_prompt, progress=gr.Progress()):
215
+ progress(0.1, desc="Initializing sample dataset generation")
216
+ result = generate_dataset(system_prompt, num_turns=1, num_rows=2, progress=progress)
217
+ progress(1.0, desc="Sample dataset generated")
218
+ return result
219
 
220
 
221
  def generate_dataset(
 
226
  orgs_selector=None,
227
  dataset_name=None,
228
  token: OAuthToken = None,
229
+ progress=gr.Progress(),
230
  ):
231
  if dataset_name is not None:
232
  if not dataset_name:
 
256
  duration = 1000
257
 
258
  gr.Info(
259
+ "Dataset generation started. This might take a while. Don't close the page.",
260
  duration=duration,
261
  )
262
  result_queue = multiprocessing.Queue()
 
264
  target=_run_pipeline,
265
  args=(result_queue, num_turns, num_rows, system_prompt),
266
  )
267
+
268
  try:
269
  p.start()
270
+ total_steps = 100
271
+ for step in range(total_steps):
272
+ if not p.is_alive():
273
+ break
274
+ progress((step + 1) / total_steps, desc=f"Generating dataset with {num_rows} rows")
275
+ time.sleep(0.5) # Adjust this value based on your needs
276
  p.join()
277
  except Exception as e:
278
  raise gr.Error(f"An error occurred during dataset generation: {str(e)}")
279
+
280
+
281
  distiset = result_queue.get()
282
 
283
  if dataset_name is not None:
284
+ progress(0.95, desc="Pushing dataset to Hugging Face Hub.")
285
  repo_id = f"{orgs_selector}/{dataset_name}"
286
  distiset.push_to_hub(
287
  repo_id=repo_id,
 
292
  gr.Info(
293
  f'Dataset pushed to Hugging Face Hub: <a href="https://huggingface.co/datasets/{repo_id}">https://huggingface.co/datasets/{repo_id}</a>'
294
  )
295
+
296
+ # If not pushing to hub generate the dataset directly
297
+ distiset = distiset["default"]["train"]
298
+ if num_turns == 1:
299
+ outputs = distiset.to_pandas()[["prompt", "completion"]]
300
  else:
301
+ outputs = distiset.to_pandas()[["messages"]]
302
+ # outputs = {"conversation_id": [], "role": [], "content": []}
303
+ # conversations = distiset["messages"]
304
+ # for idx, entry in enumerate(conversations):
305
+ # for message in entry["messages"]:
306
+ # outputs["conversation_id"].append(idx + 1)
307
+ # outputs["role"].append(message["role"])
308
+ # outputs["content"].append(message["content"])
309
+
310
+ progress(1.0, desc="Dataset generation completed")
311
+ return pd.DataFrame(outputs)
 
 
312
 
313
 
314
  with gr.Blocks(
315
  title="⚗️ Distilabel Dataset Generator",
316
  head="⚗️ Distilabel Dataset Generator",
317
  ) as app:
318
+ gr.Markdown("## Iterate on a sample dataset")
 
 
 
 
319
  dataset_description = gr.TextArea(
320
  label="Provide a description of the dataset",
321
  value=DEFAULT_SYSTEM_PROMPT_DESCRIPTION,
 
338
  value="Regenerate sample dataset",
339
  )
340
  gr.Column(scale=1)
341
+
342
+ #table = gr.HTML(_format_dataframe_as_html(DEFAULT_DATASET))
343
+ table = gr.DataFrame(
344
+ value=DEFAULT_DATASET,
345
+ interactive=False,
346
+ wrap=True,
347
+
348
+ )
349
 
350
  btn_generate_system_prompt.click(
351
  fn=generate_system_prompt,
352
  inputs=[dataset_description],
353
  outputs=[system_prompt],
354
+ show_progress=True,
355
  ).then(
356
+ fn=generate_sample_dataset,
357
  inputs=[system_prompt],
358
  outputs=[table],
359
+ show_progress=True,
360
  )
361
 
362
  btn_generate_sample_dataset.click(
363
+ fn=generate_sample_dataset,
364
  inputs=[system_prompt],
365
  outputs=[table],
366
+ show_progress=True,
367
  )
368
 
369
+ # Add a header for the full dataset generation section
370
+ gr.Markdown("## Generate full dataset and push to hub")
371
+ gr.Markdown("Once you're satisfied with the sample, generate a larger dataset and push it to the hub.")
372
+
373
  btn_login: gr.LoginButton | None = get_login_button()
374
  with gr.Column() as push_to_hub_ui:
375
  with gr.Row(variant="panel"):
 
406
  orgs_selector,
407
  dataset_name_push_to_hub,
408
  ],
409
+ outputs=[table],
410
+ show_progress=True,
411
  )
412
 
413
  app.load(get_org_dropdown, outputs=[orgs_selector])
414
+ app.load(fn=swap_visibilty, outputs=push_to_hub_ui)