LLM_DataGen / gradio_app.py
lhoestq's picture
lhoestq HF staff
revert batching
451395b
raw history blame
No virus
4.16 kB
from pathlib import Path
from urllib.parse import urlparse, parse_qs
import gradio as gr
import io
import pandas as pd
import spaces
from generate import model_id, stream_jsonl_file
MAX_SIZE = 20
DEFAULT_SEED = 42
DEFAULT_SIZE = 3
@spaces.GPU(duration=120)
def stream_output(query: str, continue_content: str = ""):
query = Path(query).name
parsed_filename = urlparse(query)
filename = parsed_filename.path
params = parse_qs(parsed_filename.query)
prompt = params["prompt"][0] if "prompt" in params else ""
columns = [column.strip() for column in params["columns"][0].split(",") if column.strip()] if "columns" in params else []
size = int(params["size"][0]) if "size" in params else DEFAULT_SIZE
seed = int(params["seed"][0]) if "seed" in params else DEFAULT_SEED
if size > MAX_SIZE:
raise gr.Error(f"Maximum size is {MAX_SIZE}. Duplicate this Space to remove this limit.")
content = continue_content
df = pd.read_json(io.StringIO(content), lines=True, convert_dates=False)
continue_content_size = len(df)
state_msg = f"⚙️ Generating... [{continue_content_size + 1}/{continue_content_size + size}]"
if list(df.columns):
columns = list(df.columns)
else:
df = pd.DataFrame({"1": [], "2": [], "3": []})
yield df, "```json\n" + content + "\n```", gr.Button(state_msg), gr.Button("Generate one more batch", interactive=False), gr.DownloadButton("⬇️ Download", interactive=False)
for i, chunk in enumerate(stream_jsonl_file(
filename=filename,
prompt=prompt,
columns=columns,
seed=seed + (continue_content_size // size),
size=size,
)):
content += chunk
df = pd.read_json(io.StringIO(content), lines=True, convert_dates=False)
state_msg = f"⚙️ Generating... [{continue_content_size + i + 1}/{continue_content_size + size}]"
yield df, "```json\n" + content + "\n```", gr.Button(state_msg), gr.Button("Generate one more batch", interactive=False), gr.DownloadButton("⬇️ Download", interactive=False)
with open(query, "w", encoding="utf-8") as f:
f.write(content)
yield df, "```json\n" + content + "\n```", gr.Button("Generate dataset"), gr.Button("Generate one more batch", visible=True, interactive=True), gr.DownloadButton("⬇️ Download", value=query, visible=True, interactive=True)
def stream_more_output(query: str):
query = Path(query).name
with open(query, "r", encoding="utf-8") as f:
continue_content = f.read()
yield from stream_output(query=query, continue_content=continue_content)
title = "LLM DataGen"
description = (
f"Generate and stream synthetic dataset files in `{{JSON Lines}}` format (currently using [{model_id}](https://huggingface.co/{model_id}))\n\n"
"Disclaimer: LLM data generation is an area of active research with known problems such as biased generation and incorrect information."
)
examples = [
"movies_data.jsonl",
"dungeon_and_dragon_characters.jsonl",
"bad_amazon_reviews_on_defunct_products_that_people_hate.jsonl",
"common_first_names.jsonl?columns=first_name,popularity&size=10",
]
with gr.Blocks() as demo:
gr.Markdown(f"# {title}")
gr.Markdown(description)
filename_comp = gr.Textbox(examples[0], placeholder=examples[0], label="File name to generate")
outputs = []
generate_button = gr.Button("Generate dataset")
with gr.Tab("Dataset"):
dataframe_comp = gr.DataFrame()
with gr.Tab("File content"):
file_content_comp = gr.Markdown()
with gr.Row():
generate_more_button = gr.Button("Generate one more batch", visible=False, interactive=False, scale=3)
download_button = gr.DownloadButton("⬇️ Download", visible=False, interactive=False, scale=1)
outputs = [dataframe_comp, file_content_comp, generate_button, generate_more_button, download_button]
examples = gr.Examples(examples, filename_comp, outputs, fn=stream_output, run_on_click=True)
generate_button.click(stream_output, filename_comp, outputs)
generate_more_button.click(stream_more_output, filename_comp, outputs)
demo.launch()