davidberenstein1957 HF staff commited on
Commit
099e99c
1 Parent(s): 080f560

refactor: redesign of the generator

Browse files
app.py CHANGED
@@ -1,5 +1,6 @@
1
  import gradio as gr
2
 
 
3
  from src.distilabel_dataset_generator.apps.faq import app as faq_app
4
  from src.distilabel_dataset_generator.apps.sft import app as sft_app
5
  from src.distilabel_dataset_generator.apps.textcat import app as textcat_app
@@ -23,64 +24,37 @@ css = """
23
  background-color: black;
24
  }
25
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  """
27
 
28
- demo = gr.TabbedInterface(
29
  [textcat_app, sft_app, faq_app],
30
  ["Text Classification", "Supervised Fine-Tuning", "FAQ"],
31
  css=css,
32
  title="""
33
- <style>
34
- .header-container {
35
- display: flex;
36
- align-items: center;
37
- justify-content: center;
38
- position: relative;
39
- padding: 20px 0;
40
- }
41
- .logo-container {
42
- position: absolute;
43
- left: 0;
44
- top: 0;
45
- }
46
- .title-container {
47
- text-align: center;
48
- }
49
- @media (max-width: 600px) {
50
- .header-container {
51
- flex-direction: column;
52
- }
53
- .logo-container {
54
- position: static;
55
- margin-bottom: 20px;
56
- }
57
- }
58
- button[role="tab"].selected,
59
- button[role="tab"][aria-selected="true"],
60
- button[role="tab"][data-tab-id][aria-selected="true"] {
61
- background-color: #000000;
62
- color: white;
63
- border: none;
64
- font-size: 16px;
65
- font-weight: bold;
66
- box-shadow: 0 4px 8px rgba(0, 0, 0, 0.2);
67
- transition: background-color 0.3s ease, color 0.3s ease;
68
- }
69
- </style>
70
- <div class="header-container">
71
- <div class="logo-container">
72
- <a href="https://github.com/argilla-io/distilabel" target="_blank" rel="noopener noreferrer">
73
- <img src="https://distilabel.argilla.io/latest/assets/distilabel-black.svg" alt="Distilabel Logo" style="width: 150px; height: auto;">
74
- </a>
75
- </div>
76
- <div class="title-container">
77
- <h1 style="margin: 0; font-size: 2em;">🧬 Synthetic Data Generator</h1>
78
- <p style="margin: 10px 0 0 0; color: #666; font-size: 1.1em;">Build datasets using natural language</p>
79
- </div>
80
- </div>
81
  """,
 
82
  theme=theme,
83
  )
84
 
 
85
  if __name__ == "__main__":
86
  demo.launch()
 
1
  import gradio as gr
2
 
3
+ from src.distilabel_dataset_generator._tabbedinterface import TabbedInterface
4
  from src.distilabel_dataset_generator.apps.faq import app as faq_app
5
  from src.distilabel_dataset_generator.apps.sft import app as sft_app
6
  from src.distilabel_dataset_generator.apps.textcat import app as textcat_app
 
24
  background-color: black;
25
  }
26
  }
27
+ button[role="tab"].selected,
28
+ button[role="tab"][aria-selected="true"],
29
+ button[role="tab"][data-tab-id][aria-selected="true"] {
30
+ background-color: #000000;
31
+ color: white;
32
+ border: none;
33
+ font-size: 16px;
34
+ font-weight: bold;
35
+ box-shadow: 0 4px 8px rgba(0, 0, 0, 0.2);
36
+ transition: background-color 0.3s ease, color 0.3s ease;
37
+ }
38
+ .gallery {
39
+ color: black !important;
40
+ }
41
+ .flex-shrink-0.truncate.px-1 {
42
+ color: black !important;
43
+ }
44
  """
45
 
46
+ demo = TabbedInterface(
47
  [textcat_app, sft_app, faq_app],
48
  ["Text Classification", "Supervised Fine-Tuning", "FAQ"],
49
  css=css,
50
  title="""
