import multiprocessing import time from typing import Union import gradio as gr import pandas as pd from distilabel.distiset import Distiset from huggingface_hub import whoami from src.distilabel_dataset_generator.pipelines.sft import ( DEFAULT_DATASET, DEFAULT_DATASET_DESCRIPTION, DEFAULT_SYSTEM_PROMPT, PROMPT_CREATION_PROMPT, get_pipeline, get_prompt_generation_step, ) from src.distilabel_dataset_generator.utils import ( get_login_button, get_org_dropdown, ) def _run_pipeline(result_queue, num_turns, num_rows, system_prompt): pipeline = get_pipeline( num_turns, num_rows, system_prompt, ) distiset: Distiset = pipeline.run(use_cache=False) result_queue.put(distiset) def generate_system_prompt(dataset_description, progress=gr.Progress()): progress(0.1, desc="Initializing text generation") generate_description = get_prompt_generation_step() progress(0.4, desc="Loading model") generate_description.load() progress(0.7, desc="Generating system prompt") result = next( generate_description.process( [ { "system_prompt": PROMPT_CREATION_PROMPT, "instruction": dataset_description, } ] ) )[0]["generation"] progress(1.0, desc="System prompt generated") return result def generate_sample_dataset(system_prompt, progress=gr.Progress()): progress(0.1, desc="Initializing sample dataset generation") result = generate_dataset(system_prompt, num_turns=1, num_rows=2, progress=progress) progress(1.0, desc="Sample dataset generated") return result def generate_dataset( system_prompt, num_turns=1, num_rows=5, private=True, org_name=None, repo_name=None, progress=gr.Progress(), oauth_token: Union[gr.OAuthToken, None] ): repo_id = ( f"{org_name}/{repo_name}" if repo_name is not None and org_name is not None else None ) if repo_id is not None: if not repo_id: raise gr.Error( "Please provide a repo_name and org_name to push the dataset to." ) if num_turns > 4: num_turns = 4 gr.Info("You can only generate a dataset with 4 or fewer turns. Setting to 4.") if num_rows > 5000: num_rows = 5000 gr.Info( "You can only generate a dataset with 5000 or fewer rows. Setting to 5000." ) if num_rows < 50: duration = 60 elif num_rows < 250: duration = 300 elif num_rows < 1000: duration = 500 else: duration = 1000 result_queue = multiprocessing.Queue() p = multiprocessing.Process( target=_run_pipeline, args=(result_queue, num_turns, num_rows, system_prompt), ) try: p.start() total_steps = 100 for step in range(total_steps): if not p.is_alive() or p._popen.poll() is not None: break progress( (step + 1) / total_steps, desc=f"Generating dataset with {num_rows} rows. Don't close this window.", ) time.sleep(duration / total_steps) # Adjust this value based on your needs p.join() except Exception as e: raise gr.Error(f"An error occurred during dataset generation: {str(e)}") distiset = result_queue.get() if repo_id is not None: progress(0.95, desc="Pushing dataset to Hugging Face Hub.") distiset.push_to_hub( repo_id=repo_id, private=private, include_script=False, token=oauth_token.token, ) # If not pushing to hub generate the dataset directly distiset = distiset["default"]["train"] if num_turns == 1: outputs = distiset.to_pandas()[["prompt", "completion"]] else: outputs = distiset.to_pandas()[["messages"]] progress(1.0, desc="Dataset generation completed") return pd.DataFrame(outputs) def generate_pipeline_code() -> str: with open("src/distilabel_dataset_generator/pipelines/sft.py", "r") as f: pipeline_code = f.read() return pipeline_code css = """ .main_ui_logged_out{opacity: 0.3; pointer-events: none} """ with gr.Blocks( title="⚗️ Distilabel Dataset Generator", head="⚗️ Distilabel Dataset Generator", css=css ) as app: get_login_button() gr.Markdown("## Iterate on a sample dataset") with gr.Column() as main_ui: dataset_description = gr.TextArea( label="Provide a description of the dataset", value=DEFAULT_DATASET_DESCRIPTION, ) with gr.Row(): gr.Column(scale=1) btn_generate_system_prompt = gr.Button(value="Generate sample dataset") gr.Column(scale=1) system_prompt = gr.TextArea( label="If you want to improve the dataset, you can tune the system prompt and regenerate the sample", value=DEFAULT_SYSTEM_PROMPT, ) with gr.Row(): gr.Column(scale=1) btn_generate_sample_dataset = gr.Button( value="Regenerate sample dataset", ) gr.Column(scale=1) with gr.Row(): table = gr.DataFrame( value=DEFAULT_DATASET, interactive=False, wrap=True, ) result = btn_generate_system_prompt.click( fn=generate_system_prompt, inputs=[dataset_description], outputs=[system_prompt], show_progress=True, ).then( fn=generate_sample_dataset, inputs=[system_prompt], outputs=[table], show_progress=True, ) btn_generate_sample_dataset.click( fn=generate_sample_dataset, inputs=[system_prompt], outputs=[table], show_progress=True, ) # Add a header for the full dataset generation section gr.Markdown("## Generate full dataset") gr.Markdown( "Once you're satisfied with the sample, generate a larger dataset and push it to the Hub." ) with gr.Column() as push_to_hub_ui: with gr.Row(variant="panel"): num_turns = gr.Number( value=1, label="Number of turns in the conversation", minimum=1, maximum=4, step=1, info="Choose between 1 (single turn with 'instruction-response' columns) and 2-4 (multi-turn conversation with a 'messages' column).", ) num_rows = gr.Number( value=100, label="Number of rows in the dataset", minimum=1, maximum=5000, info="The number of rows in the dataset. Note that you are able to generate more rows at once but that this will take time.", ) with gr.Row(variant="panel"): org_name = get_org_dropdown() repo_name = gr.Textbox(label="Repo name", placeholder="dataset_name") private = gr.Checkbox( label="Private dataset", value=True, interactive=True, scale=0.5 ) btn_generate_full_dataset = gr.Button( value="⚗️ Generate Full Dataset", variant="primary" ) success_message = gr.Markdown(visible=False) def show_success_message(org_name, repo_name): return gr.Markdown( value=f"""

Dataset Published Successfully!

The generated dataset is in the right format for Fine-tuning with TRL, AutoTrain or other frameworks. Your dataset is now available at: https://huggingface.co/datasets/{org_name}/{repo_name}

""", visible=True, ) def hide_success_message(): return gr.Markdown(visible=False) btn_generate_full_dataset.click( fn=hide_success_message, outputs=[success_message], ).then( fn=generate_dataset, inputs=[ system_prompt, num_turns, num_rows, private, org_name, repo_name, ], outputs=[table], show_progress=True, ).then( fn=show_success_message, inputs=[org_name, repo_name], outputs=[success_message], ) gr.Markdown("## Or run this pipeline locally with distilabel") with gr.Accordion("Run this pipeline on Distilabel", open=False): pipeline_code = gr.Code( value=generate_pipeline_code(), language="python", label="Distilabel Pipeline Code", ) app.load(get_org_dropdown, outputs=[org_name])