lhoestq HF staff commited on
Commit
a8aff52
1 Parent(s): bcabfd9

better buttons

Browse files
Files changed (1) hide show
  1. app.py +37 -19
app.py CHANGED
@@ -29,8 +29,10 @@ NAMESPACE = "dataset-rewriter"
29
  URL = "https://huggingface.co/spaces/dataset-rewriter/dataset-rewriter"
30
 
31
  NUM_ROWS_PREVIEW = 3
 
32
  MAX_NUM_ROWS_TO_REWRITE = int(os.environ.get("MAX_NUM_ROWS_TO_REWRITE") or 1000)
33
- PARTIAL_SUFFIX = "-1k"
 
34
  NUM_PARALLEL_CALLS = 10
35
  NUM_ROWS_PER_CALL = 5
36
  MAX_PROGRESS_UPDATES_PER_SECOND = 4
@@ -88,6 +90,12 @@ a {
88
  }
89
  """
90
 
 
 
 
 
 
 
91
  with gr.Blocks(css=css) as demo:
92
  dataset_info_json = gr.JSON(visible=False)
93
  with gr.Row():
@@ -111,7 +119,7 @@ with gr.Blocks(css=css) as demo:
111
 
112
  gr.Markdown("### ReWrite")
113
  with gr.Group():
114
- input_prompt = gr.Textbox(label="Enter the adjustment or transformation to apply to the dataset:")
115
  with gr.Accordion("(Advanced) Edit columns", open=False):
116
  output_format_dataframe = gr.DataFrame(col_count=(2, "fixed"), headers=["column", "type"])
117
  rewrite_preview_button = gr.Button("Preview Results", variant="primary")
@@ -119,8 +127,9 @@ with gr.Blocks(css=css) as demo:
119
  gr.Markdown("#### Output")
120
  full_dataset_generation_label = gr.Label(visible=False, show_label=False)
121
  pretty_output_preview = gr.DataFrame(interactive=False)
122
- full_dataset_generation_success_markdown = gr.Markdown("")
123
  pretty_full_dataset_generation_output = gr.DataFrame(interactive=False, visible=False)
 
 
124
  gr.Markdown(f"_powered by [{model_id}](https://huggingface.co/{model_id})_")
125
  with gr.Column(scale=4, min_width="200px"):
126
  with gr.Accordion("Settings", open=False, elem_classes="settings"):
@@ -130,8 +139,9 @@ with gr.Blocks(css=css) as demo:
130
  gr.Markdown("Save datasets as public or private datasets")
131
  visibility_radio = gr.Radio(["public", "private"], value="public", container=False, interactive=False)
132
  gr.Markdown("Maximum number of rows to ReWrite")
133
- gr.Dropdown(choices=[MAX_NUM_ROWS_TO_REWRITE], value=MAX_NUM_ROWS_TO_REWRITE, interactive=False, container=False)
134
- gr.Markdown(f"_[duplicate]({URL}?duplicate=true) this space to rewrite bigger datasets_")
 
135
 
136
 
137
  ############
@@ -187,10 +197,10 @@ with gr.Blocks(css=css) as demo:
187
  prompt = prompt[:1000] if prompt.strip() else ""
188
  messages = [{"role": "user", "content": REWRITE_DATASET_PREVIEW.format(
189
  dataset=dataset,
190
- rows=json.dumps({"data": rows}),
191
  prompt=prompt,
192
  )}]
193
- response_format = {"type": "json", "value": {"properties": {"data": {"type": "array", "items": format}}, "required": ["data"]}}
194
  yield from ijson.items(StringIteratorIO(stream_reponse(messages, response_format=response_format)), "data.item", buf_size=4)
195
 
196
 
@@ -198,13 +208,17 @@ with gr.Blocks(css=css) as demo:
198
  prompt = prompt[:1000] if prompt.strip() else ""
199
  messages = [{"role": "user", "content": REWRITE_DATASET.format(
200
  dataset=dataset,
201
- rows=json.dumps({"data": rows}),
202
  prompt=prompt,
203
- input_preview_rows=json.dumps({"data": input_preview_rows}),
204
- output_preview_rows=json.dumps({"data": output_preview_rows}),
205
  )}]
206
- response_format = {"type": "json", "value": {"properties": {"data": {"type": "array", "items": format}}, "required": ["data"]}}
207
- yield from ijson.items(StringIteratorIO(stream_reponse(messages, response_format=response_format)), "data.item", buf_size=4)
 
 
 
 
208
 
209
 
210
  def find_new_name(dataset: str, prompt: str) -> str:
@@ -311,11 +325,11 @@ with gr.Blocks(css=css) as demo:
311
  def show_input_from_dataset_search(dataset: str) -> dict:
312
  return _show_input_preview(dataset, default_subset="default", default_split="train")
313
 
314
- @subset_dropdown.change(inputs=[dataset_search, subset_dropdown], outputs=[pretty_input_preview, subset_dropdown, split_dropdown, output_format_dataframe, dataset_info_json])
315
  def show_input_from_subset_dropdown(dataset: str, subset: str) -> dict:
316
  return _show_input_preview(dataset, default_subset=subset, default_split="train")
317
 
318
- @split_dropdown.change(inputs=[dataset_search, subset_dropdown, split_dropdown], outputs=[pretty_input_preview, subset_dropdown, split_dropdown, output_format_dataframe, dataset_info_json])
319
  def show_input_from_split_dropdown(dataset: str, subset: str, split: str) -> dict:
320
  return _show_input_preview(dataset, default_subset=subset, default_split=split)
321
 
@@ -344,14 +358,14 @@ with gr.Blocks(css=css) as demo:
344
  print(f"(preview) Done ReWriting {dataset} with instruction '{prompt}'")
345
 
346
 
347
- @rewrite_full_dataset_button.click(inputs=[dataset_search, subset_dropdown, split_dropdown, pretty_input_preview, pretty_output_preview, input_prompt, output_format_dataframe, dataset_info_json, select_namespace_dropdown], outputs=[full_dataset_generation_label, full_dataset_generation_success_markdown, pretty_output_preview, pretty_full_dataset_generation_output])
348
- def rewrite_full_dataset(dataset: str, subset: str, split: str, pretty_input_preview_df: pd.DataFrame, pretty_output_preview_df: pd.DataFrame, prompt: str, output_format_df: pd.DataFrame, dataset_info: dict[str, Any], namespace: str, oauth_token: Optional[gr.OAuthToken]) -> Iterator[pd.DataFrame]:
349
  input_preview_rows = [{k: json.loads(v) for k, v in row.items()} for row in pretty_input_preview_df.to_dict(orient="records")]
350
  output_preview_rows = [{k: json.loads(v) for k, v in row.items()} for row in pretty_output_preview_df.to_dict(orient="records")]
351
  format = output_format_df.to_dict(orient="records")
352
  format = {"properties": {x["column"]: json.loads(x["type"]) for x in format}, "required": [x["column"] for x in format]}
353
  num_examples = dataset_info["splits"][split]["num_examples"]
354
- total = min(num_examples, MAX_NUM_ROWS_TO_REWRITE)
355
  print(f"ReWriting {dataset} with instruction '{prompt}'")
356
  yield {full_dataset_generation_label: gr.Label({f"⚙️ ReWriting {dataset}": 0.}, visible=True)}
357
  yield {pretty_full_dataset_generation_output: empty_dataframe}
@@ -387,7 +401,7 @@ with gr.Blocks(css=css) as demo:
387
  print(f"Done ReWriting {dataset} with instruction '{prompt}'")
388
 
389
  output_rows = [{k: json.loads(row[k]) for k in output_format_df["column"]} for rows in parallel_output_rows for row in rows]
390
- new_dataset = find_new_name(dataset + (PARTIAL_SUFFIX if num_examples > total else ""), prompt)
391
  repo_id = namespace + "/" + new_dataset
392
  yield {full_dataset_generation_label: gr.Label({f"✅ ReWriting {dataset}": len(output_rows) / total, f"⚙️ Saving to {repo_id}": 0.})}
393
  token = oauth_token.token if oauth_token else save_dataset_hf_token
@@ -396,7 +410,11 @@ with gr.Blocks(css=css) as demo:
396
  ds.push_to_hub(repo_id, config_name=subset, split=split, token=token)
397
  DatasetCard(DATASET_CARD_CONTENT.format(new_dataset=new_dataset, dataset=dataset, model_id=model_id, prompt=prompt, url=URL)).push_to_hub(repo_id=repo_id, repo_type="dataset", token=token)
398
  yield {full_dataset_generation_label: gr.Label({f"✅ ReWriting {dataset}": len(output_rows) / total, f"✅ Saving to {repo_id}": 1.})}
399
- yield {full_dataset_generation_success_markdown: f"# Open the ReWriten dataset in a new tab: [{repo_id}](https://huggingface.co/datasets/{repo_id})"}
 
 
 
 
400
  print(f"Saved {repo_id}")
401
 
402
 
 
29
  URL = "https://huggingface.co/spaces/dataset-rewriter/dataset-rewriter"
30
 
31
  NUM_ROWS_PREVIEW = 3
32
+ PARTIAL_SUFFIX = {10: "-10", 100: "-100", 1000: "-1k", 10_000: "-10k", 100_000: "-100k", 1000_000: "-1M"}
33
  MAX_NUM_ROWS_TO_REWRITE = int(os.environ.get("MAX_NUM_ROWS_TO_REWRITE") or 1000)
34
+ assert MAX_NUM_ROWS_TO_REWRITE in PARTIAL_SUFFIX, "allowed max num rows are 100, 1000, 10000, 100000 and 1000000"
35
+
36
  NUM_PARALLEL_CALLS = 10
37
  NUM_ROWS_PER_CALL = 5
38
  MAX_PROGRESS_UPDATES_PER_SECOND = 4
 
90
  }
91
  """
