Guilherme34 commited on
Commit
e77427d
·
verified ·
1 Parent(s): b9a08dd

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .dockerignore +57 -0
  2. .env.local.template +54 -0
  3. .gitattributes +5 -0
  4. .gitignore +173 -0
  5. LICENSE +201 -0
  6. README.md +168 -6
  7. app.py +4 -0
  8. assets/argilla.png +3 -0
  9. assets/flow.png +3 -0
  10. assets/logo.png +0 -0
  11. assets/logo.svg +1 -0
  12. assets/ui-full.png +3 -0
  13. assets/ui.png +3 -0
  14. docker-compose.yml +17 -0
  15. docker/.env.docker.template +43 -0
  16. docker/Dockerfile +45 -0
  17. docker/README.md +80 -0
  18. docker/argilla/compose.yml +118 -0
  19. docker/ollama/compose.yml +48 -0
  20. docker/ollama/entrypoint.sh +35 -0
  21. examples/argilla-deployment.py +18 -0
  22. examples/blog_private_synthetic_data_generation.md +222 -0
  23. examples/fine-tune-deepseek-reasoning-sft.ipynb +0 -0
  24. examples/fine-tune-modernbert-classifier.ipynb +538 -0
  25. examples/fine-tune-modernbert-rag.ipynb +980 -0
  26. examples/fine-tune-smollm2-on-synthetic-data.ipynb +310 -0
  27. examples/hf-dedicated-or-tgi-deployment.py +19 -0
  28. examples/hf-serverless-deployment-deepseek.py +16 -0
  29. examples/hf-serverless-deployment.py +15 -0
  30. examples/hf-serverless-different-model-for-completion.py +16 -0
  31. examples/ollama-deployment.py +22 -0
  32. examples/ollama-different-model-for-completion.py +26 -0
  33. examples/openai-deployment.py +18 -0
  34. examples/vllm-deployment.py +21 -0
  35. packages.txt +2 -0
  36. pdm.lock +0 -0
  37. pyproject.toml +40 -0
  38. requirements.txt +1 -0
  39. src/synthetic_dataset_generator/__init__.py +20 -0
  40. src/synthetic_dataset_generator/__main__.py +4 -0
  41. src/synthetic_dataset_generator/_distiset.py +148 -0
  42. src/synthetic_dataset_generator/_inference_endpoints.py +58 -0
  43. src/synthetic_dataset_generator/_tabbedinterface.py +69 -0
  44. src/synthetic_dataset_generator/app.py +35 -0
  45. src/synthetic_dataset_generator/apps/__init__.py +0 -0
  46. src/synthetic_dataset_generator/apps/about.py +15 -0
  47. src/synthetic_dataset_generator/apps/base.py +270 -0
  48. src/synthetic_dataset_generator/apps/chat.py +1142 -0
  49. src/synthetic_dataset_generator/apps/eval.py +894 -0
  50. src/synthetic_dataset_generator/apps/rag.py +972 -0
.dockerignore ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Version control
2
+ .git
3
+ .gitignore
4
+
5
+ # Python
6
+ __pycache__/
7
+ *.py[cod]
8
+ *$py.class
9
+ *.so
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+
28
+ # Virtual environments
29
+ .env*
30
+ !.env.example
31
+ .venv
32
+ env/
33
+ venv/
34
+ ENV/
35
+
36
+ # IDE
37
+ .idea/
38
+ .vscode/
39
+ *.swp
40
+ *.swo
41
+
42
+ # Testing
43
+ .tox/
44
+ .coverage
45
+ .coverage.*
46
+ .cache
47
+ nosetests.xml
48
+ coverage.xml
49
+ *.cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+
53
+ # Project specific
54
+ nltk_data/
55
+ .pdm-python
56
+ .pdm.toml
57
+ __pypackages__/
.env.local.template ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # =============================================================================
2
+ # LOCAL/API CONFIGURATION
3
+ # =============================================================================
4
+
5
+ # -----------------------------------------------------------------------------
6
+ # REQUIRED CONFIGURATION
7
+ # -----------------------------------------------------------------------------
8
+ # Hugging Face token (required for all setups)
9
+ HF_TOKEN=hf_...
10
+
11
+ # Generation Settings
12
+ MAX_NUM_TOKENS=2048
13
+ MAX_NUM_ROWS=1000
14
+ DEFAULT_BATCH_SIZE=5
15
+
16
+ # Required for chat data generation with Llama or Qwen models
17
+ # Options: "llama3", "qwen2", or custom template string
18
+ MAGPIE_PRE_QUERY_TEMPLATE=llama3
19
+
20
+ # -----------------------------------------------------------------------------
21
+ # A. CLOUD API SERVICES
22
+ # -----------------------------------------------------------------------------
23
+
24
+ # 1. HUGGING FACE INFERENCE API (Default, Recommended)
25
+ MODEL=meta-llama/Llama-3.1-8B-Instruct
26
+ # MODEL=Qwen/Qwen2.5-1.5B-Instruct
27
+
28
+ # 2. OPENAI API
29
+ # OPENAI_BASE_URL=https://api.openai.com/v1/
30
+ # MODEL=gpt-4
31
+ # API_KEY=sk-...
32
+
33
+ # 3. HUGGING FACE SPACE FOR ARGILLA (optional)
34
+ # ARGILLA_API_URL=https://your-space.hf.space/
35
+ # ARGILLA_API_KEY=your_key
36
+
37
+ # -----------------------------------------------------------------------------
38
+ # B. LOCAL SERVICES (Requires Installation)
39
+ # -----------------------------------------------------------------------------
40
+
41
+ # 1. LOCAL OLLAMA
42
+ # OLLAMA_BASE_URL=http://127.0.0.1:11434/
43
+ # MODEL=llama3.2:1b
44
+ # TOKENIZER_ID=meta-llama/Llama-3.2-1B-Instruct
45
+
46
+ # 2. LOCAL VLLM
47
+ # VLLM_BASE_URL=http://127.0.0.1:8000/
48
+ # MODEL=Qwen/Qwen2.5-1.5B-Instruct
49
+ # TOKENIZER_ID=Qwen/Qwen2.5-1.5B-Instruct
50
+
51
+ # 3. LOCAL TGI
52
+ # HUGGINGFACE_BASE_URL=http://127.0.0.1:3000/
53
+ # MODEL=meta-llama/Llama-3.1-8B-Instruct
54
+ # TOKENIZER_ID=meta-llama/Llama-3.1-8B-Instruct
.gitattributes CHANGED
@@ -33,3 +33,8 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ assets/flow.png filter=lfs diff=lfs merge=lfs -text
37
+ *.sh text eol=lf
38
+ assets/argilla.png filter=lfs diff=lfs merge=lfs -text
39
+ assets/ui-full.png filter=lfs diff=lfs merge=lfs -text
40
+ assets/ui.png filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # poetry
98
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
+ #poetry.lock
103
+
104
+ # pdm
105
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
+ #pdm.lock
107
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
+ # in version control.
109
+ # https://pdm-project.org/#use-with-ide
110
+ .pdm.toml
111
+ .pdm-python
112
+ .pdm-build/
113
+
114
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
115
+ __pypackages__/
116
+
117
+ # Celery stuff
118
+ celerybeat-schedule
119
+ celerybeat.pid
120
+
121
+ # SageMath parsed files
122
+ *.sage.py
123
+
124
+ # Environments
125
+ .env
126
+ .venv
127
+ env/
128
+ venv/
129
+ ENV/
130
+ env.bak/
131
+ venv.bak/
132
+ .python-version
133
+
134
+ # Spyder project settings
135
+ .spyderproject
136
+ .spyproject
137
+
138
+ # Rope project settings
139
+ .ropeproject
140
+
141
+ # mkdocs documentation
142
+ /site
143
+
144
+ # mypy
145
+ .mypy_cache/
146
+ .dmypy.json
147
+ dmypy.json
148
+
149
+ # Pyre type checker
150
+ .pyre/
151
+
152
+ # pytype static type analyzer
153
+ .pytype/
154
+
155
+ # Cython debug symbols
156
+ cython_debug/
157
+
158
+ # PyCharm
159
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
160
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
161
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
162
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
163
+ #.idea/
164
+ .DS_Store
165
+
166
+ # nltk
167
+ nltk_data/
168
+
169
+ # examples
170
+ models/
171
+
172
+ # Elasticsearch data
173
+ elasticsearch_data/
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
README.md CHANGED
@@ -1,10 +1,172 @@
1
  ---
2
  title: Synthetic Data Generator
3
- emoji: 🐠
4
- colorFrom: blue
5
- colorTo: blue
6
- sdk: docker
7
- pinned: false
 
 
 
 
 
 
 
 
 
 
 
8
  ---
9
 
10
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  title: Synthetic Data Generator
3
+ short_description: Build datasets using natural language
4
+ emoji: 🧬
5
+ colorFrom: yellow
6
+ colorTo: pink
7
+ sdk: gradio
8
+ sdk_version: 5.8.0
9
+ app_file: app.py
10
+ pinned: true
11
+ license: apache-2.0
12
+ hf_oauth: true
13
+ #header: mini
14
+ hf_oauth_scopes:
15
+ - read-repos
16
+ - write-repos
17
+ - manage-repos
18
+ - inference-api
19
  ---
20
 
