Spaces:
Sleeping
Sleeping
import os | |
import json | |
import time | |
from itertools import count, islice | |
from multiprocessing.pool import ThreadPool | |
from queue import Queue, Empty | |
from typing import Any, Callable, Iterable, Iterator, Optional, TypeVar | |
import gradio as gr | |
import ijson | |
import pandas as pd | |
import requests | |
from datasets import Dataset, Features, Value, Sequence | |
from datasets.fingerprint import Hasher | |
from gradio_huggingfacehub_search import HuggingfaceHubSearch | |
from huggingface_hub import DatasetCard, InferenceClient | |
from utils import StringIteratorIO | |
model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct" | |
client = InferenceClient(model_id, token=os.environ.get("INFERENCE_API_HF_TOKEN")) | |
save_dataset_hf_token = os.environ.get("SAVE_DATASET_HF_TOKEN") | |
session = requests.Session() | |
empty_dataframe = pd.DataFrame({"1": [], "2": [], "3": []}) | |
loading_dataframe = pd.DataFrame({"Loading...": ["..."]}) | |
NAMESPACE = "dataset-rewriter" | |
URL = "https://huggingface.co/spaces/dataset-rewriter/dataset-rewriter" | |
NUM_ROWS_PREVIEW = 3 | |
PARTIAL_SUFFIX = {10: "-10", 100: "-100", 1000: "-1k", 10_000: "-10k", 100_000: "-100k", 1000_000: "-1M"} | |
MAX_NUM_ROWS_TO_REWRITE = int(os.environ.get("MAX_NUM_ROWS_TO_REWRITE") or 1000) | |
assert MAX_NUM_ROWS_TO_REWRITE in PARTIAL_SUFFIX, "allowed max num rows are 100, 1000, 10000, 100000 and 1000000" | |
NUM_PARALLEL_CALLS = 10 | |
NUM_ROWS_PER_CALL = 3 | |
MAX_PROGRESS_UPDATES_PER_SECOND = 4 | |
MAX_STRING_LENGTH = 1000 | |
REWRITE_DATASET = ( | |
"A Machine Learning practitioner is looking for a dataset similar to '{dataset}' but slightly different. " | |
"They want you to rewrite the dataset and apply this instruction, which can be about transforming, translating or filtering the rows: {prompt}." | |
"The first rows of the dataset are below in JSON format:\n\n{rows}\n\n" | |
"Apply the instruction to those rows from the '{dataset}' dataset and output the resulting rows using the same JSON format. " | |
"Try to keep some of the text or meaning intact, and apply the requested instruction '{prompt}'." | |
) | |
FIND_NEW_NAME = ( | |
"You are a helpful assistant specialized in transforming english sentences for machine learning practitioners." | |
"Your job is to take input sentences like 'Take this dataset and apply the instruction xxx' and rephrase them them as 'The dataset should be yyy'. " | |
"You shoud use adjectives and exactly follow the output formula 'The dataset should be yyy'. " | |
"Here is your first job: rephrase the sentence 'Take this dataset and apply the instruction \"{prompt}\"'" | |
) | |
DATASET_CARD_CONTENT = """ | |
--- | |
license: mit | |
tags: | |
- dataset-rewriter | |
- synthetic | |
--- | |
# {new_dataset} | |
_Note: This is an AI-generated dataset so its content may be inaccurate or false_ | |
**Source of the data:** | |
The dataset was generated using the [Dataset ReWriter]({url}) and {model_id} from the dataset {dataset} and using the prompt '{prompt}': | |
- **Original Dataset**: https://huggingface.co/datasets/{dataset} | |
- **Model**: https://huggingface.co/{model_id} | |
- **More Datasets**: https://huggingface.co/datasets?other=dataset-rewriter | |
""" | |
css = """ | |
a { | |
color: var(--body-text-color); | |
} | |
.settings { | |
background: transparent; | |
} | |
.settings button span { | |
color: var(--body-text-color-subdued); | |
} | |
""" | |
js = """ | |
function load() { | |
Array.from(document.getElementsByClassName("secondary")).filter(e => (e.innerText.includes("New row")))[0].innerText = "New column" | |
return 'done'; | |
} | |
""" | |
examples = [ | |
["fka/awesome-chatgpt-prompts", "make the prompt 6 words long maximum"], | |
["lhoestq/CudyPokemonAdventures", "make Pikachu the main character"], | |
["infinite-dataset-hub/SmallTalkDialogues", "translate to proper French"], | |
] | |
with gr.Blocks(css=css, js=js) as demo: | |
dataset_info_json = gr.JSON(visible=False) | |
with gr.Row(): | |
with gr.Column(scale=10): | |
gr.Markdown( | |
"# 🤗 Dataset ReWriter ✍️✨\n\n" | |
"Adjust, translate or transform datasets with a text instruction.\n\n" | |
) | |
with gr.Row(): | |
with gr.Column(scale=3): | |
dataset_search = HuggingfaceHubSearch( | |
label="Hub Dataset ID", | |
placeholder="Search for dataset id on Huggingface", | |
search_type="dataset", | |
) | |
subset_dropdown = gr.Dropdown(info="Subset", show_label=False, visible=False) | |
split_dropdown = gr.Dropdown(info="Split", show_label=False, visible=False) | |
gr.Markdown("### Sample") | |
pretty_input_preview = gr.DataFrame(interactive=False) | |
gr.Markdown("### ReWrite") | |
with gr.Group(): | |
input_prompt = gr.Textbox(label="Adjustment or transformation to apply to the dataset") | |
with gr.Accordion("(Advanced) Edit columns", open=False): | |
output_format_dataframe = gr.DataFrame(col_count=(2, "fixed"), headers=["column", "type"]) | |
column_ro_remove_dropdown = gr.Dropdown(info="Select a column to remove", show_label=False) | |
with gr.Row(): | |
with gr.Column(scale=99): | |
pass | |
with gr.Column(scale=1, min_width=88): | |
remove_column_button = gr.Button("Remove", size="sm", elem_id="remove_column_button") | |
rewrite_preview_button = gr.Button("Preview Results", variant="primary") | |
rewrite_full_dataset_button = gr.Button("ReWrite Full Dataset", interactive=False) | |
gr.Markdown("#### Output") | |
full_dataset_generation_label = gr.Label(visible=False, show_label=False) | |
pretty_output_preview = gr.DataFrame(interactive=False) | |
pretty_full_dataset_generation_output = gr.DataFrame(interactive=False, visible=False) | |
full_dataset_generation_success_html = gr.HTML() | |
gr.Examples(examples, inputs=[dataset_search, input_prompt]) | |
gr.Markdown(f"_powered by [{model_id}](https://huggingface.co/{model_id})_") | |
with gr.Column(scale=4, min_width="200px"): | |
with gr.Accordion("Settings", open=False, elem_classes="settings"): | |
gr.Markdown("Save datasets to your account") | |
gr.LoginButton() | |
select_namespace_dropdown = gr.Dropdown(choices=[NAMESPACE], value=NAMESPACE, label="Select user or organization", visible=False) | |
gr.Markdown("Save datasets as public or private datasets") | |
visibility_radio = gr.Radio(["public", "private"], value="public", container=False, interactive=False) | |
gr.Markdown("Maximum number of rows to ReWrite") | |
max_num_rows_dropdown = gr.Dropdown(choices=[num_rows for num_rows in PARTIAL_SUFFIX if num_rows <= MAX_NUM_ROWS_TO_REWRITE], value=MAX_NUM_ROWS_TO_REWRITE, container=False) | |
gr.Markdown("Duplicate this space to ReWrite more rows") | |
gr.HTML(f'<a href="{URL}?duplicate=true" target="_blank"><img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/duplicate-this-space-xl.svg" alt="Duplicate this Space"></a>') | |
############ | |
# | |
# Utils | |
# | |
########### | |
def stream_rows(dataset: str, subset: str, split: str, batch_size: int = 100) -> Iterable[dict[str, Any]]: | |
for i in count(): | |
rows_resp = session.get(f"https://datasets-server.huggingface.co/rows?dataset={dataset}&config={subset}&split={split}&offset={i * batch_size}&length={batch_size}", timeout=10).json() | |
if "error" in rows_resp: | |
raise RuntimeError(rows_resp["error"]) | |
if not rows_resp["rows"]: | |
break | |
for row_item in rows_resp["rows"]: | |
yield row_item["row"] | |
T = TypeVar("T") | |
def batched(it: Iterable[T], n: int) -> Iterator[list[T]]: | |
it = iter(it) | |
while batch := list(islice(it, n)): | |
yield batch | |
class ContextTooLongError(ValueError): | |
pass | |
def crop_text(obj: Any) -> str: | |
if isinstance(obj, str): | |
return obj[:MAX_STRING_LENGTH] | |
else: | |
raise TypeError() | |
def stream_reponse(messages: list[dict[str: str]], response_format=None, max_tokens=5000) -> Iterator[str]: | |
for _ in range(3): | |
message = None | |
try: | |
for message in client.chat_completion( | |
messages=messages, | |
max_tokens=max_tokens, | |
stream=True, | |
top_p=0.8, | |
seed=42, | |
response_format=response_format | |
): | |
if message is None or not message.choices or message.choices[0] is None or message.choices[0].delta is None or message.choices[0].delta.content is None: | |
raise ContextTooLongError(f"messages: {sum(len(message['content']) for message in messages)} chars, max_tokens: {max_tokens}") | |
yield message.choices[0].delta.content | |
except requests.exceptions.ConnectionError as e: | |
if message: | |
raise | |
print(e + "\n\nRetrying in 1sec") | |
time.sleep(1) | |
continue | |
break | |
def stream_rewrite_dataset_preview_row_by_row(dataset: str, rows: list[dict[str, str]], prompt: str, format: str) -> Iterator[dict[str, str]]: | |
prompt = prompt[:1000] if prompt.strip() else "" | |
messages = [{"role": "user", "content": REWRITE_DATASET.format( | |
dataset=dataset, | |
rows=json.dumps({"data": rows}, ensure_ascii=False, default=crop_text), | |
prompt=prompt, | |
)}] | |
response_format = {"type": "json", "value": {"properties": {"data": {"type": "array", "items": format, "minItems": len(rows), "maxItems": len(rows)}}, "required": ["data"]}} | |
yield from ijson.items(StringIteratorIO(stream_reponse(messages, response_format=response_format)), "data.item", buf_size=4, use_float=True) | |
def stream_rewrite_dataset_row_by_row(dataset: str, rows: list[dict[str, str]], prompt: str, format: str) -> Iterator[dict[str, str]]: | |
prompt = prompt[:1000] if prompt.strip() else "" | |
messages = [{"role": "user", "content": REWRITE_DATASET.format( | |
dataset=dataset, | |
rows=json.dumps({"data": rows}, ensure_ascii=False, default=crop_text), | |
prompt=prompt, | |
)}] | |
response_format = {"type": "json", "value": {"properties": {"data": {"type": "array", "items": format, "minItems": len(rows), "maxItems": len(rows)}}, "required": ["data"]}} | |
try: | |
yield from ijson.items(StringIteratorIO(stream_reponse(messages, response_format=response_format)), "data.item", buf_size=4, use_float=True) | |
except (ijson.IncompleteJSONError) as e: | |
print(f"{type(e).__name__}: {e}") | |
print("Warning: Some rows were missing during ReWriting.") | |
def find_new_name(dataset: str, prompt: str, format: dict) -> str: | |
messages = [{"role": "user", "content": FIND_NEW_NAME.format(prompt=prompt)}] | |
out = "".join(stream_reponse(messages)) | |
if "should be" in out: | |
out = dataset.split("/")[-1] + out.split("should be", 1)[1].replace(" ", "-").replace(".", "").replace(",", "") | |
else: | |
out = dataset.split("/")[-1] + prompt.replace(" ", "-") | |
return out[:80] + "-" + Hasher.hash(prompt + str(format))[:4] | |
def _write_generator_to_queue(queue: Queue, func: Callable[..., Iterable], kwargs: dict) -> None: | |
for i, result in enumerate(func(**kwargs)): | |
queue.put(result) | |
return None | |
def iflatmap_unordered( | |
func: Callable[..., Iterable[T]], | |
*, | |
kwargs_iterable: Iterable[dict], | |
) -> Iterable[T]: | |
queue = Queue() | |
with ThreadPool() as pool: | |
async_results = [pool.apply_async(_write_generator_to_queue, (queue, func, kwargs)) for kwargs in kwargs_iterable] | |
try: | |
while True: | |
try: | |
yield queue.get(timeout=0.05) | |
except Empty: | |
if all(async_result.ready() for async_result in async_results) and queue.empty(): | |
break | |
finally: # in case there's an error to raise | |
[async_result.get(timeout=0.05) for async_result in async_results] | |
def features_to_format(features: Features) -> dict: | |
def feature_to_format(feature): | |
if isinstance(feature, Value): | |
if "int" in feature.dtype: | |
return {"type": "integer"} | |
elif "float" in feature.dtype: | |
return {"type": "number"} | |
else: | |
return {"type": "string"} | |
elif isinstance(feature, list): | |
return {"type": "array", "items": feature_to_format(feature[0])} | |
elif isinstance(feature, dict): | |
return {"properties": {k: feature_to_format(v) for k, v in feature.items()}, "required": list(feature)} | |
elif isinstance(feature, Sequence): | |
if isinstance(feature.feature, dict): | |
return {"properties": {k: {"type": "array", "items": v } for k, v in feature_to_format(feature.feature).items()}, "required": list(feature)} | |
else: | |
return {"type": "array", "items": feature_to_format(feature.feature)} | |
else: | |
return {"type": "string"} | |
return feature_to_format(features) | |
############ | |
# | |
# Events | |
# | |
########### | |
def _resolve_dataset_selection(dataset: str, default_subset: str, default_split: str) -> dict: | |
if "/" not in dataset.strip().strip("/"): | |
return None, None, { | |
subset_dropdown: gr.Dropdown(visible=False), | |
split_dropdown: gr.Dropdown(visible=False), | |
} | |
info_resp = session.get(f"https://datasets-server.huggingface.co/info?dataset={dataset}", timeout=3).json() | |
if "error" in info_resp: | |
return None, None, { | |
subset_dropdown: gr.Dropdown(visible=False), | |
split_dropdown: gr.Dropdown(visible=False), | |
} | |
subsets: list[str] = list(info_resp["dataset_info"]) | |
subset = default_subset if default_subset in subsets else subsets[0] | |
splits: list[str] = info_resp["dataset_info"][subset]["splits"] | |
split = default_split if default_split in splits else splits[0] | |
dict_format = features_to_format(Features.from_dict(info_resp["dataset_info"][subset]["features"])) | |
return subset, split, { | |
dataset_info_json: info_resp["dataset_info"][subset], | |
subset_dropdown: gr.Dropdown(value=subset, choices=subsets, visible=len(subsets) > 1), | |
split_dropdown: gr.Dropdown(value=split, choices=splits, visible=len(splits) > 1), | |
output_format_dataframe: pd.DataFrame([{"column": col, "type": json.dumps(format_type)} for col, format_type in dict_format["properties"].items()]) | |
} | |
def _show_input_preview(dataset: str, default_subset: str, default_split: str) -> dict: | |
subset, split, output = _resolve_dataset_selection(dataset, default_subset=default_subset, default_split=default_split) | |
if subset is None or split is None: | |
return output | |
print(f"Showing {dataset}") | |
rows = list(islice((stream_rows(dataset, subset, split, batch_size=NUM_ROWS_PREVIEW)), NUM_ROWS_PREVIEW)) | |
return { | |
pretty_input_preview: gr.DataFrame(pd.DataFrame([{k: json.dumps(v, ensure_ascii=False, default=crop_text) for k, v in row.items()} for row in rows])), | |
**output | |
} | |
def show_input_from_dataset_search(dataset: str) -> dict: | |
return _show_input_preview(dataset, default_subset="default", default_split="train") | |
def show_input_from_subset_dropdown(dataset: str, subset: str) -> dict: | |
return _show_input_preview(dataset, default_subset=subset, default_split="train") | |
def show_input_from_split_dropdown(dataset: str, subset: str, split: str) -> dict: | |
return _show_input_preview(dataset, default_subset=subset, default_split=split) | |
def disable_rewrite_full_dataset() -> dict: | |
return {rewrite_full_dataset_button: gr.Button(interactive=False)} | |
def update_columns_to_remove_dropdown(output_format_df: pd.DataFrame) -> dict: | |
return gr.Dropdown(choices=output_format_df["column"].tolist()) | |
def update_output_format_dataframe(column: str, output_format_df: pd.DataFrame) -> pd.DataFrame: | |
return output_format_df[output_format_df["column"] != column] | |
def rewrite_preview(dataset: str, pretty_input_preview_df: pd.DataFrame, prompt: str, output_format_df: pd.DataFrame) -> Iterator[pd.DataFrame]: | |
output_format_df = output_format_df[output_format_df["column"] != ""] | |
format = output_format_df.to_dict(orient="records") | |
format = {"properties": {x["column"]: json.loads(x["type"]) for x in format}, "required": [x["column"] for x in format]} | |
rows = [{k: json.loads(row[k]) for k in output_format_df["column"] if k in row} for row in pretty_input_preview_df.to_dict(orient="records")] | |
output_rows = [] | |
print(f"(preview) ReWriting {dataset} with instruction '{prompt}'") | |
yield {rewrite_full_dataset_button: gr.Button(interactive=False), full_dataset_generation_label: gr.Label(visible=False)} | |
yield { | |
pretty_output_preview: gr.DataFrame(loading_dataframe, visible=True), | |
pretty_full_dataset_generation_output: gr.DataFrame(visible=False), | |
full_dataset_generation_success_html: "", | |
} | |
for row in stream_rewrite_dataset_preview_row_by_row(dataset=dataset, rows=rows, prompt=prompt, format=format): | |
output_rows.append({k: json.dumps(row[k], ensure_ascii=False) for k in output_format_df["column"]}) | |
yield {pretty_output_preview: gr.DataFrame(pd.DataFrame(output_rows))} | |
yield {rewrite_full_dataset_button: gr.Button(interactive=True)} | |
print(f"(preview) Done ReWriting {dataset} with instruction '{prompt}'") | |
def rewrite_full_dataset(dataset: str, subset: str, split: str, prompt: str, output_format_df: pd.DataFrame, dataset_info: dict[str, Any], namespace: str, max_num_rows: int, oauth_token: Optional[gr.OAuthToken]) -> Iterator[pd.DataFrame]: | |
output_format_df = output_format_df[output_format_df["column"] != ""] | |
format = output_format_df.to_dict(orient="records") | |
format = {"properties": {x["column"]: json.loads(x["type"]) for x in format}, "required": [x["column"] for x in format]} | |
num_examples = dataset_info["splits"][split]["num_examples"] | |
total = min(num_examples, max_num_rows) | |
print(f"ReWriting {dataset} with instruction '{prompt}'") | |
yield {full_dataset_generation_label: gr.Label({f"⚙️ ReWriting {dataset}": 0.}, visible=True)} | |
yield {pretty_full_dataset_generation_output: empty_dataframe} | |
yield { | |
pretty_output_preview: gr.DataFrame(visible=False), | |
pretty_full_dataset_generation_output: gr.DataFrame(loading_dataframe, visible=True), | |
full_dataset_generation_success_html: "", | |
} | |
num_parallel_calls = max(1, min(total // NUM_ROWS_PER_CALL, NUM_PARALLEL_CALLS)) | |
parallel_input_rows = list(batched(islice(({k: row[k] for k in output_format_df["column"] if k in row} for row in stream_rows(dataset=dataset, subset=subset, split=split)), total), n=total // num_parallel_calls)) | |
parallel_output_rows = [[] for _ in range(num_parallel_calls)] | |
def run(i): | |
for batch_rows in batched(parallel_input_rows[i], n=NUM_ROWS_PER_CALL): | |
for row in stream_rewrite_dataset_row_by_row(dataset=dataset, rows=batch_rows, prompt=prompt, format=format): | |
parallel_output_rows[i].append({k: json.dumps(row[k], ensure_ascii=False) for k in output_format_df["column"]}) | |
yield 1 | |
current = 0 | |
_last_time = time.time() | |
try: | |
for step in iflatmap_unordered(run, kwargs_iterable=[{"i": i} for i in range(num_parallel_calls)]): | |
current += step | |
if _last_time + 1 / MAX_PROGRESS_UPDATES_PER_SECOND < time.time(): | |
_last_time = time.time() | |
yield { | |
full_dataset_generation_label: gr.Label({f"⚙️ ReWriting {dataset}": current / total}), | |
pretty_full_dataset_generation_output: gr.DataFrame(pd.DataFrame([row for rows in parallel_output_rows for row in rows])) | |
} | |
except ContextTooLongError: | |
raise gr.Error("Input dataset has too long context for the model") | |
yield { | |
full_dataset_generation_label: gr.Label({f"⚙️ ReWriting {dataset}": current / total}), | |
pretty_full_dataset_generation_output: gr.DataFrame(pd.DataFrame([row for rows in parallel_output_rows for row in rows])) | |
} | |
print(f"Done ReWriting {dataset} with instruction '{prompt}'") | |
output_rows = [{k: json.loads(row[k]) for k in output_format_df["column"]} for rows in parallel_output_rows for row in rows] | |
new_dataset = find_new_name(dataset + (PARTIAL_SUFFIX[max_num_rows] if num_examples > total else ""), prompt, format) | |
repo_id = namespace + "/" + new_dataset | |
yield {full_dataset_generation_label: gr.Label({f"✅ ReWriting {dataset}": len(output_rows) / total, f"⚙️ Saving to {repo_id}": 0.})} | |
token = oauth_token.token if oauth_token else save_dataset_hf_token | |
print(f"Saving {repo_id}") | |
ds = Dataset.from_list(output_rows) | |
ds.push_to_hub(repo_id, config_name=subset, split=split, token=token) | |
DatasetCard(DATASET_CARD_CONTENT.format(new_dataset=new_dataset, dataset=dataset, model_id=model_id, prompt=prompt, url=URL)).push_to_hub(repo_id=repo_id, repo_type="dataset", token=token) | |
yield {full_dataset_generation_label: gr.Label({f"✅ ReWriting {dataset}": len(output_rows) / total, f"✅ Saving to {repo_id}": 1.})} | |
yield {full_dataset_generation_success_html: ( | |
f'<a href="https://huggingface.co/datasets/{repo_id}" target="_blank">' | |
'<img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/dataset-on-hf-xl.svg" alt="Dataset on HF", style="margin-right: auto; margin-left: auto; max-width: fit-content;">' | |
'</a>' | |
)} | |
print(f"Saved {repo_id}") | |
demo.launch() | |