Spaces:
Runtime error
Runtime error
Commit
ยท
136bd13
1
Parent(s):
d2df8be
update message
Browse files
src/synthetic_dataset_generator/app.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
| 1 |
from synthetic_dataset_generator._tabbedinterface import TabbedInterface
|
|
|
|
| 2 |
# from synthetic_dataset_generator.apps.eval import app as eval_app
|
| 3 |
from synthetic_dataset_generator.apps.readme import app as readme_app
|
| 4 |
from synthetic_dataset_generator.apps.sft import app as sft_app
|
|
@@ -15,9 +16,6 @@ button[role="tab"][aria-selected="true"]:hover {border-color: var(--button-prima
|
|
| 15 |
#system_prompt_examples { color: var(--body-text-color) !important; background-color: var(--block-background-fill) !important;}
|
| 16 |
.container {padding-inline: 0 !important}
|
| 17 |
#sign_in_button { flex-grow: 0; width: auto !important; display: flex; align-items: center; justify-content: center; margin: 0 auto; }
|
| 18 |
-
.table-view .table-wrap {
|
| 19 |
-
max-height: 450px;
|
| 20 |
-
}
|
| 21 |
"""
|
| 22 |
|
| 23 |
image = """<br><img src="https://raw.githubusercontent.com/argilla-io/synthetic-data-generator/main/assets/logo.svg" alt="Synthetic Data Generator Logo" style="display: block; margin-left: auto; margin-right: auto; width: clamp(50%, 400px, 100%)"/>"""
|
|
|
|
| 1 |
from synthetic_dataset_generator._tabbedinterface import TabbedInterface
|
| 2 |
+
|
| 3 |
# from synthetic_dataset_generator.apps.eval import app as eval_app
|
| 4 |
from synthetic_dataset_generator.apps.readme import app as readme_app
|
| 5 |
from synthetic_dataset_generator.apps.sft import app as sft_app
|
|
|
|
| 16 |
#system_prompt_examples { color: var(--body-text-color) !important; background-color: var(--block-background-fill) !important;}
|
| 17 |
.container {padding-inline: 0 !important}
|
| 18 |
#sign_in_button { flex-grow: 0; width: auto !important; display: flex; align-items: center; justify-content: center; margin: 0 auto; }
|
|
|
|
|
|
|
|
|
|
| 19 |
"""
|
| 20 |
|
| 21 |
image = """<br><img src="https://raw.githubusercontent.com/argilla-io/synthetic-data-generator/main/assets/logo.svg" alt="Synthetic Data Generator Logo" style="display: block; margin-left: auto; margin-right: auto; width: clamp(50%, 400px, 100%)"/>"""
|
src/synthetic_dataset_generator/apps/eval.py
CHANGED
|
@@ -750,7 +750,6 @@ with gr.Blocks() as app:
|
|
| 750 |
headers=["prompt", "completion", "evaluation"],
|
| 751 |
wrap=True,
|
| 752 |
interactive=False,
|
| 753 |
-
elem_classes="table-view",
|
| 754 |
)
|
| 755 |
|
| 756 |
gr.HTML(value="<hr>")
|
|
|
|
| 750 |
headers=["prompt", "completion", "evaluation"],
|
| 751 |
wrap=True,
|
| 752 |
interactive=False,
|
|
|
|
| 753 |
)
|
| 754 |
|
| 755 |
gr.HTML(value="<hr>")
|
src/synthetic_dataset_generator/apps/sft.py
CHANGED
|
@@ -55,10 +55,10 @@ def convert_dataframe_messages(dataframe: pd.DataFrame) -> pd.DataFrame:
|
|
| 55 |
|
| 56 |
|
| 57 |
def generate_system_prompt(dataset_description, progress=gr.Progress()):
|
| 58 |
-
progress(0.0, desc="
|
| 59 |
-
progress(0.3, desc="Initializing
|
| 60 |
generate_description = get_prompt_generator()
|
| 61 |
-
progress(0.7, desc="Generating
|
| 62 |
result = next(
|
| 63 |
generate_description.process(
|
| 64 |
[
|
|
@@ -68,7 +68,7 @@ def generate_system_prompt(dataset_description, progress=gr.Progress()):
|
|
| 68 |
]
|
| 69 |
)
|
| 70 |
)[0]["generation"]
|
| 71 |
-
progress(1.0, desc="
|
| 72 |
return result
|
| 73 |
|
| 74 |
|
|
@@ -88,7 +88,6 @@ def _get_dataframe():
|
|
| 88 |
headers=["prompt", "completion"],
|
| 89 |
wrap=True,
|
| 90 |
interactive=False,
|
| 91 |
-
elem_classes="table-view",
|
| 92 |
)
|
| 93 |
|
| 94 |
|
|
|
|
| 55 |
|
| 56 |
|
| 57 |
def generate_system_prompt(dataset_description, progress=gr.Progress()):
|
| 58 |
+
progress(0.0, desc="Starting")
|
| 59 |
+
progress(0.3, desc="Initializing")
|
| 60 |
generate_description = get_prompt_generator()
|
| 61 |
+
progress(0.7, desc="Generating")
|
| 62 |
result = next(
|
| 63 |
generate_description.process(
|
| 64 |
[
|
|
|
|
| 68 |
]
|
| 69 |
)
|
| 70 |
)[0]["generation"]
|
| 71 |
+
progress(1.0, desc="Prompt generated")
|
| 72 |
return result
|
| 73 |
|
| 74 |
|
|
|
|
| 88 |
headers=["prompt", "completion"],
|
| 89 |
wrap=True,
|
| 90 |
interactive=False,
|
|
|
|
| 91 |
)
|
| 92 |
|
| 93 |
|
src/synthetic_dataset_generator/apps/textcat.py
CHANGED
|
@@ -42,15 +42,14 @@ def _get_dataframe():
|
|
| 42 |
headers=["labels", "text"],
|
| 43 |
wrap=True,
|
| 44 |
interactive=False,
|
| 45 |
-
elem_classes="table-view",
|
| 46 |
)
|
| 47 |
|
| 48 |
|
| 49 |
def generate_system_prompt(dataset_description, progress=gr.Progress()):
|
| 50 |
-
progress(0.0, desc="
|
| 51 |
-
progress(0.3, desc="Initializing
|
| 52 |
generate_description = get_prompt_generator()
|
| 53 |
-
progress(0.7, desc="Generating
|
| 54 |
result = next(
|
| 55 |
generate_description.process(
|
| 56 |
[
|
|
@@ -60,7 +59,7 @@ def generate_system_prompt(dataset_description, progress=gr.Progress()):
|
|
| 60 |
]
|
| 61 |
)
|
| 62 |
)[0]["generation"]
|
| 63 |
-
progress(1.0, desc="
|
| 64 |
data = json.loads(result)
|
| 65 |
system_prompt = data["classification_task"]
|
| 66 |
labels = data["labels"]
|
|
@@ -94,7 +93,7 @@ def generate_dataset(
|
|
| 94 |
is_sample: bool = False,
|
| 95 |
progress=gr.Progress(),
|
| 96 |
) -> pd.DataFrame:
|
| 97 |
-
progress(0.0, desc="(1/2) Generating
|
| 98 |
labels = get_preprocess_labels(labels)
|
| 99 |
textcat_generator = get_textcat_generator(
|
| 100 |
difficulty=difficulty,
|
|
@@ -117,7 +116,7 @@ def generate_dataset(
|
|
| 117 |
progress(
|
| 118 |
2 * 0.5 * n_processed / num_rows,
|
| 119 |
total=total_steps,
|
| 120 |
-
desc="(1/2) Generating
|
| 121 |
)
|
| 122 |
remaining_rows = num_rows - n_processed
|
| 123 |
batch_size = min(batch_size, remaining_rows)
|
|
@@ -139,14 +138,14 @@ def generate_dataset(
|
|
| 139 |
result["text"] = result["input_text"]
|
| 140 |
|
| 141 |
# label text classification data
|
| 142 |
-
progress(2 * 0.5, desc="(
|
| 143 |
n_processed = 0
|
| 144 |
labeller_results = []
|
| 145 |
while n_processed < num_rows:
|
| 146 |
progress(
|
| 147 |
0.5 + 0.5 * n_processed / num_rows,
|
| 148 |
total=total_steps,
|
| 149 |
-
desc="(
|
| 150 |
)
|
| 151 |
batch = textcat_results[n_processed : n_processed + batch_size]
|
| 152 |
labels_batch = list(labeller_generator.process(inputs=batch))
|
|
@@ -182,7 +181,7 @@ def generate_dataset(
|
|
| 182 |
)
|
| 183 |
)
|
| 184 |
)
|
| 185 |
-
progress(1.0, desc="Dataset
|
| 186 |
return dataframe
|
| 187 |
|
| 188 |
|
|
@@ -316,7 +315,7 @@ def push_dataset(
|
|
| 316 |
client=client,
|
| 317 |
)
|
| 318 |
rg_dataset = rg_dataset.create()
|
| 319 |
-
progress(0.7, desc="Pushing dataset
|
| 320 |
hf_dataset = Dataset.from_pandas(dataframe)
|
| 321 |
records = [
|
| 322 |
rg.Record(
|
|
@@ -347,7 +346,7 @@ def push_dataset(
|
|
| 347 |
for sample in hf_dataset
|
| 348 |
]
|
| 349 |
rg_dataset.records.log(records=records)
|
| 350 |
-
progress(1.0, desc="Dataset pushed
|
| 351 |
except Exception as e:
|
| 352 |
raise gr.Error(f"Error pushing dataset to Argilla: {e}")
|
| 353 |
return ""
|
|
@@ -406,61 +405,64 @@ with gr.Blocks() as app:
|
|
| 406 |
|
| 407 |
gr.HTML("<hr>")
|
| 408 |
gr.Markdown("## 2. Configure your dataset")
|
| 409 |
-
with gr.Row(equal_height=
|
| 410 |
-
with gr.
|
| 411 |
-
|
| 412 |
-
|
| 413 |
-
|
| 414 |
-
|
| 415 |
-
|
| 416 |
-
|
| 417 |
-
|
| 418 |
-
|
| 419 |
-
|
| 420 |
-
|
| 421 |
-
|
| 422 |
-
|
| 423 |
-
|
| 424 |
-
|
| 425 |
-
|
| 426 |
-
|
| 427 |
-
|
| 428 |
-
|
| 429 |
-
|
| 430 |
-
|
| 431 |
-
|
| 432 |
-
|
| 433 |
-
|
| 434 |
-
|
| 435 |
-
|
| 436 |
-
|
| 437 |
-
|
| 438 |
-
|
| 439 |
-
|
| 440 |
-
|
| 441 |
-
|
| 442 |
-
|
| 443 |
-
|
| 444 |
-
|
| 445 |
-
|
| 446 |
-
|
| 447 |
-
|
| 448 |
-
|
| 449 |
-
|
| 450 |
-
|
| 451 |
-
|
| 452 |
-
|
| 453 |
-
|
| 454 |
-
|
| 455 |
-
|
| 456 |
-
|
| 457 |
-
|
| 458 |
-
|
| 459 |
-
|
| 460 |
-
|
| 461 |
-
|
| 462 |
-
|
| 463 |
-
|
|
|
|
|
|
|
|
|
|
| 464 |
|
| 465 |
gr.HTML("<hr>")
|
| 466 |
gr.Markdown("## 3. Generate your dataset")
|
|
|
|
| 42 |
headers=["labels", "text"],
|
| 43 |
wrap=True,
|
| 44 |
interactive=False,
|
|
|
|
| 45 |
)
|
| 46 |
|
| 47 |
|
| 48 |
def generate_system_prompt(dataset_description, progress=gr.Progress()):
|
| 49 |
+
progress(0.0, desc="Starting")
|
| 50 |
+
progress(0.3, desc="Initializing")
|
| 51 |
generate_description = get_prompt_generator()
|
| 52 |
+
progress(0.7, desc="Generating")
|
| 53 |
result = next(
|
| 54 |
generate_description.process(
|
| 55 |
[
|
|
|
|
| 59 |
]
|
| 60 |
)
|
| 61 |
)[0]["generation"]
|
| 62 |
+
progress(1.0, desc="Prompt generated")
|
| 63 |
data = json.loads(result)
|
| 64 |
system_prompt = data["classification_task"]
|
| 65 |
labels = data["labels"]
|
|
|
|
| 93 |
is_sample: bool = False,
|
| 94 |
progress=gr.Progress(),
|
| 95 |
) -> pd.DataFrame:
|
| 96 |
+
progress(0.0, desc="(1/2) Generating dataset")
|
| 97 |
labels = get_preprocess_labels(labels)
|
| 98 |
textcat_generator = get_textcat_generator(
|
| 99 |
difficulty=difficulty,
|
|
|
|
| 116 |
progress(
|
| 117 |
2 * 0.5 * n_processed / num_rows,
|
| 118 |
total=total_steps,
|
| 119 |
+
desc="(1/2) Generating dataset",
|
| 120 |
)
|
| 121 |
remaining_rows = num_rows - n_processed
|
| 122 |
batch_size = min(batch_size, remaining_rows)
|
|
|
|
| 138 |
result["text"] = result["input_text"]
|
| 139 |
|
| 140 |
# label text classification data
|
| 141 |
+
progress(2 * 0.5, desc="(2/2) Labeling dataset")
|
| 142 |
n_processed = 0
|
| 143 |
labeller_results = []
|
| 144 |
while n_processed < num_rows:
|
| 145 |
progress(
|
| 146 |
0.5 + 0.5 * n_processed / num_rows,
|
| 147 |
total=total_steps,
|
| 148 |
+
desc="(2/2) Labeling dataset",
|
| 149 |
)
|
| 150 |
batch = textcat_results[n_processed : n_processed + batch_size]
|
| 151 |
labels_batch = list(labeller_generator.process(inputs=batch))
|
|
|
|
| 181 |
)
|
| 182 |
)
|
| 183 |
)
|
| 184 |
+
progress(1.0, desc="Dataset created")
|
| 185 |
return dataframe
|
| 186 |
|
| 187 |
|
|
|
|
| 315 |
client=client,
|
| 316 |
)
|
| 317 |
rg_dataset = rg_dataset.create()
|
| 318 |
+
progress(0.7, desc="Pushing dataset")
|
| 319 |
hf_dataset = Dataset.from_pandas(dataframe)
|
| 320 |
records = [
|
| 321 |
rg.Record(
|
|
|
|
| 346 |
for sample in hf_dataset
|
| 347 |
]
|
| 348 |
rg_dataset.records.log(records=records)
|
| 349 |
+
progress(1.0, desc="Dataset pushed")
|
| 350 |
except Exception as e:
|
| 351 |
raise gr.Error(f"Error pushing dataset to Argilla: {e}")
|
| 352 |
return ""
|
|
|
|
| 405 |
|
| 406 |
gr.HTML("<hr>")
|
| 407 |
gr.Markdown("## 2. Configure your dataset")
|
| 408 |
+
with gr.Row(equal_height=True):
|
| 409 |
+
with gr.Row(equal_height=False):
|
| 410 |
+
with gr.Column(scale=2):
|
| 411 |
+
system_prompt = gr.Textbox(
|
| 412 |
+
label="System prompt",
|
| 413 |
+
placeholder="You are a helpful assistant.",
|
| 414 |
+
visible=True,
|
| 415 |
+
)
|
| 416 |
+
labels = gr.Dropdown(
|
| 417 |
+
choices=[],
|
| 418 |
+
allow_custom_value=True,
|
| 419 |
+
interactive=True,
|
| 420 |
+
label="Labels",
|
| 421 |
+
multiselect=True,
|
| 422 |
+
info="Add the labels to classify the text.",
|
| 423 |
+
)
|
| 424 |
+
num_labels = gr.Number(
|
| 425 |
+
label="Number of labels per text",
|
| 426 |
+
value=1,
|
| 427 |
+
minimum=1,
|
| 428 |
+
maximum=10,
|
| 429 |
+
info="Select 1 for single-label and >1 for multi-label.",
|
| 430 |
+
interactive=True,
|
| 431 |
+
)
|
| 432 |
+
clarity = gr.Dropdown(
|
| 433 |
+
choices=[
|
| 434 |
+
("Clear", "clear"),
|
| 435 |
+
(
|
| 436 |
+
"Understandable",
|
| 437 |
+
"understandable with some effort",
|
| 438 |
+
),
|
| 439 |
+
("Ambiguous", "ambiguous"),
|
| 440 |
+
("Mixed", "mixed"),
|
| 441 |
+
],
|
| 442 |
+
value="understandable with some effort",
|
| 443 |
+
label="Clarity",
|
| 444 |
+
info="Set how easily the correct label or labels can be identified.",
|
| 445 |
+
interactive=True,
|
| 446 |
+
)
|
| 447 |
+
difficulty = gr.Dropdown(
|
| 448 |
+
choices=[
|
| 449 |
+
("High School", "high school"),
|
| 450 |
+
("College", "college"),
|
| 451 |
+
("PhD", "PhD"),
|
| 452 |
+
("Mixed", "mixed"),
|
| 453 |
+
],
|
| 454 |
+
value="high school",
|
| 455 |
+
label="Difficulty",
|
| 456 |
+
info="Select the comprehension level for the text. Ensure it matches the task context.",
|
| 457 |
+
interactive=True,
|
| 458 |
+
)
|
| 459 |
+
with gr.Row():
|
| 460 |
+
clear_btn_full = gr.Button("Clear", variant="secondary")
|
| 461 |
+
btn_apply_to_sample_dataset = gr.Button(
|
| 462 |
+
"Save", variant="primary"
|
| 463 |
+
)
|
| 464 |
+
with gr.Column(scale=3):
|
| 465 |
+
dataframe = _get_dataframe()
|
| 466 |
|
| 467 |
gr.HTML("<hr>")
|
| 468 |
gr.Markdown("## 3. Generate your dataset")
|