21
+ > [!IMPORTANT]
22
+ The original authors have moved on to other projects. While the code might still be functional for its original purpose, please be aware that the original team does not plan to develop new features, bug fixes, or updates. If you'd like to become a maintainer, please open an issue to discuss.
23
+ >
24
+ >
25
+ <br>
26
+
27
+ <h2 align="center">
28
+ <a href=""><img src="https://raw.githubusercontent.com/argilla-io/synthetic-data-generator/main/assets/logo.svg" alt="Synthetic Data Generator Logo" width="80%"></a>
29
+ </h2>
30
+ <h3 align="center">Build datasets using natural language</h3>
31
+
32
+ ![Synthetic Data Generator](https://huggingface.co/spaces/argilla/synthetic-data-generator/resolve/main/assets/ui-full.png)
33
+
34
+ ## Introduction
35
+
36
+ Synthetic Data Generator is a tool that allows you to create high-quality datasets for training and fine-tuning language models. It leverages the power of distilabel and LLMs to generate synthetic data tailored to your specific needs. [The announcement blog](https://huggingface.co/blog/synthetic-data-generator) goes over a practical example of how to use it but you can also watch the [video](https://www.youtube.com/watch?v=nXjVtnGeEss) to see it in action.
37
+
38
+ Supported Tasks:
39
+
40
+ - Text Classification
41
+ - Chat Data for Supervised Fine-Tuning
42
+ - Retrieval Augmented Generation
43
+
44
+ This tool simplifies the process of creating custom datasets, enabling you to:
45
+
46
+ - Describe the characteristics of your desired application
47
+ - Iterate on sample datasets
48
+ - Produce full-scale datasets
49
+ - Push your datasets to the [Hugging Face Hub](https://huggingface.co/datasets?other=datacraft) and/or [Argilla](https://docs.argilla.io/)
50
+
51
+ By using the Synthetic Data Generator, you can rapidly prototype and create datasets for, accelerating your AI development process.
52
+
53
+ <p align="center">
54
+ <a href="https://twitter.com/argilla_io">
55
+ <img src="https://img.shields.io/badge/twitter-black?logo=x"/>
56
+ </a>
57
+ <a href="https://www.linkedin.com/company/argilla-io">
58
+ <img src="https://img.shields.io/badge/linkedin-blue?logo=linkedin"/>
59
+ </a>
60
+ <a href="http://hf.co/join/discord">
61
+ <img src="https://img.shields.io/badge/Discord-7289DA?&logo=discord&logoColor=white"/>
62
+ </a>
63
+ </p>
64
+
65
+ ## Installation
66
+
67
+ You can simply install the package with:
68
+
69
+ ```bash
70
+ pip install synthetic-dataset-generator
71
+ ```
72
+
73
+ ### Quickstart
74
+
75
+ ```python
76
+ from synthetic_dataset_generator import launch
77
+
78
+ launch()
79
+ ```
80
+
81
+ ### Environment Variables
82
+
83
+ - `HF_TOKEN`: Your [Hugging Face token](https://huggingface.co/settings/tokens/new?ownUserPermissions=repo.content.read&ownUserPermissions=repo.write&globalPermissions=inference.serverless.write&tokenType=fineGrained) to push your datasets to the Hugging Face Hub and generate free completions from Hugging Face Inference Endpoints. You can find some configuration examples in the [examples](examples/) folder.
84
+
85
+ You can set the following environment variables to customize the generation process.
86
+
87
+ - `MAX_NUM_TOKENS`: The maximum number of tokens to generate, defaults to `2048`.
88
+ - `MAX_NUM_ROWS`: The maximum number of rows to generate, defaults to `1000`.
89
+ - `DEFAULT_BATCH_SIZE`: The default batch size to use for generating the dataset, defaults to `5`.
90
+
91
+ Optionally, you can use different API providers and models.
92
+
93
+ - `MODEL`: The model to use for generating the dataset, e.g. `meta-llama/Meta-Llama-3.1-8B-Instruct`, `gpt-4o`, `llama3.1`.
94
+ - `API_KEY`: The API key to use for the generation API, e.g. `hf_...`, `sk-...`. If not provided, it will default to the `HF_TOKEN` environment variable.
95
+ - `OPENAI_BASE_URL`: The base URL for any OpenAI compatible API, e.g. `https://api.openai.com/v1/`.
96
+ - `OLLAMA_BASE_URL`: The base URL for any Ollama compatible API, e.g. `http://127.0.0.1:11434/`.
97
+ - `HUGGINGFACE_BASE_URL`: The base URL for any Hugging Face compatible API, e.g. TGI server or Dedicated Inference Endpoints. If you want to use serverless inference, only set the `MODEL`.
98
+ - `VLLM_BASE_URL`: The base URL for any VLLM compatible API, e.g. `http://localhost:8000/`.
99
+
100
+ To use a specific model exclusively for generating completions, set the corresponding environment variables by appending `_COMPLETION` to the ones mentioned earlier. For example, you can use `MODEL_COMPLETION` and `OPENAI_BASE_URL_COMPLETION`.
101
+
102
+ SFT and Chat Data generation is not supported with OpenAI Endpoints. Additionally, you need to configure it per model family based on their prompt templates using the right `TOKENIZER_ID` and `MAGPIE_PRE_QUERY_TEMPLATE` environment variables.
103
+
104
+ - `TOKENIZER_ID`: The tokenizer ID to use for the magpie pipeline, e.g. `meta-llama/Meta-Llama-3.1-8B-Instruct`.
105
+ - `MAGPIE_PRE_QUERY_TEMPLATE`: Enforce setting the pre-query template for Magpie, which is only supported with Hugging Face Inference Endpoints. `llama3` and `qwen2` are supported out of the box and will use `"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n"` and `"<|im_start|>user\n"`, respectively. For other models, you can pass a custom pre-query template string.
106
+
107
+ Optionally, you can also push your datasets to Argilla for further curation by setting the following environment variables:
108
+
109
+ - `ARGILLA_API_KEY`: Your Argilla API key to push your datasets to Argilla.
110
+ - `ARGILLA_API_URL`: Your Argilla API URL to push your datasets to Argilla.
111
+
112
+ To save the generated datasets to a local directory instead of pushing them to the Hugging Face Hub, set the following environment variable:
113
+
114
+ - `SAVE_LOCAL_DIR`: The local directory to save the generated datasets to.
115
+
116
+ You can use our environment template as a starting point:
117
+
118
+ ```bash
119
+ cp .env.local.template .env
120
+ ```
121
+
122
+ ### Argilla integration
123
+
124
+ Argilla is an open source tool for data curation. It allows you to annotate and review datasets, and push curated datasets to the Hugging Face Hub. You can easily get started with Argilla by following the [quickstart guide](https://docs.argilla.io/latest/getting_started/quickstart/).
125
+
126
+ ![Argilla integration](https://huggingface.co/spaces/argilla/synthetic-data-generator/resolve/main/assets/argilla.png)
127
+
128
+ ## Custom synthetic data generation?
129
+
130
+ Each pipeline is based on distilabel, so you can easily change the LLM or the pipeline steps.
131
+
132
+ Check out the [distilabel library](https://github.com/argilla-io/distilabel) for more information.
133
+
134
+ ## Development
135
+
136
+ Install the dependencies:
137
+
138
+ ```bash
139
+ # Create a virtual environment
140
+ python -m venv .venv
141
+ source .venv/bin/activate
142
+
143
+ # Install the dependencies
144
+ pip install -e . # pdm install
145
+ ```
146
+
147
+ Run the app:
148
+
149
+ ```bash
150
+ python app.py
151
+ ```
152
+
153
+ ## 🐳 Docker Setup
154
+
155
+ The containerized tool uses Ollama for local LLM inference and Argilla for data curation. Here's the architecture:
156
+
157
+ ![Container Structure](https://cdn-uploads.huggingface.co/production/uploads/64461026e1fd8d65b27e6187/Uz-kDOBrV-_GahUrc1K_O.png)
158
+
159
+ Quick setup with all services (App + Ollama + Argilla):
160
+
161
+ ```bash
162
+ # Copy environment template
163
+ cp docker/.env.docker.template .env # Add your HF_TOKEN in .env
164
+
165
+ # Build all services (this may take a few minutes)
166
+ docker compose -f docker-compose.yml -f docker/ollama/compose.yml -f docker/argilla/compose.yml build
167
+
168
+ # Start all services
169
+ docker compose -f docker-compose.yml -f docker/ollama/compose.yml -f docker/argilla/compose.yml up -d
170
+ ```
171
+
172
+ > For more detailed Docker configurations and setups, check [docker/README.md](docker/README.md)
app.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from synthetic_dataset_generator import launch
2
+
3
+ if __name__ == "__main__":
4
+ launch()
assets/argilla.png ADDED

Git LFS Details

  • SHA256: 1892b7867842f7f5154c3923278c42d21ec7b6c4bacd159951b8d32d9e64524b
  • Pointer size: 131 Bytes
  • Size of remote file: 475 kB
assets/flow.png ADDED

Git LFS Details

  • SHA256: b0465f5f3ed2a87b14cc609a1f25a1e7b0bfeb1cc8cab534a6ec79a9a8651996
  • Pointer size: 132 Bytes
  • Size of remote file: 1.81 MB
assets/logo.png ADDED
assets/logo.svg ADDED
assets/ui-full.png ADDED

Git LFS Details

  • SHA256: a38e10e98dd3ed4c93bfd0a5ec7ebc2584cd4ed54c120aad5da9809b8422dc75
  • Pointer size: 131 Bytes
  • Size of remote file: 968 kB
assets/ui.png ADDED

Git LFS Details

  • SHA256: fdd5805b833fca7b064a67f220489e88bee139348b094bf50a907adb733aad5b
  • Pointer size: 131 Bytes
  • Size of remote file: 652 kB
docker-compose.yml ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ services:
2
+ app:
3
+ build:
4
+ context: .
5
+ dockerfile: docker/Dockerfile
6
+ image: synthetic-data-generator:app
7
+ ports:
8
+ - "7860:7860"
9
+ env_file:
10
+ - .env
11
+ networks:
12
+ - app-network
13
+
14
+ networks:
15
+ app-network:
16
+ name: synthetic-data-network
17
+ driver: bridge
docker/.env.docker.template ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # =============================================================================
2
+ # DOCKER CONFIGURATION ONLY - FULL SETUP (APP + OLLAMA + ARGILLA)
3
+ # =============================================================================
4
+
5
+ # Note: Before building:
6
+ # 1. Copy this template to the root directory: cp docker/.env.docker.template .env
7
+ # 2. Comment/uncomment the sections you want to use (OLLAMA and/or ARGILLA)
8
+ # 3. Then build and run with the appropriate docker compose command
9
+
10
+ # Hugging Face token with read/write permissions
11
+ HF_TOKEN=your_token_here
12
+
13
+ # -----------------------------------------------------------------------------
14
+ # GENERATION SETTINGS
15
+ # -----------------------------------------------------------------------------
16
+ MAX_NUM_TOKENS=2048
17
+ MAX_NUM_ROWS=1000
18
+ DEFAULT_BATCH_SIZE=5
19
+
20
+ # -----------------------------------------------------------------------------
21
+ # OLLAMA DOCKER CONFIGURATION
22
+ # -----------------------------------------------------------------------------
23
+ OLLAMA_BASE_URL=http://ollama:11434
24
+ OLLAMA_HARDWARE=latest # latest (for CPU/NVIDIA), rocm (for AMD)
25
+
26
+ # LLAMA 3.2
27
+ MODEL=llama3.2:1b
28
+ TOKENIZER_ID=meta-llama/Llama-3.2-1B-Instruct
29
+ MAGPIE_PRE_QUERY_TEMPLATE=llama3
30
+
31
+ # DEEPSEEK R1
32
+ #MODEL=deepseek-r1:1.5b # must match ollama tags https://ollama.com/library/deepseek-r1:1.5b
33
+ #TOKENIZER_ID=deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B
34
+ #MAGPIE_PRE_QUERY_TEMPLATE= "<|begin▁of▁sentence|>User: "
35
+
36
+ # -----------------------------------------------------------------------------
37
+ # ARGILLA DOCKER CONFIGURATION (persistent data)
38
+ # -----------------------------------------------------------------------------
39
+ ARGILLA_API_URL=http://argilla:6900
40
+ ARGILLA_USERNAME=admin
41
+ ARGILLA_PASSWORD=admin1234
42
+ ARGILLA_API_KEY=admin.1234
43
+ ARGILLA_REINDEX_DATASET=1
docker/Dockerfile ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Use Python slim image as base
2
+ FROM python:3.10-slim
3
+
4
+ # Set environment variables
5
+ ENV PYTHONUNBUFFERED=1 \
6
+ PYTHONDONTWRITEBYTECODE=1 \
7
+ PIP_NO_CACHE_DIR=1
8
+
9
+ # Create and set working directory
10
+ WORKDIR /app
11
+
12
+ # Create non-root user first
13
+ RUN useradd -m -u 1000 appuser
14
+
15
+ # Install system dependencies including build tools
16
+ RUN apt-get update && apt-get install -y --no-install-recommends \
17
+ curl \
18
+ build-essential \
19
+ cmake \
20
+ libgl1-mesa-glx \
21
+ libglib2.0-0 \
22
+ libsm6 \
23
+ libxext6 \
24
+ libxrender-dev \
25
+ && rm -rf /var/lib/apt/lists/*
26
+
27
+ # Install pdm
28
+ RUN pip install --no-cache-dir pdm
29
+
30
+ # Copy project files and set permissions
31
+ COPY . .
32
+ RUN chown -R appuser:appuser /app && \
33
+ chmod -R 755 /app
34
+
35
+ # Switch to non-root user
36
+ USER appuser
37
+
38
+ # Install dependencies in a virtual environment
39
+ RUN pdm install --prod --frozen-lockfile
40
+
41
+ # Expose Gradio port
42
+ EXPOSE 7860
43
+
44
+ # Start command using pdm run to use the virtual environment
45
+ CMD ["pdm", "run", "python", "-m", "synthetic_dataset_generator"]
docker/README.md ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Docker Configuration Guide
2
+
3
+ Each service runs in its own container, communicating through internal networks. The core app connects to Ollama for model inference and Argilla for data review:
4
+
5
+ ![Container Structure](https://cdn-uploads.huggingface.co/production/uploads/64461026e1fd8d65b27e6187/Uz-kDOBrV-_GahUrc1K_O.png)
6
+
7
+ The application can be run with different configurations using Docker Compose:
8
+
9
+ - `docker-compose.yml`: Core application
10
+ - `docker/ollama/compose.yml`: Ollama service for local LLM inference
11
+ - `docker/argilla/compose.yml`: Argilla service for data curation
12
+
13
+ ## Ollama Integration
14
+
15
+ The `MODEL` variable in your `.env` file determines which model Ollama will download and use. For example:
16
+ ```env
17
+ MODEL=llama3.2:1b
18
+ ```
19
+
20
+ ## Setup Options
21
+
22
+ ### Full Setup (App + Ollama + Argilla)
23
+ ```bash
24
+ # Keep all sections uncommented in .env
25
+ docker compose -f docker-compose.yml -f docker/ollama/compose.yml -f docker/argilla/compose.yml build
26
+ docker compose -f docker-compose.yml -f docker/ollama/compose.yml -f docker/argilla/compose.yml up -d
27
+ ```
28
+
29
+ ### App + Ollama
30
+ ```bash
31
+ # Comment out ARGILLA section in .env
32
+ docker compose -f docker-compose.yml -f docker/ollama/compose.yml build
33
+ docker compose -f docker-compose.yml -f docker/ollama/compose.yml up -d
34
+ ```
35
+
36
+ ### App + Argilla
37
+ ```bash
38
+ # Comment out OLLAMA section in .env
39
+ docker compose -f docker-compose.yml -f docker/argilla/compose.yml build
40
+ docker compose -f docker-compose.yml -f docker/argilla/compose.yml up -d
41
+ ```
42
+
43
+ ### App Only
44
+ ```bash
45
+ # Comment out both OLLAMA and ARGILLA sections in .env
46
+ docker compose -f docker-compose.yml build
47
+ docker compose -f docker-compose.yml up -d
48
+ ```
49
+
50
+ ## Managing Services
51
+
52
+ Services are built separately but are linked together. If you already have some services built and want to add another:
53
+
54
+ 1. You don't need to rebuild existing services
55
+ 2. Just build the new service
56
+ 3. Stop everything with `down` and start again with `up`
57
+
58
+ For example, if you have App + Ollama and want to add Argilla:
59
+ ```bash
60
+ docker compose -f docker/argilla/compose.yml build # only build Argilla
61
+ docker compose -f docker-compose.yml -f docker/ollama/compose.yml -f docker/argilla/compose.yml down
62
+ docker compose -f docker-compose.yml -f docker/ollama/compose.yml -f docker/argilla/compose.yml up -d
63
+ ```
64
+
65
+ Similarly, if you have built all services but want to run only some of them:
66
+ > **Important**: When running specific services, remember to comment out unused services in `.env` first
67
+
68
+ ```bash
69
+ # No need to build again, just start the services you need
70
+ docker compose -f docker-compose.yml -f docker/ollama/compose.yml up -d # start only App + Ollama
71
+ ```
72
+
73
+ ## Service URLs
74
+
75
+ Once running, access the services at:
76
+ - App: http://localhost:7860
77
+ - Argilla: http://localhost:6900 (if enabled)
78
+ - Ollama: http://localhost:11434 (if enabled)
79
+
80
+ > Note: Services will be available after a few seconds while they initialize. Ollama models and Argilla datasets are persisted and available after restarts
docker/argilla/compose.yml ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ services:
2
+ app:
3
+ extends:
4
+ file: docker-compose.yml
5
+ service: app
6
+ depends_on:
7
+ argilla:
8
+ condition: service_healthy
9
+ required: false
10
+ environment:
11
+ - ARGILLA_API_URL=http://argilla:6900
12
+
13
+ elasticsearch:
14
+ image: docker.elastic.co/elasticsearch/elasticsearch:8.17.0
15
+ environment:
16
+ - ES_JAVA_OPTS=-Xms512m -Xmx512m
17
+ - node.name=elasticsearch
18
+ - cluster.name=es-argilla-local
19
+ - discovery.type=single-node
20
+ - cluster.routing.allocation.disk.threshold_enabled=false
21
+ - xpack.security.enabled=false
22
+ volumes:
23
+ - es_data:/usr/share/elasticsearch/data
24
+ networks:
25
+ - app-network
26
+ ports:
27
+ - "9200:9200"
28
+ - "9300:9300"
29
+ ulimits:
30
+ memlock:
31
+ soft: -1
32
+ hard: -1
33
+ nofile:
34
+ soft: 65536
35
+ hard: 65536
36
+ healthcheck:
37
+ test: ["CMD", "curl", "-f", "http://localhost:9200"]
38
+ interval: 30s
39
+ timeout: 10s
40
+ retries: 3
41
+
42
+ postgres:
43
+ image: postgres:14
44
+ environment:
45
+ POSTGRES_USER: postgres
46
+ POSTGRES_PASSWORD: postgres
47
+ POSTGRES_DB: argilla
48
+ networks:
49
+ - app-network
50
+ volumes:
51
+ - postgres_data:/var/lib/postgresql/data
52
+
53
+ redis:
54
+ image: redis
55
+ networks:
56
+ - app-network
57
+
58
+ argilla:
59
+ image: argilla/argilla-server:latest
60
+ ports:
61
+ - "6900:6900"
62
+ healthcheck:
63
+ test: ["CMD", "curl", "-f", "http://localhost:6900/api/ready"]
64
+ interval: 30s
65
+ timeout: 10s
66
+ retries: 3
67
+ env_file:
68
+ - .env
69
+ environment:
70
+ - ARGILLA_HOME_PATH=/var/lib/argilla
71
+ - ARGILLA_ELASTICSEARCH=http://elasticsearch:9200
72
+ - ARGILLA_DATABASE_URL=postgresql+asyncpg://postgres:postgres@postgres:5432/argilla
73
+ - ARGILLA_REDIS_URL=redis://redis:6379/0
74
+ - USERNAME=${ARGILLA_USERNAME}
75
+ - PASSWORD=${ARGILLA_PASSWORD}
76
+ - API_KEY=${ARGILLA_API_KEY}
77
+ - WORKSPACE=default
78
+ volumes:
79
+ - argilla_data:/argilla
80
+ networks:
81
+ - app-network
82
+ depends_on:
83
+ elasticsearch:
84
+ condition: service_healthy
85
+ postgres:
86
+ condition: service_started
87
+ redis:
88
+ condition: service_started
89
+
90
+ worker:
91
+ image: argilla/argilla-server:latest
92
+ env_file:
93
+ - .env
94
+ environment:
95
+ - ARGILLA_HOME_PATH=/var/lib/argilla
96
+ - ARGILLA_ELASTICSEARCH=http://elasticsearch:9200
97
+ - ARGILLA_DATABASE_URL=postgresql+asyncpg://postgres:postgres@postgres:5432/argilla
98
+ - ARGILLA_REDIS_URL=redis://redis:6379/0
99
+ - BACKGROUND_NUM_WORKERS=2
100
+ - USERNAME=${ARGILLA_USERNAME}
101
+ - PASSWORD=${ARGILLA_PASSWORD}
102
+ - API_KEY=${ARGILLA_API_KEY}
103
+ - WORKSPACE=default
104
+ networks:
105
+ - app-network
106
+ depends_on:
107
+ - postgres
108
+ - elasticsearch
109
+ - redis
110
+ command: sh -c 'python -m argilla_server worker --num-workers $${BACKGROUND_NUM_WORKERS}'
111
+
112
+ volumes:
113
+ es_data:
114
+ name: synthetic-data-es
115
+ argilla_data:
116
+ name: synthetic-data-argilla
117
+ postgres_data:
118
+ name: synthetic-data-postgres
docker/ollama/compose.yml ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ services:
2
+ app:
3
+ extends:
4
+ file: docker-compose.yml
5
+ service: app
6
+ depends_on:
7
+ ollama:
8
+ condition: service_healthy
9
+ required: true
10
+ environment:
11
+ - OLLAMA_BASE_URL=http://ollama:11434
12
+
13
+ ollama:
14
+ image: ollama/ollama:${OLLAMA_HARDWARE:-latest}
15
+ ports:
16
+ - "11434:11434"
17
+ env_file:
18
+ - .env
19
+ environment:
20
+ - OLLAMA_BASE_URL=${OLLAMA_BASE_URL:-}
21
+ volumes:
22
+ - ollama_data:/root/.ollama
23
+ - ./docker/ollama/entrypoint.sh:/entrypoint.sh
24
+ networks:
25
+ - app-network
26
+ deploy:
27
+ resources:
28
+ reservations:
29
+ devices:
30
+ - driver: nvidia
31
+ count: all
32
+ capabilities: [gpu]
33
+ tty: true
34
+ entrypoint: ["/usr/bin/bash", "/entrypoint.sh"]
35
+ healthcheck:
36
+ test:
37
+ - "CMD-SHELL"
38
+ - |
39
+ test -f /tmp/ollama_ready && \
40
+ bash -c '</dev/tcp/localhost/11434'
41
+ interval: 10s
42
+ timeout: 10s
43
+ retries: 100
44
+ start_period: 10s
45
+
46
+ volumes:
47
+ ollama_data:
48
+ name: synthetic-data-ollama
docker/ollama/entrypoint.sh ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # Start Ollama in the background
4
+ /bin/ollama serve &
5
+ # Record Process ID
6
+ pid=$!
7
+
8
+ # Pause for Ollama to start
9
+ sleep 5
10
+
11
+ # Extract model name from MODEL variable (removing quotes if present)
12
+ MODEL_NAME=$(echo $MODEL | tr -d '"')
13
+
14
+ # Verificar que MODEL_NAME tenga un valor
15
+ if [ -z "$MODEL_NAME" ]; then
16
+ echo "❌ No model specified in MODEL environment variable"
17
+ else
18
+ # Check if model exists
19
+ if ollama list | grep -q "$MODEL_NAME"; then
20
+ echo "🟢 Model ($MODEL_NAME) already installed"
21
+ touch /tmp/ollama_ready
22
+ else
23
+ echo "🔴 Retrieving model ($MODEL_NAME)..."
24
+ # Intentar descargar el modelo sin crear el archivo hasta estar seguros
25
+ if ollama pull "$MODEL_NAME" 2>/dev/null && ollama list | grep -q "$MODEL_NAME"; then
26
+ echo "🟢 Model download complete!"
27
+ touch /tmp/ollama_ready
28
+ else
29
+ echo "❌ Error downloading model ($MODEL_NAME)"
30
+ fi
31
+ fi
32
+ fi
33
+
34
+ # Wait for Ollama process to finish
35
+ wait $pid
examples/argilla-deployment.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # /// script
2
+ # requires-python = ">=3.11,<3.12"
3
+ # dependencies = [
4
+ # "synthetic-dataset-generator",
5
+ # ]
6
+ # ///
7
+ import os
8
+
9
+ from synthetic_dataset_generator import launch
10
+
11
+ # Follow https://docs.argilla.io/latest/getting_started/quickstart/ to get your Argilla API key and URL
12
+ os.environ["HF_TOKEN"] = "hf_..."
13
+ os.environ["ARGILLA_API_URL"] = (
14
+ "https://[your-owner-name]-[your_space_name].hf.space" # argilla base url
15
+ )
16
+ os.environ["ARGILLA_API_KEY"] = "my_api_key" # argilla api key
17
+
18
+ launch()
examples/blog_private_synthetic_data_generation.md ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Private Synthetic Data Generation Made Easy: Out-of-the-Box with Docker, Argilla & Ollama
2
+
3
+ > "Empowering organizations with a turnkey solution for synthetic dataset creation in private environments."
4
+
5
+ The increasing adoption of AI solutions across industries has created an unprecedented demand for high-quality training data. As organizations scale their AI initiatives, they face the dual challenge of generating substantial, domain-specific datasets while ensuring data privacy and security. Traditional approaches often involve compromises: either using public datasets that may not fully align with specific needs, or investing heavily in custom data generation infrastructure.
6
+
7
+ The complexity of this challenge is amplified by regulatory requirements, resource constraints, and the need for specialized expertise. Organizations must navigate GDPR, CCPA, and industry-specific regulations while maintaining efficient data generation pipelines. This has created a pressing need for solutions that can operate entirely within private infrastructure while maintaining enterprise-grade capabilities.
8
+
9
+ ## The Challenge
10
+
11
+ The development of AI models requires extensive training data, yet organizations face significant obstacles in data generation and management. Privacy regulations and security requirements often prevent the use of public datasets or cloud-based generation services. Additionally, existing solutions typically demand complex infrastructure setups and significant technical expertise, increasing both implementation time and costs.
12
+
13
+ Modern enterprises require a solution that addresses several critical aspects:
14
+ 1. Data Privacy: Complete control over data generation and storage
15
+ 2. Infrastructure Flexibility: Deployment options that fit existing systems
16
+ 3. Quality Assurance: Tools for data validation and curation
17
+ 4. Scalability: Ability to grow with increasing data needs
18
+ 5. Cost Efficiency: Reduction in infrastructure and maintenance costs
19
+
20
+ ## The Solution
21
+
22
+ This out-of-the-box Synthetic Dataset Generator approach leverages the power of three technologies to create a seamless, private data generation pipeline. At its core is the [Synthetic Dataset Generator](https://github.com/argilla-io/synthetic-data-generator), a tool designed for dataset creation. [Ollama](https://ollama.ai/) ensures secure local LLM inference with [Distilabel](https://github.com/argilla-io/distilabel) integration, while [Argilla's](https://argilla.io/) data curation capabilities complete the workflow, all operating within your secure infrastructure.
23
+
24
+ This architecture delivers key technical advantages:
25
+ - Full data sovereignty with containerized local deployment
26
+ - End-to-end pipeline from generation to validation
27
+ - Modular design for system integration
28
+
29
+ Here's how it all fits together:
30
+
31
+ ![image/png](https://cdn-uploads.huggingface.co/production/uploads/64461026e1fd8d65b27e6187/Uz-kDOBrV-_GahUrc1K_O.png)
32
+
33
+ Let's explore how these components work together in a practical workflow.
34
+
35
+ ## 1. Installation & Setup
36
+
37
+
38
+
39
+ ### 1.1 Clone Repository
40
+ ```bash
41
+ git clone https://github.com/argilla-io/synthetic-data-generator
42
+ cd synthetic-data-generator
43
+ ```
44
+
45
+ ### 1.2 Environment Setup
46
+ ```bash
47
+ # Copy environment template
48
+ cp docker/.env.docker.template .env
49
+
50
+ # Model configuration in .env (if using Ollama)
51
+ MODEL="deepseek-r1:1.5b" # Must match Ollama model name
52
+ ```
53
+
54
+ ### 1.3 Build & Deploy Services
55
+ > Pro tip: Even if you're planning to use just one component initially, we recommend building all services to enable future functionality without rebuilding. For detailed deployment options, check the [Docker documentation](https://github.com/argilla-io/synthetic-data-generator/blob/main/docker/README.md).
56
+
57
+ > Note: Ollama runs on CPU/GPU for Linux/Windows in Docker. For macOS, only CPU is supported in Docker - for GPU support, install Ollama separately ([details](https://ollama.com/blog/ollama-is-now-available-as-an-official-docker-image)).
58
+
59
+ ```bash
60
+ # Build all services
61
+ docker compose -f docker-compose.yml -f docker/ollama/compose.yml -f docker/argilla/compose.yml build
62
+ # Start all services
63
+ docker compose -f docker-compose.yml -f docker/ollama/compose.yml -f docker/argilla/compose.yml up -d
64
+ ```
65
+
66
+ To view logs, either:
67
+ - Use Docker Desktop's interface
68
+ - Remove the `-d` flag when running the above command
69
+ - Or execute the following for specific service logs:
70
+ ```bash
71
+ # Core App logs
72
+ docker compose logs -f app
73
+ # Ollama logs
74
+ docker compose -f docker-compose.yml -f docker/ollama/compose.yml logs -f ollama
75
+ # Argilla logs
76
+ docker compose -f docker-compose.yml -f docker/argilla/compose.yml logs -f argilla
77
+ ```
78
+
79
+ ## 2. Dataset Generation
80
+
81
+ The tool currently supports **Text Classification**, **Chat**, and **RAG** datasets. These tasks will determine the type of dataset you will generate: classification requires categories, chat data requires a conversation format, and RAG requires question-answer pairs with relevant context, offering options for both retrieval and reranking data generation to enhance different aspects of information retrieval systems.
82
+
83
+ For a detailed overview of the generation process, check out the [introduction to the Synthetic Data Generator](https://huggingface.co/blog/synthetic-data-generator).
84
+
85
+
86
+ ### 2.1. **Dataset Description**
87
+
88
+ Let's walk through creating a **RAG dataset**.
89
+ ```text
90
+ A dataset to retrieve information from information security policies
91
+ ```
92
+
93
+ System initializes and processes the prompt:
94
+ ![image/png](https://cdn-uploads.huggingface.co/production/uploads/64461026e1fd8d65b27e6187/sxH8JChF-HnGMOilymYpA.png)
95
+
96
+
97
+ ### 2.2. **Task Configuration & Sample Generation**
98
+ System analyzes and generates the system prompt and optimal parameters automatically. Then, samples are generated for validation (modify system prompt or parameters manually if needed, then click save to generate sample data):
99
+ ![image/png](https://cdn-uploads.huggingface.co/production/uploads/64461026e1fd8d65b27e6187/mYVlGNnz6YNrPJutxmBtR.png)
100
+
101
+
102
+ ### 2.3. **Full Dataset Generation**
103
+ After validating the sample data quality, proceed with full dataset generation. Configure the following parameters:
104
+
105
+ - **Repository Owner**: Your Hugging Face username for dataset hosting
106
+ - **Dataset Name**: A descriptive name following standard naming conventions
107
+ - **Number of Examples**: Define dataset size (recommended: 100-1000 for initial deployments)
108
+ - **Temperature**: Controls generation creativity (default 0.7 balances coherence and diversity)
109
+ - **Privacy Settings**: Optional dataset privacy configuration for Hugging Face Hub
110
+
111
+ The temperature parameter significantly impacts output quality:
112
+ - 0.5-0.7: Optimal for technical documentation and factual content
113
+ - 0.7-0.8: Balanced for general purpose datasets
114
+ - 0.8-1.0: Increased creativity, suitable for conversational data
115
+
116
+
117
+ The system initiates the generation pipeline, leveraging Distilabel for structured output:
118
+ ![image/png](https://cdn-uploads.huggingface.co/production/uploads/64461026e1fd8d65b27e6187/PWNT_bLHwFjeoFX7AhA-z.png)
119
+
120
+
121
+ Upon completion, the dataset is pushed to Hugging Face Hub:
122
+ ![Generation Complete](https://cdn-uploads.huggingface.co/production/uploads/64461026e1fd8d65b27e6187/ohd4S-RyNI406uLPf4bnZ.png)
123
+
124
+ Access your generated dataset through the Hugging Face Hub interface:
125
+
126
+ <iframe
127
+ src="https://huggingface.co/datasets/daqc/info-security-policies-rag-distiset/embed/viewer/default/train"
128
+ frameborder="0"
129
+ width="100%"
130
+ height="560px"
131
+ ></iframe>
132
+
133
+
134
+
135
+ ## 3. Data Curation with Argilla
136
+
137
+ The integration with Argilla provides enterprise-grade dataset curation capabilities through a comprehensive review system. This phase is crucial for ensuring data quality and maintaining high standards in your training datasets.
138
+
139
+ ### Environment Configuration
140
+ Before accessing Argilla's features, ensure proper configuration in your `.env` file.
141
+
142
+
143
+ ### Curation Workflow
144
+
145
+ 1. **Dataset Integration**
146
+ Upon generation completion, the dataset is automatically ingested into Argilla. The system maintains data integrity and version control throughout the process. All datasets and progress persist across Docker restarts unless you explicitly remove the Argilla services and volumes.
147
+ ![image/png](https://cdn-uploads.huggingface.co/production/uploads/64461026e1fd8d65b27e6187/0gF6iLywhKafEo3z94cd-.png)
148
+
149
+
150
+ 2. **Quality Assurance Process**
151
+ Argilla's interface provides comprehensive tools for dataset validation:
152
+ - Semantic analysis of generated content
153
+ - Consistency checking across entries
154
+ - Metadata validation and enrichment
155
+ - Collaborative review capabilities
156
+
157
+ ![image/png](https://cdn-uploads.huggingface.co/production/uploads/64461026e1fd8d65b27e6187/h9kJ-4lA0LcFC8g6g_vwF.png)
158
+
159
+
160
+
161
+ 3. **Dataset Publication**
162
+ After thorough review, export your curated dataset to Hugging Face Hub:
163
+
164
+ > Note: Consider using a new repository name to preserve both raw and curated datasets separately.
165
+
166
+ - Configure repository settings
167
+ - Set visibility and access controls
168
+ - Add dataset cards and documentation
169
+
170
+ ![Export Configuration](https://cdn-uploads.huggingface.co/production/uploads/64461026e1fd8d65b27e6187/CPwtVr_Jw6mndNCOU2a5T.png)
171
+
172
+
173
+ The curated dataset maintains full provenance tracking and quality metrics:
174
+ <iframe
175
+ src="https://huggingface.co/datasets/daqc/info-security-policies-rag-distiset-argilla/embed/viewer/default/train"
176
+ frameborder="0"
177
+ width="100%"
178
+ height="560px"
179
+ ></iframe>
180
+
181
+ # 🎉 You're Done!
182
+ Congratulations! You've successfully completed the end-to-end dataset generation and curation process. Your curated dataset is now ready for model training.
183
+
184
+ ## Experience the Solution
185
+
186
+ For a hands-on preview of the Synthetic Dataset Generator's capabilities, explore the hosted space. This allows you to evaluate the interface and functionality before deploying your own instance:
187
+
188
+ <iframe
189
+ src="https://argilla-synthetic-data-generator.hf.space"
190
+ frameborder="0"
191
+ width="850"
192
+ height="450"
193
+ referrerpolicy="same-origin"
194
+ sandbox="allow-scripts"
195
+ ></iframe>
196
+
197
+ Create your own deployment by <a href="https://huggingface.co/spaces/argilla/synthetic-data-generator?duplicate=true">duplicating this Space</a>.
198
+
199
+ ## What's Next?
200
+
201
+ After successfully generating your first dataset, several advanced implementation paths are available:
202
+
203
+ Extend your dataset generation capabilities:
204
+ - [Fine-tune models on synthetic data](https://huggingface.co/blog/davidberenstein1957/fine-tune-a-smollm-on-synthetic-data-of-llm) for domain-specific tasks
205
+ - [Create specialized reasoning datasets](https://huggingface.co/blog/sdiazlor/fine-tune-deepseek-with-a-synthetic-reasoning-data) for advanced model training
206
+
207
+ ## Conclusion
208
+
209
+ The Synthetic Dataset Generator represents a significant advancement in private data generation technology, addressing the growing need for high-quality training data while maintaining security and control. By leveraging containerized architecture and local LLM inference, organizations can now generate custom datasets without compromising on data privacy or quality.
210
+
211
+ The solution's modular design enables seamless integration with existing ML pipelines while providing enterprise-grade features like persistent storage, comprehensive monitoring, and scalable infrastructure. Through collaborative validation workflows and structured quality control processes, teams can efficiently create and curate datasets tailored to their specific needs.
212
+
213
+ This combination of security, efficiency, and flexibility makes the Synthetic Dataset Generator an essential tool for organizations looking to accelerate their AI development while maintaining complete control over their data generation pipeline.
214
+
215
+ ## References & Documentation
216
+
217
+
218
+ - [Synthetic Dataset Generator](https://github.com/argilla-io/synthetic-data-generator): Open-source tool for dataset generation using natural language
219
+ - [Distilabel Framework](https://github.com/argilla-io/distilabel): Advanced dataset generation framework
220
+ - [Docker Best Practices](https://docs.docker.com/develop/develop-images/dockerfile_best-practices/): Container optimization guidelines
221
+ - [Argilla Documentation](https://docs.argilla.io): Data curation platform documentation
222
+ - [Ollama Integration](https://github.com/jmorganca/ollama): Local LLM deployment guide
examples/fine-tune-deepseek-reasoning-sft.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
examples/fine-tune-modernbert-classifier.ipynb ADDED
@@ -0,0 +1,538 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "# Fine-tune ModernBERT for text classification using synthetic data\n",
8
+ "\n",
9
+ "LLMs are great general purpose models, but they are not always the best choice for a specific task. Therefore, smaller and more specialized models are important for sustainable, efficient, and cheaper AI.\n",
10
+ "A lack of domain sepcific datasets is a common problem for smaller and more specialized models. This is because it is difficult to find a dataset that is both representative and diverse enough for a specific task. We solve this problem by generating a synthetic dataset from an LLM using the `synthetic-data-generator`, which is available as a [Hugging Face Space](https://huggingface.co/spaces/argilla/synthetic-data-generator) or on [GitHub](https://github.com/argilla-io/synthetic-data-generator).\n",
11
+ "\n",
12
+ "In this example, we will fine-tune a ModernBERT model on a synthetic dataset generated from the synthetic-data-generator. This demonstrates the effectiveness of synthetic data and the novel ModernBERT model, which is a new and improved version of BERT models, with an 8192 token context length, significantly better downstream performance, and much faster processing speeds.\n",
13
+ "\n",
14
+ "## Install the dependencies"
15
+ ]
16
+ },
17
+ {
18
+ "cell_type": "code",
19
+ "execution_count": null,
20
+ "metadata": {},
21
+ "outputs": [],
22
+ "source": [
23
+ "# Install Pytorch & other libraries\n",
24
+ "%pip install \"torch==2.5.0\" \"torchvision==0.20.0\" \n",
25
+ "%pip install \"setuptools<71.0.0\" scikit-learn \n",
26
+ " \n",
27
+ "# Install Hugging Face libraries\n",
28
+ "%pip install --upgrade \\\n",
29
+ " \"datasets==3.1.0\" \\\n",
30
+ " \"accelerate==1.2.1\" \\\n",
31
+ " \"hf-transfer==0.1.8\"\n",
32
+ " \n",
33
+ "# ModernBERT is not yet available in an official release, so we need to install it from github\n",
34
+ "%pip install \"git+https://github.com/huggingface/transformers.git@6e0515e99c39444caae39472ee1b2fd76ece32f1\" --upgrade"
35
+ ]
36
+ },
37
+ {
38
+ "cell_type": "markdown",
39
+ "metadata": {},
40
+ "source": [
41
+ "## The problem\n",
42
+ "\n",
43
+ "The [nvidia/domain-classifier](https://huggingface.co/nvidia/domain-classifier), is a model that can classify the domain of a text which can help with curating data. This model is cool but is based on the Deberta V3 Base, which is an outdated architecture that requires custom code to run, has a context length of 512 tokens, and is not as fast as the ModernBERT model. The labels for the model are:\n",
44
+ "\n",
45
+ "```\n",
46
+ "'Adult', 'Arts_and_Entertainment', 'Autos_and_Vehicles', 'Beauty_and_Fitness', 'Books_and_Literature', 'Business_and_Industrial', 'Computers_and_Electronics', 'Finance', 'Food_and_Drink', 'Games', 'Health', 'Hobbies_and_Leisure', 'Home_and_Garden', 'Internet_and_Telecom', 'Jobs_and_Education', 'Law_and_Government', 'News', 'Online_Communities', 'People_and_Society', 'Pets_and_Animals', 'Real_Estate', 'Science', 'Sensitive_Subjects', 'Shopping', 'Sports', 'Travel_and_Transportation'\n",
47
+ "```\n",
48
+ "\n",
49
+ "The data on which the model was trained is not available, so we cannot use it for our purposes. We can however generate a synthetic data to solve this problem."
50
+ ]
51
+ },
52
+ {
53
+ "cell_type": "markdown",
54
+ "metadata": {
55
+ "vscode": {
56
+ "languageId": "plaintext"
57
+ }
58
+ },
59
+ "source": [
60
+ "## Let's generate some data\n",
61
+ "\n",
62
+ "Let's go to the [hosted Hugging Face Space](https://huggingface.co/spaces/argilla/synthetic-data-generator) to generate the data. This is done in three steps 1) we come up with a dataset description, 2) iterate on the task configuration, and 3) generate and push the data to Hugging Face. A more detailed flow can be found in [this blogpost](https://huggingface.co/blog/synthetic-data-generator). \n",
63
+ "\n",
64
+ "<iframe\n",
65
+ "\tsrc=\"https://argilla-synthetic-data-generator.hf.space\"\n",
66
+ "\tframeborder=\"0\"\n",
67
+ "\twidth=\"850\"\n",
68
+ "\theight=\"450\"\n",
69
+ "></iframe>\n",
70
+ "\n",
71
+ "For this example, we will generate 1000 examples with a temperature of 1. After some iteration, we come up with the following system prompt:\n",
72
+ "\n",
73
+ "```\n",
74
+ "Long texts (at least 2000 words) from various media sources like Wikipedia, Reddit, Common Crawl, websites, commercials, online forums, books, newspapers and folders that cover multiple topics. Classify the text based on its main subject matter into one of the following categories\n",
75
+ "```\n",
76
+ "\n",
77
+ "We press the \"Push to Hub\" button and wait for the data to be generated. This takes a few minutes and we end up with a dataset with 1000 examples. The labels are nicely distributed across the categories, varied in length, and the texts look diverse and interesting.\n",
78
+ "\n",
79
+ "<iframe\n",
80
+ " src=\"https://huggingface.co/datasets/argilla/synthetic-domain-text-classification/embed/viewer/default/train\"\n",
81
+ " frameborder=\"0\"\n",
82
+ " width=\"100%\"\n",
83
+ " height=\"560px\"\n",
84
+ "></iframe>\n",
85
+ "\n",
86
+ "The data is pushed to Argilla to so we recommend inspecting and validating the labels before finetuning the model."
87
+ ]
88
+ },
89
+ {
90
+ "cell_type": "markdown",
91
+ "metadata": {},
92
+ "source": [
93
+ "## Finetuning the ModernBERT model\n",
94
+ "\n",
95
+ "We mostly rely on the blog from [Phillip Schmid](https://www.philschmid.de/fine-tune-modern-bert-in-2025). I will basic consumer hardware, my Apple M1 Max with 32GB of shared memory. We will use the `datasets` library to load the data and the `transformers` library to finetune the model."
96
+ ]
97
+ },
98
+ {
99
+ "cell_type": "code",
100
+ "execution_count": 1,
101
+ "metadata": {},
102
+ "outputs": [
103
+ {
104
+ "name": "stderr",
105
+ "output_type": "stream",
106
+ "text": [
107
+ "/Users/davidberenstein/Documents/programming/argilla/synthetic-data-generator/.venv/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
108
+ " from .autonotebook import tqdm as notebook_tqdm\n"
109
+ ]
110
+ },
111
+ {
112
+ "data": {
113
+ "text/plain": [
114
+ "{'text': 'Recently, there has been an increase in property values within the suburban areas of several cities due to improvements in infrastructure and lifestyle amenities such as parks, retail stores, and educational institutions nearby. Additionally, new housing developments are emerging, catering to different family needs with varying sizes and price ranges. These changes have influenced investment decisions for many looking to buy or sell properties.',\n",
115
+ " 'label': 14}"
116
+ ]
117
+ },
118
+ "execution_count": 1,
119
+ "metadata": {},
120
+ "output_type": "execute_result"
121
+ }
122
+ ],
123
+ "source": [
124
+ "from datasets import load_dataset\n",
125
+ "from datasets.arrow_dataset import Dataset\n",
126
+ "from datasets.dataset_dict import DatasetDict, IterableDatasetDict\n",
127
+ "from datasets.iterable_dataset import IterableDataset\n",
128
+ " \n",
129
+ "# Dataset id from huggingface.co/dataset\n",
130
+ "dataset_id = \"argilla/synthetic-domain-text-classification\"\n",
131
+ " \n",
132
+ "# Load raw dataset\n",
133
+ "train_dataset = load_dataset(dataset_id, split='train')\n",
134
+ "\n",
135
+ "split_dataset = train_dataset.train_test_split(test_size=0.1)\n",
136
+ "split_dataset['train'][0]"
137
+ ]
138
+ },
139
+ {
140
+ "cell_type": "markdown",
141
+ "metadata": {},
142
+ "source": [
143
+ "First, we need to tokenize the data. We will use the `AutoTokenizer` class from the `transformers` library to load the tokenizer."
144
+ ]
145
+ },
146
+ {
147
+ "cell_type": "code",
148
+ "execution_count": 2,
149
+ "metadata": {},
150
+ "outputs": [
151
+ {
152
+ "name": "stderr",
153
+ "output_type": "stream",
154
+ "text": [
155
+ "Map: 100%|██████████| 900/900 [00:00<00:00, 4787.61 examples/s]\n",
156
+ "Map: 100%|██████████| 100/100 [00:00<00:00, 4163.70 examples/s]\n"
157
+ ]
158
+ },
159
+ {
160
+ "data": {
161
+ "text/plain": [
162
+ "dict_keys(['labels', 'input_ids', 'attention_mask'])"
163
+ ]
164
+ },
165
+ "execution_count": 2,
166
+ "metadata": {},
167
+ "output_type": "execute_result"
168
+ }
169
+ ],
170
+ "source": [
171
+ "from transformers import AutoTokenizer\n",
172
+ " \n",
173
+ "# Model id to load the tokenizer\n",
174
+ "model_id = \"answerdotai/ModernBERT-base\"\n",
175
+ "\n",
176
+ "# Load Tokenizer\n",
177
+ "tokenizer = AutoTokenizer.from_pretrained(model_id)\n",
178
+ " \n",
179
+ "# Tokenize helper function\n",
180
+ "def tokenize(batch):\n",
181
+ " return tokenizer(batch['text'], padding=True, truncation=True, return_tensors=\"pt\")\n",
182
+ " \n",
183
+ "# Tokenize dataset\n",
184
+ "if \"label\" in split_dataset[\"train\"].features.keys():\n",
185
+ " split_dataset = split_dataset.rename_column(\"label\", \"labels\") # to match Trainer\n",
186
+ "tokenized_dataset = split_dataset.map(tokenize, batched=True, remove_columns=[\"text\"])\n",
187
+ " \n",
188
+ "tokenized_dataset[\"train\"].features.keys()"
189
+ ]
190
+ },
191
+ {
192
+ "cell_type": "markdown",
193
+ "metadata": {},
194
+ "source": [
195
+ "Now, we need to prepare the model. We will use the `AutoModelForSequenceClassification` class from the `transformers` library to load the model."
196
+ ]
197
+ },
198
+ {
199
+ "cell_type": "code",
200
+ "execution_count": 3,
201
+ "metadata": {},
202
+ "outputs": [
203
+ {
204
+ "name": "stderr",
205
+ "output_type": "stream",
206
+ "text": [
207
+ "Some weights of ModernBertForSequenceClassification were not initialized from the model checkpoint at answerdotai/ModernBERT-base and are newly initialized: ['classifier.bias', 'classifier.weight']\n",
208
+ "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
209
+ ]
210
+ }
211
+ ],
212
+ "source": [
213
+ "from transformers import AutoModelForSequenceClassification\n",
214
+ " \n",
215
+ "# Model id to load the tokenizer\n",
216
+ "model_id = \"answerdotai/ModernBERT-base\"\n",
217
+ " \n",
218
+ "# Prepare model labels - useful for inference\n",
219
+ "labels = tokenized_dataset[\"train\"].features[\"labels\"].names\n",
220
+ "num_labels = len(labels)\n",
221
+ "label2id, id2label = dict(), dict()\n",
222
+ "for i, label in enumerate(labels):\n",
223
+ " label2id[label] = str(i)\n",
224
+ " id2label[str(i)] = label\n",
225
+ " \n",
226
+ "# Download the model from huggingface.co/models\n",
227
+ "model = AutoModelForSequenceClassification.from_pretrained(\n",
228
+ " model_id, num_labels=num_labels, label2id=label2id, id2label=id2label,\n",
229
+ ")"
230
+ ]
231
+ },
232
+ {
233
+ "cell_type": "markdown",
234
+ "metadata": {},
235
+ "source": [
236
+ "We will use a simple F1 score as the evaluation metric."
237
+ ]
238
+ },
239
+ {
240
+ "cell_type": "code",
241
+ "execution_count": 4,
242
+ "metadata": {},
243
+ "outputs": [],
244
+ "source": [
245
+ "import numpy as np\n",
246
+ "from sklearn.metrics import f1_score\n",
247
+ " \n",
248
+ "# Metric helper method\n",
249
+ "def compute_metrics(eval_pred):\n",
250
+ " predictions, labels = eval_pred\n",
251
+ " predictions = np.argmax(predictions, axis=1)\n",
252
+ " score = f1_score(\n",
253
+ " labels, predictions, labels=labels, pos_label=1, average=\"weighted\"\n",
254
+ " )\n",
255
+ " return {\"f1\": float(score) if score == 1 else score}"
256
+ ]
257
+ },
258
+ {
259
+ "cell_type": "markdown",
260
+ "metadata": {},
261
+ "source": [
262
+ "Finally, we need to define the training arguments. We will use the `TrainingArguments` class from the `transformers` library to define the training arguments."
263
+ ]
264
+ },
265
+ {
266
+ "cell_type": "code",
267
+ "execution_count": 6,
268
+ "metadata": {},
269
+ "outputs": [
270
+ {
271
+ "name": "stderr",
272
+ "output_type": "stream",
273
+ "text": [
274
+ "/Users/davidberenstein/Documents/programming/argilla/synthetic-data-generator/.venv/lib/python3.11/site-packages/transformers/training_args.py:2241: UserWarning: `use_mps_device` is deprecated and will be removed in version 5.0 of 🤗 Transformers. `mps` device will be used by default if available similar to the way `cuda` device is used.Therefore, no action from user is required. \n",
275
+ " warnings.warn(\n"
276
+ ]
277
+ }
278
+ ],
279
+ "source": [
280
+ "from huggingface_hub import HfFolder\n",
281
+ "from transformers import Trainer, TrainingArguments\n",
282
+ " \n",
283
+ "# Define training args\n",
284
+ "training_args = TrainingArguments(\n",
285
+ " output_dir= \"ModernBERT-domain-classifier\",\n",
286
+ " per_device_train_batch_size=32,\n",
287
+ " per_device_eval_batch_size=16,\n",
288
+ " learning_rate=5e-5,\n",
289
+ "\t\tnum_train_epochs=5,\n",
290
+ " bf16=True, # bfloat16 training \n",
291
+ " optim=\"adamw_torch_fused\", # improved optimizer \n",
292
+ " # logging & evaluation strategies\n",
293
+ " logging_strategy=\"steps\",\n",
294
+ " logging_steps=100,\n",
295
+ " eval_strategy=\"epoch\",\n",
296
+ " save_strategy=\"epoch\",\n",
297
+ " save_total_limit=2,\n",
298
+ " load_best_model_at_end=True,\n",
299
+ " use_mps_device=True,\n",
300
+ " metric_for_best_model=\"f1\",\n",
301
+ " # push to hub parameters\n",
302
+ " push_to_hub=True,\n",
303
+ " hub_strategy=\"every_save\",\n",
304
+ " hub_token=HfFolder.get_token(),\n",
305
+ ")\n",
306
+ " \n",
307
+ "# Create a Trainer instance\n",
308
+ "trainer = Trainer(\n",
309
+ " model=model,\n",
310
+ " args=training_args,\n",
311
+ " train_dataset=tokenized_dataset[\"train\"],\n",
312
+ " eval_dataset=tokenized_dataset[\"test\"],\n",
313
+ " compute_metrics=compute_metrics,\n",
314
+ ")"
315
+ ]
316
+ },
317
+ {
318
+ "cell_type": "code",
319
+ "execution_count": 7,
320
+ "metadata": {},
321
+ "outputs": [
322
+ {
323
+ "name": "stderr",
324
+ "output_type": "stream",
325
+ "text": [
326
+ " \n",
327
+ " 20%|██ | 29/145 [11:32<33:16, 17.21s/it]"
328
+ ]
329
+ },
330
+ {
331
+ "name": "stdout",
332
+ "output_type": "stream",
333
+ "text": [
334
+ "{'eval_loss': 0.729780912399292, 'eval_f1': 0.7743598318036522, 'eval_runtime': 3.5337, 'eval_samples_per_second': 28.299, 'eval_steps_per_second': 1.981, 'epoch': 1.0}\n"
335
+ ]
336
+ },
337
+ {
338
+ "name": "stderr",
339
+ "output_type": "stream",
340
+ "text": [
341
+ " \n",
342
+ " 40%|████ | 58/145 [22:57<25:56, 17.89s/it]"
343
+ ]
344
+ },
345
+ {
346
+ "name": "stdout",
347
+ "output_type": "stream",
348
+ "text": [
349
+ "{'eval_loss': 0.4369044005870819, 'eval_f1': 0.8310764765820946, 'eval_runtime': 3.3266, 'eval_samples_per_second': 30.061, 'eval_steps_per_second': 2.104, 'epoch': 2.0}\n"
350
+ ]
351
+ },
352
+ {
353
+ "name": "stderr",
354
+ "output_type": "stream",
355
+ "text": [
356
+ " \n",
357
+ " 60%|██████ | 87/145 [35:16<17:06, 17.70s/it]"
358
+ ]
359
+ },
360
+ {
361
+ "name": "stdout",
362
+ "output_type": "stream",
363
+ "text": [
364
+ "{'eval_loss': 0.6091340184211731, 'eval_f1': 0.8399274488570763, 'eval_runtime': 3.2772, 'eval_samples_per_second': 30.514, 'eval_steps_per_second': 2.136, 'epoch': 3.0}\n"
365
+ ]
366
+ },
367
+ {
368
+ "name": "stderr",
369
+ "output_type": "stream",
370
+ "text": [
371
+ " 69%|██████▉ | 100/145 [41:03<18:02, 24.06s/it]"
372
+ ]
373
+ },
374
+ {
375
+ "name": "stdout",
376
+ "output_type": "stream",
377
+ "text": [
378
+ "{'loss': 0.7663, 'grad_norm': 7.232136249542236, 'learning_rate': 1.5517241379310346e-05, 'epoch': 3.45}\n"
379
+ ]
380
+ },
381
+ {
382
+ "name": "stderr",
383
+ "output_type": "stream",
384
+ "text": [
385
+ " \n",
386
+ " 80%|████████ | 116/145 [47:23<08:50, 18.30s/it]"
387
+ ]
388
+ },
389
+ {
390
+ "name": "stdout",
391
+ "output_type": "stream",
392
+ "text": [
393
+ "{'eval_loss': 0.43516409397125244, 'eval_f1': 0.8797674004703547, 'eval_runtime': 3.2975, 'eval_samples_per_second': 30.326, 'eval_steps_per_second': 2.123, 'epoch': 4.0}\n"
394
+ ]
395
+ },
396
+ {
397
+ "name": "stderr",
398
+ "output_type": "stream",
399
+ "text": [
400
+ " \n",
401
+ "100%|██████████| 145/145 [1:00:40<00:00, 19.18s/it]"
402
+ ]
403
+ },
404
+ {
405
+ "name": "stdout",
406
+ "output_type": "stream",
407
+ "text": [
408
+ "{'eval_loss': 0.39272159337997437, 'eval_f1': 0.8914389523348718, 'eval_runtime': 3.5564, 'eval_samples_per_second': 28.118, 'eval_steps_per_second': 1.968, 'epoch': 5.0}\n"
409
+ ]
410
+ },
411
+ {
412
+ "name": "stderr",
413
+ "output_type": "stream",
414
+ "text": [
415
+ "100%|██████████| 145/145 [1:00:42<00:00, 25.12s/it]\n"
416
+ ]
417
+ },
418
+ {
419
+ "name": "stdout",
420
+ "output_type": "stream",
421
+ "text": [
422
+ "{'train_runtime': 3642.7783, 'train_samples_per_second': 1.235, 'train_steps_per_second': 0.04, 'train_loss': 0.535627057634551, 'epoch': 5.0}\n"
423
+ ]
424
+ },
425
+ {
426
+ "name": "stderr",
427
+ "output_type": "stream",
428
+ "text": [
429
+ "events.out.tfevents.1735555878.Davids-MacBook-Pro.local.23438.0: 100%|██████████| 9.32k/9.32k [00:00<00:00, 55.0kB/s]\n"
430
+ ]
431
+ },
432
+ {
433
+ "data": {
434
+ "text/plain": [
435
+ "CommitInfo(commit_url='https://huggingface.co/davidberenstein1957/domain-classifier/commit/915f4b03c230cc8f376f13729728f14347400041', commit_message='End of training', commit_description='', oid='915f4b03c230cc8f376f13729728f14347400041', pr_url=None, repo_url=RepoUrl('https://huggingface.co/davidberenstein1957/domain-classifier', endpoint='https://huggingface.co', repo_type='model', repo_id='davidberenstein1957/domain-classifier'), pr_revision=None, pr_num=None)"
436
+ ]
437
+ },
438
+ "execution_count": 7,
439
+ "metadata": {},
440
+ "output_type": "execute_result"
441
+ }
442
+ ],
443
+ "source": [
444
+ "trainer.train()\n",
445
+ "# Save processor and create model card\n",
446
+ "tokenizer.save_pretrained(\"ModernBERT-domain-classifier\")\n",
447
+ "trainer.create_model_card()\n",
448
+ "trainer.push_to_hub()"
449
+ ]
450
+ },
451
+ {
452
+ "cell_type": "markdown",
453
+ "metadata": {},
454
+ "source": [
455
+ "We get an F1 score of 0.89 on the test set, which is pretty good for the small dataset and time spent."
456
+ ]
457
+ },
458
+ {
459
+ "cell_type": "markdown",
460
+ "metadata": {},
461
+ "source": [
462
+ "## Run inference\n",
463
+ "\n",
464
+ "We can now load the model and run inference."
465
+ ]
466
+ },
467
+ {
468
+ "cell_type": "code",
469
+ "execution_count": 11,
470
+ "metadata": {},
471
+ "outputs": [
472
+ {
473
+ "name": "stderr",
474
+ "output_type": "stream",
475
+ "text": [
476
+ "Device set to use mps:0\n"
477
+ ]
478
+ },
479
+ {
480
+ "data": {
481
+ "text/plain": [
482
+ "[{'label': 'health', 'score': 0.6779336333274841}]"
483
+ ]
484
+ },
485
+ "execution_count": 11,
486
+ "metadata": {},
487
+ "output_type": "execute_result"
488
+ }
489
+ ],
490
+ "source": [
491
+ "from transformers import pipeline\n",
492
+ " \n",
493
+ "# load model from huggingface.co/models using our repository id\n",
494
+ "classifier = pipeline(\n",
495
+ " task=\"text-classification\", \n",
496
+ " model=\"argilla/ModernBERT-domain-classifier\", \n",
497
+ " device=0,\n",
498
+ ")\n",
499
+ " \n",
500
+ "sample = \"Smoking is bad for your health.\"\n",
501
+ " \n",
502
+ "classifier(sample)"
503
+ ]
504
+ },
505
+ {
506
+ "cell_type": "markdown",
507
+ "metadata": {},
508
+ "source": [
509
+ "## Conclusion\n",
510
+ "\n",
511
+ "We have shown that we can generate a synthetic dataset from an LLM and finetune a ModernBERT model on it. This the effectiveness of synthetic data and the novel ModernBERT model, which is new and improved version of BERT models, with 8192 token context length, significantly better downstream performance, and much faster processing speeds. \n",
512
+ "\n",
513
+ "Pretty cool for 20 minutes of generating data, and an hour of fine-tuning on consumer hardware."
514
+ ]
515
+ }
516
+ ],
517
+ "metadata": {
518
+ "kernelspec": {
519
+ "display_name": ".venv",
520
+ "language": "python",
521
+ "name": "python3"
522
+ },
523
+ "language_info": {
524
+ "codemirror_mode": {
525
+ "name": "ipython",
526
+ "version": 3
527
+ },
528
+ "file_extension": ".py",
529
+ "mimetype": "text/x-python",
530
+ "name": "python",
531
+ "nbconvert_exporter": "python",
532
+ "pygments_lexer": "ipython3",
533
+ "version": "3.11.11"
534
+ }
535
+ },
536
+ "nbformat": 4,
537
+ "nbformat_minor": 2
538
+ }
examples/fine-tune-modernbert-rag.ipynb ADDED
@@ -0,0 +1,980 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "# Fine-tune ModernBERT with Synthetic Data for RAG\n",
8
+ "\n",
9
+ "This notebook demonstrates the fine-tuning process of `modernbert-embed-base` using synthetic data tailored for the Retrieval-Augmented Generation (RAG) model.\n",
10
+ "\n",
11
+ "It provides a complete walkthrough of the fine-tuning process after generating synthetic data using the Synthetic Data Generator. For a comprehensive explanation of the methodology and additional details, refer to the blog post: [Fine-tune ModernBERT for RAG with Synthetic Data](https://huggingface.co/blog/fine-tune-modernbert-for-rag-with-synthetic-data)."
12
+ ]
13
+ },
14
+ {
15
+ "cell_type": "markdown",
16
+ "metadata": {},
17
+ "source": [
18
+ "## Getting Started"
19
+ ]
20
+ },
21
+ {
22
+ "cell_type": "markdown",
23
+ "metadata": {},
24
+ "source": [
25
+ "### Install the Dependencies"
26
+ ]
27
+ },
28
+ {
29
+ "cell_type": "code",
30
+ "execution_count": null,
31
+ "metadata": {},
32
+ "outputs": [],
33
+ "source": [
34
+ "!pip install torch\n",
35
+ "!pip install datasets\n",
36
+ "!pip install sentence-transformers\n",
37
+ "!pip install haystack-ai\n",
38
+ "!pip install git+https://github.com/huggingface/transformers.git # for the latest version of transformers"
39
+ ]
40
+ },
41
+ {
42
+ "cell_type": "markdown",
43
+ "metadata": {},
44
+ "source": [
45
+ "### Import the Required Libraries"
46
+ ]
47
+ },
48
+ {
49
+ "cell_type": "code",
50
+ "execution_count": 1,
51
+ "metadata": {},
52
+ "outputs": [],
53
+ "source": [
54
+ "import torch\n",
55
+ "from torch.utils.data import DataLoader\n",
56
+ "\n",
57
+ "from datasets import load_dataset, concatenate_datasets, Dataset, DatasetDict\n",
58
+ "\n",
59
+ "\n",
60
+ "from sentence_transformers import (\n",
61
+ " SentenceTransformer,\n",
62
+ " SentenceTransformerModelCardData,\n",
63
+ " CrossEncoder,\n",
64
+ " InputExample,\n",
65
+ " SentenceTransformerTrainer,\n",
66
+ ")\n",
67
+ "from sentence_transformers.losses import TripletLoss\n",
68
+ "from sentence_transformers.training_args import (\n",
69
+ " SentenceTransformerTrainingArguments,\n",
70
+ " BatchSamplers,\n",
71
+ ")\n",
72
+ "from sentence_transformers.evaluation import TripletEvaluator\n",
73
+ "from sentence_transformers.cross_encoder.evaluation import CECorrelationEvaluator\n",
74
+ "\n",
75
+ "\n",
76
+ "from haystack import Document, Pipeline\n",
77
+ "from haystack.document_stores.in_memory import InMemoryDocumentStore\n",
78
+ "from haystack.components.embedders import (\n",
79
+ " SentenceTransformersDocumentEmbedder,\n",
80
+ " SentenceTransformersTextEmbedder,\n",
81
+ ")\n",
82
+ "from haystack.components.rankers import SentenceTransformersDiversityRanker\n",
83
+ "from haystack.components.retrievers.in_memory import InMemoryEmbeddingRetriever\n",
84
+ "from haystack.components.builders import ChatPromptBuilder\n",
85
+ "from haystack.components.generators.chat import HuggingFaceAPIChatGenerator\n",
86
+ "from haystack.dataclasses import ChatMessage\n",
87
+ "from haystack.utils import Secret\n",
88
+ "from haystack.utils.hf import HFGenerationAPIType"
89
+ ]
90
+ },
91
+ {
92
+ "cell_type": "markdown",
93
+ "metadata": {},
94
+ "source": [
95
+ "### Configure the Environment"
96
+ ]
97
+ },
98
+ {
99
+ "cell_type": "code",
100
+ "execution_count": 2,
101
+ "metadata": {},
102
+ "outputs": [],
103
+ "source": [
104
+ "MODEL = \"nomic-ai/modernbert-embed-base\"\n",
105
+ "REPO_NAME = \"sdiazlor\" # your HF username here\n",
106
+ "MODEL_NAME_BIENCODER = \"modernbert-embed-base-biencoder-human-rights\"\n",
107
+ "MODEL_NAME_CROSSENCODER = \"modernbert-embed-base-crossencoder-human-rights\""
108
+ ]
109
+ },
110
+ {
111
+ "cell_type": "code",
112
+ "execution_count": null,
113
+ "metadata": {},
114
+ "outputs": [
115
+ {
116
+ "name": "stdout",
117
+ "output_type": "stream",
118
+ "text": [
119
+ "Using device: mps\n"
120
+ ]
121
+ }
122
+ ],
123
+ "source": [
124
+ "if torch.backends.mps.is_available():\n",
125
+ " device = \"mps\"\n",
126
+ "elif torch.cuda.is_available():\n",
127
+ " device = \"cuda\"\n",
128
+ "else:\n",
129
+ " device = \"cpu\"\n",
130
+ "\n",
131
+ "print(f\"Using device: {device}\")"
132
+ ]
133
+ },
134
+ {
135
+ "cell_type": "markdown",
136
+ "metadata": {},
137
+ "source": [
138
+ "## Pre-process the Synthetic Data"
139
+ ]
140
+ },
141
+ {
142
+ "cell_type": "code",
143
+ "execution_count": 3,
144
+ "metadata": {},
145
+ "outputs": [
146
+ {
147
+ "data": {
148
+ "text/plain": [
149
+ "Dataset({\n",
150
+ " features: ['context', 'question', 'response', 'positive_retrieval', 'negative_retrieval', 'positive_reranking', 'negative_reranking'],\n",
151
+ " num_rows: 1000\n",
152
+ "})"
153
+ ]
154
+ },
155
+ "execution_count": 3,
156
+ "metadata": {},
157
+ "output_type": "execute_result"
158
+ }
159
+ ],
160
+ "source": [
161
+ "# Combine the generated datasets from files and prompts\n",
162
+ "\n",
163
+ "dataset_rag_from_file = load_dataset(f\"{REPO_NAME}/rag-human-rights-from-files\", split=\"train\")\n",
164
+ "dataset_rag_from_prompt = load_dataset(f\"{REPO_NAME}/rag-human-rights-from-prompt\", split=\"train\")\n",
165
+ "\n",
166
+ "combined_rag_dataset = concatenate_datasets(\n",
167
+ " [dataset_rag_from_file, dataset_rag_from_prompt]\n",
168
+ ")\n",
169
+ "\n",
170
+ "combined_rag_dataset"
171
+ ]
172
+ },
173
+ {
174
+ "cell_type": "code",
175
+ "execution_count": null,
176
+ "metadata": {},
177
+ "outputs": [
178
+ {
179
+ "data": {
180
+ "text/plain": [
181
+ "Dataset({\n",
182
+ " features: ['context', 'question', 'response', 'positive_retrieval', 'negative_retrieval', 'positive_reranking', 'negative_reranking'],\n",
183
+ " num_rows: 828\n",
184
+ "})"
185
+ ]
186
+ },
187
+ "execution_count": 6,
188
+ "metadata": {},
189
+ "output_type": "execute_result"
190
+ }
191
+ ],
192
+ "source": [
193
+ "# Filter out examples with empty or NaN values\n",
194
+ "\n",
195
+ "def filter_empty_or_nan(example):\n",
196
+ " return all(\n",
197
+ " value is not None and str(value).strip() != \"\" for value in example.values()\n",
198
+ " )\n",
199
+ "\n",
200
+ "filtered_rag_dataset = combined_rag_dataset.filter(filter_empty_or_nan).shuffle(seed=42)\n",
201
+ "filtered_rag_dataset"
202
+ ]
203
+ },
204
+ {
205
+ "cell_type": "code",
206
+ "execution_count": null,
207
+ "metadata": {},
208
+ "outputs": [
209
+ {
210
+ "name": "stdout",
211
+ "output_type": "stream",
212
+ "text": [
213
+ "Dataset({\n",
214
+ " features: ['anchor', 'positive', 'negative'],\n",
215
+ " num_rows: 828\n",
216
+ "})\n",
217
+ "Dataset({\n",
218
+ " features: ['anchor', 'positive'],\n",
219
+ " num_rows: 828\n",
220
+ "})\n"
221
+ ]
222
+ }
223
+ ],
224
+ "source": [
225
+ "# Rename, select and reorder columns according to the expected format for the SentenceTransformer and CrossEncoder models\n",
226
+ "\n",
227
+ "def rename_and_reorder_columns(dataset, rename_map, selected_columns):\n",
228
+ " for old_name, new_name in rename_map.items():\n",
229
+ " if old_name in dataset.column_names:\n",
230
+ " dataset = dataset.rename_column(old_name, new_name)\n",
231
+ " dataset = dataset.select_columns(selected_columns)\n",
232
+ " return dataset\n",
233
+ "\n",
234
+ "clean_rag_dataset_biencoder = rename_and_reorder_columns(\n",
235
+ " filtered_rag_dataset,\n",
236
+ " rename_map={\"context\": \"anchor\", \"positive_retrieval\": \"positive\", \"negative_retrieval\": \"negative\"},\n",
237
+ " selected_columns=[\"anchor\", \"positive\", \"negative\"],\n",
238
+ ")\n",
239
+ "\n",
240
+ "clean_rag_dataset_crossencoder = rename_and_reorder_columns(\n",
241
+ " filtered_rag_dataset,\n",
242
+ " rename_map={\"context\": \"anchor\", \"positive_retrieval\": \"positive\"}, #TODO\n",
243
+ " selected_columns=[\"anchor\", \"positive\"],\n",
244
+ ")\n",
245
+ "\n",
246
+ "print(clean_rag_dataset_biencoder)\n",
247
+ "print(clean_rag_dataset_crossencoder)"
248
+ ]
249
+ },
250
+ {
251
+ "cell_type": "code",
252
+ "execution_count": null,
253
+ "metadata": {},
254
+ "outputs": [
255
+ {
256
+ "name": "stderr",
257
+ "output_type": "stream",
258
+ "text": [
259
+ "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at Snowflake/snowflake-arctic-embed-m-v1.5 and are newly initialized: ['classifier.bias', 'classifier.weight', 'pooler.dense.bias', 'pooler.dense.weight']\n",
260
+ "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
261
+ ]
262
+ },
263
+ {
264
+ "data": {
265
+ "application/vnd.jupyter.widget-view+json": {
266
+ "model_id": "406c4d22f43f41d592d3b94da2955444",
267
+ "version_major": 2,
268
+ "version_minor": 0
269
+ },
270
+ "text/plain": [
271
+ "Map: 0%| | 0/828 [00:00<?, ? examples/s]"
272
+ ]
273
+ },
274
+ "metadata": {},
275
+ "output_type": "display_data"
276
+ },
277
+ {
278
+ "data": {
279
+ "text/plain": [
280
+ "Dataset({\n",
281
+ " features: ['anchor', 'positive', 'score'],\n",
282
+ " num_rows: 828\n",
283
+ "})"
284
+ ]
285
+ },
286
+ "execution_count": 8,
287
+ "metadata": {},
288
+ "output_type": "execute_result"
289
+ }
290
+ ],
291
+ "source": [
292
+ "# Add scores to train the CrossEncoder model, which requires sentence pairs with a score indicating how related they are.\n",
293
+ "# Check the available models: https://huggingface.co/spaces/mteb/leaderboard\n",
294
+ "\n",
295
+ "model_reranking = CrossEncoder(\n",
296
+ " model_name=\"Snowflake/snowflake-arctic-embed-m-v1.5\", device=device\n",
297
+ ")\n",
298
+ "\n",
299
+ "def add_reranking_scores(batch):\n",
300
+ " pairs = list(zip(batch[\"anchor\"], batch[\"positive\"]))\n",
301
+ " batch[\"score\"] = model_reranking.predict(pairs)\n",
302
+ " return batch\n",
303
+ "\n",
304
+ "clean_rag_dataset_crossencoder = clean_rag_dataset_crossencoder.map(\n",
305
+ " add_reranking_scores, batched=True, batch_size=250\n",
306
+ ")\n",
307
+ "clean_rag_dataset_crossencoder"
308
+ ]
309
+ },
310
+ {
311
+ "cell_type": "code",
312
+ "execution_count": null,
313
+ "metadata": {},
314
+ "outputs": [
315
+ {
316
+ "name": "stdout",
317
+ "output_type": "stream",
318
+ "text": [
319
+ "DatasetDict({\n",
320
+ " train: Dataset({\n",
321
+ " features: ['anchor', 'positive', 'negative'],\n",
322
+ " num_rows: 662\n",
323
+ " })\n",
324
+ " eval: Dataset({\n",
325
+ " features: ['anchor', 'positive', 'negative'],\n",
326
+ " num_rows: 166\n",
327
+ " })\n",
328
+ "})\n",
329
+ "DatasetDict({\n",
330
+ " train: Dataset({\n",
331
+ " features: ['anchor', 'positive', 'score'],\n",
332
+ " num_rows: 662\n",
333
+ " })\n",
334
+ " eval: Dataset({\n",
335
+ " features: ['anchor', 'positive', 'score'],\n",
336
+ " num_rows: 166\n",
337
+ " })\n",
338
+ "})\n"
339
+ ]
340
+ }
341
+ ],
342
+ "source": [
343
+ "# Split the datasets into training and evaluation sets\n",
344
+ "def split_dataset(dataset, train_size=0.8, seed=42):\n",
345
+ " train_eval_split = dataset.train_test_split(test_size=1 - train_size, seed=seed)\n",
346
+ "\n",
347
+ " dataset_dict = DatasetDict(\n",
348
+ " {\"train\": train_eval_split[\"train\"], \"eval\": train_eval_split[\"test\"]}\n",
349
+ " )\n",
350
+ "\n",
351
+ " return dataset_dict\n",
352
+ "\n",
353
+ "dataset_rag_biencoder = split_dataset(clean_rag_dataset_biencoder)\n",
354
+ "dataset_rag_crossencoder = split_dataset(clean_rag_dataset_crossencoder)\n",
355
+ "\n",
356
+ "print(dataset_rag_biencoder)\n",
357
+ "print(dataset_rag_crossencoder)"
358
+ ]
359
+ },
360
+ {
361
+ "cell_type": "markdown",
362
+ "metadata": {},
363
+ "source": [
364
+ "## Train the Bi-Encoder model for Retrieval"
365
+ ]
366
+ },
367
+ {
368
+ "cell_type": "code",
369
+ "execution_count": null,
370
+ "metadata": {},
371
+ "outputs": [],
372
+ "source": [
373
+ "# Load the base model and create the SentenceTransformer model\n",
374
+ "model_biencoder = SentenceTransformer(\n",
375
+ " MODEL,\n",
376
+ " model_card_data=SentenceTransformerModelCardData(\n",
377
+ " language=\"en\",\n",
378
+ " license=\"apache-2.0\",\n",
379
+ " model_name=MODEL_NAME_BIENCODER,\n",
380
+ " ),\n",
381
+ ")\n",
382
+ "model_biencoder.gradient_checkpointing_enable() # Enable gradient checkpointing to save memory"
383
+ ]
384
+ },
385
+ {
386
+ "cell_type": "code",
387
+ "execution_count": null,
388
+ "metadata": {},
389
+ "outputs": [],
390
+ "source": [
391
+ "# Select the TripleLoss loss function which requires sentence triplets (anchor, positive, negative)\n",
392
+ "# Check the available losses: https://sbert.net/docs/sentence_transformer/loss_overview.html\n",
393
+ "\n",
394
+ "loss_biencoder = TripletLoss"
395
+ ]
396
+ },
397
+ {
398
+ "cell_type": "code",
399
+ "execution_count": null,
400
+ "metadata": {},
401
+ "outputs": [
402
+ {
403
+ "name": "stderr",
404
+ "output_type": "stream",
405
+ "text": [
406
+ "/Users/sdiazlor/.pyenv/versions/3.11.4/envs/distilabel-tutorials/lib/python3.11/site-packages/transformers/training_args.py:2243: UserWarning: `use_mps_device` is deprecated and will be removed in version 5.0 of 🤗 Transformers. `mps` device will be used by default if available similar to the way `cuda` device is used.Therefore, no action from user is required. \n",
407
+ " warnings.warn(\n"
408
+ ]
409
+ }
410
+ ],
411
+ "source": [
412
+ "# Define the training arguments for the SentenceTransformer model\n",
413
+ "# Customize them as needed for your requirements\n",
414
+ "\n",
415
+ "training_args = SentenceTransformerTrainingArguments(\n",
416
+ " output_dir=f\"models/{MODEL_NAME_BIENCODER}\",\n",
417
+ " num_train_epochs=3,\n",
418
+ " per_device_train_batch_size=4,\n",
419
+ " gradient_accumulation_steps=4,\n",
420
+ " per_device_eval_batch_size=4,\n",
421
+ " warmup_ratio=0.1,\n",
422
+ " learning_rate=2e-5,\n",
423
+ " lr_scheduler_type=\"cosine\",\n",
424
+ " fp16=False, # or True if stable on your MPS device\n",
425
+ " bf16=False,\n",
426
+ " batch_sampler=BatchSamplers.NO_DUPLICATES,\n",
427
+ " eval_strategy=\"epoch\",\n",
428
+ " save_strategy=\"epoch\",\n",
429
+ " save_total_limit=2,\n",
430
+ " logging_steps=100,\n",
431
+ " load_best_model_at_end=True,\n",
432
+ " use_mps_device=(device == \"mps\"),\n",
433
+ ")"
434
+ ]
435
+ },
436
+ {
437
+ "cell_type": "code",
438
+ "execution_count": null,
439
+ "metadata": {},
440
+ "outputs": [],
441
+ "source": [
442
+ "# Define the evaluator to assess the performance of the model\n",
443
+ "triplet_evaluator = TripletEvaluator(\n",
444
+ " anchors=dataset_rag_biencoder[\"eval\"][\"anchor\"],\n",
445
+ " positives=dataset_rag_biencoder[\"eval\"][\"positive\"],\n",
446
+ " negatives=dataset_rag_biencoder[\"eval\"][\"negative\"],\n",
447
+ ")"
448
+ ]
449
+ },
450
+ {
451
+ "cell_type": "code",
452
+ "execution_count": null,
453
+ "metadata": {},
454
+ "outputs": [
455
+ {
456
+ "name": "stderr",
457
+ "output_type": "stream",
458
+ "text": [
459
+ "/Users/sdiazlor/.pyenv/versions/3.11.4/envs/distilabel-tutorials/lib/python3.11/site-packages/torch/utils/checkpoint.py:295: FutureWarning: `torch.cpu.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cpu', args...)` instead.\n",
460
+ " with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs): # type: ignore[attr-defined]\n"
461
+ ]
462
+ },
463
+ {
464
+ "data": {
465
+ "text/html": [
466
+ "\n",
467
+ " <div>\n",
468
+ " \n",
469
+ " <progress value='123' max='123' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
470
+ " [123/123 25:34, Epoch 2/3]\n",
471
+ " </div>\n",
472
+ " <table border=\"1\" class=\"dataframe\">\n",
473
+ " <thead>\n",
474
+ " <tr style=\"text-align: left;\">\n",
475
+ " <th>Epoch</th>\n",
476
+ " <th>Training Loss</th>\n",
477
+ " <th>Validation Loss</th>\n",
478
+ " <th>Cosine Accuracy</th>\n",
479
+ " </tr>\n",
480
+ " </thead>\n",
481
+ " <tbody>\n",
482
+ " <tr>\n",
483
+ " <td>1</td>\n",
484
+ " <td>No log</td>\n",
485
+ " <td>3.655929</td>\n",
486
+ " <td>0.969880</td>\n",
487
+ " </tr>\n",
488
+ " <tr>\n",
489
+ " <td>2</td>\n",
490
+ " <td>14.374000</td>\n",
491
+ " <td>3.498395</td>\n",
492
+ " <td>0.981928</td>\n",
493
+ " </tr>\n",
494
+ " </tbody>\n",
495
+ "</table><p>"
496
+ ],
497
+ "text/plain": [
498
+ "<IPython.core.display.HTML object>"
499
+ ]
500
+ },
501
+ "metadata": {},
502
+ "output_type": "display_data"
503
+ },
504
+ {
505
+ "data": {
506
+ "application/vnd.jupyter.widget-view+json": {
507
+ "model_id": "faad6e9752f34babadff7a966ae55d87",
508
+ "version_major": 2,
509
+ "version_minor": 0
510
+ },
511
+ "text/plain": [
512
+ "Computing widget examples: 0%| | 0/1 [00:00<?, ?example/s]"
513
+ ]
514
+ },
515
+ "metadata": {},
516
+ "output_type": "display_data"
517
+ },
518
+ {
519
+ "name": "stderr",
520
+ "output_type": "stream",
521
+ "text": [
522
+ "/Users/sdiazlor/.pyenv/versions/3.11.4/envs/distilabel-tutorials/lib/python3.11/site-packages/torch/utils/checkpoint.py:295: FutureWarning: `torch.cpu.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cpu', args...)` instead.\n",
523
+ " with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs): # type: ignore[attr-defined]\n",
524
+ "/Users/sdiazlor/.pyenv/versions/3.11.4/envs/distilabel-tutorials/lib/python3.11/site-packages/torch/utils/checkpoint.py:295: FutureWarning: `torch.cpu.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cpu', args...)` instead.\n",
525
+ " with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs): # type: ignore[attr-defined]\n"
526
+ ]
527
+ }
528
+ ],
529
+ "source": [
530
+ "# Train the model. This will take some time depending on the size of the dataset and the model\n",
531
+ "# Remember to adjust the training arguments according to your requirements\n",
532
+ "\n",
533
+ "trainer = SentenceTransformerTrainer(\n",
534
+ " model=model_biencoder,\n",
535
+ " args=training_args,\n",
536
+ " train_dataset=dataset_rag_biencoder[\"train\"],\n",
537
+ " eval_dataset=dataset_rag_biencoder[\"eval\"],\n",
538
+ " loss=loss_biencoder,\n",
539
+ " evaluator=triplet_evaluator,\n",
540
+ ")\n",
541
+ "trainer.train()"
542
+ ]
543
+ },
544
+ {
545
+ "cell_type": "code",
546
+ "execution_count": null,
547
+ "metadata": {},
548
+ "outputs": [],
549
+ "source": [
550
+ "# Save the model to the local directory and push it to the Hub\n",
551
+ "model_biencoder.save_pretrained(f\"models/{MODEL_NAME_BIENCODER}\")\n",
552
+ "model_biencoder.push_to_hub(f\"{REPO_NAME}/{MODEL_NAME_BIENCODER}\")"
553
+ ]
554
+ },
555
+ {
556
+ "cell_type": "markdown",
557
+ "metadata": {},
558
+ "source": [
559
+ "## Train the Cross-Encoder model for Ranking"
560
+ ]
561
+ },
562
+ {
563
+ "cell_type": "code",
564
+ "execution_count": null,
565
+ "metadata": {},
566
+ "outputs": [],
567
+ "source": [
568
+ "# Prepare the training and evaluation samples for the CrossEncoder model\n",
569
+ "\n",
570
+ "train_samples = []\n",
571
+ "for row in dataset_rag_crossencoder[\"train\"]:\n",
572
+ " # Suppose 'score' is a float or an integer that you want to predict\n",
573
+ " train_samples.append(\n",
574
+ " InputExample(texts=[row[\"anchor\"], row[\"positive\"]], label=float(row[\"score\"]))\n",
575
+ " )\n",
576
+ "\n",
577
+ "eval_samples = []\n",
578
+ "for row in dataset_rag_crossencoder[\"eval\"]:\n",
579
+ " eval_samples.append(\n",
580
+ " InputExample(texts=[row[\"anchor\"], row[\"positive\"]], label=float(row[\"score\"]))\n",
581
+ " )\n",
582
+ "\n",
583
+ "# Initialize the DataLoader for the training samples\n",
584
+ "train_dataloader = DataLoader(train_samples, shuffle=True, batch_size=4)"
585
+ ]
586
+ },
587
+ {
588
+ "cell_type": "code",
589
+ "execution_count": null,
590
+ "metadata": {},
591
+ "outputs": [],
592
+ "source": [
593
+ "# Initialize the CrossEncoder model. Set the number of labels to 1 for regression tasks\n",
594
+ "model_crossencoder = CrossEncoder(model_name=MODEL, num_labels=1)"
595
+ ]
596
+ },
597
+ {
598
+ "cell_type": "code",
599
+ "execution_count": null,
600
+ "metadata": {},
601
+ "outputs": [],
602
+ "source": [
603
+ "# Define the evaluator\n",
604
+ "evaluator = CECorrelationEvaluator.from_input_examples(eval_samples)"
605
+ ]
606
+ },
607
+ {
608
+ "cell_type": "code",
609
+ "execution_count": null,
610
+ "metadata": {},
611
+ "outputs": [
612
+ {
613
+ "data": {
614
+ "application/vnd.jupyter.widget-view+json": {
615
+ "model_id": "9517a852f3d34cff86808c4b10cf8973",
616
+ "version_major": 2,
617
+ "version_minor": 0
618
+ },
619
+ "text/plain": [
620
+ "Epoch: 0%| | 0/3 [00:00<?, ?it/s]"
621
+ ]
622
+ },
623
+ "metadata": {},
624
+ "output_type": "display_data"
625
+ },
626
+ {
627
+ "data": {
628
+ "application/vnd.jupyter.widget-view+json": {
629
+ "model_id": "6e942043c5a24e77bd6172cb5492d2a7",
630
+ "version_major": 2,
631
+ "version_minor": 0
632
+ },
633
+ "text/plain": [
634
+ "Iteration: 0%| | 0/166 [00:00<?, ?it/s]"
635
+ ]
636
+ },
637
+ "metadata": {},
638
+ "output_type": "display_data"
639
+ },
640
+ {
641
+ "data": {
642
+ "application/vnd.jupyter.widget-view+json": {
643
+ "model_id": "d039d5acf3ed424e9ff6d0b30b51aceb",
644
+ "version_major": 2,
645
+ "version_minor": 0
646
+ },
647
+ "text/plain": [
648
+ "Iteration: 0%| | 0/166 [00:00<?, ?it/s]"
649
+ ]
650
+ },
651
+ "metadata": {},
652
+ "output_type": "display_data"
653
+ },
654
+ {
655
+ "data": {
656
+ "application/vnd.jupyter.widget-view+json": {
657
+ "model_id": "5fd5d0442b76448e8cab18b652e29ad8",
658
+ "version_major": 2,
659
+ "version_minor": 0
660
+ },
661
+ "text/plain": [
662
+ "Iteration: 0%| | 0/166 [00:00<?, ?it/s]"
663
+ ]
664
+ },
665
+ "metadata": {},
666
+ "output_type": "display_data"
667
+ }
668
+ ],
669
+ "source": [
670
+ "# Train the CrossEncoder model\n",
671
+ "\n",
672
+ "model_crossencoder.fit(\n",
673
+ " train_dataloader=train_dataloader,\n",
674
+ " evaluator=evaluator,\n",
675
+ " epochs=3,\n",
676
+ " warmup_steps=500,\n",
677
+ " output_path=f\"models/{MODEL_NAME_CROSSENCODER}\",\n",
678
+ " save_best_model=True,\n",
679
+ ")"
680
+ ]
681
+ },
682
+ {
683
+ "cell_type": "code",
684
+ "execution_count": null,
685
+ "metadata": {},
686
+ "outputs": [],
687
+ "source": [
688
+ "# Save the model to the local directory and push it to the Hub\n",
689
+ "model_crossencoder.save_pretrained(f\"models/{MODEL_NAME_CROSSENCODER}\")\n",
690
+ "model_crossencoder.push_to_hub(f\"{REPO_NAME}/{MODEL_NAME_CROSSENCODER}\")"
691
+ ]
692
+ },
693
+ {
694
+ "cell_type": "markdown",
695
+ "metadata": {},
696
+ "source": [
697
+ "## Build the RAG Pipeline\n",
698
+ "\n",
699
+ "The following section is inspired by the Haystack tutorial, check it for further details: [Creating Your First QA Pipeline with Retrieval-Augmentation](https://haystack.deepset.ai/tutorials/27_first_rag_pipeline)"
700
+ ]
701
+ },
702
+ {
703
+ "cell_type": "code",
704
+ "execution_count": 4,
705
+ "metadata": {},
706
+ "outputs": [],
707
+ "source": [
708
+ "# Add the documents to the DocumentStore\n",
709
+ "# Use the already chunked documents from original datasets\n",
710
+ "\n",
711
+ "df = combined_rag_dataset.to_pandas()\n",
712
+ "df = df.drop_duplicates(subset=[\"context\"]) # drop duplicates based on \"context\" column\n",
713
+ "df = df.sample(n=10, random_state=42) # optional: sample a subset of the dataset\n",
714
+ "dataset = Dataset.from_pandas(df)\n",
715
+ "\n",
716
+ "docs = [Document(content=doc[\"context\"]) for doc in dataset]"
717
+ ]
718
+ },
719
+ {
720
+ "cell_type": "code",
721
+ "execution_count": null,
722
+ "metadata": {},
723
+ "outputs": [],
724
+ "source": [
725
+ "# Initialize the document store and store the documents with the embeddings using our bi-encoder model\n",
726
+ "\n",
727
+ "document_store = InMemoryDocumentStore()\n",
728
+ "doc_embedder = SentenceTransformersDocumentEmbedder(\n",
729
+ " model=f\"{REPO_NAME}/{MODEL_NAME_BIENCODER}\",\n",
730
+ ")\n",
731
+ "doc_embedder.warm_up()\n",
732
+ "\n",
733
+ "docs_with_embeddings = doc_embedder.run(docs)\n",
734
+ "document_store.write_documents(docs_with_embeddings[\"documents\"])\n",
735
+ "\n",
736
+ "text_embedder = SentenceTransformersTextEmbedder(\n",
737
+ " model=f\"{REPO_NAME}/{MODEL_NAME_BIENCODER}\",\n",
738
+ ")"
739
+ ]
740
+ },
741
+ {
742
+ "cell_type": "code",
743
+ "execution_count": null,
744
+ "metadata": {},
745
+ "outputs": [],
746
+ "source": [
747
+ "# Initialize the retriever (our bi-encoder model) and the ranker (our cross-encoder model)\n",
748
+ "\n",
749
+ "retriever = InMemoryEmbeddingRetriever(document_store)\n",
750
+ "ranker = SentenceTransformersDiversityRanker(\n",
751
+ " model=f\"{REPO_NAME}/{MODEL_NAME_CROSSENCODER}\"\n",
752
+ ")"
753
+ ]
754
+ },
755
+ {
756
+ "cell_type": "code",
757
+ "execution_count": null,
758
+ "metadata": {},
759
+ "outputs": [],
760
+ "source": [
761
+ "# Define the prompt builder and the chat generator to interact with the models using the HF Serverless Inference API\n",
762
+ "\n",
763
+ "template = [\n",
764
+ " ChatMessage.from_user(\n",
765
+ " \"\"\"\n",
766
+ "Given the following information, answer the question.\n",
767
+ "\n",
768
+ "Context:\n",
769
+ "{% for document in documents %}\n",
770
+ " {{ document.content }}\n",
771
+ "{% endfor %}\n",
772
+ "\n",
773
+ "Question: {{question}}\n",
774
+ "Answer:\n",
775
+ "\"\"\"\n",
776
+ " )\n",
777
+ "]\n",
778
+ "\n",
779
+ "prompt_builder = ChatPromptBuilder(template=template)\n",
780
+ "\n",
781
+ "chat_generator = HuggingFaceAPIChatGenerator(\n",
782
+ " api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API,\n",
783
+ " api_params={\"model\": \"meta-llama/Llama-3.1-8B-Instruct\"},\n",
784
+ " token=Secret.from_env_var(\"HF_TOKEN\"),\n",
785
+ ")"
786
+ ]
787
+ },
788
+ {
789
+ "cell_type": "code",
790
+ "execution_count": null,
791
+ "metadata": {},
792
+ "outputs": [],
793
+ "source": [
794
+ "# Initialize the pipeline with the components\n",
795
+ "\n",
796
+ "rag_pipeline = Pipeline()\n",
797
+ "rag_pipeline.add_component(\"text_embedder\", text_embedder)\n",
798
+ "rag_pipeline.add_component(\"retriever\", retriever)\n",
799
+ "rag_pipeline.add_component(\"ranker\", ranker)\n",
800
+ "rag_pipeline.add_component(\"prompt_builder\", prompt_builder)\n",
801
+ "rag_pipeline.add_component(\"llm\", chat_generator)"
802
+ ]
803
+ },
804
+ {
805
+ "cell_type": "code",
806
+ "execution_count": null,
807
+ "metadata": {},
808
+ "outputs": [
809
+ {
810
+ "data": {
811
+ "text/plain": [
812
+ "<haystack.core.pipeline.pipeline.Pipeline object at 0x32e75b4d0>\n",
813
+ "🚅 Components\n",
814
+ " - text_embedder: SentenceTransformersTextEmbedder\n",
815
+ " - retriever: InMemoryEmbeddingRetriever\n",
816
+ " - ranker: SentenceTransformersDiversityRanker\n",
817
+ " - prompt_builder: ChatPromptBuilder\n",
818
+ " - llm: HuggingFaceAPIChatGenerator\n",
819
+ "🛤️ Connections\n",
820
+ " - text_embedder.embedding -> retriever.query_embedding (List[float])\n",
821
+ " - retriever.documents -> ranker.documents (List[Document])\n",
822
+ " - ranker.documents -> prompt_builder.documents (List[Document])\n",
823
+ " - prompt_builder.prompt -> llm.messages (List[ChatMessage])"
824
+ ]
825
+ },
826
+ "execution_count": 12,
827
+ "metadata": {},
828
+ "output_type": "execute_result"
829
+ }
830
+ ],
831
+ "source": [
832
+ "# Connect the components to each other\n",
833
+ "\n",
834
+ "rag_pipeline.connect(\"text_embedder.embedding\", \"retriever.query_embedding\")\n",
835
+ "rag_pipeline.connect(\"retriever.documents\", \"ranker.documents\")\n",
836
+ "rag_pipeline.connect(\"ranker\", \"prompt_builder\")\n",
837
+ "rag_pipeline.connect(\"prompt_builder.prompt\", \"llm.messages\")"
838
+ ]
839
+ },
840
+ {
841
+ "cell_type": "code",
842
+ "execution_count": null,
843
+ "metadata": {},
844
+ "outputs": [
845
+ {
846
+ "data": {
847
+ "application/vnd.jupyter.widget-view+json": {
848
+ "model_id": "80c813c847524f1493067f6dbe65c725",
849
+ "version_major": 2,
850
+ "version_minor": 0
851
+ },
852
+ "text/plain": [
853
+ "Batches: 0%| | 0/1 [00:00<?, ?it/s]"
854
+ ]
855
+ },
856
+ "metadata": {},
857
+ "output_type": "display_data"
858
+ },
859
+ {
860
+ "name": "stdout",
861
+ "output_type": "stream",
862
+ "text": [
863
+ "It seems that there is not enough information given in the human rights protocols provided to accurately answer the question. However, we can inform you that there are several types of human rights documents that this could be referring too. Event the most widely respected declared world document on human rights for Example - Exernal and some Individual (Part 1 Art.) and some other attempted Separation apart include: The convention lists several key rights such as \n",
864
+ "\n",
865
+ "1. Right to Life \n",
866
+ "2. Right to Liberty and Security \n",
867
+ "3. Freedom from Torture \n",
868
+ "4. Freedom from Slavery \n",
869
+ "5. Right to a Fair Trial \n",
870
+ "6. No Punishment without Law \n",
871
+ "7. Respect for Family Life \n",
872
+ "... (and throughout given information 44 protocals - are actually chapter and not... How is the answer \n",
873
+ " \n",
874
+ "\n",
875
+ "Not possible to answer your question due to lack of information, however we can tell you Event the most widely respected declared world document on human rights.\n"
876
+ ]
877
+ }
878
+ ],
879
+ "source": [
880
+ "# Make a query to the pipeline without references included in your documentation\n",
881
+ "question = \"How many human rights there are?\"\n",
882
+ "\n",
883
+ "response = rag_pipeline.run(\n",
884
+ " {\n",
885
+ " \"text_embedder\": {\"text\": question},\n",
886
+ " \"prompt_builder\": {\"question\": question},\n",
887
+ " \"ranker\": {\"query\": question},\n",
888
+ " }\n",
889
+ ")\n",
890
+ "\n",
891
+ "print(response[\"llm\"][\"replies\"][0].text)"
892
+ ]
893
+ },
894
+ {
895
+ "cell_type": "code",
896
+ "execution_count": null,
897
+ "metadata": {},
898
+ "outputs": [
899
+ {
900
+ "data": {
901
+ "application/vnd.jupyter.widget-view+json": {
902
+ "model_id": "2995f14154d148589129a3f449adc5d5",
903
+ "version_major": 2,
904
+ "version_minor": 0
905
+ },
906
+ "text/plain": [
907
+ "Batches: 0%| | 0/1 [00:00<?, ?it/s]"
908
+ ]
909
+ },
910
+ "metadata": {},
911
+ "output_type": "display_data"
912
+ },
913
+ {
914
+ "name": "stdout",
915
+ "output_type": "stream",
916
+ "text": [
917
+ "The information you provided does not directly list the \"Right of Fair Trial\" but looking under articles of the Convention for the Protection of Human Rights and Fundamental Freedoms, Article 6, also known as the Right to a Fair Trial, gives a clear idea.\n",
918
+ "\n",
919
+ " Article 6. Right to a fair Trial\n",
920
+ " \n",
921
+ "\n",
922
+ "1. Everyone is entitled to a fair and public hearing within a reasonable time by an independent and impartial tribunal established by law.\n",
923
+ " \n",
924
+ "2, everybody shall be presumed innocent until proven guilty by a final decision of a competent court.\n",
925
+ " \n",
926
+ "3. Everyone charged with a criminal offence has the following minimum rights:\n",
927
+ "\n",
928
+ " a to be informed promptly, in a language which he understands and in detail, of the charges, if any, against him.\n",
929
+ " b to have adequate time and facilities for the preparation of his defence.\n",
930
+ " c to defend himself in person or through legal assistance of his own choosing or, if he has not sufficient means to pay for legal assistance, to be given it free when the interests of justice so require.\n",
931
+ " d to be tried in his presence, and to defend himself in person or through legal assistance of his own choosing; to be informed, if he does not have legal assistance chosen or appointed under Article 5 Part 3 of this Convention, to communicate with the defence he has chosen\n",
932
+ " e to have the free assistance of an interpreter if he cannot understand or speak the language used in court.\n",
933
+ " \n",
934
+ " \n",
935
+ "4. Everyone sentenced has the right to, review by a higher tribunal according to law\n",
936
+ "\n",
937
+ "5. Everyone sentenced has the right to, take up or pursue his occupation.\n",
938
+ "\n",
939
+ "6. Sentences may, also include restoration of rights or removal of disabilities\n"
940
+ ]
941
+ }
942
+ ],
943
+ "source": [
944
+ "# Make a query to the pipeline with references included in your documentation\n",
945
+ "question = \"What's the Right of Fair Trial?\"\n",
946
+ "\n",
947
+ "response = rag_pipeline.run(\n",
948
+ " {\n",
949
+ " \"text_embedder\": {\"text\": question},\n",
950
+ " \"prompt_builder\": {\"question\": question},\n",
951
+ " \"ranker\": {\"query\": question},\n",
952
+ " }\n",
953
+ ")\n",
954
+ "\n",
955
+ "print(response[\"llm\"][\"replies\"][0].text)"
956
+ ]
957
+ }
958
+ ],
959
+ "metadata": {
960
+ "kernelspec": {
961
+ "display_name": "distilabel-tutorials",
962
+ "language": "python",
963
+ "name": "python3"
964
+ },
965
+ "language_info": {
966
+ "codemirror_mode": {
967
+ "name": "ipython",
968
+ "version": 3
969
+ },
970
+ "file_extension": ".py",
971
+ "mimetype": "text/x-python",
972
+ "name": "python",
973
+ "nbconvert_exporter": "python",
974
+ "pygments_lexer": "ipython3",
975
+ "version": "3.11.4"
976
+ }
977
+ },
978
+ "nbformat": 4,
979
+ "nbformat_minor": 2
980
+ }
examples/fine-tune-smollm2-on-synthetic-data.ipynb ADDED
@@ -0,0 +1,310 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "# Fine-tune a SmolLM on domain-specific synthetic data from a LLM\n",
8
+ "\n",
9
+ "Yes, smoll models can beat GPT4-like models on domain-specific tasks but don't expect miracles. When comparing smoll vs large, consider all costs and gains like difference performance and the value of using private and local models and data that you own.\n",
10
+ "\n",
11
+ "The [Hugging Face SmolLM models](https://github.com/huggingface/smollm) are blazingly fast and remarkably powerful. With its 135M, 360M and 1.7B parameter models, it is a great choice for a small and fast model. The great thing about SmolLM is that it is a general-purpose model that can be fine-tuned on domain-specific data.\n",
12
+ "\n",
13
+ "A lack of domain-specific datasets is a common problem for smaller and more specialized models. This is because it is difficult to find a dataset that is both representative and diverse enough for a specific task. We solve this problem by generating a synthetic dataset from an LLM using the `synthetic-data-generator`, which is available as a [Hugging Face Space](https://huggingface.co/spaces/argilla/synthetic-data-generator) or on [GitHub](https://github.com/argilla-io/synthetic-data-generator).\n",
14
+ "\n",
15
+ "In this example, we will fine-tune a SmolLM2 model on a synthetic dataset generated from `meta-llama/Meta-Llama-3.1-8B-Instruct` with the `synthetic-data-generator`.\n",
16
+ "\n",
17
+ "## Install the dependencies\n",
18
+ "\n",
19
+ "We will install some basic dependencies for the fine-tuning with `trl` but we will use the Synthetic Data Generator UI to generate the synthetic dataset."
20
+ ]
21
+ },
22
+ {
23
+ "cell_type": "code",
24
+ "execution_count": null,
25
+ "metadata": {},
26
+ "outputs": [],
27
+ "source": [
28
+ "!pip install transformers datasets trl torch"
29
+ ]
30
+ },
31
+ {
32
+ "cell_type": "markdown",
33
+ "metadata": {},
34
+ "source": [
35
+ "## The problem\n",
36
+ "\n",
37
+ "Reasoning data has proven to be a fundamental change in the performance of generative models. Reasoning is amazing but it also means the model generates more \"chatty\" during the token generation process, causing the model to become slower and more expensive. For this reason, we want to create a model that can reason without being too chatty. Therefore, we will generate a concise reasoning dataset and fine-tune a SmolLM2 model on it.\n",
38
+ "\n",
39
+ "## Let's generate some data\n",
40
+ "\n",
41
+ "Let's go to the [hosted Hugging Face Space](https://huggingface.co/spaces/argilla/synthetic-data-generator) to generate the data. This is done in three steps 1) we come up with a dataset description, 2) iterate on the task configuration, and 3) generate and push the data to Hugging Face. A more detailed flow can be found in [this blog post](https://huggingface.co/blog/synthetic-data-generator). \n",
42
+ "\n",
43
+ "<iframe\n",
44
+ "\tsrc=\"https://argilla-synthetic-data-generator.hf.space\"\n",
45
+ "\tframeborder=\"0\"\n",
46
+ "\twidth=\"850\"\n",
47
+ "\theight=\"450\"\n",
48
+ "></iframe>\n",
49
+ "\n",
50
+ "For this example, we will generate 5000 chat data examples for a single turn in the conversation. All examples have been generated with a temperature of 1. After some iteration, we come up with the following system prompt:\n",
51
+ "\n",
52
+ "```\n",
53
+ "You are an AI assistant who provides brief and to-the-point responses with logical step-by-step reasoning. Your purpose is to offer straightforward explanations and answers so that you can get to the heart of the issue. Respond with extremely concise, direct justifications and evidence-based conclusions. User questions are direct and concise.\n",
54
+ "```\n",
55
+ "\n",
56
+ "We press the \"Push to Hub\" button and wait for the data to be generated. This takes a few hours and we end up with a dataset with 5000 examples, which is the maximum number of examples we can generate in a single run. You can scale this by deploying a private instance of the Synthetic Data Generator. \n",
57
+ "\n",
58
+ "<iframe\n",
59
+ " src=\"https://huggingface.co/datasets/argilla/synthetic-concise-reasoning-sft-filtered/embed/viewer/default/train\"\n",
60
+ " frameborder=\"0\"\n",
61
+ " width=\"100%\"\n",
62
+ " height=\"560px\"\n",
63
+ "></iframe>\n",
64
+ "\n",
65
+ "The data is pushed to Argilla too so we recommend inspecting and validating the the data before finetuning the actual model. We applied some basic filters and transformations to the data to make it more suitable for fine-tuning.\n",
66
+ "\n",
67
+ "## Fine-tune the model\n",
68
+ "\n",
69
+ "We will use TRL to fine-tune the model. It is part of the Hugging Face ecosystem and works seamlessly on top of datasets generated by the synthetic data generator without needing to do any data transformations.\n",
70
+ "\n",
71
+ "### Load the model\n",
72
+ "\n",
73
+ "We will first load the model and tokenizer and set up the chat format."
74
+ ]
75
+ },
76
+ {
77
+ "cell_type": "code",
78
+ "execution_count": 5,
79
+ "metadata": {},
80
+ "outputs": [],
81
+ "source": [
82
+ "# Import necessary libraries\n",
83
+ "from transformers import AutoModelForCausalLM, AutoTokenizer\n",
84
+ "from datasets import load_dataset\n",
85
+ "from trl import SFTConfig, SFTTrainer, setup_chat_format\n",
86
+ "import torch\n",
87
+ "import os\n",
88
+ "\n",
89
+ "device = (\n",
90
+ " \"cuda\"\n",
91
+ " if torch.cuda.is_available()\n",
92
+ " else \"mps\" if torch.backends.mps.is_available() else \"cpu\"\n",
93
+ ")\n",
94
+ "\n",
95
+ "# Load the model and tokenizer\n",
96
+ "model_name = \"HuggingFaceTB/SmolLM2-360M\"\n",
97
+ "model = AutoModelForCausalLM.from_pretrained(\n",
98
+ " pretrained_model_name_or_path=model_name\n",
99
+ ")\n",
100
+ "tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=model_name)\n",
101
+ "\n",
102
+ "# Set up the chat format\n",
103
+ "model, tokenizer = setup_chat_format(model=model, tokenizer=tokenizer)"
104
+ ]
105
+ },
106
+ {
107
+ "cell_type": "markdown",
108
+ "metadata": {},
109
+ "source": [
110
+ "### Test the base model\n",
111
+ "\n",
112
+ "We will first test the base model to see how it performs on the task. During this step we will also generate a prompt for the model to respond to, to see how it performs on the task."
113
+ ]
114
+ },
115
+ {
116
+ "cell_type": "code",
117
+ "execution_count": 2,
118
+ "metadata": {},
119
+ "outputs": [
120
+ {
121
+ "name": "stderr",
122
+ "output_type": "stream",
123
+ "text": [
124
+ "Device set to use mps:0\n"
125
+ ]
126
+ },
127
+ {
128
+ "data": {
129
+ "text/plain": [
130
+ "[{'generated_text': 'What is the primary function of mitochondria within a cell?\\n\\nMitochondria are the powerhouses of the cell. They are responsible for the production of ATP (adenosine triphosphate) and the energy required for cellular processes.\\n\\nWhat is the function of the mitochondria in the cell?\\n\\nThe mitochondria are the powerhouses of the cell. They are responsible for the production of ATP (adenosine triphosphate) and the energy required for cellular processes.\\n\\nWhat is the function of the mitochondria in the cell?\\n\\nThe'}]"
131
+ ]
132
+ },
133
+ "execution_count": 2,
134
+ "metadata": {},
135
+ "output_type": "execute_result"
136
+ }
137
+ ],
138
+ "source": [
139
+ "from transformers import pipeline\n",
140
+ "\n",
141
+ "prompt = \"What is the primary function of mitochondria within a cell?\"\n",
142
+ "\n",
143
+ "pipe = pipeline(\"text-generation\", model=model, tokenizer=tokenizer, device=device)\n",
144
+ "pipe(prompt, max_new_tokens=100)"
145
+ ]
146
+ },
147
+ {
148
+ "cell_type": "markdown",
149
+ "metadata": {},
150
+ "source": [
151
+ "### Load the dataset\n",
152
+ "\n",
153
+ "For fine-tuning, we need to load the dataset and tokenize it. We will use the `synthetic-concise-reasoning-sft-filtered` dataset that we generated in the previous step."
154
+ ]
155
+ },
156
+ {
157
+ "cell_type": "code",
158
+ "execution_count": 2,
159
+ "metadata": {},
160
+ "outputs": [
161
+ {
162
+ "name": "stderr",
163
+ "output_type": "stream",
164
+ "text": [
165
+ "Map: 100%|██████████| 4133/4133 [00:00<00:00, 18478.53 examples/s]\n"
166
+ ]
167
+ }
168
+ ],
169
+ "source": [
170
+ "from datasets import load_dataset\n",
171
+ "\n",
172
+ "ds = load_dataset(\"argilla/synthetic-concise-reasoning-sft-filtered\")\n",
173
+ "def tokenize_function(examples):\n",
174
+ " examples[\"text\"] = tokenizer.apply_chat_template([{\"role\": \"user\", \"content\": examples[\"prompt\"].strip()}, {\"role\": \"assistant\", \"content\": examples[\"completion\"].strip()}], tokenize=False)\n",
175
+ " return examples\n",
176
+ "ds = ds.map(tokenize_function)\n",
177
+ "ds = ds.shuffle()"
178
+ ]
179
+ },
180
+ {
181
+ "cell_type": "markdown",
182
+ "metadata": {},
183
+ "source": [
184
+ "### Fine-tune the model\n",
185
+ "\n",
186
+ "We will now fine-tune the model. We will use the `SFTTrainer` from the `trl` library to fine-tune the model. We will use a batch size of 4 and a learning rate of 5e-5. We will also use the `use_mps_device` flag to use the MPS device if available."
187
+ ]
188
+ },
189
+ {
190
+ "cell_type": "code",
191
+ "execution_count": null,
192
+ "metadata": {},
193
+ "outputs": [],
194
+ "source": [
195
+ "os.environ[\"PYTORCH_MPS_HIGH_WATERMARK_RATIO\"] = \"0.0\"\n",
196
+ "\n",
197
+ "# Configure the SFTTrainer\n",
198
+ "sft_config = SFTConfig(\n",
199
+ " output_dir=\"./sft_output\",\n",
200
+ " num_train_epochs=1,\n",
201
+ " per_device_train_batch_size=4, # Set according to your GPU memory capacity\n",
202
+ " learning_rate=5e-5, # Common starting point for fine-tuning\n",
203
+ " logging_steps=100, # Frequency of logging training metrics\n",
204
+ " use_mps_device= True if device == \"mps\" else False,\n",
205
+ " hub_model_id=\"argilla/SmolLM2-360M-synthetic-concise-reasoning\", # Set a unique name for your model\n",
206
+ " push_to_hub=True,\n",
207
+ ")\n",
208
+ "\n",
209
+ "# Initialize the SFTTrainer\n",
210
+ "trainer = SFTTrainer(\n",
211
+ " model=model,\n",
212
+ " args=sft_config,\n",
213
+ " train_dataset=ds[\"train\"],\n",
214
+ " tokenizer=tokenizer,\n",
215
+ ")\n",
216
+ "trainer.train()"
217
+ ]
218
+ },
219
+ {
220
+ "cell_type": "markdown",
221
+ "metadata": {},
222
+ "source": [
223
+ "```\n",
224
+ "# {'loss': 1.4498, 'grad_norm': 2.3919131755828857, 'learning_rate': 4e-05, 'epoch': 0.1}\n",
225
+ "# {'loss': 1.362, 'grad_norm': 1.6650595664978027, 'learning_rate': 3e-05, 'epoch': 0.19}\n",
226
+ "# {'loss': 1.3778, 'grad_norm': 1.4778285026550293, 'learning_rate': 2e-05, 'epoch': 0.29}\n",
227
+ "# {'loss': 1.3735, 'grad_norm': 2.1424977779388428, 'learning_rate': 1e-05, 'epoch': 0.39}\n",
228
+ "# {'loss': 1.3512, 'grad_norm': 2.3498542308807373, 'learning_rate': 0.0, 'epoch': 0.48}\n",
229
+ "# {'train_runtime': 1911.514, 'train_samples_per_second': 1.046, 'train_steps_per_second': 0.262, 'train_loss': 1.3828572998046875, 'epoch': 0.48}\n",
230
+ "```\n",
231
+ "\n",
232
+ "For the example, we did not use a specific validation set but we can see the loss is decreasing, so we assume the model is generalsing well to the training data. To get a better understanding of the model's performance, let's test it again with the same prompt.\n",
233
+ "\n",
234
+ "### Run inference\n",
235
+ "\n",
236
+ "We can now run inference with [the fine-tuned model](https://huggingface.co/argilla/SmolLM2-360M-synthetic-concise-reasoning/blob/main/README.md)."
237
+ ]
238
+ },
239
+ {
240
+ "cell_type": "code",
241
+ "execution_count": 12,
242
+ "metadata": {},
243
+ "outputs": [
244
+ {
245
+ "name": "stderr",
246
+ "output_type": "stream",
247
+ "text": [
248
+ "Device set to use mps\n"
249
+ ]
250
+ },
251
+ {
252
+ "data": {
253
+ "text/plain": [
254
+ "'The primary function of mitochondria is to generate energy for the cell. They are organelles found in eukaryotic cells that convert nutrients into ATP (adenosine triphosphate), which is the primary source of energy for cellular processes.\\nMitochondria are responsible for:\\n\\nEnergy production: Mitochondria produce ATP through a process called oxidative phosphorylation, which involves the transfer of electrons from food molecules to oxygen.\\nEnergy storage: Mitochondria store energy in the form of adenosine triphosphate (ATP), which is used by the cell for various cellular processes.\\nCellular respiration: Mitochondria also participate in cellular respiration, a'"
255
+ ]
256
+ },
257
+ "execution_count": 12,
258
+ "metadata": {},
259
+ "output_type": "execute_result"
260
+ }
261
+ ],
262
+ "source": [
263
+ "prompt = \"What is the primary function of mitochondria within a cell?\"\n",
264
+ "\n",
265
+ "generator = pipeline(\n",
266
+ " \"text-generation\",\n",
267
+ " model=\"argilla/SmolLM2-360M-synthetic-concise-reasoning\",\n",
268
+ " device=\"mps\",\n",
269
+ ")\n",
270
+ "generator(\n",
271
+ " [{\"role\": \"user\", \"content\": prompt}], max_new_tokens=128, return_full_text=False\n",
272
+ ")[0][\"generated_text\"]"
273
+ ]
274
+ },
275
+ {
276
+ "cell_type": "markdown",
277
+ "metadata": {},
278
+ "source": [
279
+ "## Conclusion\n",
280
+ "\n",
281
+ "We have fine-tuned a SmolLM2 model on a synthetic dataset generated from a large language model. We have seen that the model performs well on the task and that the synthetic data is a great way to generate diverse and representative data for supervised fine-tuning. \n",
282
+ "\n",
283
+ "In practice, you would likely want to spend more time on the data quality and fine-tuning the model but the flow shows the Synthetic Data Generator is a great tool to generate synthetic data for any task.\n",
284
+ "\n",
285
+ "Overall, I think it is pretty cool for a couple of hours of generation and fine-tuning on consumer hardware.\n"
286
+ ]
287
+ }
288
+ ],
289
+ "metadata": {
290
+ "kernelspec": {
291
+ "display_name": ".venv",
292
+ "language": "python",
293
+ "name": "python3"
294
+ },
295
+ "language_info": {
296
+ "codemirror_mode": {
297
+ "name": "ipython",
298
+ "version": 3
299
+ },
300
+ "file_extension": ".py",
301
+ "mimetype": "text/x-python",
302
+ "name": "python",
303
+ "nbconvert_exporter": "python",
304
+ "pygments_lexer": "ipython3",
305
+ "version": "3.11.9"
306
+ }
307
+ },
308
+ "nbformat": 4,
309
+ "nbformat_minor": 2
310
+ }
examples/hf-dedicated-or-tgi-deployment.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # /// script
2
+ # requires-python = ">=3.11,<3.12"
3
+ # dependencies = [
4
+ # "synthetic-dataset-generator",
5
+ # ]
6
+ # ///
7
+ import os
8
+
9
+ from synthetic_dataset_generator import launch
10
+
11
+ os.environ["HF_TOKEN"] = "hf_..." # push the data to huggingface
12
+ os.environ["HUGGINGFACE_BASE_URL"] = "http://127.0.0.1:3000/" # dedicated endpoint/TGI
13
+ os.environ["MAGPIE_PRE_QUERY_TEMPLATE"] = "llama3" # magpie template
14
+ os.environ["TOKENIZER_ID"] = (
15
+ "meta-llama/Llama-3.1-8B-Instruct" # tokenizer for model hosted on endpoint
16
+ )
17
+ os.environ["MODEL"] = None # model is linked to endpoint
18
+
19
+ launch()
examples/hf-serverless-deployment-deepseek.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # /// script
2
+ # requires-python = ">=3.11,<3.12"
3
+ # dependencies = [
4
+ # "synthetic-dataset-generator",
5
+ # ]
6
+ # ///
7
+ import os
8
+
9
+ from synthetic_dataset_generator import launch
10
+
11
+ os.environ["HF_TOKEN"] = "hf_..." # push the data to huggingface
12
+ os.environ["MODEL"] = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B" # use model for instructions
13
+ os.environ["MAGPIE_PRE_QUERY_TEMPLATE"] = "<|begin▁of▁sentence|>User: " # use the custom template for the model
14
+
15
+
16
+ launch()
examples/hf-serverless-deployment.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # /// script
2
+ # requires-python = ">=3.11,<3.12"
3
+ # dependencies = [
4
+ # "synthetic-dataset-generator",
5
+ # ]
6
+ # ///
7
+ import os
8
+
9
+ from synthetic_dataset_generator import launch
10
+
11
+ os.environ["HF_TOKEN"] = "hf_..." # push the data to huggingface
12
+ os.environ["MODEL"] = "meta-llama/Llama-3.1-8B-Instruct" # use model for generation
13
+ os.environ["MAGPIE_PRE_QUERY_TEMPLATE"] = "llama3" # use the template for the model
14
+
15
+ launch()
examples/hf-serverless-different-model-for-completion.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # /// script
2
+ # requires-python = ">=3.11,<3.12"
3
+ # dependencies = [
4
+ # "synthetic-dataset-generator",
5
+ # ]
6
+ # ///
7
+ import os
8
+
9
+ from synthetic_dataset_generator import launch
10
+
11
+ os.environ["HF_TOKEN"] = "hf_..." # push the data to huggingface
12
+ os.environ["MODEL"] = "meta-llama/Llama-3.1-8B-Instruct" # use model for instruction generation
13
+ os.environ["MODEL_COMPLETION"] = "meta-llama/Llama-3.1-70B-Instruct" # use model for completion generation
14
+ os.environ["MAGPIE_PRE_QUERY_TEMPLATE"] = "llama3" # use the template for the model
15
+
16
+ launch()
examples/ollama-deployment.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # /// script
2
+ # requires-python = ">=3.11,<3.12"
3
+ # dependencies = [
4
+ # "synthetic-dataset-generator",
5
+ # ]
6
+ # ///
7
+ # ollama serve
8
+ # ollama run qwen2.5:32b-instruct-q5_K_S
9
+ import os
10
+
11
+ from synthetic_dataset_generator import launch
12
+
13
+ os.environ["HF_TOKEN"] = "hf_..." # push the data to huggingface
14
+ os.environ["OLLAMA_BASE_URL"] = "http://127.0.0.1:11434/" # ollama base url
15
+ os.environ["MODEL"] = "qwen2.5:32b-instruct-q5_K_S" # model id
16
+ os.environ["TOKENIZER_ID"] = "Qwen/Qwen2.5-32B-Instruct" # tokenizer id
17
+ os.environ["MAGPIE_PRE_QUERY_TEMPLATE"] = "qwen2"
18
+ os.environ["MAX_NUM_ROWS"] = "10000"
19
+ os.environ["DEFAULT_BATCH_SIZE"] = "2"
20
+ os.environ["MAX_NUM_TOKENS"] = "1024"
21
+
22
+ launch()
examples/ollama-different-model-for-completion.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # /// script
2
+ # requires-python = ">=3.11,<3.12"
3
+ # dependencies = [
4
+ # "synthetic-dataset-generator",
5
+ # ]
6
+ # ///
7
+ # ollama serve
8
+ # ollama run llama3.2
9
+ # ollama run llama3.2:1b
10
+ import os
11
+
12
+ from synthetic_dataset_generator import launch
13
+
14
+ os.environ["OLLAMA_BASE_URL"] = (
15
+ "http://127.0.0.1:11434/" # in this case, the same base url for both models
16
+ )
17
+
18
+ os.environ["MODEL"] = "llama3.2" # model for instruction generation
19
+ os.environ["MODEL_COMPLETION"] = "llama3.2:1b" # model for completion generation
20
+
21
+ os.environ["TOKENIZER_ID"] = "meta-llama/Llama-3.2-3B-Instruct" # tokenizer for instruction generation
22
+ os.environ["TOKENIZER_ID_COMPLETION"] = "meta-llama/Llama-3.2-1B-Instruct" # tokenizer for completion generation
23
+
24
+ os.environ["MAGPIE_PRE_QUERY_TEMPLATE"] = "llama3" # magpie template required for instruction generation
25
+
26
+ launch()
examples/openai-deployment.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # /// script
2
+ # requires-python = ">=3.11,<3.12"
3
+ # dependencies = [
4
+ # "synthetic-dataset-generator",
5
+ # ]
6
+ # ///
7
+
8
+ import os
9
+
10
+ from synthetic_dataset_generator import launch
11
+
12
+ os.environ["HF_TOKEN"] = "hf_..." # push the data to huggingface
13
+ os.environ["OPENAI_BASE_URL"] = "https://api.openai.com/v1/" # openai base url
14
+ os.environ["API_KEY"] = os.getenv("OPENAI_API_KEY") # openai api key
15
+ os.environ["MODEL"] = "gpt-4o" # model id
16
+ os.environ["MAGPIE_PRE_QUERY_TEMPLATE"] = None # chat data not supported with OpenAI
17
+
18
+ launch()
examples/vllm-deployment.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # /// script
2
+ # requires-python = ">=3.11,<3.12"
3
+ # dependencies = [
4
+ # "synthetic-dataset-generator",
5
+ # ]
6
+ # ///
7
+ # vllm serve Qwen/Qwen2.5-1.5B-Instruct
8
+ import os
9
+
10
+ from synthetic_dataset_generator import launch
11
+
12
+ os.environ["HF_TOKEN"] = "hf_..." # push the data to huggingface
13
+ os.environ["VLLM_BASE_URL"] = "http://127.0.0.1:8000/" # vllm base url
14
+ os.environ["MODEL"] = "Qwen/Qwen2.5-1.5B-Instruct" # model id
15
+ os.environ["TOKENIZER_ID"] = "Qwen/Qwen2.5-1.5B-Instruct" # tokenizer id
16
+ os.environ["MAGPIE_PRE_QUERY_TEMPLATE"] = "qwen2"
17
+ os.environ["MAX_NUM_ROWS"] = "10000"
18
+ os.environ["DEFAULT_BATCH_SIZE"] = "2"
19
+ os.environ["MAX_NUM_TOKENS"] = "1024"
20
+
21
+ launch()
packages.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ poppler-utils
2
+ tesseract-ocr
pdm.lock ADDED
The diff for this file is too large to render. See raw diff
 
pyproject.toml ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "synthetic-dataset-generator"
3
+ version = "0.2.0"
4
+ description = "Build datasets using natural language"
5
+ authors = [
6
+ {name = "davidberenstein1957", email = "david.m.berenstein@gmail.com"},
7
+ ]
8
+ keywords = [
9
+ "gradio",
10
+ "synthetic-data",
11
+ "huggingface",
12
+ "argilla",
13
+ "generative-ai",
14
+ "ai",
15
+ ]
16
+ requires-python = "<3.13,>=3.10"
17
+ readme = "README.md"
18
+ license = {text = "Apache 2"}
19
+
20
+ dependencies = [
21
+ "argilla>=2.4.0,<3.0.0",
22
+ "distilabel[argilla,hf-inference-endpoints,hf-transformers,instructor,llama-cpp,ollama,openai,outlines,vllm,vision]>=1.5.0,<2.00",
23
+ "gradio[oauth]>=5.4.0,<6.0.0",
24
+ "gradio-huggingfacehub-search>=0.0.12,<1.0.0",
25
+ "huggingface-hub>=0.26.0,<0.28.0",
26
+ "model2vec>=0.2.4,<1.0.0",
27
+ "nltk>=3.9.1,<4.0.0",
28
+ "pydantic>=2.10.5,<3.0.0",
29
+ "sentence-transformers>=3.2.0,<4.0.0",
30
+ "transformers>=4.44.2,<5.0.0",
31
+ "unstructured[md,pdf,docx]>=0.16.3,<1.0.0",
32
+ "setuptools",
33
+ ]
34
+
35
+ [build-system]
36
+ requires = ["pdm-backend"]
37
+ build-backend = "pdm.backend"
38
+
39
+ [tool.pdm]
40
+ distribution = true
requirements.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ -e git+https://github.com/argilla-io/synthetic-data-generator.git#egg=synthetic-dataset-generator
src/synthetic_dataset_generator/__init__.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ from gradio import TabbedInterface
3
+
4
+ from synthetic_dataset_generator import ( # noqa
5
+ _distiset,
6
+ _inference_endpoints,
7
+ )
8
+
9
+ def launch(*args, **kwargs):
10
+ """Launch the synthetic dataset generator.
11
+ Based on the `TabbedInterface` from Gradio.
12
+ Parameters: https://www.gradio.app/docs/gradio/tabbedinterface
13
+ """
14
+ from synthetic_dataset_generator.app import demo
15
+ return demo.launch(*args, server_name="0.0.0.0", **kwargs)
16
+
17
+
18
+ launch.__doc__ = TabbedInterface.launch.__doc__
19
+ launch.__signature__ = inspect.signature(TabbedInterface.launch)
20
+ launch.__annotations__ = TabbedInterface.launch.__annotations__
src/synthetic_dataset_generator/__main__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ if __name__ == "__main__":
2
+ from synthetic_dataset_generator import launch
3
+
4
+ launch()
src/synthetic_dataset_generator/_distiset.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import distilabel
4
+ import distilabel.distiset
5
+ import gradio as gr
6
+ from distilabel.utils.card.dataset_card import (
7
+ DistilabelDatasetCard,
8
+ size_categories_parser,
9
+ )
10
+ from huggingface_hub import DatasetCardData, HfApi
11
+
12
+
13
+ class CustomDistisetWithAdditionalTag(distilabel.distiset.Distiset):
14
+ def _generate_card(
15
+ self,
16
+ repo_id: str,
17
+ token: str,
18
+ include_script: bool = False,
19
+ filename_py: Optional[str] = None,
20
+ ) -> None:
21
+ """Generates a dataset card and pushes it to the Hugging Face Hub, and
22
+ if the `pipeline.yaml` path is available in the `Distiset`, uploads that
23
+ to the same repository.
24
+
25
+ Args:
26
+ repo_id: The ID of the repository to push to, from the `push_to_hub` method.
27
+ token: The token to authenticate with the Hugging Face Hub, from the `push_to_hub` method.
28
+ include_script: Whether to upload the script to the hugging face repository.
29
+ filename_py: The name of the script. If `include_script` is True, the script will
30
+ be uploaded to the repository using this name, otherwise it won't be used.
31
+ """
32
+ card = self._get_card(
33
+ repo_id=repo_id,
34
+ token=token,
35
+ include_script=include_script,
36
+ filename_py=filename_py,
37
+ )
38
+
39
+ card.push_to_hub(
40
+ repo_id,
41
+ repo_type="dataset",
42
+ token=token,
43
+ )
44
+ if self.pipeline_path:
45
+ # If the pipeline.yaml is available, upload it to the Hugging Face Hub as well.
46
+ HfApi().upload_file(
47
+ path_or_fileobj=self.pipeline_path,
48
+ path_in_repo=distilabel.distiset.PIPELINE_CONFIG_FILENAME,
49
+ repo_id=repo_id,
50
+ repo_type="dataset",
51
+ token=token,
52
+ )
53
+
54
+ def _get_card(
55
+ self,
56
+ repo_id: str,
57
+ token: Optional[str] = None,
58
+ include_script: bool = False,
59
+ filename_py: Optional[str] = None,
60
+ ) -> DistilabelDatasetCard:
61
+ """Generates the dataset card for the `Distiset`.
62
+
63
+ Note:
64
+ If `repo_id` and `token` are provided, it will extract the metadata from the README.md file
65
+ on the hub.
66
+
67
+ Args:
68
+ repo_id: Name of the repository to push to, or the path for the distiset if saved to disk.
69
+ token: The token to authenticate with the Hugging Face Hub.
70
+ We assume that if it's provided, the dataset will be in the Hugging Face Hub,
71
+ so the README metadata will be extracted from there.
72
+ include_script: Whether to upload the script to the hugging face repository.
73
+ filename_py: The name of the script. If `include_script` is True, the script will
74
+ be uploaded to the repository using this name, otherwise it won't be used.
75
+
76
+ Returns:
77
+ The dataset card for the `Distiset`.
78
+ """
79
+ sample_records = {}
80
+ for name, dataset in self.items():
81
+ sample_records[name] = (
82
+ dataset[0] if not isinstance(dataset, dict) else dataset["train"][0]
83
+ )
84
+
85
+ columns = self["default"].column_names
86
+ columns = self["default"].column_names
87
+
88
+ if ("label" in columns and "text" in columns) or (
89
+ "labels" in columns and "text" in columns
90
+ ):
91
+ task_categories = ["text-classification"]
92
+ elif ("prompt" in columns and "completion" in columns) or (
93
+ "messages" in columns
94
+ ):
95
+ task_categories: list[str] = [
96
+ "text-generation",
97
+ "text2text-generation",
98
+ "question-answering",
99
+ ]
100
+ elif "context" in columns and "question" in columns and "response" in columns:
101
+ task_categories: list[str] = [
102
+ "text-generation",
103
+ "text2text-generation",
104
+ "text-retrieval",
105
+ "question-answering"
106
+ ]
107
+ if (
108
+ "positive_retrieval" in columns and "negative_retrieval" in columns
109
+ ) or ("positive_reranking" in columns and "negative_reranking" in columns):
110
+ task_categories.append("sentence-similarity")
111
+ else:
112
+ task_categories: list[str] = []
113
+ gr.Info(
114
+ f"No task categories found for dataset with columns: {columns}. "
115
+ "Please notify the distilabel team if you think this is an error."
116
+ )
117
+
118
+ readme_metadata = {}
119
+ if repo_id and token:
120
+ readme_metadata = self._extract_readme_metadata(repo_id, token)
121
+
122
+ metadata = {
123
+ **readme_metadata,
124
+ "size_categories": size_categories_parser(
125
+ max(len(dataset) for dataset in self.values())
126
+ ),
127
+ "task_categories": task_categories,
128
+ "tags": [
129
+ "synthetic",
130
+ "distilabel",
131
+ "rlaif",
132
+ "datacraft",
133
+ ],
134
+ }
135
+
136
+ card = DistilabelDatasetCard.from_template(
137
+ card_data=DatasetCardData(**metadata),
138
+ repo_id=repo_id,
139
+ sample_records=sample_records,
140
+ include_script=include_script,
141
+ filename_py=filename_py,
142
+ references=self.citations,
143
+ )
144
+
145
+ return card
146
+
147
+
148
+ distilabel.distiset.Distiset = CustomDistisetWithAdditionalTag
src/synthetic_dataset_generator/_inference_endpoints.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+
3
+ import distilabel
4
+ import distilabel.distiset
5
+ from distilabel.models import InferenceEndpointsLLM
6
+ from pydantic import (
7
+ ValidationError,
8
+ model_validator,
9
+ )
10
+
11
+
12
+ class CustomInferenceEndpointsLLM(InferenceEndpointsLLM):
13
+ @model_validator(mode="after") # type: ignore
14
+ def only_one_of_model_id_endpoint_name_or_base_url_provided(
15
+ self,
16
+ ) -> "InferenceEndpointsLLM":
17
+ """Validates that only one of `model_id` or `endpoint_name` is provided; and if `base_url` is also
18
+ provided, a warning will be shown informing the user that the provided `base_url` will be ignored in
19
+ favour of the dynamically calculated one.."""
20
+
21
+ if self.base_url and (self.model_id or self.endpoint_name):
22
+ warnings.warn( # type: ignore
23
+ f"Since the `base_url={self.base_url}` is available and either one of `model_id`"
24
+ " or `endpoint_name` is also provided, the `base_url` will either be ignored"
25
+ " or overwritten with the one generated from either of those args, for serverless"
26
+ " or dedicated inference endpoints, respectively."
27
+ )
28
+
29
+ if self.use_magpie_template and self.tokenizer_id is None:
30
+ raise ValueError(
31
+ "`use_magpie_template` cannot be `True` if `tokenizer_id` is `None`. Please,"
32
+ " set a `tokenizer_id` and try again."
33
+ )
34
+
35
+ if (
36
+ self.model_id
37
+ and self.tokenizer_id is None
38
+ and self.structured_output is not None
39
+ ):
40
+ self.tokenizer_id = self.model_id
41
+
42
+ if self.base_url and not (self.model_id or self.endpoint_name):
43
+ return self
44
+
45
+ if self.model_id and not self.endpoint_name:
46
+ return self
47
+
48
+ if self.endpoint_name and not self.model_id:
49
+ return self
50
+
51
+ raise ValidationError(
52
+ f"Only one of `model_id` or `endpoint_name` must be provided. If `base_url` is"
53
+ f" provided too, it will be overwritten instead. Found `model_id`={self.model_id},"
54
+ f" `endpoint_name`={self.endpoint_name}, and `base_url`={self.base_url}."
55
+ )
56
+
57
+
58
+ distilabel.models.llms.InferenceEndpointsLLM = CustomInferenceEndpointsLLM
src/synthetic_dataset_generator/_tabbedinterface.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This file defines two useful high-level abstractions to build Gradio apps: Interface and TabbedInterface.
3
+ """
4
+
5
+ from __future__ import annotations
6
+
7
+ from collections.abc import Sequence
8
+
9
+ import gradio as gr
10
+ from gradio.blocks import Blocks
11
+ from gradio.layouts import Tab, Tabs
12
+ from gradio.themes import ThemeClass as Theme
13
+ from gradio_client.documentation import document
14
+
15
+
16
+ @document()
17
+ class TabbedInterface(Blocks):
18
+ """
19
+ A TabbedInterface is created by providing a list of Interfaces or Blocks, each of which gets
20
+ rendered in a separate tab. Only the components from the Interface/Blocks will be rendered in the tab.
21
+ Certain high-level attributes of the Blocks (e.g. custom `css`, `js`, and `head` attributes) will not be loaded.
22
+
23
+ Demos: tabbed_interface_lite
24
+ """
25
+
26
+ def __init__(
27
+ self,
28
+ interface_list: Sequence[Blocks],
29
+ tab_names: list[str] | None = None,
30
+ title: str | None = None,
31
+ theme: Theme | str | None = None,
32
+ analytics_enabled: bool | None = None,
33
+ css: str | None = None,
34
+ js: str | None = None,
35
+ head: str | None = None,
36
+ ):
37
+ """
38
+ Parameters:
39
+ interface_list: A list of Interfaces (or Blocks) to be rendered in the tabs.
40
+ tab_names: A list of tab names. If None, the tab names will be "Tab 1", "Tab 2", etc.
41
+ title: The tab title to display when this demo is opened in a browser window.
42
+ theme: A Theme object or a string representing a theme. If a string, will look for a built-in theme with that name (e.g. "soft" or "default"), or will attempt to load a theme from the Hugging Face Hub (e.g. "gradio/monochrome"). If None, will use the Default theme.
43
+ analytics_enabled: Whether to allow basic telemetry. If None, will use GRADIO_ANALYTICS_ENABLED environment variable or default to True.
44
+ css: Custom css as a string or path to a css file. This css will be included in the demo webpage.
45
+ js: Custom js as a string or path to a js file. The custom js should in the form of a single js function. This function will automatically be executed when the page loads. For more flexibility, use the head parameter to insert js inside <script> tags.
46
+ head: Custom html to insert into the head of the demo webpage. This can be used to add custom meta tags, multiple scripts, stylesheets, etc. to the page.
47
+ Returns:
48
+ a Gradio Tabbed Interface for the given interfaces
49
+ """
50
+ super().__init__(
51
+ title="Synthetic Data Generator",
52
+ theme=theme,
53
+ analytics_enabled=analytics_enabled,
54
+ mode="tabbed_interface",
55
+ css=css,
56
+ js=js,
57
+ head=head,
58
+ )
59
+ if tab_names is None:
60
+ tab_names = [f"Tab {i}" for i in range(len(interface_list))]
61
+ with self:
62
+ h3 = "<div style='text-align: center;'><h2>Build datasets using natural language</h2></div>"
63
+ if title:
64
+ gr.HTML(value=title + h3)
65
+ gr.LoginButton(value="Sign in", variant="primary", elem_id="sign_in_button")
66
+ with Tabs():
67
+ for interface, tab_name in zip(interface_list, tab_names, strict=False):
68
+ with Tab(label=tab_name):
69
+ interface.render()
src/synthetic_dataset_generator/app.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from synthetic_dataset_generator._tabbedinterface import TabbedInterface
2
+
3
+ # from synthetic_dataset_generator.apps.eval import app as eval_app
4
+ from synthetic_dataset_generator.apps.rag import app as rag_app
5
+ from synthetic_dataset_generator.apps.about import app as about_app
6
+ from synthetic_dataset_generator.apps.chat import app as chat_app
7
+ from synthetic_dataset_generator.apps.textcat import app as textcat_app
8
+
9
+ theme = "argilla/argilla-theme"
10
+
11
+ css = """
12
+ .main_ui_logged_out{opacity: 0.3; pointer-events: none}
13
+ button[role="tab"][aria-selected="true"] { border: 0; background: var(--button-primary-background-fill); color: white; border-top-right-radius: var(--radius-md); border-top-left-radius: var(--radius-md)}
14
+ button[role="tab"][aria-selected="true"]:hover {border-color: var(--button-primary-background-fill); background: var(var(--button-primary-background-fill-hover))}
15
+ .tabitem {border: 0; padding-inline: 0}
16
+ .gallery-item {background: var(--background-fill-secondary); text-align: left}
17
+ .table-wrap .tbody td {vertical-align: top}
18
+ #system_prompt_examples {color: var(--body-text-color) !important; background-color: var(--block-background-fill) !important;}
19
+ .container {padding-inline: 0 !important}
20
+ .gradio-container { width: 100% !important; }
21
+ .gradio-row { display: flex !important; flex-direction: row !important; }
22
+ .gradio-column { flex: 1 !important; min-width: 0 !important; }
23
+ #sign_in_button {flex-grow: 0; width: auto !important; display: flex; align-items: center; justify-content: center; margin: 0 auto;}
24
+ .datasets {height: 70px;}
25
+ """
26
+
27
+ image = """<br><img src="https://raw.githubusercontent.com/argilla-io/synthetic-data-generator/main/assets/logo.svg" alt="Synthetic Data Generator Logo" style="display: block; margin-left: auto; margin-right: auto; width: clamp(50%, 400px, 100%)"/>"""
28
+
29
+ demo = TabbedInterface(
30
+ [textcat_app, chat_app, rag_app, about_app],
31
+ ["Text Classification", "Chat Data", "RAG", "About"],
32
+ css=css,
33
+ title=image,
34
+ theme=theme,
35
+ )
src/synthetic_dataset_generator/apps/__init__.py ADDED
File without changes
src/synthetic_dataset_generator/apps/about.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ with gr.Blocks() as app:
4
+ gr.Markdown(
5
+ """
6
+ Synthetic data is artificially generated information that mimics real-world data. It allows overcoming data limitations by expanding or enhancing datasets.
7
+
8
+ Introducing the Synthetic Data Generator, a user-friendly application that takes a no-code approach to creating custom datasets with Large Language Models (LLMs). The best part: A simple step-by-step process, making dataset creation a non-technical breeze, allowing anyone to create datasets and models in minutes and without any code.
9
+
10
+ The synthetic data generator takes your custom prompt and returns a dataset for your use case, using a synthetic data pipeline. In the background this is powered by [distilabel](https://distilabel.argilla.io/latest/) and the [free Hugging Face text-generation API](https://huggingface.co/docs/api-inference/en/index) but we don't need to worry about these complexities and we can focus on using the UI.
11
+
12
+ - Read more in [our announcement blog post](https://huggingface.co/blog/synthetic-data-generator)
13
+ - Find the library on [GitHub](https://github.com/argilla-io/synthetic-data-generator)
14
+ """
15
+ )
src/synthetic_dataset_generator/apps/base.py ADDED
@@ -0,0 +1,270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import uuid
3
+ from tqdm import tqdm
4
+ from typing import Union
5
+
6
+ import argilla as rg
7
+ import gradio as gr
8
+ import pandas as pd
9
+ from datasets import Dataset, concatenate_datasets, get_dataset_config_names, get_dataset_split_names, load_dataset
10
+ from gradio import OAuthToken
11
+ from huggingface_hub import HfApi, upload_file, repo_exists
12
+ from unstructured.chunking.title import chunk_by_title
13
+ from unstructured.partition.auto import partition
14
+
15
+ from synthetic_dataset_generator.constants import MAX_NUM_ROWS, SAVE_LOCAL_DIR
16
+ from synthetic_dataset_generator.utils import get_argilla_client
17
+
18
+ if SAVE_LOCAL_DIR is not None:
19
+ import os
20
+ os.makedirs(SAVE_LOCAL_DIR, exist_ok=True)
21
+
22
+
23
+ def validate_argilla_user_workspace_dataset(
24
+ dataset_name: str,
25
+ add_to_existing_dataset: bool = True,
26
+ oauth_token: Union[OAuthToken, None] = None,
27
+ progress=gr.Progress(),
28
+ ) -> str:
29
+ progress(0.1, desc="Validating dataset configuration")
30
+ hf_user = HfApi().whoami(token=oauth_token.token)["name"]
31
+ client = get_argilla_client()
32
+ if dataset_name is None or dataset_name == "":
33
+ raise gr.Error("Dataset name is required")
34
+ # Create user if it doesn't exist
35
+ rg_user = client.users(username=hf_user)
36
+ if rg_user is None:
37
+ rg_user = client.users.add(
38
+ rg.User(username=hf_user, role="admin", password=str(uuid.uuid4()))
39
+ )
40
+ # Create workspace if it doesn't exist
41
+ workspace = client.workspaces(name=hf_user)
42
+ if workspace is None:
43
+ workspace = client.workspaces.add(rg.Workspace(name=hf_user))
44
+ workspace.add_user(hf_user)
45
+ # Check if dataset exists
46
+ dataset = client.datasets(name=dataset_name, workspace=hf_user)
47
+ if dataset and not add_to_existing_dataset:
48
+ raise gr.Error(f"Dataset {dataset_name} already exists")
49
+ progress(1.0, desc="Dataset configuration validated")
50
+ return ""
51
+
52
+
53
+ def push_pipeline_code_to_hub(
54
+ pipeline_code: str,
55
+ org_name: str,
56
+ repo_name: str,
57
+ oauth_token: Union[OAuthToken, None] = None,
58
+ progress=gr.Progress(),
59
+ ):
60
+ repo_id: str | None = validate_push_to_hub(org_name, repo_name)
61
+ progress(0.1, desc="Uploading pipeline code")
62
+ with io.BytesIO(pipeline_code.encode("utf-8")) as f:
63
+ upload_file(
64
+ path_or_fileobj=f,
65
+ path_in_repo="pipeline.py",
66
+ repo_id=repo_id,
67
+ repo_type="dataset",
68
+ token=oauth_token.token,
69
+ commit_message="Include pipeline script",
70
+ create_pr=False,
71
+ )
72
+ progress(1.0, desc="Pipeline code uploaded")
73
+
74
+
75
+ def validate_push_to_hub(org_name: str, repo_name: str):
76
+ repo_id = (
77
+ f"{org_name}/{repo_name}"
78
+ if repo_name is not None and org_name is not None
79
+ else None
80
+ )
81
+ if repo_id is not None:
82
+ if not all([repo_id, org_name, repo_name]):
83
+ raise gr.Error(
84
+ "Please provide a `repo_name` and `org_name` to push the dataset to."
85
+ )
86
+ return repo_id
87
+
88
+
89
+ def combine_datasets(
90
+ repo_id: str, dataset: Dataset, oauth_token: Union[OAuthToken, None]
91
+ ) -> Dataset:
92
+ try:
93
+ new_dataset = load_dataset(
94
+ repo_id,
95
+ split="train",
96
+ download_mode="force_redownload",
97
+ token=oauth_token.token,
98
+ )
99
+ return concatenate_datasets([dataset, new_dataset])
100
+ except Exception:
101
+ return dataset
102
+
103
+
104
+ def show_success_message(org_name: str, repo_name: str) -> gr.Markdown:
105
+ client = get_argilla_client()
106
+ if client is None:
107
+ return gr.Markdown(
108
+ value=f"""
109
+ <div style="padding: 1em; background-color: var(--block-background-fill); border-color: var(--border-color-primary); border-width: 1px; border-radius: 5px;">
110
+ <h3 style="color: #2e7d32; margin: 0;">Dataset Published Successfully!</h3>
111
+ <p style="margin-top: 0.5em;">
112
+ The generated dataset is in the right format for fine-tuning with TRL, AutoTrain, or other frameworks.
113
+ <div style="display: flex; gap: 10px;">
114
+ <a href="https://huggingface.co/datasets/{org_name}/{repo_name}" target="_blank" class="lg primary svelte-1137axg" style="color: white !important; margin-top: 0.5em; text-decoration: none;">
115
+ Open in Hugging Face
116
+ </a>
117
+ </div>
118
+ </p>
119
+ <p style="margin-top: 1em; color: var(--block-title-text-color)">
120
+ By configuring an `ARGILLA_API_URL` and `ARGILLA_API_KEY` you can curate the dataset in Argilla.
121
+ Unfamiliar with Argilla? Here are some docs to help you get started:
122
+ <br>• <a href="https://docs.argilla.io/latest/getting_started/quickstart/" target="_blank">How to get started with Argilla</a>
123
+ <br>• <a href="https://docs.argilla.io/latest/how_to_guides/annotate/" target="_blank">How to curate data in Argilla</a>
124
+ <br>• <a href="https://docs.argilla.io/latest/how_to_guides/import_export/" target="_blank">How to export data once you have reviewed the dataset</a>
125
+ </p>
126
+ </div>
127
+ """,
128
+ visible=True,
129
+ height=None,
130
+ min_height=None,
131
+ max_height=None,
132
+ )
133
+ argilla_api_url = client.api_url
134
+ # Transform Docker internal URL to localhost if needed
135
+ if "argilla:" in argilla_api_url:
136
+ argilla_api_url = argilla_api_url.replace("argilla:", "127.0.0.1:")
137
+ return gr.Markdown(
138
+ value=f"""
139
+ <div style="padding: 1em; background-color: var(--block-background-fill); border-color: var(--border-color-primary); border-width: 1px; border-radius: 5px;">
140
+ <h3 style="color: #2e7d32; margin: 0;">Dataset Published Successfully!</h3>
141
+ <p style="margin-top: 0.5em;">
142
+ The generated dataset is <a href="https://huggingface.co/datasets/{org_name}/{repo_name}" target="_blank">available in the Hub</a>. It is in the right format for fine-tuning with TRL, AutoTrain, or other frameworks.
143
+ <div style="display: flex; gap: 10px;">
144
+ <a href="{argilla_api_url}" target="_blank" class="lg primary svelte-1137axg" style="color: white !important; margin-top: 0.5em; text-decoration: none;">
145
+ Open in Argilla
146
+ </a>
147
+ </div>
148
+ </p>
149
+ <p style="margin-top: 1em; color: var(--block-title-text-color)">
150
+ Unfamiliar with Argilla? Here are some docs to help you get started:
151
+ <br>• <a href="https://docs.argilla.io/latest/how_to_guides/annotate/" target="_blank">How to curate data in Argilla</a>
152
+ <br>• <a href="https://docs.argilla.io/latest/how_to_guides/import_export/" target="_blank">How to export data once you have reviewed the dataset</a>
153
+ </p>
154
+ </div>
155
+ """,
156
+ visible=True,
157
+ height=None,
158
+ min_height=None,
159
+ max_height=None,
160
+ )
161
+
162
+
163
+ def hide_success_message() -> gr.Markdown:
164
+ return gr.Markdown(value="", visible=True, height=100)
165
+
166
+
167
+ def test_max_num_rows(num_rows: int) -> int:
168
+ if num_rows > MAX_NUM_ROWS:
169
+ num_rows = MAX_NUM_ROWS
170
+ gr.Info(
171
+ f"Number of rows is larger than the configured maximum. Setting number of rows to {MAX_NUM_ROWS}. Set environment variable `MAX_NUM_ROWS` to change this behavior."
172
+ )
173
+ return num_rows
174
+
175
+
176
+ def get_iframe(hub_repo_id: str) -> str:
177
+ if not hub_repo_id:
178
+ return ""
179
+
180
+ if not repo_exists(repo_id=hub_repo_id, repo_type="dataset"):
181
+ return ""
182
+
183
+ url = f"https://huggingface.co/datasets/{hub_repo_id}/embed/viewer"
184
+ iframe = f"""
185
+ <iframe
186
+ src="{url}"
187
+ frameborder="0"
188
+ width="100%"
189
+ height="600px"
190
+ ></iframe>
191
+ """
192
+ return iframe
193
+
194
+
195
+ def _get_valid_columns(dataframe: pd.DataFrame):
196
+ doc_valid_columns = []
197
+
198
+ for col in dataframe.columns:
199
+ sample_val = dataframe[col].iloc[0]
200
+ if isinstance(sample_val, str):
201
+ doc_valid_columns.append(col)
202
+
203
+ return doc_valid_columns
204
+
205
+
206
+ def load_dataset_from_hub(
207
+ repo_id: str,
208
+ num_rows: int = 10,
209
+ token: Union[OAuthToken, None] = None,
210
+ progress=gr.Progress(track_tqdm=True),
211
+ ):
212
+ if not repo_id:
213
+ raise gr.Error("Please provide a Hub repo ID")
214
+ subsets = get_dataset_config_names(repo_id, token=token)
215
+ splits = get_dataset_split_names(repo_id, subsets[0], token=token)
216
+ ds = load_dataset(repo_id, subsets[0], split=splits[0], token=token, streaming=True)
217
+ rows = []
218
+ for idx, row in enumerate(tqdm(ds, desc="Loading the dataset", total=num_rows)):
219
+ rows.append(row)
220
+ if idx == num_rows:
221
+ break
222
+ ds = Dataset.from_list(rows)
223
+ dataframe = ds.to_pandas()
224
+ doc_valid_columns = _get_valid_columns(dataframe)
225
+ col_doc = doc_valid_columns[0] if doc_valid_columns else ""
226
+ return (
227
+ dataframe,
228
+ gr.Dropdown(
229
+ choices=doc_valid_columns,
230
+ label="Documents column",
231
+ value=col_doc,
232
+ interactive=(False if col_doc == "" else True),
233
+ multiselect=False,
234
+ ),
235
+ )
236
+
237
+
238
+ def preprocess_input_data(
239
+ file_paths: list[str], num_rows: int, progress=gr.Progress(track_tqdm=True)
240
+ ):
241
+ if not file_paths:
242
+ raise gr.Error("Please provide an input file")
243
+
244
+ data = {}
245
+ total_chunks = 0
246
+
247
+ for file_path in tqdm(file_paths, desc="Processing files", total=len(file_paths)):
248
+ partitioned_file = partition(filename=file_path)
249
+ chunks = [str(chunk) for chunk in chunk_by_title(partitioned_file)]
250
+ data[file_path] = chunks
251
+ total_chunks += len(chunks)
252
+ if total_chunks >= num_rows:
253
+ break
254
+
255
+ dataframe = pd.DataFrame.from_records(
256
+ [(k, v) for k, values in data.items() for v in values],
257
+ columns=["filename", "chunks"],
258
+ )
259
+ col_doc = "chunks"
260
+
261
+ return (
262
+ dataframe,
263
+ gr.Dropdown(
264
+ choices=["chunks"],
265
+ label="Documents column",
266
+ value=col_doc,
267
+ interactive=(False if col_doc == "" else True),
268
+ multiselect=False,
269
+ ),
270
+ )
src/synthetic_dataset_generator/apps/chat.py ADDED
@@ -0,0 +1,1142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ast
2
+ import json
3
+ import os
4
+ import random
5
+ import uuid
6
+ from typing import Dict, List, Union
7
+
8
+ import argilla as rg
9
+ import gradio as gr
10
+ import pandas as pd
11
+ from datasets import Dataset
12
+ from distilabel.distiset import Distiset
13
+ from gradio.oauth import OAuthToken
14
+ from gradio_huggingfacehub_search import HuggingfaceHubSearch
15
+ from huggingface_hub import HfApi
16
+
17
+ from synthetic_dataset_generator.apps.base import (
18
+ combine_datasets,
19
+ hide_success_message,
20
+ load_dataset_from_hub,
21
+ preprocess_input_data,
22
+ push_pipeline_code_to_hub,
23
+ show_success_message,
24
+ test_max_num_rows,
25
+ validate_argilla_user_workspace_dataset,
26
+ validate_push_to_hub,
27
+ )
28
+ from synthetic_dataset_generator.constants import (
29
+ BASE_URL,
30
+ DEFAULT_BATCH_SIZE,
31
+ MODEL,
32
+ MODEL_COMPLETION,
33
+ SAVE_LOCAL_DIR,
34
+ SFT_AVAILABLE,
35
+ )
36
+ from synthetic_dataset_generator.pipelines.base import get_rewritten_prompts
37
+ from synthetic_dataset_generator.pipelines.chat import (
38
+ DEFAULT_DATASET_DESCRIPTIONS,
39
+ generate_pipeline_code,
40
+ get_follow_up_generator,
41
+ get_magpie_generator,
42
+ get_prompt_generator,
43
+ get_response_generator,
44
+ get_sentence_pair_generator,
45
+ )
46
+ from synthetic_dataset_generator.pipelines.embeddings import (
47
+ get_embeddings,
48
+ get_sentence_embedding_dimensions,
49
+ )
50
+ from synthetic_dataset_generator.utils import (
51
+ column_to_list,
52
+ get_argilla_client,
53
+ get_org_dropdown,
54
+ get_random_repo_name,
55
+ swap_visibility,
56
+ )
57
+
58
+
59
+ def _get_dataframe():
60
+ return gr.Dataframe(
61
+ headers=["prompt", "completion"],
62
+ wrap=True,
63
+ interactive=False,
64
+ )
65
+
66
+
67
+ def convert_dataframe_messages(dataframe: pd.DataFrame) -> pd.DataFrame:
68
+ def convert_to_list_of_dicts(messages: str) -> List[Dict[str, str]]:
69
+ return ast.literal_eval(
70
+ messages.replace("'user'}", "'user'},")
71
+ .replace("'system'}", "'system'},")
72
+ .replace("'assistant'}", "'assistant'},")
73
+ )
74
+
75
+ if "messages" in dataframe.columns:
76
+ dataframe["messages"] = dataframe["messages"].apply(
77
+ lambda x: convert_to_list_of_dicts(x) if isinstance(x, str) else x
78
+ )
79
+ return dataframe
80
+
81
+
82
+ def generate_system_prompt(dataset_description: str, progress=gr.Progress()):
83
+ progress(0.1, desc="Initializing")
84
+ generate_description = get_prompt_generator()
85
+ progress(0.5, desc="Generating")
86
+ result = next(
87
+ generate_description.process(
88
+ [
89
+ {
90
+ "instruction": dataset_description,
91
+ }
92
+ ]
93
+ )
94
+ )[0]["generation"]
95
+ progress(1.0, desc="Prompt generated")
96
+ return result
97
+
98
+
99
+ def load_dataset_file(
100
+ repo_id: str,
101
+ file_paths: list[str],
102
+ input_type: str,
103
+ num_rows: int = 10,
104
+ token: Union[OAuthToken, None] = None,
105
+ progress=gr.Progress(),
106
+ ):
107
+ progress(0.1, desc="Loading the source data")
108
+ if input_type == "dataset-input":
109
+ return load_dataset_from_hub(repo_id=repo_id, num_rows=num_rows, token=token)
110
+ else:
111
+ return preprocess_input_data(file_paths=file_paths, num_rows=num_rows)
112
+
113
+
114
+ def generate_sample_dataset(
115
+ repo_id: str,
116
+ file_paths: list[str],
117
+ input_type: str,
118
+ system_prompt: str,
119
+ document_column: str,
120
+ num_turns: int,
121
+ num_rows: int,
122
+ oauth_token: Union[OAuthToken, None],
123
+ progress=gr.Progress(),
124
+ ):
125
+ if input_type == "prompt-input":
126
+ dataframe = pd.DataFrame(columns=["prompt", "completion"])
127
+ else:
128
+ dataframe, _ = load_dataset_file(
129
+ repo_id=repo_id,
130
+ file_paths=file_paths,
131
+ input_type=input_type,
132
+ num_rows=num_rows,
133
+ token=oauth_token,
134
+ )
135
+ progress(0.5, desc="Generating sample dataset")
136
+ dataframe = generate_dataset(
137
+ input_type=input_type,
138
+ dataframe=dataframe,
139
+ system_prompt=system_prompt,
140
+ document_column=document_column,
141
+ num_turns=num_turns,
142
+ num_rows=num_rows,
143
+ is_sample=True,
144
+ )
145
+ progress(1.0, desc="Sample dataset generated")
146
+ return dataframe
147
+
148
+
149
+ def generate_dataset_from_prompt(
150
+ system_prompt: str,
151
+ num_turns: int = 1,
152
+ num_rows: int = 10,
153
+ temperature: float = 0.9,
154
+ temperature_completion: Union[float, None] = None,
155
+ is_sample: bool = False,
156
+ progress=gr.Progress(),
157
+ ) -> pd.DataFrame:
158
+ num_rows = test_max_num_rows(num_rows)
159
+ progress(0.0, desc="(1/2) Generating instructions")
160
+ magpie_generator = get_magpie_generator(num_turns, temperature, is_sample)
161
+ response_generator = get_response_generator(
162
+ system_prompt=system_prompt,
163
+ num_turns=num_turns,
164
+ temperature=temperature or temperature_completion,
165
+ is_sample=is_sample,
166
+ )
167
+ total_steps: int = num_rows * 2
168
+ batch_size = DEFAULT_BATCH_SIZE
169
+
170
+ # create prompt rewrites
171
+ prompt_rewrites = get_rewritten_prompts(system_prompt, num_rows)
172
+
173
+ # create instructions
174
+ n_processed = 0
175
+ magpie_results = []
176
+ while n_processed < num_rows:
177
+ progress(
178
+ 0.5 * n_processed / num_rows,
179
+ total=total_steps,
180
+ desc="(1/2) Generating instructions",
181
+ )
182
+ remaining_rows = num_rows - n_processed
183
+ batch_size = min(batch_size, remaining_rows)
184
+ rewritten_system_prompt = random.choice(prompt_rewrites)
185
+ inputs = [{"system_prompt": rewritten_system_prompt} for _ in range(batch_size)]
186
+ batch = list(magpie_generator.process(inputs=inputs))
187
+ magpie_results.extend(batch[0])
188
+ n_processed += batch_size
189
+ random.seed(a=random.randint(0, 2**32 - 1))
190
+ progress(0.5, desc="(1/2) Generating instructions")
191
+
192
+ # generate responses
193
+ n_processed = 0
194
+ response_results = []
195
+ if num_turns == 1:
196
+ while n_processed < num_rows:
197
+ progress(
198
+ 0.5 + 0.5 * n_processed / num_rows,
199
+ total=total_steps,
200
+ desc="(2/2) Generating responses",
201
+ )
202
+ batch = magpie_results[n_processed : n_processed + batch_size]
203
+ responses = list(response_generator.process(inputs=batch))
204
+ response_results.extend(responses[0])
205
+ n_processed += batch_size
206
+ random.seed(a=random.randint(0, 2**32 - 1))
207
+ for result in response_results:
208
+ result["prompt"] = result["instruction"]
209
+ result["completion"] = result["generation"]
210
+ result["system_prompt"] = system_prompt
211
+ else:
212
+ for result in magpie_results:
213
+ result["conversation"].insert(
214
+ 0, {"role": "system", "content": system_prompt}
215
+ )
216
+ result["messages"] = result["conversation"]
217
+ while n_processed < num_rows:
218
+ progress(
219
+ 0.5 + 0.5 * n_processed / num_rows,
220
+ total=total_steps,
221
+ desc="(2/2) Generating responses",
222
+ )
223
+ batch = magpie_results[n_processed : n_processed + batch_size]
224
+ responses = list(response_generator.process(inputs=batch))
225
+ response_results.extend(responses[0])
226
+ n_processed += batch_size
227
+ random.seed(a=random.randint(0, 2**32 - 1))
228
+ for result in response_results:
229
+ result["messages"].append(
230
+ {"role": "assistant", "content": result["generation"]}
231
+ )
232
+ progress(
233
+ 1,
234
+ total=total_steps,
235
+ desc="(2/2) Creating dataset",
236
+ )
237
+
238
+ # create distiset
239
+ distiset_results = []
240
+ for result in response_results:
241
+ record = {}
242
+ for relevant_keys in [
243
+ "messages",
244
+ "prompt",
245
+ "completion",
246
+ "model_name",
247
+ "system_prompt",
248
+ ]:
249
+ if relevant_keys in result:
250
+ record[relevant_keys] = result[relevant_keys]
251
+ distiset_results.append(record)
252
+
253
+ distiset = Distiset(
254
+ {
255
+ "default": Dataset.from_list(distiset_results),
256
+ }
257
+ )
258
+
259
+ # If not pushing to hub generate the dataset directly
260
+ distiset = distiset["default"]
261
+ if num_turns == 1:
262
+ outputs = distiset.to_pandas()[["prompt", "completion", "system_prompt"]]
263
+ else:
264
+ outputs = distiset.to_pandas()[["messages"]]
265
+ dataframe = pd.DataFrame(outputs)
266
+ progress(1.0, desc="Dataset generation completed")
267
+ return dataframe
268
+
269
+
270
+ def generate_dataset_from_seed(
271
+ dataframe: pd.DataFrame,
272
+ document_column: str,
273
+ num_turns: int = 1,
274
+ num_rows: int = 10,
275
+ temperature: float = 0.9,
276
+ temperature_completion: Union[float, None] = None,
277
+ is_sample: bool = False,
278
+ progress=gr.Progress(),
279
+ ) -> pd.DataFrame:
280
+ num_rows = test_max_num_rows(num_rows)
281
+ progress(0.0, desc="Initializing dataset generation")
282
+ document_data = column_to_list(dataframe, document_column)
283
+ if len(document_data) < num_rows:
284
+ document_data += random.choices(document_data, k=num_rows - len(document_data))
285
+ instruction_generator = get_sentence_pair_generator(
286
+ temperature=temperature, is_sample=is_sample
287
+ )
288
+ response_generator = get_response_generator(
289
+ system_prompt=None,
290
+ num_turns=1,
291
+ temperature=temperature or temperature_completion,
292
+ is_sample=is_sample,
293
+ )
294
+ follow_up_generator_instruction = get_follow_up_generator(
295
+ type="instruction", temperature=temperature, is_sample=is_sample
296
+ )
297
+ follow_up_generator_response = get_follow_up_generator(
298
+ type="response",
299
+ temperature=temperature or temperature_completion,
300
+ is_sample=is_sample,
301
+ )
302
+ steps = 2 * num_turns
303
+ total_steps: int = num_rows * steps
304
+ step_progress = round(1 / steps, 2)
305
+ batch_size = DEFAULT_BATCH_SIZE
306
+
307
+ # create instructions
308
+ n_processed = 0
309
+ instruction_results = []
310
+ while n_processed < num_rows:
311
+ progress(
312
+ step_progress * n_processed / num_rows,
313
+ total=total_steps,
314
+ desc="Generating instructions",
315
+ )
316
+ remaining_rows = num_rows - n_processed
317
+ batch_size = min(batch_size, remaining_rows)
318
+ batch = [
319
+ {"anchor": document}
320
+ for document in document_data[n_processed : n_processed + batch_size]
321
+ ]
322
+ questions = list(instruction_generator.process(inputs=batch))
323
+ instruction_results.extend(questions[0])
324
+ n_processed += batch_size
325
+ for result in instruction_results:
326
+ result["instruction"] = result["positive"]
327
+ result["prompt"] = result.pop("positive")
328
+
329
+ progress(step_progress, desc="Generating instructions")
330
+
331
+ # generate responses
332
+ n_processed = 0
333
+ response_results = []
334
+ while n_processed < num_rows:
335
+ progress(
336
+ step_progress + step_progress * n_processed / num_rows,
337
+ total=total_steps,
338
+ desc="Generating responses",
339
+ )
340
+ batch = instruction_results[n_processed : n_processed + batch_size]
341
+ responses = list(response_generator.process(inputs=batch))
342
+ response_results.extend(responses[0])
343
+ n_processed += batch_size
344
+ for result in response_results:
345
+ result["completion"] = result.pop("generation")
346
+
347
+ # generate follow-ups
348
+ if num_turns > 1:
349
+ n_processed = 0
350
+ final_conversations = []
351
+
352
+ while n_processed < num_rows:
353
+ progress(
354
+ step_progress + step_progress * n_processed / num_rows,
355
+ total=total_steps,
356
+ desc="Generating follow-ups",
357
+ )
358
+ batch = response_results[n_processed : n_processed + batch_size]
359
+ conversations_batch = [
360
+ {
361
+ "messages": [
362
+ {"role": "user", "content": result["prompt"]},
363
+ {"role": "assistant", "content": result["completion"]},
364
+ ]
365
+ }
366
+ for result in batch
367
+ ]
368
+
369
+ for _ in range(num_turns - 1):
370
+ follow_up_instructions = list(
371
+ follow_up_generator_instruction.process(inputs=conversations_batch)
372
+ )
373
+ for conv, follow_up in zip(
374
+ conversations_batch, follow_up_instructions[0]
375
+ ):
376
+ conv["messages"].append(
377
+ {"role": "user", "content": follow_up["generation"]}
378
+ )
379
+
380
+ follow_up_responses = list(
381
+ follow_up_generator_response.process(inputs=conversations_batch)
382
+ )
383
+ for conv, follow_up in zip(conversations_batch, follow_up_responses[0]):
384
+ conv["messages"].append(
385
+ {"role": "assistant", "content": follow_up["generation"]}
386
+ )
387
+
388
+ final_conversations.extend(
389
+ [{"messages": conv["messages"]} for conv in conversations_batch]
390
+ )
391
+ n_processed += batch_size
392
+
393
+ # create distiset
394
+ distiset_results = []
395
+ if num_turns == 1:
396
+ for result in response_results:
397
+ record = {}
398
+ for relevant_keys in ["prompt", "completion"]:
399
+ if relevant_keys in result:
400
+ record[relevant_keys] = result[relevant_keys]
401
+ distiset_results.append(record)
402
+ dataframe = pd.DataFrame(distiset_results)
403
+ else:
404
+ distiset_results = final_conversations
405
+ dataframe = pd.DataFrame(distiset_results)
406
+ dataframe["messages"] = dataframe["messages"].apply(lambda x: json.dumps(x))
407
+
408
+ progress(1.0, desc="Dataset generation completed")
409
+ return dataframe
410
+
411
+
412
+ def generate_dataset(
413
+ input_type: str,
414
+ dataframe: pd.DataFrame,
415
+ system_prompt: str,
416
+ document_column: str,
417
+ num_turns: int = 1,
418
+ num_rows: int = 10,
419
+ temperature: float = 0.9,
420
+ temperature_completion: Union[float, None] = None,
421
+ is_sample: bool = False,
422
+ progress=gr.Progress(),
423
+ ) -> pd.DataFrame:
424
+ if input_type == "prompt-input":
425
+ dataframe = generate_dataset_from_prompt(
426
+ system_prompt=system_prompt,
427
+ num_turns=num_turns,
428
+ num_rows=num_rows,
429
+ temperature=temperature,
430
+ temperature_completion=temperature_completion,
431
+ is_sample=is_sample,
432
+ )
433
+ else:
434
+ dataframe = generate_dataset_from_seed(
435
+ dataframe=dataframe,
436
+ document_column=document_column,
437
+ num_turns=num_turns,
438
+ num_rows=num_rows,
439
+ temperature=temperature,
440
+ temperature_completion=temperature_completion,
441
+ is_sample=is_sample,
442
+ )
443
+ return dataframe
444
+
445
+
446
+ def push_dataset_to_hub(
447
+ dataframe: pd.DataFrame,
448
+ org_name: str,
449
+ repo_name: str,
450
+ oauth_token: Union[gr.OAuthToken, None],
451
+ private: bool,
452
+ pipeline_code: str,
453
+ progress=gr.Progress(),
454
+ ):
455
+ progress(0.0, desc="Validating")
456
+ repo_id = validate_push_to_hub(org_name, repo_name)
457
+ progress(0.3, desc="Converting")
458
+ original_dataframe = dataframe.copy(deep=True)
459
+ dataframe = convert_dataframe_messages(dataframe)
460
+ progress(0.7, desc="Creating dataset")
461
+ dataset = Dataset.from_pandas(dataframe)
462
+ dataset = combine_datasets(repo_id, dataset, oauth_token)
463
+ progress(0.9, desc="Pushing dataset")
464
+ distiset = Distiset({"default": dataset})
465
+ distiset.push_to_hub(
466
+ repo_id=repo_id,
467
+ private=private,
468
+ include_script=False,
469
+ token=oauth_token.token,
470
+ create_pr=False,
471
+ )
472
+ push_pipeline_code_to_hub(pipeline_code, org_name, repo_name, oauth_token)
473
+ progress(1.0, desc="Dataset pushed")
474
+ return original_dataframe
475
+
476
+
477
+ def push_dataset(
478
+ org_name: str,
479
+ repo_name: str,
480
+ private: bool,
481
+ original_repo_id: str,
482
+ file_paths: list[str],
483
+ input_type: str,
484
+ system_prompt: str,
485
+ document_column: str,
486
+ num_turns: int = 1,
487
+ num_rows: int = 10,
488
+ temperature: float = 0.9,
489
+ temperature_completion: Union[float, None] = None,
490
+ pipeline_code: str = "",
491
+ oauth_token: Union[gr.OAuthToken, None] = None,
492
+ progress=gr.Progress(),
493
+ ) -> pd.DataFrame:
494
+ if input_type == "prompt-input":
495
+ dataframe = _get_dataframe()
496
+ else:
497
+ dataframe, _ = load_dataset_file(
498
+ repo_id=original_repo_id,
499
+ file_paths=file_paths,
500
+ input_type=input_type,
501
+ num_rows=num_rows,
502
+ token=oauth_token,
503
+ )
504
+ progress(0.5, desc="Generating dataset")
505
+ dataframe = generate_dataset(
506
+ input_type=input_type,
507
+ dataframe=dataframe,
508
+ system_prompt=system_prompt,
509
+ document_column=document_column,
510
+ num_turns=num_turns,
511
+ num_rows=num_rows,
512
+ temperature=temperature,
513
+ temperature_completion=temperature_completion,
514
+ )
515
+ push_dataset_to_hub(
516
+ dataframe=dataframe,
517
+ org_name=org_name,
518
+ repo_name=repo_name,
519
+ oauth_token=oauth_token,
520
+ private=private,
521
+ pipeline_code=pipeline_code,
522
+ )
523
+ try:
524
+ progress(0.1, desc="Setting up user and workspace")
525
+ hf_user = HfApi().whoami(token=oauth_token.token)["name"]
526
+ client = get_argilla_client()
527
+ if client is None:
528
+ return ""
529
+ progress(0.5, desc="Creating dataset in Argilla")
530
+ if "messages" in dataframe.columns:
531
+ settings = rg.Settings(
532
+ fields=[
533
+ rg.ChatField(
534
+ name="messages",
535
+ description="The messages in the conversation",
536
+ title="Messages",
537
+ ),
538
+ ],
539
+ questions=[
540
+ rg.RatingQuestion(
541
+ name="rating",
542
+ title="Rating",
543
+ description="The rating of the conversation",
544
+ values=list(range(1, 6)),
545
+ ),
546
+ ],
547
+ metadata=[
548
+ rg.IntegerMetadataProperty(
549
+ name="user_message_length", title="User Message Length"
550
+ ),
551
+ rg.IntegerMetadataProperty(
552
+ name="assistant_message_length",
553
+ title="Assistant Message Length",
554
+ ),
555
+ ],
556
+ vectors=[
557
+ rg.VectorField(
558
+ name="messages_embeddings",
559
+ dimensions=get_sentence_embedding_dimensions(),
560
+ )
561
+ ],
562
+ guidelines="Please review the conversation and provide a score for the assistant's response.",
563
+ )
564
+
565
+ dataframe["user_message_length"] = dataframe["messages"].apply(
566
+ lambda x: sum([len(y["content"]) for y in x if y["role"] == "user"])
567
+ )
568
+ dataframe["assistant_message_length"] = dataframe["messages"].apply(
569
+ lambda x: sum(
570
+ [len(y["content"]) for y in x if y["role"] == "assistant"]
571
+ )
572
+ )
573
+ dataframe["messages_embeddings"] = get_embeddings(
574
+ dataframe["messages"].apply(
575
+ lambda x: " ".join([y["content"] for y in x])
576
+ )
577
+ )
578
+ else:
579
+ settings = rg.Settings(
580
+ fields=[
581
+ rg.TextField(
582
+ name="system_prompt",
583
+ title="System Prompt",
584
+ description="The system prompt used for the conversation",
585
+ required=False,
586
+ ),
587
+ rg.TextField(
588
+ name="prompt",
589
+ title="Prompt",
590
+ description="The prompt used for the conversation",
591
+ ),
592
+ rg.TextField(
593
+ name="completion",
594
+ title="Completion",
595
+ description="The completion from the assistant",
596
+ ),
597
+ ],
598
+ questions=[
599
+ rg.RatingQuestion(
600
+ name="rating",
601
+ title="Rating",
602
+ description="The rating of the conversation",
603
+ values=list(range(1, 6)),
604
+ ),
605
+ ],
606
+ metadata=[
607
+ rg.IntegerMetadataProperty(
608
+ name="prompt_length", title="Prompt Length"
609
+ ),
610
+ rg.IntegerMetadataProperty(
611
+ name="completion_length", title="Completion Length"
612
+ ),
613
+ ],
614
+ vectors=[
615
+ rg.VectorField(
616
+ name="prompt_embeddings",
617
+ dimensions=get_sentence_embedding_dimensions(),
618
+ )
619
+ ],
620
+ guidelines="Please review the conversation and correct the prompt and completion where needed.",
621
+ )
622
+ dataframe["prompt_length"] = dataframe["prompt"].apply(len)
623
+ dataframe["completion_length"] = dataframe["completion"].apply(len)
624
+ dataframe["prompt_embeddings"] = get_embeddings(dataframe["prompt"])
625
+
626
+ rg_dataset = client.datasets(name=repo_name, workspace=hf_user)
627
+ if rg_dataset is None:
628
+ rg_dataset = rg.Dataset(
629
+ name=repo_name,
630
+ workspace=hf_user,
631
+ settings=settings,
632
+ client=client,
633
+ )
634
+ rg_dataset = rg_dataset.create()
635
+ progress(0.7, desc="Pushing dataset to Argilla")
636
+ hf_dataset = Dataset.from_pandas(dataframe)
637
+ rg_dataset.records.log(records=hf_dataset)
638
+ progress(1.0, desc="Dataset pushed to Argilla")
639
+ except Exception as e:
640
+ raise gr.Error(f"Error pushing dataset to Argilla: {e}")
641
+ return ""
642
+
643
+
644
+ def save_local(
645
+ repo_id: str,
646
+ file_paths: list[str],
647
+ input_type: str,
648
+ system_prompt: str,
649
+ document_column: str,
650
+ num_turns: int,
651
+ num_rows: int,
652
+ temperature: float,
653
+ repo_name: str,
654
+ temperature_completion: Union[float, None] = None,
655
+ ) -> pd.DataFrame:
656
+ if input_type == "prompt-input":
657
+ dataframe = _get_dataframe()
658
+ else:
659
+ dataframe, _ = load_dataset_file(
660
+ repo_id=repo_id,
661
+ file_paths=file_paths,
662
+ input_type=input_type,
663
+ num_rows=num_rows,
664
+ )
665
+ dataframe = generate_dataset(
666
+ input_type=input_type,
667
+ dataframe=dataframe,
668
+ system_prompt=system_prompt,
669
+ document_column=document_column,
670
+ num_turns=num_turns,
671
+ num_rows=num_rows,
672
+ temperature=temperature,
673
+ temperature_completion=temperature_completion,
674
+ )
675
+ local_dataset = Dataset.from_pandas(dataframe)
676
+ output_csv = os.path.join(SAVE_LOCAL_DIR, repo_name + ".csv")
677
+ output_json = os.path.join(SAVE_LOCAL_DIR, repo_name + ".json")
678
+ local_dataset.to_csv(output_csv, index=False)
679
+ local_dataset.to_json(output_json, index=False)
680
+ return output_csv, output_json
681
+
682
+
683
+ def show_system_prompt_visibility():
684
+ return {system_prompt: gr.Textbox(visible=True)}
685
+
686
+
687
+ def hide_system_prompt_visibility():
688
+ return {system_prompt: gr.Textbox(visible=False)}
689
+
690
+
691
+ def show_document_column_visibility():
692
+ return {document_column: gr.Dropdown(visible=True)}
693
+
694
+
695
+ def hide_document_column_visibility():
696
+ return {
697
+ document_column: gr.Dropdown(
698
+ choices=["Load your data first in step 1."],
699
+ value="Load your data first in step 1.",
700
+ visible=False,
701
+ )
702
+ }
703
+
704
+
705
+ def show_pipeline_code_visibility():
706
+ return {pipeline_code_ui: gr.Accordion(visible=True)}
707
+
708
+
709
+ def hide_pipeline_code_visibility():
710
+ return {pipeline_code_ui: gr.Accordion(visible=False)}
711
+
712
+
713
+ def show_temperature_completion():
714
+ if MODEL != MODEL_COMPLETION:
715
+ return {temperature_completion: gr.Slider(value=0.9, visible=True)}
716
+
717
+
718
+ def show_save_local_button():
719
+ return {btn_save_local: gr.Button(visible=True)}
720
+
721
+
722
+ def hide_save_local_button():
723
+ return {btn_save_local: gr.Button(visible=False)}
724
+
725
+
726
+ def show_save_local():
727
+ gr.update(success_message, min_height=0)
728
+ return {
729
+ csv_file: gr.File(visible=True),
730
+ json_file: gr.File(visible=True),
731
+ success_message: success_message
732
+ }
733
+
734
+ def hide_save_local():
735
+ gr.update(success_message, min_height=100)
736
+ return {
737
+ csv_file: gr.File(visible=False),
738
+ json_file: gr.File(visible=False),
739
+ success_message: success_message,
740
+ }
741
+
742
+
743
+ ######################
744
+ # Gradio UI
745
+ ######################
746
+
747
+
748
+ with gr.Blocks() as app:
749
+ with gr.Column() as main_ui:
750
+ if not SFT_AVAILABLE:
751
+ gr.Markdown(
752
+ value="\n".join(
753
+ [
754
+ "## Supervised Fine-Tuning not available",
755
+ "",
756
+ f"This tool relies on the [Magpie](https://arxiv.org/abs/2406.08464) prequery template, which is not implemented for the {MODEL} with {BASE_URL}.",
757
+ "Use Llama3 or Qwen2 models with Hugging Face Inference Endpoints.",
758
+ ]
759
+ )
760
+ )
761
+ else:
762
+ gr.Markdown("## 1. Select your input")
763
+ with gr.Row(equal_height=False):
764
+ with gr.Column(scale=2):
765
+ input_type = gr.Dropdown(
766
+ label="Input type",
767
+ choices=["prompt-input", "dataset-input", "file-input"],
768
+ value="prompt-input",
769
+ multiselect=False,
770
+ visible=False,
771
+ )
772
+ with gr.Tab("Generate from prompt") as tab_prompt_input:
773
+ with gr.Row(equal_height=False):
774
+ with gr.Column(scale=2):
775
+ dataset_description = gr.Textbox(
776
+ label="Dataset description",
777
+ placeholder="Give a precise description of your desired dataset.",
778
+ )
779
+ with gr.Row():
780
+ clear_prompt_btn_part = gr.Button(
781
+ "Clear", variant="secondary"
782
+ )
783
+ load_prompt_btn = gr.Button(
784
+ "Create", variant="primary"
785
+ )
786
+ with gr.Column(scale=3):
787
+ examples = gr.Examples(
788
+ examples=DEFAULT_DATASET_DESCRIPTIONS,
789
+ inputs=[dataset_description],
790
+ cache_examples=False,
791
+ label="Examples",
792
+ )
793
+ with gr.Tab("Load from Hub") as tab_dataset_input:
794
+ with gr.Row(equal_height=False):
795
+ with gr.Column(scale=2):
796
+ search_in = HuggingfaceHubSearch(
797
+ label="Search",
798
+ placeholder="Search for a dataset",
799
+ search_type="dataset",
800
+ sumbit_on_select=True,
801
+ )
802
+ with gr.Row():
803
+ clear_dataset_btn_part = gr.Button(
804
+ "Clear", variant="secondary"
805
+ )
806
+ load_dataset_btn = gr.Button(
807
+ "Load", variant="primary"
808
+ )
809
+ with gr.Column(scale=3):
810
+ examples = gr.Examples(
811
+ examples=[
812
+ "charris/wikipedia_sample",
813
+ "plaguss/argilla_sdk_docs_raw_unstructured",
814
+ "BeIR/hotpotqa-generated-queries",
815
+ ],
816
+ label="Example datasets",
817
+ fn=lambda x: x,
818
+ inputs=[search_in],
819
+ run_on_click=True,
820
+ )
821
+ search_out = gr.HTML(
822
+ label="Dataset preview", visible=False
823
+ )
824
+ with gr.Tab("Load your file") as tab_file_input:
825
+ with gr.Row(equal_height=False):
826
+ with gr.Column(scale=2):
827
+ file_in = gr.File(
828
+ label="Upload your file. Supported formats: .md, .txt, .docx, .pdf",
829
+ file_count="multiple",
830
+ file_types=[".md", ".txt", ".docx", ".pdf"],
831
+ )
832
+ with gr.Row():
833
+ clear_file_btn_part = gr.Button(
834
+ "Clear", variant="secondary"
835
+ )
836
+ load_file_btn = gr.Button("Load", variant="primary")
837
+ with gr.Column(scale=3):
838
+ file_out = gr.HTML(
839
+ label="Dataset preview", visible=False
840
+ )
841
+
842
+ gr.HTML(value="<hr>")
843
+ gr.Markdown(value="## 2. Configure your dataset")
844
+ with gr.Row(equal_height=False):
845
+ with gr.Column(scale=2):
846
+ system_prompt = gr.Textbox(
847
+ label="System prompt",
848
+ placeholder="You are a helpful assistant.",
849
+ )
850
+ document_column = gr.Dropdown(
851
+ label="Document Column",
852
+ info="Select the document column to generate the chat data",
853
+ choices=["Load your data first in step 1."],
854
+ value="Load your data first in step 1.",
855
+ interactive=False,
856
+ multiselect=False,
857
+ allow_custom_value=False,
858
+ visible=False,
859
+ )
860
+ num_turns = gr.Number(
861
+ value=1,
862
+ label="Number of turns in the conversation",
863
+ minimum=1,
864
+ maximum=4,
865
+ step=1,
866
+ interactive=True,
867
+ info="Choose between 1 (single turn with 'instruction-response' columns) and 2-4 (multi-turn conversation with a 'messages' column).",
868
+ )
869
+ with gr.Row():
870
+ clear_btn_full = gr.Button(
871
+ "Clear",
872
+ variant="secondary",
873
+ )
874
+ btn_apply_to_sample_dataset = gr.Button(
875
+ "Save", variant="primary"
876
+ )
877
+ with gr.Column(scale=3):
878
+ dataframe = _get_dataframe()
879
+
880
+ gr.HTML(value="<hr>")
881
+ gr.Markdown(value="## 3. Generate your dataset")
882
+ with gr.Row(equal_height=False):
883
+ with gr.Column(scale=2):
884
+ org_name = get_org_dropdown()
885
+ repo_name = gr.Textbox(
886
+ label="Repo name",
887
+ placeholder="dataset_name",
888
+ value=f"my-distiset-{str(uuid.uuid4())[:8]}",
889
+ interactive=True,
890
+ )
891
+ num_rows = gr.Number(
892
+ label="Number of rows",
893
+ value=10,
894
+ interactive=True,
895
+ scale=1,
896
+ )
897
+ temperature = gr.Slider(
898
+ label="Temperature",
899
+ minimum=0.1,
900
+ maximum=1.5,
901
+ value=0.9,
902
+ step=0.1,
903
+ interactive=True,
904
+ )
905
+ temperature_completion = gr.Slider(
906
+ label="Temperature for completion",
907
+ minimum=0.1,
908
+ maximum=1.5,
909
+ value=None,
910
+ step=0.1,
911
+ interactive=True,
912
+ visible=False,
913
+ )
914
+ private = gr.Checkbox(
915
+ label="Private dataset",
916
+ value=False,
917
+ interactive=True,
918
+ scale=1,
919
+ )
920
+ btn_push_to_hub = gr.Button(
921
+ "Push to Hub", variant="primary", scale=2
922
+ )
923
+ btn_save_local = gr.Button(
924
+ "Save locally", variant="primary", scale=2, visible=False
925
+ )
926
+ with gr.Column(scale=3):
927
+ csv_file = gr.File(
928
+ label="CSV",
929
+ elem_classes="datasets",
930
+ visible=False,
931
+ )
932
+ json_file = gr.File(
933
+ label="JSON",
934
+ elem_classes="datasets",
935
+ visible=False,
936
+ )
937
+ success_message = gr.Markdown(
938
+ visible=False,
939
+ min_height=0 # don't remove this otherwise progress is not visible
940
+ )
941
+ with gr.Accordion(
942
+ "Customize your pipeline with distilabel",
943
+ open=False,
944
+ visible=False,
945
+ ) as pipeline_code_ui:
946
+ code = generate_pipeline_code(
947
+ repo_id=search_in.value,
948
+ input_type=input_type.value,
949
+ system_prompt=system_prompt.value,
950
+ document_column=document_column.value,
951
+ num_turns=num_turns.value,
952
+ num_rows=num_rows.value,
953
+ )
954
+ pipeline_code = gr.Code(
955
+ value=code,
956
+ language="python",
957
+ label="Distilabel Pipeline Code",
958
+ )
959
+
960
+ tab_prompt_input.select(
961
+ fn=lambda: "prompt-input",
962
+ inputs=[],
963
+ outputs=[input_type],
964
+ ).then(fn=show_system_prompt_visibility, inputs=[], outputs=[system_prompt]).then(
965
+ fn=hide_document_column_visibility, inputs=[], outputs=[document_column]
966
+ )
967
+
968
+ tab_dataset_input.select(
969
+ fn=lambda: "dataset-input",
970
+ inputs=[],
971
+ outputs=[input_type],
972
+ ).then(fn=hide_system_prompt_visibility, inputs=[], outputs=[system_prompt]).then(
973
+ fn=show_document_column_visibility, inputs=[], outputs=[document_column]
974
+ )
975
+
976
+ tab_file_input.select(
977
+ fn=lambda: "file-input",
978
+ inputs=[],
979
+ outputs=[input_type],
980
+ ).then(fn=hide_system_prompt_visibility, inputs=[], outputs=[system_prompt]).then(
981
+ fn=show_document_column_visibility, inputs=[], outputs=[document_column]
982
+ )
983
+
984
+ search_in.submit(
985
+ fn=lambda df: pd.DataFrame(columns=df.columns),
986
+ inputs=[dataframe],
987
+ outputs=[dataframe],
988
+ )
989
+
990
+ load_prompt_btn.click(
991
+ fn=generate_system_prompt,
992
+ inputs=[dataset_description],
993
+ outputs=[system_prompt],
994
+ ).success(
995
+ fn=generate_sample_dataset,
996
+ inputs=[
997
+ search_in,
998
+ file_in,
999
+ input_type,
1000
+ system_prompt,
1001
+ document_column,
1002
+ num_turns,
1003
+ num_rows,
1004
+ ],
1005
+ outputs=dataframe,
1006
+ )
1007
+
1008
+ gr.on(
1009
+ triggers=[load_dataset_btn.click, load_file_btn.click],
1010
+ fn=load_dataset_file,
1011
+ inputs=[search_in, file_in, input_type],
1012
+ outputs=[dataframe, document_column],
1013
+ )
1014
+
1015
+ btn_apply_to_sample_dataset.click(
1016
+ fn=generate_sample_dataset,
1017
+ inputs=[
1018
+ search_in,
1019
+ file_in,
1020
+ input_type,
1021
+ system_prompt,
1022
+ document_column,
1023
+ num_turns,
1024
+ num_rows,
1025
+ ],
1026
+ outputs=dataframe,
1027
+ )
1028
+
1029
+ btn_push_to_hub.click(
1030
+ fn=validate_argilla_user_workspace_dataset,
1031
+ inputs=[repo_name],
1032
+ outputs=[success_message],
1033
+ ).then(
1034
+ fn=validate_push_to_hub,
1035
+ inputs=[org_name, repo_name],
1036
+ outputs=[success_message],
1037
+ ).success(
1038
+ fn=hide_save_local,
1039
+ outputs=[csv_file, json_file, success_message],
1040
+ ).success(
1041
+ fn=hide_success_message,
1042
+ outputs=[success_message],
1043
+ ).success(
1044
+ fn=hide_pipeline_code_visibility,
1045
+ inputs=[],
1046
+ outputs=[pipeline_code_ui],
1047
+ ).success(
1048
+ fn=push_dataset,
1049
+ inputs=[
1050
+ org_name,
1051
+ repo_name,
1052
+ private,
1053
+ search_in,
1054
+ file_in,
1055
+ input_type,
1056
+ system_prompt,
1057
+ document_column,
1058
+ num_turns,
1059
+ num_rows,
1060
+ temperature,
1061
+ temperature_completion,
1062
+ pipeline_code,
1063
+ ],
1064
+ outputs=[success_message],
1065
+ ).success(
1066
+ fn=show_success_message,
1067
+ inputs=[org_name, repo_name],
1068
+ outputs=[success_message],
1069
+ ).success(
1070
+ fn=generate_pipeline_code,
1071
+ inputs=[
1072
+ search_in,
1073
+ input_type,
1074
+ system_prompt,
1075
+ document_column,
1076
+ num_turns,
1077
+ num_rows,
1078
+ ],
1079
+ outputs=[pipeline_code],
1080
+ ).success(
1081
+ fn=show_pipeline_code_visibility,
1082
+ inputs=[],
1083
+ outputs=[pipeline_code_ui],
1084
+ )
1085
+
1086
+ btn_save_local.click(
1087
+ fn=hide_success_message,
1088
+ outputs=[success_message],
1089
+ ).success(
1090
+ fn=hide_pipeline_code_visibility,
1091
+ inputs=[],
1092
+ outputs=[pipeline_code_ui],
1093
+ ).success(
1094
+ fn=show_save_local,
1095
+ inputs=[],
1096
+ outputs=[csv_file, json_file, success_message],
1097
+ ).success(
1098
+ save_local,
1099
+ inputs=[
1100
+ search_in,
1101
+ file_in,
1102
+ input_type,
1103
+ system_prompt,
1104
+ document_column,
1105
+ num_turns,
1106
+ num_rows,
1107
+ temperature,
1108
+ repo_name,
1109
+ temperature_completion,
1110
+ ],
1111
+ outputs=[csv_file, json_file],
1112
+ ).success(
1113
+ fn=generate_pipeline_code,
1114
+ inputs=[
1115
+ search_in,
1116
+ input_type,
1117
+ system_prompt,
1118
+ document_column,
1119
+ num_turns,
1120
+ num_rows,
1121
+ ],
1122
+ outputs=[pipeline_code],
1123
+ ).success(
1124
+ fn=show_pipeline_code_visibility,
1125
+ inputs=[],
1126
+ outputs=[pipeline_code_ui],
1127
+ )
1128
+
1129
+ clear_dataset_btn_part.click(fn=lambda: "", inputs=[], outputs=[search_in])
1130
+ clear_file_btn_part.click(fn=lambda: None, inputs=[], outputs=[file_in])
1131
+ clear_prompt_btn_part.click(fn=lambda: "", inputs=[], outputs=[dataset_description])
1132
+ clear_btn_full.click(
1133
+ fn=lambda df: ("", "", [], _get_dataframe()),
1134
+ inputs=[dataframe],
1135
+ outputs=[system_prompt, document_column, num_turns, dataframe],
1136
+ )
1137
+ app.load(fn=swap_visibility, outputs=main_ui)
1138
+ app.load(fn=get_org_dropdown, outputs=[org_name])
1139
+ app.load(fn=get_random_repo_name, outputs=[repo_name])
1140
+ app.load(fn=show_temperature_completion, outputs=[temperature_completion])
1141
+ if SAVE_LOCAL_DIR is not None:
1142
+ app.load(fn=show_save_local_button, outputs=btn_save_local)
src/synthetic_dataset_generator/apps/eval.py ADDED
@@ -0,0 +1,894 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import uuid
3
+ from typing import Union
4
+
5
+ import argilla as rg
6
+ import gradio as gr
7
+ import numpy as np
8
+ import pandas as pd
9
+ from datasets import (
10
+ Dataset,
11
+ get_dataset_config_names,
12
+ get_dataset_split_names,
13
+ load_dataset,
14
+ )
15
+ from distilabel.distiset import Distiset
16
+ from gradio.oauth import OAuthToken #
17
+ from gradio_huggingfacehub_search import HuggingfaceHubSearch
18
+ from huggingface_hub import HfApi
19
+
20
+ from synthetic_dataset_generator.apps.base import (
21
+ combine_datasets,
22
+ get_iframe,
23
+ hide_success_message,
24
+ push_pipeline_code_to_hub,
25
+ show_success_message,
26
+ test_max_num_rows,
27
+ validate_argilla_user_workspace_dataset,
28
+ validate_push_to_hub,
29
+ )
30
+ from synthetic_dataset_generator.constants import DEFAULT_BATCH_SIZE
31
+ from synthetic_dataset_generator.pipelines.embeddings import (
32
+ get_embeddings,
33
+ get_sentence_embedding_dimensions,
34
+ )
35
+ from synthetic_dataset_generator.pipelines.eval import (
36
+ generate_pipeline_code,
37
+ get_custom_evaluator,
38
+ get_ultrafeedback_evaluator,
39
+ )
40
+ from synthetic_dataset_generator.utils import (
41
+ column_to_list,
42
+ extract_column_names,
43
+ get_argilla_client,
44
+ get_org_dropdown,
45
+ get_random_repo_name,
46
+ pad_or_truncate_list,
47
+ process_columns,
48
+ swap_visibility,
49
+ )
50
+
51
+
52
+ def get_valid_columns(dataframe: pd.DataFrame):
53
+ instruction_valid_columns = []
54
+ response_valid_columns = []
55
+
56
+ for col in dataframe.columns:
57
+ sample_val = dataframe[col].iloc[0]
58
+ if isinstance(sample_val, str) or (
59
+ isinstance(sample_val, (list, np.ndarray))
60
+ and all(isinstance(item, dict) and "role" in item for item in sample_val)
61
+ ):
62
+ instruction_valid_columns.append(col)
63
+ response_valid_columns.append(col)
64
+ if isinstance(sample_val, (list, np.ndarray)) and all(
65
+ isinstance(item, str) for item in sample_val
66
+ ):
67
+ response_valid_columns.append(col)
68
+
69
+ return instruction_valid_columns, response_valid_columns
70
+
71
+
72
+ def load_dataset_from_hub(
73
+ repo_id: str, num_rows: int = 10, token: Union[OAuthToken, None] = None
74
+ ):
75
+ if not repo_id:
76
+ raise gr.Error("Hub repo id is required")
77
+ subsets = get_dataset_config_names(repo_id, token=token)
78
+ splits = get_dataset_split_names(repo_id, subsets[0], token=token)
79
+ ds = load_dataset(repo_id, subsets[0], split=splits[0], token=token, streaming=True)
80
+ rows = []
81
+ for idx, row in enumerate(ds):
82
+ rows.append(row)
83
+ if idx == num_rows:
84
+ break
85
+ ds = Dataset.from_list(rows)
86
+ dataframe = ds.to_pandas()
87
+ instruction_valid_columns, response_valid_columns = get_valid_columns(dataframe)
88
+ col_instruction = instruction_valid_columns[0] if instruction_valid_columns else ""
89
+ col_response = "No valid response columns found."
90
+ for col in response_valid_columns:
91
+ if col != col_instruction:
92
+ col_response = col
93
+ break
94
+
95
+ prompt_template = gr.Code(
96
+ label="Prompt template",
97
+ value="\n".join(
98
+ [
99
+ "Evaluate the following text based on criteria.",
100
+ "Criteria: quality.",
101
+ "Score: between 1 and 10.",
102
+ "Text: {{" + col_response + "}}",
103
+ ]
104
+ ),
105
+ language="jinja2",
106
+ interactive=True,
107
+ )
108
+ structured_output = gr.Code(
109
+ label="Structured output",
110
+ value=json.dumps(
111
+ {
112
+ "type": "object",
113
+ "properties": {"quality": {"type": "integer"}},
114
+ "required": ["quality"],
115
+ },
116
+ indent=4,
117
+ ),
118
+ language="json",
119
+ interactive=True,
120
+ )
121
+ return (
122
+ dataframe,
123
+ gr.Dropdown(
124
+ choices=instruction_valid_columns,
125
+ label="Instruction column",
126
+ value=col_instruction,
127
+ interactive=True,
128
+ ),
129
+ gr.Dropdown(
130
+ choices=response_valid_columns,
131
+ label="Response column",
132
+ value=col_response,
133
+ interactive=(
134
+ False if col_response == "No valid response columns found." else True
135
+ ),
136
+ ),
137
+ prompt_template,
138
+ structured_output,
139
+ )
140
+
141
+
142
+ def define_evaluation_aspects(task_type: str):
143
+ if task_type == "chat-eval":
144
+ return gr.Dropdown(
145
+ value=["overall-rating"],
146
+ choices=["helpfulness", "truthfulness", "overall-rating", "honesty"],
147
+ label="Evaluation Aspects",
148
+ multiselect=True,
149
+ interactive=True,
150
+ )
151
+ else:
152
+ return gr.Dropdown(interactive=False, visible=False)
153
+
154
+
155
+ def evaluate_instruction_response(
156
+ dataframe: pd.DataFrame,
157
+ aspects: list[str],
158
+ instruction_column: str,
159
+ response_columns: str,
160
+ num_rows: int = 10,
161
+ is_sample: bool = False,
162
+ progress=gr.Progress(),
163
+ ):
164
+ progress(0.0, desc="Evaluating instructions and responses")
165
+ data = process_columns(dataframe, instruction_column, response_columns)
166
+ num_generations = len(data[0]["generations"])
167
+ evaluated_results = []
168
+ for entry in data:
169
+ result_row = {
170
+ "instruction": entry["instruction"],
171
+ "generations": entry["generations"],
172
+ }
173
+ for aspect in aspects:
174
+ result_row[f"ratings_{aspect}"] = None
175
+ result_row[f"rationale_for_ratings_{aspect}"] = None
176
+ if aspect in ["truthfulness", "helpfulness"]:
177
+ result_row[f"type_{aspect}"] = None
178
+ result_row[f"rationale_for_type_{aspect}"] = None
179
+ result_row["model_name"] = None
180
+ evaluated_results.append(result_row)
181
+
182
+ batch_size = DEFAULT_BATCH_SIZE
183
+ total_steps: int = len(aspects) * num_rows
184
+
185
+ # evaluate instructions and responses
186
+ for aspect in aspects:
187
+ ultrafeedback_evaluator = get_ultrafeedback_evaluator(aspect, is_sample)
188
+ n_processed = 0
189
+
190
+ while n_processed < num_rows:
191
+ progress(
192
+ (len(aspects) * n_processed) / total_steps,
193
+ total=total_steps,
194
+ desc=f"Evaluating aspect: {aspect}",
195
+ )
196
+
197
+ remaining_rows = num_rows - n_processed
198
+ batch_size = min(batch_size, remaining_rows)
199
+ inputs = data[n_processed : n_processed + batch_size]
200
+ batch_results = list(ultrafeedback_evaluator.process(inputs=inputs))
201
+ for j, result in enumerate(batch_results[0]):
202
+ idx = n_processed + j
203
+ evaluated_results[idx][f"ratings_{aspect}"] = pad_or_truncate_list(
204
+ result.get("ratings"), num_generations
205
+ )
206
+ evaluated_results[idx]["model_name"] = result.get("model_name")
207
+ if aspect in ["truthfulness", "helpfulness"]:
208
+ evaluated_results[idx][f"type_{aspect}"] = pad_or_truncate_list(
209
+ result.get("types"), num_generations
210
+ )
211
+ evaluated_results[idx][f"rationale_for_type_{aspect}"] = (
212
+ pad_or_truncate_list(result.get("rationales"), num_generations)
213
+ )
214
+ evaluated_results[idx][f"rationale_for_ratings_{aspect}"] = (
215
+ pad_or_truncate_list(
216
+ result.get("rationales-for-ratings"), num_generations
217
+ )
218
+ )
219
+ else:
220
+ evaluated_results[idx][f"rationale_for_ratings_{aspect}"] = (
221
+ pad_or_truncate_list(result.get("rationales"), num_generations)
222
+ )
223
+ n_processed += batch_size
224
+
225
+ # create final dataset
226
+ dataframe = pd.DataFrame(evaluated_results)
227
+ progress(1.0, desc="Dataset evaluation completed")
228
+ return dataframe
229
+
230
+
231
+ def evaluate_custom(
232
+ dataframe: pd.DataFrame,
233
+ prompt_template: str,
234
+ structured_output: dict,
235
+ num_rows: int = 10,
236
+ is_sample: bool = False,
237
+ progress=gr.Progress(),
238
+ ):
239
+ progress(0.0, desc="Evaluating dataset")
240
+ columns = extract_column_names(prompt_template)
241
+ input_columns = {column: column_to_list(dataframe, column) for column in columns}
242
+
243
+ custom_evaluator = get_custom_evaluator(
244
+ prompt_template, structured_output, columns, is_sample
245
+ )
246
+ batch_size = DEFAULT_BATCH_SIZE
247
+
248
+ # evaluate the data
249
+ n_processed = 0
250
+ evaluation_results = []
251
+ while n_processed < num_rows:
252
+ progress(
253
+ n_processed / num_rows,
254
+ desc="Evaluating dataset",
255
+ )
256
+ remaining_rows = num_rows - n_processed
257
+ batch_size = min(batch_size, remaining_rows)
258
+
259
+ inputs = []
260
+ for idx in range(n_processed, n_processed + batch_size):
261
+ input = {column: input_columns[column][idx] for column in input_columns}
262
+ inputs.append(input)
263
+
264
+ batch = list(custom_evaluator.process(inputs=inputs))
265
+ evaluation_results.extend(batch[0])
266
+ n_processed += batch_size
267
+
268
+ # create final dataset
269
+ distiset_results = []
270
+ for result in evaluation_results:
271
+ record = {key: result[key] for key in result if key != "distilabel_metadata"}
272
+ distiset_results.append(record)
273
+
274
+ dataframe = pd.DataFrame(distiset_results)
275
+ progress(1.0, desc="Dataset evaluation completed")
276
+ return dataframe
277
+
278
+
279
+ def _evaluate_dataset(
280
+ dataframe: pd.DataFrame,
281
+ eval_type: str,
282
+ aspects_instruction_response: list[str],
283
+ instruction_instruction_response: str,
284
+ response_instruction_response: str,
285
+ prompt_template: str,
286
+ structured_output: dict,
287
+ num_rows: int = 10,
288
+ is_sample: bool = False,
289
+ ):
290
+ num_rows = test_max_num_rows(num_rows)
291
+ if eval_type == "chat-eval":
292
+ dataframe = evaluate_instruction_response(
293
+ dataframe=dataframe,
294
+ aspects=aspects_instruction_response,
295
+ instruction_column=instruction_instruction_response,
296
+ response_columns=response_instruction_response,
297
+ num_rows=num_rows,
298
+ is_sample=is_sample,
299
+ )
300
+ else:
301
+ dataframe = evaluate_custom(
302
+ dataframe=dataframe,
303
+ prompt_template=prompt_template,
304
+ structured_output=structured_output,
305
+ num_rows=num_rows,
306
+ is_sample=is_sample,
307
+ )
308
+ return dataframe
309
+
310
+
311
+ def evaluate_sample_dataset(
312
+ repo_id: str,
313
+ eval_type: str,
314
+ aspects_instruction_response: list[str],
315
+ instruction_instruction_response: str,
316
+ response_instruction_response: str,
317
+ prompt_template: str,
318
+ structured_output: dict,
319
+ ):
320
+ dataframe, _, _, _, _ = load_dataset_from_hub(repo_id, num_rows=10)
321
+ dataframe = _evaluate_dataset(
322
+ dataframe=dataframe,
323
+ eval_type=eval_type,
324
+ aspects_instruction_response=aspects_instruction_response,
325
+ instruction_instruction_response=instruction_instruction_response,
326
+ response_instruction_response=response_instruction_response,
327
+ prompt_template=prompt_template,
328
+ structured_output=structured_output,
329
+ num_rows=10,
330
+ is_sample=True,
331
+ )
332
+ return dataframe
333
+
334
+
335
+ def push_dataset_to_hub(
336
+ dataframe: pd.DataFrame,
337
+ org_name: str,
338
+ repo_name: str,
339
+ oauth_token: Union[gr.OAuthToken, None],
340
+ private: bool,
341
+ pipeline_code: str,
342
+ progress=gr.Progress(),
343
+ ):
344
+ progress(0.0, desc="Validating")
345
+ repo_id = validate_push_to_hub(org_name, repo_name)
346
+ progress(0.5, desc="Creating dataset")
347
+ dataset = Dataset.from_pandas(dataframe)
348
+ dataset = combine_datasets(repo_id, dataset, oauth_token)
349
+ distiset = Distiset({"default": dataset})
350
+ progress(0.9, desc="Pushing dataset")
351
+ distiset.push_to_hub(
352
+ repo_id=repo_id,
353
+ private=private,
354
+ include_script=False,
355
+ token=oauth_token.token,
356
+ create_pr=False,
357
+ )
358
+ push_pipeline_code_to_hub(pipeline_code, org_name, repo_name, oauth_token)
359
+ progress(1.0, desc="Dataset pushed")
360
+ return dataframe
361
+
362
+
363
+ def push_dataset(
364
+ org_name: str,
365
+ repo_name: str,
366
+ private: bool,
367
+ num_rows: int,
368
+ original_repo_id: str,
369
+ eval_type: str,
370
+ aspects_instruction_response: list[str],
371
+ instruction_instruction_response: str,
372
+ response_instruction_response: str,
373
+ prompt_template: str,
374
+ structured_output: dict,
375
+ pipeline_code: str,
376
+ oauth_token: Union[gr.OAuthToken, None] = None,
377
+ progress=gr.Progress(),
378
+ ) -> pd.DataFrame:
379
+ dataframe, _, _, _, _ = load_dataset_from_hub(original_repo_id, num_rows=num_rows)
380
+ dataframe = _evaluate_dataset(
381
+ dataframe=dataframe,
382
+ eval_type=eval_type,
383
+ aspects_instruction_response=aspects_instruction_response,
384
+ instruction_instruction_response=instruction_instruction_response,
385
+ response_instruction_response=response_instruction_response,
386
+ prompt_template=prompt_template,
387
+ structured_output=structured_output,
388
+ num_rows=num_rows,
389
+ )
390
+ push_dataset_to_hub(
391
+ dataframe, org_name, repo_name, oauth_token, private, pipeline_code
392
+ )
393
+ try:
394
+ progress(0.1, desc="Setting up user and workspace")
395
+ hf_user = HfApi().whoami(token=oauth_token.token)["name"]
396
+ client = get_argilla_client()
397
+ if client is None:
398
+ return ""
399
+ progress(0.5, desc="Creating dataset in Argilla")
400
+ if eval_type == "chat-eval":
401
+ num_generations = len((dataframe["generations"][0]))
402
+ fields = [
403
+ rg.ChatField(
404
+ name=f"chat_{i}",
405
+ title=f"Chat {i+1}",
406
+ description=f"User and assistant conversation for generation {i+1}",
407
+ )
408
+ for i in range(num_generations)
409
+ ]
410
+ questions = []
411
+ for i in range(num_generations):
412
+ for aspect in aspects_instruction_response:
413
+ questions.append(
414
+ rg.RatingQuestion(
415
+ name=f"ratings_{aspect}_{i}",
416
+ values=list(range(11)),
417
+ title=f"Ratings for {aspect} for response {i+1}",
418
+ required=True,
419
+ )
420
+ )
421
+ questions.append(
422
+ rg.TextQuestion(
423
+ name=f"rationale_for_ratings_{aspect}_{i}",
424
+ title=f"Rationale for ratings for {aspect} for response {i+1}",
425
+ required=False,
426
+ use_markdown=True,
427
+ )
428
+ )
429
+ if aspect in ["truthfulness", "helpfulness"]:
430
+ questions.append(
431
+ rg.RatingQuestion(
432
+ name=f"type_{aspect}_{i}",
433
+ values=list(range(1, 6)),
434
+ title=f"The type of the response {i+1} for {aspect}",
435
+ required=True,
436
+ )
437
+ )
438
+ questions.append(
439
+ rg.TextQuestion(
440
+ name=f"rationale_for_type_{aspect}_{i}",
441
+ title=f"Rationale for type of the response {i+1} for {aspect}",
442
+ required=False,
443
+ use_markdown=True,
444
+ )
445
+ )
446
+ metadata = [
447
+ rg.IntegerMetadataProperty(
448
+ name="instruction_length", title="Instruction length"
449
+ ),
450
+ ]
451
+ for i in range(num_generations):
452
+ metadata.append(
453
+ rg.IntegerMetadataProperty(
454
+ name=f"response_{i}_length", title=f"Response {i+1} length"
455
+ )
456
+ )
457
+ vectors = [
458
+ rg.VectorField(
459
+ name="instruction_embeddings",
460
+ dimensions=get_sentence_embedding_dimensions(),
461
+ )
462
+ ]
463
+ settings = rg.Settings(
464
+ fields=fields,
465
+ questions=questions,
466
+ metadata=metadata,
467
+ vectors=vectors,
468
+ guidelines="Please review the conversation and provide an evaluation.",
469
+ )
470
+
471
+ dataframe["instruction_length"] = dataframe["instruction"].apply(len)
472
+ for i in range(num_generations):
473
+ dataframe[f"response_{i}_length"] = dataframe["generations"].apply(
474
+ lambda gens: len(gens[i]) if i < len(gens) else 0
475
+ )
476
+ dataframe["instruction_embeddings"] = get_embeddings(
477
+ dataframe["instruction"].to_list()
478
+ )
479
+
480
+ rg_dataset = client.datasets(name=repo_name, workspace=hf_user)
481
+ if rg_dataset is None:
482
+ rg_dataset = rg.Dataset(
483
+ name=repo_name,
484
+ workspace=hf_user,
485
+ settings=settings,
486
+ client=client,
487
+ )
488
+ rg_dataset = rg_dataset.create()
489
+
490
+ progress(0.7, desc="Pushing dataset to Argilla")
491
+ hf_dataset = Dataset.from_pandas(dataframe)
492
+ records = []
493
+ for sample in hf_dataset:
494
+ fields = {}
495
+ metadata = {"instruction_length": sample.get("instruction_length", 0)}
496
+ vectors = {
497
+ "instruction_embeddings": sample.get("instruction_embeddings", [])
498
+ }
499
+ suggestions = []
500
+ generations = sample.get("generations", [])
501
+ for i in range(num_generations):
502
+ fields[f"chat_{i}"] = [
503
+ {"role": "user", "content": sample.get("instruction", "")},
504
+ {"role": "assistant", "content": generations[i]},
505
+ ]
506
+ metadata[f"response_{i}_length"] = sample.get(
507
+ f"response_{i}_length", 0
508
+ )
509
+
510
+ for aspect in aspects_instruction_response:
511
+ ratings = sample.get(f"ratings_{aspect}", [])
512
+ rationales = sample.get(f"rationale_for_ratings__{aspect}", [])
513
+
514
+ rating_value = (
515
+ ratings[i]
516
+ if ratings and isinstance(ratings[i], int)
517
+ else None
518
+ )
519
+ rationale_value = (
520
+ rationales[i]
521
+ if rationales and isinstance(rationales[i], str)
522
+ else None
523
+ )
524
+
525
+ if rating_value is not None:
526
+ suggestions.append(
527
+ rg.Suggestion(
528
+ question_name=f"ratings_{aspect}_{i}",
529
+ value=rating_value,
530
+ )
531
+ )
532
+ if rationale_value is not None:
533
+ suggestions.append(
534
+ rg.Suggestion(
535
+ question_name=f"rationale_for_ratings_{aspect}_{i}",
536
+ value=rationale_value,
537
+ )
538
+ )
539
+
540
+ if aspect in ["truthfulness", "helpfulness"]:
541
+ types = sample.get(f"type_{aspect}", [])
542
+ rationale_types = sample.get(
543
+ f"rationale_for_type_{aspect}", []
544
+ )
545
+
546
+ type_value = (
547
+ types[i]
548
+ if types and isinstance(types[i], int)
549
+ else None
550
+ )
551
+ rationale_type_value = (
552
+ rationale_types[i]
553
+ if rationale_types
554
+ and isinstance(rationale_types[i], str)
555
+ else None
556
+ )
557
+ if type_value is not None:
558
+ suggestions.append(
559
+ rg.Suggestion(
560
+ question_name=f"type_{aspect}_{i}",
561
+ value=type_value,
562
+ )
563
+ )
564
+ if rationale_type_value is not None:
565
+ suggestions.append(
566
+ rg.Suggestion(
567
+ question_name=f"rationale_for_type_{aspect}_{i}",
568
+ value=rationale_type_value,
569
+ )
570
+ )
571
+ records.append(
572
+ rg.Record(
573
+ fields=fields,
574
+ metadata=metadata,
575
+ vectors=vectors,
576
+ suggestions=suggestions,
577
+ )
578
+ )
579
+ rg_dataset.records.log(records=records)
580
+ progress(1.0, desc="Dataset pushed to Argilla")
581
+ else:
582
+ columns = extract_column_names(prompt_template)
583
+ settings = rg.Settings(
584
+ fields=[
585
+ rg.TextField(
586
+ name=column,
587
+ title=column.capitalize(),
588
+ description="The column content",
589
+ )
590
+ for column in columns
591
+ ],
592
+ questions=[
593
+ rg.TextQuestion(
594
+ name="evaluation",
595
+ title="Evaluation",
596
+ description="The generated evaluation",
597
+ use_markdown=True,
598
+ ),
599
+ ],
600
+ metadata=[
601
+ rg.IntegerMetadataProperty(
602
+ name=f"{column}_length", title=f"{column.capitalize()} length"
603
+ )
604
+ for column in columns
605
+ ],
606
+ vectors=[
607
+ rg.VectorField(
608
+ name=f"{column}_embeddings",
609
+ dimensions=get_sentence_embedding_dimensions(),
610
+ )
611
+ for column in columns
612
+ ],
613
+ guidelines="Please review, correct and provide an accurate evaluation.",
614
+ )
615
+ for column in columns:
616
+ dataframe[f"{column}_length"] = dataframe[column].apply(len)
617
+ dataframe[f"{column}_embeddings"] = get_embeddings(dataframe[column])
618
+
619
+ rg_dataset = client.datasets(name=repo_name, workspace=hf_user)
620
+ if rg_dataset is None:
621
+ rg_dataset = rg.Dataset(
622
+ name=repo_name,
623
+ workspace=hf_user,
624
+ settings=settings,
625
+ client=client,
626
+ )
627
+ rg_dataset = rg_dataset.create()
628
+ progress(0.7, desc="Pushing dataset to Argilla")
629
+ hf_dataset = Dataset.from_pandas(dataframe)
630
+ rg_dataset.records.log(
631
+ records=hf_dataset, mapping={"generation": "evaluation"}
632
+ )
633
+ progress(1.0, desc="Dataset pushed to Argilla")
634
+ except Exception as e:
635
+ raise gr.Error(f"Error pushing dataset to Argilla: {e}")
636
+ return ""
637
+
638
+
639
+ def show_pipeline_code_visibility():
640
+ return {pipeline_code_ui: gr.Accordion(visible=True)}
641
+
642
+
643
+ def hide_pipeline_code_visibility():
644
+ return {pipeline_code_ui: gr.Accordion(visible=False)}
645
+
646
+
647
+ ######################
648
+ # Gradio UI
649
+ ######################
650
+
651
+
652
+ with gr.Blocks() as app:
653
+ with gr.Column() as main_ui:
654
+ gr.Markdown("## 1. Select your input dataset")
655
+ with gr.Row(equal_height=False):
656
+ with gr.Column(scale=2):
657
+ search_in = HuggingfaceHubSearch(
658
+ label="Search",
659
+ placeholder="Search for a dataset",
660
+ search_type="dataset",
661
+ sumbit_on_select=True,
662
+ )
663
+ with gr.Row():
664
+ clear_btn_part = gr.Button("Clear", variant="secondary")
665
+ load_btn = gr.Button("Load", variant="primary")
666
+
667
+ with gr.Column(scale=3):
668
+ examples = gr.Examples(
669
+ examples=[
670
+ "argilla/distilabel-sft-easy",
671
+ "HuggingFaceFW/fineweb-edu",
672
+ "argilla/distilabel-intel-orca-dpo-pairs",
673
+ ],
674
+ label="Example datasets",
675
+ fn=lambda x: x,
676
+ inputs=[search_in],
677
+ run_on_click=True,
678
+ )
679
+ search_out = gr.HTML(label="Dataset preview", visible=False)
680
+
681
+ gr.HTML(value="<hr>")
682
+ gr.Markdown(value="## 2. Configure your task")
683
+ with gr.Row(equal_height=False):
684
+ with gr.Column(scale=2):
685
+ eval_type = gr.Dropdown(
686
+ label="Evaluation type",
687
+ choices=["chat-eval", "custom-eval"],
688
+ value="chat-eval",
689
+ multiselect=False,
690
+ visible=False,
691
+ )
692
+ with gr.Tab("Response Evaluation") as tab_instruction_response:
693
+ aspects_instruction_response = define_evaluation_aspects(
694
+ "chat-eval"
695
+ )
696
+ instruction_instruction_response = gr.Dropdown(
697
+ label="Instruction Column",
698
+ info="Select the instruction column to evaluate",
699
+ choices=["Load your data first in step 1."],
700
+ value="Load your data first in step 1.",
701
+ interactive=False,
702
+ multiselect=False,
703
+ allow_custom_value=False,
704
+ )
705
+ response_instruction_response = gr.Dropdown(
706
+ label="Response Column",
707
+ info="Select the response column(s) to evaluate",
708
+ choices=["Load your data first in step 1."],
709
+ value="Load your data first in step 1.",
710
+ interactive=False,
711
+ multiselect=False,
712
+ allow_custom_value=False,
713
+ )
714
+ tab_instruction_response.select(
715
+ fn=lambda: "chat-eval",
716
+ inputs=[],
717
+ outputs=[eval_type],
718
+ )
719
+ with gr.Tab("Custom Evaluation Prompt") as tab_custom:
720
+ aspects_custom = define_evaluation_aspects("custom-eval")
721
+ prompt_template = gr.Code(
722
+ label="Prompt template",
723
+ value="Load your data first in step 1.",
724
+ language="markdown",
725
+ interactive=False,
726
+ )
727
+ structured_output = gr.Code(
728
+ label="Structured output",
729
+ value="Load your data first in step 1.",
730
+ language="json",
731
+ interactive=False,
732
+ )
733
+ tab_custom.select(
734
+ fn=lambda: "custom-eval",
735
+ inputs=[],
736
+ outputs=[eval_type],
737
+ )
738
+ with gr.Row():
739
+ clear_btn_full = gr.Button("Clear", variant="secondary")
740
+ btn_apply_to_sample_dataset = gr.Button("Save", variant="primary")
741
+ with gr.Column(scale=3):
742
+ dataframe = gr.Dataframe(
743
+ headers=["prompt", "completion", "evaluation"],
744
+ wrap=True,
745
+ interactive=False,
746
+ )
747
+
748
+ gr.HTML(value="<hr>")
749
+ gr.Markdown(value="## 3. Evaluate your dataset")
750
+ with gr.Row(equal_height=False):
751
+ with gr.Column(scale=2):
752
+ org_name = get_org_dropdown()
753
+ repo_name = gr.Textbox(
754
+ label="Repo name",
755
+ placeholder="dataset_name",
756
+ value=f"my-distiset-{str(uuid.uuid4())[:8]}",
757
+ interactive=True,
758
+ )
759
+ num_rows = gr.Number(
760
+ label="Number of rows",
761
+ value=10,
762
+ interactive=True,
763
+ scale=1,
764
+ )
765
+ private = gr.Checkbox(
766
+ label="Private dataset",
767
+ value=False,
768
+ interactive=True,
769
+ scale=1,
770
+ )
771
+ btn_push_to_hub = gr.Button("Push to Hub", variant="primary", scale=2)
772
+ with gr.Column(scale=3):
773
+ success_message = gr.Markdown(
774
+ visible=True,
775
+ min_height=100, # don't remove this otherwise progress is not visible
776
+ )
777
+ with gr.Accordion(
778
+ "Customize your pipeline with distilabel",
779
+ open=False,
780
+ visible=False,
781
+ ) as pipeline_code_ui:
782
+ code = generate_pipeline_code(
783
+ repo_id=search_in.value,
784
+ aspects=aspects_instruction_response.value,
785
+ instruction_column=instruction_instruction_response,
786
+ response_columns=response_instruction_response,
787
+ prompt_template=prompt_template.value,
788
+ structured_output=structured_output.value,
789
+ num_rows=num_rows.value,
790
+ eval_type=eval_type.value,
791
+ )
792
+ pipeline_code = gr.Code(
793
+ value=code,
794
+ language="python",
795
+ label="Distilabel Pipeline Code",
796
+ )
797
+
798
+ search_in.submit(fn=get_iframe, inputs=search_in, outputs=search_out).then(
799
+ fn=lambda df: pd.DataFrame(columns=df.columns),
800
+ inputs=[dataframe],
801
+ outputs=[dataframe],
802
+ )
803
+
804
+ load_btn.click(
805
+ fn=load_dataset_from_hub,
806
+ inputs=[search_in],
807
+ outputs=[
808
+ dataframe,
809
+ instruction_instruction_response,
810
+ response_instruction_response,
811
+ prompt_template,
812
+ structured_output,
813
+ ],
814
+ )
815
+
816
+ btn_apply_to_sample_dataset.click(
817
+ fn=evaluate_sample_dataset,
818
+ inputs=[
819
+ search_in,
820
+ eval_type,
821
+ aspects_instruction_response,
822
+ instruction_instruction_response,
823
+ response_instruction_response,
824
+ prompt_template,
825
+ structured_output,
826
+ ],
827
+ outputs=dataframe,
828
+ )
829
+
830
+ btn_push_to_hub.click(
831
+ fn=validate_argilla_user_workspace_dataset,
832
+ inputs=[repo_name],
833
+ outputs=[success_message],
834
+ ).then(
835
+ fn=validate_push_to_hub,
836
+ inputs=[org_name, repo_name],
837
+ outputs=[success_message],
838
+ ).success(
839
+ fn=hide_success_message,
840
+ outputs=[success_message],
841
+ ).success(
842
+ fn=hide_pipeline_code_visibility,
843
+ inputs=[],
844
+ outputs=[pipeline_code_ui],
845
+ ).success(
846
+ fn=push_dataset,
847
+ inputs=[
848
+ org_name,
849
+ repo_name,
850
+ private,
851
+ num_rows,
852
+ search_in,
853
+ eval_type,
854
+ aspects_instruction_response,
855
+ instruction_instruction_response,
856
+ response_instruction_response,
857
+ prompt_template,
858
+ structured_output,
859
+ pipeline_code,
860
+ ],
861
+ outputs=[success_message],
862
+ ).success(
863
+ fn=show_success_message,
864
+ inputs=[org_name, repo_name],
865
+ outputs=[success_message],
866
+ ).success(
867
+ fn=generate_pipeline_code,
868
+ inputs=[
869
+ search_in,
870
+ prompt_template,
871
+ structured_output,
872
+ eval_type,
873
+ ],
874
+ outputs=[pipeline_code],
875
+ ).success(
876
+ fn=show_pipeline_code_visibility,
877
+ inputs=[],
878
+ outputs=[pipeline_code_ui],
879
+ )
880
+
881
+ clear_btn_part.click(fn=lambda: "", inputs=[], outputs=[search_in])
882
+ clear_btn_full.click(
883
+ fn=lambda df: ("", "", pd.DataFrame(columns=df.columns)),
884
+ inputs=[dataframe],
885
+ outputs=[
886
+ instruction_instruction_response,
887
+ response_instruction_response,
888
+ dataframe,
889
+ ],
890
+ )
891
+
892
+ app.load(fn=swap_visibility, outputs=main_ui)
893
+ app.load(fn=get_org_dropdown, outputs=[org_name])
894
+ app.load(fn=get_random_repo_name, outputs=[repo_name])
src/synthetic_dataset_generator/apps/rag.py ADDED
@@ -0,0 +1,972 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import uuid
4
+ from typing import Union
5
+
6
+ import argilla as rg
7
+ import gradio as gr
8
+ import nltk
9
+ import pandas as pd
10
+ from datasets import Dataset
11
+ from distilabel.distiset import Distiset
12
+ from gradio.oauth import OAuthToken
13
+ from gradio_huggingfacehub_search import HuggingfaceHubSearch
14
+ from huggingface_hub import HfApi
15
+
16
+ from synthetic_dataset_generator.apps.base import (
17
+ combine_datasets,
18
+ hide_success_message,
19
+ load_dataset_from_hub,
20
+ preprocess_input_data,
21
+ push_pipeline_code_to_hub,
22
+ show_success_message,
23
+ test_max_num_rows,
24
+ validate_argilla_user_workspace_dataset,
25
+ validate_push_to_hub,
26
+ )
27
+ from synthetic_dataset_generator.constants import (
28
+ DEFAULT_BATCH_SIZE,
29
+ MODEL,
30
+ MODEL_COMPLETION,
31
+ SAVE_LOCAL_DIR,
32
+ )
33
+ from synthetic_dataset_generator.pipelines.base import get_rewritten_prompts
34
+ from synthetic_dataset_generator.pipelines.embeddings import (
35
+ get_embeddings,
36
+ get_sentence_embedding_dimensions,
37
+ )
38
+ from synthetic_dataset_generator.pipelines.rag import (
39
+ DEFAULT_DATASET_DESCRIPTIONS,
40
+ generate_pipeline_code,
41
+ get_chunks_generator,
42
+ get_prompt_generator,
43
+ get_response_generator,
44
+ get_sentence_pair_generator,
45
+ )
46
+ from synthetic_dataset_generator.utils import (
47
+ column_to_list,
48
+ get_argilla_client,
49
+ get_org_dropdown,
50
+ get_random_repo_name,
51
+ swap_visibility,
52
+ )
53
+
54
+ os.makedirs("./nltk_data", exist_ok=True)
55
+ nltk.data.path.append("./nltk_data")
56
+ nltk.download("punkt_tab", download_dir="./nltk_data")
57
+ nltk.download("averaged_perceptron_tagger_eng", download_dir="./nltk_data")
58
+
59
+
60
+ def generate_system_prompt(dataset_description: str, progress=gr.Progress()):
61
+ progress(0.1, desc="Initializing")
62
+ generate_description = get_prompt_generator()
63
+ progress(0.5, desc="Generating")
64
+ result = next(
65
+ generate_description.process(
66
+ [
67
+ {
68
+ "instruction": dataset_description,
69
+ }
70
+ ]
71
+ )
72
+ )[0]["generation"]
73
+ progress(1.0, desc="Prompt generated")
74
+ return result
75
+
76
+
77
+ def load_dataset_file(
78
+ repo_id: str,
79
+ file_paths: list[str],
80
+ input_type: str,
81
+ num_rows: int = 10,
82
+ token: Union[OAuthToken, None] = None,
83
+ progress=gr.Progress(),
84
+ ):
85
+ progress(0.1, desc="Loading the source data")
86
+ if input_type == "dataset-input":
87
+ return load_dataset_from_hub(repo_id=repo_id, num_rows=num_rows, token=token)
88
+ else:
89
+ return preprocess_input_data(file_paths=file_paths, num_rows=num_rows)
90
+
91
+
92
+ def generate_sample_dataset(
93
+ repo_id: str,
94
+ file_paths: list[str],
95
+ input_type: str,
96
+ system_prompt: str,
97
+ document_column: str,
98
+ retrieval_reranking: list[str],
99
+ num_rows: str,
100
+ oauth_token: Union[OAuthToken, None],
101
+ progress=gr.Progress(),
102
+ ):
103
+ retrieval = "Retrieval" in retrieval_reranking
104
+ reranking = "Reranking" in retrieval_reranking
105
+
106
+ if input_type == "prompt-input":
107
+ dataframe = pd.DataFrame(columns=["context", "question", "response"])
108
+ else:
109
+ dataframe, _ = load_dataset_file(
110
+ repo_id=repo_id,
111
+ file_paths=file_paths,
112
+ input_type=input_type,
113
+ num_rows=num_rows,
114
+ token=oauth_token,
115
+ )
116
+ progress(0.5, desc="Generating dataset")
117
+ dataframe = generate_dataset(
118
+ input_type=input_type,
119
+ dataframe=dataframe,
120
+ system_prompt=system_prompt,
121
+ document_column=document_column,
122
+ retrieval=retrieval,
123
+ reranking=reranking,
124
+ num_rows=10,
125
+ is_sample=True,
126
+ )
127
+ progress(1.0, desc="Sample dataset generated")
128
+ return dataframe
129
+
130
+
131
+ def generate_dataset(
132
+ input_type: str,
133
+ dataframe: pd.DataFrame,
134
+ system_prompt: str,
135
+ document_column: str,
136
+ retrieval: bool = False,
137
+ reranking: bool = False,
138
+ num_rows: int = 10,
139
+ temperature: float = 0.7,
140
+ temperature_completion: Union[float, None] = None,
141
+ is_sample: bool = False,
142
+ progress=gr.Progress(),
143
+ ):
144
+ num_rows = test_max_num_rows(num_rows)
145
+ progress(0.0, desc="Initializing dataset generation")
146
+ if input_type == "prompt-input":
147
+ chunk_generator = get_chunks_generator(
148
+ temperature=temperature, is_sample=is_sample
149
+ )
150
+ else:
151
+ document_data = column_to_list(dataframe, document_column)
152
+ if len(document_data) < num_rows:
153
+ document_data += random.choices(
154
+ document_data, k=num_rows - len(document_data)
155
+ )
156
+
157
+ retrieval_generator = get_sentence_pair_generator(
158
+ action="query",
159
+ triplet=True if retrieval else False,
160
+ temperature=temperature,
161
+ is_sample=is_sample,
162
+ )
163
+ response_generator = get_response_generator(
164
+ temperature=temperature_completion or temperature, is_sample=is_sample
165
+ )
166
+ if reranking:
167
+ reranking_generator = get_sentence_pair_generator(
168
+ action="semantically-similar",
169
+ triplet=True,
170
+ temperature=temperature,
171
+ is_sample=is_sample,
172
+ )
173
+ steps = 2 + sum([1 if reranking else 0, 1 if input_type == "prompt-type" else 0])
174
+ total_steps: int = num_rows * steps
175
+ step_progress = round(1 / steps, 2)
176
+ batch_size = DEFAULT_BATCH_SIZE
177
+
178
+ # generate chunks
179
+ if input_type == "prompt-input":
180
+ n_processed = 0
181
+ chunk_results = []
182
+ rewritten_system_prompts = get_rewritten_prompts(system_prompt, num_rows)
183
+ while n_processed < num_rows:
184
+ progress(
185
+ step_progress * n_processed / num_rows,
186
+ total=total_steps,
187
+ desc="Generating chunks",
188
+ )
189
+ remaining_rows = num_rows - n_processed
190
+ batch_size = min(batch_size, remaining_rows)
191
+ inputs = [
192
+ {"task": random.choice(rewritten_system_prompts)}
193
+ for _ in range(batch_size)
194
+ ]
195
+ chunks = list(chunk_generator.process(inputs=inputs))
196
+ chunk_results.extend(chunks[0])
197
+ n_processed += batch_size
198
+ random.seed(a=random.randint(0, 2**32 - 1))
199
+ document_data = [chunk["generation"] for chunk in chunk_results]
200
+ progress(step_progress, desc="Generating chunks")
201
+
202
+ # generate questions
203
+ n_processed = 0
204
+ retrieval_results = []
205
+ while n_processed < num_rows:
206
+ progress(
207
+ step_progress * n_processed / num_rows,
208
+ total=total_steps,
209
+ desc="Generating questions",
210
+ )
211
+ remaining_rows = num_rows - n_processed
212
+ batch_size = min(batch_size, remaining_rows)
213
+ inputs = [
214
+ {"anchor": document}
215
+ for document in document_data[n_processed : n_processed + batch_size]
216
+ ]
217
+ questions = list(retrieval_generator.process(inputs=inputs))
218
+ retrieval_results.extend(questions[0])
219
+ n_processed += batch_size
220
+ for result in retrieval_results:
221
+ result["context"] = result["anchor"]
222
+ if retrieval:
223
+ result["question"] = result["positive"]
224
+ result["positive_retrieval"] = result.pop("positive")
225
+ result["negative_retrieval"] = result.pop("negative")
226
+ else:
227
+ result["question"] = result.pop("positive")
228
+
229
+ progress(step_progress, desc="Generating questions")
230
+
231
+ # generate responses
232
+ n_processed = 0
233
+ response_results = []
234
+ while n_processed < num_rows:
235
+ progress(
236
+ step_progress + step_progress * n_processed / num_rows,
237
+ total=total_steps,
238
+ desc="Generating responses",
239
+ )
240
+ batch = retrieval_results[n_processed : n_processed + batch_size]
241
+ responses = list(response_generator.process(inputs=batch))
242
+ response_results.extend(responses[0])
243
+ n_processed += batch_size
244
+ for result in response_results:
245
+ result["response"] = result["generation"]
246
+ progress(step_progress, desc="Generating responses")
247
+
248
+ # generate reranking
249
+ if reranking:
250
+ n_processed = 0
251
+ reranking_results = []
252
+ while n_processed < num_rows:
253
+ progress(
254
+ step_progress * n_processed / num_rows,
255
+ total=total_steps,
256
+ desc="Generating reranking data",
257
+ )
258
+ batch = response_results[n_processed : n_processed + batch_size]
259
+ batch = list(reranking_generator.process(inputs=batch))
260
+ reranking_results.extend(batch[0])
261
+ n_processed += batch_size
262
+ for result in reranking_results:
263
+ result["positive_reranking"] = result.pop("positive")
264
+ result["negative_reranking"] = result.pop("negative")
265
+ progress(
266
+ 1,
267
+ total=total_steps,
268
+ desc="Creating dataset",
269
+ )
270
+
271
+ # create distiset
272
+ distiset_results = []
273
+ source_results = reranking_results if reranking else response_results
274
+ base_keys = ["context", "question", "response"]
275
+ retrieval_keys = ["positive_retrieval", "negative_retrieval"] if retrieval else []
276
+ reranking_keys = ["positive_reranking", "negative_reranking"] if reranking else []
277
+ relevant_keys = base_keys + retrieval_keys + reranking_keys
278
+
279
+ for result in source_results:
280
+ record = {key: result.get(key) for key in relevant_keys if key in result}
281
+ distiset_results.append(record)
282
+
283
+ dataframe = pd.DataFrame(distiset_results)
284
+
285
+ progress(1.0, desc="Dataset generation completed")
286
+ return dataframe
287
+
288
+
289
+ def push_dataset_to_hub(
290
+ dataframe: pd.DataFrame,
291
+ org_name: str,
292
+ repo_name: str,
293
+ oauth_token: Union[gr.OAuthToken, None],
294
+ private: bool,
295
+ pipeline_code: str,
296
+ progress=gr.Progress(),
297
+ ):
298
+ progress(0.0, desc="Validating")
299
+ repo_id = validate_push_to_hub(org_name, repo_name)
300
+ progress(0.5, desc="Creating dataset")
301
+ dataset = Dataset.from_pandas(dataframe)
302
+ dataset = combine_datasets(repo_id, dataset, oauth_token)
303
+ distiset = Distiset({"default": dataset})
304
+ progress(0.9, desc="Pushing dataset")
305
+ distiset.push_to_hub(
306
+ repo_id=repo_id,
307
+ private=private,
308
+ include_script=False,
309
+ token=oauth_token.token,
310
+ create_pr=False,
311
+ )
312
+ push_pipeline_code_to_hub(pipeline_code, org_name, repo_name, oauth_token)
313
+ progress(1.0, desc="Dataset pushed")
314
+ return dataframe
315
+
316
+
317
+ def push_dataset(
318
+ org_name: str,
319
+ repo_name: str,
320
+ private: bool,
321
+ original_repo_id: str,
322
+ file_paths: list[str],
323
+ input_type: str,
324
+ system_prompt: str,
325
+ document_column: str,
326
+ retrieval_reranking: list[str],
327
+ num_rows: int,
328
+ temperature: float,
329
+ temperature_completion: float,
330
+ pipeline_code: str,
331
+ oauth_token: Union[gr.OAuthToken, None] = None,
332
+ progress=gr.Progress(),
333
+ ) -> pd.DataFrame:
334
+ retrieval = "Retrieval" in retrieval_reranking
335
+ reranking = "Reranking" in retrieval_reranking
336
+
337
+ if input_type == "prompt-input":
338
+ dataframe = pd.DataFrame(columns=["context", "question", "response"])
339
+ else:
340
+ dataframe, _ = load_dataset_file(
341
+ repo_id=original_repo_id,
342
+ file_paths=file_paths,
343
+ input_type=input_type,
344
+ num_rows=num_rows,
345
+ token=oauth_token,
346
+ )
347
+ progress(0.5, desc="Generating dataset")
348
+ dataframe = generate_dataset(
349
+ input_type=input_type,
350
+ dataframe=dataframe,
351
+ system_prompt=system_prompt,
352
+ document_column=document_column,
353
+ retrieval=retrieval,
354
+ reranking=reranking,
355
+ num_rows=num_rows,
356
+ temperature=temperature,
357
+ temperature_completion=temperature_completion,
358
+ is_sample=True,
359
+ )
360
+ push_dataset_to_hub(
361
+ dataframe, org_name, repo_name, oauth_token, private, pipeline_code
362
+ )
363
+ dataframe = dataframe[
364
+ dataframe.applymap(lambda x: str(x).strip() if pd.notna(x) else x).apply(
365
+ lambda row: row.notna().all() and (row != "").all(), axis=1
366
+ )
367
+ ]
368
+ try:
369
+ progress(0.1, desc="Setting up user and workspace")
370
+ hf_user = HfApi().whoami(token=oauth_token.token)["name"]
371
+ client = get_argilla_client()
372
+ if client is None:
373
+ return ""
374
+
375
+ progress(0.5, desc="Creating dataset in Argilla")
376
+ fields = [
377
+ rg.TextField(
378
+ name="context",
379
+ title="Context",
380
+ description="Context for the generation",
381
+ ),
382
+ rg.ChatField(
383
+ name="chat",
384
+ title="Chat",
385
+ description="User and assistant conversation based on the context",
386
+ ),
387
+ ]
388
+ for item in ["positive", "negative"]:
389
+ if retrieval:
390
+ fields.append(
391
+ rg.TextField(
392
+ name=f"{item}_retrieval",
393
+ title=f"{item.capitalize()} retrieval",
394
+ description=f"The {item} query for retrieval",
395
+ )
396
+ )
397
+ if reranking:
398
+ fields.append(
399
+ rg.TextField(
400
+ name=f"{item}_reranking",
401
+ title=f"{item.capitalize()} reranking",
402
+ description=f"The {item} query for reranking",
403
+ )
404
+ )
405
+
406
+ questions = [
407
+ rg.LabelQuestion(
408
+ name="relevant",
409
+ title="Are the question and response relevant to the given context?",
410
+ labels=["yes", "no"],
411
+ ),
412
+ rg.LabelQuestion(
413
+ name="is_response_correct",
414
+ title="Is the response correct?",
415
+ labels=["yes", "no"],
416
+ ),
417
+ ]
418
+ for item in ["positive", "negative"]:
419
+ if retrieval:
420
+ questions.append(
421
+ rg.LabelQuestion(
422
+ name=f"is_{item}_retrieval_relevant",
423
+ title=f"Is the {item} retrieval relevant?",
424
+ labels=["yes", "no"],
425
+ required=False,
426
+ )
427
+ )
428
+ if reranking:
429
+ questions.append(
430
+ rg.LabelQuestion(
431
+ name=f"is_{item}_reranking_relevant",
432
+ title=f"Is the {item} reranking relevant?",
433
+ labels=["yes", "no"],
434
+ required=False,
435
+ )
436
+ )
437
+ metadata = [
438
+ rg.IntegerMetadataProperty(
439
+ name=f"{item}_length", title=f"{item.capitalize()} length"
440
+ )
441
+ for item in ["context", "question", "response"]
442
+ ]
443
+
444
+ vectors = [
445
+ rg.VectorField(
446
+ name=f"{item}_embeddings",
447
+ dimensions=get_sentence_embedding_dimensions(),
448
+ )
449
+ for item in ["context", "question", "response"]
450
+ ]
451
+ settings = rg.Settings(
452
+ fields=fields,
453
+ questions=questions,
454
+ metadata=metadata,
455
+ vectors=vectors,
456
+ guidelines="Please review the conversation and provide an evaluation.",
457
+ )
458
+
459
+ dataframe["chat"] = dataframe.apply(
460
+ lambda row: [
461
+ {"role": "user", "content": row["question"]},
462
+ {"role": "assistant", "content": row["response"]},
463
+ ],
464
+ axis=1,
465
+ )
466
+
467
+ for item in ["context", "question", "response"]:
468
+ dataframe[f"{item}_length"] = dataframe[item].apply(
469
+ lambda x: len(x) if x is not None else 0
470
+ )
471
+ dataframe[f"{item}_embeddings"] = get_embeddings(
472
+ dataframe[item].apply(lambda x: x if x is not None else "").to_list()
473
+ )
474
+
475
+ rg_dataset = client.datasets(name=repo_name, workspace=hf_user)
476
+ if rg_dataset is None:
477
+ rg_dataset = rg.Dataset(
478
+ name=repo_name,
479
+ workspace=hf_user,
480
+ settings=settings,
481
+ client=client,
482
+ )
483
+ rg_dataset = rg_dataset.create()
484
+
485
+ progress(0.7, desc="Pushing dataset to Argilla")
486
+ hf_dataset = Dataset.from_pandas(dataframe)
487
+ rg_dataset.records.log(records=hf_dataset)
488
+ progress(1.0, desc="Dataset pushed to Argilla")
489
+ except Exception as e:
490
+ raise gr.Error(f"Error pushing dataset to Argilla: {e}")
491
+ return ""
492
+
493
+
494
+ def save_local(
495
+ repo_id: str,
496
+ file_paths: list[str],
497
+ input_type: str,
498
+ system_prompt: str,
499
+ document_column: str,
500
+ retrieval_reranking: list[str],
501
+ num_rows: int,
502
+ temperature: float,
503
+ repo_name: str,
504
+ temperature_completion: float,
505
+ ) -> pd.DataFrame:
506
+ retrieval = "Retrieval" in retrieval_reranking
507
+ reranking = "Reranking" in retrieval_reranking
508
+
509
+ if input_type == "prompt-input":
510
+ dataframe = pd.DataFrame(columns=["context", "question", "response"])
511
+ else:
512
+ dataframe, _ = load_dataset_file(
513
+ repo_id=repo_id,
514
+ file_paths=file_paths,
515
+ input_type=input_type,
516
+ num_rows=num_rows,
517
+ )
518
+ dataframe = generate_dataset(
519
+ input_type=input_type,
520
+ dataframe=dataframe,
521
+ system_prompt=system_prompt,
522
+ document_column=document_column,
523
+ retrieval=retrieval,
524
+ reranking=reranking,
525
+ num_rows=num_rows,
526
+ temperature=temperature,
527
+ temperature_completion=temperature_completion,
528
+ )
529
+ local_dataset = Dataset.from_pandas(dataframe)
530
+ output_csv = os.path.join(SAVE_LOCAL_DIR, repo_name + ".csv")
531
+ output_json = os.path.join(SAVE_LOCAL_DIR, repo_name + ".json")
532
+ local_dataset.to_csv(output_csv, index=False)
533
+ local_dataset.to_json(output_json, index=False)
534
+ return output_csv, output_json
535
+
536
+
537
+ def show_system_prompt_visibility():
538
+ return {system_prompt: gr.Textbox(visible=True)}
539
+
540
+
541
+ def hide_system_prompt_visibility():
542
+ return {system_prompt: gr.Textbox(visible=False)}
543
+
544
+
545
+ def show_document_column_visibility():
546
+ return {document_column: gr.Dropdown(visible=True)}
547
+
548
+
549
+ def hide_document_column_visibility():
550
+ return {
551
+ document_column: gr.Dropdown(
552
+ choices=["Load your data first in step 1."],
553
+ value="Load your data first in step 1.",
554
+ visible=False,
555
+ )
556
+ }
557
+
558
+
559
+ def show_pipeline_code_visibility():
560
+ return {pipeline_code_ui: gr.Accordion(visible=True)}
561
+
562
+
563
+ def hide_pipeline_code_visibility():
564
+ return {pipeline_code_ui: gr.Accordion(visible=False)}
565
+
566
+
567
+ def show_temperature_completion():
568
+ if MODEL != MODEL_COMPLETION:
569
+ return {temperature_completion: gr.Slider(value=0.9, visible=True)}
570
+
571
+
572
+ def show_save_local_button():
573
+ return {btn_save_local: gr.Button(visible=True)}
574
+
575
+
576
+ def hide_save_local_button():
577
+ return {btn_save_local: gr.Button(visible=False)}
578
+
579
+
580
+ def show_save_local():
581
+ gr.update(success_message, min_height=0)
582
+ return {
583
+ csv_file: gr.File(visible=True),
584
+ json_file: gr.File(visible=True),
585
+ success_message: success_message,
586
+ }
587
+
588
+
589
+ def hide_save_local():
590
+ gr.update(success_message, min_height=100)
591
+ return {
592
+ csv_file: gr.File(visible=False),
593
+ json_file: gr.File(visible=False),
594
+ success_message: success_message,
595
+ }
596
+
597
+
598
+ ######################
599
+ # Gradio UI
600
+ ######################
601
+
602
+
603
+ with gr.Blocks() as app:
604
+ with gr.Column() as main_ui:
605
+ gr.Markdown("## 1. Select your input")
606
+ with gr.Row(equal_height=False):
607
+ with gr.Column(scale=2):
608
+ input_type = gr.Dropdown(
609
+ label="Input type",
610
+ choices=["dataset-input", "file-input", "prompt-input"],
611
+ value="dataset-input",
612
+ multiselect=False,
613
+ visible=False,
614
+ )
615
+ with gr.Tab("Load from Hub") as tab_dataset_input:
616
+ with gr.Row(equal_height=False):
617
+ with gr.Column(scale=2):
618
+ search_in = HuggingfaceHubSearch(
619
+ label="Search",
620
+ placeholder="Search for a dataset",
621
+ search_type="dataset",
622
+ sumbit_on_select=True,
623
+ )
624
+ with gr.Row():
625
+ clear_dataset_btn_part = gr.Button(
626
+ "Clear", variant="secondary"
627
+ )
628
+ load_dataset_btn = gr.Button("Load", variant="primary")
629
+ with gr.Column(scale=3):
630
+ examples = gr.Examples(
631
+ examples=[
632
+ "charris/wikipedia_sample",
633
+ "plaguss/argilla_sdk_docs_raw_unstructured",
634
+ "BeIR/hotpotqa-generated-queries",
635
+ ],
636
+ label="Example datasets",
637
+ fn=lambda x: x,
638
+ inputs=[search_in],
639
+ run_on_click=True,
640
+ )
641
+ search_out = gr.HTML(label="Dataset preview", visible=False)
642
+ with gr.Tab("Load your file") as tab_file_input:
643
+ with gr.Row(equal_height=False):
644
+ with gr.Column(scale=2):
645
+ file_in = gr.File(
646
+ label="Upload your file. Supported formats: .md, .txt, .docx, .pdf",
647
+ file_count="multiple",
648
+ file_types=[".md", ".txt", ".docx", ".pdf"],
649
+ )
650
+ with gr.Row():
651
+ clear_file_btn_part = gr.Button(
652
+ "Clear", variant="secondary"
653
+ )
654
+ load_file_btn = gr.Button("Load", variant="primary")
655
+ with gr.Column(scale=3):
656
+ file_out = gr.HTML(label="Dataset preview", visible=False)
657
+ with gr.Tab("Generate from prompt") as tab_prompt_input:
658
+ with gr.Row(equal_height=False):
659
+ with gr.Column(scale=2):
660
+ dataset_description = gr.Textbox(
661
+ label="Dataset description",
662
+ placeholder="Give a precise description of your desired dataset.",
663
+ )
664
+ with gr.Row():
665
+ clear_prompt_btn_part = gr.Button(
666
+ "Clear", variant="secondary"
667
+ )
668
+ load_prompt_btn = gr.Button("Create", variant="primary")
669
+ with gr.Column(scale=3):
670
+ examples = gr.Examples(
671
+ examples=DEFAULT_DATASET_DESCRIPTIONS,
672
+ inputs=[dataset_description],
673
+ cache_examples=False,
674
+ label="Examples",
675
+ )
676
+
677
+ gr.HTML(value="<hr>")
678
+ gr.Markdown(value="## 2. Configure your task")
679
+ with gr.Row(equal_height=False):
680
+ with gr.Column(scale=2):
681
+ system_prompt = gr.Textbox(
682
+ label="System prompt",
683
+ placeholder="You are a helpful assistant.",
684
+ visible=False,
685
+ )
686
+ document_column = gr.Dropdown(
687
+ label="Document Column",
688
+ info="Select the document column to generate the RAG dataset",
689
+ choices=["Load your data first in step 1."],
690
+ value="Load your data first in step 1.",
691
+ interactive=False,
692
+ multiselect=False,
693
+ allow_custom_value=False,
694
+ )
695
+ retrieval_reranking = gr.CheckboxGroup(
696
+ choices=[("Retrieval", "Retrieval"), ("Reranking", "Reranking")],
697
+ type="value",
698
+ label="Data for RAG",
699
+ info="Indicate the additional data you want to generate for RAG.",
700
+ )
701
+ with gr.Row():
702
+ clear_btn_full = gr.Button("Clear", variant="secondary")
703
+ btn_apply_to_sample_dataset = gr.Button("Save", variant="primary")
704
+ with gr.Column(scale=3):
705
+ dataframe = gr.Dataframe(
706
+ headers=["context", "question", "response"],
707
+ wrap=True,
708
+ interactive=False,
709
+ )
710
+
711
+ gr.HTML(value="<hr>")
712
+ gr.Markdown(value="## 3. Generate your dataset")
713
+ with gr.Row(equal_height=False):
714
+ with gr.Column(scale=2):
715
+ org_name = get_org_dropdown()
716
+ repo_name = gr.Textbox(
717
+ label="Repo name",
718
+ placeholder="dataset_name",
719
+ value=f"my-distiset-{str(uuid.uuid4())[:8]}",
720
+ interactive=True,
721
+ )
722
+ num_rows = gr.Number(
723
+ label="Number of rows",
724
+ value=10,
725
+ interactive=True,
726
+ scale=1,
727
+ )
728
+ temperature = gr.Slider(
729
+ label="Temperature",
730
+ minimum=0.1,
731
+ maximum=1.5,
732
+ value=0.7,
733
+ step=0.1,
734
+ interactive=True,
735
+ )
736
+ temperature_completion = gr.Slider(
737
+ label="Temperature for completion",
738
+ minimum=0.1,
739
+ maximum=1.5,
740
+ value=None,
741
+ step=0.1,
742
+ interactive=True,
743
+ visible=False,
744
+ )
745
+ private = gr.Checkbox(
746
+ label="Private dataset",
747
+ value=False,
748
+ interactive=True,
749
+ scale=1,
750
+ )
751
+ btn_push_to_hub = gr.Button("Push to Hub", variant="primary", scale=2)
752
+ btn_save_local = gr.Button(
753
+ "Save locally", variant="primary", scale=2, visible=False
754
+ )
755
+ with gr.Column(scale=3):
756
+ csv_file = gr.File(
757
+ label="CSV",
758
+ elem_classes="datasets",
759
+ visible=False,
760
+ )
761
+ json_file = gr.File(
762
+ label="JSON",
763
+ elem_classes="datasets",
764
+ visible=False,
765
+ )
766
+ success_message = gr.Markdown(
767
+ visible=False,
768
+ min_height=0, # don't remove this otherwise progress is not visible
769
+ )
770
+ with gr.Accordion(
771
+ "Customize your pipeline with distilabel",
772
+ open=False,
773
+ visible=False,
774
+ ) as pipeline_code_ui:
775
+ code = generate_pipeline_code(
776
+ repo_id=search_in.value,
777
+ input_type=input_type.value,
778
+ system_prompt=system_prompt.value,
779
+ document_column=document_column.value,
780
+ retrieval_reranking=retrieval_reranking.value,
781
+ num_rows=num_rows.value,
782
+ )
783
+ pipeline_code = gr.Code(
784
+ value=code,
785
+ language="python",
786
+ label="Distilabel Pipeline Code",
787
+ )
788
+
789
+ tab_dataset_input.select(
790
+ fn=lambda: "dataset-input",
791
+ inputs=[],
792
+ outputs=[input_type],
793
+ ).then(fn=hide_system_prompt_visibility, inputs=[], outputs=[system_prompt]).then(
794
+ fn=show_document_column_visibility, inputs=[], outputs=[document_column]
795
+ )
796
+
797
+ tab_file_input.select(
798
+ fn=lambda: "file-input",
799
+ inputs=[],
800
+ outputs=[input_type],
801
+ ).then(fn=hide_system_prompt_visibility, inputs=[], outputs=[system_prompt]).then(
802
+ fn=show_document_column_visibility, inputs=[], outputs=[document_column]
803
+ )
804
+
805
+ tab_prompt_input.select(
806
+ fn=lambda: "prompt-input",
807
+ inputs=[],
808
+ outputs=[input_type],
809
+ ).then(fn=show_system_prompt_visibility, inputs=[], outputs=[system_prompt]).then(
810
+ fn=hide_document_column_visibility, inputs=[], outputs=[document_column]
811
+ )
812
+
813
+ search_in.submit(
814
+ fn=lambda df: pd.DataFrame(columns=df.columns),
815
+ inputs=[dataframe],
816
+ outputs=[dataframe],
817
+ )
818
+
819
+ gr.on(
820
+ triggers=[load_dataset_btn.click, load_file_btn.click],
821
+ fn=load_dataset_file,
822
+ inputs=[search_in, file_in, input_type],
823
+ outputs=[dataframe, document_column],
824
+ )
825
+
826
+ load_prompt_btn.click(
827
+ fn=generate_system_prompt,
828
+ inputs=[dataset_description],
829
+ outputs=[system_prompt],
830
+ ).success(
831
+ fn=generate_sample_dataset,
832
+ inputs=[
833
+ search_in,
834
+ file_in,
835
+ input_type,
836
+ system_prompt,
837
+ document_column,
838
+ retrieval_reranking,
839
+ num_rows,
840
+ ],
841
+ outputs=dataframe,
842
+ )
843
+
844
+ btn_apply_to_sample_dataset.click(
845
+ fn=generate_sample_dataset,
846
+ inputs=[
847
+ search_in,
848
+ file_in,
849
+ input_type,
850
+ system_prompt,
851
+ document_column,
852
+ retrieval_reranking,
853
+ num_rows,
854
+ ],
855
+ outputs=dataframe,
856
+ )
857
+
858
+ btn_push_to_hub.click(
859
+ fn=validate_argilla_user_workspace_dataset,
860
+ inputs=[repo_name],
861
+ outputs=[success_message],
862
+ ).then(
863
+ fn=validate_push_to_hub,
864
+ inputs=[org_name, repo_name],
865
+ outputs=[success_message],
866
+ ).success(
867
+ fn=hide_save_local,
868
+ outputs=[csv_file, json_file, success_message],
869
+ ).success(
870
+ fn=hide_success_message,
871
+ outputs=[success_message],
872
+ ).success(
873
+ fn=hide_pipeline_code_visibility,
874
+ inputs=[],
875
+ outputs=[pipeline_code_ui],
876
+ ).success(
877
+ fn=push_dataset,
878
+ inputs=[
879
+ org_name,
880
+ repo_name,
881
+ private,
882
+ search_in,
883
+ file_in,
884
+ input_type,
885
+ system_prompt,
886
+ document_column,
887
+ retrieval_reranking,
888
+ num_rows,
889
+ temperature,
890
+ temperature_completion,
891
+ pipeline_code,
892
+ ],
893
+ outputs=[success_message],
894
+ ).success(
895
+ fn=show_success_message,
896
+ inputs=[org_name, repo_name],
897
+ outputs=[success_message],
898
+ ).success(
899
+ fn=generate_pipeline_code,
900
+ inputs=[
901
+ search_in,
902
+ input_type,
903
+ system_prompt,
904
+ document_column,
905
+ retrieval_reranking,
906
+ num_rows,
907
+ ],
908
+ outputs=[pipeline_code],
909
+ ).success(
910
+ fn=show_pipeline_code_visibility,
911
+ inputs=[],
912
+ outputs=[pipeline_code_ui],
913
+ )
914
+
915
+ btn_save_local.click(
916
+ fn=hide_success_message,
917
+ outputs=[success_message],
918
+ ).success(
919
+ fn=hide_pipeline_code_visibility,
920
+ inputs=[],
921
+ outputs=[pipeline_code_ui],
922
+ ).success(
923
+ fn=show_save_local,
924
+ inputs=[],
925
+ outputs=[csv_file, json_file, success_message],
926
+ ).success(
927
+ save_local,
928
+ inputs=[
929
+ search_in,
930
+ file_in,
931
+ input_type,
932
+ system_prompt,
933
+ document_column,
934
+ retrieval_reranking,
935
+ num_rows,
936
+ temperature,
937
+ repo_name,
938
+ temperature_completion,
939
+ ],
940
+ outputs=[csv_file, json_file],
941
+ ).success(
942
+ fn=generate_pipeline_code,
943
+ inputs=[
944
+ search_in,
945
+ input_type,
946
+ system_prompt,
947
+ document_column,
948
+ retrieval_reranking,
949
+ num_rows,
950
+ ],
951
+ outputs=[pipeline_code],
952
+ ).success(
953
+ fn=show_pipeline_code_visibility,
954
+ inputs=[],
955
+ outputs=[pipeline_code_ui],
956
+ )
957
+
958
+ clear_dataset_btn_part.click(fn=lambda: "", inputs=[], outputs=[search_in])
959
+ clear_file_btn_part.click(fn=lambda: None, inputs=[], outputs=[file_in])
960
+ clear_prompt_btn_part.click(fn=lambda: "", inputs=[], outputs=[dataset_description])
961
+ clear_btn_full.click(
962
+ fn=lambda df: ("", [], pd.DataFrame(columns=df.columns)),
963
+ inputs=[dataframe],
964
+ outputs=[document_column, retrieval_reranking, dataframe],
965
+ )
966
+
967
+ app.load(fn=swap_visibility, outputs=main_ui)
968
+ app.load(fn=get_org_dropdown, outputs=[org_name])
969
+ app.load(fn=get_random_repo_name, outputs=[repo_name])
970
+ app.load(fn=show_temperature_completion, outputs=[temperature_completion])
971
+ if SAVE_LOCAL_DIR is not None:
972
+ app.load(fn=show_save_local_button, outputs=btn_save_local)