davidberenstein1957 HF staff commited on
Commit
0d28c87
Β·
1 Parent(s): ed40758

feat: reintroduce oauth

Browse files

feat: create markdown notification on succes
feat: set magpie generation batch size to 1

app.py CHANGED
@@ -8,9 +8,20 @@ theme = gr.themes.Monochrome(
8
  font=[gr.themes.GoogleFont("Inter"), "ui-sans-serif", "system-ui", "sans-serif"],
9
  )
10
 
 
 
 
 
 
 
 
 
 
 
11
  demo = gr.TabbedInterface(
12
  [sft_app, faq_app],
13
  ["Supervised Fine-Tuning", "FAQ"],
 
14
  title="βš—οΈ distilabel Dataset Generator",
15
  head="βš—οΈ distilabel Dataset Generator",
16
  theme=theme,
 
8
  font=[gr.themes.GoogleFont("Inter"), "ui-sans-serif", "system-ui", "sans-serif"],
9
  )
10
 
11
+ css = """
12
+ h1{font-size: 2em}
13
+ h3{margin-top: 0}
14
+ #component-1{text-align:center}
15
+ .main_ui_logged_out{opacity: 0.3; pointer-events: none}
16
+ .tabitem{border: 0px}
17
+ .group_padding{padding: .55em}
18
+ #space_model .wrap > label:last-child{opacity: 0.3; pointer-events:none}
19
+ """
20
+
21
  demo = gr.TabbedInterface(
22
  [sft_app, faq_app],
23
  ["Supervised Fine-Tuning", "FAQ"],
24
+ css=css,
25
  title="βš—οΈ distilabel Dataset Generator",
26
  head="βš—οΈ distilabel Dataset Generator",
27
  theme=theme,
src/distilabel_dataset_generator/apps/sft.py CHANGED
@@ -15,6 +15,11 @@ from src.distilabel_dataset_generator.pipelines.sft import (
15
  get_pipeline,
16
  get_prompt_generation_step,
17
  )
 
 
 
 
 
18
 
19
 
20
  def _run_pipeline(result_queue, num_turns, num_rows, system_prompt):
@@ -59,18 +64,26 @@ def generate_dataset(
59
  num_turns=1,
60
  num_rows=5,
61
  private=True,
62
- repo_id=None,
 
63
  token=None,
64
  progress=gr.Progress(),
65
  ):
 
 
 
 
 
66
  if repo_id is not None:
67
  if not repo_id:
68
- raise gr.Error("Please provide a dataset name to push the dataset to.")
 
 
69
  try:
70
  whoami(token=token)
71
  except Exception:
72
  raise gr.Error(
73
- "Provide a Hugging Face to be able to push the dataset to the Hub."
74
  )
75
 
76
  if num_turns > 4:
@@ -111,7 +124,7 @@ def generate_dataset(
111
  (step + 1) / total_steps,
112
  desc=f"Generating dataset with {num_rows} rows",
113
  )
114
- time.sleep(0.5) # Adjust this value based on your needs
115
  p.join()
116
  except Exception as e:
117
  raise gr.Error(f"An error occurred during dataset generation: {str(e)}")
@@ -126,9 +139,6 @@ def generate_dataset(
126
  include_script=False,
127
  token=token,
128
  )
129
- gr.Info(
130
- f'Dataset pushed to Hugging Face Hub: <a href="https://huggingface.co/datasets/{repo_id}">https://huggingface.co/datasets/{repo_id}</a>'
131
- )
132
 
133
  # If not pushing to hub generate the dataset directly
134
  distiset = distiset["default"]["train"]
@@ -193,98 +203,103 @@ with gr.Blocks(
193
  title="βš—οΈ Distilabel Dataset Generator",
194
  head="βš—οΈ Distilabel Dataset Generator",
195
  ) as app:
 
196
  gr.Markdown("## Iterate on a sample dataset")
197
- dataset_description = gr.TextArea(
198
- label="Provide a description of the dataset",
199
- value=DEFAULT_DATASET_DESCRIPTION,
200
- )
201
- with gr.Row():
202
- gr.Column(scale=1)
203
- btn_generate_system_prompt = gr.Button(value="Generate sample dataset")
204
- gr.Column(scale=1)
205
-
206
- system_prompt = gr.TextArea(
207
- label="If you want to improve the dataset, you can tune the system prompt and regenerate the sample",
208
- value=DEFAULT_SYSTEM_PROMPT,
209
- )
210
-
211
- with gr.Row():
212
- gr.Column(scale=1)
213
- btn_generate_sample_dataset = gr.Button(
214
- value="Regenerate sample dataset",
215
  )
216
- gr.Column(scale=1)
217
-
218
- with gr.Row():
219
- table = gr.DataFrame(
220
- value=DEFAULT_DATASET,
221
- interactive=False,
222
- wrap=True,
 
223
  )
224
 
225
- result = btn_generate_system_prompt.click(
226
- fn=generate_system_prompt,
227
- inputs=[dataset_description],
228
- outputs=[system_prompt],
229
- show_progress=True,
230
- ).then(
231
- fn=generate_sample_dataset,
232
- inputs=[system_prompt],
233
- outputs=[table],
234
- show_progress=True,
235
- )
236
-
237
- btn_generate_sample_dataset.click(
238
- fn=generate_sample_dataset,
239
- inputs=[system_prompt],
240
- outputs=[table],
241
- show_progress=True,
242
- )
243
-
244
- # Add a header for the full dataset generation section
245
- gr.Markdown("## Generate full dataset")
246
- gr.Markdown(
247
- "Once you're satisfied with the sample, generate a larger dataset and push it to the hub. Get <a href='https://huggingface.co/settings/tokens' target='_blank'>a Hugging Face token</a> with write access to the organization you want to push the dataset to."
248
- )
249
-
250
- with gr.Column() as push_to_hub_ui:
251
- with gr.Row(variant="panel"):
252
- num_turns = gr.Number(
253
- value=1,
254
- label="Number of turns in the conversation",
255
- minimum=1,
256
- maximum=4,
257
- step=1,
258
- info="Choose between 1 (single turn with 'instruction-response' columns) and 2-4 (multi-turn conversation with a 'conversation' column).",
259
  )
260
- num_rows = gr.Number(
261
- value=100,
262
- label="Number of rows in the dataset",
263
- minimum=1,
264
- maximum=5000,
265
- info="The number of rows in the dataset. Note that you are able to generate more rows at once but that this will take time.",
 
266
  )
267
 
268
- with gr.Row(variant="panel"):
269
- hf_token = gr.Textbox(label="HF token", type="password")
270
- repo_id = gr.Textbox(label="HF repo ID", placeholder="owner/dataset_name")
271
- private = gr.Checkbox(label="Private dataset", value=True, interactive=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
272
 
273
- btn_generate_full_dataset = gr.Button(
274
- value="βš—οΈ Generate Full Dataset", variant="primary"
 
 
275
  )
276
 
277
- success_message = gr.Markdown(visible=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
278
 
279
- def show_success_message(repo_id_value):
280
  return gr.Markdown(
281
  value=f"""
282
  <div style="padding: 1em; background-color: #e6f3e6; border-radius: 5px; margin-top: 1em;">
283
  <h3 style="color: #2e7d32; margin: 0;">Dataset Published Successfully!</h3>
284
  <p style="margin-top: 0.5em;">
285
  Your dataset is now available at:
286
- <a href="https://huggingface.co/datasets/{repo_id_value}" target="_blank" style="color: #1565c0; text-decoration: none;">
287
- https://huggingface.co/datasets/{repo_id_value}
288
  </a>
289
  </p>
290
  </div>
@@ -294,10 +309,22 @@ with gr.Blocks(
294
 
295
  btn_generate_full_dataset.click(
296
  fn=generate_dataset,
297
- inputs=[system_prompt, num_turns, num_rows, private, repo_id, hf_token],
 
 
 
 
 
 
 
 
298
  outputs=[table],
299
  show_progress=True,
300
- ).then(fn=show_success_message, inputs=[repo_id], outputs=[success_message])
 
 
 
 
301
 
302
  gr.Markdown("## Or run this pipeline locally with distilabel")
303
 
@@ -309,3 +336,6 @@ with gr.Blocks(
309
  inputs=[system_prompt],
310
  outputs=[pipeline_code],
311
  )
 
 
 
 
15
  get_pipeline,
16
  get_prompt_generation_step,
17
  )
18
+ from src.distilabel_dataset_generator.utils import (
19
+ get_login_button,
20
+ get_org_dropdown,
21
+ get_token,
22
+ )
23
 
24
 
25
  def _run_pipeline(result_queue, num_turns, num_rows, system_prompt):
 
64
  num_turns=1,
65
  num_rows=5,
66
  private=True,
67
+ org_name=None,
68
+ repo_name=None,
69
  token=None,
70
  progress=gr.Progress(),
71
  ):
72
+ repo_id = (
73
+ f"{org_name}/{repo_name}"
74
+ if repo_name is not None and org_name is not None
75
+ else None
76
+ )
77
  if repo_id is not None:
78
  if not repo_id:
79
+ raise gr.Error(
80
+ "Please provide a repo_name and org_name to push the dataset to."
81
+ )
82
  try:
83
  whoami(token=token)
84
  except Exception:
85
  raise gr.Error(
86
+ "Provide a Hugging Face token with write access to the organization you want to push the dataset to."
87
  )
88
 
89
  if num_turns > 4:
 
124
  (step + 1) / total_steps,
125
  desc=f"Generating dataset with {num_rows} rows",
126
  )
127
+ time.sleep(duration / total_steps) # Adjust this value based on your needs
128
  p.join()
129
  except Exception as e:
130
  raise gr.Error(f"An error occurred during dataset generation: {str(e)}")
 
139
  include_script=False,
140
  token=token,
141
  )
 
 
 
142
 
143
  # If not pushing to hub generate the dataset directly
144
  distiset = distiset["default"]["train"]
 
203
  title="βš—οΈ Distilabel Dataset Generator",
204
  head="βš—οΈ Distilabel Dataset Generator",
205
  ) as app:
206
+ get_login_button()
207
  gr.Markdown("## Iterate on a sample dataset")
208
+ with gr.Column() as main_ui:
209
+ dataset_description = gr.TextArea(
210
+ label="Provide a description of the dataset",
211
+ value=DEFAULT_DATASET_DESCRIPTION,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
212
  )
213
+ with gr.Row():
214
+ gr.Column(scale=1)
215
+ btn_generate_system_prompt = gr.Button(value="Generate sample dataset")
216
+ gr.Column(scale=1)
217
+
218
+ system_prompt = gr.TextArea(
219
+ label="If you want to improve the dataset, you can tune the system prompt and regenerate the sample",
220
+ value=DEFAULT_SYSTEM_PROMPT,
221
  )
222
 
223
+ with gr.Row():
224
+ gr.Column(scale=1)
225
+ btn_generate_sample_dataset = gr.Button(
226
+ value="Regenerate sample dataset",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
227
  )
228
+ gr.Column(scale=1)
229
+
230
+ with gr.Row():
231
+ table = gr.DataFrame(
232
+ value=DEFAULT_DATASET,
233
+ interactive=False,
234
+ wrap=True,
235
  )
236
 
237
+ result = btn_generate_system_prompt.click(
238
+ fn=generate_system_prompt,
239
+ inputs=[dataset_description],
240
+ outputs=[system_prompt],
241
+ show_progress=True,
242
+ ).then(
243
+ fn=generate_sample_dataset,
244
+ inputs=[system_prompt],
245
+ outputs=[table],
246
+ show_progress=True,
247
+ )
248
+
249
+ btn_generate_sample_dataset.click(
250
+ fn=generate_sample_dataset,
251
+ inputs=[system_prompt],
252
+ outputs=[table],
253
+ show_progress=True,
254
+ )
255
 
256
+ # Add a header for the full dataset generation section
257
+ gr.Markdown("## Generate full dataset")
258
+ gr.Markdown(
259
+ "Once you're satisfied with the sample, generate a larger dataset and push it to the hub. Get <a href='https://huggingface.co/settings/tokens' target='_blank'>a Hugging Face token</a> with write access to the organization you want to push the dataset to."
260
  )
261
 
262
+ with gr.Column() as push_to_hub_ui:
263
+ with gr.Row(variant="panel"):
264
+ num_turns = gr.Number(
265
+ value=1,
266
+ label="Number of turns in the conversation",
267
+ minimum=1,
268
+ maximum=4,
269
+ step=1,
270
+ info="Choose between 1 (single turn with 'instruction-response' columns) and 2-4 (multi-turn conversation with a 'conversation' column).",
271
+ )
272
+ num_rows = gr.Number(
273
+ value=100,
274
+ label="Number of rows in the dataset",
275
+ minimum=1,
276
+ maximum=5000,
277
+ info="The number of rows in the dataset. Note that you are able to generate more rows at once but that this will take time.",
278
+ )
279
+
280
+ with gr.Row(variant="panel"):
281
+ hf_token = gr.Textbox(label="HF token", type="password")
282
+ org_name = get_org_dropdown()
283
+ repo_name = gr.Textbox(label="Repo name", placeholder="dataset_name")
284
+ private = gr.Checkbox(
285
+ label="Private dataset", value=True, interactive=True, scale=0.5
286
+ )
287
+
288
+ btn_generate_full_dataset = gr.Button(
289
+ value="βš—οΈ Generate Full Dataset", variant="primary"
290
+ )
291
+
292
+ success_message = gr.Markdown(visible=False)
293
 
294
+ def show_success_message(org_name, repo_name):
295
  return gr.Markdown(
296
  value=f"""
297
  <div style="padding: 1em; background-color: #e6f3e6; border-radius: 5px; margin-top: 1em;">
298
  <h3 style="color: #2e7d32; margin: 0;">Dataset Published Successfully!</h3>
299
  <p style="margin-top: 0.5em;">
300
  Your dataset is now available at:
301
+ <a href="https://huggingface.co/datasets/{org_name}/{repo_name}" target="_blank" style="color: #1565c0; text-decoration: none;">
302
+ https://huggingface.co/datasets/{org_name}/{repo_name}
303
  </a>
304
  </p>
305
  </div>
 
309
 
310
  btn_generate_full_dataset.click(
311
  fn=generate_dataset,
312
+ inputs=[
313
+ system_prompt,
314
+ num_turns,
315
+ num_rows,
316
+ private,
317
+ org_name,
318
+ repo_name,
319
+ hf_token,
320
+ ],
321
  outputs=[table],
322
  show_progress=True,
323
+ ).then(
324
+ fn=show_success_message,
325
+ inputs=[org_name, repo_name],
326
+ outputs=[success_message],
327
+ )
328
 
329
  gr.Markdown("## Or run this pipeline locally with distilabel")
330
 
 
336
  inputs=[system_prompt],
337
  outputs=[pipeline_code],
338
  )
339
+
340
+ app.load(get_token, outputs=[hf_token])
341
+ app.load(get_org_dropdown, outputs=[org_name])
src/distilabel_dataset_generator/pipelines/sft.py CHANGED
@@ -156,7 +156,7 @@ def get_pipeline(num_turns, num_rows, system_prompt):
156
  ],
157
  },
158
  ),
