update textcat (separate prompt and labels) and use input parameters
Browse files
src/distilabel_dataset_generator/apps/textcat.py
CHANGED
@@ -1,3 +1,4 @@
|
|
|
|
1 |
import re
|
2 |
import uuid
|
3 |
from typing import List, Union
|
@@ -24,7 +25,6 @@ from src.distilabel_dataset_generator.pipelines.embeddings import (
|
|
24 |
)
|
25 |
from src.distilabel_dataset_generator.pipelines.textcat import (
|
26 |
DEFAULT_DATASET_DESCRIPTIONS,
|
27 |
-
PROMPT_CREATION_PROMPT,
|
28 |
generate_pipeline_code,
|
29 |
get_labeller_generator,
|
30 |
get_prompt_generator,
|
@@ -44,36 +44,33 @@ def generate_system_prompt(dataset_description, temperature, progress=gr.Progres
|
|
44 |
progress(0.3, desc="Initializing text generation")
|
45 |
generate_description = get_prompt_generator(temperature)
|
46 |
progress(0.7, desc="Generating text classification task")
|
47 |
-
|
48 |
generate_description.process(
|
49 |
[
|
50 |
{
|
51 |
-
"system_prompt": PROMPT_CREATION_PROMPT,
|
52 |
"instruction": dataset_description,
|
53 |
}
|
54 |
]
|
55 |
)
|
56 |
)[0]["generation"]
|
57 |
progress(1.0, desc="Text classification task generated")
|
58 |
-
|
|
|
|
|
|
|
59 |
|
60 |
-
|
61 |
-
|
62 |
-
df = generate_dataset(
|
63 |
system_prompt=system_prompt,
|
64 |
-
difficulty=
|
65 |
-
clarity=
|
66 |
-
labels=
|
67 |
-
num_labels=
|
68 |
num_rows=10,
|
69 |
progress=progress,
|
70 |
is_sample=True,
|
71 |
)
|
72 |
-
|
73 |
-
df = df[["label", "text"]]
|
74 |
-
elif "labels" in df.columns:
|
75 |
-
df = df[["labels", "text"]]
|
76 |
-
return df
|
77 |
|
78 |
|
79 |
def generate_dataset(
|
@@ -86,17 +83,13 @@ def generate_dataset(
|
|
86 |
is_sample: bool = False,
|
87 |
progress=gr.Progress(),
|
88 |
) -> pd.DataFrame:
|
89 |
-
if is_sample:
|
90 |
-
multiplier = 1
|
91 |
-
else:
|
92 |
-
multiplier = 2
|
93 |
progress(0.0, desc="(1/2) Generating text classification data")
|
94 |
labels = get_preprocess_labels(labels)
|
95 |
textcat_generator = get_textcat_generator(
|
96 |
difficulty=difficulty, clarity=clarity, is_sample=is_sample
|
97 |
)
|
98 |
labeller_generator = get_labeller_generator(
|
99 |
-
system_prompt=system_prompt,
|
100 |
labels=labels,
|
101 |
num_labels=num_labels,
|
102 |
)
|
@@ -108,13 +101,15 @@ def generate_dataset(
|
|
108 |
textcat_results = []
|
109 |
while n_processed < num_rows:
|
110 |
progress(
|
111 |
-
|
112 |
total=total_steps,
|
113 |
desc="(1/2) Generating text classification data",
|
114 |
)
|
115 |
remaining_rows = num_rows - n_processed
|
116 |
batch_size = min(batch_size, remaining_rows)
|
117 |
-
inputs = [
|
|
|
|
|
118 |
batch = list(textcat_generator.process(inputs=inputs))
|
119 |
textcat_results.extend(batch[0])
|
120 |
n_processed += batch_size
|
@@ -122,58 +117,41 @@ def generate_dataset(
|
|
122 |
result["text"] = result["input_text"]
|
123 |
|
124 |
# label text classification data
|
125 |
-
progress(
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
while n_processed < num_rows:
|
130 |
-
progress(
|
131 |
-
0.5 + 0.5 * n_processed / num_rows,
|
132 |
-
total=total_steps,
|
133 |
-
desc="(1/2) Labeling text classification data",
|
134 |
-
)
|
135 |
-
batch = textcat_results[n_processed : n_processed + batch_size]
|
136 |
-
labels_batch = list(labeller_generator.process(inputs=batch))
|
137 |
-
labeller_results.extend(labels_batch[0])
|
138 |
-
n_processed += batch_size
|
139 |
progress(
|
140 |
-
|
141 |
total=total_steps,
|
142 |
-
desc="(
|
143 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
144 |
|
145 |
# create final dataset
|
146 |
distiset_results = []
|
147 |
-
|
148 |
-
for result in source_results:
|
149 |
record = {
|
150 |
key: result[key]
|
151 |
-
for key in ["
|
152 |
if key in result
|
153 |
}
|
154 |
distiset_results.append(record)
|
155 |
|
156 |
dataframe = pd.DataFrame(distiset_results)
|
157 |
-
if
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
)
|
163 |
-
else:
|
164 |
-
dataframe["labels"] = dataframe["labels"].apply(
|
165 |
-
lambda x: (
|
166 |
-
list(
|
167 |
-
set(
|
168 |
-
label.lower().strip()
|
169 |
-
for label in x
|
170 |
-
if label.lower().strip() in labels
|
171 |
-
)
|
172 |
-
)
|
173 |
-
if isinstance(x, list)
|
174 |
-
else None
|
175 |
-
)
|
176 |
-
)
|
177 |
progress(1.0, desc="Dataset generation completed")
|
178 |
return dataframe
|
179 |
|
@@ -281,7 +259,7 @@ def push_dataset_to_argilla(
|
|
281 |
)
|
282 |
|
283 |
dataframe["text_length"] = dataframe["text"].apply(len)
|
284 |
-
dataframe["text_embeddings"] = get_embeddings(dataframe["text"])
|
285 |
|
286 |
progress(0.5, desc="Creating dataset")
|
287 |
rg_dataset = client.datasets(name=repo_name, workspace=hf_user)
|
@@ -330,15 +308,6 @@ def push_dataset_to_argilla(
|
|
330 |
return ""
|
331 |
|
332 |
|
333 |
-
def update_suggested_labels(system_prompt):
|
334 |
-
new_labels = re.findall(r"'(\b[\w-]+\b)'", system_prompt)
|
335 |
-
if not new_labels:
|
336 |
-
return gr.Warning(
|
337 |
-
"No labels found in the system prompt. Please add labels manually."
|
338 |
-
)
|
339 |
-
return gr.update(choices=new_labels, value=new_labels)
|
340 |
-
|
341 |
-
|
342 |
def validate_input_labels(labels):
|
343 |
if not labels or len(labels) < 2:
|
344 |
raise gr.Error(
|
@@ -448,7 +417,7 @@ with gr.Blocks(css=_LOGGED_OUT_CSS) as app:
|
|
448 |
)
|
449 |
with gr.Column(scale=3):
|
450 |
dataframe = gr.Dataframe(
|
451 |
-
headers=["labels", "text"], wrap=True, height=
|
452 |
)
|
453 |
|
454 |
gr.HTML("<hr>")
|
@@ -496,27 +465,35 @@ with gr.Blocks(css=_LOGGED_OUT_CSS) as app:
|
|
496 |
label="Distilabel Pipeline Code",
|
497 |
)
|
498 |
|
499 |
-
|
500 |
-
triggers=[load_btn.click, btn_apply_to_sample_dataset.click],
|
501 |
fn=generate_system_prompt,
|
502 |
inputs=[dataset_description, temperature],
|
503 |
-
outputs=[system_prompt,
|
504 |
show_progress=True,
|
505 |
).then(
|
506 |
fn=generate_sample_dataset,
|
507 |
-
inputs=[system_prompt],
|
508 |
outputs=[dataframe],
|
509 |
show_progress=True,
|
510 |
).then(
|
511 |
-
fn=update_suggested_labels,
|
512 |
-
inputs=[system_prompt],
|
513 |
-
outputs=labels,
|
514 |
-
).then(
|
515 |
fn=update_max_num_labels,
|
516 |
inputs=[labels],
|
517 |
outputs=[num_labels],
|
518 |
)
|
519 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
520 |
btn_push_to_hub.click(
|
521 |
fn=validate_argilla_user_workspace_dataset,
|
522 |
inputs=[repo_name],
|
|
|
1 |
+
import json
|
2 |
import re
|
3 |
import uuid
|
4 |
from typing import List, Union
|
|
|
25 |
)
|
26 |
from src.distilabel_dataset_generator.pipelines.textcat import (
|
27 |
DEFAULT_DATASET_DESCRIPTIONS,
|
|
|
28 |
generate_pipeline_code,
|
29 |
get_labeller_generator,
|
30 |
get_prompt_generator,
|
|
|
44 |
progress(0.3, desc="Initializing text generation")
|
45 |
generate_description = get_prompt_generator(temperature)
|
46 |
progress(0.7, desc="Generating text classification task")
|
47 |
+
result = next(
|
48 |
generate_description.process(
|
49 |
[
|
50 |
{
|
|
|
51 |
"instruction": dataset_description,
|
52 |
}
|
53 |
]
|
54 |
)
|
55 |
)[0]["generation"]
|
56 |
progress(1.0, desc="Text classification task generated")
|
57 |
+
data = json.loads(result)
|
58 |
+
system_prompt = data["classification_task"]
|
59 |
+
labels = data["labels"]
|
60 |
+
return system_prompt, labels
|
61 |
|
62 |
+
def generate_sample_dataset(system_prompt, difficulty, clarity, labels, num_labels, progress=gr.Progress()):
|
63 |
+
dataframe = generate_dataset(
|
|
|
64 |
system_prompt=system_prompt,
|
65 |
+
difficulty=difficulty,
|
66 |
+
clarity=clarity,
|
67 |
+
labels=labels,
|
68 |
+
num_labels=num_labels,
|
69 |
num_rows=10,
|
70 |
progress=progress,
|
71 |
is_sample=True,
|
72 |
)
|
73 |
+
return dataframe
|
|
|
|
|
|
|
|
|
74 |
|
75 |
|
76 |
def generate_dataset(
|
|
|
83 |
is_sample: bool = False,
|
84 |
progress=gr.Progress(),
|
85 |
) -> pd.DataFrame:
|
|
|
|
|
|
|
|
|
86 |
progress(0.0, desc="(1/2) Generating text classification data")
|
87 |
labels = get_preprocess_labels(labels)
|
88 |
textcat_generator = get_textcat_generator(
|
89 |
difficulty=difficulty, clarity=clarity, is_sample=is_sample
|
90 |
)
|
91 |
labeller_generator = get_labeller_generator(
|
92 |
+
system_prompt=f"{system_prompt} {', '.join(labels)}",
|
93 |
labels=labels,
|
94 |
num_labels=num_labels,
|
95 |
)
|
|
|
101 |
textcat_results = []
|
102 |
while n_processed < num_rows:
|
103 |
progress(
|
104 |
+
2 * 0.5 * n_processed / num_rows,
|
105 |
total=total_steps,
|
106 |
desc="(1/2) Generating text classification data",
|
107 |
)
|
108 |
remaining_rows = num_rows - n_processed
|
109 |
batch_size = min(batch_size, remaining_rows)
|
110 |
+
inputs = [
|
111 |
+
{"task": f"{system_prompt} {', '.join(labels)}"} for _ in range(batch_size)
|
112 |
+
]
|
113 |
batch = list(textcat_generator.process(inputs=inputs))
|
114 |
textcat_results.extend(batch[0])
|
115 |
n_processed += batch_size
|
|
|
117 |
result["text"] = result["input_text"]
|
118 |
|
119 |
# label text classification data
|
120 |
+
progress(2 * 0.5, desc="(1/2) Generating text classification data")
|
121 |
+
n_processed = 0
|
122 |
+
labeller_results = []
|
123 |
+
while n_processed < num_rows:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
124 |
progress(
|
125 |
+
0.5 + 0.5 * n_processed / num_rows,
|
126 |
total=total_steps,
|
127 |
+
desc="(1/2) Labeling text classification data",
|
128 |
)
|
129 |
+
batch = textcat_results[n_processed : n_processed + batch_size]
|
130 |
+
labels_batch = list(labeller_generator.process(inputs=batch))
|
131 |
+
labeller_results.extend(labels_batch[0])
|
132 |
+
n_processed += batch_size
|
133 |
+
progress(
|
134 |
+
1,
|
135 |
+
total=total_steps,
|
136 |
+
desc="(2/2) Creating dataset",
|
137 |
+
)
|
138 |
|
139 |
# create final dataset
|
140 |
distiset_results = []
|
141 |
+
for result in labeller_results:
|
|
|
142 |
record = {
|
143 |
key: result[key]
|
144 |
+
for key in ["labels", "text"]
|
145 |
if key in result
|
146 |
}
|
147 |
distiset_results.append(record)
|
148 |
|
149 |
dataframe = pd.DataFrame(distiset_results)
|
150 |
+
if num_labels == 1:
|
151 |
+
dataframe = dataframe.rename(columns={"labels": "label"})
|
152 |
+
dataframe["label"] = dataframe["label"].apply(
|
153 |
+
lambda x: x.lower().strip() if x.lower().strip() in labels else None
|
154 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
155 |
progress(1.0, desc="Dataset generation completed")
|
156 |
return dataframe
|
157 |
|
|
|
259 |
)
|
260 |
|
261 |
dataframe["text_length"] = dataframe["text"].apply(len)
|
262 |
+
dataframe["text_embeddings"] = get_embeddings(dataframe["text"].to_list())
|
263 |
|
264 |
progress(0.5, desc="Creating dataset")
|
265 |
rg_dataset = client.datasets(name=repo_name, workspace=hf_user)
|
|
|
308 |
return ""
|
309 |
|
310 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
311 |
def validate_input_labels(labels):
|
312 |
if not labels or len(labels) < 2:
|
313 |
raise gr.Error(
|
|
|
417 |
)
|
418 |
with gr.Column(scale=3):
|
419 |
dataframe = gr.Dataframe(
|
420 |
+
headers=["labels", "text"], wrap=True, height=500, interactive=False
|
421 |
)
|
422 |
|
423 |
gr.HTML("<hr>")
|
|
|
465 |
label="Distilabel Pipeline Code",
|
466 |
)
|
467 |
|
468 |
+
load_btn.click(
|
|
|
469 |
fn=generate_system_prompt,
|
470 |
inputs=[dataset_description, temperature],
|
471 |
+
outputs=[system_prompt, labels],
|
472 |
show_progress=True,
|
473 |
).then(
|
474 |
fn=generate_sample_dataset,
|
475 |
+
inputs=[system_prompt, difficulty, clarity, labels, num_labels],
|
476 |
outputs=[dataframe],
|
477 |
show_progress=True,
|
478 |
).then(
|
|
|
|
|
|
|
|
|
479 |
fn=update_max_num_labels,
|
480 |
inputs=[labels],
|
481 |
outputs=[num_labels],
|
482 |
)
|
483 |
|
484 |
+
labels.input(
|
485 |
+
fn=update_max_num_labels,
|
486 |
+
inputs=[labels],
|
487 |
+
outputs=[num_labels],
|
488 |
+
)
|
489 |
+
|
490 |
+
btn_apply_to_sample_dataset.click(
|
491 |
+
fn=generate_sample_dataset,
|
492 |
+
inputs=[system_prompt, difficulty, clarity, labels, num_labels],
|
493 |
+
outputs=[dataframe],
|
494 |
+
show_progress=True,
|
495 |
+
)
|
496 |
+
|
497 |
btn_push_to_hub.click(
|
498 |
fn=validate_argilla_user_workspace_dataset,
|
499 |
inputs=[repo_name],
|
src/distilabel_dataset_generator/pipelines/textcat.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
import random
|
|
|
2 |
from typing import List
|
3 |
|
4 |
from distilabel.llms import InferenceEndpointsLLM
|
@@ -22,25 +23,27 @@ The prompt you write should follow the same style and structure as the following
|
|
22 |
|
23 |
If a label is composed of multiple words, use a hyphen to separate them. For example, 'smartphone-review', 'customer-service', 'product-quality'.:
|
24 |
|
25 |
-
Classify the following customer review of a cinema as
|
26 |
|
27 |
-
|
28 |
|
29 |
-
|
30 |
|
31 |
-
|
32 |
|
33 |
-
|
34 |
|
35 |
-
|
36 |
|
37 |
-
Categorize the following
|
38 |
|
39 |
-
Classify the following
|
40 |
|
41 |
-
|
42 |
|
43 |
-
Classify the following
|
|
|
|
|
44 |
|
45 |
User dataset description:
|
46 |
"""
|
@@ -51,6 +54,82 @@ DEFAULT_DATASET_DESCRIPTIONS = [
|
|
51 |
]
|
52 |
|
53 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
def generate_pipeline_code(
|
55 |
system_prompt: str,
|
56 |
difficulty: str = None,
|
@@ -146,63 +225,3 @@ with Pipeline(name="textcat") as pipeline:
|
|
146 |
distiset = pipeline.run()
|
147 |
"""
|
148 |
)
|
149 |
-
|
150 |
-
|
151 |
-
def get_prompt_generator(temperature):
|
152 |
-
prompt_generator = TextGeneration(
|
153 |
-
llm=InferenceEndpointsLLM(
|
154 |
-
api_key=_get_next_api_key(),
|
155 |
-
model_id=MODEL,
|
156 |
-
tokenizer_id=MODEL,
|
157 |
-
generation_kwargs={
|
158 |
-
"temperature": temperature,
|
159 |
-
"max_new_tokens": 2048,
|
160 |
-
"do_sample": True,
|
161 |
-
},
|
162 |
-
),
|
163 |
-
use_system_prompt=True,
|
164 |
-
)
|
165 |
-
prompt_generator.load()
|
166 |
-
return prompt_generator
|
167 |
-
|
168 |
-
|
169 |
-
def get_textcat_generator(difficulty, clarity, is_sample):
|
170 |
-
textcat_generator = GenerateTextClassificationData(
|
171 |
-
llm=InferenceEndpointsLLM(
|
172 |
-
model_id=MODEL,
|
173 |
-
tokenizer_id=MODEL,
|
174 |
-
api_key=_get_next_api_key(),
|
175 |
-
generation_kwargs={
|
176 |
-
"temperature": 0.9,
|
177 |
-
"max_new_tokens": 256 if is_sample else 2048,
|
178 |
-
"do_sample": True,
|
179 |
-
"top_k": 50,
|
180 |
-
"top_p": 0.95,
|
181 |
-
},
|
182 |
-
),
|
183 |
-
difficulty=None if difficulty == "mixed" else difficulty,
|
184 |
-
clarity=None if clarity == "mixed" else clarity,
|
185 |
-
seed=random.randint(0, 2**32 - 1),
|
186 |
-
)
|
187 |
-
textcat_generator.load()
|
188 |
-
return textcat_generator
|
189 |
-
|
190 |
-
|
191 |
-
def get_labeller_generator(system_prompt, labels, num_labels):
|
192 |
-
labeller_generator = TextClassification(
|
193 |
-
llm=InferenceEndpointsLLM(
|
194 |
-
model_id=MODEL,
|
195 |
-
tokenizer_id=MODEL,
|
196 |
-
api_key=_get_next_api_key(),
|
197 |
-
generation_kwargs={
|
198 |
-
"temperature": 0.7,
|
199 |
-
"max_new_tokens": 2048,
|
200 |
-
},
|
201 |
-
),
|
202 |
-
context=system_prompt,
|
203 |
-
available_labels=labels,
|
204 |
-
n=num_labels,
|
205 |
-
default_label="unknown",
|
206 |
-
)
|
207 |
-
labeller_generator.load()
|
208 |
-
return labeller_generator
|
|
|
1 |
import random
|
2 |
+
from pydantic import BaseModel, Field
|
3 |
from typing import List
|
4 |
|
5 |
from distilabel.llms import InferenceEndpointsLLM
|
|
|
23 |
|
24 |
If a label is composed of multiple words, use a hyphen to separate them. For example, 'smartphone-review', 'customer-service', 'product-quality'.:
|
25 |
|
26 |
+
{"classification_task": "Classify the following customer review of a cinema as", "labels": ["positive", "negative"]}
|
27 |
|
28 |
+
{"classification_task": "Categorize the following news article into one or more of the following categories:", "labels": ["politics", "sports", "technology", "entertainment", "health", "business", "environment", "education", "science", "international"]}
|
29 |
|
30 |
+
{"classification_task": "Classify the following news article into one or more of the following categories:", "labels": ['politics', 'sports', 'technology', 'entertainment', 'health', 'business', 'environment', 'education', 'science', 'international']}
|
31 |
|
32 |
+
{"classification_task": "Determine the sentiment of the following social media post:", "labels": ['ambiguous', 'sarcastic', 'informative', 'emotional']}
|
33 |
|
34 |
+
{"classification_task": "Identify the issue category for the following technical support ticket:", "labels": ['billing', 'technical', 'account', 'shipping', 'returns', 'installation', 'subscription']}
|
35 |
|
36 |
+
{"classification_task": "Classify the following movie review into one of the following categories:", "labels": ['critical', 'praise', 'disappointed', 'enthusiastic']}
|
37 |
|
38 |
+
{"classification_task": "Categorize the following customer service transcript into one of the following categories:", "labels": ['satisfied', 'dissatisfied', 'highly-satisfied', 'somewhat-dissatisfied', 'indifferent']}
|
39 |
|
40 |
+
{"classification_task": "Classify the following product description into one of the following product types:", "labels": ['smartphone', 'laptop', 'tablet', 'smartwatch', 'e-reader', 'headphones']}
|
41 |
|
42 |
+
{"classification_task": "Categorize the following tweet expressing the political event discussed as", "labels": ['support', 'opposition']}
|
43 |
|
44 |
+
{"classification_task": "Classify the following restaurant review into one of the following categories:", "labels": ['food-quality', 'service', 'ambiance', 'price']}
|
45 |
+
|
46 |
+
{"classification_task": "Categorize the following blog post based on its primary fashion trend or style:", "labels": ['casual', 'formal', 'streetwear', 'vintage', 'sustainable-fashion']}
|
47 |
|
48 |
User dataset description:
|
49 |
"""
|
|
|
54 |
]
|
55 |
|
56 |
|
57 |
+
class TextClassificationTask(BaseModel):
|
58 |
+
classification_task: str = Field(
|
59 |
+
...,
|
60 |
+
title="classification_task",
|
61 |
+
description="The classification task to be performed.",
|
62 |
+
)
|
63 |
+
|
64 |
+
labels: list[str] = Field(
|
65 |
+
...,
|
66 |
+
title="Labels",
|
67 |
+
description="The possible labels for the classification task.",
|
68 |
+
)
|
69 |
+
|
70 |
+
|
71 |
+
def get_prompt_generator(temperature):
|
72 |
+
prompt_generator = TextGeneration(
|
73 |
+
llm=InferenceEndpointsLLM(
|
74 |
+
api_key=_get_next_api_key(),
|
75 |
+
model_id=MODEL,
|
76 |
+
tokenizer_id=MODEL,
|
77 |
+
structured_output={"format": "json", "schema": TextClassificationTask},
|
78 |
+
generation_kwargs={
|
79 |
+
"temperature": temperature,
|
80 |
+
"max_new_tokens": 2048,
|
81 |
+
"do_sample": True,
|
82 |
+
},
|
83 |
+
),
|
84 |
+
system_prompt=PROMPT_CREATION_PROMPT,
|
85 |
+
use_system_prompt=True,
|
86 |
+
)
|
87 |
+
prompt_generator.load()
|
88 |
+
return prompt_generator
|
89 |
+
|
90 |
+
|
91 |
+
def get_textcat_generator(difficulty, clarity, is_sample):
|
92 |
+
textcat_generator = GenerateTextClassificationData(
|
93 |
+
llm=InferenceEndpointsLLM(
|
94 |
+
model_id=MODEL,
|
95 |
+
tokenizer_id=MODEL,
|
96 |
+
api_key=_get_next_api_key(),
|
97 |
+
generation_kwargs={
|
98 |
+
"temperature": 0.9,
|
99 |
+
"max_new_tokens": 256 if is_sample else 2048,
|
100 |
+
"do_sample": True,
|
101 |
+
"top_k": 50,
|
102 |
+
"top_p": 0.95,
|
103 |
+
},
|
104 |
+
),
|
105 |
+
difficulty=None if difficulty == "mixed" else difficulty,
|
106 |
+
clarity=None if clarity == "mixed" else clarity,
|
107 |
+
seed=random.randint(0, 2**32 - 1),
|
108 |
+
)
|
109 |
+
textcat_generator.load()
|
110 |
+
return textcat_generator
|
111 |
+
|
112 |
+
|
113 |
+
def get_labeller_generator(system_prompt, labels, num_labels):
|
114 |
+
labeller_generator = TextClassification(
|
115 |
+
llm=InferenceEndpointsLLM(
|
116 |
+
model_id=MODEL,
|
117 |
+
tokenizer_id=MODEL,
|
118 |
+
api_key=_get_next_api_key(),
|
119 |
+
generation_kwargs={
|
120 |
+
"temperature": 0.7,
|
121 |
+
"max_new_tokens": 2048,
|
122 |
+
},
|
123 |
+
),
|
124 |
+
context=system_prompt,
|
125 |
+
available_labels=labels,
|
126 |
+
n=num_labels,
|
127 |
+
default_label="unknown",
|
128 |
+
)
|
129 |
+
labeller_generator.load()
|
130 |
+
return labeller_generator
|
131 |
+
|
132 |
+
|
133 |
def generate_pipeline_code(
|
134 |
system_prompt: str,
|
135 |
difficulty: str = None,
|
|
|
225 |
distiset = pipeline.run()
|
226 |
"""
|
227 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|