Commit
•
288d796
1
Parent(s):
adc79ce
feat: Add buttons to align with textcat and textcatgenerator arguments
Browse files
src/distilabel_dataset_generator/apps/textcat.py
CHANGED
@@ -30,7 +30,12 @@ def generate_system_prompt(dataset_description: str) -> str:
|
|
30 |
|
31 |
|
32 |
def generate_dataset(
|
33 |
-
system_prompt: str,
|
|
|
|
|
|
|
|
|
|
|
34 |
) -> pd.DataFrame:
|
35 |
return pd.DataFrame({"prompt": [system_prompt], "completion": [system_prompt]})
|
36 |
|
@@ -69,6 +74,29 @@ def generate_dataset(
|
|
69 |
with app:
|
70 |
with main_ui:
|
71 |
with custom_input_ui:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
72 |
labels = gr.Dropdown(
|
73 |
choices=[],
|
74 |
allow_custom_value=True,
|
@@ -77,14 +105,21 @@ with app:
|
|
77 |
multiselect=True,
|
78 |
)
|
79 |
num_labels = gr.Number(
|
80 |
-
label="Number of labels", value=
|
81 |
)
|
82 |
num_rows = gr.Number(
|
83 |
label="Number of rows", value=10, minimum=1, maximum=500
|
84 |
)
|
85 |
|
86 |
pipeline_code = get_pipeline_code_ui(
|
87 |
-
generate_pipeline_code(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
88 |
)
|
89 |
|
90 |
# define app triggers
|
@@ -97,7 +132,7 @@ with app:
|
|
97 |
outputs=[success_message],
|
98 |
).then(
|
99 |
fn=generate_dataset,
|
100 |
-
inputs=[system_prompt, labels,
|
101 |
outputs=[final_dataset],
|
102 |
show_progress=True,
|
103 |
)
|
@@ -112,7 +147,7 @@ with app:
|
|
112 |
outputs=[success_message],
|
113 |
).success(
|
114 |
fn=generate_dataset,
|
115 |
-
inputs=[system_prompt, labels,
|
116 |
outputs=[final_dataset],
|
117 |
show_progress=True,
|
118 |
).success(
|
@@ -131,7 +166,7 @@ with app:
|
|
131 |
outputs=[success_message],
|
132 |
).then(
|
133 |
fn=generate_dataset,
|
134 |
-
inputs=[system_prompt, labels,
|
135 |
outputs=[final_dataset],
|
136 |
show_progress=True,
|
137 |
).then(
|
@@ -190,16 +225,26 @@ with app:
|
|
190 |
|
191 |
system_prompt.change(
|
192 |
fn=generate_pipeline_code,
|
193 |
-
inputs=[system_prompt, labels,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
194 |
outputs=[pipeline_code],
|
195 |
)
|
196 |
labels.change(
|
197 |
fn=generate_pipeline_code,
|
198 |
-
inputs=[system_prompt, labels,
|
199 |
outputs=[pipeline_code],
|
200 |
)
|
201 |
-
|
202 |
fn=generate_pipeline_code,
|
203 |
-
inputs=[system_prompt, labels,
|
204 |
outputs=[pipeline_code],
|
205 |
)
|
|
|
30 |
|
31 |
|
32 |
def generate_dataset(
|
33 |
+
system_prompt: str,
|
34 |
+
difficulty: str,
|
35 |
+
clarity: str,
|
36 |
+
labels: List[str],
|
37 |
+
num_labels: int,
|
38 |
+
num_rows: int,
|
39 |
) -> pd.DataFrame:
|
40 |
return pd.DataFrame({"prompt": [system_prompt], "completion": [system_prompt]})
|
41 |
|
|
|
74 |
with app:
|
75 |
with main_ui:
|
76 |
with custom_input_ui:
|
77 |
+
difficulty = gr.Dropdown(
|
78 |
+
choices=[
|
79 |
+
("High School", "high school"),
|
80 |
+
("College", "college"),
|
81 |
+
("PhD", "phd"),
|
82 |
+
("Mixed", "mixed"),
|
83 |
+
],
|
84 |
+
value="mixed",
|
85 |
+
label="Difficulty",
|
86 |
+
)
|
87 |
+
clarity = gr.Dropdown(
|
88 |
+
choices=[
|
89 |
+
("Clear", "Clear"),
|
90 |
+
(
|
91 |
+
"Understandable",
|
92 |
+
"understandable with some effort",
|
93 |
+
),
|
94 |
+
("Ambiguous", "Ambiguous"),
|
95 |
+
("Mixed", "mixed"),
|
96 |
+
],
|
97 |
+
value="mixed",
|
98 |
+
label="Clarity",
|
99 |
+
)
|
100 |
labels = gr.Dropdown(
|
101 |
choices=[],
|
102 |
allow_custom_value=True,
|
|
|
105 |
multiselect=True,
|
106 |
)
|
107 |
num_labels = gr.Number(
|
108 |
+
label="Number of labels", value=1, minimum=1, maximum=10
|
109 |
)
|
110 |
num_rows = gr.Number(
|
111 |
label="Number of rows", value=10, minimum=1, maximum=500
|
112 |
)
|
113 |
|
114 |
pipeline_code = get_pipeline_code_ui(
|
115 |
+
generate_pipeline_code(
|
116 |
+
system_prompt.value,
|
117 |
+
difficulty=difficulty.value,
|
118 |
+
clarity=clarity.value,
|
119 |
+
labels=labels.value,
|
120 |
+
num_labels=num_labels.value,
|
121 |
+
num_rows=num_rows.value,
|
122 |
+
)
|
123 |
)
|
124 |
|
125 |
# define app triggers
|
|
|
132 |
outputs=[success_message],
|
133 |
).then(
|
134 |
fn=generate_dataset,
|
135 |
+
inputs=[system_prompt, difficulty, clarity, labels, num_labels, num_rows],
|
136 |
outputs=[final_dataset],
|
137 |
show_progress=True,
|
138 |
)
|
|
|
147 |
outputs=[success_message],
|
148 |
).success(
|
149 |
fn=generate_dataset,
|
150 |
+
inputs=[system_prompt, difficulty, clarity, labels, num_labels, num_rows],
|
151 |
outputs=[final_dataset],
|
152 |
show_progress=True,
|
153 |
).success(
|
|
|
166 |
outputs=[success_message],
|
167 |
).then(
|
168 |
fn=generate_dataset,
|
169 |
+
inputs=[system_prompt, difficulty, clarity, labels, num_labels, num_rows],
|
170 |
outputs=[final_dataset],
|
171 |
show_progress=True,
|
172 |
).then(
|
|
|
225 |
|
226 |
system_prompt.change(
|
227 |
fn=generate_pipeline_code,
|
228 |
+
inputs=[system_prompt, difficulty, clarity, labels, num_labels, num_rows],
|
229 |
+
outputs=[pipeline_code],
|
230 |
+
)
|
231 |
+
difficulty.change(
|
232 |
+
fn=generate_pipeline_code,
|
233 |
+
inputs=[system_prompt, difficulty, clarity, labels, num_labels, num_rows],
|
234 |
+
outputs=[pipeline_code],
|
235 |
+
)
|
236 |
+
clarity.change(
|
237 |
+
fn=generate_pipeline_code,
|
238 |
+
inputs=[system_prompt, difficulty, clarity, labels, num_labels, num_rows],
|
239 |
outputs=[pipeline_code],
|
240 |
)
|
241 |
labels.change(
|
242 |
fn=generate_pipeline_code,
|
243 |
+
inputs=[system_prompt, difficulty, clarity, labels, num_labels, num_rows],
|
244 |
outputs=[pipeline_code],
|
245 |
)
|
246 |
+
num_labels.change(
|
247 |
fn=generate_pipeline_code,
|
248 |
+
inputs=[system_prompt, difficulty, clarity, labels, num_labels, num_rows],
|
249 |
outputs=[pipeline_code],
|
250 |
)
|
src/distilabel_dataset_generator/pipelines/textcat.py
CHANGED
@@ -35,7 +35,12 @@ DEFAULT_SYSTEM_PROMPTS = [
|
|
35 |
|
36 |
|
37 |
def generate_pipeline_code(
|
38 |
-
system_prompt: str,
|
|
|
|
|
|
|
|
|
|
|
39 |
) -> str:
|
40 |
return """
|
41 |
from distilabel import Distilabel
|
|
|
35 |
|
36 |
|
37 |
def generate_pipeline_code(
|
38 |
+
system_prompt: str,
|
39 |
+
difficulty: str,
|
40 |
+
clarity: str,
|
41 |
+
labels: List[str],
|
42 |
+
num_labels: int,
|
43 |
+
num_rows: int,
|
44 |
) -> str:
|
45 |
return """
|
46 |
from distilabel import Distilabel
|