lhoestq HF staff commited on
Commit
3de98b2
1 Parent(s): 231073c

limit max string length and fix input columns

Browse files
Files changed (1) hide show
  1. app.py +20 -24
app.py CHANGED
@@ -37,21 +37,14 @@ assert MAX_NUM_ROWS_TO_REWRITE in PARTIAL_SUFFIX, "allowed max num rows are 100,
37
  NUM_PARALLEL_CALLS = 10
38
  NUM_ROWS_PER_CALL = 3
39
  MAX_PROGRESS_UPDATES_PER_SECOND = 4
40
- REWRITE_DATASET_PREVIEW = (
 
41
  "A Machine Learning practitioner is looking for a dataset similar to '{dataset}' but slightly different. "
42
  "They want you to rewrite the dataset and apply this instruction, which can be about transforming, translating or filtering the rows: {prompt}."
43
  "The first rows of the dataset are below in JSON format:\n\n{rows}\n\n"
44
  "Apply the instruction to those rows from the '{dataset}' dataset and output the resulting rows using the same JSON format. "
45
  "Try to keep some of the text or meaning intact, and apply the requested instruction '{prompt}'."
46
  )
47
- REWRITE_DATASET= (
48
- "A Machine Learning practitioner is looking for a dataset similar to '{dataset}' but slightly different. "
49
- "They want you to rewrite the dataset and apply this instruction, which can be about transforming, translating or filtering the rows: {prompt}."
50
- "Here is an example:\n\nOriginal rows:\n{input_preview_rows}\n\Resulting rows:\n{output_preview_rows}\n\n"
51
- "The rows of the dataset are below in JSON format:\n\n{rows}\n\n"
52
- "Apply the instruction to those rows from the '{dataset}' dataset and output the resulting rows using the same JSON format. "
53
- "Try to keep some of the text or meaning intact, and apply the requested instruction '{prompt}'."
54
- )
55
  FIND_NEW_NAME = (
56
  "You are a helpful assistant specialized in transforming english sentences for machine learning practitioners."
57
  "Your job is to take input sentences like 'Take this dataset and apply the instruction xxx' and rephrase them them as 'The dataset should be yyy'. "
@@ -186,6 +179,13 @@ with gr.Blocks(css=css, js=js) as demo:
186
  class ContextTooLongError(ValueError):
187
  pass
188
 
 
 
 
 
 
 
 
189
  def stream_reponse(messages: list[dict[str: str]], response_format=None, max_tokens=5000) -> Iterator[str]:
190
  for _ in range(3):
191
  message = None
@@ -212,23 +212,21 @@ with gr.Blocks(css=css, js=js) as demo:
212
 
213
  def stream_rewrite_dataset_preview_row_by_row(dataset: str, rows: list[dict[str, str]], prompt: str, format: str) -> Iterator[dict[str, str]]:
214
  prompt = prompt[:1000] if prompt.strip() else ""
215
- messages = [{"role": "user", "content": REWRITE_DATASET_PREVIEW.format(
216
  dataset=dataset,
217
- rows=json.dumps({"data": rows}, ensure_ascii=False),
218
  prompt=prompt,
219
  )}]
220
  response_format = {"type": "json", "value": {"properties": {"data": {"type": "array", "items": format, "minItems": len(rows), "maxItems": len(rows)}}, "required": ["data"]}}
221
  yield from ijson.items(StringIteratorIO(stream_reponse(messages, response_format=response_format)), "data.item", buf_size=4, use_float=True)
222
 
223
 
224
- def stream_rewrite_dataset_row_by_row(dataset: str, rows: list[dict[str, str]], prompt: str, format: str, input_preview_rows: list[dict[str, str]], output_preview_rows: list[dict[str, str]]) -> Iterator[dict[str, str]]:
225
  prompt = prompt[:1000] if prompt.strip() else ""
226
  messages = [{"role": "user", "content": REWRITE_DATASET.format(
227
  dataset=dataset,
228
- rows=json.dumps({"data": rows}, ensure_ascii=False),
229
  prompt=prompt,
230
- input_preview_rows=json.dumps({"data": input_preview_rows}, ensure_ascii=False),
231
- output_preview_rows=json.dumps({"data": output_preview_rows}, ensure_ascii=False),
232
  )}]
233
  response_format = {"type": "json", "value": {"properties": {"data": {"type": "array", "items": format, "minItems": len(rows), "maxItems": len(rows)}}, "required": ["data"]}}
234
  try:
@@ -333,7 +331,7 @@ with gr.Blocks(css=css, js=js) as demo:
333
  print(f"Showing {dataset}")
334
  rows = list(islice((stream_rows(dataset, subset, split, batch_size=NUM_ROWS_PREVIEW)), NUM_ROWS_PREVIEW))
335
  return {
336
- pretty_input_preview: gr.DataFrame(pd.DataFrame([{k: json.dumps(v, ensure_ascii=False) for k, v in row.items()} for row in rows])),
337
  **output
338
  }
339
 
@@ -379,19 +377,17 @@ with gr.Blocks(css=css, js=js) as demo:
379
  full_dataset_generation_success_html: "",
380
  }
