Spaces:
Runtime error
Runtime error
textcat-review (#12)
Browse files- fix: apply feedback (3f2128047133bf339ddb5d4c4a6d6d41edb368da)
- fix: remove extra args (d27c1e6872d8be8b25d12d2b7e4baa5c075ed8c5)
- fix: add seed for more randomized samples (46f00bc57d59efb6274c287aa3b3ab0046d4d64e)
- fix: typo (a3f4be77171e6db5c804079dc82e0e30354bcec9)
- fix: correction label or labels (d59361703bc3414d4d5845cbe57ae760b52be7fc)
- fix: duplicated labels in labels and number of rows update listener in raw pipeline (b92482822c81a2a4330d54fd35640c6984ae8bda)
src/distilabel_dataset_generator/apps/base.py
CHANGED
|
@@ -38,8 +38,8 @@ def get_main_ui(
|
|
| 38 |
if task == TEXTCAT_TASK:
|
| 39 |
result = fn_generate_dataset(
|
| 40 |
system_prompt=system_prompt,
|
| 41 |
-
difficulty="
|
| 42 |
-
clarity="
|
| 43 |
labels=[],
|
| 44 |
num_labels=1,
|
| 45 |
num_rows=1,
|
|
@@ -271,7 +271,11 @@ def get_iterate_on_sample_dataset_ui(
|
|
| 271 |
with gr.Row():
|
| 272 |
sample_dataset = gr.Dataframe(
|
| 273 |
value=default_datasets[0],
|
| 274 |
-
label=
|
|
|
|
|
|
|
|
|
|
|
|
|
| 275 |
interactive=False,
|
| 276 |
wrap=True,
|
| 277 |
)
|
|
|
|
| 38 |
if task == TEXTCAT_TASK:
|
| 39 |
result = fn_generate_dataset(
|
| 40 |
system_prompt=system_prompt,
|
| 41 |
+
difficulty="high school",
|
| 42 |
+
clarity="clear",
|
| 43 |
labels=[],
|
| 44 |
num_labels=1,
|
| 45 |
num_rows=1,
|
|
|
|
| 271 |
with gr.Row():
|
| 272 |
sample_dataset = gr.Dataframe(
|
| 273 |
value=default_datasets[0],
|
| 274 |
+
label=(
|
| 275 |
+
"Sample dataset. Text truncated to 256 tokens."
|
| 276 |
+
if task == TEXTCAT_TASK
|
| 277 |
+
else "Sample dataset. Prompts and completions truncated to 256 tokens."
|
| 278 |
+
),
|
| 279 |
interactive=False,
|
| 280 |
wrap=True,
|
| 281 |
)
|
src/distilabel_dataset_generator/apps/textcat.py
CHANGED
|
@@ -215,7 +215,6 @@ def generate_dataset(
|
|
| 215 |
system_prompt=system_prompt,
|
| 216 |
labels=labels,
|
| 217 |
num_labels=num_labels,
|
| 218 |
-
is_sample=is_sample,
|
| 219 |
)
|
| 220 |
total_steps: int = num_rows * 2
|
| 221 |
batch_size = DEFAULT_BATCH_SIZE
|
|
@@ -280,11 +279,13 @@ def generate_dataset(
|
|
| 280 |
else:
|
| 281 |
dataframe["labels"] = dataframe["labels"].apply(
|
| 282 |
lambda x: (
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
|
|
|
|
|
|
|
| 288 |
if isinstance(x, list)
|
| 289 |
else None
|
| 290 |
)
|
|
@@ -309,6 +310,9 @@ def validate_input_labels(labels):
|
|
| 309 |
)
|
| 310 |
return labels
|
| 311 |
|
|
|
|
|
|
|
|
|
|
| 312 |
|
| 313 |
(
|
| 314 |
app,
|
|
@@ -354,7 +358,7 @@ with app:
|
|
| 354 |
],
|
| 355 |
value="mixed",
|
| 356 |
label="Difficulty",
|
| 357 |
-
info="
|
| 358 |
)
|
| 359 |
clarity = gr.Dropdown(
|
| 360 |
choices=[
|
|
@@ -368,7 +372,7 @@ with app:
|
|
| 368 |
],
|
| 369 |
value="mixed",
|
| 370 |
label="Clarity",
|
| 371 |
-
info="
|
| 372 |
)
|
| 373 |
with gr.Column():
|
| 374 |
labels = gr.Dropdown(
|
|
@@ -385,18 +389,18 @@ with app:
|
|
| 385 |
size="sm",
|
| 386 |
)
|
| 387 |
num_labels = gr.Number(
|
| 388 |
-
label="Number of labels",
|
| 389 |
value=1,
|
| 390 |
minimum=1,
|
| 391 |
maximum=10,
|
| 392 |
-
info="
|
| 393 |
)
|
| 394 |
num_rows = gr.Number(
|
| 395 |
label="Number of rows",
|
| 396 |
value=10,
|
| 397 |
minimum=1,
|
| 398 |
maximum=500,
|
| 399 |
-
info="More rows will take
|
| 400 |
)
|
| 401 |
|
| 402 |
pipeline_code = get_pipeline_code_ui(
|
|
@@ -415,6 +419,10 @@ with app:
|
|
| 415 |
fn=update_suggested_labels,
|
| 416 |
inputs=[system_prompt],
|
| 417 |
outputs=labels,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 418 |
)
|
| 419 |
|
| 420 |
gr.on(
|
|
@@ -540,9 +548,18 @@ with app:
|
|
| 540 |
fn=generate_pipeline_code,
|
| 541 |
inputs=[system_prompt, difficulty, clarity, labels, num_labels, num_rows],
|
| 542 |
outputs=[pipeline_code],
|
|
|
|
|
|
|
|
|
|
|
|
|
| 543 |
)
|
| 544 |
num_labels.change(
|
| 545 |
fn=generate_pipeline_code,
|
| 546 |
inputs=[system_prompt, difficulty, clarity, labels, num_labels, num_rows],
|
| 547 |
outputs=[pipeline_code],
|
| 548 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 215 |
system_prompt=system_prompt,
|
| 216 |
labels=labels,
|
| 217 |
num_labels=num_labels,
|
|
|
|
| 218 |
)
|
| 219 |
total_steps: int = num_rows * 2
|
| 220 |
batch_size = DEFAULT_BATCH_SIZE
|
|
|
|
| 279 |
else:
|
| 280 |
dataframe["labels"] = dataframe["labels"].apply(
|
| 281 |
lambda x: (
|
| 282 |
+
list(
|
| 283 |
+
set(
|
| 284 |
+
label.lower().strip()
|
| 285 |
+
for label in x
|
| 286 |
+
if label.lower().strip() in labels
|
| 287 |
+
)
|
| 288 |
+
)
|
| 289 |
if isinstance(x, list)
|
| 290 |
else None
|
| 291 |
)
|
|
|
|
| 310 |
)
|
| 311 |
return labels
|
| 312 |
|
| 313 |
+
def update_max_num_labels(labels):
|
| 314 |
+
return gr.update(maximum=len(labels) if labels else 1)
|
| 315 |
+
|
| 316 |
|
| 317 |
(
|
| 318 |
app,
|
|
|
|
| 358 |
],
|
| 359 |
value="mixed",
|
| 360 |
label="Difficulty",
|
| 361 |
+
info="Select the comprehension level for the text. Ensure it matches the task context.",
|
| 362 |
)
|
| 363 |
clarity = gr.Dropdown(
|
| 364 |
choices=[
|
|
|
|
| 372 |
],
|
| 373 |
value="mixed",
|
| 374 |
label="Clarity",
|
| 375 |
+
info="Set how easily the correct label or labels can be identified.",
|
| 376 |
)
|
| 377 |
with gr.Column():
|
| 378 |
labels = gr.Dropdown(
|
|
|
|
| 389 |
size="sm",
|
| 390 |
)
|
| 391 |
num_labels = gr.Number(
|
| 392 |
+
label="Number of labels per text",
|
| 393 |
value=1,
|
| 394 |
minimum=1,
|
| 395 |
maximum=10,
|
| 396 |
+
info="Select 1 for single-label and >1 for multi-label.",
|
| 397 |
)
|
| 398 |
num_rows = gr.Number(
|
| 399 |
label="Number of rows",
|
| 400 |
value=10,
|
| 401 |
minimum=1,
|
| 402 |
maximum=500,
|
| 403 |
+
info="Select the number of rows in the dataset. More rows will take more time.",
|
| 404 |
)
|
| 405 |
|
| 406 |
pipeline_code = get_pipeline_code_ui(
|
|
|
|
| 419 |
fn=update_suggested_labels,
|
| 420 |
inputs=[system_prompt],
|
| 421 |
outputs=labels,
|
| 422 |
+
).then(
|
| 423 |
+
fn=update_max_num_labels,
|
| 424 |
+
inputs=[labels],
|
| 425 |
+
outputs=[num_labels],
|
| 426 |
)
|
| 427 |
|
| 428 |
gr.on(
|
|
|
|
| 548 |
fn=generate_pipeline_code,
|
| 549 |
inputs=[system_prompt, difficulty, clarity, labels, num_labels, num_rows],
|
| 550 |
outputs=[pipeline_code],
|
| 551 |
+
).then(
|
| 552 |
+
fn=update_max_num_labels,
|
| 553 |
+
inputs=[labels],
|
| 554 |
+
outputs=[num_labels],
|
| 555 |
)
|
| 556 |
num_labels.change(
|
| 557 |
fn=generate_pipeline_code,
|
| 558 |
inputs=[system_prompt, difficulty, clarity, labels, num_labels, num_rows],
|
| 559 |
outputs=[pipeline_code],
|
| 560 |
)
|
| 561 |
+
num_rows.change(
|
| 562 |
+
fn=generate_pipeline_code,
|
| 563 |
+
inputs=[system_prompt, difficulty, clarity, labels, num_labels, num_rows],
|
| 564 |
+
outputs=[pipeline_code],
|
| 565 |
+
)
|
src/distilabel_dataset_generator/pipelines/textcat.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
| 1 |
from typing import List
|
| 2 |
|
| 3 |
import pandas as pd
|
|
|
|
| 4 |
from distilabel.llms import InferenceEndpointsLLM
|
| 5 |
from distilabel.steps.tasks import (
|
| 6 |
GenerateTextClassificationData,
|
|
@@ -88,6 +89,7 @@ def generate_pipeline_code(
|
|
| 88 |
base_code = f"""
|
| 89 |
# Requirements: `pip install distilabel[hf-inference-endpoints]`
|
| 90 |
import os
|
|
|
|
| 91 |
from distilabel.llms import InferenceEndpointsLLM
|
| 92 |
from distilabel.pipeline import Pipeline
|
| 93 |
from distilabel.steps import LoadDataFromDicts, KeepColumns
|
|
@@ -111,6 +113,8 @@ with Pipeline(name="textcat") as pipeline:
|
|
| 111 |
generation_kwargs={{
|
| 112 |
"temperature": 0.8,
|
| 113 |
"max_new_tokens": 2048,
|
|
|
|
|
|
|
| 114 |
}},
|
| 115 |
),
|
| 116 |
difficulty={None if difficulty == "mixed" else repr(difficulty)},
|
|
@@ -175,8 +179,10 @@ def get_textcat_generator(difficulty, clarity, is_sample):
|
|
| 175 |
tokenizer_id=MODEL,
|
| 176 |
api_key=_get_next_api_key(),
|
| 177 |
generation_kwargs={
|
| 178 |
-
"temperature": 0.
|
| 179 |
-
"max_new_tokens": 256 if is_sample else
|
|
|
|
|
|
|
| 180 |
},
|
| 181 |
),
|
| 182 |
difficulty=None if difficulty == "mixed" else difficulty,
|
|
@@ -186,15 +192,15 @@ def get_textcat_generator(difficulty, clarity, is_sample):
|
|
| 186 |
return textcat_generator
|
| 187 |
|
| 188 |
|
| 189 |
-
def get_labeller_generator(system_prompt, labels, num_labels
|
| 190 |
labeller_generator = TextClassification(
|
| 191 |
llm=InferenceEndpointsLLM(
|
| 192 |
model_id=MODEL,
|
| 193 |
tokenizer_id=MODEL,
|
| 194 |
api_key=_get_next_api_key(),
|
| 195 |
generation_kwargs={
|
| 196 |
-
"temperature": 0.
|
| 197 |
-
"max_new_tokens":
|
| 198 |
},
|
| 199 |
),
|
| 200 |
context=system_prompt,
|
|
|
|
| 1 |
from typing import List
|
| 2 |
|
| 3 |
import pandas as pd
|
| 4 |
+
import random
|
| 5 |
from distilabel.llms import InferenceEndpointsLLM
|
| 6 |
from distilabel.steps.tasks import (
|
| 7 |
GenerateTextClassificationData,
|
|
|
|
| 89 |
base_code = f"""
|
| 90 |
# Requirements: `pip install distilabel[hf-inference-endpoints]`
|
| 91 |
import os
|
| 92 |
+
import random
|
| 93 |
from distilabel.llms import InferenceEndpointsLLM
|
| 94 |
from distilabel.pipeline import Pipeline
|
| 95 |
from distilabel.steps import LoadDataFromDicts, KeepColumns
|
|
|
|
| 113 |
generation_kwargs={{
|
| 114 |
"temperature": 0.8,
|
| 115 |
"max_new_tokens": 2048,
|
| 116 |
+
"do_sample": True,
|
| 117 |
+
"seed": random.randint(0, 2**32 - 1),
|
| 118 |
}},
|
| 119 |
),
|
| 120 |
difficulty={None if difficulty == "mixed" else repr(difficulty)},
|
|
|
|
| 179 |
tokenizer_id=MODEL,
|
| 180 |
api_key=_get_next_api_key(),
|
| 181 |
generation_kwargs={
|
| 182 |
+
"temperature": 0.9,
|
| 183 |
+
"max_new_tokens": 256 if is_sample else 2048,
|
| 184 |
+
"do_sample": True,
|
| 185 |
+
"seed": random.randint(0, 2**32 - 1),
|
| 186 |
},
|
| 187 |
),
|
| 188 |
difficulty=None if difficulty == "mixed" else difficulty,
|
|
|
|
| 192 |
return textcat_generator
|
| 193 |
|
| 194 |
|
| 195 |
+
def get_labeller_generator(system_prompt, labels, num_labels):
|
| 196 |
labeller_generator = TextClassification(
|
| 197 |
llm=InferenceEndpointsLLM(
|
| 198 |
model_id=MODEL,
|
| 199 |
tokenizer_id=MODEL,
|
| 200 |
api_key=_get_next_api_key(),
|
| 201 |
generation_kwargs={
|
| 202 |
+
"temperature": 0.7,
|
| 203 |
+
"max_new_tokens": 2048,
|
| 204 |
},
|
| 205 |
),
|
| 206 |
context=system_prompt,
|