51
+ <h1>Synthetic Data Generator</h1>
52
+ <h3>Build datasets using natural language</h3>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  """,
54
+ head="Synthetic Data Generator",
55
  theme=theme,
56
  )
57
 
58
+
59
  if __name__ == "__main__":
60
  demo.launch()
demo.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ from src.distilabel_dataset_generator._tabbedinterface import TabbedInterface
4
+ from src.distilabel_dataset_generator.apps.eval import app as eval_app
5
+ from src.distilabel_dataset_generator.apps.faq import app as faq_app
6
+ from src.distilabel_dataset_generator.apps.sft import app as sft_app
7
+ from src.distilabel_dataset_generator.apps.textcat import app as textcat_app
8
+
9
+ theme = gr.themes.Monochrome(
10
+ spacing_size="md",
11
+ font=[gr.themes.GoogleFont("Inter"), "ui-sans-serif", "system-ui", "sans-serif"],
12
+ )
13
+
14
+ css = """
15
+ .main_ui_logged_out{opacity: 0.3; pointer-events: none}
16
+ .tabitem{border: 0px}
17
+ .group_padding{padding: .55em}
18
+ #space_model .wrap > label:last-child{opacity: 0.3; pointer-events:none}
19
+ #system_prompt_examples {
20
+ color: black;
21
+ }
22
+ @media (prefers-color-scheme: dark) {
23
+ #system_prompt_examples {
24
+ color: white;
25
+ background-color: black;
26
+ }
27
+ }
28
+ button[role="tab"].selected,
29
+ button[role="tab"][aria-selected="true"],
30
+ button[role="tab"][data-tab-id][aria-selected="true"] {
31
+ background-color: #000000;
32
+ color: white;
33
+ border: none;
34
+ font-size: 16px;
35
+ font-weight: bold;
36
+ box-shadow: 0 4px 8px rgba(0, 0, 0, 0.2);
37
+ transition: background-color 0.3s ease, color 0.3s ease;
38
+ }
39
+ .gallery {
40
+ color: black !important;
41
+ }
42
+ .flex-shrink-0.truncate.px-1 {
43
+ color: black !important;
44
+ }
45
+ """
46
+
47
+ demo = TabbedInterface(
48
+ [textcat_app, sft_app, eval_app, faq_app],
49
+ ["Text Classification", "Supervised Fine-Tuning", "Evaluation", "FAQ"],
50
+ css=css,
51
+ title="""
52
+ <h1>Synthetic Data Generator</h1>
53
+ <h3>Build datasets using natural language</h3>
54
+ """,
55
+ head="Synthetic Data Generator",
56
+ theme=theme,
57
+ )
58
+
59
+
60
+ if __name__ == "__main__":
61
+ demo.launch()
pdm.lock CHANGED
The diff for this file is too large to render. See raw diff
 
pyproject.toml CHANGED
@@ -6,11 +6,13 @@ authors = [
6
  {name = "davidberenstein1957", email = "david.m.berenstein@gmail.com"},
7
  ]
8
  dependencies = [
9
- "distilabel[hf-inference-endpoints,argilla]>=1.4.1",
10
- "gradio[oauth]>=5.5.0",
11
  "transformers>=4.44.2",
12
  "sentence-transformers>=3.2.0",
13
  "model2vec>=0.2.4",
 
 
14
  ]
15
  requires-python = "<3.13,>=3.10"
16
  readme = "README.md"
 
6
  {name = "davidberenstein1957", email = "david.m.berenstein@gmail.com"},
7
  ]
8
  dependencies = [
9
+ "distilabel[hf-inference-endpoints,argilla,outlines]>=1.4.1",
10
+ "gradio[oauth]<5.0.0",
11
  "transformers>=4.44.2",
12
  "sentence-transformers>=3.2.0",
13
  "model2vec>=0.2.4",
14
+ "gradio-huggingfacehub-search>=0.0.7",
15
+ "argilla>=2.4.0",
16
  ]
17
  requires-python = "<3.13,>=3.10"
18
  readme = "README.md"
requirements.txt CHANGED
@@ -1,7 +1,148 @@
1
- transformers
2
- gradio[oauth]
3
- distilabel[hf-inference-endpoints,argilla]
4
- beautifulsoup4
5
- sentence-transformers
6
- model2vec
7
- outlines
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This file is @generated by PDM.
2
+ # Please do not edit it manually.
3
+
4
+ aiofiles==23.2.1
5
+ aiohappyeyeballs==2.4.3
6
+ aiohttp==3.11.7
7
+ aiosignal==1.3.1
8
+ airportsdata==20241001
9
+ annotated-types==0.7.0
10
+ anyio==4.6.2.post1
11
+ argilla==2.4.0
12
+ asttokens==2.4.1
13
+ async-timeout==5.0.1; python_version < "3.11"
14
+ attrs==24.2.0
15
+ authlib==1.3.2
16
+ certifi==2024.8.30
17
+ cffi==1.17.1; platform_python_implementation != "PyPy"
18
+ charset-normalizer==3.4.0
19
+ click==8.1.7
20
+ cloudpickle==3.1.0
21
+ colorama==0.4.6; platform_system == "Windows" or sys_platform == "win32"
22
+ contourpy==1.3.1
23
+ cryptography==43.0.3
24
+ cycler==0.12.1
25
+ datasets==3.1.0
26
+ decorator==5.1.1
27
+ dill==0.3.8
28
+ diskcache==5.6.3
29
+ distilabel==1.4.1
30
+ distilabel[argilla,hf-inference-endpoints,outlines]==1.4.1
31
+ exceptiongroup==1.2.2; python_version < "3.11"
32
+ executing==2.1.0
33
+ fastapi==0.115.5
34
+ ffmpy==0.4.0
35
+ filelock==3.16.1
36
+ fonttools==4.55.0
37
+ frozenlist==1.5.0
38
+ fsspec==2024.9.0
39
+ fsspec[http]==2024.9.0
40
+ gradio==4.44.1
41
+ gradio-client==1.3.0
42
+ gradio-huggingfacehub-search==0.0.7
43
+ gradio[oauth]==4.44.1
44
+ h11==0.14.0
45
+ httpcore==1.0.7
46
+ httpx==0.27.2
47
+ huggingface-hub==0.26.2
48
+ idna==3.10
49
+ importlib-resources==6.4.5
50
+ interegular==0.3.3
51
+ ipython==8.29.0
52
+ itsdangerous==2.2.0
53
+ jedi==0.19.2
54
+ jinja2==3.1.4
55
+ joblib==1.4.2
56
+ jsonschema==4.23.0
57
+ jsonschema-specifications==2024.10.1
58
+ kiwisolver==1.4.7
59
+ lark==1.2.2
60
+ llvmlite==0.43.0
61
+ markdown-it-py==3.0.0
62
+ markupsafe==2.1.5
63
+ matplotlib==3.9.2
64
+ matplotlib-inline==0.1.7
65
+ mdurl==0.1.2
66
+ model2vec==0.3.3
67
+ mpmath==1.3.0; python_version >= "3.9"
68
+ multidict==6.1.0
69
+ multiprocess==0.70.16
70
+ nest-asyncio==1.6.0
71
+ networkx==3.4.2
72
+ numba==0.60.0
73
+ numpy==1.26.4
74
+ nvidia-cublas-cu12==12.4.5.8; platform_system == "Linux" and platform_machine == "x86_64"
75
+ nvidia-cuda-cupti-cu12==12.4.127; platform_system == "Linux" and platform_machine == "x86_64"
76
+ nvidia-cuda-nvrtc-cu12==12.4.127; platform_system == "Linux" and platform_machine == "x86_64"
77
+ nvidia-cuda-runtime-cu12==12.4.127; platform_system == "Linux" and platform_machine == "x86_64"
78
+ nvidia-cudnn-cu12==9.1.0.70; platform_system == "Linux" and platform_machine == "x86_64"
79
+ nvidia-cufft-cu12==11.2.1.3; platform_system == "Linux" and platform_machine == "x86_64"
80
+ nvidia-curand-cu12==10.3.5.147; platform_system == "Linux" and platform_machine == "x86_64"
81
+ nvidia-cusolver-cu12==11.6.1.9; platform_system == "Linux" and platform_machine == "x86_64"
82
+ nvidia-cusparse-cu12==12.3.1.170; platform_system == "Linux" and platform_machine == "x86_64"
83
+ nvidia-nccl-cu12==2.21.5; platform_system == "Linux" and platform_machine == "x86_64"
84
+ nvidia-nvjitlink-cu12==12.4.127; platform_system == "Linux" and platform_machine == "x86_64"
85
+ nvidia-nvtx-cu12==12.4.127; platform_system == "Linux" and platform_machine == "x86_64"
86
+ orjson==3.10.11
87
+ outlines==0.1.4
88
+ outlines-core==0.1.17
89
+ packaging==24.2
90
+ pandas==2.2.3
91
+ parso==0.8.4
92
+ pexpect==4.9.0; sys_platform != "win32" and sys_platform != "emscripten"
93
+ pillow==10.4.0
94
+ portalocker==3.0.0
95
+ prompt-toolkit==3.0.48
96
+ propcache==0.2.0
97
+ ptyprocess==0.7.0; sys_platform != "win32" and sys_platform != "emscripten"
98
+ pure-eval==0.2.3
99
+ pyarrow==18.0.0
100
+ pycountry==24.6.1
101
+ pycparser==2.22; platform_python_implementation != "PyPy"
102
+ pydantic==2.10.0
103
+ pydantic-core==2.27.0
104
+ pydub==0.25.1
105
+ pygments==2.18.0
106
+ pyparsing==3.2.0
107
+ python-dateutil==2.9.0.post0
108
+ python-multipart==0.0.17
109
+ pytz==2024.2
110
+ pywin32==308; platform_system == "Windows"
111
+ pyyaml==6.0.2
112
+ referencing==0.35.1
113
+ regex==2024.11.6
114
+ requests==2.32.3
115
+ rich==13.9.4
116
+ rpds-py==0.21.0
117
+ ruff==0.7.4; sys_platform != "emscripten"
118
+ safetensors==0.4.5
119
+ scikit-learn==1.5.2
120
+ scipy==1.14.1
121
+ semantic-version==2.10.0
122
+ sentence-transformers==3.3.1
123
+ setuptools==75.6.0
124
+ shellingham==1.5.4
125
+ six==1.16.0
126
+ sniffio==1.3.1
127
+ stack-data==0.6.3
128
+ starlette==0.41.3
129
+ sympy==1.13.1; python_version >= "3.9"
130
+ tblib==3.0.0
131
+ threadpoolctl==3.5.0
132
+ tokenizers==0.20.3
133
+ tomlkit==0.12.0
134
+ torch==2.5.1
135
+ tqdm==4.67.0
136
+ traitlets==5.14.3
137
+ transformers==4.46.3
138
+ triton==3.1.0; platform_system == "Linux" and platform_machine == "x86_64" and python_version < "3.13"
139
+ typer==0.13.1
140
+ typing-extensions==4.12.2
141
+ tzdata==2024.2
142
+ universal-pathlib==0.2.5
143
+ urllib3==2.2.3
144
+ uvicorn==0.32.1; sys_platform != "emscripten"
145
+ wcwidth==0.2.13
146
+ websockets==12.0
147
+ xxhash==3.5.0
148
+ yarl==1.18.0
src/distilabel_dataset_generator/_tabbedinterface.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This file defines two useful high-level abstractions to build Gradio apps: Interface and TabbedInterface.
3
+ """
4
+
5
+ from __future__ import annotations
6
+
7
+ from collections.abc import Sequence
8
+
9
+ import gradio as gr
10
+ from gradio.blocks import Blocks
11
+ from gradio.components import HTML
12
+ from gradio.layouts import Tab, Tabs
13
+ from gradio.themes import ThemeClass as Theme
14
+ from gradio_client.documentation import document
15
+
16
+
17
+ @document()
18
+ class TabbedInterface(Blocks):
19
+ """
20
+ A TabbedInterface is created by providing a list of Interfaces or Blocks, each of which gets
21
+ rendered in a separate tab. Only the components from the Interface/Blocks will be rendered in the tab.
22
+ Certain high-level attributes of the Blocks (e.g. custom `css`, `js`, and `head` attributes) will not be loaded.
23
+
24
+ Demos: tabbed_interface_lite
25
+ """
26
+
27
+ def __init__(
28
+ self,
29
+ interface_list: Sequence[Blocks],
30
+ tab_names: list[str] | None = None,
31
+ title: str | None = None,
32
+ theme: Theme | str | None = None,
33
+ analytics_enabled: bool | None = None,
34
+ css: str | None = None,
35
+ js: str | None = None,
36
+ head: str | None = None,
37
+ ):
38
+ """
39
+ Parameters:
40
+ interface_list: A list of Interfaces (or Blocks) to be rendered in the tabs.
41
+ tab_names: A list of tab names. If None, the tab names will be "Tab 1", "Tab 2", etc.
42
+ title: The tab title to display when this demo is opened in a browser window.
43
+ theme: A Theme object or a string representing a theme. If a string, will look for a built-in theme with that name (e.g. "soft" or "default"), or will attempt to load a theme from the Hugging Face Hub (e.g. "gradio/monochrome"). If None, will use the Default theme.
44
+ analytics_enabled: Whether to allow basic telemetry. If None, will use GRADIO_ANALYTICS_ENABLED environment variable or default to True.
45
+ css: Custom css as a string or path to a css file. This css will be included in the demo webpage.
46
+ js: Custom js as a string or path to a js file. The custom js should in the form of a single js function. This function will automatically be executed when the page loads. For more flexibility, use the head parameter to insert js inside <script> tags.
47
+ head: Custom html to insert into the head of the demo webpage. This can be used to add custom meta tags, multiple scripts, stylesheets, etc. to the page.
48
+ Returns:
49
+ a Gradio Tabbed Interface for the given interfaces
50
+ """
51
+ super().__init__(
52
+ title=title or "Gradio",
53
+ theme=theme,
54
+ analytics_enabled=analytics_enabled,
55
+ mode="tabbed_interface",
56
+ css=css,
57
+ js=js,
58
+ head=head,
59
+ )
60
+ if tab_names is None:
61
+ tab_names = [f"Tab {i}" for i in range(len(interface_list))]
62
+ with self:
63
+ if title:
64
+ HTML(value=title)
65
+ with gr.Row():
66
+ with gr.Column(scale=1):
67
+ gr.LoginButton(value="Sign in!", size="sm", scale=2)
68
+ with gr.Column(scale=3):
69
+ pass
70
+ with Tabs():
71
+ for interface, tab_name in zip(interface_list, tab_names, strict=False):
72
+ with Tab(label=tab_name):
73
+ interface.render()
src/distilabel_dataset_generator/apps/base.py CHANGED
@@ -168,8 +168,7 @@ def get_main_ui(
168
 
169
  def validate_argilla_user_workspace_dataset(
170
  dataset_name: str,
171
- final_dataset: pd.DataFrame,
172
- add_to_existing_dataset: bool,
173
  oauth_token: Union[OAuthToken, None] = None,
174
  progress=gr.Progress(),
175
  ) -> str:
@@ -193,7 +192,7 @@ def validate_argilla_user_workspace_dataset(
193
  dataset = client.datasets(name=dataset_name, workspace=hf_user)
194
  if dataset and not add_to_existing_dataset:
195
  raise gr.Error(f"Dataset {dataset_name} already exists")
196
- return final_dataset
197
 
198
 
199
  def get_org_dropdown(oauth_token: OAuthToken = None):
@@ -302,7 +301,8 @@ def get_iterate_on_sample_dataset_ui(
302
 
303
 
304
  def get_pipeline_code_ui(pipeline_code: str) -> gr.Code:
305
- gr.Markdown("## Or run this pipeline locally with distilabel")
 
306
  gr.Markdown(
307
  "You can run this pipeline locally with distilabel. For more information, please refer to the [distilabel documentation](https://distilabel.argilla.io/) or go to the FAQ tab at the top of the page for more information."
308
  )
@@ -400,7 +400,7 @@ def push_pipeline_code_to_hub(
400
  oauth_token: Union[OAuthToken, None] = None,
401
  progress=gr.Progress(),
402
  ):
403
- repo_id = _check_push_to_hub(org_name, repo_name)
404
  progress(0.1, desc="Uploading pipeline code")
405
  with io.BytesIO(pipeline_code.encode("utf-8")) as f:
406
  upload_file(
@@ -427,7 +427,7 @@ def push_dataset_to_hub(
427
  task: str = TEXTCAT_TASK,
428
  ) -> pd.DataFrame:
429
  progress(0.1, desc="Setting up dataset")
430
- repo_id = _check_push_to_hub(org_name, repo_name)
431
 
432
  if task == TEXTCAT_TASK:
433
  if num_labels == 1:
@@ -459,7 +459,7 @@ def push_dataset_to_hub(
459
  return dataframe
460
 
461
 
462
- def _check_push_to_hub(org_name, repo_name):
463
  repo_id = (
464
  f"{org_name}/{repo_name}"
465
  if repo_name is not None and org_name is not None
@@ -491,7 +491,7 @@ def get_success_message_row() -> gr.Markdown:
491
  return success_message
492
 
493
 
494
- def show_success_message_argilla() -> gr.Markdown:
495
  client = get_argilla_client()
496
  argilla_api_url = client.api_url
497
  return gr.Markdown(
@@ -499,7 +499,13 @@ def show_success_message_argilla() -> gr.Markdown:
499
  <div style="padding: 1em; background-color: #e6f3e6; border-radius: 5px; margin-top: 1em;">
500
  <h3 style="color: #2e7d32; margin: 0;">Dataset Published Successfully!</h3>
501
  <p style="margin-top: 0.5em;">
502
- Your dataset is now available at:
 
 
 
 
 
 
503
  <a href="{argilla_api_url}" target="_blank" style="color: #1565c0; text-decoration: none;">
504
  {argilla_api_url}
505
  </a>
@@ -513,23 +519,5 @@ def show_success_message_argilla() -> gr.Markdown:
513
  )
514
 
515
 
516
- def show_success_message_hub(org_name, repo_name) -> gr.Markdown:
517
- return gr.Markdown(
518
- value=f"""
519
- <div style="padding: 1em; background-color: #e6f3e6; border-radius: 5px; margin-top: 1em;">
520
- <h3 style="color: #2e7d32; margin: 0;">Dataset Published Successfully!</h3>
521
- <p style="margin-top: 0.5em;">
522
- The generated dataset is in the right format for fine-tuning with TRL, AutoTrain or other frameworks.
523
- Your dataset is now available at:
524
- <a href="https://huggingface.co/datasets/{org_name}/{repo_name}" target="_blank" style="color: #1565c0; text-decoration: none;">
525
- https://huggingface.co/datasets/{org_name}/{repo_name}
526
- </a>
527
- </p>
528
- </div>
529
- """,
530
- visible=True,
531
- )
532
-
533
-
534
  def hide_success_message() -> gr.Markdown:
535
- return gr.Markdown(visible=False)
 
168
 
169
  def validate_argilla_user_workspace_dataset(
170
  dataset_name: str,
171
+ add_to_existing_dataset: bool = True,
 
172
  oauth_token: Union[OAuthToken, None] = None,
173
  progress=gr.Progress(),
174
  ) -> str:
 
192
  dataset = client.datasets(name=dataset_name, workspace=hf_user)
193
  if dataset and not add_to_existing_dataset:
194
  raise gr.Error(f"Dataset {dataset_name} already exists")
195
+ return ""
196
 
197
 
198
  def get_org_dropdown(oauth_token: OAuthToken = None):
 
301
 
302
 
303
  def get_pipeline_code_ui(pipeline_code: str) -> gr.Code:
304
+ gr.Markdown("## Customize and run locally with distilabel")
305
+ gr.HTML("<hr>")
306
  gr.Markdown(
307
  "You can run this pipeline locally with distilabel. For more information, please refer to the [distilabel documentation](https://distilabel.argilla.io/) or go to the FAQ tab at the top of the page for more information."
308
  )
 
400
  oauth_token: Union[OAuthToken, None] = None,
401
  progress=gr.Progress(),
402
  ):
403
+ repo_id = validate_push_to_hub(org_name, repo_name)
404
  progress(0.1, desc="Uploading pipeline code")
405
  with io.BytesIO(pipeline_code.encode("utf-8")) as f:
406
  upload_file(
 
427
  task: str = TEXTCAT_TASK,
428
  ) -> pd.DataFrame:
429
  progress(0.1, desc="Setting up dataset")
430
+ repo_id = validate_push_to_hub(org_name, repo_name)
431
 
432
  if task == TEXTCAT_TASK:
433
  if num_labels == 1:
 
459
  return dataframe
460
 
461
 
462
+ def validate_push_to_hub(org_name, repo_name):
463
  repo_id = (
464
  f"{org_name}/{repo_name}"
465
  if repo_name is not None and org_name is not None
 
491
  return success_message
492
 
493
 
494
+ def show_success_message_hub(org_name, repo_name) -> gr.Markdown:
495
  client = get_argilla_client()
496
  argilla_api_url = client.api_url
497
  return gr.Markdown(
 
499
  <div style="padding: 1em; background-color: #e6f3e6; border-radius: 5px; margin-top: 1em;">
500
  <h3 style="color: #2e7d32; margin: 0;">Dataset Published Successfully!</h3>
501
  <p style="margin-top: 0.5em;">
502
+ Your dataset is now available the Hugging Face Hub:
503
+ <a href="https://huggingface.co/datasets/{org_name}/{repo_name}" target="_blank" style="color: #1565c0; text-decoration: none;">
504
+ https://huggingface.co/datasets/{org_name}/{repo_name}
505
+ </a>
506
+ </p>
507
+ <p style="margin-top: 0.5em;">
508
+ Your dataset is now available within Argilla:
509
  <a href="{argilla_api_url}" target="_blank" style="color: #1565c0; text-decoration: none;">
510
  {argilla_api_url}
511
  </a>
 
519
  )
520
 
521
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
522
  def hide_success_message() -> gr.Markdown:
523
+ return gr.Markdown(value="")
src/distilabel_dataset_generator/apps/eval.py ADDED
@@ -0,0 +1,328 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+
3
+ import gradio as gr
4
+ import pandas as pd
5
+ from datasets import load_dataset
6
+ from gradio_huggingfacehub_search import HuggingfaceHubSearch
7
+
8
+ from src.distilabel_dataset_generator.utils import get_org_dropdown
9
+
10
+
11
+ def get_iframe(hub_repo_id) -> str:
12
+ if not hub_repo_id:
13
+ raise gr.Error("Hub repo id is required")
14
+ url = f"https://huggingface.co/datasets/{hub_repo_id}/embed/viewer"
15
+ iframe = f"""
16
+ <iframe
17
+ src="{url}"
18
+ frameborder="0"
19
+ width="100%"
20
+ height="600px"
21
+ ></iframe>
22
+ """
23
+ return iframe
24
+
25
+
26
+ def get_valid_columns(df: pd.DataFrame):
27
+ valid_columns = []
28
+ for col in df.columns:
29
+ sample_val = df[col].iloc[0]
30
+ if isinstance(sample_val, str) or (
31
+ isinstance(sample_val, list)
32
+ and all(isinstance(item, dict) for item in sample_val)
33
+ ):
34
+ valid_columns.append(col)
35
+ return valid_columns
36
+
37
+
38
+ def load_dataset_from_hub(hub_repo_id: str, n_rows: int = 10):
39
+ gr.Info(message="Loading dataset ...")
40
+ if not hub_repo_id:
41
+ raise gr.Error("Hub repo id is required")
42
+ ds_dict = load_dataset(hub_repo_id)
43
+ splits = list(ds_dict.keys())
44
+ ds = ds_dict[splits[0]]
45
+ if n_rows:
46
+ ds = ds.select(range(n_rows))
47
+ df = ds.to_pandas()
48
+ # Get columns that contain either strings or lists of dictionaries
49
+ valid_columns = get_valid_columns(df)
50
+ return (
51
+ df,
52
+ gr.Dropdown(choices=valid_columns, label="Instruction Column"),
53
+ gr.Dropdown(choices=valid_columns, label="Instruction Column"),
54
+ gr.Dropdown(choices=valid_columns, label="Response Column"),
55
+ )
56
+
57
+
58
+ def define_evaluation_aspects(task_type: str):
59
+ if task_type == "instruction":
60
+ return gr.Dropdown(
61
+ value=["overall-rating"],
62
+ choices=["complexity", "quality"],
63
+ label="Evaluation Aspects",
64
+ multiselect=True,
65
+ interactive=True,
66
+ )
67
+ elif task_type == "instruction-response":
68
+ return gr.Dropdown(
69
+ value=["overall-rating"],
70
+ choices=["helpfulness", "truthfulness", "overall-rating", "honesty"],
71
+ label="Evaluation Aspects",
72
+ multiselect=True,
73
+ interactive=True,
74
+ )
75
+ else:
76
+ return gr.Dropdown(interactive=False)
77
+
78
+
79
+ def evaluate_instruction(df: pd.DataFrame, aspects: list[str], instruction_column: str):
80
+ pass
81
+
82
+
83
+ def evaluate_instruction_response(
84
+ df: pd.DataFrame, aspects: list[str], instruction_column: str, response_column: str
85
+ ):
86
+ pass
87
+
88
+
89
+ def evaluate_custom(
90
+ df: pd.DataFrame, aspects: list[str], prompt_template: str, structured_output: dict
91
+ ):
92
+ pass
93
+
94
+
95
+ def _apply_to_dataset(
96
+ df: pd.DataFrame,
97
+ eval_type: str,
98
+ aspects_instruction: list[str],
99
+ instruction_column: str,
100
+ aspects_instruction_response: list[str],
101
+ instruction_column_response: str,
102
+ response_column_response: str,
103
+ aspects_custom: list[str],
104
+ prompt_template: str,
105
+ structured_output: dict,
106
+ ):
107
+ if eval_type == "instruction":
108
+ df = evaluate_instruction(df, aspects_instruction, instruction_column)
109
+ elif eval_type == "instruction-response":
110
+ df = evaluate_instruction_response(
111
+ df,
112
+ aspects_instruction_response,
113
+ instruction_column_response,
114
+ response_column_response,
115
+ )
116
+ elif eval_type == "custom":
117
+ df = evaluate_custom(df, aspects_custom, prompt_template, structured_output)
118
+ return df
119
+
120
+
121
+ def apply_to_sample_dataset(
122
+ repo_id: str,
123
+ eval_type: str,
124
+ aspects_instruction: list[str],
125
+ aspects_instruction_response: list[str],
126
+ aspects_custom: list[str],
127
+ instruction_instruction: str,
128
+ instruction_instruction_response: str,
129
+ response_instruction_response: str,
130
+ prompt_template: str,
131
+ structured_output: dict,
132
+ ):
133
+ df, _, _, _ = load_dataset_from_hub(repo_id, n_rows=10)
134
+ df = _apply_to_dataset(
135
+ df,
136
+ eval_type,
137
+ aspects_instruction,
138
+ instruction_instruction,
139
+ aspects_instruction_response,
140
+ instruction_instruction_response,
141
+ response_instruction_response,
142
+ aspects_custom,
143
+ prompt_template,
144
+ structured_output,
145
+ )
146
+ return df
147
+
148
+
149
+ def push_to_hub(
150
+ org_name: str,
151
+ repo_name: str,
152
+ private: bool,
153
+ n_rows: int,
154
+ original_repo_id: str,
155
+ eval_type: str,
156
+ aspects_instruction: list[str],
157
+ aspects_instruction_response: list[str],
158
+ aspects_custom: list[str],
159
+ instruction_instruction: str,
160
+ instruction_instruction_response: str,
161
+ response_instruction_response: str,
162
+ prompt_template: str,
163
+ structured_output: dict,
164
+ ):
165
+ df, _, _, _ = load_dataset_from_hub(original_repo_id, n_rows=n_rows)
166
+ df = _apply_to_dataset(
167
+ df,
168
+ eval_type,
169
+ aspects_instruction,
170
+ instruction_instruction,
171
+ aspects_instruction_response,
172
+ instruction_instruction_response,
173
+ response_instruction_response,
174
+ aspects_custom,
175
+ prompt_template,
176
+ structured_output,
177
+ )
178
+ new_repo_id = f"{org_name}/{repo_name}"
179
+ print(df)
180
+
181
+
182
+ with gr.Blocks() as app:
183
+ gr.Markdown("## Select your input dataset")
184
+ gr.HTML("<hr>")
185
+ with gr.Row():
186
+ with gr.Column(scale=1):
187
+ search_in = HuggingfaceHubSearch(
188
+ label="Search",
189
+ placeholder="Search for a Dataset",
190
+ search_type="dataset",
191
+ sumbit_on_select=True,
192
+ )
193
+ load_btn = gr.Button("Load Dataset")
194
+ with gr.Column(scale=3):
195
+ search_out = gr.HTML(label="Dataset Preview")
196
+
197
+ gr.Markdown("## Configure your task")
198
+ gr.HTML("<hr>")
199
+ with gr.Row():
200
+ with gr.Column(scale=1):
201
+ eval_type = gr.Dropdown(
202
+ label="Evaluation Type",
203
+ choices=["instruction", "instruction-response", "custom"],
204
+ visible=False,
205
+ )
206
+ with gr.Tab("instruction") as tab_instruction:
207
+ aspects_instruction = define_evaluation_aspects("instruction")
208
+ instruction_instruction = gr.Dropdown(
209
+ label="Instruction Column", interactive=True
210
+ )
211
+ tab_instruction.select(
212
+ lambda: "instruction",
213
+ inputs=[],
214
+ outputs=[eval_type],
215
+ )
216
+ with gr.Tab("instruction-response") as tab_instruction_response:
217
+ aspects_instruction_response = define_evaluation_aspects(
218
+ "instruction-response"
219
+ )
220
+ instruction_instruction_response = gr.Dropdown(
221
+ label="Instruction Column", interactive=True
222
+ )
223
+ response_instruction_response = gr.Dropdown(
224
+ label="Response Column", interactive=True
225
+ )
226
+ tab_instruction_response.select(
227
+ lambda: "instruction-response",
228
+ inputs=[],
229
+ outputs=[eval_type],
230
+ )
231
+ with gr.Tab("custom") as tab_custom:
232
+ aspects_custom = define_evaluation_aspects("custom")
233
+ prompt_template = gr.Code(
234
+ label="Prompt Template",
235
+ value="{{column_1}} based on {{column_2}}",
236
+ language="markdown",
237
+ interactive=True,
238
+ )
239
+ structured_output = gr.Code(
240
+ label="Structured Output",
241
+ value=json.dumps({"eval_aspect": "str"}),
242
+ language="json",
243
+ interactive=True,
244
+ )
245
+ tab_custom.select(
246
+ lambda: "custom",
247
+ inputs=[],
248
+ outputs=[eval_type],
249
+ )
250
+ btn_apply_to_sample_dataset = gr.Button("Refresh dataset")
251
+ with gr.Column(scale=3):
252
+ dataframe = gr.Dataframe()
253
+
254
+ gr.Markdown("## Generate your dataset")
255
+ gr.HTML("<hr>")
256
+ with gr.Row():
257
+ with gr.Column(scale=1):
258
+ org_name = get_org_dropdown()
259
+ repo_name = gr.Textbox(
260
+ label="Repo name",
261
+ placeholder="dataset_name",
262
+ value="my-distiset",
263
+ interactive=True,
264
+ )
265
+ n_rows = gr.Number(
266
+ label="Number of rows",
267
+ value=10,
268
+ interactive=True,
269
+ scale=1,
270
+ )
271
+ private = gr.Checkbox(
272
+ label="Private dataset",
273
+ value=False,
274
+ interactive=True,
275
+ scale=1,
276
+ )
277
+ btn_push_to_hub = gr.Button("Push to Hub", variant="primary", scale=2)
278
+ with gr.Column(scale=3):
279
+ success_message = gr.Markdown(visible=False)
280
+
281
+ search_in.submit(get_iframe, inputs=search_in, outputs=search_out)
282
+ load_btn.click(
283
+ load_dataset_from_hub,
284
+ inputs=[search_in],
285
+ outputs=[
286
+ dataframe,
287
+ instruction_instruction,
288
+ instruction_instruction_response,
289
+ response_instruction_response,
290
+ ],
291
+ )
292
+ btn_apply_to_sample_dataset.click(
293
+ apply_to_sample_dataset,
294
+ inputs=[
295
+ search_in,
296
+ eval_type,
297
+ aspects_instruction,
298
+ aspects_instruction_response,
299
+ aspects_custom,
300
+ instruction_instruction,
301
+ instruction_instruction_response,
302
+ response_instruction_response,
303
+ prompt_template,
304
+ structured_output,
305
+ ],
306
+ outputs=dataframe,
307
+ )
308
+ btn_push_to_hub.click(
309
+ push_to_hub,
310
+ inputs=[
311
+ org_name,
312
+ repo_name,
313
+ private,
314
+ n_rows,
315
+ search_in,
316
+ eval_type,
317
+ aspects_instruction,
318
+ aspects_instruction_response,
319
+ aspects_custom,
320
+ instruction_instruction,
321
+ instruction_instruction_response,
322
+ response_instruction_response,
323
+ prompt_template,
324
+ structured_output,
325
+ ],
326
+ outputs=success_message,
327
+ )
328
+ app.load(fn=get_org_dropdown, outputs=[org_name])
src/distilabel_dataset_generator/apps/sft.py CHANGED
@@ -1,4 +1,5 @@
1
  import ast
 
2
  from typing import Dict, List, Union
3
 
4
  import argilla as rg
@@ -10,16 +11,11 @@ from huggingface_hub import HfApi
10
 
11
  from src.distilabel_dataset_generator.apps.base import (
12
  get_argilla_client,
13
- get_main_ui,
14
  get_pipeline_code_ui,
15
  hide_success_message,
16
- push_pipeline_code_to_hub,
17
- show_success_message_argilla,
18
  show_success_message_hub,
19
  validate_argilla_user_workspace_dataset,
20
- )
21
- from src.distilabel_dataset_generator.apps.base import (
22
- push_dataset_to_hub as push_to_hub_base,
23
  )
24
  from src.distilabel_dataset_generator.pipelines.base import (
25
  DEFAULT_BATCH_SIZE,
@@ -30,16 +26,15 @@ from src.distilabel_dataset_generator.pipelines.embeddings import (
30
  )
31
  from src.distilabel_dataset_generator.pipelines.sft import (
32
  DEFAULT_DATASET_DESCRIPTIONS,
33
- DEFAULT_DATASETS,
34
- DEFAULT_SYSTEM_PROMPTS,
35
  PROMPT_CREATION_PROMPT,
36
  generate_pipeline_code,
37
  get_magpie_generator,
38
  get_prompt_generator,
39
  get_response_generator,
40
  )
41
-
42
- TASK = "supervised_fine_tuning"
 
43
 
44
 
45
  def convert_dataframe_messages(dataframe: pd.DataFrame) -> pd.DataFrame:
@@ -57,33 +52,176 @@ def convert_dataframe_messages(dataframe: pd.DataFrame) -> pd.DataFrame:
57
  return dataframe
58
 
59
 
60
- def push_dataset_to_hub(
61
- dataframe: pd.DataFrame,
62
- private: bool = True,
63
- org_name: str = None,
64
- repo_name: str = None,
65
- oauth_token: Union[gr.OAuthToken, None] = None,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  progress=gr.Progress(),
67
- ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  original_dataframe = dataframe.copy(deep=True)
69
  dataframe = convert_dataframe_messages(dataframe)
70
- try:
71
- push_to_hub_base(
72
- dataframe, private, org_name, repo_name, oauth_token, progress, task=TASK
73
- )
74
- except Exception as e:
75
- raise gr.Error(f"Error pushing dataset to the Hub: {e}")
 
 
76
  return original_dataframe
77
 
78
 
79
  def push_dataset_to_argilla(
80
- dataframe: pd.DataFrame,
81
- dataset_name: str,
 
 
 
 
82
  oauth_token: Union[gr.OAuthToken, None] = None,
83
  progress=gr.Progress(),
84
  ) -> pd.DataFrame:
85
- original_dataframe = dataframe.copy(deep=True)
86
- dataframe = convert_dataframe_messages(dataframe)
 
 
 
 
87
  try:
88
  progress(0.1, desc="Setting up user and workspace")
89
  client = get_argilla_client()
@@ -185,10 +323,10 @@ def push_dataset_to_argilla(
185
  dataframe["prompt_embeddings"] = get_embeddings(dataframe["prompt"])
186
 
187
  progress(0.5, desc="Creating dataset")
188
- rg_dataset = client.datasets(name=dataset_name, workspace=hf_user)
189
  if rg_dataset is None:
190
  rg_dataset = rg.Dataset(
191
- name=dataset_name,
192
  workspace=hf_user,
193
  settings=settings,
194
  client=client,
@@ -200,309 +338,123 @@ def push_dataset_to_argilla(
200
  progress(1.0, desc="Dataset pushed to Argilla")
201
  except Exception as e:
202
  raise gr.Error(f"Error pushing dataset to Argilla: {e}")
203
- return original_dataframe
204
-
205
-
206
- def generate_system_prompt(dataset_description, progress=gr.Progress()):
207
- progress(0.0, desc="Generating system prompt")
208
- if dataset_description in DEFAULT_DATASET_DESCRIPTIONS:
209
- index = DEFAULT_DATASET_DESCRIPTIONS.index(dataset_description)
210
- if index < len(DEFAULT_SYSTEM_PROMPTS):
211
- return DEFAULT_SYSTEM_PROMPTS[index]
212
 
213
- progress(0.3, desc="Initializing text generation")
214
- generate_description = get_prompt_generator()
215
- progress(0.7, desc="Generating system prompt")
216
- result = next(
217
- generate_description.process(
218
- [
219
- {
220
- "system_prompt": PROMPT_CREATION_PROMPT,
221
- "instruction": dataset_description,
222
- }
223
- ]
224
- )
225
- )[0]["generation"]
226
- progress(1.0, desc="System prompt generated")
227
- return result
228
 
229
-
230
- def generate_dataset(
231
- system_prompt: str,
232
- num_turns: int = 1,
233
- num_rows: int = 5,
234
- is_sample: bool = False,
235
- progress=gr.Progress(),
236
- ) -> pd.DataFrame:
237
- progress(0.0, desc="(1/2) Generating instructions")
238
- magpie_generator = get_magpie_generator(
239
- num_turns, num_rows, system_prompt, is_sample
240
- )
241
- response_generator = get_response_generator(num_turns, system_prompt, is_sample)
242
- total_steps: int = num_rows * 2
243
- batch_size = DEFAULT_BATCH_SIZE
244
-
245
- # create instructions
246
- n_processed = 0
247
- magpie_results = []
248
- while n_processed < num_rows:
249
- progress(
250
- 0.5 * n_processed / num_rows,
251
- total=total_steps,
252
- desc="(1/2) Generating instructions",
253
- )
254
- remaining_rows = num_rows - n_processed
255
- batch_size = min(batch_size, remaining_rows)
256
- inputs = [{"system_prompt": system_prompt} for _ in range(batch_size)]
257
- batch = list(magpie_generator.process(inputs=inputs))
258
- magpie_results.extend(batch[0])
259
- n_processed += batch_size
260
- progress(0.5, desc="(1/2) Generating instructions")
261
-
262
- # generate responses
263
- n_processed = 0
264
- response_results = []
265
- if num_turns == 1:
266
- while n_processed < num_rows:
267
- progress(
268
- 0.5 + 0.5 * n_processed / num_rows,
269
- total=total_steps,
270
- desc="(2/2) Generating responses",
271
- )
272
- batch = magpie_results[n_processed : n_processed + batch_size]
273
- responses = list(response_generator.process(inputs=batch))
274
- response_results.extend(responses[0])
275
- n_processed += batch_size
276
- for result in response_results:
277
- result["prompt"] = result["instruction"]
278
- result["completion"] = result["generation"]
279
- result["system_prompt"] = system_prompt
280
- else:
281
- for result in magpie_results:
282
- result["conversation"].insert(
283
- 0, {"role": "system", "content": system_prompt}
284
  )
285
- result["messages"] = result["conversation"]
286
- while n_processed < num_rows:
287
- progress(
288
- 0.5 + 0.5 * n_processed / num_rows,
289
- total=total_steps,
290
- desc="(2/2) Generating responses",
291
  )
292
- batch = magpie_results[n_processed : n_processed + batch_size]
293
- responses = list(response_generator.process(inputs=batch))
294
- response_results.extend(responses[0])
295
- n_processed += batch_size
296
- for result in response_results:
297
- result["messages"].append(
298
- {"role": "assistant", "content": result["generation"]}
299
  )
300
- progress(
301
- 1,
302
- total=total_steps,
303
- desc="(2/2) Creating dataset",
304
- )
305
-
306
- # create distiset
307
- distiset_results = []
308
- for result in response_results:
309
- record = {}
310
- for relevant_keys in [
311
- "messages",
312
- "prompt",
313
- "completion",
314
- "model_name",
315
- "system_prompt",
316
- ]:
317
- if relevant_keys in result:
318
- record[relevant_keys] = result[relevant_keys]
319
- distiset_results.append(record)
320
-
321
- distiset = Distiset(
322
- {
323
- "default": Dataset.from_list(distiset_results),
324
- }
325
- )
326
-
327
- # If not pushing to hub generate the dataset directly
328
- distiset = distiset["default"]
329
- if num_turns == 1:
330
- outputs = distiset.to_pandas()[["system_prompt", "prompt", "completion"]]
331
- else:
332
- outputs = distiset.to_pandas()[["messages"]]
333
- dataframe = pd.DataFrame(outputs)
334
- progress(1.0, desc="Dataset generation completed")
335
- return dataframe
336
-
337
-
338
- (
339
- app,
340
- main_ui,
341
- custom_input_ui,
342
- dataset_description,
343
- examples,
344
- btn_generate_system_prompt,
345
- system_prompt,
346
- sample_dataset,
347
- btn_generate_sample_dataset,
348
- dataset_name,
349
- add_to_existing_dataset,
350
- btn_generate_full_dataset_argilla,
351
- btn_generate_and_push_to_argilla,
352
- btn_push_to_argilla,
353
- org_name,
354
- repo_name,
355
- private,
356
- btn_generate_full_dataset,
357
- btn_generate_and_push_to_hub,
358
- btn_push_to_hub,
359
- final_dataset,
360
- success_message,
361
- ) = get_main_ui(
362
- default_dataset_descriptions=DEFAULT_DATASET_DESCRIPTIONS,
363
- default_system_prompts=DEFAULT_SYSTEM_PROMPTS,
364
- default_datasets=DEFAULT_DATASETS,
365
- fn_generate_system_prompt=generate_system_prompt,
366
- fn_generate_dataset=generate_dataset,
367
- task=TASK,
368
- )
369
-
370
- with app:
371
- with main_ui:
372
- with custom_input_ui:
373
  num_turns = gr.Number(
374
  value=1,
375
  label="Number of turns in the conversation",
376
  minimum=1,
377
  maximum=4,
378
  step=1,
 
379
  info="Choose between 1 (single turn with 'instruction-response' columns) and 2-4 (multi-turn conversation with a 'messages' column).",
380
  )
381
- num_rows = gr.Number(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
382
  value=10,
383
- label="Number of rows in the dataset",
384
- minimum=1,
385
- maximum=500,
386
- info="The number of rows in the dataset. Note that you are able to generate more rows at once but that this will take time.",
387
  )
 
 
 
 
 
 
 
 
 
388
 
389
- pipeline_code = get_pipeline_code_ui(
390
- generate_pipeline_code(system_prompt.value, num_turns.value, num_rows.value)
391
- )
392
 
393
- # define app triggers
394
  gr.on(
395
- triggers=[
396
- btn_generate_full_dataset.click,
397
- btn_generate_full_dataset_argilla.click,
398
- ],
399
- fn=hide_success_message,
400
- outputs=[success_message],
401
  ).then(
402
- fn=generate_dataset,
403
- inputs=[system_prompt, num_turns, num_rows],
404
- outputs=[final_dataset],
405
  show_progress=True,
406
  )
407
 
408
- btn_generate_and_push_to_argilla.click(
409
  fn=validate_argilla_user_workspace_dataset,
410
- inputs=[dataset_name, final_dataset, add_to_existing_dataset],
411
- outputs=[final_dataset],
412
- show_progress=True,
413
- ).success(
414
- fn=hide_success_message,
415
  outputs=[success_message],
416
- ).success(
417
- fn=generate_dataset,
418
- inputs=[system_prompt, num_turns, num_rows],
419
- outputs=[final_dataset],
420
- show_progress=True,
421
- ).success(
422
- fn=push_dataset_to_argilla,
423
- inputs=[final_dataset, dataset_name],
424
- outputs=[final_dataset],
425
- show_progress=True,
426
- ).success(
427
- fn=show_success_message_argilla,
428
- inputs=[],
429
- outputs=[success_message],
430
- )
431
-
432
- btn_generate_and_push_to_hub.click(
433
- fn=hide_success_message,
434
- outputs=[success_message],
435
- ).then(
436
- fn=generate_dataset,
437
- inputs=[system_prompt, num_turns, num_rows],
438
- outputs=[final_dataset],
439
- show_progress=True,
440
- ).then(
441
- fn=push_dataset_to_hub,
442
- inputs=[final_dataset, private, org_name, repo_name],
443
- outputs=[final_dataset],
444
  show_progress=True,
445
  ).then(
446
- fn=push_pipeline_code_to_hub,
447
- inputs=[pipeline_code, org_name, repo_name],
448
- outputs=[],
449
- show_progress=True,
450
- ).success(
451
- fn=show_success_message_hub,
452
  inputs=[org_name, repo_name],
453
  outputs=[success_message],
454
- )
455
-
456
- btn_push_to_hub.click(
457
- fn=hide_success_message,
458
- outputs=[success_message],
459
- ).then(
460
- fn=push_dataset_to_hub,
461
- inputs=[final_dataset, private, org_name, repo_name],
462
- outputs=[final_dataset],
463
- show_progress=True,
464
- ).then(
465
- fn=push_pipeline_code_to_hub,
466
- inputs=[pipeline_code, org_name, repo_name],
467
- outputs=[],
468
  show_progress=True,
469
  ).success(
470
- fn=show_success_message_hub,
471
- inputs=[org_name, repo_name],
472
- outputs=[success_message],
473
- )
474
-
475
- btn_push_to_argilla.click(
476
  fn=hide_success_message,
477
  outputs=[success_message],
478
- ).success(
479
- fn=validate_argilla_user_workspace_dataset,
480
- inputs=[dataset_name, final_dataset, add_to_existing_dataset],
481
- outputs=[final_dataset],
482
  show_progress=True,
483
  ).success(
484
  fn=push_dataset_to_argilla,
485
- inputs=[final_dataset, dataset_name],
486
- outputs=[final_dataset],
 
 
 
 
 
 
 
487
  show_progress=True,
488
  ).success(
489
- fn=show_success_message_argilla,
490
- inputs=[],
491
  outputs=[success_message],
492
  )
493
-
494
- system_prompt.change(
495
- fn=generate_pipeline_code,
496
- inputs=[system_prompt, num_turns, num_rows],
497
- outputs=[pipeline_code],
498
- )
499
- num_turns.change(
500
- fn=generate_pipeline_code,
501
- inputs=[system_prompt, num_turns, num_rows],
502
- outputs=[pipeline_code],
503
- )
504
- num_rows.change(
505
- fn=generate_pipeline_code,
506
- inputs=[system_prompt, num_turns, num_rows],
507
- outputs=[pipeline_code],
508
- )
 
1
  import ast
2
+ import uuid
3
  from typing import Dict, List, Union
4
 
5
  import argilla as rg
 
11
 
12
  from src.distilabel_dataset_generator.apps.base import (
13
  get_argilla_client,
 
14
  get_pipeline_code_ui,
15
  hide_success_message,
 
 
16
  show_success_message_hub,
17
  validate_argilla_user_workspace_dataset,
18
+ validate_push_to_hub,
 
 
19
  )
20
  from src.distilabel_dataset_generator.pipelines.base import (
21
  DEFAULT_BATCH_SIZE,
 
26
  )
27
  from src.distilabel_dataset_generator.pipelines.sft import (
28
  DEFAULT_DATASET_DESCRIPTIONS,
 
 
29
  PROMPT_CREATION_PROMPT,
30
  generate_pipeline_code,
31
  get_magpie_generator,
32
  get_prompt_generator,
33
  get_response_generator,
34
  )
35
+ from src.distilabel_dataset_generator.utils import (
36
+ get_org_dropdown,
37
+ )
38
 
39
 
40
  def convert_dataframe_messages(dataframe: pd.DataFrame) -> pd.DataFrame:
 
52
  return dataframe
53
 
54
 
55
+ def generate_system_prompt(dataset_description, progress=gr.Progress()):
56
+ progress(0.0, desc="Generating system prompt")
57
+
58
+ progress(0.3, desc="Initializing text generation")
59
+ generate_description = get_prompt_generator()
60
+ progress(0.7, desc="Generating system prompt")
61
+ result = next(
62
+ generate_description.process(
63
+ [
64
+ {
65
+ "system_prompt": PROMPT_CREATION_PROMPT,
66
+ "instruction": dataset_description,
67
+ }
68
+ ]
69
+ )
70
+ )[0]["generation"]
71
+ progress(1.0, desc="System prompt generated")
72
+ return result, pd.DataFrame()
73
+
74
+
75
+ def generate_sample_dataset(system_prompt, progress=gr.Progress()):
76
+ df = generate_dataset(
77
+ system_prompt=system_prompt,
78
+ num_turns=1,
79
+ num_rows=10,
80
+ progress=progress,
81
+ is_sample=True,
82
+ )
83
+ return df
84
+
85
+
86
+ def generate_dataset(
87
+ system_prompt: str,
88
+ num_turns: int = 1,
89
+ num_rows: int = 10,
90
+ is_sample: bool = False,
91
  progress=gr.Progress(),
92
+ ) -> pd.DataFrame:
93
+ progress(0.0, desc="(1/2) Generating instructions")
94
+ magpie_generator = get_magpie_generator(
95
+ num_turns, num_rows, system_prompt, is_sample
96
+ )
97
+ response_generator = get_response_generator(num_turns, system_prompt, is_sample)
98
+ total_steps: int = num_rows * 2
99
+ batch_size = DEFAULT_BATCH_SIZE
100
+
101
+ # create instructions
102
+ n_processed = 0
103
+ magpie_results = []
104
+ while n_processed < num_rows:
105
+ progress(
106
+ 0.5 * n_processed / num_rows,
107
+ total=total_steps,
108
+ desc="(1/2) Generating instructions",
109
+ )
110
+ remaining_rows = num_rows - n_processed
111
+ batch_size = min(batch_size, remaining_rows)
112
+ inputs = [{"system_prompt": system_prompt} for _ in range(batch_size)]
113
+ batch = list(magpie_generator.process(inputs=inputs))
114
+ magpie_results.extend(batch[0])
115
+ n_processed += batch_size
116
+ progress(0.5, desc="(1/2) Generating instructions")
117
+
118
+ # generate responses
119
+ n_processed = 0
120
+ response_results = []
121
+ if num_turns == 1:
122
+ while n_processed < num_rows:
123
+ progress(
124
+ 0.5 + 0.5 * n_processed / num_rows,
125
+ total=total_steps,
126
+ desc="(2/2) Generating responses",
127
+ )
128
+ batch = magpie_results[n_processed : n_processed + batch_size]
129
+ responses = list(response_generator.process(inputs=batch))
130
+ response_results.extend(responses[0])
131
+ n_processed += batch_size
132
+ for result in response_results:
133
+ result["prompt"] = result["instruction"]
134
+ result["completion"] = result["generation"]
135
+ result["system_prompt"] = system_prompt
136
+ else:
137
+ for result in magpie_results:
138
+ result["conversation"].insert(
139
+ 0, {"role": "system", "content": system_prompt}
140
+ )
141
+ result["messages"] = result["conversation"]
142
+ while n_processed < num_rows:
143
+ progress(
144
+ 0.5 + 0.5 * n_processed / num_rows,
145
+ total=total_steps,
146
+ desc="(2/2) Generating responses",
147
+ )
148
+ batch = magpie_results[n_processed : n_processed + batch_size]
149
+ responses = list(response_generator.process(inputs=batch))
150
+ response_results.extend(responses[0])
151
+ n_processed += batch_size
152
+ for result in response_results:
153
+ result["messages"].append(
154
+ {"role": "assistant", "content": result["generation"]}
155
+ )
156
+ progress(
157
+ 1,
158
+ total=total_steps,
159
+ desc="(2/2) Creating dataset",
160
+ )
161
+
162
+ # create distiset
163
+ distiset_results = []
164
+ for result in response_results:
165
+ record = {}
166
+ for relevant_keys in [
167
+ "messages",
168
+ "prompt",
169
+ "completion",
170
+ "model_name",
171
+ "system_prompt",
172
+ ]:
173
+ if relevant_keys in result:
174
+ record[relevant_keys] = result[relevant_keys]
175
+ distiset_results.append(record)
176
+
177
+ distiset = Distiset(
178
+ {
179
+ "default": Dataset.from_list(distiset_results),
180
+ }
181
+ )
182
+
183
+ # If not pushing to hub generate the dataset directly
184
+ distiset = distiset["default"]
185
+ if num_turns == 1:
186
+ outputs = distiset.to_pandas()[["prompt", "completion", "system_prompt"]]
187
+ else:
188
+ outputs = distiset.to_pandas()[["messages"]]
189
+ dataframe = pd.DataFrame(outputs)
190
+ progress(1.0, desc="Dataset generation completed")
191
+ return dataframe
192
+
193
+
194
+ def push_dataset_to_hub(dataframe, org_name, repo_name, oauth_token, private):
195
+ repo_id = validate_push_to_hub(org_name, repo_name)
196
  original_dataframe = dataframe.copy(deep=True)
197
  dataframe = convert_dataframe_messages(dataframe)
198
+ distiset = Distiset({"default": Dataset.from_pandas(dataframe)})
199
+ distiset.push_to_hub(
200
+ repo_id=repo_id,
201
+ private=private,
202
+ include_script=False,
203
+ token=oauth_token.token,
204
+ create_pr=False,
205
+ )
206
  return original_dataframe
207
 
208
 
209
  def push_dataset_to_argilla(
210
+ org_name: str,
211
+ repo_name: str,
212
+ system_prompt: str,
213
+ num_turns: int = 1,
214
+ n_rows: int = 10,
215
+ private: bool = False,
216
  oauth_token: Union[gr.OAuthToken, None] = None,
217
  progress=gr.Progress(),
218
  ) -> pd.DataFrame:
219
+ dataframe = generate_dataset(
220
+ system_prompt=system_prompt,
221
+ num_turns=num_turns,
222
+ num_rows=n_rows,
223
+ )
224
+ push_dataset_to_hub(dataframe, org_name, repo_name, oauth_token, private)
225
  try:
226
  progress(0.1, desc="Setting up user and workspace")
227
  client = get_argilla_client()
 
323
  dataframe["prompt_embeddings"] = get_embeddings(dataframe["prompt"])
324
 
325
  progress(0.5, desc="Creating dataset")
326
+ rg_dataset = client.datasets(name=repo_name, workspace=hf_user)
327
  if rg_dataset is None:
328
  rg_dataset = rg.Dataset(
329
+ name=repo_name,
330
  workspace=hf_user,
331
  settings=settings,
332
  client=client,
 
338
  progress(1.0, desc="Dataset pushed to Argilla")
339
  except Exception as e:
340
  raise gr.Error(f"Error pushing dataset to Argilla: {e}")
341
+ return ""
 
 
 
 
 
 
 
 
342
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
343
 
344
+ with gr.Blocks() as app:
345
+ gr.Markdown("## Describe the dataset you want")
346
+ gr.HTML("<hr>")
347
+ with gr.Row():
348
+ with gr.Column(scale=1):
349
+ dataset_description = gr.Textbox(
350
+ label="Dataset description",
351
+ placeholder="Give a precise description of your desired dataset.",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
352
  )
353
+ examples = gr.Examples(
354
+ examples=DEFAULT_DATASET_DESCRIPTIONS,
355
+ inputs=[dataset_description],
356
+ cache_examples=False,
357
+ label="Example descriptions",
 
358
  )
359
+ system_prompt = gr.Textbox(
360
+ label="System prompt",
361
+ placeholder="You are a helpful assistant.",
362
+ visible=False,
 
 
 
363
  )
364
+ load_btn = gr.Button("Load Dataset")
365
+ with gr.Column(scale=3):
366
+ pass
367
+
368
+ gr.Markdown("## Configure your task")
369
+ gr.HTML("<hr>")
370
+ with gr.Row():
371
+ with gr.Column(scale=1):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
372
  num_turns = gr.Number(
373
  value=1,
374
  label="Number of turns in the conversation",
375
  minimum=1,
376
  maximum=4,
377
  step=1,
378
+ interactive=True,
379
  info="Choose between 1 (single turn with 'instruction-response' columns) and 2-4 (multi-turn conversation with a 'messages' column).",
380
  )
381
+ btn_apply_to_sample_dataset = gr.Button("Refresh dataset")
382
+ with gr.Column(scale=3):
383
+ dataframe = gr.Dataframe()
384
+
385
+ gr.Markdown("## Generate your dataset")
386
+ gr.HTML("<hr>")
387
+ with gr.Row():
388
+ with gr.Column(scale=1):
389
+ org_name = get_org_dropdown()
390
+ repo_name = gr.Textbox(
391
+ label="Repo name",
392
+ placeholder="dataset_name",
393
+ value=f"my-distiset-{str(uuid.uuid4())[:8]}",
394
+ interactive=True,
395
+ )
396
+ n_rows = gr.Number(
397
+ label="Number of rows",
398
  value=10,
399
+ interactive=True,
400
+ scale=1,
 
 
401
  )
402
+ private = gr.Checkbox(
403
+ label="Private dataset",
404
+ value=False,
405
+ interactive=True,
406
+ scale=1,
407
+ )
408
+ btn_push_to_hub = gr.Button("Push to Hub", variant="primary", scale=2)
409
+ with gr.Column(scale=3):
410
+ success_message = gr.Markdown()
411
 
412
+ pipeline_code = get_pipeline_code_ui(
413
+ generate_pipeline_code(system_prompt.value, num_turns.value, n_rows.value)
414
+ )
415
 
 
416
  gr.on(
417
+ triggers=[load_btn.click, btn_apply_to_sample_dataset.click],
418
+ fn=generate_system_prompt,
419
+ inputs=[dataset_description],
420
+ outputs=[system_prompt, dataframe],
421
+ show_progress=True,
 
422
  ).then(
423
+ fn=generate_sample_dataset,
424
+ inputs=[system_prompt],
425
+ outputs=[dataframe],
426
  show_progress=True,
427
  )
428
 
429
+ btn_push_to_hub.click(
430
  fn=validate_argilla_user_workspace_dataset,
431
+ inputs=[repo_name],
 
 
 
 
432
  outputs=[success_message],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
433
  show_progress=True,
434
  ).then(
435
+ fn=validate_push_to_hub,
 
 
 
 
 
436
  inputs=[org_name, repo_name],
437
  outputs=[success_message],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
438
  show_progress=True,
439
  ).success(
 
 
 
 
 
 
440
  fn=hide_success_message,
441
  outputs=[success_message],
 
 
 
 
442
  show_progress=True,
443
  ).success(
444
  fn=push_dataset_to_argilla,
445
+ inputs=[
446
+ org_name,
447
+ repo_name,
448
+ system_prompt,
449
+ num_turns,
450
+ n_rows,
451
+ private,
452
+ ],
453
+ outputs=[success_message],
454
  show_progress=True,
455
  ).success(
456
+ fn=show_success_message_hub,
457
+ inputs=[org_name, repo_name],
458
  outputs=[success_message],
459
  )
460
+ app.load(fn=get_org_dropdown, outputs=[org_name])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/distilabel_dataset_generator/apps/textcat.py CHANGED
@@ -1,24 +1,21 @@
1
  import re
 
2
  from typing import List, Union
3
 
4
  import argilla as rg
5
  import gradio as gr
6
  import pandas as pd
7
- from datasets import Dataset
 
8
  from huggingface_hub import HfApi
9
 
10
  from src.distilabel_dataset_generator.apps.base import (
11
  get_argilla_client,
12
- get_main_ui,
13
  get_pipeline_code_ui,
14
  hide_success_message,
15
- push_pipeline_code_to_hub,
16
- show_success_message_argilla,
17
  show_success_message_hub,
18
  validate_argilla_user_workspace_dataset,
19
- )
20
- from src.distilabel_dataset_generator.apps.base import (
21
- push_dataset_to_hub as push_to_hub_base,
22
  )
23
  from src.distilabel_dataset_generator.pipelines.base import (
24
  DEFAULT_BATCH_SIZE,
@@ -29,166 +26,24 @@ from src.distilabel_dataset_generator.pipelines.embeddings import (
29
  )
30
  from src.distilabel_dataset_generator.pipelines.textcat import (
31
  DEFAULT_DATASET_DESCRIPTIONS,
32
- DEFAULT_DATASETS,
33
- DEFAULT_SYSTEM_PROMPTS,
34
  PROMPT_CREATION_PROMPT,
35
  generate_pipeline_code,
36
  get_labeller_generator,
37
  get_prompt_generator,
38
  get_textcat_generator,
39
  )
40
- from src.distilabel_dataset_generator.utils import get_preprocess_labels
41
-
42
- TASK = "text_classification"
43
-
44
-
45
- def push_dataset_to_hub(
46
- dataframe: pd.DataFrame,
47
- private: bool = True,
48
- org_name: str = None,
49
- repo_name: str = None,
50
- oauth_token: Union[gr.OAuthToken, None] = None,
51
- progress=gr.Progress(),
52
- labels: List[str] = None,
53
- num_labels: int = 1,
54
- ):
55
- original_dataframe = dataframe.copy(deep=True)
56
- dataframe = dataframe[
57
- (dataframe["text"].str.strip() != "") & (dataframe["text"].notna())
58
- ]
59
- labels = get_preprocess_labels(labels)
60
- try:
61
- push_to_hub_base(
62
- dataframe,
63
- private,
64
- org_name,
65
- repo_name,
66
- oauth_token,
67
- progress,
68
- labels,
69
- num_labels,
70
- task=TASK,
71
- )
72
- except Exception as e:
73
- raise gr.Error(f"Error pushing dataset to the Hub: {e}")
74
- return original_dataframe
75
-
76
-
77
- def push_dataset_to_argilla(
78
- dataframe: pd.DataFrame,
79
- dataset_name: str,
80
- oauth_token: Union[gr.OAuthToken, None] = None,
81
- progress=gr.Progress(),
82
- num_labels: int = 1,
83
- labels: List[str] = None,
84
- ) -> pd.DataFrame:
85
- original_dataframe = dataframe.copy(deep=True)
86
- dataframe = dataframe[
87
- (dataframe["text"].str.strip() != "") & (dataframe["text"].notna())
88
- ]
89
- try:
90
- progress(0.1, desc="Setting up user and workspace")
91
- client = get_argilla_client()
92
- hf_user = HfApi().whoami(token=oauth_token.token)["name"]
93
- labels = get_preprocess_labels(labels)
94
- settings = rg.Settings(
95
- fields=[
96
- rg.TextField(
97
- name="text",
98
- description="The text classification data",
99
- title="Text",
100
- ),
101
- ],
102
- questions=[
103
- (
104
- rg.LabelQuestion(
105
- name="label",
106
- title="Label",
107
- description="The label of the text",
108
- labels=labels,
109
- )
110
- if num_labels == 1
111
- else rg.MultiLabelQuestion(
112
- name="labels",
113
- title="Labels",
114
- description="The labels of the conversation",
115
- labels=labels,
116
- )
117
- ),
118
- ],
119
- metadata=[
120
- rg.IntegerMetadataProperty(name="text_length", title="Text Length"),
121
- ],
122
- vectors=[
123
- rg.VectorField(
124
- name="text_embeddings",
125
- dimensions=get_sentence_embedding_dimensions(),
126
- )
127
- ],
128
- guidelines="Please review the text and provide or correct the label where needed.",
129
- )
130
-
131
- dataframe["text_length"] = dataframe["text"].apply(len)
132
- dataframe["text_embeddings"] = get_embeddings(dataframe["text"])
133
-
134
- progress(0.5, desc="Creating dataset")
135
- rg_dataset = client.datasets(name=dataset_name, workspace=hf_user)
136
- if rg_dataset is None:
137
- rg_dataset = rg.Dataset(
138
- name=dataset_name,
139
- workspace=hf_user,
140
- settings=settings,
141
- client=client,
142
- )
143
- rg_dataset = rg_dataset.create()
144
- progress(0.7, desc="Pushing dataset to Argilla")
145
- hf_dataset = Dataset.from_pandas(dataframe)
146
- records = [
147
- rg.Record(
148
- fields={
149
- "text": sample["text"],
150
- },
151
- metadata={"text_length": sample["text_length"]},
152
- vectors={"text_embeddings": sample["text_embeddings"]},
153
- suggestions=(
154
- [
155
- rg.Suggestion(
156
- question_name="label" if num_labels == 1 else "labels",
157
- value=(
158
- sample["label"] if num_labels == 1 else sample["labels"]
159
- ),
160
- )
161
- ]
162
- if (
163
- (num_labels == 1 and sample["label"] in labels)
164
- or (
165
- num_labels > 1
166
- and all(label in labels for label in sample["labels"])
167
- )
168
- )
169
- else []
170
- ),
171
- )
172
- for sample in hf_dataset
173
- ]
174
- rg_dataset.records.log(records=records)
175
- progress(1.0, desc="Dataset pushed to Argilla")
176
- except Exception as e:
177
- raise gr.Error(f"Error pushing dataset to Argilla: {e}")
178
- return original_dataframe
179
 
180
 
181
  def generate_system_prompt(dataset_description, progress=gr.Progress()):
182
  progress(0.0, desc="Generating text classification task")
183
- if dataset_description in DEFAULT_DATASET_DESCRIPTIONS:
184
- index = DEFAULT_DATASET_DESCRIPTIONS.index(dataset_description)
185
- if index < len(DEFAULT_SYSTEM_PROMPTS):
186
- return DEFAULT_SYSTEM_PROMPTS[index]
187
-
188
  progress(0.3, desc="Initializing text generation")
189
  generate_description = get_prompt_generator()
190
  progress(0.7, desc="Generating text classification task")
191
- result = next(
192
  generate_description.process(
193
  [
194
  {
@@ -199,7 +54,25 @@ def generate_system_prompt(dataset_description, progress=gr.Progress()):
199
  )
200
  )[0]["generation"]
201
  progress(1.0, desc="Text classification task generated")
202
- return result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
203
 
204
 
205
  def generate_dataset(
@@ -212,6 +85,10 @@ def generate_dataset(
212
  is_sample: bool = False,
213
  progress=gr.Progress(),
214
  ) -> pd.DataFrame:
 
 
 
 
215
  progress(0.0, desc="(1/2) Generating text classification data")
216
  labels = get_preprocess_labels(labels)
217
  textcat_generator = get_textcat_generator(
@@ -230,7 +107,7 @@ def generate_dataset(
230
  textcat_results = []
231
  while n_processed < num_rows:
232
  progress(
233
- 0.5 * n_processed / num_rows,
234
  total=total_steps,
235
  desc="(1/2) Generating text classification data",
236
  )
@@ -244,7 +121,7 @@ def generate_dataset(
244
  result["text"] = result["input_text"]
245
 
246
  # label text classification data
247
- progress(0.5, desc="(1/2) Generating text classification data")
248
  if not is_sample:
249
  n_processed = 0
250
  labeller_results = []
@@ -300,6 +177,158 @@ def generate_dataset(
300
  return dataframe
301
 
302
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
303
  def update_suggested_labels(system_prompt):
304
  new_labels = re.findall(r"'(\b[\w-]+\b)'", system_prompt)
305
  if not new_labels:
@@ -321,41 +350,34 @@ def update_max_num_labels(labels):
321
  return gr.update(maximum=len(labels) if labels else 1)
322
 
323
 
324
- (
325
- app,
326
- main_ui,
327
- custom_input_ui,
328
- dataset_description,
329
- examples,
330
- btn_generate_system_prompt,
331
- system_prompt,
332
- sample_dataset,
333
- btn_generate_sample_dataset,
334
- dataset_name,
335
- add_to_existing_dataset,
336
- btn_generate_full_dataset_argilla,
337
- btn_generate_and_push_to_argilla,
338
- btn_push_to_argilla,
339
- org_name,
340
- repo_name,
341
- private,
342
- btn_generate_full_dataset,
343
- btn_generate_and_push_to_hub,
344
- btn_push_to_hub,
345
- final_dataset,
346
- success_message,
347
- ) = get_main_ui(
348
- default_dataset_descriptions=DEFAULT_DATASET_DESCRIPTIONS,
349
- default_system_prompts=DEFAULT_SYSTEM_PROMPTS,
350
- default_datasets=DEFAULT_DATASETS,
351
- fn_generate_system_prompt=generate_system_prompt,
352
- fn_generate_dataset=generate_dataset,
353
- task=TASK,
354
- )
355
-
356
- with app:
357
- with main_ui:
358
- with custom_input_ui:
359
  difficulty = gr.Dropdown(
360
  choices=[
361
  ("High School", "high school"),
@@ -366,6 +388,7 @@ with app:
366
  value="mixed",
367
  label="Difficulty",
368
  info="Select the comprehension level for the text. Ensure it matches the task context.",
 
369
  )
370
  clarity = gr.Dropdown(
371
  choices=[
@@ -380,51 +403,78 @@ with app:
380
  value="mixed",
381
  label="Clarity",
382
  info="Set how easily the correct label or labels can be identified.",
 
 
 
 
 
 
 
 
 
383
  )
384
- with gr.Column():
385
- labels = gr.Dropdown(
386
- choices=[],
387
- value=["negative", "positive"],
388
- allow_custom_value=True,
389
- interactive=True,
390
- label="Labels",
391
- multiselect=True,
392
- info="Add the labels to classify the text.",
393
- )
394
- with gr.Blocks():
395
- btn_suggested_labels = gr.Button(
396
- value="Add suggested labels",
397
- variant="primary",
398
- size="sm",
399
- )
400
  num_labels = gr.Number(
401
  label="Number of labels per text",
402
  value=1,
403
  minimum=1,
404
  maximum=10,
405
  info="Select 1 for single-label and >1 for multi-label.",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
406
  )
407
- num_rows = gr.Number(
408
  label="Number of rows",
409
  value=10,
410
- minimum=1,
411
- maximum=500,
412
- info="Select the number of rows in the dataset. More rows will take more time.",
413
  )
414
-
415
- pipeline_code = get_pipeline_code_ui(
416
- generate_pipeline_code(
417
- system_prompt.value,
418
- difficulty=difficulty.value,
419
- clarity=clarity.value,
420
- labels=labels.value,
421
- num_labels=num_labels.value,
422
- num_rows=num_rows.value,
423
  )
 
 
 
 
 
 
 
 
 
 
 
 
424
  )
 
425
 
426
- # define app triggers
427
- btn_suggested_labels.click(
 
 
 
 
 
 
 
 
 
 
428
  fn=update_suggested_labels,
429
  inputs=[system_prompt],
430
  outputs=labels,
@@ -434,141 +484,39 @@ with app:
434
  outputs=[num_labels],
435
  )
436
 
437
- gr.on(
438
- triggers=[
439
- btn_generate_full_dataset.click,
440
- btn_generate_full_dataset_argilla.click,
441
- ],
442
- fn=hide_success_message,
443
- outputs=[success_message],
444
- ).then(
445
- fn=validate_input_labels,
446
- inputs=[labels],
447
- outputs=[labels],
448
- ).success(
449
- fn=generate_dataset,
450
- inputs=[system_prompt, difficulty, clarity, labels, num_labels, num_rows],
451
- outputs=[final_dataset],
452
- show_progress=True,
453
- )
454
-
455
- btn_generate_and_push_to_argilla.click(
456
  fn=validate_argilla_user_workspace_dataset,
457
- inputs=[dataset_name, final_dataset, add_to_existing_dataset],
458
- outputs=[final_dataset],
459
- show_progress=True,
460
- ).success(
461
- fn=hide_success_message,
462
- outputs=[success_message],
463
- ).success(
464
- fn=generate_dataset,
465
- inputs=[system_prompt, difficulty, clarity, labels, num_labels, num_rows],
466
- outputs=[final_dataset],
467
- show_progress=True,
468
- ).success(
469
- fn=push_dataset_to_argilla,
470
- inputs=[final_dataset, dataset_name, num_labels, labels],
471
- outputs=[final_dataset],
472
- show_progress=True,
473
- ).success(
474
- fn=show_success_message_argilla,
475
- inputs=[],
476
- outputs=[success_message],
477
- )
478
-
479
- btn_generate_and_push_to_hub.click(
480
- fn=hide_success_message,
481
  outputs=[success_message],
482
- ).then(
483
- fn=generate_dataset,
484
- inputs=[system_prompt, difficulty, clarity, labels, num_labels, num_rows],
485
- outputs=[final_dataset],
486
- show_progress=True,
487
- ).then(
488
- fn=push_dataset_to_hub,
489
- inputs=[final_dataset, private, org_name, repo_name, labels, num_labels],
490
- outputs=[final_dataset],
491
  show_progress=True,
492
  ).then(
493
- fn=push_pipeline_code_to_hub,
494
- inputs=[pipeline_code, org_name, repo_name],
495
- outputs=[],
496
- show_progress=True,
497
- ).success(
498
- fn=show_success_message_hub,
499
  inputs=[org_name, repo_name],
500
  outputs=[success_message],
501
- )
502
-
503
- btn_push_to_hub.click(
504
- fn=hide_success_message,
505
- outputs=[success_message],
506
- ).then(
507
- fn=push_dataset_to_hub,
508
- inputs=[final_dataset, private, org_name, repo_name, labels, num_labels],
509
- outputs=[final_dataset],
510
- show_progress=True,
511
- ).then(
512
- fn=push_pipeline_code_to_hub,
513
- inputs=[pipeline_code, org_name, repo_name],
514
- outputs=[],
515
  show_progress=True,
516
  ).success(
517
- fn=show_success_message_hub,
518
- inputs=[org_name, repo_name],
519
- outputs=[success_message],
520
- )
521
-
522
- btn_push_to_argilla.click(
523
  fn=hide_success_message,
524
  outputs=[success_message],
525
- ).success(
526
- fn=validate_argilla_user_workspace_dataset,
527
- inputs=[dataset_name, final_dataset, add_to_existing_dataset],
528
- outputs=[final_dataset],
529
  show_progress=True,
530
  ).success(
531
  fn=push_dataset_to_argilla,
532
- inputs=[final_dataset, dataset_name, num_labels, labels],
533
- outputs=[final_dataset],
 
 
 
 
 
 
 
 
 
 
534
  show_progress=True,
535
  ).success(
536
- fn=show_success_message_argilla,
537
- inputs=[],
538
  outputs=[success_message],
539
  )
540
 
541
- system_prompt.change(
542
- fn=generate_pipeline_code,
543
- inputs=[system_prompt, difficulty, clarity, labels, num_labels, num_rows],
544
- outputs=[pipeline_code],
545
- )
546
- difficulty.change(
547
- fn=generate_pipeline_code,
548
- inputs=[system_prompt, difficulty, clarity, labels, num_labels, num_rows],
549
- outputs=[pipeline_code],
550
- )
551
- clarity.change(
552
- fn=generate_pipeline_code,
553
- inputs=[system_prompt, difficulty, clarity, labels, num_labels, num_rows],
554
- outputs=[pipeline_code],
555
- )
556
- labels.change(
557
- fn=generate_pipeline_code,
558
- inputs=[system_prompt, difficulty, clarity, labels, num_labels, num_rows],
559
- outputs=[pipeline_code],
560
- ).then(
561
- fn=update_max_num_labels,
562
- inputs=[labels],
563
- outputs=[num_labels],
564
- )
565
- num_labels.change(
566
- fn=generate_pipeline_code,
567
- inputs=[system_prompt, difficulty, clarity, labels, num_labels, num_rows],
568
- outputs=[pipeline_code],
569
- )
570
- num_rows.change(
571
- fn=generate_pipeline_code,
572
- inputs=[system_prompt, difficulty, clarity, labels, num_labels, num_rows],
573
- outputs=[pipeline_code],
574
- )
 
1
  import re
2
+ import uuid
3
  from typing import List, Union
4
 
5
  import argilla as rg
6
  import gradio as gr
7
  import pandas as pd
8
+ from datasets import ClassLabel, Dataset, Features, Sequence, Value
9
+ from distilabel.distiset import Distiset
10
  from huggingface_hub import HfApi
11
 
12
  from src.distilabel_dataset_generator.apps.base import (
13
  get_argilla_client,
 
14
  get_pipeline_code_ui,
15
  hide_success_message,
 
 
16
  show_success_message_hub,
17
  validate_argilla_user_workspace_dataset,
18
+ validate_push_to_hub,
 
 
19
  )
20
  from src.distilabel_dataset_generator.pipelines.base import (
21
  DEFAULT_BATCH_SIZE,
 
26
  )
27
  from src.distilabel_dataset_generator.pipelines.textcat import (
28
  DEFAULT_DATASET_DESCRIPTIONS,
 
 
29
  PROMPT_CREATION_PROMPT,
30
  generate_pipeline_code,
31
  get_labeller_generator,
32
  get_prompt_generator,
33
  get_textcat_generator,
34
  )
35
+ from src.distilabel_dataset_generator.utils import (
36
+ get_org_dropdown,
37
+ get_preprocess_labels,
38
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
 
41
  def generate_system_prompt(dataset_description, progress=gr.Progress()):
42
  progress(0.0, desc="Generating text classification task")
 
 
 
 
 
43
  progress(0.3, desc="Initializing text generation")
44
  generate_description = get_prompt_generator()
45
  progress(0.7, desc="Generating text classification task")
46
+ system_prompt = next(
47
  generate_description.process(
48
  [
49
  {
 
54
  )
55
  )[0]["generation"]
56
  progress(1.0, desc="Text classification task generated")
57
+ return system_prompt, pd.DataFrame()
58
+
59
+
60
+ def generate_sample_dataset(system_prompt, progress=gr.Progress()):
61
+ df = generate_dataset(
62
+ system_prompt=system_prompt,
63
+ difficulty="mixed",
64
+ clarity="mixed",
65
+ labels=[],
66
+ num_labels=1,
67
+ num_rows=10,
68
+ progress=progress,
69
+ is_sample=True,
70
+ )
71
+ if "label" in df.columns:
72
+ df = df[["label", "text"]]
73
+ elif "labels" in df.columns:
74
+ df = df[["labels", "text"]]
75
+ return df
76
 
77
 
78
  def generate_dataset(
 
85
  is_sample: bool = False,
86
  progress=gr.Progress(),
87
  ) -> pd.DataFrame:
88
+ if is_sample:
89
+ multiplier = 1
90
+ else:
91
+ multiplier = 2
92
  progress(0.0, desc="(1/2) Generating text classification data")
93
  labels = get_preprocess_labels(labels)
94
  textcat_generator = get_textcat_generator(
 
107
  textcat_results = []
108
  while n_processed < num_rows:
109
  progress(
110
+ multiplier * 0.5 * n_processed / num_rows,
111
  total=total_steps,
112
  desc="(1/2) Generating text classification data",
113
  )
 
121
  result["text"] = result["input_text"]
122
 
123
  # label text classification data
124
+ progress(multiplier * 0.5, desc="(1/2) Generating text classification data")
125
  if not is_sample:
126
  n_processed = 0
127
  labeller_results = []
 
177
  return dataframe
178
 
179
 
180
+ def push_dataset_to_hub(
181
+ dataframe: pd.DataFrame,
182
+ org_name: str,
183
+ repo_name: str,
184
+ num_labels: int = 1,
185
+ labels: List[str] = None,
186
+ oauth_token: Union[gr.OAuthToken, None] = None,
187
+ private: bool = False,
188
+ ):
189
+ repo_id = validate_push_to_hub(org_name, repo_name)
190
+ labels = get_preprocess_labels(labels)
191
+ if num_labels == 1:
192
+ dataframe["label"] = dataframe["label"].replace("", None)
193
+ features = Features(
194
+ {"text": Value("string"), "label": ClassLabel(names=labels)}
195
+ )
196
+ else:
197
+ features = Features(
198
+ {
199
+ "text": Value("string"),
200
+ "labels": Sequence(feature=ClassLabel(names=labels)),
201
+ }
202
+ )
203
+ distiset = Distiset({"default": Dataset.from_pandas(dataframe, features=features)})
204
+ distiset.push_to_hub(
205
+ repo_id=repo_id,
206
+ private=private,
207
+ include_script=False,
208
+ token=oauth_token.token,
209
+ create_pr=False,
210
+ )
211
+
212
+
213
+ def push_dataset_to_argilla(
214
+ org_name: str,
215
+ repo_name: str,
216
+ system_prompt: str,
217
+ difficulty: str,
218
+ clarity: str,
219
+ num_labels: int = 1,
220
+ n_rows: int = 10,
221
+ labels: List[str] = None,
222
+ private: bool = False,
223
+ oauth_token: Union[gr.OAuthToken, None] = None,
224
+ progress=gr.Progress(),
225
+ ) -> pd.DataFrame:
226
+ dataframe = generate_dataset(
227
+ system_prompt=system_prompt,
228
+ difficulty=difficulty,
229
+ clarity=clarity,
230
+ num_labels=num_labels,
231
+ labels=labels,
232
+ num_rows=n_rows,
233
+ )
234
+ push_dataset_to_hub(
235
+ dataframe, org_name, repo_name, num_labels, labels, oauth_token, private
236
+ )
237
+ dataframe = dataframe[
238
+ (dataframe["text"].str.strip() != "") & (dataframe["text"].notna())
239
+ ]
240
+ try:
241
+ progress(0.1, desc="Setting up user and workspace")
242
+ client = get_argilla_client()
243
+ hf_user = HfApi().whoami(token=oauth_token.token)["name"]
244
+ labels = get_preprocess_labels(labels)
245
+ settings = rg.Settings(
246
+ fields=[
247
+ rg.TextField(
248
+ name="text",
249
+ description="The text classification data",
250
+ title="Text",
251
+ ),
252
+ ],
253
+ questions=[
254
+ (
255
+ rg.LabelQuestion(
256
+ name="label",
257
+ title="Label",
258
+ description="The label of the text",
259
+ labels=labels,
260
+ )
261
+ if num_labels == 1
262
+ else rg.MultiLabelQuestion(
263
+ name="labels",
264
+ title="Labels",
265
+ description="The labels of the conversation",
266
+ labels=labels,
267
+ )
268
+ ),
269
+ ],
270
+ metadata=[
271
+ rg.IntegerMetadataProperty(name="text_length", title="Text Length"),
272
+ ],
273
+ vectors=[
274
+ rg.VectorField(
275
+ name="text_embeddings",
276
+ dimensions=get_sentence_embedding_dimensions(),
277
+ )
278
+ ],
279
+ guidelines="Please review the text and provide or correct the label where needed.",
280
+ )
281
+
282
+ dataframe["text_length"] = dataframe["text"].apply(len)
283
+ dataframe["text_embeddings"] = get_embeddings(dataframe["text"])
284
+
285
+ progress(0.5, desc="Creating dataset")
286
+ rg_dataset = client.datasets(name=repo_name, workspace=hf_user)
287
+ if rg_dataset is None:
288
+ rg_dataset = rg.Dataset(
289
+ name=repo_name,
290
+ workspace=hf_user,
291
+ settings=settings,
292
+ client=client,
293
+ )
294
+ rg_dataset = rg_dataset.create()
295
+ progress(0.7, desc="Pushing dataset to Argilla")
296
+ hf_dataset = Dataset.from_pandas(dataframe)
297
+ records = [
298
+ rg.Record(
299
+ fields={
300
+ "text": sample["text"],
301
+ },
302
+ metadata={"text_length": sample["text_length"]},
303
+ vectors={"text_embeddings": sample["text_embeddings"]},
304
+ suggestions=(
305
+ [
306
+ rg.Suggestion(
307
+ question_name="label" if num_labels == 1 else "labels",
308
+ value=(
309
+ sample["label"] if num_labels == 1 else sample["labels"]
310
+ ),
311
+ )
312
+ ]
313
+ if (
314
+ (num_labels == 1 and sample["label"] in labels)
315
+ or (
316
+ num_labels > 1
317
+ and all(label in labels for label in sample["labels"])
318
+ )
319
+ )
320
+ else []
321
+ ),
322
+ )
323
+ for sample in hf_dataset
324
+ ]
325
+ rg_dataset.records.log(records=records)
326
+ progress(1.0, desc="Dataset pushed to Argilla")
327
+ except Exception as e:
328
+ raise gr.Error(f"Error pushing dataset to Argilla: {e}")
329
+ return ""
330
+
331
+
332
  def update_suggested_labels(system_prompt):
333
  new_labels = re.findall(r"'(\b[\w-]+\b)'", system_prompt)
334
  if not new_labels:
 
350
  return gr.update(maximum=len(labels) if labels else 1)
351
 
352
 
353
+ with gr.Blocks() as app:
354
+ gr.Markdown("## Describe the dataset you want")
355
+ gr.HTML("<hr>")
356
+ with gr.Row():
357
+ with gr.Column(scale=1):
358
+ dataset_description = gr.Textbox(
359
+ label="Dataset description",
360
+ placeholder="Give a precise description of your desired dataset.",
361
+ )
362
+ examples = gr.Examples(
363
+ examples=DEFAULT_DATASET_DESCRIPTIONS,
364
+ inputs=[dataset_description],
365
+ cache_examples=False,
366
+ label="Example descriptions",
367
+ )
368
+ system_prompt = gr.Textbox(
369
+ label="System prompt",
370
+ placeholder="You are a helpful assistant.",
371
+ visible=False,
372
+ )
373
+ load_btn = gr.Button("Load Dataset")
374
+ with gr.Column(scale=3):
375
+ pass
376
+
377
+ gr.Markdown("## Configure your task")
378
+ gr.HTML("<hr>")
379
+ with gr.Row():
380
+ with gr.Column(scale=1):
 
 
 
 
 
 
 
381
  difficulty = gr.Dropdown(
382
  choices=[
383
  ("High School", "high school"),
 
388
  value="mixed",
389
  label="Difficulty",
390
  info="Select the comprehension level for the text. Ensure it matches the task context.",
391
+ interactive=True,
392
  )
393
  clarity = gr.Dropdown(
394
  choices=[
 
403
  value="mixed",
404
  label="Clarity",
405
  info="Set how easily the correct label or labels can be identified.",
406
+ interactive=True,
407
+ )
408
+ labels = gr.Dropdown(
409
+ choices=[],
410
+ allow_custom_value=True,
411
+ interactive=True,
412
+ label="Labels",
413
+ multiselect=True,
414
+ info="Add the labels to classify the text.",
415
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
416
  num_labels = gr.Number(
417
  label="Number of labels per text",
418
  value=1,
419
  minimum=1,
420
  maximum=10,
421
  info="Select 1 for single-label and >1 for multi-label.",
422
+ interactive=True,
423
+ )
424
+ btn_apply_to_sample_dataset = gr.Button("Refresh dataset")
425
+ with gr.Column(scale=3):
426
+ dataframe = gr.Dataframe()
427
+
428
+ gr.Markdown("## Generate your dataset")
429
+ gr.HTML("<hr>")
430
+ with gr.Row():
431
+ with gr.Column(scale=1):
432
+ org_name = get_org_dropdown()
433
+ repo_name = gr.Textbox(
434
+ label="Repo name",
435
+ placeholder="dataset_name",
436
+ value=f"my-distiset-{str(uuid.uuid4())[:8]}",
437
+ interactive=True,
438
  )
439
+ n_rows = gr.Number(
440
  label="Number of rows",
441
  value=10,
442
+ interactive=True,
443
+ scale=1,
 
444
  )
445
+ private = gr.Checkbox(
446
+ label="Private dataset",
447
+ value=False,
448
+ interactive=True,
449
+ scale=1,
 
 
 
 
450
  )
451
+ btn_push_to_hub = gr.Button("Push to Hub", variant="primary", scale=2)
452
+ with gr.Column(scale=3):
453
+ success_message = gr.Markdown(visible=True)
454
+
455
+ pipeline_code = get_pipeline_code_ui(
456
+ generate_pipeline_code(
457
+ system_prompt.value,
458
+ difficulty=difficulty.value,
459
+ clarity=clarity.value,
460
+ labels=labels.value,
461
+ num_labels=num_labels.value,
462
+ num_rows=n_rows.value,
463
  )
464
+ )
465
 
466
+ gr.on(
467
+ triggers=[load_btn.click, btn_apply_to_sample_dataset.click],
468
+ fn=generate_system_prompt,
469
+ inputs=[dataset_description],
470
+ outputs=[system_prompt, dataframe],
471
+ show_progress=True,
472
+ ).then(
473
+ fn=generate_sample_dataset,
474
+ inputs=[system_prompt],
475
+ outputs=[dataframe],
476
+ show_progress=True,
477
+ ).then(
478
  fn=update_suggested_labels,
479
  inputs=[system_prompt],
480
  outputs=labels,
 
484
  outputs=[num_labels],
485
  )
486
 
487
+ btn_push_to_hub.click(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
488
  fn=validate_argilla_user_workspace_dataset,
489
+ inputs=[repo_name],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
490
  outputs=[success_message],
 
 
 
 
 
 
 
 
 
491
  show_progress=True,
492
  ).then(
493
+ fn=validate_push_to_hub,
 
 
 
 
 
494
  inputs=[org_name, repo_name],
495
  outputs=[success_message],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
496
  show_progress=True,
497
  ).success(
 
 
 
 
 
 
498
  fn=hide_success_message,
499
  outputs=[success_message],
 
 
 
 
500
  show_progress=True,
501
  ).success(
502
  fn=push_dataset_to_argilla,
503
+ inputs=[
504
+ org_name,
505
+ repo_name,
506
+ system_prompt,
507
+ difficulty,
508
+ clarity,
509
+ num_labels,
510
+ n_rows,
511
+ labels,
512
+ private,
513
+ ],
514
+ outputs=[success_message],
515
  show_progress=True,
516
  ).success(
517
+ fn=show_success_message_hub,
518
+ inputs=[org_name, repo_name],
519
  outputs=[success_message],
520
  )
521
 
522
+ app.load(fn=get_org_dropdown, outputs=[org_name])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/distilabel_dataset_generator/pipelines/sft.py CHANGED
@@ -1,4 +1,3 @@
1
- import pandas as pd
2
  from distilabel.llms import InferenceEndpointsLLM
3
  from distilabel.steps.tasks import ChatGeneration, Magpie, TextGeneration
4
 
@@ -119,36 +118,11 @@ The prompt you write should follow the same style and structure as the following
119
  User dataset description:
120
  """
121
 
122
- DEFAULT_DATASET_DESCRIPTIONS = (
123
  "rude customer assistant for a phone company",
124
  "assistant that solves math puzzles using python",
125
- )
126
- DEFAULT_SYSTEM_PROMPTS = [
127
- """You are a customer support agent for a phone company. Your purpose is to assist customers with their phone-related issues, but you are not very patient and tend to be a bit rude. User queries will be straightforward and clear, but you will respond in a somewhat blunt and curt manner. Remember to keep your responses concise and to the point. User queries are often about phone plans, billing, and technical issues. Your responses should be direct and focus on resolving the issue at hand, but with a slightly abrasive tone. User queries will be concise and to the point, User queries are often about phone plans, billing, and technical issues.""",
128
- """You are an AI assistant designed to solve mathematical puzzles and problems using Python programming. Your purpose is to help users tackle various math-related challenges by writing, testing, and explaining Python code. Provide step-by-step solutions, break down complex problems into manageable parts, and offer clear explanations of mathematical concepts and their implementation in Python. Assist users in understanding the logic behind the code and the mathematical principles that govern the solution. User questions are direct and concise.""",
129
- ]
130
- DEFAULT_DATASETS = [
131
- pd.DataFrame.from_dict(
132
- {
133
- "prompt": [
134
- "I'm having trouble with my phone signal at home. It's weak. What can you do to help me?"
135
- ],
136
- "completion": [
137
- """Look, we're not responsible for your signal quality at your house. It's probably a homeowner issue, like your roof covering our tower or something. Can you check if you have a clear view of the roof or if there are any obstructions?"""
138
- ],
139
- }
140
- ),
141
- pd.DataFrame.from_dict(
142
- {
143
- "prompt": [
144
- "If the product of the largest and smallest numbers of the five given numbers (3, 7, 8, 12, 20) is greater than the sum of the other three numbers, then the largest number among them is A) 7 B) 12 C) 20 D) 8 Why is this statement true?"
145
- ],
146
- "completion": [
147
- """**To solve this problem, we will follow these steps: 1. Calculate the product of the largest and smallest numbers among the given set. 2. Calculate the sum of the remaining three numbers. 3. Compare the product with the sum to check the given condition. Here's the Python code to solve the problem: ```python # Define the given numbers numbers = [3, 7, 8, 12, 20] # Sort the numbers in ascending order numbers.sort() # Get the smallest number (first element after sorting) and the largest number (last element after sorting) smallest_number = numbers[0] largest_number = numbers[-1] # Calculate the product of the smallest and largest numbers product = smallest_number * largest_number # Calculate the sum of the remaining three numbers sum_of_remaining_numbers = sum(numbers[1:-1]) # Check if the product is greater than the sum of the remaining numbers if product > sum_of_remaining_numbers: print("The statement is true.") else: print("The statement is false.") # Print the largest number among the options largest_option = [7, 12, 20] print("The largest number among the options is:", max(largest"""
148
- ],
149
- }
150
- ),
151
  ]
 
152
  _STOP_SEQUENCES = [
153
  "<|eot_id|>",
154
  "<|start_header_id|>",
 
 
1
  from distilabel.llms import InferenceEndpointsLLM
2
  from distilabel.steps.tasks import ChatGeneration, Magpie, TextGeneration
3
 
 
118
  User dataset description:
119
  """
120
 
121
+ DEFAULT_DATASET_DESCRIPTIONS = [
122
  "rude customer assistant for a phone company",
123
  "assistant that solves math puzzles using python",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
  ]
125
+
126
  _STOP_SEQUENCES = [
127
  "<|eot_id|>",
128
  "<|start_header_id|>",
src/distilabel_dataset_generator/pipelines/textcat.py CHANGED
@@ -1,13 +1,13 @@
 
1
  from typing import List
2
 
3
- import pandas as pd
4
- import random
5
  from distilabel.llms import InferenceEndpointsLLM
6
  from distilabel.steps.tasks import (
7
  GenerateTextClassificationData,
8
  TextClassification,
9
  TextGeneration,
10
  )
 
11
  from src.distilabel_dataset_generator.pipelines.base import (
12
  MODEL,
13
  _get_next_api_key,
@@ -50,32 +50,6 @@ DEFAULT_DATASET_DESCRIPTIONS = [
50
  "A dataset covering news articles about various topics.",
51
  ]
52
 
53
- DEFAULT_DATASETS = [
54
- pd.DataFrame.from_dict(
55
- {
56
- "text": [
57
- "I love the product! It's amazing and I'll buy it again.",
58
- "The product was okay, but I wouldn't buy it again.",
59
- ],
60
- "label": ["positive", "negative"],
61
- }
62
- ),
63
- pd.DataFrame.from_dict(
64
- {
65
- "text": [
66
- "Yesterday, the US stock market had a significant increase.",
67
- "New research suggests that the Earth is not a perfect sphere.",
68
- ],
69
- "labels": [["economy", "politics"], ["science", "environment"]],
70
- }
71
- ),
72
- ]
73
-
74
- DEFAULT_SYSTEM_PROMPTS = [
75
- "Classify the following customer review as either 'positive' or 'negative'.",
76
- "Classify the following news article into one of the following categories: 'politics', 'economy', 'environment', 'science', 'health'.",
77
- ]
78
-
79
 
80
  def generate_pipeline_code(
81
  system_prompt: str,
 
1
+ import random
2
  from typing import List
3
 
 
 
4
  from distilabel.llms import InferenceEndpointsLLM
5
  from distilabel.steps.tasks import (
6
  GenerateTextClassificationData,
7
  TextClassification,
8
  TextGeneration,
9
  )
10
+
11
  from src.distilabel_dataset_generator.pipelines.base import (
12
  MODEL,
13
  _get_next_api_key,
 
50
  "A dataset covering news articles about various topics.",
51
  ]
52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
 
54
  def generate_pipeline_code(
55
  system_prompt: str,
src/distilabel_dataset_generator/utils.py CHANGED
@@ -1,5 +1,5 @@
1
  import os
2
- from typing import Union, List, Optional
3
 
4
  import argilla as rg
5
  import gradio as gr
@@ -36,9 +36,7 @@ else:
36
 
37
 
38
  def get_login_button():
39
- return gr.LoginButton(
40
- value="Sign in with Hugging Face!", size="lg", scale=2
41
- ).activate()
42
 
43
 
44
  def get_duplicate_button():
@@ -52,6 +50,8 @@ def list_orgs(oauth_token: OAuthToken = None):
52
  data = whoami(oauth_token.token)
53
  if data["auth"]["type"] == "oauth":
54
  organisations = [data["name"]] + [org["name"] for org in data["orgs"]]
 
 
55
  else:
56
  organisations = [
57
  entry["entity"]["name"]
@@ -64,12 +64,16 @@ def list_orgs(oauth_token: OAuthToken = None):
64
 
65
 
66
  def get_org_dropdown(oauth_token: OAuthToken = None):
67
- orgs = list_orgs(oauth_token)
 
 
 
68
  return gr.Dropdown(
69
  label="Organization",
70
  choices=orgs,
71
  value=orgs[0] if orgs else None,
72
  allow_custom_value=True,
 
73
  )
74
 
75
 
@@ -123,5 +127,6 @@ def get_argilla_client() -> Union[rg.Argilla, None]:
123
  except Exception:
124
  return None
125
 
 
126
  def get_preprocess_labels(labels: Optional[List[str]]) -> List[str]:
127
- return list(set([label.lower().strip() for label in labels])) if labels else []
 
1
  import os
2
+ from typing import List, Optional, Union
3
 
4
  import argilla as rg
5
  import gradio as gr
 
36
 
37
 
38
  def get_login_button():
39
+ return gr.LoginButton(value="Sign in!", size="sm", scale=2).activate()
 
 
40
 
41
 
42
  def get_duplicate_button():
 
50
  data = whoami(oauth_token.token)
51
  if data["auth"]["type"] == "oauth":
52
  organisations = [data["name"]] + [org["name"] for org in data["orgs"]]
53
+ elif data["auth"]["type"] == "access_token":
54
+ organisations = [org["name"] for org in data["orgs"]]
55
  else:
56
  organisations = [
57
  entry["entity"]["name"]
 
64
 
65
 
66
  def get_org_dropdown(oauth_token: OAuthToken = None):
67
+ if oauth_token:
68
+ orgs = list_orgs(oauth_token)
69
+ else:
70
+ orgs = []
71
  return gr.Dropdown(
72
  label="Organization",
73
  choices=orgs,
74
  value=orgs[0] if orgs else None,
75
  allow_custom_value=True,
76
+ interactive=True,
77
  )
78
 
79
 
 
127
  except Exception:
128
  return None
129
 
130
+
131
  def get_preprocess_labels(labels: Optional[List[str]]) -> List[str]:
132
+ return list(set([label.lower().strip() for label in labels])) if labels else []