92
 
93
+ examples = [
94
+ ["fka/awesome-chatgpt-prompts", "make the prompt 6 words long maximum"],
95
+ ["lhoestq/CudyPokemonAdventures", "Pikachu as main character"],
96
+ ["infinite-dataset-hub/SmallTalkDialogues", "translate to proper French"],
97
+ ]
98
+
99
  with gr.Blocks(css=css) as demo:
100
  dataset_info_json = gr.JSON(visible=False)
101
  with gr.Row():
 
119
 
120
  gr.Markdown("### ReWrite")
121
  with gr.Group():
122
+ input_prompt = gr.Textbox(label="Adjustment or transformation to apply to the dataset")
123
  with gr.Accordion("(Advanced) Edit columns", open=False):
124
  output_format_dataframe = gr.DataFrame(col_count=(2, "fixed"), headers=["column", "type"])
125
  rewrite_preview_button = gr.Button("Preview Results", variant="primary")
 
127
  gr.Markdown("#### Output")
128
  full_dataset_generation_label = gr.Label(visible=False, show_label=False)
129
  pretty_output_preview = gr.DataFrame(interactive=False)
 
130
  pretty_full_dataset_generation_output = gr.DataFrame(interactive=False, visible=False)
131
+ full_dataset_generation_success_html = gr.HTML()
132
+ gr.Examples(examples, inputs=[dataset_search, input_prompt])
133
  gr.Markdown(f"_powered by [{model_id}](https://huggingface.co/{model_id})_")