381
  for row in stream_rewrite_dataset_preview_row_by_row(dataset=dataset, rows=rows, prompt=prompt, format=format):
382
- output_rows.append({k: json.dumps(row[k], ensure_ascii=False) for k in output_format_df["column"]})
383
  yield {pretty_output_preview: gr.DataFrame(pd.DataFrame(output_rows))}
384
  yield {rewrite_full_dataset_button: gr.Button(interactive=True)}
385
  print(f"(preview) Done ReWriting {dataset} with instruction '{prompt}'")
386
 
387
 
388
- @rewrite_full_dataset_button.click(inputs=[dataset_search, subset_dropdown, split_dropdown, pretty_input_preview, pretty_output_preview, input_prompt, output_format_dataframe, dataset_info_json, select_namespace_dropdown, max_num_rows_dropdown], outputs=[full_dataset_generation_label, full_dataset_generation_success_html, pretty_output_preview, pretty_full_dataset_generation_output])
389
- def rewrite_full_dataset(dataset: str, subset: str, split: str, pretty_input_preview_df: pd.DataFrame, pretty_output_preview_df: pd.DataFrame, prompt: str, output_format_df: pd.DataFrame, dataset_info: dict[str, Any], namespace: str, max_num_rows: int, oauth_token: Optional[gr.OAuthToken]) -> Iterator[pd.DataFrame]:
390
  output_format_df = output_format_df[output_format_df["column"] != ""]
391
  format = output_format_df.to_dict(orient="records")
392
  format = {"properties": {x["column"]: json.loads(x["type"]) for x in format}, "required": [x["column"] for x in format]}
393
- input_preview_rows = [{k: json.loads(row[k]) for k in output_format_df["column"] if k in row} for row in pretty_input_preview_df.to_dict(orient="records")]
394
- output_preview_rows = [{k: json.loads(v) for k, v in row.items()} for row in pretty_output_preview_df.to_dict(orient="records")]
395
  num_examples = dataset_info["splits"][split]["num_examples"]
396
  total = min(num_examples, max_num_rows)
397
  print(f"ReWriting {dataset} with instruction '{prompt}'")
@@ -404,13 +400,13 @@ with gr.Blocks(css=css, js=js) as demo:
404
  }
405
 
