Spaces:
Sleeping
Sleeping
better buttons
Browse files
app.py
CHANGED
@@ -29,8 +29,10 @@ NAMESPACE = "dataset-rewriter"
|
|
29 |
URL = "https://huggingface.co/spaces/dataset-rewriter/dataset-rewriter"
|
30 |
|
31 |
NUM_ROWS_PREVIEW = 3
|
|
|
32 |
MAX_NUM_ROWS_TO_REWRITE = int(os.environ.get("MAX_NUM_ROWS_TO_REWRITE") or 1000)
|
33 |
-
PARTIAL_SUFFIX
|
|
|
34 |
NUM_PARALLEL_CALLS = 10
|
35 |
NUM_ROWS_PER_CALL = 5
|
36 |
MAX_PROGRESS_UPDATES_PER_SECOND = 4
|
@@ -88,6 +90,12 @@ a {
|
|
88 |
}
|
89 |
"""
|
90 |
|
|
|
|
|
|
|
|
|
|
|
|
|
91 |
with gr.Blocks(css=css) as demo:
|
92 |
dataset_info_json = gr.JSON(visible=False)
|
93 |
with gr.Row():
|
@@ -111,7 +119,7 @@ with gr.Blocks(css=css) as demo:
|
|
111 |
|
112 |
gr.Markdown("### ReWrite")
|
113 |
with gr.Group():
|
114 |
-
input_prompt = gr.Textbox(label="
|
115 |
with gr.Accordion("(Advanced) Edit columns", open=False):
|
116 |
output_format_dataframe = gr.DataFrame(col_count=(2, "fixed"), headers=["column", "type"])
|
117 |
rewrite_preview_button = gr.Button("Preview Results", variant="primary")
|
@@ -119,8 +127,9 @@ with gr.Blocks(css=css) as demo:
|
|
119 |
gr.Markdown("#### Output")
|
120 |
full_dataset_generation_label = gr.Label(visible=False, show_label=False)
|
121 |
pretty_output_preview = gr.DataFrame(interactive=False)
|
122 |
-
full_dataset_generation_success_markdown = gr.Markdown("")
|
123 |
pretty_full_dataset_generation_output = gr.DataFrame(interactive=False, visible=False)
|
|
|
|
|
124 |
gr.Markdown(f"_powered by [{model_id}](https://huggingface.co/{model_id})_")
|
125 |
with gr.Column(scale=4, min_width="200px"):
|
126 |
with gr.Accordion("Settings", open=False, elem_classes="settings"):
|
@@ -130,8 +139,9 @@ with gr.Blocks(css=css) as demo:
|
|
130 |
gr.Markdown("Save datasets as public or private datasets")
|
131 |
visibility_radio = gr.Radio(["public", "private"], value="public", container=False, interactive=False)
|
132 |
gr.Markdown("Maximum number of rows to ReWrite")
|
133 |
-
gr.Dropdown(choices=[MAX_NUM_ROWS_TO_REWRITE], value=MAX_NUM_ROWS_TO_REWRITE,
|
134 |
-
gr.Markdown(
|
|
|
135 |
|
136 |
|
137 |
############
|
@@ -187,10 +197,10 @@ with gr.Blocks(css=css) as demo:
|
|
187 |
prompt = prompt[:1000] if prompt.strip() else ""
|
188 |
messages = [{"role": "user", "content": REWRITE_DATASET_PREVIEW.format(
|
189 |
dataset=dataset,
|
190 |
-
rows=json.dumps({"data": rows}),
|
191 |
prompt=prompt,
|
192 |
)}]
|
193 |
-
response_format = {"type": "json", "value": {"properties": {"data": {"type": "array", "items": format}}, "required": ["data"]}}
|
194 |
yield from ijson.items(StringIteratorIO(stream_reponse(messages, response_format=response_format)), "data.item", buf_size=4)
|
195 |
|
196 |
|
@@ -198,13 +208,17 @@ with gr.Blocks(css=css) as demo:
|
|
198 |
prompt = prompt[:1000] if prompt.strip() else ""
|
199 |
messages = [{"role": "user", "content": REWRITE_DATASET.format(
|
200 |
dataset=dataset,
|
201 |
-
rows=json.dumps({"data": rows}),
|
202 |
prompt=prompt,
|
203 |
-
input_preview_rows=json.dumps({"data": input_preview_rows}),
|
204 |
-
output_preview_rows=json.dumps({"data": output_preview_rows}),
|
205 |
)}]
|
206 |
-
response_format = {"type": "json", "value": {"properties": {"data": {"type": "array", "items": format}}, "required": ["data"]}}
|
207 |
-
|
|
|
|
|
|
|
|
|
208 |
|
209 |
|
210 |
def find_new_name(dataset: str, prompt: str) -> str:
|
@@ -311,11 +325,11 @@ with gr.Blocks(css=css) as demo:
|
|
311 |
def show_input_from_dataset_search(dataset: str) -> dict:
|
312 |
return _show_input_preview(dataset, default_subset="default", default_split="train")
|
313 |
|
314 |
-
@subset_dropdown.
|
315 |
def show_input_from_subset_dropdown(dataset: str, subset: str) -> dict:
|
316 |
return _show_input_preview(dataset, default_subset=subset, default_split="train")
|
317 |
|
318 |
-
@split_dropdown.
|
319 |
def show_input_from_split_dropdown(dataset: str, subset: str, split: str) -> dict:
|
320 |
return _show_input_preview(dataset, default_subset=subset, default_split=split)
|
321 |
|
@@ -344,14 +358,14 @@ with gr.Blocks(css=css) as demo:
|
|
344 |
print(f"(preview) Done ReWriting {dataset} with instruction '{prompt}'")
|
345 |
|
346 |
|
347 |
-
@rewrite_full_dataset_button.click(inputs=[dataset_search, subset_dropdown, split_dropdown, pretty_input_preview, pretty_output_preview, input_prompt, output_format_dataframe, dataset_info_json, select_namespace_dropdown], outputs=[full_dataset_generation_label,
|
348 |
-
def rewrite_full_dataset(dataset: str, subset: str, split: str, pretty_input_preview_df: pd.DataFrame, pretty_output_preview_df: pd.DataFrame, prompt: str, output_format_df: pd.DataFrame, dataset_info: dict[str, Any], namespace: str, oauth_token: Optional[gr.OAuthToken]) -> Iterator[pd.DataFrame]:
|
349 |
input_preview_rows = [{k: json.loads(v) for k, v in row.items()} for row in pretty_input_preview_df.to_dict(orient="records")]
|
350 |
output_preview_rows = [{k: json.loads(v) for k, v in row.items()} for row in pretty_output_preview_df.to_dict(orient="records")]
|
351 |
format = output_format_df.to_dict(orient="records")
|
352 |
format = {"properties": {x["column"]: json.loads(x["type"]) for x in format}, "required": [x["column"] for x in format]}
|
353 |
num_examples = dataset_info["splits"][split]["num_examples"]
|
354 |
-
total = min(num_examples,
|
355 |
print(f"ReWriting {dataset} with instruction '{prompt}'")
|
356 |
yield {full_dataset_generation_label: gr.Label({f"⚙️ ReWriting {dataset}": 0.}, visible=True)}
|
357 |
yield {pretty_full_dataset_generation_output: empty_dataframe}
|
@@ -387,7 +401,7 @@ with gr.Blocks(css=css) as demo:
|
|
387 |
print(f"Done ReWriting {dataset} with instruction '{prompt}'")
|
388 |
|
389 |
output_rows = [{k: json.loads(row[k]) for k in output_format_df["column"]} for rows in parallel_output_rows for row in rows]
|
390 |
-
new_dataset = find_new_name(dataset + (PARTIAL_SUFFIX if num_examples > total else ""), prompt)
|
391 |
repo_id = namespace + "/" + new_dataset
|
392 |
yield {full_dataset_generation_label: gr.Label({f"✅ ReWriting {dataset}": len(output_rows) / total, f"⚙️ Saving to {repo_id}": 0.})}
|
393 |
token = oauth_token.token if oauth_token else save_dataset_hf_token
|
@@ -396,7 +410,11 @@ with gr.Blocks(css=css) as demo:
|
|
396 |
ds.push_to_hub(repo_id, config_name=subset, split=split, token=token)
|
397 |
DatasetCard(DATASET_CARD_CONTENT.format(new_dataset=new_dataset, dataset=dataset, model_id=model_id, prompt=prompt, url=URL)).push_to_hub(repo_id=repo_id, repo_type="dataset", token=token)
|
398 |
yield {full_dataset_generation_label: gr.Label({f"✅ ReWriting {dataset}": len(output_rows) / total, f"✅ Saving to {repo_id}": 1.})}
|
399 |
-
yield {
|
|
|
|
|
|
|
|
|
400 |
print(f"Saved {repo_id}")
|
401 |
|
402 |
|
|
|
29 |
URL = "https://huggingface.co/spaces/dataset-rewriter/dataset-rewriter"
|
30 |
|
31 |
NUM_ROWS_PREVIEW = 3
|
32 |
+
PARTIAL_SUFFIX = {10: "-10", 100: "-100", 1000: "-1k", 10_000: "-10k", 100_000: "-100k", 1000_000: "-1M"}
|
33 |
MAX_NUM_ROWS_TO_REWRITE = int(os.environ.get("MAX_NUM_ROWS_TO_REWRITE") or 1000)
|
34 |
+
assert MAX_NUM_ROWS_TO_REWRITE in PARTIAL_SUFFIX, "allowed max num rows are 100, 1000, 10000, 100000 and 1000000"
|
35 |
+
|
36 |
NUM_PARALLEL_CALLS = 10
|
37 |
NUM_ROWS_PER_CALL = 5
|
38 |
MAX_PROGRESS_UPDATES_PER_SECOND = 4
|
|
|
90 |
}
|
91 |
"""
|
92 |
|
93 |
+
examples = [
|
94 |
+
["fka/awesome-chatgpt-prompts", "make the prompt 6 words long maximum"],
|
95 |
+
["lhoestq/CudyPokemonAdventures", "Pikachu as main character"],
|
96 |
+
["infinite-dataset-hub/SmallTalkDialogues", "translate to proper French"],
|
97 |
+
]
|
98 |
+
|
99 |
with gr.Blocks(css=css) as demo:
|
100 |
dataset_info_json = gr.JSON(visible=False)
|
101 |
with gr.Row():
|
|
|
119 |
|
120 |
gr.Markdown("### ReWrite")
|
121 |
with gr.Group():
|
122 |
+
input_prompt = gr.Textbox(label="Adjustment or transformation to apply to the dataset")
|
123 |
with gr.Accordion("(Advanced) Edit columns", open=False):
|
124 |
output_format_dataframe = gr.DataFrame(col_count=(2, "fixed"), headers=["column", "type"])
|
125 |
rewrite_preview_button = gr.Button("Preview Results", variant="primary")
|
|
|
127 |
gr.Markdown("#### Output")
|
128 |
full_dataset_generation_label = gr.Label(visible=False, show_label=False)
|
129 |
pretty_output_preview = gr.DataFrame(interactive=False)
|
|
|
130 |
pretty_full_dataset_generation_output = gr.DataFrame(interactive=False, visible=False)
|
131 |
+
full_dataset_generation_success_html = gr.HTML()
|
132 |
+
gr.Examples(examples, inputs=[dataset_search, input_prompt])
|
133 |
gr.Markdown(f"_powered by [{model_id}](https://huggingface.co/{model_id})_")
|
134 |
with gr.Column(scale=4, min_width="200px"):
|
135 |
with gr.Accordion("Settings", open=False, elem_classes="settings"):
|
|
|
139 |
gr.Markdown("Save datasets as public or private datasets")
|
140 |
visibility_radio = gr.Radio(["public", "private"], value="public", container=False, interactive=False)
|
141 |
gr.Markdown("Maximum number of rows to ReWrite")
|
142 |
+
max_num_rows_dropdown = gr.Dropdown(choices=[num_rows for num_rows in PARTIAL_SUFFIX if num_rows <= MAX_NUM_ROWS_TO_REWRITE], value=MAX_NUM_ROWS_TO_REWRITE, container=False)
|
143 |
+
gr.Markdown("Duplicate this space to ReWrite more rows")
|
144 |
+
gr.HTML(f'<a href="{URL}?duplicate=true" target="_blank"><img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/duplicate-this-space-xl.svg" alt="Duplicate this Space"></a>')
|
145 |
|
146 |
|
147 |
############
|
|
|
197 |
prompt = prompt[:1000] if prompt.strip() else ""
|
198 |
messages = [{"role": "user", "content": REWRITE_DATASET_PREVIEW.format(
|
199 |
dataset=dataset,
|
200 |
+
rows=json.dumps({"data": rows}, ensure_ascii=False),
|
201 |
prompt=prompt,
|
202 |
)}]
|
203 |
+
response_format = {"type": "json", "value": {"properties": {"data": {"type": "array", "items": format, "minItems": len(rows), "maxItems": len(rows)}}, "required": ["data"]}}
|
204 |
yield from ijson.items(StringIteratorIO(stream_reponse(messages, response_format=response_format)), "data.item", buf_size=4)
|
205 |
|
206 |
|
|
|
208 |
prompt = prompt[:1000] if prompt.strip() else ""
|
209 |
messages = [{"role": "user", "content": REWRITE_DATASET.format(
|
210 |
dataset=dataset,
|
211 |
+
rows=json.dumps({"data": rows}, ensure_ascii=False),
|
212 |
prompt=prompt,
|
213 |
+
input_preview_rows=json.dumps({"data": input_preview_rows}, ensure_ascii=False),
|
214 |
+
output_preview_rows=json.dumps({"data": output_preview_rows}, ensure_ascii=False),
|
215 |
)}]
|
216 |
+
response_format = {"type": "json", "value": {"properties": {"data": {"type": "array", "items": format, "minItems": len(rows), "maxItems": len(rows)}}, "required": ["data"]}}
|
217 |
+
try:
|
218 |
+
yield from ijson.items(StringIteratorIO(stream_reponse(messages, response_format=response_format)), "data.item", buf_size=4)
|
219 |
+
except ijson.IncompleteJSONError as e:
|
220 |
+
print(f"{type(e).__name__}: {e}")
|
221 |
+
print("Warning: Some rows were missing during ReWriting.")
|
222 |
|
223 |
|
224 |
def find_new_name(dataset: str, prompt: str) -> str:
|
|
|
325 |
def show_input_from_dataset_search(dataset: str) -> dict:
|
326 |
return _show_input_preview(dataset, default_subset="default", default_split="train")
|
327 |
|
328 |
+
@subset_dropdown.select(inputs=[dataset_search, subset_dropdown], outputs=[pretty_input_preview, subset_dropdown, split_dropdown, output_format_dataframe, dataset_info_json])
|
329 |
def show_input_from_subset_dropdown(dataset: str, subset: str) -> dict:
|
330 |
return _show_input_preview(dataset, default_subset=subset, default_split="train")
|
331 |
|
332 |
+
@split_dropdown.select(inputs=[dataset_search, subset_dropdown, split_dropdown], outputs=[pretty_input_preview, subset_dropdown, split_dropdown, output_format_dataframe, dataset_info_json])
|
333 |
def show_input_from_split_dropdown(dataset: str, subset: str, split: str) -> dict:
|
334 |
return _show_input_preview(dataset, default_subset=subset, default_split=split)
|
335 |
|
|
|
358 |
print(f"(preview) Done ReWriting {dataset} with instruction '{prompt}'")
|
359 |
|
360 |
|
361 |
+
@rewrite_full_dataset_button.click(inputs=[dataset_search, subset_dropdown, split_dropdown, pretty_input_preview, pretty_output_preview, input_prompt, output_format_dataframe, dataset_info_json, select_namespace_dropdown, max_num_rows_dropdown], outputs=[full_dataset_generation_label, full_dataset_generation_success_html, pretty_output_preview, pretty_full_dataset_generation_output])
|
362 |
+
def rewrite_full_dataset(dataset: str, subset: str, split: str, pretty_input_preview_df: pd.DataFrame, pretty_output_preview_df: pd.DataFrame, prompt: str, output_format_df: pd.DataFrame, dataset_info: dict[str, Any], namespace: str, max_num_rows: int, oauth_token: Optional[gr.OAuthToken]) -> Iterator[pd.DataFrame]:
|
363 |
input_preview_rows = [{k: json.loads(v) for k, v in row.items()} for row in pretty_input_preview_df.to_dict(orient="records")]
|
364 |
output_preview_rows = [{k: json.loads(v) for k, v in row.items()} for row in pretty_output_preview_df.to_dict(orient="records")]
|
365 |
format = output_format_df.to_dict(orient="records")
|
366 |
format = {"properties": {x["column"]: json.loads(x["type"]) for x in format}, "required": [x["column"] for x in format]}
|
367 |
num_examples = dataset_info["splits"][split]["num_examples"]
|
368 |
+
total = min(num_examples, max_num_rows)
|
369 |
print(f"ReWriting {dataset} with instruction '{prompt}'")
|
370 |
yield {full_dataset_generation_label: gr.Label({f"⚙️ ReWriting {dataset}": 0.}, visible=True)}
|
371 |
yield {pretty_full_dataset_generation_output: empty_dataframe}
|
|
|
401 |
print(f"Done ReWriting {dataset} with instruction '{prompt}'")
|
402 |
|
403 |
output_rows = [{k: json.loads(row[k]) for k in output_format_df["column"]} for rows in parallel_output_rows for row in rows]
|
404 |
+
new_dataset = find_new_name(dataset + (PARTIAL_SUFFIX[max_num_rows] if num_examples > total else ""), prompt)
|
405 |
repo_id = namespace + "/" + new_dataset
|
406 |
yield {full_dataset_generation_label: gr.Label({f"✅ ReWriting {dataset}": len(output_rows) / total, f"⚙️ Saving to {repo_id}": 0.})}
|
407 |
token = oauth_token.token if oauth_token else save_dataset_hf_token
|
|
|
410 |
ds.push_to_hub(repo_id, config_name=subset, split=split, token=token)
|
411 |
DatasetCard(DATASET_CARD_CONTENT.format(new_dataset=new_dataset, dataset=dataset, model_id=model_id, prompt=prompt, url=URL)).push_to_hub(repo_id=repo_id, repo_type="dataset", token=token)
|
412 |
yield {full_dataset_generation_label: gr.Label({f"✅ ReWriting {dataset}": len(output_rows) / total, f"✅ Saving to {repo_id}": 1.})}
|
413 |
+
yield {full_dataset_generation_success_html: (
|
414 |
+
f'<a href="https://huggingface.co/datasets/{repo_id}" target="_blank">'
|
415 |
+
'<img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/dataset-on-hf-xl.svg" alt="Dataset on HF", style="margin-right: auto; margin-left: auto; max-width: fit-content;">'
|
416 |
+
'</a>'
|
417 |
+
)}
|
418 |
print(f"Saved {repo_id}")
|
419 |
|
420 |
|