134
  with gr.Column(scale=4, min_width="200px"):
135
  with gr.Accordion("Settings", open=False, elem_classes="settings"):
 
139
  gr.Markdown("Save datasets as public or private datasets")
140
  visibility_radio = gr.Radio(["public", "private"], value="public", container=False, interactive=False)
141
  gr.Markdown("Maximum number of rows to ReWrite")
142
+ max_num_rows_dropdown = gr.Dropdown(choices=[num_rows for num_rows in PARTIAL_SUFFIX if num_rows <= MAX_NUM_ROWS_TO_REWRITE], value=MAX_NUM_ROWS_TO_REWRITE, container=False)
143
+ gr.Markdown("Duplicate this space to ReWrite more rows")
144
+ gr.HTML(f'<a href="{URL}?duplicate=true" target="_blank"><img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/duplicate-this-space-xl.svg" alt="Duplicate this Space"></a>')
145
 
146
 
147
  ############
 
197
  prompt = prompt[:1000] if prompt.strip() else ""
198
  messages = [{"role": "user", "content": REWRITE_DATASET_PREVIEW.format(
199
  dataset=dataset,
200
+ rows=json.dumps({"data": rows}, ensure_ascii=False),
201
  prompt=prompt,
202
  )}]
203
+ response_format = {"type": "json", "value": {"properties": {"data": {"type": "array", "items": format, "minItems": len(rows), "maxItems": len(rows)}}, "required": ["data"]}}
204
  yield from ijson.items(StringIteratorIO(stream_reponse(messages, response_format=response_format)), "data.item", buf_size=4)
205
 
206
 
 
208
  prompt = prompt[:1000] if prompt.strip() else ""
209
  messages = [{"role": "user", "content": REWRITE_DATASET.format(
210
  dataset=dataset,
211
+ rows=json.dumps({"data": rows}, ensure_ascii=False),
212
  prompt=prompt,
213
+ input_preview_rows=json.dumps({"data": input_preview_rows}, ensure_ascii=False),
214
+ output_preview_rows=json.dumps({"data": output_preview_rows}, ensure_ascii=False),
215
  )}]
216
+ response_format = {"type": "json", "value": {"properties": {"data": {"type": "array", "items": format, "minItems": len(rows), "maxItems": len(rows)}}, "required": ["data"]}}
217
+ try:
218
+ yield from ijson.items(StringIteratorIO(stream_reponse(messages, response_format=response_format)), "data.item", buf_size=4)
219
+ except ijson.IncompleteJSONError as e:
220
+ print(f"{type(e).__name__}: {e}")
221
+ print("Warning: Some rows were missing during ReWriting.")
222
 
