textcat-review (#12)
Browse files- fix: apply feedback (3f2128047133bf339ddb5d4c4a6d6d41edb368da)
- fix: remove extra args (d27c1e6872d8be8b25d12d2b7e4baa5c075ed8c5)
- fix: add seed for more randomized samples (46f00bc57d59efb6274c287aa3b3ab0046d4d64e)
- fix: typo (a3f4be77171e6db5c804079dc82e0e30354bcec9)
- fix: correction label or labels (d59361703bc3414d4d5845cbe57ae760b52be7fc)
- fix: duplicated labels in labels and number of rows update listener in raw pipeline (b92482822c81a2a4330d54fd35640c6984ae8bda)
src/distilabel_dataset_generator/apps/base.py
CHANGED
@@ -38,8 +38,8 @@ def get_main_ui(
|
|
38 |
if task == TEXTCAT_TASK:
|
39 |
result = fn_generate_dataset(
|
40 |
system_prompt=system_prompt,
|
41 |
-
difficulty="
|
42 |
-
clarity="
|
43 |
labels=[],
|
44 |
num_labels=1,
|
45 |
num_rows=1,
|
@@ -271,7 +271,11 @@ def get_iterate_on_sample_dataset_ui(
|
|
271 |
with gr.Row():
|
272 |
sample_dataset = gr.Dataframe(
|
273 |
value=default_datasets[0],
|
274 |
-
label=
|
|
|
|
|
|
|
|
|
275 |
interactive=False,
|
276 |
wrap=True,
|
277 |
)
|
|
|
38 |
if task == TEXTCAT_TASK:
|
39 |
result = fn_generate_dataset(
|
40 |
system_prompt=system_prompt,
|
41 |
+
difficulty="high school",
|
42 |
+
clarity="clear",
|
43 |
labels=[],
|
44 |
num_labels=1,
|
45 |
num_rows=1,
|
|
|
271 |
with gr.Row():
|
272 |
sample_dataset = gr.Dataframe(
|
273 |
value=default_datasets[0],
|
274 |
+
label=(
|
275 |
+
"Sample dataset. Text truncated to 256 tokens."
|
276 |
+
if task == TEXTCAT_TASK
|
277 |
+
else "Sample dataset. Prompts and completions truncated to 256 tokens."
|
278 |
+
),
|
279 |
interactive=False,
|
280 |
wrap=True,
|
281 |
)
|
src/distilabel_dataset_generator/apps/textcat.py
CHANGED
@@ -215,7 +215,6 @@ def generate_dataset(
|
|
215 |
system_prompt=system_prompt,
|
216 |
labels=labels,
|
217 |
num_labels=num_labels,
|
218 |
-
is_sample=is_sample,
|
219 |
)
|
220 |
total_steps: int = num_rows * 2
|
221 |
batch_size = DEFAULT_BATCH_SIZE
|
@@ -280,11 +279,13 @@ def generate_dataset(
|
|
280 |
else:
|
281 |
dataframe["labels"] = dataframe["labels"].apply(
|
282 |
lambda x: (
|
283 |
-
|
284 |
-
|
285 |
-
|
286 |
-
|
287 |
-
|
|
|
|
|
288 |
if isinstance(x, list)
|
289 |
else None
|
290 |
)
|
@@ -309,6 +310,9 @@ def validate_input_labels(labels):
|
|
309 |
)
|
310 |
return labels
|
311 |
|
|
|
|
|
|
|
312 |
|
313 |
(
|
314 |
app,
|
@@ -354,7 +358,7 @@ with app:
|
|
354 |
],
|
355 |
value="mixed",
|
356 |
label="Difficulty",
|
357 |
-
info="
|
358 |
)
|
359 |
clarity = gr.Dropdown(
|
360 |
choices=[
|
@@ -368,7 +372,7 @@ with app:
|
|
368 |
],
|
369 |
value="mixed",
|
370 |
label="Clarity",
|
371 |
-
info="
|
372 |
)
|
373 |
with gr.Column():
|
374 |
labels = gr.Dropdown(
|
@@ -385,18 +389,18 @@ with app:
|
|
385 |
size="sm",
|
386 |
)
|
387 |
num_labels = gr.Number(
|
388 |
-
label="Number of labels",
|
389 |
value=1,
|
390 |
minimum=1,
|
391 |
maximum=10,
|
392 |
-
info="
|
393 |
)
|
394 |
num_rows = gr.Number(
|
395 |
label="Number of rows",
|
396 |
value=10,
|
397 |
minimum=1,
|
398 |
maximum=500,
|
399 |
-
info="More rows will take
|
400 |
)
|
401 |
|
402 |
pipeline_code = get_pipeline_code_ui(
|
@@ -415,6 +419,10 @@ with app:
|
|
415 |
fn=update_suggested_labels,
|
416 |
inputs=[system_prompt],
|
417 |
outputs=labels,
|
|
|
|
|
|
|
|
|
418 |
)
|
419 |
|
420 |
gr.on(
|
@@ -540,9 +548,18 @@ with app:
|
|
540 |
fn=generate_pipeline_code,
|
541 |
inputs=[system_prompt, difficulty, clarity, labels, num_labels, num_rows],
|
542 |
outputs=[pipeline_code],
|
|
|
|
|
|
|
|
|
543 |
)
|
544 |
num_labels.change(
|
545 |
fn=generate_pipeline_code,
|
546 |
inputs=[system_prompt, difficulty, clarity, labels, num_labels, num_rows],
|
547 |
outputs=[pipeline_code],
|
548 |
)
|
|
|
|
|
|
|
|
|
|
|
|
215 |
system_prompt=system_prompt,
|
216 |
labels=labels,
|
217 |
num_labels=num_labels,
|
|
|
218 |
)
|
219 |
total_steps: int = num_rows * 2
|
220 |
batch_size = DEFAULT_BATCH_SIZE
|
|
|
279 |
else:
|
280 |
dataframe["labels"] = dataframe["labels"].apply(
|
281 |
lambda x: (
|
282 |
+
list(
|
283 |
+
set(
|
284 |
+
label.lower().strip()
|
285 |
+
for label in x
|
286 |
+
if label.lower().strip() in labels
|
287 |
+
)
|
288 |
+
)
|
289 |
if isinstance(x, list)
|
290 |
else None
|
291 |
)
|
|
|
310 |
)
|
311 |
return labels
|
312 |
|
313 |
+
def update_max_num_labels(labels):
|
314 |
+
return gr.update(maximum=len(labels) if labels else 1)
|
315 |
+
|
316 |
|
317 |
(
|
318 |
app,
|
|
|
358 |
],
|
359 |
value="mixed",
|
360 |
label="Difficulty",
|
361 |
+
info="Select the comprehension level for the text. Ensure it matches the task context.",
|
362 |
)
|
363 |
clarity = gr.Dropdown(
|
364 |
choices=[
|
|
|
372 |
],
|
373 |
value="mixed",
|
374 |
label="Clarity",
|
375 |
+
info="Set how easily the correct label or labels can be identified.",
|
376 |
)
|
377 |
with gr.Column():
|
378 |
labels = gr.Dropdown(
|
|
|
389 |
size="sm",
|
390 |
)
|
391 |
num_labels = gr.Number(
|
392 |
+
label="Number of labels per text",
|
393 |
value=1,
|
394 |
minimum=1,
|
395 |
maximum=10,
|
396 |
+
info="Select 1 for single-label and >1 for multi-label.",
|
397 |
)
|
398 |
num_rows = gr.Number(
|
399 |
label="Number of rows",
|
400 |
value=10,
|
401 |
minimum=1,
|
402 |
maximum=500,
|
403 |
+
info="Select the number of rows in the dataset. More rows will take more time.",
|
404 |
)
|
405 |
|
406 |
pipeline_code = get_pipeline_code_ui(
|
|
|
419 |
fn=update_suggested_labels,
|
420 |
inputs=[system_prompt],
|
421 |
outputs=labels,
|
422 |
+
).then(
|
423 |
+
fn=update_max_num_labels,
|
424 |
+
inputs=[labels],
|
425 |
+
outputs=[num_labels],
|
426 |
)
|
427 |
|
428 |
gr.on(
|
|
|
548 |
fn=generate_pipeline_code,
|
549 |
inputs=[system_prompt, difficulty, clarity, labels, num_labels, num_rows],
|
550 |
outputs=[pipeline_code],
|
551 |
+
).then(
|
552 |
+
fn=update_max_num_labels,
|
553 |
+
inputs=[labels],
|
554 |
+
outputs=[num_labels],
|
555 |
)
|
556 |
num_labels.change(
|
557 |
fn=generate_pipeline_code,
|
558 |
inputs=[system_prompt, difficulty, clarity, labels, num_labels, num_rows],
|
559 |
outputs=[pipeline_code],
|
560 |
)
|
561 |
+
num_rows.change(
|
562 |
+
fn=generate_pipeline_code,
|
563 |
+
inputs=[system_prompt, difficulty, clarity, labels, num_labels, num_rows],
|
564 |
+
outputs=[pipeline_code],
|
565 |
+
)
|
src/distilabel_dataset_generator/pipelines/textcat.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
from typing import List
|
2 |
|
3 |
import pandas as pd
|
|
|
4 |
from distilabel.llms import InferenceEndpointsLLM
|
5 |
from distilabel.steps.tasks import (
|
6 |
GenerateTextClassificationData,
|
@@ -88,6 +89,7 @@ def generate_pipeline_code(
|
|
88 |
base_code = f"""
|
89 |
# Requirements: `pip install distilabel[hf-inference-endpoints]`
|
90 |
import os
|
|
|
91 |
from distilabel.llms import InferenceEndpointsLLM
|
92 |
from distilabel.pipeline import Pipeline
|
93 |
from distilabel.steps import LoadDataFromDicts, KeepColumns
|
@@ -111,6 +113,8 @@ with Pipeline(name="textcat") as pipeline:
|
|
111 |
generation_kwargs={{
|
112 |
"temperature": 0.8,
|
113 |
"max_new_tokens": 2048,
|
|
|
|
|
114 |
}},
|
115 |
),
|
116 |
difficulty={None if difficulty == "mixed" else repr(difficulty)},
|
@@ -175,8 +179,10 @@ def get_textcat_generator(difficulty, clarity, is_sample):
|
|
175 |
tokenizer_id=MODEL,
|
176 |
api_key=_get_next_api_key(),
|
177 |
generation_kwargs={
|
178 |
-
"temperature": 0.
|
179 |
-
"max_new_tokens": 256 if is_sample else
|
|
|
|
|
180 |
},
|
181 |
),
|
182 |
difficulty=None if difficulty == "mixed" else difficulty,
|
@@ -186,15 +192,15 @@ def get_textcat_generator(difficulty, clarity, is_sample):
|
|
186 |
return textcat_generator
|
187 |
|
188 |
|
189 |
-
def get_labeller_generator(system_prompt, labels, num_labels
|
190 |
labeller_generator = TextClassification(
|
191 |
llm=InferenceEndpointsLLM(
|
192 |
model_id=MODEL,
|
193 |
tokenizer_id=MODEL,
|
194 |
api_key=_get_next_api_key(),
|
195 |
generation_kwargs={
|
196 |
-
"temperature": 0.
|
197 |
-
"max_new_tokens":
|
198 |
},
|
199 |
),
|
200 |
context=system_prompt,
|
|
|
1 |
from typing import List
|
2 |
|
3 |
import pandas as pd
|
4 |
+
import random
|
5 |
from distilabel.llms import InferenceEndpointsLLM
|
6 |
from distilabel.steps.tasks import (
|
7 |
GenerateTextClassificationData,
|
|
|
89 |
base_code = f"""
|
90 |
# Requirements: `pip install distilabel[hf-inference-endpoints]`
|
91 |
import os
|
92 |
+
import random
|
93 |
from distilabel.llms import InferenceEndpointsLLM
|
94 |
from distilabel.pipeline import Pipeline
|
95 |
from distilabel.steps import LoadDataFromDicts, KeepColumns
|
|
|
113 |
generation_kwargs={{
|
114 |
"temperature": 0.8,
|
115 |
"max_new_tokens": 2048,
|
116 |
+
"do_sample": True,
|
117 |
+
"seed": random.randint(0, 2**32 - 1),
|
118 |
}},
|
119 |
),
|
120 |
difficulty={None if difficulty == "mixed" else repr(difficulty)},
|
|
|
179 |
tokenizer_id=MODEL,
|
180 |
api_key=_get_next_api_key(),
|
181 |
generation_kwargs={
|
182 |
+
"temperature": 0.9,
|
183 |
+
"max_new_tokens": 256 if is_sample else 2048,
|
184 |
+
"do_sample": True,
|
185 |
+
"seed": random.randint(0, 2**32 - 1),
|
186 |
},
|
187 |
),
|
188 |
difficulty=None if difficulty == "mixed" else difficulty,
|
|
|
192 |
return textcat_generator
|
193 |
|
194 |
|
195 |
+
def get_labeller_generator(system_prompt, labels, num_labels):
|
196 |
labeller_generator = TextClassification(
|
197 |
llm=InferenceEndpointsLLM(
|
198 |
model_id=MODEL,
|
199 |
tokenizer_id=MODEL,
|
200 |
api_key=_get_next_api_key(),
|
201 |
generation_kwargs={
|
202 |
+
"temperature": 0.7,
|
203 |
+
"max_new_tokens": 2048,
|
204 |
},
|
205 |
),
|
206 |
context=system_prompt,
|