davidberenstein1957 HF staff commited on
Commit
288d796
1 Parent(s): adc79ce

feat: Add buttons to align with textcat and textcatgenerator arguments

Browse files
src/distilabel_dataset_generator/apps/textcat.py CHANGED
@@ -30,7 +30,12 @@ def generate_system_prompt(dataset_description: str) -> str:
30
 
31
 
32
  def generate_dataset(
33
- system_prompt: str, labels: List[str], multi_label: bool
 
 
 
 
 
34
  ) -> pd.DataFrame:
35
  return pd.DataFrame({"prompt": [system_prompt], "completion": [system_prompt]})
36
 
@@ -69,6 +74,29 @@ def generate_dataset(
69
  with app:
70
  with main_ui:
71
  with custom_input_ui:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  labels = gr.Dropdown(
73
  choices=[],
74
  allow_custom_value=True,
@@ -77,14 +105,21 @@ with app:
77
  multiselect=True,
78
  )
79
  num_labels = gr.Number(
80
- label="Number of labels", value=2, minimum=1, maximum=10
81
  )
82
  num_rows = gr.Number(
83
  label="Number of rows", value=10, minimum=1, maximum=500
84
  )
85
 
86
  pipeline_code = get_pipeline_code_ui(
87
- generate_pipeline_code(system_prompt.value, labels.value, multi_label.value)
 
 
 
 
 
 
 
88
  )
89
 
90
  # define app triggers
@@ -97,7 +132,7 @@ with app:
97
  outputs=[success_message],
98
  ).then(
99
  fn=generate_dataset,
100
- inputs=[system_prompt, labels, multi_label],
101
  outputs=[final_dataset],
102
  show_progress=True,
103
  )
@@ -112,7 +147,7 @@ with app:
112
  outputs=[success_message],
113
  ).success(
114
  fn=generate_dataset,
115
- inputs=[system_prompt, labels, multi_label],
116
  outputs=[final_dataset],
117
  show_progress=True,
118
  ).success(
@@ -131,7 +166,7 @@ with app:
131
  outputs=[success_message],
132
  ).then(
133
  fn=generate_dataset,
134
- inputs=[system_prompt, labels, multi_label],
135
  outputs=[final_dataset],
136
  show_progress=True,
137
  ).then(
@@ -190,16 +225,26 @@ with app:
190
 
191
  system_prompt.change(
192
  fn=generate_pipeline_code,
193
- inputs=[system_prompt, labels, multi_label],
 
 
 
 
 
 
 
 
 
 
194
  outputs=[pipeline_code],
195
  )
196
  labels.change(
197
  fn=generate_pipeline_code,
198
- inputs=[system_prompt, labels, multi_label],
199
  outputs=[pipeline_code],
200
  )
201
- multi_label.change(
202
  fn=generate_pipeline_code,
203
- inputs=[system_prompt, labels, multi_label],
204
  outputs=[pipeline_code],
205
  )
 
30
 
31
 
32
  def generate_dataset(
33
+ system_prompt: str,
34
+ difficulty: str,
35
+ clarity: str,
36
+ labels: List[str],
37
+ num_labels: int,
38
+ num_rows: int,
39
  ) -> pd.DataFrame:
40
  return pd.DataFrame({"prompt": [system_prompt], "completion": [system_prompt]})
41
 
 
74
  with app:
75
  with main_ui:
76
  with custom_input_ui:
77
+ difficulty = gr.Dropdown(
78
+ choices=[
79
+ ("High School", "high school"),
80
+ ("College", "college"),
81
+ ("PhD", "phd"),
82
+ ("Mixed", "mixed"),
83
+ ],
84
+ value="mixed",
85
+ label="Difficulty",
86
+ )
87
+ clarity = gr.Dropdown(
88
+ choices=[
89
+ ("Clear", "Clear"),
90
+ (
91
+ "Understandable",
92
+ "understandable with some effort",
93
+ ),
94
+ ("Ambiguous", "Ambiguous"),
95
+ ("Mixed", "mixed"),
96
+ ],
97
+ value="mixed",
98
+ label="Clarity",
99
+ )
100
  labels = gr.Dropdown(
101
  choices=[],
102
  allow_custom_value=True,
 
105
  multiselect=True,
106
  )
107
  num_labels = gr.Number(
108
+ label="Number of labels", value=1, minimum=1, maximum=10
109
  )
110
  num_rows = gr.Number(
111
  label="Number of rows", value=10, minimum=1, maximum=500
112
  )
113
 
114
  pipeline_code = get_pipeline_code_ui(
115
+ generate_pipeline_code(
116
+ system_prompt.value,
117
+ difficulty=difficulty.value,
118
+ clarity=clarity.value,
119
+ labels=labels.value,
120
+ num_labels=num_labels.value,
121
+ num_rows=num_rows.value,
122
+ )
123
  )
124
 
125
  # define app triggers
 
132
  outputs=[success_message],
133
  ).then(
134
  fn=generate_dataset,
135
+ inputs=[system_prompt, difficulty, clarity, labels, num_labels, num_rows],
136
  outputs=[final_dataset],
137
  show_progress=True,
138
  )
 
147
  outputs=[success_message],
148
  ).success(
149
  fn=generate_dataset,
150
+ inputs=[system_prompt, difficulty, clarity, labels, num_labels, num_rows],
151
  outputs=[final_dataset],
152
  show_progress=True,
153
  ).success(
 
166
  outputs=[success_message],
167
  ).then(
168
  fn=generate_dataset,
169
+ inputs=[system_prompt, difficulty, clarity, labels, num_labels, num_rows],
170
  outputs=[final_dataset],
171
  show_progress=True,
172
  ).then(
 
225
 
226
  system_prompt.change(
227
  fn=generate_pipeline_code,
228
+ inputs=[system_prompt, difficulty, clarity, labels, num_labels, num_rows],
229
+ outputs=[pipeline_code],
230
+ )
231
+ difficulty.change(
232
+ fn=generate_pipeline_code,
233
+ inputs=[system_prompt, difficulty, clarity, labels, num_labels, num_rows],
234
+ outputs=[pipeline_code],
235
+ )
236
+ clarity.change(
237
+ fn=generate_pipeline_code,
238
+ inputs=[system_prompt, difficulty, clarity, labels, num_labels, num_rows],
239
  outputs=[pipeline_code],
240
  )
241
  labels.change(
242
  fn=generate_pipeline_code,
243
+ inputs=[system_prompt, difficulty, clarity, labels, num_labels, num_rows],
244
  outputs=[pipeline_code],
245
  )
246
+ num_labels.change(
247
  fn=generate_pipeline_code,
248
+ inputs=[system_prompt, difficulty, clarity, labels, num_labels, num_rows],
249
  outputs=[pipeline_code],
250
  )
src/distilabel_dataset_generator/pipelines/textcat.py CHANGED
@@ -35,7 +35,12 @@ DEFAULT_SYSTEM_PROMPTS = [
35
 
36
 
37
  def generate_pipeline_code(
38
- system_prompt: str, labels: List[str], multi_label: bool
 
 
 
 
 
39
  ) -> str:
40
  return """
41
  from distilabel import Distilabel
 
35
 
36
 
37
  def generate_pipeline_code(
38
+ system_prompt: str,
39
+ difficulty: str,
40
+ clarity: str,
41
+ labels: List[str],
42
+ num_labels: int,
43
+ num_rows: int,
44
  ) -> str:
45
  return """
46
  from distilabel import Distilabel