223
 
224
  def find_new_name(dataset: str, prompt: str) -> str:
 
325
  def show_input_from_dataset_search(dataset: str) -> dict:
326
  return _show_input_preview(dataset, default_subset="default", default_split="train")
327
 
328
+ @subset_dropdown.select(inputs=[dataset_search, subset_dropdown], outputs=[pretty_input_preview, subset_dropdown, split_dropdown, output_format_dataframe, dataset_info_json])
329
  def show_input_from_subset_dropdown(dataset: str, subset: str) -> dict:
330
  return _show_input_preview(dataset, default_subset=subset, default_split="train")
331
 
332
+ @split_dropdown.select(inputs=[dataset_search, subset_dropdown, split_dropdown], outputs=[pretty_input_preview, subset_dropdown, split_dropdown, output_format_dataframe, dataset_info_json])
333
  def show_input_from_split_dropdown(dataset: str, subset: str, split: str) -> dict:
334
  return _show_input_preview(dataset, default_subset=subset, default_split=split)
335
 
 
358
  print(f"(preview) Done ReWriting {dataset} with instruction '{prompt}'")
359
 
360
 
361
+ @rewrite_full_dataset_button.click(inputs=[dataset_search, subset_dropdown, split_dropdown, pretty_input_preview, pretty_output_preview, input_prompt, output_format_dataframe, dataset_info_json, select_namespace_dropdown, max_num_rows_dropdown], outputs=[full_dataset_generation_label, full_dataset_generation_success_html, pretty_output_preview, pretty_full_dataset_generation_output])
362
+ def rewrite_full_dataset(dataset: str, subset: str, split: str, pretty_input_preview_df: pd.DataFrame, pretty_output_preview_df: pd.DataFrame, prompt: str, output_format_df: pd.DataFrame, dataset_info: dict[str, Any], namespace: str, max_num_rows: int, oauth_token: Optional[gr.OAuthToken]) -> Iterator[pd.DataFrame]:
363
  input_preview_rows = [{k: json.loads(v) for k, v in row.items()} for row in pretty_input_preview_df.to_dict(orient="records")]
364
  output_preview_rows = [{k: json.loads(v) for k, v in row.items()} for row in pretty_output_preview_df.to_dict(orient="records")]
365
  format = output_format_df.to_dict(orient="records")
366
  format = {"properties": {x["column"]: json.loads(x["type"]) for x in format}, "required": [x["column"] for x in format]}
367
  num_examples = dataset_info["splits"][split]["num_examples"]
368
+ total = min(num_examples, max_num_rows)
369
  print(f"ReWriting {dataset} with instruction '{prompt}'")
370
  yield {full_dataset_generation_label: gr.Label({f"⚙️ ReWriting {dataset}": 0.}, visible=True)}
371
  yield {pretty_full_dataset_generation_output: empty_dataframe}
 
401
  print(f"Done ReWriting {dataset} with instruction '{prompt}'")
402
 
403
  output_rows = [{k: json.loads(row[k]) for k in output_format_df["column"]} for rows in parallel_output_rows for row in rows]
404
+ new_dataset = find_new_name(dataset + (PARTIAL_SUFFIX[max_num_rows] if num_examples > total else ""), prompt)
405
  repo_id = namespace + "/" + new_dataset
406
  yield {full_dataset_generation_label: gr.Label({f"✅ ReWriting {dataset}": len(output_rows) / total, f"⚙️ Saving to {repo_id}": 0.})}
407
  token = oauth_token.token if oauth_token else save_dataset_hf_token
 
410
  ds.push_to_hub(repo_id, config_name=subset, split=split, token=token)
411
  DatasetCard(DATASET_CARD_CONTENT.format(new_dataset=new_dataset, dataset=dataset, model_id=model_id, prompt=prompt, url=URL)).push_to_hub(repo_id=repo_id, repo_type="dataset", token=token)
412
  yield {full_dataset_generation_label: gr.Label({f"✅ ReWriting {dataset}": len(output_rows) / total, f"✅ Saving to {repo_id}": 1.})}
413
+ yield {full_dataset_generation_success_html: (
414
+ f'<a href="https://huggingface.co/datasets/{repo_id}" target="_blank">'
415
+ '<img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/dataset-on-hf-xl.svg" alt="Dataset on HF", style="margin-right: auto; margin-left: auto; max-width: fit-content;">'
416
+ '</a>'
417
+ )}
418
  print(f"Saved {repo_id}")
419
 
420