sdiazlor HF staff commited on
Commit
fb096d2
·
verified ·
1 Parent(s): 7afacd5

update-layout-add-evaluation (#17)

Browse files

- add comment to divide functions/ui (c2fbbc311616f60f3dbc1e546d20517329b41486)
- fix typo (3ef1fed40b702eb482f9017241a2d3641cde3946)
- move sign in button to another column (ea29202ac74e3b567f8bbc83f2dd01edeb7e00b5)
- make sign in button smaller (3b5e775206b9e2276d8d5afd14d5a2ae6eb0f852)
- remove repeated import (9c1769a069aa5cadade2e120ebbfbb3524b4a71c)
- move sign in button to the right (5d91425bc46bbed12e76279e32eed828738f2e78)
- modify column width and typos (2b5c2e3fa953fcd1e7bbc5d50fa1eaa18cf51a7f)
- update successful message and pipeline code (4234ad816ad254da4dc0a2bdf4f6ee901f3ab647)
- update dataframe visualizations (7350fc6b3cdabd3c88562f7ebea772ea936b293b)
- update text and order parameter layout (45693e1d9d5340197d0a5298329a1b176836e5e9)
- typo (2673ebc69f8b3ee53ca0bea400abe8b18dcec6c7)
- add temperature for system prompt (857f1ba71f10ddb10f045923601746daed130b19)
- update textcat (separate prompt and labels) and use input parameters (4e193106207eda3f59650448038a680c25075972)
- update sft and use input parameters (dea11022bc5c78e08481e4e90bbb73b0402cdadc)
- update push dataset (49d5948eb076fc8b3354a9d4acdaac477fc0c398)
- add evaluation task (34371d30aa99cdb709c9def84739ab3b8b7fa611)
- hide pipeline ui each time it generates (c26510fcff621c6a144917e1a56d5f87dd41fd41)
- move order hide pipeline ui (1b00519115b913bff86a6f2ba061f97eb860e78a)
- merge remote tracking branch (1c412e2113c3889b13572af931a3be19fc93df5a)

app.py CHANGED
@@ -3,6 +3,7 @@ import gradio as gr
3
  from src.distilabel_dataset_generator._tabbedinterface import TabbedInterface
4
  from src.distilabel_dataset_generator.apps.faq import app as faq_app
5
  from src.distilabel_dataset_generator.apps.sft import app as sft_app
 
6
  from src.distilabel_dataset_generator.apps.textcat import app as textcat_app
7
 
8
  theme ='argilla/argilla-theme'
@@ -25,12 +26,11 @@ button.hf-login:hover {background: var(--neutral-700); color: white}
25
  """
26
 
27
  demo = TabbedInterface(
28
- [textcat_app, sft_app, faq_app],
29
- ["Text Classification", "Supervised Fine-Tuning", "FAQ"],
30
  css=css,
31
  title="""
32
  <h1>Synthetic Data Generator</h1>
33
- <h3>Build datasets using natural language</h3>
34
  """,
35
  head="Synthetic Data Generator",
36
  theme=theme,
 
3
  from src.distilabel_dataset_generator._tabbedinterface import TabbedInterface
4
  from src.distilabel_dataset_generator.apps.faq import app as faq_app
5
  from src.distilabel_dataset_generator.apps.sft import app as sft_app
6
+ from src.distilabel_dataset_generator.apps.eval import app as eval_app
7
  from src.distilabel_dataset_generator.apps.textcat import app as textcat_app
8
 
9
  theme ='argilla/argilla-theme'
 
26
  """
27
 
28
  demo = TabbedInterface(
29
+ [textcat_app, sft_app, eval_app, faq_app],
30
+ ["Text Classification", "Supervised Fine-Tuning", "Evaluation", "FAQ"],
31
  css=css,
32
  title="""
33
  <h1>Synthetic Data Generator</h1>
 
34
  """,
35
  head="Synthetic Data Generator",
36
  theme=theme,
pyproject.toml CHANGED
@@ -6,7 +6,7 @@ authors = [
6
  {name = "davidberenstein1957", email = "david.m.berenstein@gmail.com"},
7
  ]
8
  dependencies = [
9
- "distilabel[hf-inference-endpoints,argilla,outlines]>=1.4.1",
10
  "gradio[oauth]<5.0.0",
11
  "transformers>=4.44.2",
12
  "sentence-transformers>=3.2.0",
 
6
  {name = "davidberenstein1957", email = "david.m.berenstein@gmail.com"},
7
  ]
8
  dependencies = [
9
+ "distilabel[hf-inference-endpoints,argilla,outlines,instructor]>=1.4.1",
10
  "gradio[oauth]<5.0.0",
11
  "transformers>=4.44.2",
12
  "sentence-transformers>=3.2.0",
src/distilabel_dataset_generator/_tabbedinterface.py CHANGED
@@ -63,10 +63,12 @@ class TabbedInterface(Blocks):
63
  if title:
64
  HTML(value=title)
65
  with gr.Row():
66
- with gr.Column(scale=1):
67
- gr.LoginButton(value="Sign in!", variant="hf-login", size="sm", scale=2)
68
  with gr.Column(scale=3):
69
  pass
 
 
70
  with Tabs():
71
  for interface, tab_name in zip(interface_list, tab_names, strict=False):
72
  with Tab(label=tab_name):
 
63
  if title:
64
  HTML(value=title)
65
  with gr.Row():
66
+ with gr.Column(scale=2):
67
+ gr.Markdown("### Build datasets using natural language")
68
  with gr.Column(scale=3):
69
  pass
70
+ with gr.Column(scale=2):
71
+ gr.LoginButton(value="Sign in!", variant="hf-login", size="sm", scale=2)
72
  with Tabs():
73
  for interface, tab_name in zip(interface_list, tab_names, strict=False):
74
  with Tab(label=tab_name):
src/distilabel_dataset_generator/apps/base.py CHANGED
@@ -15,7 +15,7 @@ from src.distilabel_dataset_generator.utils import (
15
  get_argilla_client,
16
  get_login_button,
17
  list_orgs,
18
- swap_visibilty,
19
  )
20
 
21
  TEXTCAT_TASK = "text_classification"
@@ -137,7 +137,7 @@ def get_main_ui(
137
  show_progress=True,
138
  )
139
 
140
- app.load(fn=swap_visibilty, outputs=main_ui)
141
  app.load(get_org_dropdown, outputs=[org_name])
142
 
143
  return (
@@ -300,25 +300,6 @@ def get_iterate_on_sample_dataset_ui(
300
  )
301
 
302
 
303
- def get_pipeline_code_ui(pipeline_code: str) -> gr.Code:
304
- gr.Markdown("## Customize and run with distilabel")
305
- gr.HTML("<hr>")
306
-
307
- with gr.Accordion(
308
- "Run this pipeline using distilabel",
309
- open=False,
310
- ):
311
- gr.Markdown(
312
- "You can run this pipeline locally with distilabel. For more information, please refer to the [distilabel documentation](https://distilabel.argilla.io/) or go to the FAQ tab at the top of the page for more information."
313
- )
314
- pipeline_code = gr.Code(
315
- value=pipeline_code,
316
- language="python",
317
- label="Distilabel Pipeline Code",
318
- )
319
- return pipeline_code
320
-
321
-
322
  def get_argilla_tab() -> Tuple[Any]:
323
  with gr.Tab(label="Argilla"):
324
  if get_argilla_client() is not None:
@@ -492,7 +473,7 @@ def get_success_message_row() -> gr.Markdown:
492
  return success_message
493
 
494
 
495
- def show_success_message_hub(org_name, repo_name) -> gr.Markdown:
496
  client = get_argilla_client()
497
  argilla_api_url = client.api_url
498
  return gr.Markdown(
@@ -500,25 +481,27 @@ def show_success_message_hub(org_name, repo_name) -> gr.Markdown:
500
  <div style="padding: 1em; background-color: #e6f3e6; border-radius: 5px; margin-top: 1em;">
501
  <h3 style="color: #2e7d32; margin: 0;">Dataset Published Successfully!</h3>
502
  <p style="margin-top: 0.5em;">
503
- Your dataset is now available the Hugging Face Hub:
504
- <a href="https://huggingface.co/datasets/{org_name}/{repo_name}" target="_blank" style="color: #1565c0; text-decoration: none;">
505
- https://huggingface.co/datasets/{org_name}/{repo_name}
506
- </a>
 
507
  </p>
508
  <p style="margin-top: 0.5em;">
509
- Your dataset is now available within Argilla:
510
- <a href="{argilla_api_url}" target="_blank" style="color: #1565c0; text-decoration: none;">
511
- {argilla_api_url}
512
  </a>
513
- <br>Unfamiliar with Argilla? Here are some docs to help you get started:
514
- <br>• <a href="https://docs.argilla.io/latest/how_to_guides/annotate/" target="_blank">How to curate data in Argilla</a>
515
- <br>• <a href="https://docs.argilla.io/latest/how_to_guides/import_export/" target="_blank">How to export data once you have reviewed the dataset</a>
516
  </p>
517
  </div>
 
 
 
 
 
518
  """,
519
  visible=True,
520
  )
521
 
522
-
523
  def hide_success_message() -> gr.Markdown:
524
  return gr.Markdown(value="")
 
15
  get_argilla_client,
16
  get_login_button,
17
  list_orgs,
18
+ swap_visibility,
19
  )
20
 
21
  TEXTCAT_TASK = "text_classification"
 
137
  show_progress=True,
138
  )
139
 
140
+ app.load(fn=swap_visibility, outputs=main_ui)
141
  app.load(get_org_dropdown, outputs=[org_name])
142
 
143
  return (
 
300
  )
301
 
302
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
303
  def get_argilla_tab() -> Tuple[Any]:
304
  with gr.Tab(label="Argilla"):
305
  if get_argilla_client() is not None:
 
473
  return success_message
474
 
475
 
476
+ def show_success_message(org_name, repo_name) -> gr.Markdown:
477
  client = get_argilla_client()
478
  argilla_api_url = client.api_url
479
  return gr.Markdown(
 
481
  <div style="padding: 1em; background-color: #e6f3e6; border-radius: 5px; margin-top: 1em;">
482
  <h3 style="color: #2e7d32; margin: 0;">Dataset Published Successfully!</h3>
483
  <p style="margin-top: 0.5em;">
484
+ <strong>
485
+ <a href="{argilla_api_url}" target="_blank" style="color: #1565c0; text-decoration: none;">
486
+ Open your dataset in the Argilla space
487
+ </a>
488
+ </strong>
489
  </p>
490
  <p style="margin-top: 0.5em;">
491
+ The generated dataset is in the right format for fine-tuning with TRL, AutoTrain, or other frameworks. Your dataset is now available at:
492
+ <a href="https://huggingface.co/datasets/{org_name}/{repo_name}" target="_blank" style="color: #1565c0; text-decoration: none;">
493
+ https://huggingface.co/datasets/{org_name}/{repo_name}
494
  </a>
 
 
 
495
  </p>
496
  </div>
497
+ <p style="margin-top: 1em; font-size: 0.9em; color: #333;">
498
+ Unfamiliar with Argilla? Here are some docs to help you get started:
499
+ <br>• <a href="https://docs.argilla.io/latest/how_to_guides/annotate/" target="_blank">How to curate data in Argilla</a>
500
+ <br>• <a href="https://docs.argilla.io/latest/how_to_guides/import_export/" target="_blank">How to export data once you have reviewed the dataset</a>
501
+ </p>
502
  """,
503
  visible=True,
504
  )
505
 
 
506
  def hide_success_message() -> gr.Markdown:
507
  return gr.Markdown(value="")
src/distilabel_dataset_generator/apps/eval.py CHANGED
@@ -1,70 +1,106 @@
1
  import json
 
 
2
 
 
3
  import gradio as gr
 
4
  import pandas as pd
5
- from datasets import load_dataset
 
 
 
 
 
 
6
  from gradio_huggingfacehub_search import HuggingfaceHubSearch
 
7
 
8
- from src.distilabel_dataset_generator.utils import get_org_dropdown
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
 
11
- def get_iframe(hub_repo_id) -> str:
12
  if not hub_repo_id:
13
- raise gr.Error("Hub repo id is required")
 
14
  url = f"https://huggingface.co/datasets/{hub_repo_id}/embed/viewer"
15
  iframe = f"""
16
  <iframe
17
- src="{url}"
18
- frameborder="0"
19
- width="100%"
20
- height="600px"
21
- ></iframe>
22
- """
23
  return iframe
24
 
25
 
26
- def get_valid_columns(df: pd.DataFrame):
27
- valid_columns = []
28
- for col in df.columns:
29
- sample_val = df[col].iloc[0]
 
 
30
  if isinstance(sample_val, str) or (
31
- isinstance(sample_val, list)
32
- and all(isinstance(item, dict) for item in sample_val)
33
  ):
34
- valid_columns.append(col)
35
- return valid_columns
 
 
 
 
36
 
 
37
 
38
- def load_dataset_from_hub(hub_repo_id: str, n_rows: int = 10):
39
- gr.Info(message="Loading dataset ...")
40
- if not hub_repo_id:
41
  raise gr.Error("Hub repo id is required")
42
- ds_dict = load_dataset(hub_repo_id)
43
- splits = list(ds_dict.keys())
 
44
  ds = ds_dict[splits[0]]
45
- if n_rows:
46
- ds = ds.select(range(n_rows))
47
- df = ds.to_pandas()
48
- # Get columns that contain either strings or lists of dictionaries
49
- valid_columns = get_valid_columns(df)
50
  return (
51
- df,
52
- gr.Dropdown(choices=valid_columns, label="Instruction Column"),
53
- gr.Dropdown(choices=valid_columns, label="Instruction Column"),
54
- gr.Dropdown(choices=valid_columns, label="Response Column"),
55
  )
56
 
57
 
58
  def define_evaluation_aspects(task_type: str):
59
- if task_type == "instruction":
60
- return gr.Dropdown(
61
- value=["overall-rating"],
62
- choices=["complexity", "quality"],
63
- label="Evaluation Aspects",
64
- multiselect=True,
65
- interactive=True,
66
- )
67
- elif task_type == "instruction-response":
68
  return gr.Dropdown(
69
  value=["overall-rating"],
70
  choices=["helpfulness", "truthfulness", "overall-rating", "honesty"],
@@ -76,226 +112,635 @@ def define_evaluation_aspects(task_type: str):
76
  return gr.Dropdown(interactive=False, visible=False)
77
 
78
 
79
- def evaluate_instruction(df: pd.DataFrame, aspects: list[str], instruction_column: str):
80
- pass
81
-
82
-
83
  def evaluate_instruction_response(
84
- df: pd.DataFrame, aspects: list[str], instruction_column: str, response_column: str
 
 
 
 
 
 
85
  ):
86
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
 
88
 
89
  def evaluate_custom(
90
- df: pd.DataFrame, aspects: list[str], prompt_template: str, structured_output: dict
 
 
 
 
 
91
  ):
92
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
 
 
 
 
94
 
95
- def _apply_to_dataset(
96
- df: pd.DataFrame,
 
 
 
 
 
 
 
 
 
 
 
97
  eval_type: str,
98
- aspects_instruction: list[str],
99
- instruction_column: str,
100
  aspects_instruction_response: list[str],
101
- instruction_column_response: str,
102
- response_column_response: str,
103
- aspects_custom: list[str],
104
  prompt_template: str,
105
  structured_output: dict,
 
 
106
  ):
107
- if eval_type == "instruction":
108
- df = evaluate_instruction(df, aspects_instruction, instruction_column)
109
- elif eval_type == "instruction-response":
110
- df = evaluate_instruction_response(
111
- df,
112
- aspects_instruction_response,
113
- instruction_column_response,
114
- response_column_response,
 
 
 
 
 
 
 
 
115
  )
116
- elif eval_type == "custom":
117
- df = evaluate_custom(df, aspects_custom, prompt_template, structured_output)
118
- return df
119
 
120
 
121
- def apply_to_sample_dataset(
122
  repo_id: str,
123
  eval_type: str,
124
- aspects_instruction: list[str],
125
  aspects_instruction_response: list[str],
126
- aspects_custom: list[str],
127
- instruction_instruction: str,
128
  instruction_instruction_response: str,
129
  response_instruction_response: str,
130
  prompt_template: str,
131
  structured_output: dict,
132
  ):
133
- df, _, _, _ = load_dataset_from_hub(repo_id, n_rows=10)
134
- df = _apply_to_dataset(
135
- df,
136
- eval_type,
137
- aspects_instruction,
138
- instruction_instruction,
139
- aspects_instruction_response,
140
- instruction_instruction_response,
141
- response_instruction_response,
142
- aspects_custom,
143
- prompt_template,
144
- structured_output,
145
  )
146
- return df
147
 
148
 
149
- def push_to_hub(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
  org_name: str,
151
  repo_name: str,
152
  private: bool,
153
- n_rows: int,
154
  original_repo_id: str,
155
  eval_type: str,
156
- aspects_instruction: list[str],
157
  aspects_instruction_response: list[str],
158
- aspects_custom: list[str],
159
- instruction_instruction: str,
160
  instruction_instruction_response: str,
161
  response_instruction_response: str,
162
  prompt_template: str,
163
  structured_output: dict,
164
- ):
165
- df, _, _, _ = load_dataset_from_hub(original_repo_id, n_rows=n_rows)
166
- df = _apply_to_dataset(
167
- df,
168
- eval_type,
169
- aspects_instruction,
170
- instruction_instruction,
171
- aspects_instruction_response,
172
- instruction_instruction_response,
173
- response_instruction_response,
174
- aspects_custom,
175
- prompt_template,
176
- structured_output,
177
  )
178
- new_repo_id = f"{org_name}/{repo_name}"
179
-
180
-
181
- with gr.Blocks() as app:
182
- gr.Markdown("## 1. Select your input dataset")
183
- with gr.Row():
184
- with gr.Column(scale=1):
185
- search_in = HuggingfaceHubSearch(
186
- label="Search",
187
- placeholder="Search for a Dataset",
188
- search_type="dataset",
189
- sumbit_on_select=True,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
190
  )
191
- load_btn = gr.Button("Load dataset", variant="primary")
192
- with gr.Column(scale=3):
193
- search_out = gr.HTML(label="Dataset Preview")
194
-
195
- gr.HTML("<hr>")
196
- gr.Markdown("## 2. Configure your task")
197
- with gr.Row():
198
- with gr.Column(scale=1):
199
- eval_type = gr.Dropdown(
200
- label="Evaluation Type",
201
- choices=["instruction", "instruction-response", "custom-template"],
202
- visible=False,
203
  )
204
- with gr.Tab("instruction") as tab_instruction:
205
- aspects_instruction = define_evaluation_aspects("instruction")
206
- instruction_instruction = gr.Dropdown(
207
- label="Instruction Column", interactive=True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
208
  )
209
- tab_instruction.select(
210
- lambda: "instruction",
211
- inputs=[],
212
- outputs=[eval_type],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
213
  )
214
- with gr.Tab("instruction-response") as tab_instruction_response:
215
- aspects_instruction_response = define_evaluation_aspects(
216
- "instruction-response"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
217
  )
218
- instruction_instruction_response = gr.Dropdown(
219
- label="Instruction Column", interactive=True
 
 
 
 
 
 
 
 
 
 
 
 
220
  )
221
- response_instruction_response = gr.Dropdown(
222
- label="Response Column", interactive=True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
223
  )
224
- tab_instruction_response.select(
225
- lambda: "instruction-response",
226
- inputs=[],
227
- outputs=[eval_type],
 
 
228
  )
229
- with gr.Tab("custom") as tab_custom:
230
- aspects_custom = define_evaluation_aspects("custom")
231
- prompt_template = gr.Code(
232
- label="Prompt Template",
233
- value="{{column_1}} based on {{column_2}}",
234
- language="markdown",
 
 
 
 
235
  interactive=True,
236
  )
237
- structured_output = gr.Code(
238
- label="Structured Output",
239
- value=json.dumps({"eval_aspect": "str"}),
240
- language="json",
241
  interactive=True,
 
242
  )
243
- tab_custom.select(
244
- lambda: "custom-template",
245
- inputs=[],
246
- outputs=[eval_type],
 
247
  )
248
- btn_apply_to_sample_dataset = gr.Button("Refresh dataset")
249
- with gr.Column(scale=3):
250
- dataframe = gr.Dataframe()
251
-
252
- gr.HTML("<hr>")
253
- gr.Markdown("## 3. Generate your dataset")
254
- with gr.Row():
255
- with gr.Column(scale=1):
256
- org_name = get_org_dropdown()
257
- repo_name = gr.Textbox(
258
- label="Repo name",
259
- placeholder="dataset_name",
260
- value="my-distiset",
261
- interactive=True,
262
- )
263
- n_rows = gr.Number(
264
- label="Number of rows",
265
- value=10,
266
- interactive=True,
267
- scale=1,
268
- )
269
- private = gr.Checkbox(
270
- label="Private dataset",
271
- value=False,
272
- interactive=True,
273
- scale=1,
274
- )
275
- btn_push_to_hub = gr.Button("Push to Hub", variant="primary", scale=2)
276
- with gr.Column(scale=3):
277
- success_message = gr.Markdown(visible=False)
278
 
279
- search_in.submit(get_iframe, inputs=search_in, outputs=search_out)
280
  load_btn.click(
281
- load_dataset_from_hub,
282
  inputs=[search_in],
283
  outputs=[
284
  dataframe,
285
- instruction_instruction,
286
  instruction_instruction_response,
287
  response_instruction_response,
288
  ],
289
  )
 
290
  btn_apply_to_sample_dataset.click(
291
- apply_to_sample_dataset,
292
  inputs=[
293
  search_in,
294
  eval_type,
295
- aspects_instruction,
296
  aspects_instruction_response,
297
- aspects_custom,
298
- instruction_instruction,
299
  instruction_instruction_response,
300
  response_instruction_response,
301
  prompt_template,
@@ -303,24 +748,64 @@ with gr.Blocks() as app:
303
  ],
304
  outputs=dataframe,
305
  )
 
306
  btn_push_to_hub.click(
307
- push_to_hub,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
308
  inputs=[
309
  org_name,
310
  repo_name,
311
  private,
312
- n_rows,
313
  search_in,
314
  eval_type,
315
- aspects_instruction,
316
  aspects_instruction_response,
317
- aspects_custom,
318
- instruction_instruction,
319
  instruction_instruction_response,
320
  response_instruction_response,
321
  prompt_template,
322
  structured_output,
323
  ],
324
- outputs=success_message,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
325
  )
 
 
326
  app.load(fn=get_org_dropdown, outputs=[org_name])
 
1
  import json
2
+ import uuid
3
+ from typing import Union
4
 
5
+ import argilla as rg
6
  import gradio as gr
7
+ import numpy as np
8
  import pandas as pd
9
+ from datasets import (
10
+ Dataset,
11
+ get_dataset_config_names,
12
+ get_dataset_split_names,
13
+ load_dataset,
14
+ )
15
+ from distilabel.distiset import Distiset
16
  from gradio_huggingfacehub_search import HuggingfaceHubSearch
17
+ from huggingface_hub import HfApi
18
 
19
+ from src.distilabel_dataset_generator.apps.base import (
20
+ hide_success_message,
21
+ show_success_message,
22
+ validate_argilla_user_workspace_dataset,
23
+ validate_push_to_hub,
24
+ )
25
+ from src.distilabel_dataset_generator.pipelines.base import (
26
+ DEFAULT_BATCH_SIZE,
27
+ )
28
+ from src.distilabel_dataset_generator.pipelines.embeddings import (
29
+ get_embeddings,
30
+ get_sentence_embedding_dimensions,
31
+ )
32
+ from src.distilabel_dataset_generator.pipelines.eval import (
33
+ generate_pipeline_code,
34
+ get_custom_evaluator,
35
+ get_ultrafeedback_evaluator,
36
+ )
37
+ from src.distilabel_dataset_generator.utils import (
38
+ column_to_list,
39
+ extract_column_names,
40
+ get_argilla_client,
41
+ get_org_dropdown,
42
+ process_columns,
43
+ swap_visibility,
44
+ pad_or_truncate_list,
45
+ )
46
 
47
 
48
+ def get_iframe(hub_repo_id: str) -> str:
49
  if not hub_repo_id:
50
+ raise gr.Error("Hub repository ID is required.")
51
+
52
  url = f"https://huggingface.co/datasets/{hub_repo_id}/embed/viewer"
53
  iframe = f"""
54
  <iframe
55
+ src="{url}"
56
+ frameborder="0"
57
+ width="100%"
58
+ height="600px"
59
+ ></iframe>
60
+ """
61
  return iframe
62
 
63
 
64
+ def get_valid_columns(dataframe: pd.DataFrame):
65
+ instruction_valid_columns = []
66
+ response_valid_columns = []
67
+
68
+ for col in dataframe.columns:
69
+ sample_val = dataframe[col].iloc[0]
70
  if isinstance(sample_val, str) or (
71
+ isinstance(sample_val, (list, np.ndarray))
72
+ and all(isinstance(item, dict) and "role" in item for item in sample_val)
73
  ):
74
+ instruction_valid_columns.append(col)
75
+ response_valid_columns.append(col)
76
+ if isinstance(sample_val, (list, np.ndarray)) and all(
77
+ isinstance(item, str) for item in sample_val
78
+ ):
79
+ response_valid_columns.append(col)
80
 
81
+ return instruction_valid_columns, response_valid_columns
82
 
83
+
84
+ def load_dataset_from_hub(repo_id: str, num_rows: int = 10):
85
+ if not repo_id:
86
  raise gr.Error("Hub repo id is required")
87
+ subsets = get_dataset_config_names(repo_id)
88
+ ds_dict = load_dataset(repo_id, subsets[0])
89
+ splits = get_dataset_split_names(repo_id, subsets[0])
90
  ds = ds_dict[splits[0]]
91
+ if num_rows:
92
+ ds = ds.select(range(num_rows))
93
+ dataframe = ds.to_pandas()
94
+ instruction_valid_columns, response_valid_columns = get_valid_columns(dataframe)
 
95
  return (
96
+ dataframe,
97
+ gr.Dropdown(choices=instruction_valid_columns, label="Instruction column"),
98
+ gr.Dropdown(choices=response_valid_columns, label="Response column"),
 
99
  )
100
 
101
 
102
  def define_evaluation_aspects(task_type: str):
103
+ if task_type == "ultrafeedback":
 
 
 
 
 
 
 
 
104
  return gr.Dropdown(
105
  value=["overall-rating"],
106
  choices=["helpfulness", "truthfulness", "overall-rating", "honesty"],
 
112
  return gr.Dropdown(interactive=False, visible=False)
113
 
114
 
 
 
 
 
115
  def evaluate_instruction_response(
116
+ dataframe: pd.DataFrame,
117
+ aspects: list[str],
118
+ instruction_column: str,
119
+ response_columns: str,
120
+ num_rows: int = 10,
121
+ is_sample: bool = False,
122
+ progress=gr.Progress(),
123
  ):
124
+ progress(0.0, desc="Evaluating instructions and responses")
125
+ data = process_columns(dataframe, instruction_column, response_columns)
126
+ num_generations = len(data[0]["generations"])
127
+ evaluated_results = []
128
+ for entry in data:
129
+ result_row = {
130
+ "instruction": entry["instruction"],
131
+ "generations": entry["generations"],
132
+ }
133
+ for aspect in aspects:
134
+ result_row[f"ratings_{aspect}"] = None
135
+ result_row[f"rationale_for_ratings_{aspect}"] = None
136
+ if aspect in ["truthfulness", "helpfulness"]:
137
+ result_row[f"type_{aspect}"] = None
138
+ result_row[f"rationale_for_type_{aspect}"] = None
139
+ result_row["model_name"] = None
140
+ evaluated_results.append(result_row)
141
+
142
+ batch_size = DEFAULT_BATCH_SIZE
143
+ total_steps: int = len(aspects) * num_rows
144
+
145
+ # evaluate instructions and responses
146
+ for aspect in aspects:
147
+ ultrafeedback_evaluator = get_ultrafeedback_evaluator(aspect, is_sample)
148
+ n_processed = 0
149
+
150
+ while n_processed < num_rows:
151
+ progress(
152
+ (len(aspects) * n_processed) / total_steps,
153
+ total=total_steps,
154
+ desc=f"Evaluating aspect: {aspect}",
155
+ )
156
+
157
+ remaining_rows = num_rows - n_processed
158
+ batch_size = min(batch_size, remaining_rows)
159
+ inputs = data[n_processed : n_processed + batch_size]
160
+ batch_results = list(ultrafeedback_evaluator.process(inputs=inputs))
161
+ for j, result in enumerate(batch_results[0]):
162
+ idx = n_processed + j
163
+ evaluated_results[idx][f"ratings_{aspect}"] = pad_or_truncate_list(
164
+ result.get("ratings"), num_generations
165
+ )
166
+ evaluated_results[idx]["model_name"] = result.get("model_name")
167
+ if aspect in ["truthfulness", "helpfulness"]:
168
+ evaluated_results[idx][f"type_{aspect}"] = pad_or_truncate_list(
169
+ result.get("types"), num_generations
170
+ )
171
+ evaluated_results[idx][f"rationale_for_type_{aspect}"] = (
172
+ pad_or_truncate_list(result.get("rationales"), num_generations)
173
+ )
174
+ evaluated_results[idx][f"rationale_for_ratings_{aspect}"] = (
175
+ pad_or_truncate_list(
176
+ result.get("rationales-for-ratings"), num_generations
177
+ )
178
+ )
179
+ else:
180
+ evaluated_results[idx][f"rationale_for_ratings_{aspect}"] = (
181
+ pad_or_truncate_list(result.get("rationales"), num_generations)
182
+ )
183
+ n_processed += batch_size
184
+
185
+ # create final dataset
186
+ dataframe = pd.DataFrame(evaluated_results)
187
+ progress(1.0, desc="Dataset evaluation completed")
188
+ return dataframe
189
 
190
 
191
  def evaluate_custom(
192
+ dataframe: pd.DataFrame,
193
+ prompt_template: str,
194
+ structured_output: dict,
195
+ num_rows: int = 10,
196
+ is_sample: bool = False,
197
+ progress=gr.Progress(),
198
  ):
199
+ progress(0.0, desc="Evaluating dataset")
200
+ columns = extract_column_names(prompt_template)
201
+ input_columns = {column: column_to_list(dataframe, column) for column in columns}
202
+
203
+ custom_evaluator = get_custom_evaluator(
204
+ prompt_template, structured_output, columns, is_sample
205
+ )
206
+ batch_size = DEFAULT_BATCH_SIZE
207
+
208
+ # evaluate the data
209
+ n_processed = 0
210
+ evaluation_results = []
211
+ while n_processed < num_rows:
212
+ progress(
213
+ n_processed / num_rows,
214
+ desc="Evaluating dataset",
215
+ )
216
+ remaining_rows = num_rows - n_processed
217
+ batch_size = min(batch_size, remaining_rows)
218
+
219
+ inputs = []
220
+ for idx in range(n_processed, n_processed + batch_size):
221
+ input = {column: input_columns[column][idx] for column in input_columns}
222
+ inputs.append(input)
223
 
224
+ batch = list(custom_evaluator.process(inputs=inputs))
225
+ evaluation_results.extend(batch[0])
226
+ n_processed += batch_size
227
 
228
+ # create final dataset
229
+ distiset_results = []
230
+ for result in evaluation_results:
231
+ record = {key: result[key] for key in result if key != "distilabel_metadata"}
232
+ distiset_results.append(record)
233
+
234
+ dataframe = pd.DataFrame(distiset_results)
235
+ progress(1.0, desc="Dataset evaluation completed")
236
+ return dataframe
237
+
238
+
239
+ def _evaluate_dataset(
240
+ dataframe: pd.DataFrame,
241
  eval_type: str,
 
 
242
  aspects_instruction_response: list[str],
243
+ instruction_instruction_response: str,
244
+ response_instruction_response: str,
 
245
  prompt_template: str,
246
  structured_output: dict,
247
+ num_rows: int = 10,
248
+ is_sample: bool = False,
249
  ):
250
+ if eval_type == "ultrafeedback":
251
+ dataframe = evaluate_instruction_response(
252
+ dataframe=dataframe,
253
+ aspects=aspects_instruction_response,
254
+ instruction_column=instruction_instruction_response,
255
+ response_columns=response_instruction_response,
256
+ num_rows=num_rows,
257
+ is_sample=is_sample,
258
+ )
259
+ else:
260
+ dataframe = evaluate_custom(
261
+ dataframe=dataframe,
262
+ prompt_template=prompt_template,
263
+ structured_output=structured_output,
264
+ num_rows=num_rows,
265
+ is_sample=is_sample,
266
  )
267
+ return dataframe
 
 
268
 
269
 
270
+ def evaluate_sample_dataset(
271
  repo_id: str,
272
  eval_type: str,
 
273
  aspects_instruction_response: list[str],
 
 
274
  instruction_instruction_response: str,
275
  response_instruction_response: str,
276
  prompt_template: str,
277
  structured_output: dict,
278
  ):
279
+ dataframe, _, _ = load_dataset_from_hub(repo_id, num_rows=10)
280
+ dataframe = _evaluate_dataset(
281
+ dataframe=dataframe,
282
+ eval_type=eval_type,
283
+ aspects_instruction_response=aspects_instruction_response,
284
+ instruction_instruction_response=instruction_instruction_response,
285
+ response_instruction_response=response_instruction_response,
286
+ prompt_template=prompt_template,
287
+ structured_output=structured_output,
288
+ num_rows=10,
289
+ is_sample=True,
 
290
  )
291
+ return dataframe
292
 
293
 
294
+ def push_dataset_to_hub(
295
+ dataframe: pd.DataFrame, org_name: str, repo_name: str, oauth_token, private
296
+ ):
297
+ repo_id = validate_push_to_hub(org_name, repo_name)
298
+ distiset = Distiset({"default": Dataset.from_pandas(dataframe)})
299
+ distiset.push_to_hub(
300
+ repo_id=repo_id,
301
+ private=private,
302
+ include_script=False,
303
+ token=oauth_token.token,
304
+ create_pr=False,
305
+ )
306
+
307
+
308
+ def push_dataset(
309
  org_name: str,
310
  repo_name: str,
311
  private: bool,
312
+ num_rows: int,
313
  original_repo_id: str,
314
  eval_type: str,
 
315
  aspects_instruction_response: list[str],
 
 
316
  instruction_instruction_response: str,
317
  response_instruction_response: str,
318
  prompt_template: str,
319
  structured_output: dict,
320
+ oauth_token: Union[gr.OAuthToken, None] = None,
321
+ progress=gr.Progress(),
322
+ ) -> pd.DataFrame:
323
+ dataframe, _, _ = load_dataset_from_hub(original_repo_id, num_rows=num_rows)
324
+ dataframe = _evaluate_dataset(
325
+ dataframe=dataframe,
326
+ eval_type=eval_type,
327
+ aspects_instruction_response=aspects_instruction_response,
328
+ instruction_instruction_response=instruction_instruction_response,
329
+ response_instruction_response=response_instruction_response,
330
+ prompt_template=prompt_template,
331
+ structured_output=structured_output,
332
+ num_rows=num_rows,
333
  )
334
+ push_dataset_to_hub(dataframe, org_name, repo_name, oauth_token, private)
335
+ try:
336
+ progress(0.1, desc="Setting up user and workspace")
337
+ client = get_argilla_client()
338
+ hf_user = HfApi().whoami(token=oauth_token.token)["name"]
339
+ if eval_type == "ultrafeedback":
340
+ num_generations = len((dataframe["generations"][0]))
341
+ fields = [
342
+ rg.ChatField(
343
+ name=f"chat_{i}",
344
+ title=f"Chat {i+1}",
345
+ description=f"User and assistant conversation for generation {i+1}",
346
+ )
347
+ for i in range(num_generations)
348
+ ]
349
+ questions = []
350
+ for i in range(num_generations):
351
+ for aspect in aspects_instruction_response:
352
+ questions.append(
353
+ rg.RatingQuestion(
354
+ name=f"ratings_{aspect}_{i}",
355
+ values=list(range(11)),
356
+ title=f"Ratings for {aspect} for response {i+1}",
357
+ required=True,
358
+ )
359
+ )
360
+ questions.append(
361
+ rg.TextQuestion(
362
+ name=f"rationale_for_ratings_{aspect}_{i}",
363
+ title=f"Rationale for ratings for {aspect} for response {i+1}",
364
+ required=False,
365
+ use_markdown=True,
366
+ )
367
+ )
368
+ if aspect in ["truthfulness", "helpfulness"]:
369
+ questions.append(
370
+ rg.RatingQuestion(
371
+ name=f"type_{aspect}_{i}",
372
+ values=list(range(1, 6)),
373
+ title=f"The type of the response {i+1} for {aspect}",
374
+ required=True,
375
+ )
376
+ )
377
+ questions.append(
378
+ rg.TextQuestion(
379
+ name=f"rationale_for_type_{aspect}_{i}",
380
+ title=f"Rationale for type of the response {i+1} for {aspect}",
381
+ required=False,
382
+ use_markdown=True,
383
+ )
384
+ )
385
+ metadata = [
386
+ rg.IntegerMetadataProperty(
387
+ name="instruction_length", title="Instruction length"
388
+ ),
389
+ ]
390
+ for i in range(num_generations):
391
+ metadata.append(
392
+ rg.IntegerMetadataProperty(
393
+ name=f"response_{i}_length", title=f"Response {i+1} length"
394
+ )
395
+ )
396
+ vectors = [
397
+ rg.VectorField(
398
+ name="instruction_embeddings",
399
+ dimensions=get_sentence_embedding_dimensions(),
400
+ )
401
+ ]
402
+ settings = rg.Settings(
403
+ fields=fields,
404
+ questions=questions,
405
+ metadata=metadata,
406
+ vectors=vectors,
407
+ guidelines="Please review the conversation and provide an evaluation.",
408
  )
409
+
410
+ dataframe["instruction_length"] = dataframe["instruction"].apply(len)
411
+ for i in range(num_generations):
412
+ dataframe[f"response_{i}_length"] = dataframe["generations"].apply(
413
+ lambda gens: len(gens[i]) if i < len(gens) else 0
414
+ )
415
+ dataframe["instruction_embeddings"] = get_embeddings(
416
+ dataframe["instruction"].to_list()
 
 
 
 
417
  )
418
+
419
+ progress(0.5, desc="Creating dataset")
420
+ rg_dataset = client.datasets(name=repo_name, workspace=hf_user)
421
+ if rg_dataset is None:
422
+ rg_dataset = rg.Dataset(
423
+ name=repo_name,
424
+ workspace=hf_user,
425
+ settings=settings,
426
+ client=client,
427
+ )
428
+ rg_dataset = rg_dataset.create()
429
+
430
+ progress(0.7, desc="Pushing dataset to Argilla")
431
+ hf_dataset = Dataset.from_pandas(dataframe)
432
+ records = []
433
+ for sample in hf_dataset:
434
+ fields = {}
435
+ metadata = {"instruction_length": sample.get("instruction_length", 0)}
436
+ vectors = {
437
+ "instruction_embeddings": sample.get("instruction_embeddings", [])
438
+ }
439
+ suggestions = []
440
+ generations = sample.get("generations", [])
441
+ for i in range(num_generations):
442
+ fields[f"chat_{i}"] = [
443
+ {"role": "user", "content": sample.get("instruction", "")},
444
+ {"role": "assistant", "content": generations[i]},
445
+ ]
446
+ metadata[f"response_{i}_length"] = sample.get(
447
+ f"response_{i}_length", 0
448
+ )
449
+
450
+ for aspect in aspects_instruction_response:
451
+ ratings = sample.get(f"ratings_{aspect}", [])
452
+ rationales = sample.get(f"rationale_for_ratings__{aspect}", [])
453
+
454
+ rating_value = (
455
+ ratings[i]
456
+ if ratings and isinstance(ratings[i], int)
457
+ else None
458
+ )
459
+ rationale_value = (
460
+ rationales[i]
461
+ if rationales and isinstance(rationales[i], str)
462
+ else None
463
+ )
464
+
465
+ if rating_value is not None:
466
+ suggestions.append(
467
+ rg.Suggestion(
468
+ question_name=f"ratings_{aspect}_{i}",
469
+ value=rating_value,
470
+ )
471
+ )
472
+ if rationale_value is not None:
473
+ suggestions.append(
474
+ rg.Suggestion(
475
+ question_name=f"rationale_for_ratings_{aspect}_{i}",
476
+ value=rationale_value,
477
+ )
478
+ )
479
+
480
+ if aspect in ["truthfulness", "helpfulness"]:
481
+ types = sample.get(f"type_{aspect}", [])
482
+ rationale_types = sample.get(
483
+ f"rationale_for_type_{aspect}", []
484
+ )
485
+
486
+ type_value = (
487
+ types[i]
488
+ if types and isinstance(types[i], int)
489
+ else None
490
+ )
491
+ rationale_type_value = (
492
+ rationale_types[i]
493
+ if rationale_types
494
+ and isinstance(rationale_types[i], str)
495
+ else None
496
+ )
497
+ if type_value is not None:
498
+ suggestions.append(
499
+ rg.Suggestion(
500
+ question_name=f"type_{aspect}_{i}",
501
+ value=type_value,
502
+ )
503
+ )
504
+ if rationale_type_value is not None:
505
+ suggestions.append(
506
+ rg.Suggestion(
507
+ question_name=f"rationale_for_type_{aspect}_{i}",
508
+ value=rationale_type_value,
509
+ )
510
+ )
511
+ records.append(
512
+ rg.Record(
513
+ fields=fields,
514
+ metadata=metadata,
515
+ vectors=vectors,
516
+ suggestions=suggestions,
517
+ )
518
  )
519
+ rg_dataset.records.log(records=records)
520
+ progress(1.0, desc="Dataset pushed to Argilla")
521
+ else:
522
+ columns = extract_column_names(prompt_template)
523
+ settings = rg.Settings(
524
+ fields=[
525
+ rg.TextField(
526
+ name=column,
527
+ title=column.capitalize(),
528
+ description="The column content",
529
+ )
530
+ for column in columns
531
+ ],
532
+ questions=[
533
+ rg.TextQuestion(
534
+ name="evaluation",
535
+ title="Evaluation",
536
+ description="The generated evaluation",
537
+ use_markdown=True,
538
+ ),
539
+ ],
540
+ metadata=[
541
+ rg.IntegerMetadataProperty(
542
+ name=f"{column}_length", title=f"{column.capitalize()} length"
543
+ )
544
+ for column in columns
545
+ ],
546
+ vectors=[
547
+ rg.VectorField(
548
+ name=f"{column}_embeddings",
549
+ dimensions=get_sentence_embedding_dimensions(),
550
+ )
551
+ for column in columns
552
+ ],
553
+ guidelines="Please review, correct and provide an accurate evaluation.",
554
+ )
555
+ for column in columns:
556
+ dataframe[f"{column}_length"] = dataframe[column].apply(len)
557
+ dataframe[f"{column}_embeddings"] = get_embeddings(dataframe[column])
558
+
559
+ progress(0.5, desc="Creating dataset")
560
+ rg_dataset = client.datasets(name=repo_name, workspace=hf_user)
561
+ if rg_dataset is None:
562
+ rg_dataset = rg.Dataset(
563
+ name=repo_name,
564
+ workspace=hf_user,
565
+ settings=settings,
566
+ client=client,
567
  )
568
+ rg_dataset = rg_dataset.create()
569
+ progress(0.7, desc="Pushing dataset to Argilla")
570
+ hf_dataset = Dataset.from_pandas(dataframe)
571
+ rg_dataset.records.log(
572
+ records=hf_dataset, mapping={"generation": "evaluation"}
573
+ )
574
+ progress(1.0, desc="Dataset pushed to Argilla")
575
+ except Exception as e:
576
+ raise gr.Error(f"Error pushing dataset to Argilla: {e}")
577
+ return ""
578
+
579
+
580
+ def show_pipeline_code_visibility():
581
+ return {pipeline_code_ui: gr.Accordion(visible=True)}
582
+
583
+ def hide_pipeline_code_visibility():
584
+ return {pipeline_code_ui: gr.Accordion(visible=False)}
585
+
586
+
587
+ ######################
588
+ # Gradio UI
589
+ ######################
590
+
591
+
592
+ with gr.Blocks() as app:
593
+ with gr.Column() as main_ui:
594
+ gr.Markdown("## 1. Select your input dataset")
595
+ with gr.Row(equal_height=False):
596
+ with gr.Column(scale=1):
597
+ search_in = HuggingfaceHubSearch(
598
+ label="Search",
599
+ placeholder="Search for a dataset",
600
+ search_type="dataset",
601
+ sumbit_on_select=True,
602
  )
603
+ load_btn = gr.Button("Load dataset", variant="primary")
604
+ with gr.Column(scale=3):
605
+ search_out = gr.HTML(label="Dataset preview")
606
+
607
+ gr.HTML(value="<hr>")
608
+ gr.Markdown(value="## 2. Configure your task")
609
+ with gr.Row(equal_height=False):
610
+ with gr.Column(scale=1):
611
+ eval_type = gr.Dropdown(
612
+ label="Evaluation type",
613
+ choices=["ultrafeedback", "custom"],
614
+ value="ultrafeedback",
615
+ multiselect=False,
616
+ visible=False,
617
  )
618
+ with gr.Tab("ultrafeedback") as tab_instruction_response:
619
+ aspects_instruction_response = define_evaluation_aspects(
620
+ "ultrafeedback"
621
+ )
622
+ instruction_instruction_response = gr.Dropdown(
623
+ label="Instruction Column",
624
+ interactive=True,
625
+ multiselect=False,
626
+ allow_custom_value=False,
627
+ )
628
+ response_instruction_response = gr.Dropdown(
629
+ label="Response Column",
630
+ interactive=True,
631
+ multiselect=True,
632
+ allow_custom_value=False,
633
+ )
634
+ tab_instruction_response.select(
635
+ fn=lambda: "ultrafeedback",
636
+ inputs=[],
637
+ outputs=[eval_type],
638
+ )
639
+ with gr.Tab("custom") as tab_custom:
640
+ aspects_custom = define_evaluation_aspects("custom")
641
+ prompt_template = gr.Code(
642
+ label="Prompt template",
643
+ value="Evaluate {{column_1}} based on {{column_2}}.",
644
+ language="markdown",
645
+ interactive=True,
646
+ )
647
+ structured_output = gr.Code(
648
+ label="Structured output",
649
+ value=json.dumps(
650
+ {
651
+ "type": "object",
652
+ "properties": {
653
+ "quality": {"type": "integer"},
654
+ "clarity": {"type": "integer"},
655
+ "relevance": {"type": "integer"},
656
+ },
657
+ },
658
+ indent=4,
659
+ ),
660
+ language="json",
661
+ interactive=True,
662
+ )
663
+ tab_custom.select(
664
+ fn=lambda: "custom",
665
+ inputs=[],
666
+ outputs=[eval_type],
667
+ )
668
+ btn_apply_to_sample_dataset = gr.Button(
669
+ "Refresh dataset", variant="secondary", size="sm"
670
  )
671
+ with gr.Column(scale=3):
672
+ dataframe = gr.Dataframe(
673
+ headers=["prompt", "completion", "evaluation"],
674
+ wrap=False,
675
+ height=500,
676
+ interactive=False,
677
  )
678
+
679
+ gr.HTML(value="<hr>")
680
+ gr.Markdown(value="## 3. Evaluate your dataset")
681
+ with gr.Row(equal_height=False):
682
+ with gr.Column(scale=2):
683
+ org_name = get_org_dropdown()
684
+ repo_name = gr.Textbox(
685
+ label="Repo name",
686
+ placeholder="dataset_name",
687
+ value=f"my-distiset-{str(uuid.uuid4())[:8]}",
688
  interactive=True,
689
  )
690
+ num_rows = gr.Number(
691
+ label="Number of rows",
692
+ value=10,
 
693
  interactive=True,
694
+ scale=1,
695
  )
696
+ private = gr.Checkbox(
697
+ label="Private dataset",
698
+ value=False,
699
+ interactive=True,
700
+ scale=1,
701
  )
702
+ btn_push_to_hub = gr.Button("Push to Hub", variant="primary", scale=2)
703
+ with gr.Column(scale=3):
704
+ success_message = gr.Markdown(visible=True)
705
+ with gr.Accordion(
706
+ "Do you want to go further? Customize and run with Distilabel",
707
+ open=False,
708
+ visible=False,
709
+ ) as pipeline_code_ui:
710
+ code = generate_pipeline_code(
711
+ repo_id=search_in.value,
712
+ aspects=aspects_instruction_response.value,
713
+ instruction_column=instruction_instruction_response,
714
+ response_columns=response_instruction_response,
715
+ prompt_template=prompt_template.value,
716
+ structured_output=structured_output.value,
717
+ num_rows=num_rows.value,
718
+ eval_type=eval_type.value,
719
+ )
720
+ pipeline_code = gr.Code(
721
+ value=code,
722
+ language="python",
723
+ label="Distilabel Pipeline Code",
724
+ )
725
+
726
+ search_in.submit(fn=get_iframe, inputs=search_in, outputs=search_out)
 
 
 
 
 
727
 
 
728
  load_btn.click(
729
+ fn=load_dataset_from_hub,
730
  inputs=[search_in],
731
  outputs=[
732
  dataframe,
 
733
  instruction_instruction_response,
734
  response_instruction_response,
735
  ],
736
  )
737
+
738
  btn_apply_to_sample_dataset.click(
739
+ fn=evaluate_sample_dataset,
740
  inputs=[
741
  search_in,
742
  eval_type,
 
743
  aspects_instruction_response,
 
 
744
  instruction_instruction_response,
745
  response_instruction_response,
746
  prompt_template,
 
748
  ],
749
  outputs=dataframe,
750
  )
751
+
752
  btn_push_to_hub.click(
753
+ fn=validate_argilla_user_workspace_dataset,
754
+ inputs=[repo_name],
755
+ outputs=[success_message],
756
+ show_progress=True,
757
+ ).then(
758
+ fn=validate_push_to_hub,
759
+ inputs=[org_name, repo_name],
760
+ outputs=[success_message],
761
+ show_progress=True,
762
+ ).success(
763
+ fn=hide_success_message,
764
+ outputs=[success_message],
765
+ show_progress=True,
766
+ ).success(
767
+ fn=hide_pipeline_code_visibility,
768
+ inputs=[],
769
+ outputs=[pipeline_code_ui],
770
+ ).success(
771
+ fn=push_dataset,
772
  inputs=[
773
  org_name,
774
  repo_name,
775
  private,
776
+ num_rows,
777
  search_in,
778
  eval_type,
 
779
  aspects_instruction_response,
 
 
780
  instruction_instruction_response,
781
  response_instruction_response,
782
  prompt_template,
783
  structured_output,
784
  ],
785
+ outputs=[success_message],
786
+ show_progress=True,
787
+ ).success(
788
+ fn=show_success_message,
789
+ inputs=[org_name, repo_name],
790
+ outputs=[success_message],
791
+ ).success(
792
+ fn=generate_pipeline_code,
793
+ inputs=[
794
+ search_in,
795
+ aspects_instruction_response,
796
+ instruction_instruction_response,
797
+ response_instruction_response,
798
+ prompt_template,
799
+ structured_output,
800
+ num_rows,
801
+ eval_type,
802
+ ],
803
+ outputs=[pipeline_code],
804
+ ).success(
805
+ fn=show_pipeline_code_visibility,
806
+ inputs=[],
807
+ outputs=[pipeline_code_ui],
808
  )
809
+
810
+ app.load(fn=swap_visibility, outputs=main_ui)
811
  app.load(fn=get_org_dropdown, outputs=[org_name])
src/distilabel_dataset_generator/apps/sft.py CHANGED
@@ -10,10 +10,8 @@ from distilabel.distiset import Distiset
10
  from huggingface_hub import HfApi
11
 
12
  from src.distilabel_dataset_generator.apps.base import (
13
- get_argilla_client,
14
- get_pipeline_code_ui,
15
  hide_success_message,
16
- show_success_message_hub,
17
  validate_argilla_user_workspace_dataset,
18
  validate_push_to_hub,
19
  )
@@ -26,7 +24,6 @@ from src.distilabel_dataset_generator.pipelines.embeddings import (
26
  )
27
  from src.distilabel_dataset_generator.pipelines.sft import (
28
  DEFAULT_DATASET_DESCRIPTIONS,
29
- PROMPT_CREATION_PROMPT,
30
  generate_pipeline_code,
31
  get_magpie_generator,
32
  get_prompt_generator,
@@ -36,7 +33,7 @@ from src.distilabel_dataset_generator.utils import (
36
  _LOGGED_OUT_CSS,
37
  get_argilla_client,
38
  get_org_dropdown,
39
- swap_visibilty,
40
  )
41
 
42
 
@@ -55,35 +52,33 @@ def convert_dataframe_messages(dataframe: pd.DataFrame) -> pd.DataFrame:
55
  return dataframe
56
 
57
 
58
- def generate_system_prompt(dataset_description, progress=gr.Progress()):
59
  progress(0.0, desc="Generating system prompt")
60
-
61
  progress(0.3, desc="Initializing text generation")
62
- generate_description = get_prompt_generator()
63
  progress(0.7, desc="Generating system prompt")
64
  result = next(
65
  generate_description.process(
66
  [
67
  {
68
- "system_prompt": PROMPT_CREATION_PROMPT,
69
  "instruction": dataset_description,
70
  }
71
  ]
72
  )
73
  )[0]["generation"]
74
  progress(1.0, desc="System prompt generated")
75
- return result, pd.DataFrame()
76
 
77
 
78
- def generate_sample_dataset(system_prompt, progress=gr.Progress()):
79
- df = generate_dataset(
80
  system_prompt=system_prompt,
81
- num_turns=1,
82
  num_rows=10,
83
  progress=progress,
84
  is_sample=True,
85
  )
86
- return df
87
 
88
 
89
  def generate_dataset(
@@ -94,10 +89,8 @@ def generate_dataset(
94
  progress=gr.Progress(),
95
  ) -> pd.DataFrame:
96
  progress(0.0, desc="(1/2) Generating instructions")
97
- magpie_generator = get_magpie_generator(
98
- num_turns, num_rows, system_prompt, is_sample
99
- )
100
- response_generator = get_response_generator(num_turns, system_prompt, is_sample)
101
  total_steps: int = num_rows * 2
102
  batch_size = DEFAULT_BATCH_SIZE
103
 
@@ -209,12 +202,12 @@ def push_dataset_to_hub(dataframe, org_name, repo_name, oauth_token, private):
209
  return original_dataframe
210
 
211
 
212
- def push_dataset_to_argilla(
213
  org_name: str,
214
  repo_name: str,
215
  system_prompt: str,
216
  num_turns: int = 1,
217
- n_rows: int = 10,
218
  private: bool = False,
219
  oauth_token: Union[gr.OAuthToken, None] = None,
220
  progress=gr.Progress(),
@@ -222,7 +215,7 @@ def push_dataset_to_argilla(
222
  dataframe = generate_dataset(
223
  system_prompt=system_prompt,
224
  num_turns=num_turns,
225
- num_rows=n_rows,
226
  )
227
  push_dataset_to_hub(dataframe, org_name, repo_name, oauth_token, private)
228
  try:
@@ -344,29 +337,54 @@ def push_dataset_to_argilla(
344
  return ""
345
 
346
 
 
 
 
 
 
 
 
 
 
 
 
 
 
347
  with gr.Blocks(css=_LOGGED_OUT_CSS) as app:
348
  with gr.Column() as main_ui:
349
  gr.Markdown(value="## 1. Describe the dataset you want")
350
  with gr.Row():
351
- with gr.Column(scale=1):
352
  dataset_description = gr.Textbox(
353
  label="Dataset description",
354
  placeholder="Give a precise description of your desired dataset.",
355
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
356
  examples = gr.Examples(
357
  examples=DEFAULT_DATASET_DESCRIPTIONS,
358
  inputs=[dataset_description],
359
  cache_examples=False,
360
- label="Example descriptions",
361
  )
362
-
363
- load_btn = gr.Button("Load dataset", variant="primary")
364
- with gr.Column(scale=3):
365
  pass
366
 
367
  gr.HTML(value="<hr>")
368
- gr.Markdown(value="## 2. Configure your task")
369
- with gr.Row():
370
  with gr.Column(scale=1):
371
  system_prompt = gr.Textbox(
372
  label="System prompt",
@@ -381,14 +399,21 @@ with gr.Blocks(css=_LOGGED_OUT_CSS) as app:
381
  interactive=True,
382
  info="Choose between 1 (single turn with 'instruction-response' columns) and 2-4 (multi-turn conversation with a 'messages' column).",
383
  )
384
- btn_apply_to_sample_dataset = gr.Button("Refresh dataset")
 
 
385
  with gr.Column(scale=3):
386
- dataframe = gr.Dataframe()
 
 
 
 
 
387
 
388
  gr.HTML(value="<hr>")
389
  gr.Markdown(value="## 3. Generate your dataset")
390
- with gr.Row():
391
- with gr.Column(scale=1):
392
  org_name = get_org_dropdown()
393
  repo_name = gr.Textbox(
394
  label="Repo name",
@@ -396,7 +421,7 @@ with gr.Blocks(css=_LOGGED_OUT_CSS) as app:
396
  value=f"my-distiset-{str(uuid.uuid4())[:8]}",
397
  interactive=True,
398
  )
399
- n_rows = gr.Number(
400
  label="Number of rows",
401
  value=10,
402
  interactive=True,
@@ -410,21 +435,38 @@ with gr.Blocks(css=_LOGGED_OUT_CSS) as app:
410
  )
411
  btn_push_to_hub = gr.Button("Push to Hub", variant="primary", scale=2)
412
  with gr.Column(scale=3):
413
- success_message = gr.Markdown()
414
-
415
- pipeline_code = get_pipeline_code_ui(
416
- generate_pipeline_code(system_prompt.value, num_turns.value, n_rows.value)
417
- )
 
 
 
 
 
 
 
 
 
 
 
418
 
419
- gr.on(
420
- triggers=[load_btn.click, btn_apply_to_sample_dataset.click],
421
  fn=generate_system_prompt,
422
- inputs=[dataset_description],
423
- outputs=[system_prompt, dataframe],
424
  show_progress=True,
425
  ).then(
426
  fn=generate_sample_dataset,
427
- inputs=[system_prompt],
 
 
 
 
 
 
 
428
  outputs=[dataframe],
429
  show_progress=True,
430
  )
@@ -444,21 +486,34 @@ with gr.Blocks(css=_LOGGED_OUT_CSS) as app:
444
  outputs=[success_message],
445
  show_progress=True,
446
  ).success(
447
- fn=push_dataset_to_argilla,
 
 
 
 
448
  inputs=[
449
  org_name,
450
  repo_name,
451
  system_prompt,
452
  num_turns,
453
- n_rows,
454
  private,
455
  ],
456
  outputs=[success_message],
457
  show_progress=True,
458
  ).success(
459
- fn=show_success_message_hub,
460
  inputs=[org_name, repo_name],
461
  outputs=[success_message],
 
 
 
 
 
 
 
 
462
  )
463
- app.load(fn=swap_visibilty, outputs=main_ui)
 
464
  app.load(fn=get_org_dropdown, outputs=[org_name])
 
10
  from huggingface_hub import HfApi
11
 
12
  from src.distilabel_dataset_generator.apps.base import (
 
 
13
  hide_success_message,
14
+ show_success_message,
15
  validate_argilla_user_workspace_dataset,
16
  validate_push_to_hub,
17
  )
 
24
  )
25
  from src.distilabel_dataset_generator.pipelines.sft import (
26
  DEFAULT_DATASET_DESCRIPTIONS,
 
27
  generate_pipeline_code,
28
  get_magpie_generator,
29
  get_prompt_generator,
 
33
  _LOGGED_OUT_CSS,
34
  get_argilla_client,
35
  get_org_dropdown,
36
+ swap_visibility,
37
  )
38
 
39
 
 
52
  return dataframe
53
 
54
 
55
+ def generate_system_prompt(dataset_description, temperature, progress=gr.Progress()):
56
  progress(0.0, desc="Generating system prompt")
 
57
  progress(0.3, desc="Initializing text generation")
58
+ generate_description = get_prompt_generator(temperature)
59
  progress(0.7, desc="Generating system prompt")
60
  result = next(
61
  generate_description.process(
62
  [
63
  {
 
64
  "instruction": dataset_description,
65
  }
66
  ]
67
  )
68
  )[0]["generation"]
69
  progress(1.0, desc="System prompt generated")
70
+ return result
71
 
72
 
73
+ def generate_sample_dataset(system_prompt, num_turns, progress=gr.Progress()):
74
+ dataframe = generate_dataset(
75
  system_prompt=system_prompt,
76
+ num_turns=num_turns,
77
  num_rows=10,
78
  progress=progress,
79
  is_sample=True,
80
  )
81
+ return dataframe
82
 
83
 
84
  def generate_dataset(
 
89
  progress=gr.Progress(),
90
  ) -> pd.DataFrame:
91
  progress(0.0, desc="(1/2) Generating instructions")
92
+ magpie_generator = get_magpie_generator(system_prompt, num_turns, is_sample)
93
+ response_generator = get_response_generator(system_prompt, num_turns, is_sample)
 
 
94
  total_steps: int = num_rows * 2
95
  batch_size = DEFAULT_BATCH_SIZE
96
 
 
202
  return original_dataframe
203
 
204
 
205
+ def push_dataset(
206
  org_name: str,
207
  repo_name: str,
208
  system_prompt: str,
209
  num_turns: int = 1,
210
+ num_rows: int = 10,
211
  private: bool = False,
212
  oauth_token: Union[gr.OAuthToken, None] = None,
213
  progress=gr.Progress(),
 
215
  dataframe = generate_dataset(
216
  system_prompt=system_prompt,
217
  num_turns=num_turns,
218
+ num_rows=num_rows,
219
  )
220
  push_dataset_to_hub(dataframe, org_name, repo_name, oauth_token, private)
221
  try:
 
337
  return ""
338
 
339
 
340
+ def show_pipeline_code_visibility():
341
+ return {pipeline_code_ui: gr.Accordion(visible=True)}
342
+
343
+
344
+ def hide_pipeline_code_visibility():
345
+ return {pipeline_code_ui: gr.Accordion(visible=False)}
346
+
347
+
348
+ ######################
349
+ # Gradio UI
350
+ ######################
351
+
352
+
353
  with gr.Blocks(css=_LOGGED_OUT_CSS) as app:
354
  with gr.Column() as main_ui:
355
  gr.Markdown(value="## 1. Describe the dataset you want")
356
  with gr.Row():
357
+ with gr.Column(scale=2):
358
  dataset_description = gr.Textbox(
359
  label="Dataset description",
360
  placeholder="Give a precise description of your desired dataset.",
361
  )
362
+ with gr.Accordion("Temperature", open=False):
363
+ temperature = gr.Slider(
364
+ minimum=0.1,
365
+ maximum=1,
366
+ value=0.8,
367
+ step=0.1,
368
+ interactive=True,
369
+ show_label=False,
370
+ )
371
+ load_btn = gr.Button(
372
+ "Create dataset",
373
+ variant="primary",
374
+ )
375
+ with gr.Column(scale=2):
376
  examples = gr.Examples(
377
  examples=DEFAULT_DATASET_DESCRIPTIONS,
378
  inputs=[dataset_description],
379
  cache_examples=False,
380
+ label="Examples",
381
  )
382
+ with gr.Column(scale=1):
 
 
383
  pass
384
 
385
  gr.HTML(value="<hr>")
386
+ gr.Markdown(value="## 2. Configure your dataset")
387
+ with gr.Row(equal_height=False):
388
  with gr.Column(scale=1):
389
  system_prompt = gr.Textbox(
390
  label="System prompt",
 
399
  interactive=True,
400
  info="Choose between 1 (single turn with 'instruction-response' columns) and 2-4 (multi-turn conversation with a 'messages' column).",
401
  )
402
+ btn_apply_to_sample_dataset = gr.Button(
403
+ "Refresh dataset", variant="secondary", size="sm"
404
+ )
405
  with gr.Column(scale=3):
406
+ dataframe = gr.Dataframe(
407
+ headers=["prompt", "completion"],
408
+ wrap=True,
409
+ height=500,
410
+ interactive=False,
411
+ )
412
 
413
  gr.HTML(value="<hr>")
414
  gr.Markdown(value="## 3. Generate your dataset")
415
+ with gr.Row(equal_height=False):
416
+ with gr.Column(scale=2):
417
  org_name = get_org_dropdown()
418
  repo_name = gr.Textbox(
419
  label="Repo name",
 
421
  value=f"my-distiset-{str(uuid.uuid4())[:8]}",
422
  interactive=True,
423
  )
424
+ num_rows = gr.Number(
425
  label="Number of rows",
426
  value=10,
427
  interactive=True,
 
435
  )
436
  btn_push_to_hub = gr.Button("Push to Hub", variant="primary", scale=2)
437
  with gr.Column(scale=3):
438
+ success_message = gr.Markdown(visible=True)
439
+ with gr.Accordion(
440
+ "Do you want to go further? Customize and run with Distilabel",
441
+ open=False,
442
+ visible=False,
443
+ ) as pipeline_code_ui:
444
+ code = generate_pipeline_code(
445
+ system_prompt=system_prompt.value,
446
+ num_turns=num_turns.value,
447
+ num_rows=num_rows.value,
448
+ )
449
+ pipeline_code = gr.Code(
450
+ value=code,
451
+ language="python",
452
+ label="Distilabel Pipeline Code",
453
+ )
454
 
455
+ load_btn.click(
 
456
  fn=generate_system_prompt,
457
+ inputs=[dataset_description, temperature],
458
+ outputs=[system_prompt],
459
  show_progress=True,
460
  ).then(
461
  fn=generate_sample_dataset,
462
+ inputs=[system_prompt, num_turns],
463
+ outputs=[dataframe],
464
+ show_progress=True,
465
+ )
466
+
467
+ btn_apply_to_sample_dataset.click(
468
+ fn=generate_sample_dataset,
469
+ inputs=[system_prompt, num_turns],
470
  outputs=[dataframe],
471
  show_progress=True,
472
  )
 
486
  outputs=[success_message],
487
  show_progress=True,
488
  ).success(
489
+ fn=hide_pipeline_code_visibility,
490
+ inputs=[],
491
+ outputs=[pipeline_code_ui],
492
+ ).success(
493
+ fn=push_dataset,
494
  inputs=[
495
  org_name,
496
  repo_name,
497
  system_prompt,
498
  num_turns,
499
+ num_rows,
500
  private,
501
  ],
502
  outputs=[success_message],
503
  show_progress=True,
504
  ).success(
505
+ fn=show_success_message,
506
  inputs=[org_name, repo_name],
507
  outputs=[success_message],
508
+ ).success(
509
+ fn=generate_pipeline_code,
510
+ inputs=[system_prompt, num_turns, num_rows],
511
+ outputs=[pipeline_code],
512
+ ).success(
513
+ fn=show_pipeline_code_visibility,
514
+ inputs=[],
515
+ outputs=[pipeline_code_ui],
516
  )
517
+
518
+ app.load(fn=swap_visibility, outputs=main_ui)
519
  app.load(fn=get_org_dropdown, outputs=[org_name])
src/distilabel_dataset_generator/apps/textcat.py CHANGED
@@ -1,4 +1,4 @@
1
- import re
2
  import uuid
3
  from typing import List, Union
4
 
@@ -10,10 +10,8 @@ from distilabel.distiset import Distiset
10
  from huggingface_hub import HfApi
11
 
12
  from src.distilabel_dataset_generator.apps.base import (
13
- get_argilla_client,
14
- get_pipeline_code_ui,
15
  hide_success_message,
16
- show_success_message_hub,
17
  validate_argilla_user_workspace_dataset,
18
  validate_push_to_hub,
19
  )
@@ -26,7 +24,6 @@ from src.distilabel_dataset_generator.pipelines.embeddings import (
26
  )
27
  from src.distilabel_dataset_generator.pipelines.textcat import (
28
  DEFAULT_DATASET_DESCRIPTIONS,
29
- PROMPT_CREATION_PROMPT,
30
  generate_pipeline_code,
31
  get_labeller_generator,
32
  get_prompt_generator,
@@ -37,45 +34,42 @@ from src.distilabel_dataset_generator.utils import (
37
  get_argilla_client,
38
  get_org_dropdown,
39
  get_preprocess_labels,
40
- swap_visibilty,
41
  )
42
 
43
 
44
- def generate_system_prompt(dataset_description, progress=gr.Progress()):
45
  progress(0.0, desc="Generating text classification task")
46
  progress(0.3, desc="Initializing text generation")
47
- generate_description = get_prompt_generator()
48
  progress(0.7, desc="Generating text classification task")
49
- system_prompt = next(
50
  generate_description.process(
51
  [
52
  {
53
- "system_prompt": PROMPT_CREATION_PROMPT,
54
  "instruction": dataset_description,
55
  }
56
  ]
57
  )
58
  )[0]["generation"]
59
  progress(1.0, desc="Text classification task generated")
60
- return system_prompt, pd.DataFrame()
61
-
 
 
62
 
63
- def generate_sample_dataset(system_prompt, progress=gr.Progress()):
64
- df = generate_dataset(
65
  system_prompt=system_prompt,
66
- difficulty="mixed",
67
- clarity="mixed",
68
- labels=[],
69
- num_labels=1,
70
  num_rows=10,
71
  progress=progress,
72
  is_sample=True,
73
  )
74
- if "label" in df.columns:
75
- df = df[["label", "text"]]
76
- elif "labels" in df.columns:
77
- df = df[["labels", "text"]]
78
- return df
79
 
80
 
81
  def generate_dataset(
@@ -88,17 +82,13 @@ def generate_dataset(
88
  is_sample: bool = False,
89
  progress=gr.Progress(),
90
  ) -> pd.DataFrame:
91
- if is_sample:
92
- multiplier = 1
93
- else:
94
- multiplier = 2
95
  progress(0.0, desc="(1/2) Generating text classification data")
96
  labels = get_preprocess_labels(labels)
97
  textcat_generator = get_textcat_generator(
98
  difficulty=difficulty, clarity=clarity, is_sample=is_sample
99
  )
100
  labeller_generator = get_labeller_generator(
101
- system_prompt=system_prompt,
102
  labels=labels,
103
  num_labels=num_labels,
104
  )
@@ -110,13 +100,15 @@ def generate_dataset(
110
  textcat_results = []
111
  while n_processed < num_rows:
112
  progress(
113
- multiplier * 0.5 * n_processed / num_rows,
114
  total=total_steps,
115
  desc="(1/2) Generating text classification data",
116
  )
117
  remaining_rows = num_rows - n_processed
118
  batch_size = min(batch_size, remaining_rows)
119
- inputs = [{"task": system_prompt} for _ in range(batch_size)]
 
 
120
  batch = list(textcat_generator.process(inputs=inputs))
121
  textcat_results.extend(batch[0])
122
  n_processed += batch_size
@@ -124,58 +116,41 @@ def generate_dataset(
124
  result["text"] = result["input_text"]
125
 
126
  # label text classification data
127
- progress(multiplier * 0.5, desc="(1/2) Generating text classification data")
128
- if not is_sample:
129
- n_processed = 0
130
- labeller_results = []
131
- while n_processed < num_rows:
132
- progress(
133
- 0.5 + 0.5 * n_processed / num_rows,
134
- total=total_steps,
135
- desc="(1/2) Labeling text classification data",
136
- )
137
- batch = textcat_results[n_processed : n_processed + batch_size]
138
- labels_batch = list(labeller_generator.process(inputs=batch))
139
- labeller_results.extend(labels_batch[0])
140
- n_processed += batch_size
141
  progress(
142
- 1,
143
  total=total_steps,
144
- desc="(2/2) Creating dataset",
145
  )
 
 
 
 
 
 
 
 
 
146
 
147
  # create final dataset
148
  distiset_results = []
149
- source_results = textcat_results if is_sample else labeller_results
150
- for result in source_results:
151
  record = {
152
  key: result[key]
153
- for key in ["text", "label" if is_sample else "labels"]
154
  if key in result
155
  }
156
  distiset_results.append(record)
157
 
158
  dataframe = pd.DataFrame(distiset_results)
159
- if not is_sample:
160
- if num_labels == 1:
161
- dataframe = dataframe.rename(columns={"labels": "label"})
162
- dataframe["label"] = dataframe["label"].apply(
163
- lambda x: x.lower().strip() if x.lower().strip() in labels else None
164
- )
165
- else:
166
- dataframe["labels"] = dataframe["labels"].apply(
167
- lambda x: (
168
- list(
169
- set(
170
- label.lower().strip()
171
- for label in x
172
- if label.lower().strip() in labels
173
- )
174
- )
175
- if isinstance(x, list)
176
- else None
177
- )
178
- )
179
  progress(1.0, desc="Dataset generation completed")
180
  return dataframe
181
 
@@ -213,14 +188,14 @@ def push_dataset_to_hub(
213
  )
214
 
215
 
216
- def push_dataset_to_argilla(
217
  org_name: str,
218
  repo_name: str,
219
  system_prompt: str,
220
  difficulty: str,
221
  clarity: str,
222
  num_labels: int = 1,
223
- n_rows: int = 10,
224
  labels: List[str] = None,
225
  private: bool = False,
226
  oauth_token: Union[gr.OAuthToken, None] = None,
@@ -232,7 +207,7 @@ def push_dataset_to_argilla(
232
  clarity=clarity,
233
  num_labels=num_labels,
234
  labels=labels,
235
- num_rows=n_rows,
236
  )
237
  push_dataset_to_hub(
238
  dataframe, org_name, repo_name, num_labels, labels, oauth_token, private
@@ -283,7 +258,7 @@ def push_dataset_to_argilla(
283
  )
284
 
285
  dataframe["text_length"] = dataframe["text"].apply(len)
286
- dataframe["text_embeddings"] = get_embeddings(dataframe["text"])
287
 
288
  progress(0.5, desc="Creating dataset")
289
  rg_dataset = client.datasets(name=repo_name, workspace=hf_user)
@@ -332,15 +307,6 @@ def push_dataset_to_argilla(
332
  return ""
333
 
334
 
335
- def update_suggested_labels(system_prompt):
336
- new_labels = re.findall(r"'(\b[\w-]+\b)'", system_prompt)
337
- if not new_labels:
338
- return gr.Warning(
339
- "No labels found in the system prompt. Please add labels manually."
340
- )
341
- return gr.update(choices=new_labels, value=new_labels)
342
-
343
-
344
  def validate_input_labels(labels):
345
  if not labels or len(labels) < 2:
346
  raise gr.Error(
@@ -353,44 +319,74 @@ def update_max_num_labels(labels):
353
  return gr.update(maximum=len(labels) if labels else 1)
354
 
355
 
 
 
 
 
 
 
 
 
 
 
 
 
 
356
  with gr.Blocks(css=_LOGGED_OUT_CSS) as app:
357
  with gr.Column() as main_ui:
358
  gr.Markdown("## 1. Describe the dataset you want")
359
  with gr.Row():
360
- with gr.Column(scale=1):
361
  dataset_description = gr.Textbox(
362
  label="Dataset description",
363
  placeholder="Give a precise description of your desired dataset.",
364
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
365
  examples = gr.Examples(
366
  examples=DEFAULT_DATASET_DESCRIPTIONS,
367
  inputs=[dataset_description],
368
  cache_examples=False,
369
- label="Example descriptions",
370
  )
371
- load_btn = gr.Button("Load dataset", variant="primary")
372
- with gr.Column(scale=3):
373
  pass
374
 
375
  gr.HTML("<hr>")
376
- gr.Markdown("## 2. Configure your task")
377
- with gr.Row():
378
  with gr.Column(scale=1):
379
  system_prompt = gr.Textbox(
380
  label="System prompt",
381
  placeholder="You are a helpful assistant.",
382
  visible=True,
383
  )
384
- difficulty = gr.Dropdown(
385
- choices=[
386
- ("High School", "high school"),
387
- ("College", "college"),
388
- ("PhD", "PhD"),
389
- ("Mixed", "mixed"),
390
- ],
391
- value="mixed",
392
- label="Difficulty",
393
- info="Select the comprehension level for the text. Ensure it matches the task context.",
 
 
 
 
394
  interactive=True,
395
  )
396
  clarity = gr.Dropdown(
@@ -408,30 +404,30 @@ with gr.Blocks(css=_LOGGED_OUT_CSS) as app:
408
  info="Set how easily the correct label or labels can be identified.",
409
  interactive=True,
410
  )
411
- labels = gr.Dropdown(
412
- choices=[],
413
- allow_custom_value=True,
 
 
 
 
 
 
 
414
  interactive=True,
415
- label="Labels",
416
- multiselect=True,
417
- info="Add the labels to classify the text.",
418
  )
419
- num_labels = gr.Number(
420
- label="Number of labels per text",
421
- value=1,
422
- minimum=1,
423
- maximum=10,
424
- info="Select 1 for single-label and >1 for multi-label.",
425
- interactive=True,
426
  )
427
- btn_apply_to_sample_dataset = gr.Button("Refresh dataset")
428
  with gr.Column(scale=3):
429
- dataframe = gr.Dataframe()
 
 
430
 
431
  gr.HTML("<hr>")
432
  gr.Markdown("## 3. Generate your dataset")
433
- with gr.Row():
434
- with gr.Column(scale=1):
435
  org_name = get_org_dropdown()
436
  repo_name = gr.Textbox(
437
  label="Repo name",
@@ -439,7 +435,7 @@ with gr.Blocks(css=_LOGGED_OUT_CSS) as app:
439
  value=f"my-distiset-{str(uuid.uuid4())[:8]}",
440
  interactive=True,
441
  )
442
- n_rows = gr.Number(
443
  label="Number of rows",
444
  value=10,
445
  interactive=True,
@@ -454,39 +450,54 @@ with gr.Blocks(css=_LOGGED_OUT_CSS) as app:
454
  btn_push_to_hub = gr.Button("Push to Hub", variant="primary", scale=2)
455
  with gr.Column(scale=3):
456
  success_message = gr.Markdown(visible=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
457
 
458
- pipeline_code = get_pipeline_code_ui(
459
- generate_pipeline_code(
460
- system_prompt.value,
461
- difficulty=difficulty.value,
462
- clarity=clarity.value,
463
- labels=labels.value,
464
- num_labels=num_labels.value,
465
- num_rows=n_rows.value,
466
- )
467
- )
468
-
469
- gr.on(
470
- triggers=[load_btn.click, btn_apply_to_sample_dataset.click],
471
  fn=generate_system_prompt,
472
- inputs=[dataset_description],
473
- outputs=[system_prompt, dataframe],
474
  show_progress=True,
475
  ).then(
476
  fn=generate_sample_dataset,
477
- inputs=[system_prompt],
478
  outputs=[dataframe],
479
  show_progress=True,
480
- ).then(
481
- fn=update_suggested_labels,
482
- inputs=[system_prompt],
483
- outputs=labels,
484
  ).then(
485
  fn=update_max_num_labels,
486
  inputs=[labels],
487
  outputs=[num_labels],
488
  )
489
 
 
 
 
 
 
 
 
 
 
 
 
 
 
490
  btn_push_to_hub.click(
491
  fn=validate_argilla_user_workspace_dataset,
492
  inputs=[repo_name],
@@ -502,7 +513,11 @@ with gr.Blocks(css=_LOGGED_OUT_CSS) as app:
502
  outputs=[success_message],
503
  show_progress=True,
504
  ).success(
505
- fn=push_dataset_to_argilla,
 
 
 
 
506
  inputs=[
507
  org_name,
508
  repo_name,
@@ -510,16 +525,32 @@ with gr.Blocks(css=_LOGGED_OUT_CSS) as app:
510
  difficulty,
511
  clarity,
512
  num_labels,
513
- n_rows,
514
  labels,
515
  private,
516
  ],
517
  outputs=[success_message],
518
  show_progress=True,
519
  ).success(
520
- fn=show_success_message_hub,
521
  inputs=[org_name, repo_name],
522
  outputs=[success_message],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
523
  )
524
- app.load(fn=swap_visibilty, outputs=main_ui)
 
525
  app.load(fn=get_org_dropdown, outputs=[org_name])
 
1
+ import json
2
  import uuid
3
  from typing import List, Union
4
 
 
10
  from huggingface_hub import HfApi
11
 
12
  from src.distilabel_dataset_generator.apps.base import (
 
 
13
  hide_success_message,
14
+ show_success_message,
15
  validate_argilla_user_workspace_dataset,
16
  validate_push_to_hub,
17
  )
 
24
  )
25
  from src.distilabel_dataset_generator.pipelines.textcat import (
26
  DEFAULT_DATASET_DESCRIPTIONS,
 
27
  generate_pipeline_code,
28
  get_labeller_generator,
29
  get_prompt_generator,
 
34
  get_argilla_client,
35
  get_org_dropdown,
36
  get_preprocess_labels,
37
+ swap_visibility,
38
  )
39
 
40
 
41
+ def generate_system_prompt(dataset_description, temperature, progress=gr.Progress()):
42
  progress(0.0, desc="Generating text classification task")
43
  progress(0.3, desc="Initializing text generation")
44
+ generate_description = get_prompt_generator(temperature)
45
  progress(0.7, desc="Generating text classification task")
46
+ result = next(
47
  generate_description.process(
48
  [
49
  {
 
50
  "instruction": dataset_description,
51
  }
52
  ]
53
  )
54
  )[0]["generation"]
55
  progress(1.0, desc="Text classification task generated")
56
+ data = json.loads(result)
57
+ system_prompt = data["classification_task"]
58
+ labels = data["labels"]
59
+ return system_prompt, labels
60
 
61
+ def generate_sample_dataset(system_prompt, difficulty, clarity, labels, num_labels, progress=gr.Progress()):
62
+ dataframe = generate_dataset(
63
  system_prompt=system_prompt,
64
+ difficulty=difficulty,
65
+ clarity=clarity,
66
+ labels=labels,
67
+ num_labels=num_labels,
68
  num_rows=10,
69
  progress=progress,
70
  is_sample=True,
71
  )
72
+ return dataframe
 
 
 
 
73
 
74
 
75
  def generate_dataset(
 
82
  is_sample: bool = False,
83
  progress=gr.Progress(),
84
  ) -> pd.DataFrame:
 
 
 
 
85
  progress(0.0, desc="(1/2) Generating text classification data")
86
  labels = get_preprocess_labels(labels)
87
  textcat_generator = get_textcat_generator(
88
  difficulty=difficulty, clarity=clarity, is_sample=is_sample
89
  )
90
  labeller_generator = get_labeller_generator(
91
+ system_prompt=f"{system_prompt} {', '.join(labels)}",
92
  labels=labels,
93
  num_labels=num_labels,
94
  )
 
100
  textcat_results = []
101
  while n_processed < num_rows:
102
  progress(
103
+ 2 * 0.5 * n_processed / num_rows,
104
  total=total_steps,
105
  desc="(1/2) Generating text classification data",
106
  )
107
  remaining_rows = num_rows - n_processed
108
  batch_size = min(batch_size, remaining_rows)
109
+ inputs = [
110
+ {"task": f"{system_prompt} {', '.join(labels)}"} for _ in range(batch_size)
111
+ ]
112
  batch = list(textcat_generator.process(inputs=inputs))
113
  textcat_results.extend(batch[0])
114
  n_processed += batch_size
 
116
  result["text"] = result["input_text"]
117
 
118
  # label text classification data
119
+ progress(2 * 0.5, desc="(1/2) Generating text classification data")
120
+ n_processed = 0
121
+ labeller_results = []
122
+ while n_processed < num_rows:
 
 
 
 
 
 
 
 
 
 
123
  progress(
124
+ 0.5 + 0.5 * n_processed / num_rows,
125
  total=total_steps,
126
+ desc="(1/2) Labeling text classification data",
127
  )
128
+ batch = textcat_results[n_processed : n_processed + batch_size]
129
+ labels_batch = list(labeller_generator.process(inputs=batch))
130
+ labeller_results.extend(labels_batch[0])
131
+ n_processed += batch_size
132
+ progress(
133
+ 1,
134
+ total=total_steps,
135
+ desc="(2/2) Creating dataset",
136
+ )
137
 
138
  # create final dataset
139
  distiset_results = []
140
+ for result in labeller_results:
 
141
  record = {
142
  key: result[key]
143
+ for key in ["labels", "text"]
144
  if key in result
145
  }
146
  distiset_results.append(record)
147
 
148
  dataframe = pd.DataFrame(distiset_results)
149
+ if num_labels == 1:
150
+ dataframe = dataframe.rename(columns={"labels": "label"})
151
+ dataframe["label"] = dataframe["label"].apply(
152
+ lambda x: x.lower().strip() if x.lower().strip() in labels else None
153
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
  progress(1.0, desc="Dataset generation completed")
155
  return dataframe
156
 
 
188
  )
189
 
190
 
191
+ def push_dataset(
192
  org_name: str,
193
  repo_name: str,
194
  system_prompt: str,
195
  difficulty: str,
196
  clarity: str,
197
  num_labels: int = 1,
198
+ num_rows: int = 10,
199
  labels: List[str] = None,
200
  private: bool = False,
201
  oauth_token: Union[gr.OAuthToken, None] = None,
 
207
  clarity=clarity,
208
  num_labels=num_labels,
209
  labels=labels,
210
+ num_rows=num_rows,
211
  )
212
  push_dataset_to_hub(
213
  dataframe, org_name, repo_name, num_labels, labels, oauth_token, private
 
258
  )
259
 
260
  dataframe["text_length"] = dataframe["text"].apply(len)
261
+ dataframe["text_embeddings"] = get_embeddings(dataframe["text"].to_list())
262
 
263
  progress(0.5, desc="Creating dataset")
264
  rg_dataset = client.datasets(name=repo_name, workspace=hf_user)
 
307
  return ""
308
 
309
 
 
 
 
 
 
 
 
 
 
310
  def validate_input_labels(labels):
311
  if not labels or len(labels) < 2:
312
  raise gr.Error(
 
319
  return gr.update(maximum=len(labels) if labels else 1)
320
 
321
 
322
+ def show_pipeline_code_visibility():
323
+ return {pipeline_code_ui: gr.Accordion(visible=True)}
324
+
325
+
326
+ def hide_pipeline_code_visibility():
327
+ return {pipeline_code_ui: gr.Accordion(visible=False)}
328
+
329
+
330
+ ######################
331
+ # Gradio UI
332
+ ######################
333
+
334
+
335
  with gr.Blocks(css=_LOGGED_OUT_CSS) as app:
336
  with gr.Column() as main_ui:
337
  gr.Markdown("## 1. Describe the dataset you want")
338
  with gr.Row():
339
+ with gr.Column(scale=2):
340
  dataset_description = gr.Textbox(
341
  label="Dataset description",
342
  placeholder="Give a precise description of your desired dataset.",
343
  )
344
+ with gr.Accordion("Temperature", open=False):
345
+ temperature = gr.Slider(
346
+ minimum=0.1,
347
+ maximum=1,
348
+ value=0.8,
349
+ step=0.1,
350
+ interactive=True,
351
+ show_label=False,
352
+ )
353
+ load_btn = gr.Button(
354
+ "Create dataset",
355
+ variant="primary",
356
+ )
357
+ with gr.Column(scale=2):
358
  examples = gr.Examples(
359
  examples=DEFAULT_DATASET_DESCRIPTIONS,
360
  inputs=[dataset_description],
361
  cache_examples=False,
362
+ label="Examples",
363
  )
364
+ with gr.Column(scale=1):
 
365
  pass
366
 
367
  gr.HTML("<hr>")
368
+ gr.Markdown("## 2. Configure your dataset")
369
+ with gr.Row(equal_height=False):
370
  with gr.Column(scale=1):
371
  system_prompt = gr.Textbox(
372
  label="System prompt",
373
  placeholder="You are a helpful assistant.",
374
  visible=True,
375
  )
376
+ labels = gr.Dropdown(
377
+ choices=[],
378
+ allow_custom_value=True,
379
+ interactive=True,
380
+ label="Labels",
381
+ multiselect=True,
382
+ info="Add the labels to classify the text.",
383
+ )
384
+ num_labels = gr.Number(
385
+ label="Number of labels per text",
386
+ value=1,
387
+ minimum=1,
388
+ maximum=10,
389
+ info="Select 1 for single-label and >1 for multi-label.",
390
  interactive=True,
391
  )
392
  clarity = gr.Dropdown(
 
404
  info="Set how easily the correct label or labels can be identified.",
405
  interactive=True,
406
  )
407
+ difficulty = gr.Dropdown(
408
+ choices=[
409
+ ("High School", "high school"),
410
+ ("College", "college"),
411
+ ("PhD", "PhD"),
412
+ ("Mixed", "mixed"),
413
+ ],
414
+ value="mixed",
415
+ label="Difficulty",
416
+ info="Select the comprehension level for the text. Ensure it matches the task context.",
417
  interactive=True,
 
 
 
418
  )
419
+ btn_apply_to_sample_dataset = gr.Button(
420
+ "Refresh dataset", variant="secondary", size="sm"
 
 
 
 
 
421
  )
 
422
  with gr.Column(scale=3):
423
+ dataframe = gr.Dataframe(
424
+ headers=["labels", "text"], wrap=True, height=500, interactive=False
425
+ )
426
 
427
  gr.HTML("<hr>")
428
  gr.Markdown("## 3. Generate your dataset")
429
+ with gr.Row(equal_height=False):
430
+ with gr.Column(scale=2):
431
  org_name = get_org_dropdown()
432
  repo_name = gr.Textbox(
433
  label="Repo name",
 
435
  value=f"my-distiset-{str(uuid.uuid4())[:8]}",
436
  interactive=True,
437
  )
438
+ num_rows = gr.Number(
439
  label="Number of rows",
440
  value=10,
441
  interactive=True,
 
450
  btn_push_to_hub = gr.Button("Push to Hub", variant="primary", scale=2)
451
  with gr.Column(scale=3):
452
  success_message = gr.Markdown(visible=True)
453
+ with gr.Accordion(
454
+ "Do you want to go further? Customize and run with Distilabel",
455
+ open=False,
456
+ visible=False,
457
+ ) as pipeline_code_ui:
458
+ code = generate_pipeline_code(
459
+ system_prompt.value,
460
+ difficulty=difficulty.value,
461
+ clarity=clarity.value,
462
+ labels=labels.value,
463
+ num_labels=num_labels.value,
464
+ num_rows=num_rows.value,
465
+ )
466
+ pipeline_code = gr.Code(
467
+ value=code,
468
+ language="python",
469
+ label="Distilabel Pipeline Code",
470
+ )
471
 
472
+ load_btn.click(
 
 
 
 
 
 
 
 
 
 
 
 
473
  fn=generate_system_prompt,
474
+ inputs=[dataset_description, temperature],
475
+ outputs=[system_prompt, labels],
476
  show_progress=True,
477
  ).then(
478
  fn=generate_sample_dataset,
479
+ inputs=[system_prompt, difficulty, clarity, labels, num_labels],
480
  outputs=[dataframe],
481
  show_progress=True,
 
 
 
 
482
  ).then(
483
  fn=update_max_num_labels,
484
  inputs=[labels],
485
  outputs=[num_labels],
486
  )
487
 
488
+ labels.input(
489
+ fn=update_max_num_labels,
490
+ inputs=[labels],
491
+ outputs=[num_labels],
492
+ )
493
+
494
+ btn_apply_to_sample_dataset.click(
495
+ fn=generate_sample_dataset,
496
+ inputs=[system_prompt, difficulty, clarity, labels, num_labels],
497
+ outputs=[dataframe],
498
+ show_progress=True,
499
+ )
500
+
501
  btn_push_to_hub.click(
502
  fn=validate_argilla_user_workspace_dataset,
503
  inputs=[repo_name],
 
513
  outputs=[success_message],
514
  show_progress=True,
515
  ).success(
516
+ fn=hide_pipeline_code_visibility,
517
+ inputs=[],
518
+ outputs=[pipeline_code_ui],
519
+ ).success(
520
+ fn=push_dataset,
521
  inputs=[
522
  org_name,
523
  repo_name,
 
525
  difficulty,
526
  clarity,
527
  num_labels,
528
+ num_rows,
529
  labels,
530
  private,
531
  ],
532
  outputs=[success_message],
533
  show_progress=True,
534
  ).success(
535
+ fn=show_success_message,
536
  inputs=[org_name, repo_name],
537
  outputs=[success_message],
538
+ ).success(
539
+ fn=generate_pipeline_code,
540
+ inputs=[
541
+ system_prompt,
542
+ difficulty,
543
+ clarity,
544
+ labels,
545
+ num_labels,
546
+ num_rows,
547
+ ],
548
+ outputs=[pipeline_code],
549
+ ).success(
550
+ fn=show_pipeline_code_visibility,
551
+ inputs=[],
552
+ outputs=[pipeline_code_ui],
553
  )
554
+
555
+ app.load(fn=swap_visibility, outputs=main_ui)
556
  app.load(fn=get_org_dropdown, outputs=[org_name])
src/distilabel_dataset_generator/pipelines/eval.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+
3
+ from datasets import get_dataset_config_names, get_dataset_split_names
4
+ from distilabel.llms import InferenceEndpointsLLM
5
+ from distilabel.steps.tasks import (
6
+ UltraFeedback,
7
+ TextGeneration,
8
+ )
9
+
10
+ from src.distilabel_dataset_generator.pipelines.base import (
11
+ MODEL,
12
+ _get_next_api_key,
13
+ )
14
+ from src.distilabel_dataset_generator.utils import extract_column_names
15
+
16
+
17
+ def get_ultrafeedback_evaluator(aspect, is_sample):
18
+ ultrafeedback_evaluator = UltraFeedback(
19
+ llm=InferenceEndpointsLLM(
20
+ model_id=MODEL,
21
+ tokenizer_id=MODEL,
22
+ api_key=_get_next_api_key(),
23
+ generation_kwargs={
24
+ "temperature": 0.7,
25
+ "max_new_tokens": 256 if is_sample else 2048,
26
+ },
27
+ ),
28
+ aspect=aspect,
29
+ )
30
+ ultrafeedback_evaluator.load()
31
+ return ultrafeedback_evaluator
32
+
33
+
34
+ def get_custom_evaluator(prompt_template, structured_output, columns, is_sample):
35
+ custom_evaluator = TextGeneration(
36
+ llm=InferenceEndpointsLLM(
37
+ model_id=MODEL,
38
+ tokenizer_id=MODEL,
39
+ api_key=_get_next_api_key(),
40
+ structured_output={"format": "json", "schema": structured_output},
41
+ generation_kwargs={
42
+ "temperature": 0.7,
43
+ "max_new_tokens": 256 if is_sample else 2048,
44
+ },
45
+ ),
46
+ template=prompt_template,
47
+ columns=columns
48
+ )
49
+ custom_evaluator.load()
50
+ return custom_evaluator
51
+
52
+
53
+ def generate_ultrafeedback_pipeline_code(
54
+ repo_id, subset, split, aspects, instruction_column, response_columns, num_rows
55
+ ):
56
+ if len(aspects) == 1:
57
+ code = f"""
58
+ # Requirements: `pip install distilabel[hf-inference-endpoints]`
59
+ import os
60
+ from datasets import load_dataset
61
+ from distilabel.pipeline import Pipeline
62
+ from distilabel.steps import LoadDataFromDicts
63
+ from distilabel.steps.tasks import UltraFeedback
64
+ from distilabel.llms import InferenceEndpointsLLM
65
+
66
+ MODEL = "{MODEL}"
67
+ os.environ["HF_TOKEN"] = "hf_xxx" # https://huggingface.co/settings/tokens/new?ownUserPermissions=repo.content.read&ownUserPermissions=repo.write&globalPermissions=inference.serverless.write&canReadGatedRepos=true&tokenType=fineGrained
68
+
69
+ hf_ds = load_dataset("{repo_id}", "{subset}", split="{split}[:{num_rows}]")
70
+ data = preprocess_data(hf_ds, "{instruction_column}", "{response_columns}") # to get a list of dictionaries
71
+
72
+ with Pipeline(name="ultrafeedback") as pipeline:
73
+
74
+ load_the_dataset = LoadDataFromDicts(
75
+ data = data,
76
+ )
77
+
78
+ ultrafeedback_evaluator = UltraFeedback(
79
+ llm=InferenceEndpointsLLM(
80
+ model_id=MODEL,
81
+ tokenizer_id=MODEL,
82
+ api_key=os.environ["HF_TOKEN"],
83
+ generation_kwargs={{
84
+ "temperature": 0.7,
85
+ "max_new_tokens": 2048,
86
+ }},
87
+ ),
88
+ aspect=aspect,
89
+ )
90
+
91
+ load_the_dataset >> ultrafeedback_evaluator
92
+
93
+ if __name__ == "__main__":
94
+ distiset = pipeline.run()
95
+ """
96
+ else:
97
+ code = f"""
98
+ # Requirements: `pip install distilabel[hf-inference-endpoints]`
99
+ import os
100
+ from distilabel.pipeline import Pipeline
101
+ from distilabel.steps import LoadDataFromDicts, CombineOutputs
102
+ from distilabel.steps.tasks import UltraFeedback
103
+ from distilabel.llms import InferenceEndpointsLLM
104
+
105
+ MODEL = "{MODEL}"
106
+ os.environ["HF_TOKEN"] = "hf_xxx" # https://huggingface.co/settings/tokens/new?ownUserPermissions=repo.content.read&ownUserPermissions=repo.write&globalPermissions=inference.serverless.write&canReadGatedRepos=true&tokenType=fineGrained
107
+
108
+ hf_ds = load_dataset("{repo_id}", "{subset}", split="{split}")
109
+ data = preprocess_data(hf_ds, "{instruction_column}", "{response_columns}") # to get a list of dictionaries
110
+
111
+ with Pipeline(name="ultrafeedback") as pipeline:
112
+
113
+ load_the_dataset = LoadDataFromDicts(
114
+ data = data,
115
+ )
116
+
117
+ tasks = []
118
+ for aspect in aspects:
119
+ evaluate_responses = UltraFeedback(
120
+ name=f"evaluate-responses-{{aspect}}",
121
+ aspect=aspect,
122
+ llm=InferenceEndpointsLLM(
123
+ model_id=MODEL,
124
+ tokenizer_id=MODEL,
125
+ api_key=os.environ["HF_TOKEN"],
126
+ generation_kwargs={{
127
+ "temperature": 0.7,
128
+ "max_new_tokens": 2048,
129
+ }},
130
+ output_mappings={{
131
+ "ratings": f"ratings_{{aspect}}",
132
+ "types": f"type_{{aspect}}",
133
+ "rationales": f"rationales_for_types_{{aspect}}",
134
+ "rationales-for-ratings": f"rationales_for_ratings_{{aspect}}",
135
+ }} if aspect in ["truthfulness", "helpfulness"] else {{"rationales": f"rationales_{{aspect}}", "ratings": f"ratings_{{aspect}}"}},
136
+ )
137
+ tasks.append(evaluate_responses)
138
+
139
+ combine_outputs = CombineOutputs()
140
+
141
+ load_the_dataset >> tasks >> combine_outputs
142
+
143
+ if __name__ == "__main__":
144
+ distiset = pipeline.run()
145
+ """
146
+ return code
147
+
148
+
149
+ def generate_custom_pipeline_code(
150
+ repo_id, subset, split, prompt_template, structured_output, num_rows
151
+ ):
152
+ columns = extract_column_names(structured_output)
153
+ code = f"""
154
+ # Requirements: `pip install distilabel[hf-inference-endpoints, instructor]`
155
+ import os
156
+ from distilabel.pipeline import Pipeline
157
+ from distilabel.steps import LoadDataFromHub
158
+ from distilabel.steps.tasks import TextGeneration
159
+ from distilabel.llms import InferenceEndpointsLLM
160
+
161
+ MODEL = "{MODEL}"
162
+ CUSTOM_TEMPLATE = "{prompt_template}"
163
+ os.environ["HF_TOKEN"] = "hf_xxx" # https://huggingface.co/settings/tokens/new?ownUserPermissions=repo.content.read&ownUserPermissions=repo.write&globalPermissions=inference.serverless.write&canReadGatedRepos=true&tokenType=fineGrained
164
+
165
+ with Pipeline(name="custom-evaluation") as pipeline:
166
+ load_the_dataset = LoadDataFromHub(
167
+ repo_id="{repo_id}",
168
+ config="{subset}",
169
+ split="{split}",
170
+ num_examples={num_rows},
171
+ batch_size=2
172
+ )
173
+ custom_evaluator = TextGeneration(
174
+ llm=InferenceEndpointsLLM(
175
+ model_id=MODEL,
176
+ tokenizer_id=MODEL,
177
+ api_key=os.environ["HF_TOKEN"],
178
+ structured_output={{"format": "json", "schema": {structured_output}}},
179
+ generation_kwargs={{
180
+ "temperature": 0.7,
181
+ "max_new_tokens": 2048,
182
+ }},
183
+ ),
184
+ template=CUSTOM_TEMPLATE,
185
+ columns={columns}
186
+ )
187
+
188
+ load_the_dataset >> custom_evaluator
189
+
190
+ if __name__ == "__main__":
191
+ distiset = pipeline.run()
192
+ """
193
+ return code
194
+
195
+
196
+ def generate_pipeline_code(repo_id, aspects, instruction_column, response_columns, prompt_template, structured_output, num_rows, eval_type):
197
+ if repo_id is None:
198
+ subset = "default"
199
+ split = "train"
200
+ else:
201
+ subset = get_dataset_config_names(repo_id)[0]
202
+ split = get_dataset_split_names(repo_id, subset)[0]
203
+ if eval_type == "ultrafeedback":
204
+ return generate_ultrafeedback_pipeline_code(repo_id, subset, split, aspects, instruction_column, response_columns, num_rows)
205
+ return generate_custom_pipeline_code(repo_id, subset, split, prompt_template, structured_output, num_rows)
src/distilabel_dataset_generator/pipelines/sft.py CHANGED
@@ -138,52 +138,26 @@ def _get_output_mappings(num_turns):
138
  return {"conversation": "messages"}
139
 
140
 
141
- def generate_pipeline_code(system_prompt, num_turns, num_rows):
142
- input_mappings = _get_output_mappings(num_turns)
143
- code = f"""
144
- # Requirements: `pip install distilabel[hf-inference-endpoints]`
145
- import os
146
- from distilabel.pipeline import Pipeline
147
- from distilabel.steps import KeepColumns
148
- from distilabel.steps.tasks import MagpieGenerator
149
- from distilabel.llms import InferenceEndpointsLLM
150
-
151
- MODEL = "{MODEL}"
152
- SYSTEM_PROMPT = "{system_prompt}"
153
- os.environ["HF_TOKEN"] = "hf_xxx" # https://huggingface.co/settings/tokens/new?ownUserPermissions=repo.content.read&ownUserPermissions=repo.write&globalPermissions=inference.serverless.write&canReadGatedRepos=true&tokenType=fineGrained
154
-
155
- with Pipeline(name="sft") as pipeline:
156
- magpie = MagpieGenerator(
157
  llm=InferenceEndpointsLLM(
 
158
  model_id=MODEL,
159
  tokenizer_id=MODEL,
160
- magpie_pre_query_template="llama3",
161
- generation_kwargs={{
162
- "temperature": 0.9,
163
- "do_sample": True,
164
  "max_new_tokens": 2048,
165
- "stop_sequences": {_STOP_SEQUENCES}
166
- }},
167
- api_key=os.environ["HF_TOKEN"],
168
  ),
169
- n_turns={num_turns},
170
- num_rows={num_rows},
171
- batch_size=1,
172
- system_prompt=SYSTEM_PROMPT,
173
- output_mappings={input_mappings},
174
- )
175
- keep_columns = KeepColumns(
176
- columns={list(input_mappings.values())} + ["model_name"],
177
  )
178
- magpie.connect(keep_columns)
179
-
180
- if __name__ == "__main__":
181
- distiset = pipeline.run()
182
- """
183
- return code
184
 
185
 
186
- def get_magpie_generator(num_turns, num_rows, system_prompt, is_sample):
187
  input_mappings = _get_output_mappings(num_turns)
188
  output_mappings = input_mappings.copy()
189
  if num_turns == 1:
@@ -228,7 +202,7 @@ def get_magpie_generator(num_turns, num_rows, system_prompt, is_sample):
228
  return magpie_generator
229
 
230
 
231
- def get_response_generator(num_turns, system_prompt, is_sample):
232
  if num_turns == 1:
233
  response_generator = TextGeneration(
234
  llm=InferenceEndpointsLLM(
@@ -262,19 +236,46 @@ def get_response_generator(num_turns, system_prompt, is_sample):
262
  return response_generator
263
 
264
 
265
- def get_prompt_generator():
266
- prompt_generator = TextGeneration(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
267
  llm=InferenceEndpointsLLM(
268
- api_key=_get_next_api_key(),
269
  model_id=MODEL,
270
  tokenizer_id=MODEL,
271
- generation_kwargs={
272
- "temperature": 0.8,
273
- "max_new_tokens": 2048,
274
  "do_sample": True,
275
- },
 
 
 
276
  ),
277
- use_system_prompt=True,
 
 
 
 
278
  )
279
- prompt_generator.load()
280
- return prompt_generator
 
 
 
 
 
 
 
 
138
  return {"conversation": "messages"}
139
 
140
 
141
+ def get_prompt_generator(temperature):
142
+ prompt_generator = TextGeneration(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
  llm=InferenceEndpointsLLM(
144
+ api_key=_get_next_api_key(),
145
  model_id=MODEL,
146
  tokenizer_id=MODEL,
147
+ generation_kwargs={
148
+ "temperature": temperature,
 
 
149
  "max_new_tokens": 2048,
150
+ "do_sample": True,
151
+ },
 
152
  ),
153
+ system_prompt=PROMPT_CREATION_PROMPT,
154
+ use_system_prompt=True,
 
 
 
 
 
 
155
  )
156
+ prompt_generator.load()
157
+ return prompt_generator
 
 
 
 
158
 
159
 
160
+ def get_magpie_generator(system_prompt, num_turns, is_sample):
161
  input_mappings = _get_output_mappings(num_turns)
162
  output_mappings = input_mappings.copy()
163
  if num_turns == 1:
 
202
  return magpie_generator
203
 
204
 
205
+ def get_response_generator(system_prompt, num_turns, is_sample):
206
  if num_turns == 1:
207
  response_generator = TextGeneration(
208
  llm=InferenceEndpointsLLM(
 
236
  return response_generator
237
 
238
 
239
+ def generate_pipeline_code(system_prompt, num_turns, num_rows):
240
+ input_mappings = _get_output_mappings(num_turns)
241
+ code = f"""
242
+ # Requirements: `pip install distilabel[hf-inference-endpoints]`
243
+ import os
244
+ from distilabel.pipeline import Pipeline
245
+ from distilabel.steps import KeepColumns
246
+ from distilabel.steps.tasks import MagpieGenerator
247
+ from distilabel.llms import InferenceEndpointsLLM
248
+
249
+ MODEL = "{MODEL}"
250
+ SYSTEM_PROMPT = "{system_prompt}"
251
+ os.environ["HF_TOKEN"] = "hf_xxx" # https://huggingface.co/settings/tokens/new?ownUserPermissions=repo.content.read&ownUserPermissions=repo.write&globalPermissions=inference.serverless.write&canReadGatedRepos=true&tokenType=fineGrained
252
+
253
+ with Pipeline(name="sft") as pipeline:
254
+ magpie = MagpieGenerator(
255
  llm=InferenceEndpointsLLM(
 
256
  model_id=MODEL,
257
  tokenizer_id=MODEL,
258
+ magpie_pre_query_template="llama3",
259
+ generation_kwargs={{
260
+ "temperature": 0.9,
261
  "do_sample": True,
262
+ "max_new_tokens": 2048,
263
+ "stop_sequences": {_STOP_SEQUENCES}
264
+ }},
265
+ api_key=os.environ["HF_TOKEN"],
266
  ),
267
+ n_turns={num_turns},
268
+ num_rows={num_rows},
269
+ batch_size=1,
270
+ system_prompt=SYSTEM_PROMPT,
271
+ output_mappings={input_mappings},
272
  )
273
+ keep_columns = KeepColumns(
274
+ columns={list(input_mappings.values())} + ["model_name"],
275
+ )
276
+ magpie.connect(keep_columns)
277
+
278
+ if __name__ == "__main__":
279
+ distiset = pipeline.run()
280
+ """
281
+ return code
src/distilabel_dataset_generator/pipelines/textcat.py CHANGED
@@ -1,4 +1,5 @@
1
  import random
 
2
  from typing import List
3
 
4
  from distilabel.llms import InferenceEndpointsLLM
@@ -22,25 +23,27 @@ The prompt you write should follow the same style and structure as the following
22
 
23
  If a label is composed of multiple words, use a hyphen to separate them. For example, 'smartphone-review', 'customer-service', 'product-quality'.:
24
 
25
- Classify the following customer review of a cinema as either 'positive' or 'negative'.
26
 
27
- Classify the following news article into one or more of the following categories: 'politics', 'sports', 'technology', 'entertainment', 'health', 'business', 'environment', 'education', 'science', 'international'.
28
 
29
- Determine the sentiment of the following social media post: 'ambiguous', 'sarcastic', 'informative', 'emotional'.
30
 
31
- Identify the issue category for the following technical support ticket: 'billing', 'technical', 'account', 'shipping', 'returns', 'installation', 'subscription'.
32
 
33
- Classify the following movie review into one of the following categories: 'critical', 'praise', 'disappointed', 'enthusiastic'.
34
 
35
- Determine the level of customer satisfaction from the following customer service transcript: 'satisfied', 'dissatisfied', 'highly-satisfied', 'somewhat-dissatisfied', 'indifferent'.
36
 
37
- Categorize the following product description into one of the following product types: 'smartphone', 'laptop', 'tablet', 'smartwatch', 'e-reader', 'headphones'.
38
 
39
- Classify the following tweet as expressing either 'support' or 'opposition' to the political event discussed.
40
 
41
- Classify the following restaurant review into one of the following categories: 'food-quality', 'service', 'ambiance', or 'price'.
42
 
43
- Classify the following blog post based on its primary fashion trend or style: 'casual', 'formal', 'streetwear', 'vintage' or 'sustainable-fashion'.
 
 
44
 
45
  User dataset description:
46
  """
@@ -51,6 +54,82 @@ DEFAULT_DATASET_DESCRIPTIONS = [
51
  ]
52
 
53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  def generate_pipeline_code(
55
  system_prompt: str,
56
  difficulty: str = None,
@@ -146,63 +225,3 @@ with Pipeline(name="textcat") as pipeline:
146
  distiset = pipeline.run()
147
  """
148
  )
149
-
150
-
151
- def get_textcat_generator(difficulty, clarity, is_sample):
152
- textcat_generator = GenerateTextClassificationData(
153
- llm=InferenceEndpointsLLM(
154
- model_id=MODEL,
155
- tokenizer_id=MODEL,
156
- api_key=_get_next_api_key(),
157
- generation_kwargs={
158
- "temperature": 0.9,
159
- "max_new_tokens": 256 if is_sample else 2048,
160
- "do_sample": True,
161
- "top_k": 50,
162
- "top_p": 0.95,
163
- },
164
- ),
165
- difficulty=None if difficulty == "mixed" else difficulty,
166
- clarity=None if clarity == "mixed" else clarity,
167
- seed=random.randint(0, 2**32 - 1),
168
- )
169
- textcat_generator.load()
170
- return textcat_generator
171
-
172
-
173
- def get_labeller_generator(system_prompt, labels, num_labels):
174
- labeller_generator = TextClassification(
175
- llm=InferenceEndpointsLLM(
176
- model_id=MODEL,
177
- tokenizer_id=MODEL,
178
- api_key=_get_next_api_key(),
179
- generation_kwargs={
180
- "temperature": 0.7,
181
- "max_new_tokens": 2048,
182
- },
183
- ),
184
- context=system_prompt,
185
- available_labels=labels,
186
- n=num_labels,
187
- default_label="unknown",
188
- )
189
- labeller_generator.load()
190
- return labeller_generator
191
-
192
-
193
- def get_prompt_generator():
194
- prompt_generator = TextGeneration(
195
- llm=InferenceEndpointsLLM(
196
- api_key=_get_next_api_key(),
197
- model_id=MODEL,
198
- tokenizer_id=MODEL,
199
- generation_kwargs={
200
- "temperature": 0.8,
201
- "max_new_tokens": 2048,
202
- "do_sample": True,
203
- },
204
- ),
205
- use_system_prompt=True,
206
- )
207
- prompt_generator.load()
208
- return prompt_generator
 
1
  import random
2
+ from pydantic import BaseModel, Field
3
  from typing import List
4
 
5
  from distilabel.llms import InferenceEndpointsLLM
 
23
 
24
  If a label is composed of multiple words, use a hyphen to separate them. For example, 'smartphone-review', 'customer-service', 'product-quality'.:
25
 
26
+ {"classification_task": "Classify the following customer review of a cinema as", "labels": ["positive", "negative"]}
27
 
28
+ {"classification_task": "Categorize the following news article into one or more of the following categories:", "labels": ["politics", "sports", "technology", "entertainment", "health", "business", "environment", "education", "science", "international"]}
29
 
30
+ {"classification_task": "Classify the following news article into one or more of the following categories:", "labels": ['politics', 'sports', 'technology', 'entertainment', 'health', 'business', 'environment', 'education', 'science', 'international']}
31
 
32
+ {"classification_task": "Determine the sentiment of the following social media post:", "labels": ['ambiguous', 'sarcastic', 'informative', 'emotional']}
33
 
34
+ {"classification_task": "Identify the issue category for the following technical support ticket:", "labels": ['billing', 'technical', 'account', 'shipping', 'returns', 'installation', 'subscription']}
35
 
36
+ {"classification_task": "Classify the following movie review into one of the following categories:", "labels": ['critical', 'praise', 'disappointed', 'enthusiastic']}
37
 
38
+ {"classification_task": "Categorize the following customer service transcript into one of the following categories:", "labels": ['satisfied', 'dissatisfied', 'highly-satisfied', 'somewhat-dissatisfied', 'indifferent']}
39
 
40
+ {"classification_task": "Classify the following product description into one of the following product types:", "labels": ['smartphone', 'laptop', 'tablet', 'smartwatch', 'e-reader', 'headphones']}
41
 
42
+ {"classification_task": "Categorize the following tweet expressing the political event discussed as", "labels": ['support', 'opposition']}
43
 
44
+ {"classification_task": "Classify the following restaurant review into one of the following categories:", "labels": ['food-quality', 'service', 'ambiance', 'price']}
45
+
46
+ {"classification_task": "Categorize the following blog post based on its primary fashion trend or style:", "labels": ['casual', 'formal', 'streetwear', 'vintage', 'sustainable-fashion']}
47
 
48
  User dataset description:
49
  """
 
54
  ]
55
 
56
 
57
+ class TextClassificationTask(BaseModel):
58
+ classification_task: str = Field(
59
+ ...,
60
+ title="classification_task",
61
+ description="The classification task to be performed.",
62
+ )
63
+
64
+ labels: list[str] = Field(
65
+ ...,
66
+ title="Labels",
67
+ description="The possible labels for the classification task.",
68
+ )
69
+
70
+
71
+ def get_prompt_generator(temperature):
72
+ prompt_generator = TextGeneration(
73
+ llm=InferenceEndpointsLLM(
74
+ api_key=_get_next_api_key(),
75
+ model_id=MODEL,
76
+ tokenizer_id=MODEL,
77
+ structured_output={"format": "json", "schema": TextClassificationTask},
78
+ generation_kwargs={
79
+ "temperature": temperature,
80
+ "max_new_tokens": 2048,
81
+ "do_sample": True,
82
+ },
83
+ ),
84
+ system_prompt=PROMPT_CREATION_PROMPT,
85
+ use_system_prompt=True,
86
+ )
87
+ prompt_generator.load()
88
+ return prompt_generator
89
+
90
+
91
+ def get_textcat_generator(difficulty, clarity, is_sample):
92
+ textcat_generator = GenerateTextClassificationData(
93
+ llm=InferenceEndpointsLLM(
94
+ model_id=MODEL,
95
+ tokenizer_id=MODEL,
96
+ api_key=_get_next_api_key(),
97
+ generation_kwargs={
98
+ "temperature": 0.9,
99
+ "max_new_tokens": 256 if is_sample else 2048,
100
+ "do_sample": True,
101
+ "top_k": 50,
102
+ "top_p": 0.95,
103
+ },
104
+ ),
105
+ difficulty=None if difficulty == "mixed" else difficulty,
106
+ clarity=None if clarity == "mixed" else clarity,
107
+ seed=random.randint(0, 2**32 - 1),
108
+ )
109
+ textcat_generator.load()
110
+ return textcat_generator
111
+
112
+
113
+ def get_labeller_generator(system_prompt, labels, num_labels):
114
+ labeller_generator = TextClassification(
115
+ llm=InferenceEndpointsLLM(
116
+ model_id=MODEL,
117
+ tokenizer_id=MODEL,
118
+ api_key=_get_next_api_key(),
119
+ generation_kwargs={
120
+ "temperature": 0.7,
121
+ "max_new_tokens": 2048,
122
+ },
123
+ ),
124
+ context=system_prompt,
125
+ available_labels=labels,
126
+ n=num_labels,
127
+ default_label="unknown",
128
+ )
129
+ labeller_generator.load()
130
+ return labeller_generator
131
+
132
+
133
  def generate_pipeline_code(
134
  system_prompt: str,
135
  difficulty: str = None,
 
225
  distiset = pipeline.run()
226
  """
227
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/distilabel_dataset_generator/utils.py CHANGED
@@ -1,8 +1,11 @@
 
1
  import os
2
  from typing import List, Optional, Union
3
 
4
  import argilla as rg
5
  import gradio as gr
 
 
6
  from gradio.oauth import (
7
  OAUTH_CLIENT_ID,
8
  OAUTH_CLIENT_SECRET,
@@ -11,6 +14,7 @@ from gradio.oauth import (
11
  get_space,
12
  )
13
  from huggingface_hub import whoami
 
14
 
15
  _LOGGED_OUT_CSS = ".main_ui_logged_out{opacity: 0.3; pointer-events: none}"
16
 
@@ -50,22 +54,22 @@ def list_orgs(oauth_token: OAuthToken = None):
50
  return []
51
  data = whoami(oauth_token.token)
52
  if data["auth"]["type"] == "oauth":
53
- organisations = [data["name"]] + [org["name"] for org in data["orgs"]]
54
  elif data["auth"]["type"] == "access_token":
55
- organisations = [org["name"] for org in data["orgs"]]
56
  else:
57
- organisations = [
58
  entry["entity"]["name"]
59
  for entry in data["auth"]["accessToken"]["fineGrained"]["scoped"]
60
  if "repo.write" in entry["permissions"]
61
  ]
62
- organisations = [org for org in organisations if org != data["name"]]
63
- organisations = [data["name"]] + organisations
64
  except Exception as e:
65
  raise gr.Error(
66
  f"Failed to get organizations: {e}. See if you are logged and connected: https://huggingface.co/settings/connected-applications."
67
  )
68
- return organisations
69
 
70
 
71
  def get_org_dropdown(oauth_token: OAuthToken = None):
@@ -89,7 +93,7 @@ def get_token(oauth_token: OAuthToken = None):
89
  return ""
90
 
91
 
92
- def swap_visibilty(oauth_token: Optional[OAuthToken] = None):
93
  if oauth_token:
94
  return gr.update(elem_classes=["main_ui_logged_in"])
95
  else:
@@ -132,6 +136,91 @@ def get_argilla_client() -> Union[rg.Argilla, None]:
132
  except Exception:
133
  return None
134
 
135
-
136
  def get_preprocess_labels(labels: Optional[List[str]]) -> List[str]:
137
  return list(set([label.lower().strip() for label in labels])) if labels else []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
  import os
3
  from typing import List, Optional, Union
4
 
5
  import argilla as rg
6
  import gradio as gr
7
+ import numpy as np
8
+ import pandas as pd
9
  from gradio.oauth import (
10
  OAUTH_CLIENT_ID,
11
  OAUTH_CLIENT_SECRET,
 
14
  get_space,
15
  )
16
  from huggingface_hub import whoami
17
+ from jinja2 import Environment, meta
18
 
19
  _LOGGED_OUT_CSS = ".main_ui_logged_out{opacity: 0.3; pointer-events: none}"
20
 
 
54
  return []
55
  data = whoami(oauth_token.token)
56
  if data["auth"]["type"] == "oauth":
57
+ organizations = [data["name"]] + [org["name"] for org in data["orgs"]]
58
  elif data["auth"]["type"] == "access_token":
59
+ organizations = [org["name"] for org in data["orgs"]]
60
  else:
61
+ organizations = [
62
  entry["entity"]["name"]
63
  for entry in data["auth"]["accessToken"]["fineGrained"]["scoped"]
64
  if "repo.write" in entry["permissions"]
65
  ]
66
+ organizations = [org for org in organizations if org != data["name"]]
67
+ organizations = [data["name"]] + organizations
68
  except Exception as e:
69
  raise gr.Error(
70
  f"Failed to get organizations: {e}. See if you are logged and connected: https://huggingface.co/settings/connected-applications."
71
  )
72
+ return organizations
73
 
74
 
75
  def get_org_dropdown(oauth_token: OAuthToken = None):
 
93
  return ""
94
 
95
 
96
+ def swap_visibility(oauth_token: Optional[OAuthToken] = None):
97
  if oauth_token:
98
  return gr.update(elem_classes=["main_ui_logged_in"])
99
  else:
 
136
  except Exception:
137
  return None
138
 
 
139
  def get_preprocess_labels(labels: Optional[List[str]]) -> List[str]:
140
  return list(set([label.lower().strip() for label in labels])) if labels else []
141
+
142
+
143
+ def column_to_list(dataframe: pd.DataFrame, column_name: str) -> List[str]:
144
+ if column_name in dataframe.columns:
145
+ return dataframe[column_name].tolist()
146
+ else:
147
+ raise ValueError(f"Column '{column_name}' does not exist.")
148
+
149
+
150
+ def process_columns(
151
+ dataframe,
152
+ instruction_column: str,
153
+ response_columns: Union[str, List[str]],
154
+ ) -> List[dict]:
155
+ instruction_column = [instruction_column]
156
+ if isinstance(response_columns, str):
157
+ response_columns = [response_columns]
158
+
159
+ data = []
160
+ for _, row in dataframe.iterrows():
161
+ instruction = ""
162
+ for col in instruction_column:
163
+ value = row[col]
164
+ if isinstance(value, (list, np.ndarray)):
165
+ user_contents = [d["content"] for d in value if d.get("role") == "user"]
166
+ if user_contents:
167
+ instruction = user_contents[-1]
168
+ elif isinstance(value, str):
169
+ try:
170
+ parsed_message = json.loads(value)
171
+ user_contents = [
172
+ d["content"] for d in parsed_message if d.get("role") == "user"
173
+ ]
174
+ if user_contents:
175
+ instruction = user_contents[-1]
176
+ except json.JSONDecodeError:
177
+ instruction = value
178
+ else:
179
+ instruction = ""
180
+
181
+ generations = []
182
+ for col in response_columns:
183
+ value = row[col]
184
+ if isinstance(value, (list, np.ndarray)):
185
+ if all(isinstance(item, dict) and "role" in item for item in value):
186
+ assistant_contents = [
187
+ d["content"] for d in value if d.get("role") == "assistant"
188
+ ]
189
+ if assistant_contents:
190
+ generations.append(assistant_contents[-1])
191
+ else:
192
+ generations.extend(value)
193
+ elif isinstance(value, str):
194
+ try:
195
+ parsed_message = json.loads(value)
196
+ assistant_contents = [
197
+ d["content"]
198
+ for d in parsed_message
199
+ if d.get("role") == "assistant"
200
+ ]
201
+ if assistant_contents:
202
+ generations.append(assistant_contents[-1])
203
+ except json.JSONDecodeError:
204
+ generations.append(value)
205
+ else:
206
+ pass
207
+
208
+ data.append({"instruction": instruction, "generations": generations})
209
+
210
+ return data
211
+
212
+
213
+ def extract_column_names(prompt_template: str) -> List[str]:
214
+ env = Environment()
215
+ parsed_content = env.parse(prompt_template)
216
+ variables = meta.find_undeclared_variables(parsed_content)
217
+ return list(variables)
218
+
219
+
220
+ def pad_or_truncate_list(lst, target_length):
221
+ lst = lst or []
222
+ lst_length = len(lst)
223
+ if lst_length >= target_length:
224
+ return lst[-target_length:]
225
+ else:
226
+ return lst + [None] * (target_length - lst_length)