Spaces:
Sleeping
Sleeping
limit max string length and fix input columns
Browse files
app.py
CHANGED
@@ -37,21 +37,14 @@ assert MAX_NUM_ROWS_TO_REWRITE in PARTIAL_SUFFIX, "allowed max num rows are 100,
|
|
37 |
NUM_PARALLEL_CALLS = 10
|
38 |
NUM_ROWS_PER_CALL = 3
|
39 |
MAX_PROGRESS_UPDATES_PER_SECOND = 4
|
40 |
-
|
|
|
41 |
"A Machine Learning practitioner is looking for a dataset similar to '{dataset}' but slightly different. "
|
42 |
"They want you to rewrite the dataset and apply this instruction, which can be about transforming, translating or filtering the rows: {prompt}."
|
43 |
"The first rows of the dataset are below in JSON format:\n\n{rows}\n\n"
|
44 |
"Apply the instruction to those rows from the '{dataset}' dataset and output the resulting rows using the same JSON format. "
|
45 |
"Try to keep some of the text or meaning intact, and apply the requested instruction '{prompt}'."
|
46 |
)
|
47 |
-
REWRITE_DATASET= (
|
48 |
-
"A Machine Learning practitioner is looking for a dataset similar to '{dataset}' but slightly different. "
|
49 |
-
"They want you to rewrite the dataset and apply this instruction, which can be about transforming, translating or filtering the rows: {prompt}."
|
50 |
-
"Here is an example:\n\nOriginal rows:\n{input_preview_rows}\n\Resulting rows:\n{output_preview_rows}\n\n"
|
51 |
-
"The rows of the dataset are below in JSON format:\n\n{rows}\n\n"
|
52 |
-
"Apply the instruction to those rows from the '{dataset}' dataset and output the resulting rows using the same JSON format. "
|
53 |
-
"Try to keep some of the text or meaning intact, and apply the requested instruction '{prompt}'."
|
54 |
-
)
|
55 |
FIND_NEW_NAME = (
|
56 |
"You are a helpful assistant specialized in transforming english sentences for machine learning practitioners."
|
57 |
"Your job is to take input sentences like 'Take this dataset and apply the instruction xxx' and rephrase them them as 'The dataset should be yyy'. "
|
@@ -186,6 +179,13 @@ with gr.Blocks(css=css, js=js) as demo:
|
|
186 |
class ContextTooLongError(ValueError):
|
187 |
pass
|
188 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
189 |
def stream_reponse(messages: list[dict[str: str]], response_format=None, max_tokens=5000) -> Iterator[str]:
|
190 |
for _ in range(3):
|
191 |
message = None
|
@@ -212,23 +212,21 @@ with gr.Blocks(css=css, js=js) as demo:
|
|
212 |
|
213 |
def stream_rewrite_dataset_preview_row_by_row(dataset: str, rows: list[dict[str, str]], prompt: str, format: str) -> Iterator[dict[str, str]]:
|
214 |
prompt = prompt[:1000] if prompt.strip() else ""
|
215 |
-
messages = [{"role": "user", "content":
|
216 |
dataset=dataset,
|
217 |
-
rows=json.dumps({"data": rows}, ensure_ascii=False),
|
218 |
prompt=prompt,
|
219 |
)}]
|
220 |
response_format = {"type": "json", "value": {"properties": {"data": {"type": "array", "items": format, "minItems": len(rows), "maxItems": len(rows)}}, "required": ["data"]}}
|
221 |
yield from ijson.items(StringIteratorIO(stream_reponse(messages, response_format=response_format)), "data.item", buf_size=4, use_float=True)
|
222 |
|
223 |
|
224 |
-
def stream_rewrite_dataset_row_by_row(dataset: str, rows: list[dict[str, str]], prompt: str, format: str
|
225 |
prompt = prompt[:1000] if prompt.strip() else ""
|
226 |
messages = [{"role": "user", "content": REWRITE_DATASET.format(
|
227 |
dataset=dataset,
|
228 |
-
rows=json.dumps({"data": rows}, ensure_ascii=False),
|
229 |
prompt=prompt,
|
230 |
-
input_preview_rows=json.dumps({"data": input_preview_rows}, ensure_ascii=False),
|
231 |
-
output_preview_rows=json.dumps({"data": output_preview_rows}, ensure_ascii=False),
|
232 |
)}]
|
233 |
response_format = {"type": "json", "value": {"properties": {"data": {"type": "array", "items": format, "minItems": len(rows), "maxItems": len(rows)}}, "required": ["data"]}}
|
234 |
try:
|
@@ -333,7 +331,7 @@ with gr.Blocks(css=css, js=js) as demo:
|
|
333 |
print(f"Showing {dataset}")
|
334 |
rows = list(islice((stream_rows(dataset, subset, split, batch_size=NUM_ROWS_PREVIEW)), NUM_ROWS_PREVIEW))
|
335 |
return {
|
336 |
-
pretty_input_preview: gr.DataFrame(pd.DataFrame([{k: json.dumps(v, ensure_ascii=False) for k, v in row.items()} for row in rows])),
|
337 |
**output
|
338 |
}
|
339 |
|
@@ -379,19 +377,17 @@ with gr.Blocks(css=css, js=js) as demo:
|
|
379 |
full_dataset_generation_success_html: "",
|
380 |
}
|
381 |
for row in stream_rewrite_dataset_preview_row_by_row(dataset=dataset, rows=rows, prompt=prompt, format=format):
|
382 |
-
output_rows.append({k: json.dumps(row[k], ensure_ascii=False) for k in output_format_df["column"]})
|
383 |
yield {pretty_output_preview: gr.DataFrame(pd.DataFrame(output_rows))}
|
384 |
yield {rewrite_full_dataset_button: gr.Button(interactive=True)}
|
385 |
print(f"(preview) Done ReWriting {dataset} with instruction '{prompt}'")
|
386 |
|
387 |
|
388 |
-
@rewrite_full_dataset_button.click(inputs=[dataset_search, subset_dropdown, split_dropdown,
|
389 |
-
def rewrite_full_dataset(dataset: str, subset: str, split: str,
|
390 |
output_format_df = output_format_df[output_format_df["column"] != ""]
|
391 |
format = output_format_df.to_dict(orient="records")
|
392 |
format = {"properties": {x["column"]: json.loads(x["type"]) for x in format}, "required": [x["column"] for x in format]}
|
393 |
-
input_preview_rows = [{k: json.loads(row[k]) for k in output_format_df["column"] if k in row} for row in pretty_input_preview_df.to_dict(orient="records")]
|
394 |
-
output_preview_rows = [{k: json.loads(v) for k, v in row.items()} for row in pretty_output_preview_df.to_dict(orient="records")]
|
395 |
num_examples = dataset_info["splits"][split]["num_examples"]
|
396 |
total = min(num_examples, max_num_rows)
|
397 |
print(f"ReWriting {dataset} with instruction '{prompt}'")
|
@@ -404,13 +400,13 @@ with gr.Blocks(css=css, js=js) as demo:
|
|
404 |
}
|
405 |
|
406 |
num_parallel_calls = max(1, min(total // NUM_ROWS_PER_CALL, NUM_PARALLEL_CALLS))
|
407 |
-
parallel_input_rows = list(batched(islice(stream_rows(dataset=dataset, subset=subset, split=split), total), n=total // num_parallel_calls))
|
408 |
parallel_output_rows = [[] for _ in range(num_parallel_calls)]
|
409 |
|
410 |
def run(i):
|
411 |
for batch_rows in batched(parallel_input_rows[i], n=NUM_ROWS_PER_CALL):
|
412 |
-
for row in stream_rewrite_dataset_row_by_row(dataset=dataset, rows=batch_rows, prompt=prompt, format=format
|
413 |
-
parallel_output_rows[i].append({k: json.dumps(row[k], ensure_ascii=False) for k in output_format_df["column"]})
|
414 |
yield 1
|
415 |
|
416 |
current = 0
|
|
|
37 |
NUM_PARALLEL_CALLS = 10
|
38 |
NUM_ROWS_PER_CALL = 3
|
39 |
MAX_PROGRESS_UPDATES_PER_SECOND = 4
|
40 |
+
MAX_STRING_LENGTH = 1000
|
41 |
+
REWRITE_DATASET = (
|
42 |
"A Machine Learning practitioner is looking for a dataset similar to '{dataset}' but slightly different. "
|
43 |
"They want you to rewrite the dataset and apply this instruction, which can be about transforming, translating or filtering the rows: {prompt}."
|
44 |
"The first rows of the dataset are below in JSON format:\n\n{rows}\n\n"
|
45 |
"Apply the instruction to those rows from the '{dataset}' dataset and output the resulting rows using the same JSON format. "
|
46 |
"Try to keep some of the text or meaning intact, and apply the requested instruction '{prompt}'."
|
47 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
48 |
FIND_NEW_NAME = (
|
49 |
"You are a helpful assistant specialized in transforming english sentences for machine learning practitioners."
|
50 |
"Your job is to take input sentences like 'Take this dataset and apply the instruction xxx' and rephrase them them as 'The dataset should be yyy'. "
|
|
|
179 |
class ContextTooLongError(ValueError):
|
180 |
pass
|
181 |
|
182 |
+
def crop_text(obj: Any) -> str:
|
183 |
+
if isinstance(obj, str):
|
184 |
+
return obj[:MAX_STRING_LENGTH]
|
185 |
+
else:
|
186 |
+
raise TypeError()
|
187 |
+
|
188 |
+
|
189 |
def stream_reponse(messages: list[dict[str: str]], response_format=None, max_tokens=5000) -> Iterator[str]:
|
190 |
for _ in range(3):
|
191 |
message = None
|
|
|
212 |
|
213 |
def stream_rewrite_dataset_preview_row_by_row(dataset: str, rows: list[dict[str, str]], prompt: str, format: str) -> Iterator[dict[str, str]]:
|
214 |
prompt = prompt[:1000] if prompt.strip() else ""
|
215 |
+
messages = [{"role": "user", "content": REWRITE_DATASET.format(
|
216 |
dataset=dataset,
|
217 |
+
rows=json.dumps({"data": rows}, ensure_ascii=False, default=crop_text),
|
218 |
prompt=prompt,
|
219 |
)}]
|
220 |
response_format = {"type": "json", "value": {"properties": {"data": {"type": "array", "items": format, "minItems": len(rows), "maxItems": len(rows)}}, "required": ["data"]}}
|
221 |
yield from ijson.items(StringIteratorIO(stream_reponse(messages, response_format=response_format)), "data.item", buf_size=4, use_float=True)
|
222 |
|
223 |
|
224 |
+
def stream_rewrite_dataset_row_by_row(dataset: str, rows: list[dict[str, str]], prompt: str, format: str) -> Iterator[dict[str, str]]:
|
225 |
prompt = prompt[:1000] if prompt.strip() else ""
|
226 |
messages = [{"role": "user", "content": REWRITE_DATASET.format(
|
227 |
dataset=dataset,
|
228 |
+
rows=json.dumps({"data": rows}, ensure_ascii=False, default=crop_text),
|
229 |
prompt=prompt,
|
|
|
|
|
230 |
)}]
|
231 |
response_format = {"type": "json", "value": {"properties": {"data": {"type": "array", "items": format, "minItems": len(rows), "maxItems": len(rows)}}, "required": ["data"]}}
|
232 |
try:
|
|
|
331 |
print(f"Showing {dataset}")
|
332 |
rows = list(islice((stream_rows(dataset, subset, split, batch_size=NUM_ROWS_PREVIEW)), NUM_ROWS_PREVIEW))
|
333 |
return {
|
334 |
+
pretty_input_preview: gr.DataFrame(pd.DataFrame([{k: json.dumps(v, ensure_ascii=False, default=crop_text) for k, v in row.items()} for row in rows])),
|
335 |
**output
|
336 |
}
|
337 |
|
|
|
377 |
full_dataset_generation_success_html: "",
|
378 |
}
|
379 |
for row in stream_rewrite_dataset_preview_row_by_row(dataset=dataset, rows=rows, prompt=prompt, format=format):
|
380 |
+
output_rows.append({k: json.dumps(row[k], ensure_ascii=False, default=crop_text) for k in output_format_df["column"]})
|
381 |
yield {pretty_output_preview: gr.DataFrame(pd.DataFrame(output_rows))}
|
382 |
yield {rewrite_full_dataset_button: gr.Button(interactive=True)}
|
383 |
print(f"(preview) Done ReWriting {dataset} with instruction '{prompt}'")
|
384 |
|
385 |
|
386 |
+
@rewrite_full_dataset_button.click(inputs=[dataset_search, subset_dropdown, split_dropdown, input_prompt, output_format_dataframe, dataset_info_json, select_namespace_dropdown, max_num_rows_dropdown], outputs=[full_dataset_generation_label, full_dataset_generation_success_html, pretty_output_preview, pretty_full_dataset_generation_output])
|
387 |
+
def rewrite_full_dataset(dataset: str, subset: str, split: str, prompt: str, output_format_df: pd.DataFrame, dataset_info: dict[str, Any], namespace: str, max_num_rows: int, oauth_token: Optional[gr.OAuthToken]) -> Iterator[pd.DataFrame]:
|
388 |
output_format_df = output_format_df[output_format_df["column"] != ""]
|
389 |
format = output_format_df.to_dict(orient="records")
|
390 |
format = {"properties": {x["column"]: json.loads(x["type"]) for x in format}, "required": [x["column"] for x in format]}
|
|
|
|
|
391 |
num_examples = dataset_info["splits"][split]["num_examples"]
|
392 |
total = min(num_examples, max_num_rows)
|
393 |
print(f"ReWriting {dataset} with instruction '{prompt}'")
|
|
|
400 |
}
|
401 |
|
402 |
num_parallel_calls = max(1, min(total // NUM_ROWS_PER_CALL, NUM_PARALLEL_CALLS))
|
403 |
+
parallel_input_rows = list(batched(islice(({k: row[k] for k in output_format_df["column"] if k in row} for row in stream_rows(dataset=dataset, subset=subset, split=split)), total), n=total // num_parallel_calls))
|
404 |
parallel_output_rows = [[] for _ in range(num_parallel_calls)]
|
405 |
|
406 |
def run(i):
|
407 |
for batch_rows in batched(parallel_input_rows[i], n=NUM_ROWS_PER_CALL):
|
408 |
+
for row in stream_rewrite_dataset_row_by_row(dataset=dataset, rows=batch_rows, prompt=prompt, format=format):
|
409 |
+
parallel_output_rows[i].append({k: json.dumps(row[k], ensure_ascii=False, default=crop_text) for k in output_format_df["column"]})
|
410 |
yield 1
|
411 |
|
412 |
current = 0
|