159
- batch_size=1,
160
  n_turns=num_turns,
161
  num_rows=num_rows,
162
  system_prompt=system_prompt,
 
156
  ],
157
  },
158
  ),
159
+ batch_size=2,
160
  n_turns=num_turns,
161
  num_rows=num_rows,
162
  system_prompt=system_prompt,
src/distilabel_dataset_generator/utils.py CHANGED
@@ -66,7 +66,10 @@ def list_orgs(token: OAuthToken = None):
66
  def get_org_dropdown(token: OAuthToken = None):
67
  orgs = list_orgs(token)
68
  return gr.Dropdown(
69
- label="Organization", choices=orgs, value=orgs[0] if orgs else None
 
 
 
70
  )
71
 
72
 
@@ -75,3 +78,17 @@ def swap_visibilty(profile: Union[gr.OAuthProfile, None]):
75
  return gr.Column(visible=False)
76
  else:
77
  return gr.Column(visible=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  def get_org_dropdown(token: OAuthToken = None):
67
  orgs = list_orgs(token)
68
  return gr.Dropdown(
69
+ label="Organization",
70
+ choices=orgs,
71
+ value=orgs[0] if orgs else None,
72
+ allow_custom_value=True,
73
  )
74
 
75
 
 
78
  return gr.Column(visible=False)
79
  else:
80
  return gr.Column(visible=True)
81
+
82
+
83
+ def swap_visibilty_classes(profile: Union[gr.OAuthProfile, None]):
84
+ if profile is None:
85
+ return gr.update(elem_classes=["main_ui_logged_out"])
86
+ else:
87
+ return gr.update(elem_classes=["main_ui_logged_in"])
88
+
89
+
90
+ def get_token(token: OAuthToken = None):
91
+ if token:
92
+ return token.token
93
+ else:
94
+ return ""