lhoestq HF staff commited on
Commit
3a7f10a
1 Parent(s): 063480a

column feature

Browse files
Files changed (1) hide show
  1. app.py +25 -4
app.py CHANGED
@@ -90,6 +90,12 @@ a {
90
  color: var(--body-text-color-subdued);
91
  }
92
  """
 
 
 
 
 
 
93
 
94
  examples = [
95
  ["fka/awesome-chatgpt-prompts", "make the prompt 6 words long maximum"],
@@ -97,7 +103,7 @@ examples = [
97
  ["infinite-dataset-hub/SmallTalkDialogues", "translate to proper French"],
98
  ]
99
 
100
- with gr.Blocks(css=css) as demo:
101
  dataset_info_json = gr.JSON(visible=False)
102
  with gr.Row():
103
  with gr.Column(scale=10):
@@ -123,6 +129,12 @@ with gr.Blocks(css=css) as demo:
123
  input_prompt = gr.Textbox(label="Adjustment or transformation to apply to the dataset")
124
  with gr.Accordion("(Advanced) Edit columns", open=False):
125
  output_format_dataframe = gr.DataFrame(col_count=(2, "fixed"), headers=["column", "type"])
 
 
 
 
 
 
126
  rewrite_preview_button = gr.Button("Preview Results", variant="primary")
127
  rewrite_full_dataset_button = gr.Button("ReWrite Full Dataset", interactive=False)
128
  gr.Markdown("#### Output")
@@ -226,14 +238,14 @@ with gr.Blocks(css=css) as demo:
226
  print("Warning: Some rows were missing during ReWriting.")
227
 
228
 
229
- def find_new_name(dataset: str, prompt: str) -> str:
230
  messages = [{"role": "user", "content": FIND_NEW_NAME.format(prompt=prompt)}]
231
  out = "".join(stream_reponse(messages))
232
  if "should be" in out:
233
  out = dataset.split("/")[-1] + out.split("should be", 1)[1].replace(" ", "-").replace(".", "").replace(",", "")
234
  else:
235
  out = dataset.split("/")[-1] + prompt.replace(" ", "-")
236
- return out[:80] + "-" + Hasher.hash(prompt)[:4]
237
 
238
  def _write_generator_to_queue(queue: Queue, func: Callable[..., Iterable], kwargs: dict) -> None:
239
  for i, result in enumerate(func(**kwargs)):
@@ -342,6 +354,14 @@ with gr.Blocks(css=css) as demo:
342
  @input_prompt.change(outputs=[rewrite_full_dataset_button])
343
  def disable_rewrite_full_dataset() -> dict:
344
  return {rewrite_full_dataset_button: gr.Button(interactive=False)}
 
 
 
 
 
 
 
 
345
 
346
 
347
  @rewrite_preview_button.click(inputs=[dataset_search, pretty_input_preview, input_prompt, output_format_dataframe], outputs=[pretty_output_preview, rewrite_full_dataset_button, full_dataset_generation_label, full_dataset_generation_success_html, pretty_full_dataset_generation_output])
@@ -368,6 +388,7 @@ with gr.Blocks(css=css) as demo:
368
  def rewrite_full_dataset(dataset: str, subset: str, split: str, pretty_input_preview_df: pd.DataFrame, pretty_output_preview_df: pd.DataFrame, prompt: str, output_format_df: pd.DataFrame, dataset_info: dict[str, Any], namespace: str, max_num_rows: int, oauth_token: Optional[gr.OAuthToken]) -> Iterator[pd.DataFrame]:
369
  input_preview_rows = [{k: json.loads(v) for k, v in row.items()} for row in pretty_input_preview_df.to_dict(orient="records")]
370
  output_preview_rows = [{k: json.loads(v) for k, v in row.items()} for row in pretty_output_preview_df.to_dict(orient="records")]
 
371
  format = output_format_df.to_dict(orient="records")
372
  format = {"properties": {x["column"]: json.loads(x["type"]) for x in format}, "required": [x["column"] for x in format]}
373
  num_examples = dataset_info["splits"][split]["num_examples"]
@@ -411,7 +432,7 @@ with gr.Blocks(css=css) as demo:
411
  print(f"Done ReWriting {dataset} with instruction '{prompt}'")
412
 
413
  output_rows = [{k: json.loads(row[k]) for k in output_format_df["column"]} for rows in parallel_output_rows for row in rows]
414
- new_dataset = find_new_name(dataset + (PARTIAL_SUFFIX[max_num_rows] if num_examples > total else ""), prompt)
415
  repo_id = namespace + "/" + new_dataset
416
  yield {full_dataset_generation_label: gr.Label({f"✅ ReWriting {dataset}": len(output_rows) / total, f"⚙️ Saving to {repo_id}": 0.})}
417
  token = oauth_token.token if oauth_token else save_dataset_hf_token
 
90
  color: var(--body-text-color-subdued);
91
  }
92
  """
93
+ js = """
94
+ function load() {
95
+ Array.from(document.getElementsByClassName("secondary")).filter(e => (e.innerText.includes("New row")))[0].innerText = "New column"
96
+ return 'done';
97
+ }
98
+ """
99
 