406
  num_parallel_calls = max(1, min(total // NUM_ROWS_PER_CALL, NUM_PARALLEL_CALLS))
407
- parallel_input_rows = list(batched(islice(stream_rows(dataset=dataset, subset=subset, split=split), total), n=total // num_parallel_calls))
408
  parallel_output_rows = [[] for _ in range(num_parallel_calls)]
409
 
410
  def run(i):
411
  for batch_rows in batched(parallel_input_rows[i], n=NUM_ROWS_PER_CALL):
412
- for row in stream_rewrite_dataset_row_by_row(dataset=dataset, rows=batch_rows, prompt=prompt, format=format, input_preview_rows=input_preview_rows, output_preview_rows=output_preview_rows):
413
- parallel_output_rows[i].append({k: json.dumps(row[k], ensure_ascii=False) for k in output_format_df["column"]})
414
  yield 1
415
 
416
  current = 0
 
37
  NUM_PARALLEL_CALLS = 10
38
  NUM_ROWS_PER_CALL = 3
39
  MAX_PROGRESS_UPDATES_PER_SECOND = 4
40
+ MAX_STRING_LENGTH = 1000
41
+ REWRITE_DATASET = (
42
  "A Machine Learning practitioner is looking for a dataset similar to '{dataset}' but slightly different. "
43
  "They want you to rewrite the dataset and apply this instruction, which can be about transforming, translating or filtering the rows: {prompt}."
44
  "The first rows of the dataset are below in JSON format:\n\n{rows}\n\n"
45
  "Apply the instruction to those rows from the '{dataset}' dataset and output the resulting rows using the same JSON format. "
46
  "Try to keep some of the text or meaning intact, and apply the requested instruction '{prompt}'."
47
  )
 
 
 
 
 
 
 
 
48
  FIND_NEW_NAME = (
49
  "You are a helpful assistant specialized in transforming english sentences for machine learning practitioners."
50
  "Your job is to take input sentences like 'Take this dataset and apply the instruction xxx' and rephrase them them as 'The dataset should be yyy'. "
 
179
  class ContextTooLongError(ValueError):
180
  pass
181
 
182
+ def crop_text(obj: Any) -> str:
183
+ if isinstance(obj, str):
184
+ return obj[:MAX_STRING_LENGTH]
185
+ else:
186
+ raise TypeError()
187
+
188
+
189
  def stream_reponse(messages: list[dict[str: str]], response_format=None, max_tokens=5000) -> Iterator[str]:
190
  for _ in range(3):
191
  message = None
 
212
 
213
  def stream_rewrite_dataset_preview_row_by_row(dataset: str, rows: list[dict[str, str]], prompt: str, format: str) -> Iterator[dict[str, str]]:
214
  prompt = prompt[:1000] if prompt.strip() else ""
215
+ messages = [{"role": "user", "content": REWRITE_DATASET.format(
216
  dataset=dataset,
217
+ rows=json.dumps({"data": rows}, ensure_ascii=False, default=crop_text),
218
  prompt=prompt,
219
  )}]
220
  response_format = {"type": "json", "value": {"properties": {"data": {"type": "array", "items": format, "minItems": len(rows), "maxItems": len(rows)}}, "required": ["data"]}}
221
  yield from ijson.items(StringIteratorIO(stream_reponse(messages, response_format=response_format)), "data.item", buf_size=4, use_float=True)
222
 
223
 
224
+ def stream_rewrite_dataset_row_by_row(dataset: str, rows: list[dict[str, str]], prompt: str, format: str) -> Iterator[dict[str, str]]:
225
  prompt = prompt[:1000] if prompt.strip() else ""
226
  messages = [{"role": "user", "content": REWRITE_DATASET.format(
227
  dataset=dataset,
228
+ rows=json.dumps({"data": rows}, ensure_ascii=False, default=crop_text),
229
  prompt=prompt,
 
 
230
  )}]
231
  response_format = {"type": "json", "value": {"properties": {"data": {"type": "array", "items": format, "minItems": len(rows), "maxItems": len(rows)}}, "required": ["data"]}}
232
  try:
 
331
  print(f"Showing {dataset}")
332
  rows = list(islice((stream_rows(dataset, subset, split, batch_size=NUM_ROWS_PREVIEW)), NUM_ROWS_PREVIEW))
333
  return {
334
+ pretty_input_preview: gr.DataFrame(pd.DataFrame([{k: json.dumps(v, ensure_ascii=False, default=crop_text) for k, v in row.items()} for row in rows])),
335
  **output
336
  }
337
 
 
377
  full_dataset_generation_success_html: "",
378
  }
379
  for row in stream_rewrite_dataset_preview_row_by_row(dataset=dataset, rows=rows, prompt=prompt, format=format):
380
+ output_rows.append({k: json.dumps(row[k], ensure_ascii=False, default=crop_text) for k in output_format_df["column"]})
381
  yield {pretty_output_preview: gr.DataFrame(pd.DataFrame(output_rows))}
382
  yield {rewrite_full_dataset_button: gr.Button(interactive=True)}
383
  print(f"(preview) Done ReWriting {dataset} with instruction '{prompt}'")
384
 
385
 
386
+ @rewrite_full_dataset_button.click(inputs=[dataset_search, subset_dropdown, split_dropdown, input_prompt, output_format_dataframe, dataset_info_json, select_namespace_dropdown, max_num_rows_dropdown], outputs=[full_dataset_generation_label, full_dataset_generation_success_html, pretty_output_preview, pretty_full_dataset_generation_output])
387
+ def rewrite_full_dataset(dataset: str, subset: str, split: str, prompt: str, output_format_df: pd.DataFrame, dataset_info: dict[str, Any], namespace: str, max_num_rows: int, oauth_token: Optional[gr.OAuthToken]) -> Iterator[pd.DataFrame]:
388
  output_format_df = output_format_df[output_format_df["column"] != ""]
389
  format = output_format_df.to_dict(orient="records")
390
  format = {"properties": {x["column"]: json.loads(x["type"]) for x in format}, "required": [x["column"] for x in format]}
 
 
391
  num_examples = dataset_info["splits"][split]["num_examples"]
392
  total = min(num_examples, max_num_rows)
393
  print(f"ReWriting {dataset} with instruction '{prompt}'")
 
400
  }
401
 
402
  num_parallel_calls = max(1, min(total // NUM_ROWS_PER_CALL, NUM_PARALLEL_CALLS))
403
+ parallel_input_rows = list(batched(islice(({k: row[k] for k in output_format_df["column"] if k in row} for row in stream_rows(dataset=dataset, subset=subset, split=split)), total), n=total // num_parallel_calls))
404
  parallel_output_rows = [[] for _ in range(num_parallel_calls)]
405
 
406
  def run(i):
407
  for batch_rows in batched(parallel_input_rows[i], n=NUM_ROWS_PER_CALL):
408
+ for row in stream_rewrite_dataset_row_by_row(dataset=dataset, rows=batch_rows, prompt=prompt, format=format):
409
+ parallel_output_rows[i].append({k: json.dumps(row[k], ensure_ascii=False, default=crop_text) for k in output_format_df["column"]})
410
  yield 1
411
 
412
  current = 0