|
from typing import List |
|
|
|
import gradio as gr |
|
import pandas as pd |
|
|
|
from src.distilabel_dataset_generator.apps.base import ( |
|
get_main_ui, |
|
get_pipeline_code_ui, |
|
hide_success_message, |
|
push_dataset_to_hub, |
|
push_pipeline_code_to_hub, |
|
show_success_message_argilla, |
|
show_success_message_hub, |
|
validate_argilla_user_workspace_dataset, |
|
) |
|
from src.distilabel_dataset_generator.pipelines.textcat import ( |
|
DEFAULT_DATASET_DESCRIPTIONS, |
|
DEFAULT_DATASETS, |
|
DEFAULT_SYSTEM_PROMPTS, |
|
generate_pipeline_code, |
|
) |
|
|
|
|
|
def push_dataset_to_argilla(dataset: pd.DataFrame, dataset_name: str) -> pd.DataFrame: |
|
return dataset |
|
|
|
|
|
def generate_system_prompt(dataset_description: str) -> str: |
|
return dataset_description |
|
|
|
|
|
def generate_dataset( |
|
system_prompt: str, labels: List[str], multi_label: bool |
|
) -> pd.DataFrame: |
|
return pd.DataFrame({"prompt": [system_prompt], "completion": [system_prompt]}) |
|
|
|
|
|
( |
|
app, |
|
main_ui, |
|
custom_input_ui, |
|
dataset_description, |
|
examples, |
|
btn_generate_system_prompt, |
|
system_prompt, |
|
sample_dataset, |
|
btn_generate_sample_dataset, |
|
dataset_name, |
|
add_to_existing_dataset, |
|
btn_generate_full_dataset_copy, |
|
btn_generate_and_push_to_argilla, |
|
btn_push_to_argilla, |
|
org_name, |
|
repo_name, |
|
private, |
|
btn_generate_full_dataset, |
|
btn_generate_and_push_to_hub, |
|
btn_push_to_hub, |
|
final_dataset, |
|
success_message, |
|
) = get_main_ui( |
|
default_dataset_descriptions=DEFAULT_DATASET_DESCRIPTIONS, |
|
default_system_prompts=DEFAULT_SYSTEM_PROMPTS, |
|
default_datasets=DEFAULT_DATASETS, |
|
fn_generate_system_prompt=generate_system_prompt, |
|
fn_generate_dataset=generate_dataset, |
|
) |
|
|
|
with app: |
|
with main_ui: |
|
with custom_input_ui: |
|
labels = gr.Dropdown( |
|
choices=[], |
|
allow_custom_value=True, |
|
interactive=True, |
|
label="Labels", |
|
multiselect=True, |
|
) |
|
num_labels = gr.Number( |
|
label="Number of labels", value=2, minimum=1, maximum=10 |
|
) |
|
num_rows = gr.Number( |
|
label="Number of rows", value=10, minimum=1, maximum=500 |
|
) |
|
|
|
pipeline_code = get_pipeline_code_ui( |
|
generate_pipeline_code(system_prompt.value, labels.value, multi_label.value) |
|
) |
|
|
|
|
|
gr.on( |
|
triggers=[ |
|
btn_generate_full_dataset.click, |
|
btn_generate_full_dataset_copy.click, |
|
], |
|
fn=hide_success_message, |
|
outputs=[success_message], |
|
).then( |
|
fn=generate_dataset, |
|
inputs=[system_prompt, labels, multi_label], |
|
outputs=[final_dataset], |
|
show_progress=True, |
|
) |
|
|
|
btn_generate_and_push_to_argilla.click( |
|
fn=validate_argilla_user_workspace_dataset, |
|
inputs=[dataset_name, final_dataset, add_to_existing_dataset], |
|
outputs=[final_dataset], |
|
show_progress=True, |
|
).success( |
|
fn=hide_success_message, |
|
outputs=[success_message], |
|
).success( |
|
fn=generate_dataset, |
|
inputs=[system_prompt, labels, multi_label], |
|
outputs=[final_dataset], |
|
show_progress=True, |
|
).success( |
|
fn=push_dataset_to_argilla, |
|
inputs=[final_dataset, dataset_name], |
|
outputs=[final_dataset], |
|
show_progress=True, |
|
).success( |
|
fn=show_success_message_argilla, |
|
inputs=[], |
|
outputs=[success_message], |
|
) |
|
|
|
btn_generate_and_push_to_hub.click( |
|
fn=hide_success_message, |
|
outputs=[success_message], |
|
).then( |
|
fn=generate_dataset, |
|
inputs=[system_prompt, labels, multi_label], |
|
outputs=[final_dataset], |
|
show_progress=True, |
|
).then( |
|
fn=push_dataset_to_hub, |
|
inputs=[final_dataset, private, org_name, repo_name], |
|
outputs=[final_dataset], |
|
show_progress=True, |
|
).then( |
|
fn=push_pipeline_code_to_hub, |
|
inputs=[pipeline_code, org_name, repo_name], |
|
outputs=[], |
|
show_progress=True, |
|
).success( |
|
fn=show_success_message_hub, |
|
inputs=[org_name, repo_name], |
|
outputs=[success_message], |
|
) |
|
|
|
btn_push_to_hub.click( |
|
fn=hide_success_message, |
|
outputs=[success_message], |
|
).then( |
|
fn=push_dataset_to_hub, |
|
inputs=[final_dataset, private, org_name, repo_name], |
|
outputs=[final_dataset], |
|
show_progress=True, |
|
).then( |
|
fn=push_pipeline_code_to_hub, |
|
inputs=[pipeline_code, org_name, repo_name], |
|
outputs=[], |
|
show_progress=True, |
|
).success( |
|
fn=show_success_message_hub, |
|
inputs=[org_name, repo_name], |
|
outputs=[success_message], |
|
) |
|
|
|
btn_push_to_argilla.click( |
|
fn=hide_success_message, |
|
outputs=[success_message], |
|
).success( |
|
fn=validate_argilla_user_workspace_dataset, |
|
inputs=[dataset_name, final_dataset, add_to_existing_dataset], |
|
outputs=[final_dataset], |
|
show_progress=True, |
|
).success( |
|
fn=push_dataset_to_argilla, |
|
inputs=[final_dataset, dataset_name], |
|
outputs=[final_dataset], |
|
show_progress=True, |
|
).success( |
|
fn=show_success_message_argilla, |
|
inputs=[], |
|
outputs=[success_message], |
|
) |
|
|
|
system_prompt.change( |
|
fn=generate_pipeline_code, |
|
inputs=[system_prompt, labels, multi_label], |
|
outputs=[pipeline_code], |
|
) |
|
labels.change( |
|
fn=generate_pipeline_code, |
|
inputs=[system_prompt, labels, multi_label], |
|
outputs=[pipeline_code], |
|
) |
|
multi_label.change( |
|
fn=generate_pipeline_code, |
|
inputs=[system_prompt, labels, multi_label], |
|
outputs=[pipeline_code], |
|
) |
|
|