100
  examples = [
101
  ["fka/awesome-chatgpt-prompts", "make the prompt 6 words long maximum"],
 
103
  ["infinite-dataset-hub/SmallTalkDialogues", "translate to proper French"],
104
  ]
105
 
106
+ with gr.Blocks(css=css, js=js) as demo:
107
  dataset_info_json = gr.JSON(visible=False)
108
  with gr.Row():
109
  with gr.Column(scale=10):
 
129
  input_prompt = gr.Textbox(label="Adjustment or transformation to apply to the dataset")
130
  with gr.Accordion("(Advanced) Edit columns", open=False):
131
  output_format_dataframe = gr.DataFrame(col_count=(2, "fixed"), headers=["column", "type"])
132
+ column_ro_remove_dropdown = gr.Dropdown(info="Select a column to remove", show_label=False)
133
+ with gr.Row():
134
+ with gr.Column(scale=99):
135
+ pass
136
+ with gr.Column(scale=1, min_width=88):
137
+ remove_column_button = gr.Button("Remove", size="sm", elem_id="remove_column_button")
138
  rewrite_preview_button = gr.Button("Preview Results", variant="primary")
139
  rewrite_full_dataset_button = gr.Button("ReWrite Full Dataset", interactive=False)
140
  gr.Markdown("#### Output")
 
238
  print("Warning: Some rows were missing during ReWriting.")
239
 
240
 
241
+ def find_new_name(dataset: str, prompt: str, format: dict) -> str:
242
  messages = [{"role": "user", "content": FIND_NEW_NAME.format(prompt=prompt)}]
243
  out = "".join(stream_reponse(messages))
244
  if "should be" in out:
245
  out = dataset.split("/")[-1] + out.split("should be", 1)[1].replace(" ", "-").replace(".", "").replace(",", "")
246
  else:
247
  out = dataset.split("/")[-1] + prompt.replace(" ", "-")
248
+ return out[:80] + "-" + Hasher.hash(prompt + str(format))[:4]
249
 
250
  def _write_generator_to_queue(queue: Queue, func: Callable[..., Iterable], kwargs: dict) -> None:
251
  for i, result in enumerate(func(**kwargs)):
 
354
  @input_prompt.change(outputs=[rewrite_full_dataset_button])
355
  def disable_rewrite_full_dataset() -> dict:
356
  return {rewrite_full_dataset_button: gr.Button(interactive=False)}
357
+
358
+ @output_format_dataframe.change(inputs=[output_format_dataframe], outputs=[column_ro_remove_dropdown])
359
+ def update_columns_to_remove_dropdown(output_format_df: pd.DataFrame) -> dict:
360
+ return gr.Dropdown(choices=output_format_df["column"].tolist())
361
+
362
+ @remove_column_button.click(inputs=[column_ro_remove_dropdown, output_format_dataframe], outputs=[output_format_dataframe])
363
+ def update_output_format_dataframe(column: str, output_format_df: pd.DataFrame) -> pd.DataFrame:
364
+ return output_format_df[output_format_df["column"] != column]
365
 
366
 
367
  @rewrite_preview_button.click(inputs=[dataset_search, pretty_input_preview, input_prompt, output_format_dataframe], outputs=[pretty_output_preview, rewrite_full_dataset_button, full_dataset_generation_label, full_dataset_generation_success_html, pretty_full_dataset_generation_output])
 
388
  def rewrite_full_dataset(dataset: str, subset: str, split: str, pretty_input_preview_df: pd.DataFrame, pretty_output_preview_df: pd.DataFrame, prompt: str, output_format_df: pd.DataFrame, dataset_info: dict[str, Any], namespace: str, max_num_rows: int, oauth_token: Optional[gr.OAuthToken]) -> Iterator[pd.DataFrame]:
389
  input_preview_rows = [{k: json.loads(v) for k, v in row.items()} for row in pretty_input_preview_df.to_dict(orient="records")]
390
  output_preview_rows = [{k: json.loads(v) for k, v in row.items()} for row in pretty_output_preview_df.to_dict(orient="records")]
391
+ output_format_df = output_format_df[output_format_df["column"] != ""]
392
  format = output_format_df.to_dict(orient="records")
393
  format = {"properties": {x["column"]: json.loads(x["type"]) for x in format}, "required": [x["column"] for x in format]}
394
  num_examples = dataset_info["splits"][split]["num_examples"]
 
432
  print(f"Done ReWriting {dataset} with instruction '{prompt}'")
433
 
434
  output_rows = [{k: json.loads(row[k]) for k in output_format_df["column"]} for rows in parallel_output_rows for row in rows]
435
+ new_dataset = find_new_name(dataset + (PARTIAL_SUFFIX[max_num_rows] if num_examples > total else ""), prompt, format)
436
  repo_id = namespace + "/" + new_dataset
437
  yield {full_dataset_generation_label: gr.Label({f"✅ ReWriting {dataset}": len(output_rows) / total, f"⚙️ Saving to {repo_id}": 0.})}
438
  token = oauth_token.token if oauth_token else save_dataset_hf_token