Spaces:
Build error
Build error
Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .dockerignore +57 -0
- .env.local.template +54 -0
- .gitattributes +5 -0
- .gitignore +173 -0
- LICENSE +201 -0
- README.md +168 -6
- app.py +4 -0
- assets/argilla.png +3 -0
- assets/flow.png +3 -0
- assets/logo.png +0 -0
- assets/logo.svg +1 -0
- assets/ui-full.png +3 -0
- assets/ui.png +3 -0
- docker-compose.yml +17 -0
- docker/.env.docker.template +43 -0
- docker/Dockerfile +45 -0
- docker/README.md +80 -0
- docker/argilla/compose.yml +118 -0
- docker/ollama/compose.yml +48 -0
- docker/ollama/entrypoint.sh +35 -0
- examples/argilla-deployment.py +18 -0
- examples/blog_private_synthetic_data_generation.md +222 -0
- examples/fine-tune-deepseek-reasoning-sft.ipynb +0 -0
- examples/fine-tune-modernbert-classifier.ipynb +538 -0
- examples/fine-tune-modernbert-rag.ipynb +980 -0
- examples/fine-tune-smollm2-on-synthetic-data.ipynb +310 -0
- examples/hf-dedicated-or-tgi-deployment.py +19 -0
- examples/hf-serverless-deployment-deepseek.py +16 -0
- examples/hf-serverless-deployment.py +15 -0
- examples/hf-serverless-different-model-for-completion.py +16 -0
- examples/ollama-deployment.py +22 -0
- examples/ollama-different-model-for-completion.py +26 -0
- examples/openai-deployment.py +18 -0
- examples/vllm-deployment.py +21 -0
- packages.txt +2 -0
- pdm.lock +0 -0
- pyproject.toml +40 -0
- requirements.txt +1 -0
- src/synthetic_dataset_generator/__init__.py +20 -0
- src/synthetic_dataset_generator/__main__.py +4 -0
- src/synthetic_dataset_generator/_distiset.py +148 -0
- src/synthetic_dataset_generator/_inference_endpoints.py +58 -0
- src/synthetic_dataset_generator/_tabbedinterface.py +69 -0
- src/synthetic_dataset_generator/app.py +35 -0
- src/synthetic_dataset_generator/apps/__init__.py +0 -0
- src/synthetic_dataset_generator/apps/about.py +15 -0
- src/synthetic_dataset_generator/apps/base.py +270 -0
- src/synthetic_dataset_generator/apps/chat.py +1142 -0
- src/synthetic_dataset_generator/apps/eval.py +894 -0
- src/synthetic_dataset_generator/apps/rag.py +972 -0
.dockerignore
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Version control
|
| 2 |
+
.git
|
| 3 |
+
.gitignore
|
| 4 |
+
|
| 5 |
+
# Python
|
| 6 |
+
__pycache__/
|
| 7 |
+
*.py[cod]
|
| 8 |
+
*$py.class
|
| 9 |
+
*.so
|
| 10 |
+
.Python
|
| 11 |
+
build/
|
| 12 |
+
develop-eggs/
|
| 13 |
+
dist/
|
| 14 |
+
downloads/
|
| 15 |
+
eggs/
|
| 16 |
+
.eggs/
|
| 17 |
+
lib/
|
| 18 |
+
lib64/
|
| 19 |
+
parts/
|
| 20 |
+
sdist/
|
| 21 |
+
var/
|
| 22 |
+
wheels/
|
| 23 |
+
share/python-wheels/
|
| 24 |
+
*.egg-info/
|
| 25 |
+
.installed.cfg
|
| 26 |
+
*.egg
|
| 27 |
+
|
| 28 |
+
# Virtual environments
|
| 29 |
+
.env*
|
| 30 |
+
!.env.example
|
| 31 |
+
.venv
|
| 32 |
+
env/
|
| 33 |
+
venv/
|
| 34 |
+
ENV/
|
| 35 |
+
|
| 36 |
+
# IDE
|
| 37 |
+
.idea/
|
| 38 |
+
.vscode/
|
| 39 |
+
*.swp
|
| 40 |
+
*.swo
|
| 41 |
+
|
| 42 |
+
# Testing
|
| 43 |
+
.tox/
|
| 44 |
+
.coverage
|
| 45 |
+
.coverage.*
|
| 46 |
+
.cache
|
| 47 |
+
nosetests.xml
|
| 48 |
+
coverage.xml
|
| 49 |
+
*.cover
|
| 50 |
+
.hypothesis/
|
| 51 |
+
.pytest_cache/
|
| 52 |
+
|
| 53 |
+
# Project specific
|
| 54 |
+
nltk_data/
|
| 55 |
+
.pdm-python
|
| 56 |
+
.pdm.toml
|
| 57 |
+
__pypackages__/
|
.env.local.template
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# =============================================================================
|
| 2 |
+
# LOCAL/API CONFIGURATION
|
| 3 |
+
# =============================================================================
|
| 4 |
+
|
| 5 |
+
# -----------------------------------------------------------------------------
|
| 6 |
+
# REQUIRED CONFIGURATION
|
| 7 |
+
# -----------------------------------------------------------------------------
|
| 8 |
+
# Hugging Face token (required for all setups)
|
| 9 |
+
HF_TOKEN=hf_...
|
| 10 |
+
|
| 11 |
+
# Generation Settings
|
| 12 |
+
MAX_NUM_TOKENS=2048
|
| 13 |
+
MAX_NUM_ROWS=1000
|
| 14 |
+
DEFAULT_BATCH_SIZE=5
|
| 15 |
+
|
| 16 |
+
# Required for chat data generation with Llama or Qwen models
|
| 17 |
+
# Options: "llama3", "qwen2", or custom template string
|
| 18 |
+
MAGPIE_PRE_QUERY_TEMPLATE=llama3
|
| 19 |
+
|
| 20 |
+
# -----------------------------------------------------------------------------
|
| 21 |
+
# A. CLOUD API SERVICES
|
| 22 |
+
# -----------------------------------------------------------------------------
|
| 23 |
+
|
| 24 |
+
# 1. HUGGING FACE INFERENCE API (Default, Recommended)
|
| 25 |
+
MODEL=meta-llama/Llama-3.1-8B-Instruct
|
| 26 |
+
# MODEL=Qwen/Qwen2.5-1.5B-Instruct
|
| 27 |
+
|
| 28 |
+
# 2. OPENAI API
|
| 29 |
+
# OPENAI_BASE_URL=https://api.openai.com/v1/
|
| 30 |
+
# MODEL=gpt-4
|
| 31 |
+
# API_KEY=sk-...
|
| 32 |
+
|
| 33 |
+
# 3. HUGGING FACE SPACE FOR ARGILLA (optional)
|
| 34 |
+
# ARGILLA_API_URL=https://your-space.hf.space/
|
| 35 |
+
# ARGILLA_API_KEY=your_key
|
| 36 |
+
|
| 37 |
+
# -----------------------------------------------------------------------------
|
| 38 |
+
# B. LOCAL SERVICES (Requires Installation)
|
| 39 |
+
# -----------------------------------------------------------------------------
|
| 40 |
+
|
| 41 |
+
# 1. LOCAL OLLAMA
|
| 42 |
+
# OLLAMA_BASE_URL=http://127.0.0.1:11434/
|
| 43 |
+
# MODEL=llama3.2:1b
|
| 44 |
+
# TOKENIZER_ID=meta-llama/Llama-3.2-1B-Instruct
|
| 45 |
+
|
| 46 |
+
# 2. LOCAL VLLM
|
| 47 |
+
# VLLM_BASE_URL=http://127.0.0.1:8000/
|
| 48 |
+
# MODEL=Qwen/Qwen2.5-1.5B-Instruct
|
| 49 |
+
# TOKENIZER_ID=Qwen/Qwen2.5-1.5B-Instruct
|
| 50 |
+
|
| 51 |
+
# 3. LOCAL TGI
|
| 52 |
+
# HUGGINGFACE_BASE_URL=http://127.0.0.1:3000/
|
| 53 |
+
# MODEL=meta-llama/Llama-3.1-8B-Instruct
|
| 54 |
+
# TOKENIZER_ID=meta-llama/Llama-3.1-8B-Instruct
|
.gitattributes
CHANGED
|
@@ -33,3 +33,8 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
assets/flow.png filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
*.sh text eol=lf
|
| 38 |
+
assets/argilla.png filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
assets/ui-full.png filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
assets/ui.png filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Byte-compiled / optimized / DLL files
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*$py.class
|
| 5 |
+
|
| 6 |
+
# C extensions
|
| 7 |
+
*.so
|
| 8 |
+
|
| 9 |
+
# Distribution / packaging
|
| 10 |
+
.Python
|
| 11 |
+
build/
|
| 12 |
+
develop-eggs/
|
| 13 |
+
dist/
|
| 14 |
+
downloads/
|
| 15 |
+
eggs/
|
| 16 |
+
.eggs/
|
| 17 |
+
lib/
|
| 18 |
+
lib64/
|
| 19 |
+
parts/
|
| 20 |
+
sdist/
|
| 21 |
+
var/
|
| 22 |
+
wheels/
|
| 23 |
+
share/python-wheels/
|
| 24 |
+
*.egg-info/
|
| 25 |
+
.installed.cfg
|
| 26 |
+
*.egg
|
| 27 |
+
MANIFEST
|
| 28 |
+
|
| 29 |
+
# PyInstaller
|
| 30 |
+
# Usually these files are written by a python script from a template
|
| 31 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
| 32 |
+
*.manifest
|
| 33 |
+
*.spec
|
| 34 |
+
|
| 35 |
+
# Installer logs
|
| 36 |
+
pip-log.txt
|
| 37 |
+
pip-delete-this-directory.txt
|
| 38 |
+
|
| 39 |
+
# Unit test / coverage reports
|
| 40 |
+
htmlcov/
|
| 41 |
+
.tox/
|
| 42 |
+
.nox/
|
| 43 |
+
.coverage
|
| 44 |
+
.coverage.*
|
| 45 |
+
.cache
|
| 46 |
+
nosetests.xml
|
| 47 |
+
coverage.xml
|
| 48 |
+
*.cover
|
| 49 |
+
*.py,cover
|
| 50 |
+
.hypothesis/
|
| 51 |
+
.pytest_cache/
|
| 52 |
+
cover/
|
| 53 |
+
|
| 54 |
+
# Translations
|
| 55 |
+
*.mo
|
| 56 |
+
*.pot
|
| 57 |
+
|
| 58 |
+
# Django stuff:
|
| 59 |
+
*.log
|
| 60 |
+
local_settings.py
|
| 61 |
+
db.sqlite3
|
| 62 |
+
db.sqlite3-journal
|
| 63 |
+
|
| 64 |
+
# Flask stuff:
|
| 65 |
+
instance/
|
| 66 |
+
.webassets-cache
|
| 67 |
+
|
| 68 |
+
# Scrapy stuff:
|
| 69 |
+
.scrapy
|
| 70 |
+
|
| 71 |
+
# Sphinx documentation
|
| 72 |
+
docs/_build/
|
| 73 |
+
|
| 74 |
+
# PyBuilder
|
| 75 |
+
.pybuilder/
|
| 76 |
+
target/
|
| 77 |
+
|
| 78 |
+
# Jupyter Notebook
|
| 79 |
+
.ipynb_checkpoints
|
| 80 |
+
|
| 81 |
+
# IPython
|
| 82 |
+
profile_default/
|
| 83 |
+
ipython_config.py
|
| 84 |
+
|
| 85 |
+
# pyenv
|
| 86 |
+
# For a library or package, you might want to ignore these files since the code is
|
| 87 |
+
# intended to run in multiple environments; otherwise, check them in:
|
| 88 |
+
# .python-version
|
| 89 |
+
|
| 90 |
+
# pipenv
|
| 91 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
| 92 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
| 93 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
| 94 |
+
# install all needed dependencies.
|
| 95 |
+
#Pipfile.lock
|
| 96 |
+
|
| 97 |
+
# poetry
|
| 98 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
| 99 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
| 100 |
+
# commonly ignored for libraries.
|
| 101 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
| 102 |
+
#poetry.lock
|
| 103 |
+
|
| 104 |
+
# pdm
|
| 105 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
| 106 |
+
#pdm.lock
|
| 107 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
| 108 |
+
# in version control.
|
| 109 |
+
# https://pdm-project.org/#use-with-ide
|
| 110 |
+
.pdm.toml
|
| 111 |
+
.pdm-python
|
| 112 |
+
.pdm-build/
|
| 113 |
+
|
| 114 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
| 115 |
+
__pypackages__/
|
| 116 |
+
|
| 117 |
+
# Celery stuff
|
| 118 |
+
celerybeat-schedule
|
| 119 |
+
celerybeat.pid
|
| 120 |
+
|
| 121 |
+
# SageMath parsed files
|
| 122 |
+
*.sage.py
|
| 123 |
+
|
| 124 |
+
# Environments
|
| 125 |
+
.env
|
| 126 |
+
.venv
|
| 127 |
+
env/
|
| 128 |
+
venv/
|
| 129 |
+
ENV/
|
| 130 |
+
env.bak/
|
| 131 |
+
venv.bak/
|
| 132 |
+
.python-version
|
| 133 |
+
|
| 134 |
+
# Spyder project settings
|
| 135 |
+
.spyderproject
|
| 136 |
+
.spyproject
|
| 137 |
+
|
| 138 |
+
# Rope project settings
|
| 139 |
+
.ropeproject
|
| 140 |
+
|
| 141 |
+
# mkdocs documentation
|
| 142 |
+
/site
|
| 143 |
+
|
| 144 |
+
# mypy
|
| 145 |
+
.mypy_cache/
|
| 146 |
+
.dmypy.json
|
| 147 |
+
dmypy.json
|
| 148 |
+
|
| 149 |
+
# Pyre type checker
|
| 150 |
+
.pyre/
|
| 151 |
+
|
| 152 |
+
# pytype static type analyzer
|
| 153 |
+
.pytype/
|
| 154 |
+
|
| 155 |
+
# Cython debug symbols
|
| 156 |
+
cython_debug/
|
| 157 |
+
|
| 158 |
+
# PyCharm
|
| 159 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
| 160 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
| 161 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
| 162 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
| 163 |
+
#.idea/
|
| 164 |
+
.DS_Store
|
| 165 |
+
|
| 166 |
+
# nltk
|
| 167 |
+
nltk_data/
|
| 168 |
+
|
| 169 |
+
# examples
|
| 170 |
+
models/
|
| 171 |
+
|
| 172 |
+
# Elasticsearch data
|
| 173 |
+
elasticsearch_data/
|
LICENSE
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Apache License
|
| 2 |
+
Version 2.0, January 2004
|
| 3 |
+
http://www.apache.org/licenses/
|
| 4 |
+
|
| 5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
| 6 |
+
|
| 7 |
+
1. Definitions.
|
| 8 |
+
|
| 9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
| 10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
| 11 |
+
|
| 12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
| 13 |
+
the copyright owner that is granting the License.
|
| 14 |
+
|
| 15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
| 16 |
+
other entities that control, are controlled by, or are under common
|
| 17 |
+
control with that entity. For the purposes of this definition,
|
| 18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
| 19 |
+
direction or management of such entity, whether by contract or
|
| 20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
| 21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
| 22 |
+
|
| 23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
| 24 |
+
exercising permissions granted by this License.
|
| 25 |
+
|
| 26 |
+
"Source" form shall mean the preferred form for making modifications,
|
| 27 |
+
including but not limited to software source code, documentation
|
| 28 |
+
source, and configuration files.
|
| 29 |
+
|
| 30 |
+
"Object" form shall mean any form resulting from mechanical
|
| 31 |
+
transformation or translation of a Source form, including but
|
| 32 |
+
not limited to compiled object code, generated documentation,
|
| 33 |
+
and conversions to other media types.
|
| 34 |
+
|
| 35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
| 36 |
+
Object form, made available under the License, as indicated by a
|
| 37 |
+
copyright notice that is included in or attached to the work
|
| 38 |
+
(an example is provided in the Appendix below).
|
| 39 |
+
|
| 40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
| 41 |
+
form, that is based on (or derived from) the Work and for which the
|
| 42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
| 43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
| 44 |
+
of this License, Derivative Works shall not include works that remain
|
| 45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
| 46 |
+
the Work and Derivative Works thereof.
|
| 47 |
+
|
| 48 |
+
"Contribution" shall mean any work of authorship, including
|
| 49 |
+
the original version of the Work and any modifications or additions
|
| 50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
| 51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
| 52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
| 53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
| 54 |
+
means any form of electronic, verbal, or written communication sent
|
| 55 |
+
to the Licensor or its representatives, including but not limited to
|
| 56 |
+
communication on electronic mailing lists, source code control systems,
|
| 57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
| 58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
| 59 |
+
excluding communication that is conspicuously marked or otherwise
|
| 60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
| 61 |
+
|
| 62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
| 63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
| 64 |
+
subsequently incorporated within the Work.
|
| 65 |
+
|
| 66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
| 67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
| 70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
| 71 |
+
Work and such Derivative Works in Source or Object form.
|
| 72 |
+
|
| 73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
| 74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 76 |
+
(except as stated in this section) patent license to make, have made,
|
| 77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
| 78 |
+
where such license applies only to those patent claims licensable
|
| 79 |
+
by such Contributor that are necessarily infringed by their
|
| 80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
| 81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
| 82 |
+
institute patent litigation against any entity (including a
|
| 83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
| 84 |
+
or a Contribution incorporated within the Work constitutes direct
|
| 85 |
+
or contributory patent infringement, then any patent licenses
|
| 86 |
+
granted to You under this License for that Work shall terminate
|
| 87 |
+
as of the date such litigation is filed.
|
| 88 |
+
|
| 89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
| 90 |
+
Work or Derivative Works thereof in any medium, with or without
|
| 91 |
+
modifications, and in Source or Object form, provided that You
|
| 92 |
+
meet the following conditions:
|
| 93 |
+
|
| 94 |
+
(a) You must give any other recipients of the Work or
|
| 95 |
+
Derivative Works a copy of this License; and
|
| 96 |
+
|
| 97 |
+
(b) You must cause any modified files to carry prominent notices
|
| 98 |
+
stating that You changed the files; and
|
| 99 |
+
|
| 100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
| 101 |
+
that You distribute, all copyright, patent, trademark, and
|
| 102 |
+
attribution notices from the Source form of the Work,
|
| 103 |
+
excluding those notices that do not pertain to any part of
|
| 104 |
+
the Derivative Works; and
|
| 105 |
+
|
| 106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
| 107 |
+
distribution, then any Derivative Works that You distribute must
|
| 108 |
+
include a readable copy of the attribution notices contained
|
| 109 |
+
within such NOTICE file, excluding those notices that do not
|
| 110 |
+
pertain to any part of the Derivative Works, in at least one
|
| 111 |
+
of the following places: within a NOTICE text file distributed
|
| 112 |
+
as part of the Derivative Works; within the Source form or
|
| 113 |
+
documentation, if provided along with the Derivative Works; or,
|
| 114 |
+
within a display generated by the Derivative Works, if and
|
| 115 |
+
wherever such third-party notices normally appear. The contents
|
| 116 |
+
of the NOTICE file are for informational purposes only and
|
| 117 |
+
do not modify the License. You may add Your own attribution
|
| 118 |
+
notices within Derivative Works that You distribute, alongside
|
| 119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
| 120 |
+
that such additional attribution notices cannot be construed
|
| 121 |
+
as modifying the License.
|
| 122 |
+
|
| 123 |
+
You may add Your own copyright statement to Your modifications and
|
| 124 |
+
may provide additional or different license terms and conditions
|
| 125 |
+
for use, reproduction, or distribution of Your modifications, or
|
| 126 |
+
for any such Derivative Works as a whole, provided Your use,
|
| 127 |
+
reproduction, and distribution of the Work otherwise complies with
|
| 128 |
+
the conditions stated in this License.
|
| 129 |
+
|
| 130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
| 131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
| 132 |
+
by You to the Licensor shall be under the terms and conditions of
|
| 133 |
+
this License, without any additional terms or conditions.
|
| 134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
| 135 |
+
the terms of any separate license agreement you may have executed
|
| 136 |
+
with Licensor regarding such Contributions.
|
| 137 |
+
|
| 138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
| 139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
| 140 |
+
except as required for reasonable and customary use in describing the
|
| 141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
| 142 |
+
|
| 143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
| 144 |
+
agreed to in writing, Licensor provides the Work (and each
|
| 145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
| 146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 147 |
+
implied, including, without limitation, any warranties or conditions
|
| 148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
| 149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
| 150 |
+
appropriateness of using or redistributing the Work and assume any
|
| 151 |
+
risks associated with Your exercise of permissions under this License.
|
| 152 |
+
|
| 153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
| 154 |
+
whether in tort (including negligence), contract, or otherwise,
|
| 155 |
+
unless required by applicable law (such as deliberate and grossly
|
| 156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
| 157 |
+
liable to You for damages, including any direct, indirect, special,
|
| 158 |
+
incidental, or consequential damages of any character arising as a
|
| 159 |
+
result of this License or out of the use or inability to use the
|
| 160 |
+
Work (including but not limited to damages for loss of goodwill,
|
| 161 |
+
work stoppage, computer failure or malfunction, or any and all
|
| 162 |
+
other commercial damages or losses), even if such Contributor
|
| 163 |
+
has been advised of the possibility of such damages.
|
| 164 |
+
|
| 165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
| 166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
| 167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
| 168 |
+
or other liability obligations and/or rights consistent with this
|
| 169 |
+
License. However, in accepting such obligations, You may act only
|
| 170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
| 171 |
+
of any other Contributor, and only if You agree to indemnify,
|
| 172 |
+
defend, and hold each Contributor harmless for any liability
|
| 173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
| 174 |
+
of your accepting any such warranty or additional liability.
|
| 175 |
+
|
| 176 |
+
END OF TERMS AND CONDITIONS
|
| 177 |
+
|
| 178 |
+
APPENDIX: How to apply the Apache License to your work.
|
| 179 |
+
|
| 180 |
+
To apply the Apache License to your work, attach the following
|
| 181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
| 182 |
+
replaced with your own identifying information. (Don't include
|
| 183 |
+
the brackets!) The text should be enclosed in the appropriate
|
| 184 |
+
comment syntax for the file format. We also recommend that a
|
| 185 |
+
file or class name and description of purpose be included on the
|
| 186 |
+
same "printed page" as the copyright notice for easier
|
| 187 |
+
identification within third-party archives.
|
| 188 |
+
|
| 189 |
+
Copyright [yyyy] [name of copyright owner]
|
| 190 |
+
|
| 191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 192 |
+
you may not use this file except in compliance with the License.
|
| 193 |
+
You may obtain a copy of the License at
|
| 194 |
+
|
| 195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 196 |
+
|
| 197 |
+
Unless required by applicable law or agreed to in writing, software
|
| 198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 200 |
+
See the License for the specific language governing permissions and
|
| 201 |
+
limitations under the License.
|
README.md
CHANGED
|
@@ -1,10 +1,172 @@
|
|
| 1 |
---
|
| 2 |
title: Synthetic Data Generator
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
---
|
| 9 |
|
| 10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
title: Synthetic Data Generator
|
| 3 |
+
short_description: Build datasets using natural language
|
| 4 |
+
emoji: 🧬
|
| 5 |
+
colorFrom: yellow
|
| 6 |
+
colorTo: pink
|
| 7 |
+
sdk: gradio
|
| 8 |
+
sdk_version: 5.8.0
|
| 9 |
+
app_file: app.py
|
| 10 |
+
pinned: true
|
| 11 |
+
license: apache-2.0
|
| 12 |
+
hf_oauth: true
|
| 13 |
+
#header: mini
|
| 14 |
+
hf_oauth_scopes:
|
| 15 |
+
- read-repos
|
| 16 |
+
- write-repos
|
| 17 |
+
- manage-repos
|
| 18 |
+
- inference-api
|
| 19 |
---
|
| 20 |
|
| 21 |
+
> [!IMPORTANT]
|
| 22 |
+
The original authors have moved on to other projects. While the code might still be functional for its original purpose, please be aware that the original team does not plan to develop new features, bug fixes, or updates. If you'd like to become a maintainer, please open an issue to discuss.
|
| 23 |
+
>
|
| 24 |
+
>
|
| 25 |
+
<br>
|
| 26 |
+
|
| 27 |
+
<h2 align="center">
|
| 28 |
+
<a href=""><img src="https://raw.githubusercontent.com/argilla-io/synthetic-data-generator/main/assets/logo.svg" alt="Synthetic Data Generator Logo" width="80%"></a>
|
| 29 |
+
</h2>
|
| 30 |
+
<h3 align="center">Build datasets using natural language</h3>
|
| 31 |
+
|
| 32 |
+

|
| 33 |
+
|
| 34 |
+
## Introduction
|
| 35 |
+
|
| 36 |
+
Synthetic Data Generator is a tool that allows you to create high-quality datasets for training and fine-tuning language models. It leverages the power of distilabel and LLMs to generate synthetic data tailored to your specific needs. [The announcement blog](https://huggingface.co/blog/synthetic-data-generator) goes over a practical example of how to use it but you can also watch the [video](https://www.youtube.com/watch?v=nXjVtnGeEss) to see it in action.
|
| 37 |
+
|
| 38 |
+
Supported Tasks:
|
| 39 |
+
|
| 40 |
+
- Text Classification
|
| 41 |
+
- Chat Data for Supervised Fine-Tuning
|
| 42 |
+
- Retrieval Augmented Generation
|
| 43 |
+
|
| 44 |
+
This tool simplifies the process of creating custom datasets, enabling you to:
|
| 45 |
+
|
| 46 |
+
- Describe the characteristics of your desired application
|
| 47 |
+
- Iterate on sample datasets
|
| 48 |
+
- Produce full-scale datasets
|
| 49 |
+
- Push your datasets to the [Hugging Face Hub](https://huggingface.co/datasets?other=datacraft) and/or [Argilla](https://docs.argilla.io/)
|
| 50 |
+
|
| 51 |
+
By using the Synthetic Data Generator, you can rapidly prototype and create datasets for, accelerating your AI development process.
|
| 52 |
+
|
| 53 |
+
<p align="center">
|
| 54 |
+
<a href="https://twitter.com/argilla_io">
|
| 55 |
+
<img src="https://img.shields.io/badge/twitter-black?logo=x"/>
|
| 56 |
+
</a>
|
| 57 |
+
<a href="https://www.linkedin.com/company/argilla-io">
|
| 58 |
+
<img src="https://img.shields.io/badge/linkedin-blue?logo=linkedin"/>
|
| 59 |
+
</a>
|
| 60 |
+
<a href="http://hf.co/join/discord">
|
| 61 |
+
<img src="https://img.shields.io/badge/Discord-7289DA?&logo=discord&logoColor=white"/>
|
| 62 |
+
</a>
|
| 63 |
+
</p>
|
| 64 |
+
|
| 65 |
+
## Installation
|
| 66 |
+
|
| 67 |
+
You can simply install the package with:
|
| 68 |
+
|
| 69 |
+
```bash
|
| 70 |
+
pip install synthetic-dataset-generator
|
| 71 |
+
```
|
| 72 |
+
|
| 73 |
+
### Quickstart
|
| 74 |
+
|
| 75 |
+
```python
|
| 76 |
+
from synthetic_dataset_generator import launch
|
| 77 |
+
|
| 78 |
+
launch()
|
| 79 |
+
```
|
| 80 |
+
|
| 81 |
+
### Environment Variables
|
| 82 |
+
|
| 83 |
+
- `HF_TOKEN`: Your [Hugging Face token](https://huggingface.co/settings/tokens/new?ownUserPermissions=repo.content.read&ownUserPermissions=repo.write&globalPermissions=inference.serverless.write&tokenType=fineGrained) to push your datasets to the Hugging Face Hub and generate free completions from Hugging Face Inference Endpoints. You can find some configuration examples in the [examples](examples/) folder.
|
| 84 |
+
|
| 85 |
+
You can set the following environment variables to customize the generation process.
|
| 86 |
+
|
| 87 |
+
- `MAX_NUM_TOKENS`: The maximum number of tokens to generate, defaults to `2048`.
|
| 88 |
+
- `MAX_NUM_ROWS`: The maximum number of rows to generate, defaults to `1000`.
|
| 89 |
+
- `DEFAULT_BATCH_SIZE`: The default batch size to use for generating the dataset, defaults to `5`.
|
| 90 |
+
|
| 91 |
+
Optionally, you can use different API providers and models.
|
| 92 |
+
|
| 93 |
+
- `MODEL`: The model to use for generating the dataset, e.g. `meta-llama/Meta-Llama-3.1-8B-Instruct`, `gpt-4o`, `llama3.1`.
|
| 94 |
+
- `API_KEY`: The API key to use for the generation API, e.g. `hf_...`, `sk-...`. If not provided, it will default to the `HF_TOKEN` environment variable.
|
| 95 |
+
- `OPENAI_BASE_URL`: The base URL for any OpenAI compatible API, e.g. `https://api.openai.com/v1/`.
|
| 96 |
+
- `OLLAMA_BASE_URL`: The base URL for any Ollama compatible API, e.g. `http://127.0.0.1:11434/`.
|
| 97 |
+
- `HUGGINGFACE_BASE_URL`: The base URL for any Hugging Face compatible API, e.g. TGI server or Dedicated Inference Endpoints. If you want to use serverless inference, only set the `MODEL`.
|
| 98 |
+
- `VLLM_BASE_URL`: The base URL for any VLLM compatible API, e.g. `http://localhost:8000/`.
|
| 99 |
+
|
| 100 |
+
To use a specific model exclusively for generating completions, set the corresponding environment variables by appending `_COMPLETION` to the ones mentioned earlier. For example, you can use `MODEL_COMPLETION` and `OPENAI_BASE_URL_COMPLETION`.
|
| 101 |
+
|
| 102 |
+
SFT and Chat Data generation is not supported with OpenAI Endpoints. Additionally, you need to configure it per model family based on their prompt templates using the right `TOKENIZER_ID` and `MAGPIE_PRE_QUERY_TEMPLATE` environment variables.
|
| 103 |
+
|
| 104 |
+
- `TOKENIZER_ID`: The tokenizer ID to use for the magpie pipeline, e.g. `meta-llama/Meta-Llama-3.1-8B-Instruct`.
|
| 105 |
+
- `MAGPIE_PRE_QUERY_TEMPLATE`: Enforce setting the pre-query template for Magpie, which is only supported with Hugging Face Inference Endpoints. `llama3` and `qwen2` are supported out of the box and will use `"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n"` and `"<|im_start|>user\n"`, respectively. For other models, you can pass a custom pre-query template string.
|
| 106 |
+
|
| 107 |
+
Optionally, you can also push your datasets to Argilla for further curation by setting the following environment variables:
|
| 108 |
+
|
| 109 |
+
- `ARGILLA_API_KEY`: Your Argilla API key to push your datasets to Argilla.
|
| 110 |
+
- `ARGILLA_API_URL`: Your Argilla API URL to push your datasets to Argilla.
|
| 111 |
+
|
| 112 |
+
To save the generated datasets to a local directory instead of pushing them to the Hugging Face Hub, set the following environment variable:
|
| 113 |
+
|
| 114 |
+
- `SAVE_LOCAL_DIR`: The local directory to save the generated datasets to.
|
| 115 |
+
|
| 116 |
+
You can use our environment template as a starting point:
|
| 117 |
+
|
| 118 |
+
```bash
|
| 119 |
+
cp .env.local.template .env
|
| 120 |
+
```
|
| 121 |
+
|
| 122 |
+
### Argilla integration
|
| 123 |
+
|
| 124 |
+
Argilla is an open source tool for data curation. It allows you to annotate and review datasets, and push curated datasets to the Hugging Face Hub. You can easily get started with Argilla by following the [quickstart guide](https://docs.argilla.io/latest/getting_started/quickstart/).
|
| 125 |
+
|
| 126 |
+

|
| 127 |
+
|
| 128 |
+
## Custom synthetic data generation?
|
| 129 |
+
|
| 130 |
+
Each pipeline is based on distilabel, so you can easily change the LLM or the pipeline steps.
|
| 131 |
+
|
| 132 |
+
Check out the [distilabel library](https://github.com/argilla-io/distilabel) for more information.
|
| 133 |
+
|
| 134 |
+
## Development
|
| 135 |
+
|
| 136 |
+
Install the dependencies:
|
| 137 |
+
|
| 138 |
+
```bash
|
| 139 |
+
# Create a virtual environment
|
| 140 |
+
python -m venv .venv
|
| 141 |
+
source .venv/bin/activate
|
| 142 |
+
|
| 143 |
+
# Install the dependencies
|
| 144 |
+
pip install -e . # pdm install
|
| 145 |
+
```
|
| 146 |
+
|
| 147 |
+
Run the app:
|
| 148 |
+
|
| 149 |
+
```bash
|
| 150 |
+
python app.py
|
| 151 |
+
```
|
| 152 |
+
|
| 153 |
+
## 🐳 Docker Setup
|
| 154 |
+
|
| 155 |
+
The containerized tool uses Ollama for local LLM inference and Argilla for data curation. Here's the architecture:
|
| 156 |
+
|
| 157 |
+

|
| 158 |
+
|
| 159 |
+
Quick setup with all services (App + Ollama + Argilla):
|
| 160 |
+
|
| 161 |
+
```bash
|
| 162 |
+
# Copy environment template
|
| 163 |
+
cp docker/.env.docker.template .env # Add your HF_TOKEN in .env
|
| 164 |
+
|
| 165 |
+
# Build all services (this may take a few minutes)
|
| 166 |
+
docker compose -f docker-compose.yml -f docker/ollama/compose.yml -f docker/argilla/compose.yml build
|
| 167 |
+
|
| 168 |
+
# Start all services
|
| 169 |
+
docker compose -f docker-compose.yml -f docker/ollama/compose.yml -f docker/argilla/compose.yml up -d
|
| 170 |
+
```
|
| 171 |
+
|
| 172 |
+
> For more detailed Docker configurations and setups, check [docker/README.md](docker/README.md)
|
app.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from synthetic_dataset_generator import launch
|
| 2 |
+
|
| 3 |
+
if __name__ == "__main__":
|
| 4 |
+
launch()
|
assets/argilla.png
ADDED
|
Git LFS Details
|
assets/flow.png
ADDED
|
Git LFS Details
|
assets/logo.png
ADDED
|
assets/logo.svg
ADDED
|
|
assets/ui-full.png
ADDED
|
Git LFS Details
|
assets/ui.png
ADDED
|
Git LFS Details
|
docker-compose.yml
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
services:
|
| 2 |
+
app:
|
| 3 |
+
build:
|
| 4 |
+
context: .
|
| 5 |
+
dockerfile: docker/Dockerfile
|
| 6 |
+
image: synthetic-data-generator:app
|
| 7 |
+
ports:
|
| 8 |
+
- "7860:7860"
|
| 9 |
+
env_file:
|
| 10 |
+
- .env
|
| 11 |
+
networks:
|
| 12 |
+
- app-network
|
| 13 |
+
|
| 14 |
+
networks:
|
| 15 |
+
app-network:
|
| 16 |
+
name: synthetic-data-network
|
| 17 |
+
driver: bridge
|
docker/.env.docker.template
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# =============================================================================
|
| 2 |
+
# DOCKER CONFIGURATION ONLY - FULL SETUP (APP + OLLAMA + ARGILLA)
|
| 3 |
+
# =============================================================================
|
| 4 |
+
|
| 5 |
+
# Note: Before building:
|
| 6 |
+
# 1. Copy this template to the root directory: cp docker/.env.docker.template .env
|
| 7 |
+
# 2. Comment/uncomment the sections you want to use (OLLAMA and/or ARGILLA)
|
| 8 |
+
# 3. Then build and run with the appropriate docker compose command
|
| 9 |
+
|
| 10 |
+
# Hugging Face token with read/write permissions
|
| 11 |
+
HF_TOKEN=your_token_here
|
| 12 |
+
|
| 13 |
+
# -----------------------------------------------------------------------------
|
| 14 |
+
# GENERATION SETTINGS
|
| 15 |
+
# -----------------------------------------------------------------------------
|
| 16 |
+
MAX_NUM_TOKENS=2048
|
| 17 |
+
MAX_NUM_ROWS=1000
|
| 18 |
+
DEFAULT_BATCH_SIZE=5
|
| 19 |
+
|
| 20 |
+
# -----------------------------------------------------------------------------
|
| 21 |
+
# OLLAMA DOCKER CONFIGURATION
|
| 22 |
+
# -----------------------------------------------------------------------------
|
| 23 |
+
OLLAMA_BASE_URL=http://ollama:11434
|
| 24 |
+
OLLAMA_HARDWARE=latest # latest (for CPU/NVIDIA), rocm (for AMD)
|
| 25 |
+
|
| 26 |
+
# LLAMA 3.2
|
| 27 |
+
MODEL=llama3.2:1b
|
| 28 |
+
TOKENIZER_ID=meta-llama/Llama-3.2-1B-Instruct
|
| 29 |
+
MAGPIE_PRE_QUERY_TEMPLATE=llama3
|
| 30 |
+
|
| 31 |
+
# DEEPSEEK R1
|
| 32 |
+
#MODEL=deepseek-r1:1.5b # must match ollama tags https://ollama.com/library/deepseek-r1:1.5b
|
| 33 |
+
#TOKENIZER_ID=deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B
|
| 34 |
+
#MAGPIE_PRE_QUERY_TEMPLATE= "<|begin▁of▁sentence|>User: "
|
| 35 |
+
|
| 36 |
+
# -----------------------------------------------------------------------------
|
| 37 |
+
# ARGILLA DOCKER CONFIGURATION (persistent data)
|
| 38 |
+
# -----------------------------------------------------------------------------
|
| 39 |
+
ARGILLA_API_URL=http://argilla:6900
|
| 40 |
+
ARGILLA_USERNAME=admin
|
| 41 |
+
ARGILLA_PASSWORD=admin1234
|
| 42 |
+
ARGILLA_API_KEY=admin.1234
|
| 43 |
+
ARGILLA_REINDEX_DATASET=1
|
docker/Dockerfile
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Use Python slim image as base
|
| 2 |
+
FROM python:3.10-slim
|
| 3 |
+
|
| 4 |
+
# Set environment variables
|
| 5 |
+
ENV PYTHONUNBUFFERED=1 \
|
| 6 |
+
PYTHONDONTWRITEBYTECODE=1 \
|
| 7 |
+
PIP_NO_CACHE_DIR=1
|
| 8 |
+
|
| 9 |
+
# Create and set working directory
|
| 10 |
+
WORKDIR /app
|
| 11 |
+
|
| 12 |
+
# Create non-root user first
|
| 13 |
+
RUN useradd -m -u 1000 appuser
|
| 14 |
+
|
| 15 |
+
# Install system dependencies including build tools
|
| 16 |
+
RUN apt-get update && apt-get install -y --no-install-recommends \
|
| 17 |
+
curl \
|
| 18 |
+
build-essential \
|
| 19 |
+
cmake \
|
| 20 |
+
libgl1-mesa-glx \
|
| 21 |
+
libglib2.0-0 \
|
| 22 |
+
libsm6 \
|
| 23 |
+
libxext6 \
|
| 24 |
+
libxrender-dev \
|
| 25 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 26 |
+
|
| 27 |
+
# Install pdm
|
| 28 |
+
RUN pip install --no-cache-dir pdm
|
| 29 |
+
|
| 30 |
+
# Copy project files and set permissions
|
| 31 |
+
COPY . .
|
| 32 |
+
RUN chown -R appuser:appuser /app && \
|
| 33 |
+
chmod -R 755 /app
|
| 34 |
+
|
| 35 |
+
# Switch to non-root user
|
| 36 |
+
USER appuser
|
| 37 |
+
|
| 38 |
+
# Install dependencies in a virtual environment
|
| 39 |
+
RUN pdm install --prod --frozen-lockfile
|
| 40 |
+
|
| 41 |
+
# Expose Gradio port
|
| 42 |
+
EXPOSE 7860
|
| 43 |
+
|
| 44 |
+
# Start command using pdm run to use the virtual environment
|
| 45 |
+
CMD ["pdm", "run", "python", "-m", "synthetic_dataset_generator"]
|
docker/README.md
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Docker Configuration Guide
|
| 2 |
+
|
| 3 |
+
Each service runs in its own container, communicating through internal networks. The core app connects to Ollama for model inference and Argilla for data review:
|
| 4 |
+
|
| 5 |
+

|
| 6 |
+
|
| 7 |
+
The application can be run with different configurations using Docker Compose:
|
| 8 |
+
|
| 9 |
+
- `docker-compose.yml`: Core application
|
| 10 |
+
- `docker/ollama/compose.yml`: Ollama service for local LLM inference
|
| 11 |
+
- `docker/argilla/compose.yml`: Argilla service for data curation
|
| 12 |
+
|
| 13 |
+
## Ollama Integration
|
| 14 |
+
|
| 15 |
+
The `MODEL` variable in your `.env` file determines which model Ollama will download and use. For example:
|
| 16 |
+
```env
|
| 17 |
+
MODEL=llama3.2:1b
|
| 18 |
+
```
|
| 19 |
+
|
| 20 |
+
## Setup Options
|
| 21 |
+
|
| 22 |
+
### Full Setup (App + Ollama + Argilla)
|
| 23 |
+
```bash
|
| 24 |
+
# Keep all sections uncommented in .env
|
| 25 |
+
docker compose -f docker-compose.yml -f docker/ollama/compose.yml -f docker/argilla/compose.yml build
|
| 26 |
+
docker compose -f docker-compose.yml -f docker/ollama/compose.yml -f docker/argilla/compose.yml up -d
|
| 27 |
+
```
|
| 28 |
+
|
| 29 |
+
### App + Ollama
|
| 30 |
+
```bash
|
| 31 |
+
# Comment out ARGILLA section in .env
|
| 32 |
+
docker compose -f docker-compose.yml -f docker/ollama/compose.yml build
|
| 33 |
+
docker compose -f docker-compose.yml -f docker/ollama/compose.yml up -d
|
| 34 |
+
```
|
| 35 |
+
|
| 36 |
+
### App + Argilla
|
| 37 |
+
```bash
|
| 38 |
+
# Comment out OLLAMA section in .env
|
| 39 |
+
docker compose -f docker-compose.yml -f docker/argilla/compose.yml build
|
| 40 |
+
docker compose -f docker-compose.yml -f docker/argilla/compose.yml up -d
|
| 41 |
+
```
|
| 42 |
+
|
| 43 |
+
### App Only
|
| 44 |
+
```bash
|
| 45 |
+
# Comment out both OLLAMA and ARGILLA sections in .env
|
| 46 |
+
docker compose -f docker-compose.yml build
|
| 47 |
+
docker compose -f docker-compose.yml up -d
|
| 48 |
+
```
|
| 49 |
+
|
| 50 |
+
## Managing Services
|
| 51 |
+
|
| 52 |
+
Services are built separately but are linked together. If you already have some services built and want to add another:
|
| 53 |
+
|
| 54 |
+
1. You don't need to rebuild existing services
|
| 55 |
+
2. Just build the new service
|
| 56 |
+
3. Stop everything with `down` and start again with `up`
|
| 57 |
+
|
| 58 |
+
For example, if you have App + Ollama and want to add Argilla:
|
| 59 |
+
```bash
|
| 60 |
+
docker compose -f docker/argilla/compose.yml build # only build Argilla
|
| 61 |
+
docker compose -f docker-compose.yml -f docker/ollama/compose.yml -f docker/argilla/compose.yml down
|
| 62 |
+
docker compose -f docker-compose.yml -f docker/ollama/compose.yml -f docker/argilla/compose.yml up -d
|
| 63 |
+
```
|
| 64 |
+
|
| 65 |
+
Similarly, if you have built all services but want to run only some of them:
|
| 66 |
+
> **Important**: When running specific services, remember to comment out unused services in `.env` first
|
| 67 |
+
|
| 68 |
+
```bash
|
| 69 |
+
# No need to build again, just start the services you need
|
| 70 |
+
docker compose -f docker-compose.yml -f docker/ollama/compose.yml up -d # start only App + Ollama
|
| 71 |
+
```
|
| 72 |
+
|
| 73 |
+
## Service URLs
|
| 74 |
+
|
| 75 |
+
Once running, access the services at:
|
| 76 |
+
- App: http://localhost:7860
|
| 77 |
+
- Argilla: http://localhost:6900 (if enabled)
|
| 78 |
+
- Ollama: http://localhost:11434 (if enabled)
|
| 79 |
+
|
| 80 |
+
> Note: Services will be available after a few seconds while they initialize. Ollama models and Argilla datasets are persisted and available after restarts
|
docker/argilla/compose.yml
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
services:
|
| 2 |
+
app:
|
| 3 |
+
extends:
|
| 4 |
+
file: docker-compose.yml
|
| 5 |
+
service: app
|
| 6 |
+
depends_on:
|
| 7 |
+
argilla:
|
| 8 |
+
condition: service_healthy
|
| 9 |
+
required: false
|
| 10 |
+
environment:
|
| 11 |
+
- ARGILLA_API_URL=http://argilla:6900
|
| 12 |
+
|
| 13 |
+
elasticsearch:
|
| 14 |
+
image: docker.elastic.co/elasticsearch/elasticsearch:8.17.0
|
| 15 |
+
environment:
|
| 16 |
+
- ES_JAVA_OPTS=-Xms512m -Xmx512m
|
| 17 |
+
- node.name=elasticsearch
|
| 18 |
+
- cluster.name=es-argilla-local
|
| 19 |
+
- discovery.type=single-node
|
| 20 |
+
- cluster.routing.allocation.disk.threshold_enabled=false
|
| 21 |
+
- xpack.security.enabled=false
|
| 22 |
+
volumes:
|
| 23 |
+
- es_data:/usr/share/elasticsearch/data
|
| 24 |
+
networks:
|
| 25 |
+
- app-network
|
| 26 |
+
ports:
|
| 27 |
+
- "9200:9200"
|
| 28 |
+
- "9300:9300"
|
| 29 |
+
ulimits:
|
| 30 |
+
memlock:
|
| 31 |
+
soft: -1
|
| 32 |
+
hard: -1
|
| 33 |
+
nofile:
|
| 34 |
+
soft: 65536
|
| 35 |
+
hard: 65536
|
| 36 |
+
healthcheck:
|
| 37 |
+
test: ["CMD", "curl", "-f", "http://localhost:9200"]
|
| 38 |
+
interval: 30s
|
| 39 |
+
timeout: 10s
|
| 40 |
+
retries: 3
|
| 41 |
+
|
| 42 |
+
postgres:
|
| 43 |
+
image: postgres:14
|
| 44 |
+
environment:
|
| 45 |
+
POSTGRES_USER: postgres
|
| 46 |
+
POSTGRES_PASSWORD: postgres
|
| 47 |
+
POSTGRES_DB: argilla
|
| 48 |
+
networks:
|
| 49 |
+
- app-network
|
| 50 |
+
volumes:
|
| 51 |
+
- postgres_data:/var/lib/postgresql/data
|
| 52 |
+
|
| 53 |
+
redis:
|
| 54 |
+
image: redis
|
| 55 |
+
networks:
|
| 56 |
+
- app-network
|
| 57 |
+
|
| 58 |
+
argilla:
|
| 59 |
+
image: argilla/argilla-server:latest
|
| 60 |
+
ports:
|
| 61 |
+
- "6900:6900"
|
| 62 |
+
healthcheck:
|
| 63 |
+
test: ["CMD", "curl", "-f", "http://localhost:6900/api/ready"]
|
| 64 |
+
interval: 30s
|
| 65 |
+
timeout: 10s
|
| 66 |
+
retries: 3
|
| 67 |
+
env_file:
|
| 68 |
+
- .env
|
| 69 |
+
environment:
|
| 70 |
+
- ARGILLA_HOME_PATH=/var/lib/argilla
|
| 71 |
+
- ARGILLA_ELASTICSEARCH=http://elasticsearch:9200
|
| 72 |
+
- ARGILLA_DATABASE_URL=postgresql+asyncpg://postgres:postgres@postgres:5432/argilla
|
| 73 |
+
- ARGILLA_REDIS_URL=redis://redis:6379/0
|
| 74 |
+
- USERNAME=${ARGILLA_USERNAME}
|
| 75 |
+
- PASSWORD=${ARGILLA_PASSWORD}
|
| 76 |
+
- API_KEY=${ARGILLA_API_KEY}
|
| 77 |
+
- WORKSPACE=default
|
| 78 |
+
volumes:
|
| 79 |
+
- argilla_data:/argilla
|
| 80 |
+
networks:
|
| 81 |
+
- app-network
|
| 82 |
+
depends_on:
|
| 83 |
+
elasticsearch:
|
| 84 |
+
condition: service_healthy
|
| 85 |
+
postgres:
|
| 86 |
+
condition: service_started
|
| 87 |
+
redis:
|
| 88 |
+
condition: service_started
|
| 89 |
+
|
| 90 |
+
worker:
|
| 91 |
+
image: argilla/argilla-server:latest
|
| 92 |
+
env_file:
|
| 93 |
+
- .env
|
| 94 |
+
environment:
|
| 95 |
+
- ARGILLA_HOME_PATH=/var/lib/argilla
|
| 96 |
+
- ARGILLA_ELASTICSEARCH=http://elasticsearch:9200
|
| 97 |
+
- ARGILLA_DATABASE_URL=postgresql+asyncpg://postgres:postgres@postgres:5432/argilla
|
| 98 |
+
- ARGILLA_REDIS_URL=redis://redis:6379/0
|
| 99 |
+
- BACKGROUND_NUM_WORKERS=2
|
| 100 |
+
- USERNAME=${ARGILLA_USERNAME}
|
| 101 |
+
- PASSWORD=${ARGILLA_PASSWORD}
|
| 102 |
+
- API_KEY=${ARGILLA_API_KEY}
|
| 103 |
+
- WORKSPACE=default
|
| 104 |
+
networks:
|
| 105 |
+
- app-network
|
| 106 |
+
depends_on:
|
| 107 |
+
- postgres
|
| 108 |
+
- elasticsearch
|
| 109 |
+
- redis
|
| 110 |
+
command: sh -c 'python -m argilla_server worker --num-workers $${BACKGROUND_NUM_WORKERS}'
|
| 111 |
+
|
| 112 |
+
volumes:
|
| 113 |
+
es_data:
|
| 114 |
+
name: synthetic-data-es
|
| 115 |
+
argilla_data:
|
| 116 |
+
name: synthetic-data-argilla
|
| 117 |
+
postgres_data:
|
| 118 |
+
name: synthetic-data-postgres
|
docker/ollama/compose.yml
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
services:
|
| 2 |
+
app:
|
| 3 |
+
extends:
|
| 4 |
+
file: docker-compose.yml
|
| 5 |
+
service: app
|
| 6 |
+
depends_on:
|
| 7 |
+
ollama:
|
| 8 |
+
condition: service_healthy
|
| 9 |
+
required: true
|
| 10 |
+
environment:
|
| 11 |
+
- OLLAMA_BASE_URL=http://ollama:11434
|
| 12 |
+
|
| 13 |
+
ollama:
|
| 14 |
+
image: ollama/ollama:${OLLAMA_HARDWARE:-latest}
|
| 15 |
+
ports:
|
| 16 |
+
- "11434:11434"
|
| 17 |
+
env_file:
|
| 18 |
+
- .env
|
| 19 |
+
environment:
|
| 20 |
+
- OLLAMA_BASE_URL=${OLLAMA_BASE_URL:-}
|
| 21 |
+
volumes:
|
| 22 |
+
- ollama_data:/root/.ollama
|
| 23 |
+
- ./docker/ollama/entrypoint.sh:/entrypoint.sh
|
| 24 |
+
networks:
|
| 25 |
+
- app-network
|
| 26 |
+
deploy:
|
| 27 |
+
resources:
|
| 28 |
+
reservations:
|
| 29 |
+
devices:
|
| 30 |
+
- driver: nvidia
|
| 31 |
+
count: all
|
| 32 |
+
capabilities: [gpu]
|
| 33 |
+
tty: true
|
| 34 |
+
entrypoint: ["/usr/bin/bash", "/entrypoint.sh"]
|
| 35 |
+
healthcheck:
|
| 36 |
+
test:
|
| 37 |
+
- "CMD-SHELL"
|
| 38 |
+
- |
|
| 39 |
+
test -f /tmp/ollama_ready && \
|
| 40 |
+
bash -c '</dev/tcp/localhost/11434'
|
| 41 |
+
interval: 10s
|
| 42 |
+
timeout: 10s
|
| 43 |
+
retries: 100
|
| 44 |
+
start_period: 10s
|
| 45 |
+
|
| 46 |
+
volumes:
|
| 47 |
+
ollama_data:
|
| 48 |
+
name: synthetic-data-ollama
|
docker/ollama/entrypoint.sh
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
# Start Ollama in the background
|
| 4 |
+
/bin/ollama serve &
|
| 5 |
+
# Record Process ID
|
| 6 |
+
pid=$!
|
| 7 |
+
|
| 8 |
+
# Pause for Ollama to start
|
| 9 |
+
sleep 5
|
| 10 |
+
|
| 11 |
+
# Extract model name from MODEL variable (removing quotes if present)
|
| 12 |
+
MODEL_NAME=$(echo $MODEL | tr -d '"')
|
| 13 |
+
|
| 14 |
+
# Verificar que MODEL_NAME tenga un valor
|
| 15 |
+
if [ -z "$MODEL_NAME" ]; then
|
| 16 |
+
echo "❌ No model specified in MODEL environment variable"
|
| 17 |
+
else
|
| 18 |
+
# Check if model exists
|
| 19 |
+
if ollama list | grep -q "$MODEL_NAME"; then
|
| 20 |
+
echo "🟢 Model ($MODEL_NAME) already installed"
|
| 21 |
+
touch /tmp/ollama_ready
|
| 22 |
+
else
|
| 23 |
+
echo "🔴 Retrieving model ($MODEL_NAME)..."
|
| 24 |
+
# Intentar descargar el modelo sin crear el archivo hasta estar seguros
|
| 25 |
+
if ollama pull "$MODEL_NAME" 2>/dev/null && ollama list | grep -q "$MODEL_NAME"; then
|
| 26 |
+
echo "🟢 Model download complete!"
|
| 27 |
+
touch /tmp/ollama_ready
|
| 28 |
+
else
|
| 29 |
+
echo "❌ Error downloading model ($MODEL_NAME)"
|
| 30 |
+
fi
|
| 31 |
+
fi
|
| 32 |
+
fi
|
| 33 |
+
|
| 34 |
+
# Wait for Ollama process to finish
|
| 35 |
+
wait $pid
|
examples/argilla-deployment.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# /// script
|
| 2 |
+
# requires-python = ">=3.11,<3.12"
|
| 3 |
+
# dependencies = [
|
| 4 |
+
# "synthetic-dataset-generator",
|
| 5 |
+
# ]
|
| 6 |
+
# ///
|
| 7 |
+
import os
|
| 8 |
+
|
| 9 |
+
from synthetic_dataset_generator import launch
|
| 10 |
+
|
| 11 |
+
# Follow https://docs.argilla.io/latest/getting_started/quickstart/ to get your Argilla API key and URL
|
| 12 |
+
os.environ["HF_TOKEN"] = "hf_..."
|
| 13 |
+
os.environ["ARGILLA_API_URL"] = (
|
| 14 |
+
"https://[your-owner-name]-[your_space_name].hf.space" # argilla base url
|
| 15 |
+
)
|
| 16 |
+
os.environ["ARGILLA_API_KEY"] = "my_api_key" # argilla api key
|
| 17 |
+
|
| 18 |
+
launch()
|
examples/blog_private_synthetic_data_generation.md
ADDED
|
@@ -0,0 +1,222 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Private Synthetic Data Generation Made Easy: Out-of-the-Box with Docker, Argilla & Ollama
|
| 2 |
+
|
| 3 |
+
> "Empowering organizations with a turnkey solution for synthetic dataset creation in private environments."
|
| 4 |
+
|
| 5 |
+
The increasing adoption of AI solutions across industries has created an unprecedented demand for high-quality training data. As organizations scale their AI initiatives, they face the dual challenge of generating substantial, domain-specific datasets while ensuring data privacy and security. Traditional approaches often involve compromises: either using public datasets that may not fully align with specific needs, or investing heavily in custom data generation infrastructure.
|
| 6 |
+
|
| 7 |
+
The complexity of this challenge is amplified by regulatory requirements, resource constraints, and the need for specialized expertise. Organizations must navigate GDPR, CCPA, and industry-specific regulations while maintaining efficient data generation pipelines. This has created a pressing need for solutions that can operate entirely within private infrastructure while maintaining enterprise-grade capabilities.
|
| 8 |
+
|
| 9 |
+
## The Challenge
|
| 10 |
+
|
| 11 |
+
The development of AI models requires extensive training data, yet organizations face significant obstacles in data generation and management. Privacy regulations and security requirements often prevent the use of public datasets or cloud-based generation services. Additionally, existing solutions typically demand complex infrastructure setups and significant technical expertise, increasing both implementation time and costs.
|
| 12 |
+
|
| 13 |
+
Modern enterprises require a solution that addresses several critical aspects:
|
| 14 |
+
1. Data Privacy: Complete control over data generation and storage
|
| 15 |
+
2. Infrastructure Flexibility: Deployment options that fit existing systems
|
| 16 |
+
3. Quality Assurance: Tools for data validation and curation
|
| 17 |
+
4. Scalability: Ability to grow with increasing data needs
|
| 18 |
+
5. Cost Efficiency: Reduction in infrastructure and maintenance costs
|
| 19 |
+
|
| 20 |
+
## The Solution
|
| 21 |
+
|
| 22 |
+
This out-of-the-box Synthetic Dataset Generator approach leverages the power of three technologies to create a seamless, private data generation pipeline. At its core is the [Synthetic Dataset Generator](https://github.com/argilla-io/synthetic-data-generator), a tool designed for dataset creation. [Ollama](https://ollama.ai/) ensures secure local LLM inference with [Distilabel](https://github.com/argilla-io/distilabel) integration, while [Argilla's](https://argilla.io/) data curation capabilities complete the workflow, all operating within your secure infrastructure.
|
| 23 |
+
|
| 24 |
+
This architecture delivers key technical advantages:
|
| 25 |
+
- Full data sovereignty with containerized local deployment
|
| 26 |
+
- End-to-end pipeline from generation to validation
|
| 27 |
+
- Modular design for system integration
|
| 28 |
+
|
| 29 |
+
Here's how it all fits together:
|
| 30 |
+
|
| 31 |
+

|
| 32 |
+
|
| 33 |
+
Let's explore how these components work together in a practical workflow.
|
| 34 |
+
|
| 35 |
+
## 1. Installation & Setup
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
### 1.1 Clone Repository
|
| 40 |
+
```bash
|
| 41 |
+
git clone https://github.com/argilla-io/synthetic-data-generator
|
| 42 |
+
cd synthetic-data-generator
|
| 43 |
+
```
|
| 44 |
+
|
| 45 |
+
### 1.2 Environment Setup
|
| 46 |
+
```bash
|
| 47 |
+
# Copy environment template
|
| 48 |
+
cp docker/.env.docker.template .env
|
| 49 |
+
|
| 50 |
+
# Model configuration in .env (if using Ollama)
|
| 51 |
+
MODEL="deepseek-r1:1.5b" # Must match Ollama model name
|
| 52 |
+
```
|
| 53 |
+
|
| 54 |
+
### 1.3 Build & Deploy Services
|
| 55 |
+
> Pro tip: Even if you're planning to use just one component initially, we recommend building all services to enable future functionality without rebuilding. For detailed deployment options, check the [Docker documentation](https://github.com/argilla-io/synthetic-data-generator/blob/main/docker/README.md).
|
| 56 |
+
|
| 57 |
+
> Note: Ollama runs on CPU/GPU for Linux/Windows in Docker. For macOS, only CPU is supported in Docker - for GPU support, install Ollama separately ([details](https://ollama.com/blog/ollama-is-now-available-as-an-official-docker-image)).
|
| 58 |
+
|
| 59 |
+
```bash
|
| 60 |
+
# Build all services
|
| 61 |
+
docker compose -f docker-compose.yml -f docker/ollama/compose.yml -f docker/argilla/compose.yml build
|
| 62 |
+
# Start all services
|
| 63 |
+
docker compose -f docker-compose.yml -f docker/ollama/compose.yml -f docker/argilla/compose.yml up -d
|
| 64 |
+
```
|
| 65 |
+
|
| 66 |
+
To view logs, either:
|
| 67 |
+
- Use Docker Desktop's interface
|
| 68 |
+
- Remove the `-d` flag when running the above command
|
| 69 |
+
- Or execute the following for specific service logs:
|
| 70 |
+
```bash
|
| 71 |
+
# Core App logs
|
| 72 |
+
docker compose logs -f app
|
| 73 |
+
# Ollama logs
|
| 74 |
+
docker compose -f docker-compose.yml -f docker/ollama/compose.yml logs -f ollama
|
| 75 |
+
# Argilla logs
|
| 76 |
+
docker compose -f docker-compose.yml -f docker/argilla/compose.yml logs -f argilla
|
| 77 |
+
```
|
| 78 |
+
|
| 79 |
+
## 2. Dataset Generation
|
| 80 |
+
|
| 81 |
+
The tool currently supports **Text Classification**, **Chat**, and **RAG** datasets. These tasks will determine the type of dataset you will generate: classification requires categories, chat data requires a conversation format, and RAG requires question-answer pairs with relevant context, offering options for both retrieval and reranking data generation to enhance different aspects of information retrieval systems.
|
| 82 |
+
|
| 83 |
+
For a detailed overview of the generation process, check out the [introduction to the Synthetic Data Generator](https://huggingface.co/blog/synthetic-data-generator).
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
### 2.1. **Dataset Description**
|
| 87 |
+
|
| 88 |
+
Let's walk through creating a **RAG dataset**.
|
| 89 |
+
```text
|
| 90 |
+
A dataset to retrieve information from information security policies
|
| 91 |
+
```
|
| 92 |
+
|
| 93 |
+
System initializes and processes the prompt:
|
| 94 |
+

|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
### 2.2. **Task Configuration & Sample Generation**
|
| 98 |
+
System analyzes and generates the system prompt and optimal parameters automatically. Then, samples are generated for validation (modify system prompt or parameters manually if needed, then click save to generate sample data):
|
| 99 |
+

|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
### 2.3. **Full Dataset Generation**
|
| 103 |
+
After validating the sample data quality, proceed with full dataset generation. Configure the following parameters:
|
| 104 |
+
|
| 105 |
+
- **Repository Owner**: Your Hugging Face username for dataset hosting
|
| 106 |
+
- **Dataset Name**: A descriptive name following standard naming conventions
|
| 107 |
+
- **Number of Examples**: Define dataset size (recommended: 100-1000 for initial deployments)
|
| 108 |
+
- **Temperature**: Controls generation creativity (default 0.7 balances coherence and diversity)
|
| 109 |
+
- **Privacy Settings**: Optional dataset privacy configuration for Hugging Face Hub
|
| 110 |
+
|
| 111 |
+
The temperature parameter significantly impacts output quality:
|
| 112 |
+
- 0.5-0.7: Optimal for technical documentation and factual content
|
| 113 |
+
- 0.7-0.8: Balanced for general purpose datasets
|
| 114 |
+
- 0.8-1.0: Increased creativity, suitable for conversational data
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
The system initiates the generation pipeline, leveraging Distilabel for structured output:
|
| 118 |
+

|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
Upon completion, the dataset is pushed to Hugging Face Hub:
|
| 122 |
+

|
| 123 |
+
|
| 124 |
+
Access your generated dataset through the Hugging Face Hub interface:
|
| 125 |
+
|
| 126 |
+
<iframe
|
| 127 |
+
src="https://huggingface.co/datasets/daqc/info-security-policies-rag-distiset/embed/viewer/default/train"
|
| 128 |
+
frameborder="0"
|
| 129 |
+
width="100%"
|
| 130 |
+
height="560px"
|
| 131 |
+
></iframe>
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
## 3. Data Curation with Argilla
|
| 136 |
+
|
| 137 |
+
The integration with Argilla provides enterprise-grade dataset curation capabilities through a comprehensive review system. This phase is crucial for ensuring data quality and maintaining high standards in your training datasets.
|
| 138 |
+
|
| 139 |
+
### Environment Configuration
|
| 140 |
+
Before accessing Argilla's features, ensure proper configuration in your `.env` file.
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
### Curation Workflow
|
| 144 |
+
|
| 145 |
+
1. **Dataset Integration**
|
| 146 |
+
Upon generation completion, the dataset is automatically ingested into Argilla. The system maintains data integrity and version control throughout the process. All datasets and progress persist across Docker restarts unless you explicitly remove the Argilla services and volumes.
|
| 147 |
+

|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
2. **Quality Assurance Process**
|
| 151 |
+
Argilla's interface provides comprehensive tools for dataset validation:
|
| 152 |
+
- Semantic analysis of generated content
|
| 153 |
+
- Consistency checking across entries
|
| 154 |
+
- Metadata validation and enrichment
|
| 155 |
+
- Collaborative review capabilities
|
| 156 |
+
|
| 157 |
+

|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
3. **Dataset Publication**
|
| 162 |
+
After thorough review, export your curated dataset to Hugging Face Hub:
|
| 163 |
+
|
| 164 |
+
> Note: Consider using a new repository name to preserve both raw and curated datasets separately.
|
| 165 |
+
|
| 166 |
+
- Configure repository settings
|
| 167 |
+
- Set visibility and access controls
|
| 168 |
+
- Add dataset cards and documentation
|
| 169 |
+
|
| 170 |
+

|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
The curated dataset maintains full provenance tracking and quality metrics:
|
| 174 |
+
<iframe
|
| 175 |
+
src="https://huggingface.co/datasets/daqc/info-security-policies-rag-distiset-argilla/embed/viewer/default/train"
|
| 176 |
+
frameborder="0"
|
| 177 |
+
width="100%"
|
| 178 |
+
height="560px"
|
| 179 |
+
></iframe>
|
| 180 |
+
|
| 181 |
+
# 🎉 You're Done!
|
| 182 |
+
Congratulations! You've successfully completed the end-to-end dataset generation and curation process. Your curated dataset is now ready for model training.
|
| 183 |
+
|
| 184 |
+
## Experience the Solution
|
| 185 |
+
|
| 186 |
+
For a hands-on preview of the Synthetic Dataset Generator's capabilities, explore the hosted space. This allows you to evaluate the interface and functionality before deploying your own instance:
|
| 187 |
+
|
| 188 |
+
<iframe
|
| 189 |
+
src="https://argilla-synthetic-data-generator.hf.space"
|
| 190 |
+
frameborder="0"
|
| 191 |
+
width="850"
|
| 192 |
+
height="450"
|
| 193 |
+
referrerpolicy="same-origin"
|
| 194 |
+
sandbox="allow-scripts"
|
| 195 |
+
></iframe>
|
| 196 |
+
|
| 197 |
+
Create your own deployment by <a href="https://huggingface.co/spaces/argilla/synthetic-data-generator?duplicate=true">duplicating this Space</a>.
|
| 198 |
+
|
| 199 |
+
## What's Next?
|
| 200 |
+
|
| 201 |
+
After successfully generating your first dataset, several advanced implementation paths are available:
|
| 202 |
+
|
| 203 |
+
Extend your dataset generation capabilities:
|
| 204 |
+
- [Fine-tune models on synthetic data](https://huggingface.co/blog/davidberenstein1957/fine-tune-a-smollm-on-synthetic-data-of-llm) for domain-specific tasks
|
| 205 |
+
- [Create specialized reasoning datasets](https://huggingface.co/blog/sdiazlor/fine-tune-deepseek-with-a-synthetic-reasoning-data) for advanced model training
|
| 206 |
+
|
| 207 |
+
## Conclusion
|
| 208 |
+
|
| 209 |
+
The Synthetic Dataset Generator represents a significant advancement in private data generation technology, addressing the growing need for high-quality training data while maintaining security and control. By leveraging containerized architecture and local LLM inference, organizations can now generate custom datasets without compromising on data privacy or quality.
|
| 210 |
+
|
| 211 |
+
The solution's modular design enables seamless integration with existing ML pipelines while providing enterprise-grade features like persistent storage, comprehensive monitoring, and scalable infrastructure. Through collaborative validation workflows and structured quality control processes, teams can efficiently create and curate datasets tailored to their specific needs.
|
| 212 |
+
|
| 213 |
+
This combination of security, efficiency, and flexibility makes the Synthetic Dataset Generator an essential tool for organizations looking to accelerate their AI development while maintaining complete control over their data generation pipeline.
|
| 214 |
+
|
| 215 |
+
## References & Documentation
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
- [Synthetic Dataset Generator](https://github.com/argilla-io/synthetic-data-generator): Open-source tool for dataset generation using natural language
|
| 219 |
+
- [Distilabel Framework](https://github.com/argilla-io/distilabel): Advanced dataset generation framework
|
| 220 |
+
- [Docker Best Practices](https://docs.docker.com/develop/develop-images/dockerfile_best-practices/): Container optimization guidelines
|
| 221 |
+
- [Argilla Documentation](https://docs.argilla.io): Data curation platform documentation
|
| 222 |
+
- [Ollama Integration](https://github.com/jmorganca/ollama): Local LLM deployment guide
|
examples/fine-tune-deepseek-reasoning-sft.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
examples/fine-tune-modernbert-classifier.ipynb
ADDED
|
@@ -0,0 +1,538 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "markdown",
|
| 5 |
+
"metadata": {},
|
| 6 |
+
"source": [
|
| 7 |
+
"# Fine-tune ModernBERT for text classification using synthetic data\n",
|
| 8 |
+
"\n",
|
| 9 |
+
"LLMs are great general purpose models, but they are not always the best choice for a specific task. Therefore, smaller and more specialized models are important for sustainable, efficient, and cheaper AI.\n",
|
| 10 |
+
"A lack of domain sepcific datasets is a common problem for smaller and more specialized models. This is because it is difficult to find a dataset that is both representative and diverse enough for a specific task. We solve this problem by generating a synthetic dataset from an LLM using the `synthetic-data-generator`, which is available as a [Hugging Face Space](https://huggingface.co/spaces/argilla/synthetic-data-generator) or on [GitHub](https://github.com/argilla-io/synthetic-data-generator).\n",
|
| 11 |
+
"\n",
|
| 12 |
+
"In this example, we will fine-tune a ModernBERT model on a synthetic dataset generated from the synthetic-data-generator. This demonstrates the effectiveness of synthetic data and the novel ModernBERT model, which is a new and improved version of BERT models, with an 8192 token context length, significantly better downstream performance, and much faster processing speeds.\n",
|
| 13 |
+
"\n",
|
| 14 |
+
"## Install the dependencies"
|
| 15 |
+
]
|
| 16 |
+
},
|
| 17 |
+
{
|
| 18 |
+
"cell_type": "code",
|
| 19 |
+
"execution_count": null,
|
| 20 |
+
"metadata": {},
|
| 21 |
+
"outputs": [],
|
| 22 |
+
"source": [
|
| 23 |
+
"# Install Pytorch & other libraries\n",
|
| 24 |
+
"%pip install \"torch==2.5.0\" \"torchvision==0.20.0\" \n",
|
| 25 |
+
"%pip install \"setuptools<71.0.0\" scikit-learn \n",
|
| 26 |
+
" \n",
|
| 27 |
+
"# Install Hugging Face libraries\n",
|
| 28 |
+
"%pip install --upgrade \\\n",
|
| 29 |
+
" \"datasets==3.1.0\" \\\n",
|
| 30 |
+
" \"accelerate==1.2.1\" \\\n",
|
| 31 |
+
" \"hf-transfer==0.1.8\"\n",
|
| 32 |
+
" \n",
|
| 33 |
+
"# ModernBERT is not yet available in an official release, so we need to install it from github\n",
|
| 34 |
+
"%pip install \"git+https://github.com/huggingface/transformers.git@6e0515e99c39444caae39472ee1b2fd76ece32f1\" --upgrade"
|
| 35 |
+
]
|
| 36 |
+
},
|
| 37 |
+
{
|
| 38 |
+
"cell_type": "markdown",
|
| 39 |
+
"metadata": {},
|
| 40 |
+
"source": [
|
| 41 |
+
"## The problem\n",
|
| 42 |
+
"\n",
|
| 43 |
+
"The [nvidia/domain-classifier](https://huggingface.co/nvidia/domain-classifier), is a model that can classify the domain of a text which can help with curating data. This model is cool but is based on the Deberta V3 Base, which is an outdated architecture that requires custom code to run, has a context length of 512 tokens, and is not as fast as the ModernBERT model. The labels for the model are:\n",
|
| 44 |
+
"\n",
|
| 45 |
+
"```\n",
|
| 46 |
+
"'Adult', 'Arts_and_Entertainment', 'Autos_and_Vehicles', 'Beauty_and_Fitness', 'Books_and_Literature', 'Business_and_Industrial', 'Computers_and_Electronics', 'Finance', 'Food_and_Drink', 'Games', 'Health', 'Hobbies_and_Leisure', 'Home_and_Garden', 'Internet_and_Telecom', 'Jobs_and_Education', 'Law_and_Government', 'News', 'Online_Communities', 'People_and_Society', 'Pets_and_Animals', 'Real_Estate', 'Science', 'Sensitive_Subjects', 'Shopping', 'Sports', 'Travel_and_Transportation'\n",
|
| 47 |
+
"```\n",
|
| 48 |
+
"\n",
|
| 49 |
+
"The data on which the model was trained is not available, so we cannot use it for our purposes. We can however generate a synthetic data to solve this problem."
|
| 50 |
+
]
|
| 51 |
+
},
|
| 52 |
+
{
|
| 53 |
+
"cell_type": "markdown",
|
| 54 |
+
"metadata": {
|
| 55 |
+
"vscode": {
|
| 56 |
+
"languageId": "plaintext"
|
| 57 |
+
}
|
| 58 |
+
},
|
| 59 |
+
"source": [
|
| 60 |
+
"## Let's generate some data\n",
|
| 61 |
+
"\n",
|
| 62 |
+
"Let's go to the [hosted Hugging Face Space](https://huggingface.co/spaces/argilla/synthetic-data-generator) to generate the data. This is done in three steps 1) we come up with a dataset description, 2) iterate on the task configuration, and 3) generate and push the data to Hugging Face. A more detailed flow can be found in [this blogpost](https://huggingface.co/blog/synthetic-data-generator). \n",
|
| 63 |
+
"\n",
|
| 64 |
+
"<iframe\n",
|
| 65 |
+
"\tsrc=\"https://argilla-synthetic-data-generator.hf.space\"\n",
|
| 66 |
+
"\tframeborder=\"0\"\n",
|
| 67 |
+
"\twidth=\"850\"\n",
|
| 68 |
+
"\theight=\"450\"\n",
|
| 69 |
+
"></iframe>\n",
|
| 70 |
+
"\n",
|
| 71 |
+
"For this example, we will generate 1000 examples with a temperature of 1. After some iteration, we come up with the following system prompt:\n",
|
| 72 |
+
"\n",
|
| 73 |
+
"```\n",
|
| 74 |
+
"Long texts (at least 2000 words) from various media sources like Wikipedia, Reddit, Common Crawl, websites, commercials, online forums, books, newspapers and folders that cover multiple topics. Classify the text based on its main subject matter into one of the following categories\n",
|
| 75 |
+
"```\n",
|
| 76 |
+
"\n",
|
| 77 |
+
"We press the \"Push to Hub\" button and wait for the data to be generated. This takes a few minutes and we end up with a dataset with 1000 examples. The labels are nicely distributed across the categories, varied in length, and the texts look diverse and interesting.\n",
|
| 78 |
+
"\n",
|
| 79 |
+
"<iframe\n",
|
| 80 |
+
" src=\"https://huggingface.co/datasets/argilla/synthetic-domain-text-classification/embed/viewer/default/train\"\n",
|
| 81 |
+
" frameborder=\"0\"\n",
|
| 82 |
+
" width=\"100%\"\n",
|
| 83 |
+
" height=\"560px\"\n",
|
| 84 |
+
"></iframe>\n",
|
| 85 |
+
"\n",
|
| 86 |
+
"The data is pushed to Argilla to so we recommend inspecting and validating the labels before finetuning the model."
|
| 87 |
+
]
|
| 88 |
+
},
|
| 89 |
+
{
|
| 90 |
+
"cell_type": "markdown",
|
| 91 |
+
"metadata": {},
|
| 92 |
+
"source": [
|
| 93 |
+
"## Finetuning the ModernBERT model\n",
|
| 94 |
+
"\n",
|
| 95 |
+
"We mostly rely on the blog from [Phillip Schmid](https://www.philschmid.de/fine-tune-modern-bert-in-2025). I will basic consumer hardware, my Apple M1 Max with 32GB of shared memory. We will use the `datasets` library to load the data and the `transformers` library to finetune the model."
|
| 96 |
+
]
|
| 97 |
+
},
|
| 98 |
+
{
|
| 99 |
+
"cell_type": "code",
|
| 100 |
+
"execution_count": 1,
|
| 101 |
+
"metadata": {},
|
| 102 |
+
"outputs": [
|
| 103 |
+
{
|
| 104 |
+
"name": "stderr",
|
| 105 |
+
"output_type": "stream",
|
| 106 |
+
"text": [
|
| 107 |
+
"/Users/davidberenstein/Documents/programming/argilla/synthetic-data-generator/.venv/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
|
| 108 |
+
" from .autonotebook import tqdm as notebook_tqdm\n"
|
| 109 |
+
]
|
| 110 |
+
},
|
| 111 |
+
{
|
| 112 |
+
"data": {
|
| 113 |
+
"text/plain": [
|
| 114 |
+
"{'text': 'Recently, there has been an increase in property values within the suburban areas of several cities due to improvements in infrastructure and lifestyle amenities such as parks, retail stores, and educational institutions nearby. Additionally, new housing developments are emerging, catering to different family needs with varying sizes and price ranges. These changes have influenced investment decisions for many looking to buy or sell properties.',\n",
|
| 115 |
+
" 'label': 14}"
|
| 116 |
+
]
|
| 117 |
+
},
|
| 118 |
+
"execution_count": 1,
|
| 119 |
+
"metadata": {},
|
| 120 |
+
"output_type": "execute_result"
|
| 121 |
+
}
|
| 122 |
+
],
|
| 123 |
+
"source": [
|
| 124 |
+
"from datasets import load_dataset\n",
|
| 125 |
+
"from datasets.arrow_dataset import Dataset\n",
|
| 126 |
+
"from datasets.dataset_dict import DatasetDict, IterableDatasetDict\n",
|
| 127 |
+
"from datasets.iterable_dataset import IterableDataset\n",
|
| 128 |
+
" \n",
|
| 129 |
+
"# Dataset id from huggingface.co/dataset\n",
|
| 130 |
+
"dataset_id = \"argilla/synthetic-domain-text-classification\"\n",
|
| 131 |
+
" \n",
|
| 132 |
+
"# Load raw dataset\n",
|
| 133 |
+
"train_dataset = load_dataset(dataset_id, split='train')\n",
|
| 134 |
+
"\n",
|
| 135 |
+
"split_dataset = train_dataset.train_test_split(test_size=0.1)\n",
|
| 136 |
+
"split_dataset['train'][0]"
|
| 137 |
+
]
|
| 138 |
+
},
|
| 139 |
+
{
|
| 140 |
+
"cell_type": "markdown",
|
| 141 |
+
"metadata": {},
|
| 142 |
+
"source": [
|
| 143 |
+
"First, we need to tokenize the data. We will use the `AutoTokenizer` class from the `transformers` library to load the tokenizer."
|
| 144 |
+
]
|
| 145 |
+
},
|
| 146 |
+
{
|
| 147 |
+
"cell_type": "code",
|
| 148 |
+
"execution_count": 2,
|
| 149 |
+
"metadata": {},
|
| 150 |
+
"outputs": [
|
| 151 |
+
{
|
| 152 |
+
"name": "stderr",
|
| 153 |
+
"output_type": "stream",
|
| 154 |
+
"text": [
|
| 155 |
+
"Map: 100%|██████████| 900/900 [00:00<00:00, 4787.61 examples/s]\n",
|
| 156 |
+
"Map: 100%|██████████| 100/100 [00:00<00:00, 4163.70 examples/s]\n"
|
| 157 |
+
]
|
| 158 |
+
},
|
| 159 |
+
{
|
| 160 |
+
"data": {
|
| 161 |
+
"text/plain": [
|
| 162 |
+
"dict_keys(['labels', 'input_ids', 'attention_mask'])"
|
| 163 |
+
]
|
| 164 |
+
},
|
| 165 |
+
"execution_count": 2,
|
| 166 |
+
"metadata": {},
|
| 167 |
+
"output_type": "execute_result"
|
| 168 |
+
}
|
| 169 |
+
],
|
| 170 |
+
"source": [
|
| 171 |
+
"from transformers import AutoTokenizer\n",
|
| 172 |
+
" \n",
|
| 173 |
+
"# Model id to load the tokenizer\n",
|
| 174 |
+
"model_id = \"answerdotai/ModernBERT-base\"\n",
|
| 175 |
+
"\n",
|
| 176 |
+
"# Load Tokenizer\n",
|
| 177 |
+
"tokenizer = AutoTokenizer.from_pretrained(model_id)\n",
|
| 178 |
+
" \n",
|
| 179 |
+
"# Tokenize helper function\n",
|
| 180 |
+
"def tokenize(batch):\n",
|
| 181 |
+
" return tokenizer(batch['text'], padding=True, truncation=True, return_tensors=\"pt\")\n",
|
| 182 |
+
" \n",
|
| 183 |
+
"# Tokenize dataset\n",
|
| 184 |
+
"if \"label\" in split_dataset[\"train\"].features.keys():\n",
|
| 185 |
+
" split_dataset = split_dataset.rename_column(\"label\", \"labels\") # to match Trainer\n",
|
| 186 |
+
"tokenized_dataset = split_dataset.map(tokenize, batched=True, remove_columns=[\"text\"])\n",
|
| 187 |
+
" \n",
|
| 188 |
+
"tokenized_dataset[\"train\"].features.keys()"
|
| 189 |
+
]
|
| 190 |
+
},
|
| 191 |
+
{
|
| 192 |
+
"cell_type": "markdown",
|
| 193 |
+
"metadata": {},
|
| 194 |
+
"source": [
|
| 195 |
+
"Now, we need to prepare the model. We will use the `AutoModelForSequenceClassification` class from the `transformers` library to load the model."
|
| 196 |
+
]
|
| 197 |
+
},
|
| 198 |
+
{
|
| 199 |
+
"cell_type": "code",
|
| 200 |
+
"execution_count": 3,
|
| 201 |
+
"metadata": {},
|
| 202 |
+
"outputs": [
|
| 203 |
+
{
|
| 204 |
+
"name": "stderr",
|
| 205 |
+
"output_type": "stream",
|
| 206 |
+
"text": [
|
| 207 |
+
"Some weights of ModernBertForSequenceClassification were not initialized from the model checkpoint at answerdotai/ModernBERT-base and are newly initialized: ['classifier.bias', 'classifier.weight']\n",
|
| 208 |
+
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
|
| 209 |
+
]
|
| 210 |
+
}
|
| 211 |
+
],
|
| 212 |
+
"source": [
|
| 213 |
+
"from transformers import AutoModelForSequenceClassification\n",
|
| 214 |
+
" \n",
|
| 215 |
+
"# Model id to load the tokenizer\n",
|
| 216 |
+
"model_id = \"answerdotai/ModernBERT-base\"\n",
|
| 217 |
+
" \n",
|
| 218 |
+
"# Prepare model labels - useful for inference\n",
|
| 219 |
+
"labels = tokenized_dataset[\"train\"].features[\"labels\"].names\n",
|
| 220 |
+
"num_labels = len(labels)\n",
|
| 221 |
+
"label2id, id2label = dict(), dict()\n",
|
| 222 |
+
"for i, label in enumerate(labels):\n",
|
| 223 |
+
" label2id[label] = str(i)\n",
|
| 224 |
+
" id2label[str(i)] = label\n",
|
| 225 |
+
" \n",
|
| 226 |
+
"# Download the model from huggingface.co/models\n",
|
| 227 |
+
"model = AutoModelForSequenceClassification.from_pretrained(\n",
|
| 228 |
+
" model_id, num_labels=num_labels, label2id=label2id, id2label=id2label,\n",
|
| 229 |
+
")"
|
| 230 |
+
]
|
| 231 |
+
},
|
| 232 |
+
{
|
| 233 |
+
"cell_type": "markdown",
|
| 234 |
+
"metadata": {},
|
| 235 |
+
"source": [
|
| 236 |
+
"We will use a simple F1 score as the evaluation metric."
|
| 237 |
+
]
|
| 238 |
+
},
|
| 239 |
+
{
|
| 240 |
+
"cell_type": "code",
|
| 241 |
+
"execution_count": 4,
|
| 242 |
+
"metadata": {},
|
| 243 |
+
"outputs": [],
|
| 244 |
+
"source": [
|
| 245 |
+
"import numpy as np\n",
|
| 246 |
+
"from sklearn.metrics import f1_score\n",
|
| 247 |
+
" \n",
|
| 248 |
+
"# Metric helper method\n",
|
| 249 |
+
"def compute_metrics(eval_pred):\n",
|
| 250 |
+
" predictions, labels = eval_pred\n",
|
| 251 |
+
" predictions = np.argmax(predictions, axis=1)\n",
|
| 252 |
+
" score = f1_score(\n",
|
| 253 |
+
" labels, predictions, labels=labels, pos_label=1, average=\"weighted\"\n",
|
| 254 |
+
" )\n",
|
| 255 |
+
" return {\"f1\": float(score) if score == 1 else score}"
|
| 256 |
+
]
|
| 257 |
+
},
|
| 258 |
+
{
|
| 259 |
+
"cell_type": "markdown",
|
| 260 |
+
"metadata": {},
|
| 261 |
+
"source": [
|
| 262 |
+
"Finally, we need to define the training arguments. We will use the `TrainingArguments` class from the `transformers` library to define the training arguments."
|
| 263 |
+
]
|
| 264 |
+
},
|
| 265 |
+
{
|
| 266 |
+
"cell_type": "code",
|
| 267 |
+
"execution_count": 6,
|
| 268 |
+
"metadata": {},
|
| 269 |
+
"outputs": [
|
| 270 |
+
{
|
| 271 |
+
"name": "stderr",
|
| 272 |
+
"output_type": "stream",
|
| 273 |
+
"text": [
|
| 274 |
+
"/Users/davidberenstein/Documents/programming/argilla/synthetic-data-generator/.venv/lib/python3.11/site-packages/transformers/training_args.py:2241: UserWarning: `use_mps_device` is deprecated and will be removed in version 5.0 of 🤗 Transformers. `mps` device will be used by default if available similar to the way `cuda` device is used.Therefore, no action from user is required. \n",
|
| 275 |
+
" warnings.warn(\n"
|
| 276 |
+
]
|
| 277 |
+
}
|
| 278 |
+
],
|
| 279 |
+
"source": [
|
| 280 |
+
"from huggingface_hub import HfFolder\n",
|
| 281 |
+
"from transformers import Trainer, TrainingArguments\n",
|
| 282 |
+
" \n",
|
| 283 |
+
"# Define training args\n",
|
| 284 |
+
"training_args = TrainingArguments(\n",
|
| 285 |
+
" output_dir= \"ModernBERT-domain-classifier\",\n",
|
| 286 |
+
" per_device_train_batch_size=32,\n",
|
| 287 |
+
" per_device_eval_batch_size=16,\n",
|
| 288 |
+
" learning_rate=5e-5,\n",
|
| 289 |
+
"\t\tnum_train_epochs=5,\n",
|
| 290 |
+
" bf16=True, # bfloat16 training \n",
|
| 291 |
+
" optim=\"adamw_torch_fused\", # improved optimizer \n",
|
| 292 |
+
" # logging & evaluation strategies\n",
|
| 293 |
+
" logging_strategy=\"steps\",\n",
|
| 294 |
+
" logging_steps=100,\n",
|
| 295 |
+
" eval_strategy=\"epoch\",\n",
|
| 296 |
+
" save_strategy=\"epoch\",\n",
|
| 297 |
+
" save_total_limit=2,\n",
|
| 298 |
+
" load_best_model_at_end=True,\n",
|
| 299 |
+
" use_mps_device=True,\n",
|
| 300 |
+
" metric_for_best_model=\"f1\",\n",
|
| 301 |
+
" # push to hub parameters\n",
|
| 302 |
+
" push_to_hub=True,\n",
|
| 303 |
+
" hub_strategy=\"every_save\",\n",
|
| 304 |
+
" hub_token=HfFolder.get_token(),\n",
|
| 305 |
+
")\n",
|
| 306 |
+
" \n",
|
| 307 |
+
"# Create a Trainer instance\n",
|
| 308 |
+
"trainer = Trainer(\n",
|
| 309 |
+
" model=model,\n",
|
| 310 |
+
" args=training_args,\n",
|
| 311 |
+
" train_dataset=tokenized_dataset[\"train\"],\n",
|
| 312 |
+
" eval_dataset=tokenized_dataset[\"test\"],\n",
|
| 313 |
+
" compute_metrics=compute_metrics,\n",
|
| 314 |
+
")"
|
| 315 |
+
]
|
| 316 |
+
},
|
| 317 |
+
{
|
| 318 |
+
"cell_type": "code",
|
| 319 |
+
"execution_count": 7,
|
| 320 |
+
"metadata": {},
|
| 321 |
+
"outputs": [
|
| 322 |
+
{
|
| 323 |
+
"name": "stderr",
|
| 324 |
+
"output_type": "stream",
|
| 325 |
+
"text": [
|
| 326 |
+
" \n",
|
| 327 |
+
" 20%|██ | 29/145 [11:32<33:16, 17.21s/it]"
|
| 328 |
+
]
|
| 329 |
+
},
|
| 330 |
+
{
|
| 331 |
+
"name": "stdout",
|
| 332 |
+
"output_type": "stream",
|
| 333 |
+
"text": [
|
| 334 |
+
"{'eval_loss': 0.729780912399292, 'eval_f1': 0.7743598318036522, 'eval_runtime': 3.5337, 'eval_samples_per_second': 28.299, 'eval_steps_per_second': 1.981, 'epoch': 1.0}\n"
|
| 335 |
+
]
|
| 336 |
+
},
|
| 337 |
+
{
|
| 338 |
+
"name": "stderr",
|
| 339 |
+
"output_type": "stream",
|
| 340 |
+
"text": [
|
| 341 |
+
" \n",
|
| 342 |
+
" 40%|████ | 58/145 [22:57<25:56, 17.89s/it]"
|
| 343 |
+
]
|
| 344 |
+
},
|
| 345 |
+
{
|
| 346 |
+
"name": "stdout",
|
| 347 |
+
"output_type": "stream",
|
| 348 |
+
"text": [
|
| 349 |
+
"{'eval_loss': 0.4369044005870819, 'eval_f1': 0.8310764765820946, 'eval_runtime': 3.3266, 'eval_samples_per_second': 30.061, 'eval_steps_per_second': 2.104, 'epoch': 2.0}\n"
|
| 350 |
+
]
|
| 351 |
+
},
|
| 352 |
+
{
|
| 353 |
+
"name": "stderr",
|
| 354 |
+
"output_type": "stream",
|
| 355 |
+
"text": [
|
| 356 |
+
" \n",
|
| 357 |
+
" 60%|██████ | 87/145 [35:16<17:06, 17.70s/it]"
|
| 358 |
+
]
|
| 359 |
+
},
|
| 360 |
+
{
|
| 361 |
+
"name": "stdout",
|
| 362 |
+
"output_type": "stream",
|
| 363 |
+
"text": [
|
| 364 |
+
"{'eval_loss': 0.6091340184211731, 'eval_f1': 0.8399274488570763, 'eval_runtime': 3.2772, 'eval_samples_per_second': 30.514, 'eval_steps_per_second': 2.136, 'epoch': 3.0}\n"
|
| 365 |
+
]
|
| 366 |
+
},
|
| 367 |
+
{
|
| 368 |
+
"name": "stderr",
|
| 369 |
+
"output_type": "stream",
|
| 370 |
+
"text": [
|
| 371 |
+
" 69%|██████▉ | 100/145 [41:03<18:02, 24.06s/it]"
|
| 372 |
+
]
|
| 373 |
+
},
|
| 374 |
+
{
|
| 375 |
+
"name": "stdout",
|
| 376 |
+
"output_type": "stream",
|
| 377 |
+
"text": [
|
| 378 |
+
"{'loss': 0.7663, 'grad_norm': 7.232136249542236, 'learning_rate': 1.5517241379310346e-05, 'epoch': 3.45}\n"
|
| 379 |
+
]
|
| 380 |
+
},
|
| 381 |
+
{
|
| 382 |
+
"name": "stderr",
|
| 383 |
+
"output_type": "stream",
|
| 384 |
+
"text": [
|
| 385 |
+
" \n",
|
| 386 |
+
" 80%|████████ | 116/145 [47:23<08:50, 18.30s/it]"
|
| 387 |
+
]
|
| 388 |
+
},
|
| 389 |
+
{
|
| 390 |
+
"name": "stdout",
|
| 391 |
+
"output_type": "stream",
|
| 392 |
+
"text": [
|
| 393 |
+
"{'eval_loss': 0.43516409397125244, 'eval_f1': 0.8797674004703547, 'eval_runtime': 3.2975, 'eval_samples_per_second': 30.326, 'eval_steps_per_second': 2.123, 'epoch': 4.0}\n"
|
| 394 |
+
]
|
| 395 |
+
},
|
| 396 |
+
{
|
| 397 |
+
"name": "stderr",
|
| 398 |
+
"output_type": "stream",
|
| 399 |
+
"text": [
|
| 400 |
+
" \n",
|
| 401 |
+
"100%|██████████| 145/145 [1:00:40<00:00, 19.18s/it]"
|
| 402 |
+
]
|
| 403 |
+
},
|
| 404 |
+
{
|
| 405 |
+
"name": "stdout",
|
| 406 |
+
"output_type": "stream",
|
| 407 |
+
"text": [
|
| 408 |
+
"{'eval_loss': 0.39272159337997437, 'eval_f1': 0.8914389523348718, 'eval_runtime': 3.5564, 'eval_samples_per_second': 28.118, 'eval_steps_per_second': 1.968, 'epoch': 5.0}\n"
|
| 409 |
+
]
|
| 410 |
+
},
|
| 411 |
+
{
|
| 412 |
+
"name": "stderr",
|
| 413 |
+
"output_type": "stream",
|
| 414 |
+
"text": [
|
| 415 |
+
"100%|██████████| 145/145 [1:00:42<00:00, 25.12s/it]\n"
|
| 416 |
+
]
|
| 417 |
+
},
|
| 418 |
+
{
|
| 419 |
+
"name": "stdout",
|
| 420 |
+
"output_type": "stream",
|
| 421 |
+
"text": [
|
| 422 |
+
"{'train_runtime': 3642.7783, 'train_samples_per_second': 1.235, 'train_steps_per_second': 0.04, 'train_loss': 0.535627057634551, 'epoch': 5.0}\n"
|
| 423 |
+
]
|
| 424 |
+
},
|
| 425 |
+
{
|
| 426 |
+
"name": "stderr",
|
| 427 |
+
"output_type": "stream",
|
| 428 |
+
"text": [
|
| 429 |
+
"events.out.tfevents.1735555878.Davids-MacBook-Pro.local.23438.0: 100%|██████████| 9.32k/9.32k [00:00<00:00, 55.0kB/s]\n"
|
| 430 |
+
]
|
| 431 |
+
},
|
| 432 |
+
{
|
| 433 |
+
"data": {
|
| 434 |
+
"text/plain": [
|
| 435 |
+
"CommitInfo(commit_url='https://huggingface.co/davidberenstein1957/domain-classifier/commit/915f4b03c230cc8f376f13729728f14347400041', commit_message='End of training', commit_description='', oid='915f4b03c230cc8f376f13729728f14347400041', pr_url=None, repo_url=RepoUrl('https://huggingface.co/davidberenstein1957/domain-classifier', endpoint='https://huggingface.co', repo_type='model', repo_id='davidberenstein1957/domain-classifier'), pr_revision=None, pr_num=None)"
|
| 436 |
+
]
|
| 437 |
+
},
|
| 438 |
+
"execution_count": 7,
|
| 439 |
+
"metadata": {},
|
| 440 |
+
"output_type": "execute_result"
|
| 441 |
+
}
|
| 442 |
+
],
|
| 443 |
+
"source": [
|
| 444 |
+
"trainer.train()\n",
|
| 445 |
+
"# Save processor and create model card\n",
|
| 446 |
+
"tokenizer.save_pretrained(\"ModernBERT-domain-classifier\")\n",
|
| 447 |
+
"trainer.create_model_card()\n",
|
| 448 |
+
"trainer.push_to_hub()"
|
| 449 |
+
]
|
| 450 |
+
},
|
| 451 |
+
{
|
| 452 |
+
"cell_type": "markdown",
|
| 453 |
+
"metadata": {},
|
| 454 |
+
"source": [
|
| 455 |
+
"We get an F1 score of 0.89 on the test set, which is pretty good for the small dataset and time spent."
|
| 456 |
+
]
|
| 457 |
+
},
|
| 458 |
+
{
|
| 459 |
+
"cell_type": "markdown",
|
| 460 |
+
"metadata": {},
|
| 461 |
+
"source": [
|
| 462 |
+
"## Run inference\n",
|
| 463 |
+
"\n",
|
| 464 |
+
"We can now load the model and run inference."
|
| 465 |
+
]
|
| 466 |
+
},
|
| 467 |
+
{
|
| 468 |
+
"cell_type": "code",
|
| 469 |
+
"execution_count": 11,
|
| 470 |
+
"metadata": {},
|
| 471 |
+
"outputs": [
|
| 472 |
+
{
|
| 473 |
+
"name": "stderr",
|
| 474 |
+
"output_type": "stream",
|
| 475 |
+
"text": [
|
| 476 |
+
"Device set to use mps:0\n"
|
| 477 |
+
]
|
| 478 |
+
},
|
| 479 |
+
{
|
| 480 |
+
"data": {
|
| 481 |
+
"text/plain": [
|
| 482 |
+
"[{'label': 'health', 'score': 0.6779336333274841}]"
|
| 483 |
+
]
|
| 484 |
+
},
|
| 485 |
+
"execution_count": 11,
|
| 486 |
+
"metadata": {},
|
| 487 |
+
"output_type": "execute_result"
|
| 488 |
+
}
|
| 489 |
+
],
|
| 490 |
+
"source": [
|
| 491 |
+
"from transformers import pipeline\n",
|
| 492 |
+
" \n",
|
| 493 |
+
"# load model from huggingface.co/models using our repository id\n",
|
| 494 |
+
"classifier = pipeline(\n",
|
| 495 |
+
" task=\"text-classification\", \n",
|
| 496 |
+
" model=\"argilla/ModernBERT-domain-classifier\", \n",
|
| 497 |
+
" device=0,\n",
|
| 498 |
+
")\n",
|
| 499 |
+
" \n",
|
| 500 |
+
"sample = \"Smoking is bad for your health.\"\n",
|
| 501 |
+
" \n",
|
| 502 |
+
"classifier(sample)"
|
| 503 |
+
]
|
| 504 |
+
},
|
| 505 |
+
{
|
| 506 |
+
"cell_type": "markdown",
|
| 507 |
+
"metadata": {},
|
| 508 |
+
"source": [
|
| 509 |
+
"## Conclusion\n",
|
| 510 |
+
"\n",
|
| 511 |
+
"We have shown that we can generate a synthetic dataset from an LLM and finetune a ModernBERT model on it. This the effectiveness of synthetic data and the novel ModernBERT model, which is new and improved version of BERT models, with 8192 token context length, significantly better downstream performance, and much faster processing speeds. \n",
|
| 512 |
+
"\n",
|
| 513 |
+
"Pretty cool for 20 minutes of generating data, and an hour of fine-tuning on consumer hardware."
|
| 514 |
+
]
|
| 515 |
+
}
|
| 516 |
+
],
|
| 517 |
+
"metadata": {
|
| 518 |
+
"kernelspec": {
|
| 519 |
+
"display_name": ".venv",
|
| 520 |
+
"language": "python",
|
| 521 |
+
"name": "python3"
|
| 522 |
+
},
|
| 523 |
+
"language_info": {
|
| 524 |
+
"codemirror_mode": {
|
| 525 |
+
"name": "ipython",
|
| 526 |
+
"version": 3
|
| 527 |
+
},
|
| 528 |
+
"file_extension": ".py",
|
| 529 |
+
"mimetype": "text/x-python",
|
| 530 |
+
"name": "python",
|
| 531 |
+
"nbconvert_exporter": "python",
|
| 532 |
+
"pygments_lexer": "ipython3",
|
| 533 |
+
"version": "3.11.11"
|
| 534 |
+
}
|
| 535 |
+
},
|
| 536 |
+
"nbformat": 4,
|
| 537 |
+
"nbformat_minor": 2
|
| 538 |
+
}
|
examples/fine-tune-modernbert-rag.ipynb
ADDED
|
@@ -0,0 +1,980 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "markdown",
|
| 5 |
+
"metadata": {},
|
| 6 |
+
"source": [
|
| 7 |
+
"# Fine-tune ModernBERT with Synthetic Data for RAG\n",
|
| 8 |
+
"\n",
|
| 9 |
+
"This notebook demonstrates the fine-tuning process of `modernbert-embed-base` using synthetic data tailored for the Retrieval-Augmented Generation (RAG) model.\n",
|
| 10 |
+
"\n",
|
| 11 |
+
"It provides a complete walkthrough of the fine-tuning process after generating synthetic data using the Synthetic Data Generator. For a comprehensive explanation of the methodology and additional details, refer to the blog post: [Fine-tune ModernBERT for RAG with Synthetic Data](https://huggingface.co/blog/fine-tune-modernbert-for-rag-with-synthetic-data)."
|
| 12 |
+
]
|
| 13 |
+
},
|
| 14 |
+
{
|
| 15 |
+
"cell_type": "markdown",
|
| 16 |
+
"metadata": {},
|
| 17 |
+
"source": [
|
| 18 |
+
"## Getting Started"
|
| 19 |
+
]
|
| 20 |
+
},
|
| 21 |
+
{
|
| 22 |
+
"cell_type": "markdown",
|
| 23 |
+
"metadata": {},
|
| 24 |
+
"source": [
|
| 25 |
+
"### Install the Dependencies"
|
| 26 |
+
]
|
| 27 |
+
},
|
| 28 |
+
{
|
| 29 |
+
"cell_type": "code",
|
| 30 |
+
"execution_count": null,
|
| 31 |
+
"metadata": {},
|
| 32 |
+
"outputs": [],
|
| 33 |
+
"source": [
|
| 34 |
+
"!pip install torch\n",
|
| 35 |
+
"!pip install datasets\n",
|
| 36 |
+
"!pip install sentence-transformers\n",
|
| 37 |
+
"!pip install haystack-ai\n",
|
| 38 |
+
"!pip install git+https://github.com/huggingface/transformers.git # for the latest version of transformers"
|
| 39 |
+
]
|
| 40 |
+
},
|
| 41 |
+
{
|
| 42 |
+
"cell_type": "markdown",
|
| 43 |
+
"metadata": {},
|
| 44 |
+
"source": [
|
| 45 |
+
"### Import the Required Libraries"
|
| 46 |
+
]
|
| 47 |
+
},
|
| 48 |
+
{
|
| 49 |
+
"cell_type": "code",
|
| 50 |
+
"execution_count": 1,
|
| 51 |
+
"metadata": {},
|
| 52 |
+
"outputs": [],
|
| 53 |
+
"source": [
|
| 54 |
+
"import torch\n",
|
| 55 |
+
"from torch.utils.data import DataLoader\n",
|
| 56 |
+
"\n",
|
| 57 |
+
"from datasets import load_dataset, concatenate_datasets, Dataset, DatasetDict\n",
|
| 58 |
+
"\n",
|
| 59 |
+
"\n",
|
| 60 |
+
"from sentence_transformers import (\n",
|
| 61 |
+
" SentenceTransformer,\n",
|
| 62 |
+
" SentenceTransformerModelCardData,\n",
|
| 63 |
+
" CrossEncoder,\n",
|
| 64 |
+
" InputExample,\n",
|
| 65 |
+
" SentenceTransformerTrainer,\n",
|
| 66 |
+
")\n",
|
| 67 |
+
"from sentence_transformers.losses import TripletLoss\n",
|
| 68 |
+
"from sentence_transformers.training_args import (\n",
|
| 69 |
+
" SentenceTransformerTrainingArguments,\n",
|
| 70 |
+
" BatchSamplers,\n",
|
| 71 |
+
")\n",
|
| 72 |
+
"from sentence_transformers.evaluation import TripletEvaluator\n",
|
| 73 |
+
"from sentence_transformers.cross_encoder.evaluation import CECorrelationEvaluator\n",
|
| 74 |
+
"\n",
|
| 75 |
+
"\n",
|
| 76 |
+
"from haystack import Document, Pipeline\n",
|
| 77 |
+
"from haystack.document_stores.in_memory import InMemoryDocumentStore\n",
|
| 78 |
+
"from haystack.components.embedders import (\n",
|
| 79 |
+
" SentenceTransformersDocumentEmbedder,\n",
|
| 80 |
+
" SentenceTransformersTextEmbedder,\n",
|
| 81 |
+
")\n",
|
| 82 |
+
"from haystack.components.rankers import SentenceTransformersDiversityRanker\n",
|
| 83 |
+
"from haystack.components.retrievers.in_memory import InMemoryEmbeddingRetriever\n",
|
| 84 |
+
"from haystack.components.builders import ChatPromptBuilder\n",
|
| 85 |
+
"from haystack.components.generators.chat import HuggingFaceAPIChatGenerator\n",
|
| 86 |
+
"from haystack.dataclasses import ChatMessage\n",
|
| 87 |
+
"from haystack.utils import Secret\n",
|
| 88 |
+
"from haystack.utils.hf import HFGenerationAPIType"
|
| 89 |
+
]
|
| 90 |
+
},
|
| 91 |
+
{
|
| 92 |
+
"cell_type": "markdown",
|
| 93 |
+
"metadata": {},
|
| 94 |
+
"source": [
|
| 95 |
+
"### Configure the Environment"
|
| 96 |
+
]
|
| 97 |
+
},
|
| 98 |
+
{
|
| 99 |
+
"cell_type": "code",
|
| 100 |
+
"execution_count": 2,
|
| 101 |
+
"metadata": {},
|
| 102 |
+
"outputs": [],
|
| 103 |
+
"source": [
|
| 104 |
+
"MODEL = \"nomic-ai/modernbert-embed-base\"\n",
|
| 105 |
+
"REPO_NAME = \"sdiazlor\" # your HF username here\n",
|
| 106 |
+
"MODEL_NAME_BIENCODER = \"modernbert-embed-base-biencoder-human-rights\"\n",
|
| 107 |
+
"MODEL_NAME_CROSSENCODER = \"modernbert-embed-base-crossencoder-human-rights\""
|
| 108 |
+
]
|
| 109 |
+
},
|
| 110 |
+
{
|
| 111 |
+
"cell_type": "code",
|
| 112 |
+
"execution_count": null,
|
| 113 |
+
"metadata": {},
|
| 114 |
+
"outputs": [
|
| 115 |
+
{
|
| 116 |
+
"name": "stdout",
|
| 117 |
+
"output_type": "stream",
|
| 118 |
+
"text": [
|
| 119 |
+
"Using device: mps\n"
|
| 120 |
+
]
|
| 121 |
+
}
|
| 122 |
+
],
|
| 123 |
+
"source": [
|
| 124 |
+
"if torch.backends.mps.is_available():\n",
|
| 125 |
+
" device = \"mps\"\n",
|
| 126 |
+
"elif torch.cuda.is_available():\n",
|
| 127 |
+
" device = \"cuda\"\n",
|
| 128 |
+
"else:\n",
|
| 129 |
+
" device = \"cpu\"\n",
|
| 130 |
+
"\n",
|
| 131 |
+
"print(f\"Using device: {device}\")"
|
| 132 |
+
]
|
| 133 |
+
},
|
| 134 |
+
{
|
| 135 |
+
"cell_type": "markdown",
|
| 136 |
+
"metadata": {},
|
| 137 |
+
"source": [
|
| 138 |
+
"## Pre-process the Synthetic Data"
|
| 139 |
+
]
|
| 140 |
+
},
|
| 141 |
+
{
|
| 142 |
+
"cell_type": "code",
|
| 143 |
+
"execution_count": 3,
|
| 144 |
+
"metadata": {},
|
| 145 |
+
"outputs": [
|
| 146 |
+
{
|
| 147 |
+
"data": {
|
| 148 |
+
"text/plain": [
|
| 149 |
+
"Dataset({\n",
|
| 150 |
+
" features: ['context', 'question', 'response', 'positive_retrieval', 'negative_retrieval', 'positive_reranking', 'negative_reranking'],\n",
|
| 151 |
+
" num_rows: 1000\n",
|
| 152 |
+
"})"
|
| 153 |
+
]
|
| 154 |
+
},
|
| 155 |
+
"execution_count": 3,
|
| 156 |
+
"metadata": {},
|
| 157 |
+
"output_type": "execute_result"
|
| 158 |
+
}
|
| 159 |
+
],
|
| 160 |
+
"source": [
|
| 161 |
+
"# Combine the generated datasets from files and prompts\n",
|
| 162 |
+
"\n",
|
| 163 |
+
"dataset_rag_from_file = load_dataset(f\"{REPO_NAME}/rag-human-rights-from-files\", split=\"train\")\n",
|
| 164 |
+
"dataset_rag_from_prompt = load_dataset(f\"{REPO_NAME}/rag-human-rights-from-prompt\", split=\"train\")\n",
|
| 165 |
+
"\n",
|
| 166 |
+
"combined_rag_dataset = concatenate_datasets(\n",
|
| 167 |
+
" [dataset_rag_from_file, dataset_rag_from_prompt]\n",
|
| 168 |
+
")\n",
|
| 169 |
+
"\n",
|
| 170 |
+
"combined_rag_dataset"
|
| 171 |
+
]
|
| 172 |
+
},
|
| 173 |
+
{
|
| 174 |
+
"cell_type": "code",
|
| 175 |
+
"execution_count": null,
|
| 176 |
+
"metadata": {},
|
| 177 |
+
"outputs": [
|
| 178 |
+
{
|
| 179 |
+
"data": {
|
| 180 |
+
"text/plain": [
|
| 181 |
+
"Dataset({\n",
|
| 182 |
+
" features: ['context', 'question', 'response', 'positive_retrieval', 'negative_retrieval', 'positive_reranking', 'negative_reranking'],\n",
|
| 183 |
+
" num_rows: 828\n",
|
| 184 |
+
"})"
|
| 185 |
+
]
|
| 186 |
+
},
|
| 187 |
+
"execution_count": 6,
|
| 188 |
+
"metadata": {},
|
| 189 |
+
"output_type": "execute_result"
|
| 190 |
+
}
|
| 191 |
+
],
|
| 192 |
+
"source": [
|
| 193 |
+
"# Filter out examples with empty or NaN values\n",
|
| 194 |
+
"\n",
|
| 195 |
+
"def filter_empty_or_nan(example):\n",
|
| 196 |
+
" return all(\n",
|
| 197 |
+
" value is not None and str(value).strip() != \"\" for value in example.values()\n",
|
| 198 |
+
" )\n",
|
| 199 |
+
"\n",
|
| 200 |
+
"filtered_rag_dataset = combined_rag_dataset.filter(filter_empty_or_nan).shuffle(seed=42)\n",
|
| 201 |
+
"filtered_rag_dataset"
|
| 202 |
+
]
|
| 203 |
+
},
|
| 204 |
+
{
|
| 205 |
+
"cell_type": "code",
|
| 206 |
+
"execution_count": null,
|
| 207 |
+
"metadata": {},
|
| 208 |
+
"outputs": [
|
| 209 |
+
{
|
| 210 |
+
"name": "stdout",
|
| 211 |
+
"output_type": "stream",
|
| 212 |
+
"text": [
|
| 213 |
+
"Dataset({\n",
|
| 214 |
+
" features: ['anchor', 'positive', 'negative'],\n",
|
| 215 |
+
" num_rows: 828\n",
|
| 216 |
+
"})\n",
|
| 217 |
+
"Dataset({\n",
|
| 218 |
+
" features: ['anchor', 'positive'],\n",
|
| 219 |
+
" num_rows: 828\n",
|
| 220 |
+
"})\n"
|
| 221 |
+
]
|
| 222 |
+
}
|
| 223 |
+
],
|
| 224 |
+
"source": [
|
| 225 |
+
"# Rename, select and reorder columns according to the expected format for the SentenceTransformer and CrossEncoder models\n",
|
| 226 |
+
"\n",
|
| 227 |
+
"def rename_and_reorder_columns(dataset, rename_map, selected_columns):\n",
|
| 228 |
+
" for old_name, new_name in rename_map.items():\n",
|
| 229 |
+
" if old_name in dataset.column_names:\n",
|
| 230 |
+
" dataset = dataset.rename_column(old_name, new_name)\n",
|
| 231 |
+
" dataset = dataset.select_columns(selected_columns)\n",
|
| 232 |
+
" return dataset\n",
|
| 233 |
+
"\n",
|
| 234 |
+
"clean_rag_dataset_biencoder = rename_and_reorder_columns(\n",
|
| 235 |
+
" filtered_rag_dataset,\n",
|
| 236 |
+
" rename_map={\"context\": \"anchor\", \"positive_retrieval\": \"positive\", \"negative_retrieval\": \"negative\"},\n",
|
| 237 |
+
" selected_columns=[\"anchor\", \"positive\", \"negative\"],\n",
|
| 238 |
+
")\n",
|
| 239 |
+
"\n",
|
| 240 |
+
"clean_rag_dataset_crossencoder = rename_and_reorder_columns(\n",
|
| 241 |
+
" filtered_rag_dataset,\n",
|
| 242 |
+
" rename_map={\"context\": \"anchor\", \"positive_retrieval\": \"positive\"}, #TODO\n",
|
| 243 |
+
" selected_columns=[\"anchor\", \"positive\"],\n",
|
| 244 |
+
")\n",
|
| 245 |
+
"\n",
|
| 246 |
+
"print(clean_rag_dataset_biencoder)\n",
|
| 247 |
+
"print(clean_rag_dataset_crossencoder)"
|
| 248 |
+
]
|
| 249 |
+
},
|
| 250 |
+
{
|
| 251 |
+
"cell_type": "code",
|
| 252 |
+
"execution_count": null,
|
| 253 |
+
"metadata": {},
|
| 254 |
+
"outputs": [
|
| 255 |
+
{
|
| 256 |
+
"name": "stderr",
|
| 257 |
+
"output_type": "stream",
|
| 258 |
+
"text": [
|
| 259 |
+
"Some weights of BertForSequenceClassification were not initialized from the model checkpoint at Snowflake/snowflake-arctic-embed-m-v1.5 and are newly initialized: ['classifier.bias', 'classifier.weight', 'pooler.dense.bias', 'pooler.dense.weight']\n",
|
| 260 |
+
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
|
| 261 |
+
]
|
| 262 |
+
},
|
| 263 |
+
{
|
| 264 |
+
"data": {
|
| 265 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 266 |
+
"model_id": "406c4d22f43f41d592d3b94da2955444",
|
| 267 |
+
"version_major": 2,
|
| 268 |
+
"version_minor": 0
|
| 269 |
+
},
|
| 270 |
+
"text/plain": [
|
| 271 |
+
"Map: 0%| | 0/828 [00:00<?, ? examples/s]"
|
| 272 |
+
]
|
| 273 |
+
},
|
| 274 |
+
"metadata": {},
|
| 275 |
+
"output_type": "display_data"
|
| 276 |
+
},
|
| 277 |
+
{
|
| 278 |
+
"data": {
|
| 279 |
+
"text/plain": [
|
| 280 |
+
"Dataset({\n",
|
| 281 |
+
" features: ['anchor', 'positive', 'score'],\n",
|
| 282 |
+
" num_rows: 828\n",
|
| 283 |
+
"})"
|
| 284 |
+
]
|
| 285 |
+
},
|
| 286 |
+
"execution_count": 8,
|
| 287 |
+
"metadata": {},
|
| 288 |
+
"output_type": "execute_result"
|
| 289 |
+
}
|
| 290 |
+
],
|
| 291 |
+
"source": [
|
| 292 |
+
"# Add scores to train the CrossEncoder model, which requires sentence pairs with a score indicating how related they are.\n",
|
| 293 |
+
"# Check the available models: https://huggingface.co/spaces/mteb/leaderboard\n",
|
| 294 |
+
"\n",
|
| 295 |
+
"model_reranking = CrossEncoder(\n",
|
| 296 |
+
" model_name=\"Snowflake/snowflake-arctic-embed-m-v1.5\", device=device\n",
|
| 297 |
+
")\n",
|
| 298 |
+
"\n",
|
| 299 |
+
"def add_reranking_scores(batch):\n",
|
| 300 |
+
" pairs = list(zip(batch[\"anchor\"], batch[\"positive\"]))\n",
|
| 301 |
+
" batch[\"score\"] = model_reranking.predict(pairs)\n",
|
| 302 |
+
" return batch\n",
|
| 303 |
+
"\n",
|
| 304 |
+
"clean_rag_dataset_crossencoder = clean_rag_dataset_crossencoder.map(\n",
|
| 305 |
+
" add_reranking_scores, batched=True, batch_size=250\n",
|
| 306 |
+
")\n",
|
| 307 |
+
"clean_rag_dataset_crossencoder"
|
| 308 |
+
]
|
| 309 |
+
},
|
| 310 |
+
{
|
| 311 |
+
"cell_type": "code",
|
| 312 |
+
"execution_count": null,
|
| 313 |
+
"metadata": {},
|
| 314 |
+
"outputs": [
|
| 315 |
+
{
|
| 316 |
+
"name": "stdout",
|
| 317 |
+
"output_type": "stream",
|
| 318 |
+
"text": [
|
| 319 |
+
"DatasetDict({\n",
|
| 320 |
+
" train: Dataset({\n",
|
| 321 |
+
" features: ['anchor', 'positive', 'negative'],\n",
|
| 322 |
+
" num_rows: 662\n",
|
| 323 |
+
" })\n",
|
| 324 |
+
" eval: Dataset({\n",
|
| 325 |
+
" features: ['anchor', 'positive', 'negative'],\n",
|
| 326 |
+
" num_rows: 166\n",
|
| 327 |
+
" })\n",
|
| 328 |
+
"})\n",
|
| 329 |
+
"DatasetDict({\n",
|
| 330 |
+
" train: Dataset({\n",
|
| 331 |
+
" features: ['anchor', 'positive', 'score'],\n",
|
| 332 |
+
" num_rows: 662\n",
|
| 333 |
+
" })\n",
|
| 334 |
+
" eval: Dataset({\n",
|
| 335 |
+
" features: ['anchor', 'positive', 'score'],\n",
|
| 336 |
+
" num_rows: 166\n",
|
| 337 |
+
" })\n",
|
| 338 |
+
"})\n"
|
| 339 |
+
]
|
| 340 |
+
}
|
| 341 |
+
],
|
| 342 |
+
"source": [
|
| 343 |
+
"# Split the datasets into training and evaluation sets\n",
|
| 344 |
+
"def split_dataset(dataset, train_size=0.8, seed=42):\n",
|
| 345 |
+
" train_eval_split = dataset.train_test_split(test_size=1 - train_size, seed=seed)\n",
|
| 346 |
+
"\n",
|
| 347 |
+
" dataset_dict = DatasetDict(\n",
|
| 348 |
+
" {\"train\": train_eval_split[\"train\"], \"eval\": train_eval_split[\"test\"]}\n",
|
| 349 |
+
" )\n",
|
| 350 |
+
"\n",
|
| 351 |
+
" return dataset_dict\n",
|
| 352 |
+
"\n",
|
| 353 |
+
"dataset_rag_biencoder = split_dataset(clean_rag_dataset_biencoder)\n",
|
| 354 |
+
"dataset_rag_crossencoder = split_dataset(clean_rag_dataset_crossencoder)\n",
|
| 355 |
+
"\n",
|
| 356 |
+
"print(dataset_rag_biencoder)\n",
|
| 357 |
+
"print(dataset_rag_crossencoder)"
|
| 358 |
+
]
|
| 359 |
+
},
|
| 360 |
+
{
|
| 361 |
+
"cell_type": "markdown",
|
| 362 |
+
"metadata": {},
|
| 363 |
+
"source": [
|
| 364 |
+
"## Train the Bi-Encoder model for Retrieval"
|
| 365 |
+
]
|
| 366 |
+
},
|
| 367 |
+
{
|
| 368 |
+
"cell_type": "code",
|
| 369 |
+
"execution_count": null,
|
| 370 |
+
"metadata": {},
|
| 371 |
+
"outputs": [],
|
| 372 |
+
"source": [
|
| 373 |
+
"# Load the base model and create the SentenceTransformer model\n",
|
| 374 |
+
"model_biencoder = SentenceTransformer(\n",
|
| 375 |
+
" MODEL,\n",
|
| 376 |
+
" model_card_data=SentenceTransformerModelCardData(\n",
|
| 377 |
+
" language=\"en\",\n",
|
| 378 |
+
" license=\"apache-2.0\",\n",
|
| 379 |
+
" model_name=MODEL_NAME_BIENCODER,\n",
|
| 380 |
+
" ),\n",
|
| 381 |
+
")\n",
|
| 382 |
+
"model_biencoder.gradient_checkpointing_enable() # Enable gradient checkpointing to save memory"
|
| 383 |
+
]
|
| 384 |
+
},
|
| 385 |
+
{
|
| 386 |
+
"cell_type": "code",
|
| 387 |
+
"execution_count": null,
|
| 388 |
+
"metadata": {},
|
| 389 |
+
"outputs": [],
|
| 390 |
+
"source": [
|
| 391 |
+
"# Select the TripleLoss loss function which requires sentence triplets (anchor, positive, negative)\n",
|
| 392 |
+
"# Check the available losses: https://sbert.net/docs/sentence_transformer/loss_overview.html\n",
|
| 393 |
+
"\n",
|
| 394 |
+
"loss_biencoder = TripletLoss"
|
| 395 |
+
]
|
| 396 |
+
},
|
| 397 |
+
{
|
| 398 |
+
"cell_type": "code",
|
| 399 |
+
"execution_count": null,
|
| 400 |
+
"metadata": {},
|
| 401 |
+
"outputs": [
|
| 402 |
+
{
|
| 403 |
+
"name": "stderr",
|
| 404 |
+
"output_type": "stream",
|
| 405 |
+
"text": [
|
| 406 |
+
"/Users/sdiazlor/.pyenv/versions/3.11.4/envs/distilabel-tutorials/lib/python3.11/site-packages/transformers/training_args.py:2243: UserWarning: `use_mps_device` is deprecated and will be removed in version 5.0 of 🤗 Transformers. `mps` device will be used by default if available similar to the way `cuda` device is used.Therefore, no action from user is required. \n",
|
| 407 |
+
" warnings.warn(\n"
|
| 408 |
+
]
|
| 409 |
+
}
|
| 410 |
+
],
|
| 411 |
+
"source": [
|
| 412 |
+
"# Define the training arguments for the SentenceTransformer model\n",
|
| 413 |
+
"# Customize them as needed for your requirements\n",
|
| 414 |
+
"\n",
|
| 415 |
+
"training_args = SentenceTransformerTrainingArguments(\n",
|
| 416 |
+
" output_dir=f\"models/{MODEL_NAME_BIENCODER}\",\n",
|
| 417 |
+
" num_train_epochs=3,\n",
|
| 418 |
+
" per_device_train_batch_size=4,\n",
|
| 419 |
+
" gradient_accumulation_steps=4,\n",
|
| 420 |
+
" per_device_eval_batch_size=4,\n",
|
| 421 |
+
" warmup_ratio=0.1,\n",
|
| 422 |
+
" learning_rate=2e-5,\n",
|
| 423 |
+
" lr_scheduler_type=\"cosine\",\n",
|
| 424 |
+
" fp16=False, # or True if stable on your MPS device\n",
|
| 425 |
+
" bf16=False,\n",
|
| 426 |
+
" batch_sampler=BatchSamplers.NO_DUPLICATES,\n",
|
| 427 |
+
" eval_strategy=\"epoch\",\n",
|
| 428 |
+
" save_strategy=\"epoch\",\n",
|
| 429 |
+
" save_total_limit=2,\n",
|
| 430 |
+
" logging_steps=100,\n",
|
| 431 |
+
" load_best_model_at_end=True,\n",
|
| 432 |
+
" use_mps_device=(device == \"mps\"),\n",
|
| 433 |
+
")"
|
| 434 |
+
]
|
| 435 |
+
},
|
| 436 |
+
{
|
| 437 |
+
"cell_type": "code",
|
| 438 |
+
"execution_count": null,
|
| 439 |
+
"metadata": {},
|
| 440 |
+
"outputs": [],
|
| 441 |
+
"source": [
|
| 442 |
+
"# Define the evaluator to assess the performance of the model\n",
|
| 443 |
+
"triplet_evaluator = TripletEvaluator(\n",
|
| 444 |
+
" anchors=dataset_rag_biencoder[\"eval\"][\"anchor\"],\n",
|
| 445 |
+
" positives=dataset_rag_biencoder[\"eval\"][\"positive\"],\n",
|
| 446 |
+
" negatives=dataset_rag_biencoder[\"eval\"][\"negative\"],\n",
|
| 447 |
+
")"
|
| 448 |
+
]
|
| 449 |
+
},
|
| 450 |
+
{
|
| 451 |
+
"cell_type": "code",
|
| 452 |
+
"execution_count": null,
|
| 453 |
+
"metadata": {},
|
| 454 |
+
"outputs": [
|
| 455 |
+
{
|
| 456 |
+
"name": "stderr",
|
| 457 |
+
"output_type": "stream",
|
| 458 |
+
"text": [
|
| 459 |
+
"/Users/sdiazlor/.pyenv/versions/3.11.4/envs/distilabel-tutorials/lib/python3.11/site-packages/torch/utils/checkpoint.py:295: FutureWarning: `torch.cpu.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cpu', args...)` instead.\n",
|
| 460 |
+
" with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs): # type: ignore[attr-defined]\n"
|
| 461 |
+
]
|
| 462 |
+
},
|
| 463 |
+
{
|
| 464 |
+
"data": {
|
| 465 |
+
"text/html": [
|
| 466 |
+
"\n",
|
| 467 |
+
" <div>\n",
|
| 468 |
+
" \n",
|
| 469 |
+
" <progress value='123' max='123' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
|
| 470 |
+
" [123/123 25:34, Epoch 2/3]\n",
|
| 471 |
+
" </div>\n",
|
| 472 |
+
" <table border=\"1\" class=\"dataframe\">\n",
|
| 473 |
+
" <thead>\n",
|
| 474 |
+
" <tr style=\"text-align: left;\">\n",
|
| 475 |
+
" <th>Epoch</th>\n",
|
| 476 |
+
" <th>Training Loss</th>\n",
|
| 477 |
+
" <th>Validation Loss</th>\n",
|
| 478 |
+
" <th>Cosine Accuracy</th>\n",
|
| 479 |
+
" </tr>\n",
|
| 480 |
+
" </thead>\n",
|
| 481 |
+
" <tbody>\n",
|
| 482 |
+
" <tr>\n",
|
| 483 |
+
" <td>1</td>\n",
|
| 484 |
+
" <td>No log</td>\n",
|
| 485 |
+
" <td>3.655929</td>\n",
|
| 486 |
+
" <td>0.969880</td>\n",
|
| 487 |
+
" </tr>\n",
|
| 488 |
+
" <tr>\n",
|
| 489 |
+
" <td>2</td>\n",
|
| 490 |
+
" <td>14.374000</td>\n",
|
| 491 |
+
" <td>3.498395</td>\n",
|
| 492 |
+
" <td>0.981928</td>\n",
|
| 493 |
+
" </tr>\n",
|
| 494 |
+
" </tbody>\n",
|
| 495 |
+
"</table><p>"
|
| 496 |
+
],
|
| 497 |
+
"text/plain": [
|
| 498 |
+
"<IPython.core.display.HTML object>"
|
| 499 |
+
]
|
| 500 |
+
},
|
| 501 |
+
"metadata": {},
|
| 502 |
+
"output_type": "display_data"
|
| 503 |
+
},
|
| 504 |
+
{
|
| 505 |
+
"data": {
|
| 506 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 507 |
+
"model_id": "faad6e9752f34babadff7a966ae55d87",
|
| 508 |
+
"version_major": 2,
|
| 509 |
+
"version_minor": 0
|
| 510 |
+
},
|
| 511 |
+
"text/plain": [
|
| 512 |
+
"Computing widget examples: 0%| | 0/1 [00:00<?, ?example/s]"
|
| 513 |
+
]
|
| 514 |
+
},
|
| 515 |
+
"metadata": {},
|
| 516 |
+
"output_type": "display_data"
|
| 517 |
+
},
|
| 518 |
+
{
|
| 519 |
+
"name": "stderr",
|
| 520 |
+
"output_type": "stream",
|
| 521 |
+
"text": [
|
| 522 |
+
"/Users/sdiazlor/.pyenv/versions/3.11.4/envs/distilabel-tutorials/lib/python3.11/site-packages/torch/utils/checkpoint.py:295: FutureWarning: `torch.cpu.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cpu', args...)` instead.\n",
|
| 523 |
+
" with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs): # type: ignore[attr-defined]\n",
|
| 524 |
+
"/Users/sdiazlor/.pyenv/versions/3.11.4/envs/distilabel-tutorials/lib/python3.11/site-packages/torch/utils/checkpoint.py:295: FutureWarning: `torch.cpu.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cpu', args...)` instead.\n",
|
| 525 |
+
" with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs): # type: ignore[attr-defined]\n"
|
| 526 |
+
]
|
| 527 |
+
}
|
| 528 |
+
],
|
| 529 |
+
"source": [
|
| 530 |
+
"# Train the model. This will take some time depending on the size of the dataset and the model\n",
|
| 531 |
+
"# Remember to adjust the training arguments according to your requirements\n",
|
| 532 |
+
"\n",
|
| 533 |
+
"trainer = SentenceTransformerTrainer(\n",
|
| 534 |
+
" model=model_biencoder,\n",
|
| 535 |
+
" args=training_args,\n",
|
| 536 |
+
" train_dataset=dataset_rag_biencoder[\"train\"],\n",
|
| 537 |
+
" eval_dataset=dataset_rag_biencoder[\"eval\"],\n",
|
| 538 |
+
" loss=loss_biencoder,\n",
|
| 539 |
+
" evaluator=triplet_evaluator,\n",
|
| 540 |
+
")\n",
|
| 541 |
+
"trainer.train()"
|
| 542 |
+
]
|
| 543 |
+
},
|
| 544 |
+
{
|
| 545 |
+
"cell_type": "code",
|
| 546 |
+
"execution_count": null,
|
| 547 |
+
"metadata": {},
|
| 548 |
+
"outputs": [],
|
| 549 |
+
"source": [
|
| 550 |
+
"# Save the model to the local directory and push it to the Hub\n",
|
| 551 |
+
"model_biencoder.save_pretrained(f\"models/{MODEL_NAME_BIENCODER}\")\n",
|
| 552 |
+
"model_biencoder.push_to_hub(f\"{REPO_NAME}/{MODEL_NAME_BIENCODER}\")"
|
| 553 |
+
]
|
| 554 |
+
},
|
| 555 |
+
{
|
| 556 |
+
"cell_type": "markdown",
|
| 557 |
+
"metadata": {},
|
| 558 |
+
"source": [
|
| 559 |
+
"## Train the Cross-Encoder model for Ranking"
|
| 560 |
+
]
|
| 561 |
+
},
|
| 562 |
+
{
|
| 563 |
+
"cell_type": "code",
|
| 564 |
+
"execution_count": null,
|
| 565 |
+
"metadata": {},
|
| 566 |
+
"outputs": [],
|
| 567 |
+
"source": [
|
| 568 |
+
"# Prepare the training and evaluation samples for the CrossEncoder model\n",
|
| 569 |
+
"\n",
|
| 570 |
+
"train_samples = []\n",
|
| 571 |
+
"for row in dataset_rag_crossencoder[\"train\"]:\n",
|
| 572 |
+
" # Suppose 'score' is a float or an integer that you want to predict\n",
|
| 573 |
+
" train_samples.append(\n",
|
| 574 |
+
" InputExample(texts=[row[\"anchor\"], row[\"positive\"]], label=float(row[\"score\"]))\n",
|
| 575 |
+
" )\n",
|
| 576 |
+
"\n",
|
| 577 |
+
"eval_samples = []\n",
|
| 578 |
+
"for row in dataset_rag_crossencoder[\"eval\"]:\n",
|
| 579 |
+
" eval_samples.append(\n",
|
| 580 |
+
" InputExample(texts=[row[\"anchor\"], row[\"positive\"]], label=float(row[\"score\"]))\n",
|
| 581 |
+
" )\n",
|
| 582 |
+
"\n",
|
| 583 |
+
"# Initialize the DataLoader for the training samples\n",
|
| 584 |
+
"train_dataloader = DataLoader(train_samples, shuffle=True, batch_size=4)"
|
| 585 |
+
]
|
| 586 |
+
},
|
| 587 |
+
{
|
| 588 |
+
"cell_type": "code",
|
| 589 |
+
"execution_count": null,
|
| 590 |
+
"metadata": {},
|
| 591 |
+
"outputs": [],
|
| 592 |
+
"source": [
|
| 593 |
+
"# Initialize the CrossEncoder model. Set the number of labels to 1 for regression tasks\n",
|
| 594 |
+
"model_crossencoder = CrossEncoder(model_name=MODEL, num_labels=1)"
|
| 595 |
+
]
|
| 596 |
+
},
|
| 597 |
+
{
|
| 598 |
+
"cell_type": "code",
|
| 599 |
+
"execution_count": null,
|
| 600 |
+
"metadata": {},
|
| 601 |
+
"outputs": [],
|
| 602 |
+
"source": [
|
| 603 |
+
"# Define the evaluator\n",
|
| 604 |
+
"evaluator = CECorrelationEvaluator.from_input_examples(eval_samples)"
|
| 605 |
+
]
|
| 606 |
+
},
|
| 607 |
+
{
|
| 608 |
+
"cell_type": "code",
|
| 609 |
+
"execution_count": null,
|
| 610 |
+
"metadata": {},
|
| 611 |
+
"outputs": [
|
| 612 |
+
{
|
| 613 |
+
"data": {
|
| 614 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 615 |
+
"model_id": "9517a852f3d34cff86808c4b10cf8973",
|
| 616 |
+
"version_major": 2,
|
| 617 |
+
"version_minor": 0
|
| 618 |
+
},
|
| 619 |
+
"text/plain": [
|
| 620 |
+
"Epoch: 0%| | 0/3 [00:00<?, ?it/s]"
|
| 621 |
+
]
|
| 622 |
+
},
|
| 623 |
+
"metadata": {},
|
| 624 |
+
"output_type": "display_data"
|
| 625 |
+
},
|
| 626 |
+
{
|
| 627 |
+
"data": {
|
| 628 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 629 |
+
"model_id": "6e942043c5a24e77bd6172cb5492d2a7",
|
| 630 |
+
"version_major": 2,
|
| 631 |
+
"version_minor": 0
|
| 632 |
+
},
|
| 633 |
+
"text/plain": [
|
| 634 |
+
"Iteration: 0%| | 0/166 [00:00<?, ?it/s]"
|
| 635 |
+
]
|
| 636 |
+
},
|
| 637 |
+
"metadata": {},
|
| 638 |
+
"output_type": "display_data"
|
| 639 |
+
},
|
| 640 |
+
{
|
| 641 |
+
"data": {
|
| 642 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 643 |
+
"model_id": "d039d5acf3ed424e9ff6d0b30b51aceb",
|
| 644 |
+
"version_major": 2,
|
| 645 |
+
"version_minor": 0
|
| 646 |
+
},
|
| 647 |
+
"text/plain": [
|
| 648 |
+
"Iteration: 0%| | 0/166 [00:00<?, ?it/s]"
|
| 649 |
+
]
|
| 650 |
+
},
|
| 651 |
+
"metadata": {},
|
| 652 |
+
"output_type": "display_data"
|
| 653 |
+
},
|
| 654 |
+
{
|
| 655 |
+
"data": {
|
| 656 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 657 |
+
"model_id": "5fd5d0442b76448e8cab18b652e29ad8",
|
| 658 |
+
"version_major": 2,
|
| 659 |
+
"version_minor": 0
|
| 660 |
+
},
|
| 661 |
+
"text/plain": [
|
| 662 |
+
"Iteration: 0%| | 0/166 [00:00<?, ?it/s]"
|
| 663 |
+
]
|
| 664 |
+
},
|
| 665 |
+
"metadata": {},
|
| 666 |
+
"output_type": "display_data"
|
| 667 |
+
}
|
| 668 |
+
],
|
| 669 |
+
"source": [
|
| 670 |
+
"# Train the CrossEncoder model\n",
|
| 671 |
+
"\n",
|
| 672 |
+
"model_crossencoder.fit(\n",
|
| 673 |
+
" train_dataloader=train_dataloader,\n",
|
| 674 |
+
" evaluator=evaluator,\n",
|
| 675 |
+
" epochs=3,\n",
|
| 676 |
+
" warmup_steps=500,\n",
|
| 677 |
+
" output_path=f\"models/{MODEL_NAME_CROSSENCODER}\",\n",
|
| 678 |
+
" save_best_model=True,\n",
|
| 679 |
+
")"
|
| 680 |
+
]
|
| 681 |
+
},
|
| 682 |
+
{
|
| 683 |
+
"cell_type": "code",
|
| 684 |
+
"execution_count": null,
|
| 685 |
+
"metadata": {},
|
| 686 |
+
"outputs": [],
|
| 687 |
+
"source": [
|
| 688 |
+
"# Save the model to the local directory and push it to the Hub\n",
|
| 689 |
+
"model_crossencoder.save_pretrained(f\"models/{MODEL_NAME_CROSSENCODER}\")\n",
|
| 690 |
+
"model_crossencoder.push_to_hub(f\"{REPO_NAME}/{MODEL_NAME_CROSSENCODER}\")"
|
| 691 |
+
]
|
| 692 |
+
},
|
| 693 |
+
{
|
| 694 |
+
"cell_type": "markdown",
|
| 695 |
+
"metadata": {},
|
| 696 |
+
"source": [
|
| 697 |
+
"## Build the RAG Pipeline\n",
|
| 698 |
+
"\n",
|
| 699 |
+
"The following section is inspired by the Haystack tutorial, check it for further details: [Creating Your First QA Pipeline with Retrieval-Augmentation](https://haystack.deepset.ai/tutorials/27_first_rag_pipeline)"
|
| 700 |
+
]
|
| 701 |
+
},
|
| 702 |
+
{
|
| 703 |
+
"cell_type": "code",
|
| 704 |
+
"execution_count": 4,
|
| 705 |
+
"metadata": {},
|
| 706 |
+
"outputs": [],
|
| 707 |
+
"source": [
|
| 708 |
+
"# Add the documents to the DocumentStore\n",
|
| 709 |
+
"# Use the already chunked documents from original datasets\n",
|
| 710 |
+
"\n",
|
| 711 |
+
"df = combined_rag_dataset.to_pandas()\n",
|
| 712 |
+
"df = df.drop_duplicates(subset=[\"context\"]) # drop duplicates based on \"context\" column\n",
|
| 713 |
+
"df = df.sample(n=10, random_state=42) # optional: sample a subset of the dataset\n",
|
| 714 |
+
"dataset = Dataset.from_pandas(df)\n",
|
| 715 |
+
"\n",
|
| 716 |
+
"docs = [Document(content=doc[\"context\"]) for doc in dataset]"
|
| 717 |
+
]
|
| 718 |
+
},
|
| 719 |
+
{
|
| 720 |
+
"cell_type": "code",
|
| 721 |
+
"execution_count": null,
|
| 722 |
+
"metadata": {},
|
| 723 |
+
"outputs": [],
|
| 724 |
+
"source": [
|
| 725 |
+
"# Initialize the document store and store the documents with the embeddings using our bi-encoder model\n",
|
| 726 |
+
"\n",
|
| 727 |
+
"document_store = InMemoryDocumentStore()\n",
|
| 728 |
+
"doc_embedder = SentenceTransformersDocumentEmbedder(\n",
|
| 729 |
+
" model=f\"{REPO_NAME}/{MODEL_NAME_BIENCODER}\",\n",
|
| 730 |
+
")\n",
|
| 731 |
+
"doc_embedder.warm_up()\n",
|
| 732 |
+
"\n",
|
| 733 |
+
"docs_with_embeddings = doc_embedder.run(docs)\n",
|
| 734 |
+
"document_store.write_documents(docs_with_embeddings[\"documents\"])\n",
|
| 735 |
+
"\n",
|
| 736 |
+
"text_embedder = SentenceTransformersTextEmbedder(\n",
|
| 737 |
+
" model=f\"{REPO_NAME}/{MODEL_NAME_BIENCODER}\",\n",
|
| 738 |
+
")"
|
| 739 |
+
]
|
| 740 |
+
},
|
| 741 |
+
{
|
| 742 |
+
"cell_type": "code",
|
| 743 |
+
"execution_count": null,
|
| 744 |
+
"metadata": {},
|
| 745 |
+
"outputs": [],
|
| 746 |
+
"source": [
|
| 747 |
+
"# Initialize the retriever (our bi-encoder model) and the ranker (our cross-encoder model)\n",
|
| 748 |
+
"\n",
|
| 749 |
+
"retriever = InMemoryEmbeddingRetriever(document_store)\n",
|
| 750 |
+
"ranker = SentenceTransformersDiversityRanker(\n",
|
| 751 |
+
" model=f\"{REPO_NAME}/{MODEL_NAME_CROSSENCODER}\"\n",
|
| 752 |
+
")"
|
| 753 |
+
]
|
| 754 |
+
},
|
| 755 |
+
{
|
| 756 |
+
"cell_type": "code",
|
| 757 |
+
"execution_count": null,
|
| 758 |
+
"metadata": {},
|
| 759 |
+
"outputs": [],
|
| 760 |
+
"source": [
|
| 761 |
+
"# Define the prompt builder and the chat generator to interact with the models using the HF Serverless Inference API\n",
|
| 762 |
+
"\n",
|
| 763 |
+
"template = [\n",
|
| 764 |
+
" ChatMessage.from_user(\n",
|
| 765 |
+
" \"\"\"\n",
|
| 766 |
+
"Given the following information, answer the question.\n",
|
| 767 |
+
"\n",
|
| 768 |
+
"Context:\n",
|
| 769 |
+
"{% for document in documents %}\n",
|
| 770 |
+
" {{ document.content }}\n",
|
| 771 |
+
"{% endfor %}\n",
|
| 772 |
+
"\n",
|
| 773 |
+
"Question: {{question}}\n",
|
| 774 |
+
"Answer:\n",
|
| 775 |
+
"\"\"\"\n",
|
| 776 |
+
" )\n",
|
| 777 |
+
"]\n",
|
| 778 |
+
"\n",
|
| 779 |
+
"prompt_builder = ChatPromptBuilder(template=template)\n",
|
| 780 |
+
"\n",
|
| 781 |
+
"chat_generator = HuggingFaceAPIChatGenerator(\n",
|
| 782 |
+
" api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API,\n",
|
| 783 |
+
" api_params={\"model\": \"meta-llama/Llama-3.1-8B-Instruct\"},\n",
|
| 784 |
+
" token=Secret.from_env_var(\"HF_TOKEN\"),\n",
|
| 785 |
+
")"
|
| 786 |
+
]
|
| 787 |
+
},
|
| 788 |
+
{
|
| 789 |
+
"cell_type": "code",
|
| 790 |
+
"execution_count": null,
|
| 791 |
+
"metadata": {},
|
| 792 |
+
"outputs": [],
|
| 793 |
+
"source": [
|
| 794 |
+
"# Initialize the pipeline with the components\n",
|
| 795 |
+
"\n",
|
| 796 |
+
"rag_pipeline = Pipeline()\n",
|
| 797 |
+
"rag_pipeline.add_component(\"text_embedder\", text_embedder)\n",
|
| 798 |
+
"rag_pipeline.add_component(\"retriever\", retriever)\n",
|
| 799 |
+
"rag_pipeline.add_component(\"ranker\", ranker)\n",
|
| 800 |
+
"rag_pipeline.add_component(\"prompt_builder\", prompt_builder)\n",
|
| 801 |
+
"rag_pipeline.add_component(\"llm\", chat_generator)"
|
| 802 |
+
]
|
| 803 |
+
},
|
| 804 |
+
{
|
| 805 |
+
"cell_type": "code",
|
| 806 |
+
"execution_count": null,
|
| 807 |
+
"metadata": {},
|
| 808 |
+
"outputs": [
|
| 809 |
+
{
|
| 810 |
+
"data": {
|
| 811 |
+
"text/plain": [
|
| 812 |
+
"<haystack.core.pipeline.pipeline.Pipeline object at 0x32e75b4d0>\n",
|
| 813 |
+
"🚅 Components\n",
|
| 814 |
+
" - text_embedder: SentenceTransformersTextEmbedder\n",
|
| 815 |
+
" - retriever: InMemoryEmbeddingRetriever\n",
|
| 816 |
+
" - ranker: SentenceTransformersDiversityRanker\n",
|
| 817 |
+
" - prompt_builder: ChatPromptBuilder\n",
|
| 818 |
+
" - llm: HuggingFaceAPIChatGenerator\n",
|
| 819 |
+
"🛤️ Connections\n",
|
| 820 |
+
" - text_embedder.embedding -> retriever.query_embedding (List[float])\n",
|
| 821 |
+
" - retriever.documents -> ranker.documents (List[Document])\n",
|
| 822 |
+
" - ranker.documents -> prompt_builder.documents (List[Document])\n",
|
| 823 |
+
" - prompt_builder.prompt -> llm.messages (List[ChatMessage])"
|
| 824 |
+
]
|
| 825 |
+
},
|
| 826 |
+
"execution_count": 12,
|
| 827 |
+
"metadata": {},
|
| 828 |
+
"output_type": "execute_result"
|
| 829 |
+
}
|
| 830 |
+
],
|
| 831 |
+
"source": [
|
| 832 |
+
"# Connect the components to each other\n",
|
| 833 |
+
"\n",
|
| 834 |
+
"rag_pipeline.connect(\"text_embedder.embedding\", \"retriever.query_embedding\")\n",
|
| 835 |
+
"rag_pipeline.connect(\"retriever.documents\", \"ranker.documents\")\n",
|
| 836 |
+
"rag_pipeline.connect(\"ranker\", \"prompt_builder\")\n",
|
| 837 |
+
"rag_pipeline.connect(\"prompt_builder.prompt\", \"llm.messages\")"
|
| 838 |
+
]
|
| 839 |
+
},
|
| 840 |
+
{
|
| 841 |
+
"cell_type": "code",
|
| 842 |
+
"execution_count": null,
|
| 843 |
+
"metadata": {},
|
| 844 |
+
"outputs": [
|
| 845 |
+
{
|
| 846 |
+
"data": {
|
| 847 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 848 |
+
"model_id": "80c813c847524f1493067f6dbe65c725",
|
| 849 |
+
"version_major": 2,
|
| 850 |
+
"version_minor": 0
|
| 851 |
+
},
|
| 852 |
+
"text/plain": [
|
| 853 |
+
"Batches: 0%| | 0/1 [00:00<?, ?it/s]"
|
| 854 |
+
]
|
| 855 |
+
},
|
| 856 |
+
"metadata": {},
|
| 857 |
+
"output_type": "display_data"
|
| 858 |
+
},
|
| 859 |
+
{
|
| 860 |
+
"name": "stdout",
|
| 861 |
+
"output_type": "stream",
|
| 862 |
+
"text": [
|
| 863 |
+
"It seems that there is not enough information given in the human rights protocols provided to accurately answer the question. However, we can inform you that there are several types of human rights documents that this could be referring too. Event the most widely respected declared world document on human rights for Example - Exernal and some Individual (Part 1 Art.) and some other attempted Separation apart include: The convention lists several key rights such as \n",
|
| 864 |
+
"\n",
|
| 865 |
+
"1. Right to Life \n",
|
| 866 |
+
"2. Right to Liberty and Security \n",
|
| 867 |
+
"3. Freedom from Torture \n",
|
| 868 |
+
"4. Freedom from Slavery \n",
|
| 869 |
+
"5. Right to a Fair Trial \n",
|
| 870 |
+
"6. No Punishment without Law \n",
|
| 871 |
+
"7. Respect for Family Life \n",
|
| 872 |
+
"... (and throughout given information 44 protocals - are actually chapter and not... How is the answer \n",
|
| 873 |
+
" \n",
|
| 874 |
+
"\n",
|
| 875 |
+
"Not possible to answer your question due to lack of information, however we can tell you Event the most widely respected declared world document on human rights.\n"
|
| 876 |
+
]
|
| 877 |
+
}
|
| 878 |
+
],
|
| 879 |
+
"source": [
|
| 880 |
+
"# Make a query to the pipeline without references included in your documentation\n",
|
| 881 |
+
"question = \"How many human rights there are?\"\n",
|
| 882 |
+
"\n",
|
| 883 |
+
"response = rag_pipeline.run(\n",
|
| 884 |
+
" {\n",
|
| 885 |
+
" \"text_embedder\": {\"text\": question},\n",
|
| 886 |
+
" \"prompt_builder\": {\"question\": question},\n",
|
| 887 |
+
" \"ranker\": {\"query\": question},\n",
|
| 888 |
+
" }\n",
|
| 889 |
+
")\n",
|
| 890 |
+
"\n",
|
| 891 |
+
"print(response[\"llm\"][\"replies\"][0].text)"
|
| 892 |
+
]
|
| 893 |
+
},
|
| 894 |
+
{
|
| 895 |
+
"cell_type": "code",
|
| 896 |
+
"execution_count": null,
|
| 897 |
+
"metadata": {},
|
| 898 |
+
"outputs": [
|
| 899 |
+
{
|
| 900 |
+
"data": {
|
| 901 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 902 |
+
"model_id": "2995f14154d148589129a3f449adc5d5",
|
| 903 |
+
"version_major": 2,
|
| 904 |
+
"version_minor": 0
|
| 905 |
+
},
|
| 906 |
+
"text/plain": [
|
| 907 |
+
"Batches: 0%| | 0/1 [00:00<?, ?it/s]"
|
| 908 |
+
]
|
| 909 |
+
},
|
| 910 |
+
"metadata": {},
|
| 911 |
+
"output_type": "display_data"
|
| 912 |
+
},
|
| 913 |
+
{
|
| 914 |
+
"name": "stdout",
|
| 915 |
+
"output_type": "stream",
|
| 916 |
+
"text": [
|
| 917 |
+
"The information you provided does not directly list the \"Right of Fair Trial\" but looking under articles of the Convention for the Protection of Human Rights and Fundamental Freedoms, Article 6, also known as the Right to a Fair Trial, gives a clear idea.\n",
|
| 918 |
+
"\n",
|
| 919 |
+
" Article 6. Right to a fair Trial\n",
|
| 920 |
+
" \n",
|
| 921 |
+
"\n",
|
| 922 |
+
"1. Everyone is entitled to a fair and public hearing within a reasonable time by an independent and impartial tribunal established by law.\n",
|
| 923 |
+
" \n",
|
| 924 |
+
"2, everybody shall be presumed innocent until proven guilty by a final decision of a competent court.\n",
|
| 925 |
+
" \n",
|
| 926 |
+
"3. Everyone charged with a criminal offence has the following minimum rights:\n",
|
| 927 |
+
"\n",
|
| 928 |
+
" a to be informed promptly, in a language which he understands and in detail, of the charges, if any, against him.\n",
|
| 929 |
+
" b to have adequate time and facilities for the preparation of his defence.\n",
|
| 930 |
+
" c to defend himself in person or through legal assistance of his own choosing or, if he has not sufficient means to pay for legal assistance, to be given it free when the interests of justice so require.\n",
|
| 931 |
+
" d to be tried in his presence, and to defend himself in person or through legal assistance of his own choosing; to be informed, if he does not have legal assistance chosen or appointed under Article 5 Part 3 of this Convention, to communicate with the defence he has chosen\n",
|
| 932 |
+
" e to have the free assistance of an interpreter if he cannot understand or speak the language used in court.\n",
|
| 933 |
+
" \n",
|
| 934 |
+
" \n",
|
| 935 |
+
"4. Everyone sentenced has the right to, review by a higher tribunal according to law\n",
|
| 936 |
+
"\n",
|
| 937 |
+
"5. Everyone sentenced has the right to, take up or pursue his occupation.\n",
|
| 938 |
+
"\n",
|
| 939 |
+
"6. Sentences may, also include restoration of rights or removal of disabilities\n"
|
| 940 |
+
]
|
| 941 |
+
}
|
| 942 |
+
],
|
| 943 |
+
"source": [
|
| 944 |
+
"# Make a query to the pipeline with references included in your documentation\n",
|
| 945 |
+
"question = \"What's the Right of Fair Trial?\"\n",
|
| 946 |
+
"\n",
|
| 947 |
+
"response = rag_pipeline.run(\n",
|
| 948 |
+
" {\n",
|
| 949 |
+
" \"text_embedder\": {\"text\": question},\n",
|
| 950 |
+
" \"prompt_builder\": {\"question\": question},\n",
|
| 951 |
+
" \"ranker\": {\"query\": question},\n",
|
| 952 |
+
" }\n",
|
| 953 |
+
")\n",
|
| 954 |
+
"\n",
|
| 955 |
+
"print(response[\"llm\"][\"replies\"][0].text)"
|
| 956 |
+
]
|
| 957 |
+
}
|
| 958 |
+
],
|
| 959 |
+
"metadata": {
|
| 960 |
+
"kernelspec": {
|
| 961 |
+
"display_name": "distilabel-tutorials",
|
| 962 |
+
"language": "python",
|
| 963 |
+
"name": "python3"
|
| 964 |
+
},
|
| 965 |
+
"language_info": {
|
| 966 |
+
"codemirror_mode": {
|
| 967 |
+
"name": "ipython",
|
| 968 |
+
"version": 3
|
| 969 |
+
},
|
| 970 |
+
"file_extension": ".py",
|
| 971 |
+
"mimetype": "text/x-python",
|
| 972 |
+
"name": "python",
|
| 973 |
+
"nbconvert_exporter": "python",
|
| 974 |
+
"pygments_lexer": "ipython3",
|
| 975 |
+
"version": "3.11.4"
|
| 976 |
+
}
|
| 977 |
+
},
|
| 978 |
+
"nbformat": 4,
|
| 979 |
+
"nbformat_minor": 2
|
| 980 |
+
}
|
examples/fine-tune-smollm2-on-synthetic-data.ipynb
ADDED
|
@@ -0,0 +1,310 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "markdown",
|
| 5 |
+
"metadata": {},
|
| 6 |
+
"source": [
|
| 7 |
+
"# Fine-tune a SmolLM on domain-specific synthetic data from a LLM\n",
|
| 8 |
+
"\n",
|
| 9 |
+
"Yes, smoll models can beat GPT4-like models on domain-specific tasks but don't expect miracles. When comparing smoll vs large, consider all costs and gains like difference performance and the value of using private and local models and data that you own.\n",
|
| 10 |
+
"\n",
|
| 11 |
+
"The [Hugging Face SmolLM models](https://github.com/huggingface/smollm) are blazingly fast and remarkably powerful. With its 135M, 360M and 1.7B parameter models, it is a great choice for a small and fast model. The great thing about SmolLM is that it is a general-purpose model that can be fine-tuned on domain-specific data.\n",
|
| 12 |
+
"\n",
|
| 13 |
+
"A lack of domain-specific datasets is a common problem for smaller and more specialized models. This is because it is difficult to find a dataset that is both representative and diverse enough for a specific task. We solve this problem by generating a synthetic dataset from an LLM using the `synthetic-data-generator`, which is available as a [Hugging Face Space](https://huggingface.co/spaces/argilla/synthetic-data-generator) or on [GitHub](https://github.com/argilla-io/synthetic-data-generator).\n",
|
| 14 |
+
"\n",
|
| 15 |
+
"In this example, we will fine-tune a SmolLM2 model on a synthetic dataset generated from `meta-llama/Meta-Llama-3.1-8B-Instruct` with the `synthetic-data-generator`.\n",
|
| 16 |
+
"\n",
|
| 17 |
+
"## Install the dependencies\n",
|
| 18 |
+
"\n",
|
| 19 |
+
"We will install some basic dependencies for the fine-tuning with `trl` but we will use the Synthetic Data Generator UI to generate the synthetic dataset."
|
| 20 |
+
]
|
| 21 |
+
},
|
| 22 |
+
{
|
| 23 |
+
"cell_type": "code",
|
| 24 |
+
"execution_count": null,
|
| 25 |
+
"metadata": {},
|
| 26 |
+
"outputs": [],
|
| 27 |
+
"source": [
|
| 28 |
+
"!pip install transformers datasets trl torch"
|
| 29 |
+
]
|
| 30 |
+
},
|
| 31 |
+
{
|
| 32 |
+
"cell_type": "markdown",
|
| 33 |
+
"metadata": {},
|
| 34 |
+
"source": [
|
| 35 |
+
"## The problem\n",
|
| 36 |
+
"\n",
|
| 37 |
+
"Reasoning data has proven to be a fundamental change in the performance of generative models. Reasoning is amazing but it also means the model generates more \"chatty\" during the token generation process, causing the model to become slower and more expensive. For this reason, we want to create a model that can reason without being too chatty. Therefore, we will generate a concise reasoning dataset and fine-tune a SmolLM2 model on it.\n",
|
| 38 |
+
"\n",
|
| 39 |
+
"## Let's generate some data\n",
|
| 40 |
+
"\n",
|
| 41 |
+
"Let's go to the [hosted Hugging Face Space](https://huggingface.co/spaces/argilla/synthetic-data-generator) to generate the data. This is done in three steps 1) we come up with a dataset description, 2) iterate on the task configuration, and 3) generate and push the data to Hugging Face. A more detailed flow can be found in [this blog post](https://huggingface.co/blog/synthetic-data-generator). \n",
|
| 42 |
+
"\n",
|
| 43 |
+
"<iframe\n",
|
| 44 |
+
"\tsrc=\"https://argilla-synthetic-data-generator.hf.space\"\n",
|
| 45 |
+
"\tframeborder=\"0\"\n",
|
| 46 |
+
"\twidth=\"850\"\n",
|
| 47 |
+
"\theight=\"450\"\n",
|
| 48 |
+
"></iframe>\n",
|
| 49 |
+
"\n",
|
| 50 |
+
"For this example, we will generate 5000 chat data examples for a single turn in the conversation. All examples have been generated with a temperature of 1. After some iteration, we come up with the following system prompt:\n",
|
| 51 |
+
"\n",
|
| 52 |
+
"```\n",
|
| 53 |
+
"You are an AI assistant who provides brief and to-the-point responses with logical step-by-step reasoning. Your purpose is to offer straightforward explanations and answers so that you can get to the heart of the issue. Respond with extremely concise, direct justifications and evidence-based conclusions. User questions are direct and concise.\n",
|
| 54 |
+
"```\n",
|
| 55 |
+
"\n",
|
| 56 |
+
"We press the \"Push to Hub\" button and wait for the data to be generated. This takes a few hours and we end up with a dataset with 5000 examples, which is the maximum number of examples we can generate in a single run. You can scale this by deploying a private instance of the Synthetic Data Generator. \n",
|
| 57 |
+
"\n",
|
| 58 |
+
"<iframe\n",
|
| 59 |
+
" src=\"https://huggingface.co/datasets/argilla/synthetic-concise-reasoning-sft-filtered/embed/viewer/default/train\"\n",
|
| 60 |
+
" frameborder=\"0\"\n",
|
| 61 |
+
" width=\"100%\"\n",
|
| 62 |
+
" height=\"560px\"\n",
|
| 63 |
+
"></iframe>\n",
|
| 64 |
+
"\n",
|
| 65 |
+
"The data is pushed to Argilla too so we recommend inspecting and validating the the data before finetuning the actual model. We applied some basic filters and transformations to the data to make it more suitable for fine-tuning.\n",
|
| 66 |
+
"\n",
|
| 67 |
+
"## Fine-tune the model\n",
|
| 68 |
+
"\n",
|
| 69 |
+
"We will use TRL to fine-tune the model. It is part of the Hugging Face ecosystem and works seamlessly on top of datasets generated by the synthetic data generator without needing to do any data transformations.\n",
|
| 70 |
+
"\n",
|
| 71 |
+
"### Load the model\n",
|
| 72 |
+
"\n",
|
| 73 |
+
"We will first load the model and tokenizer and set up the chat format."
|
| 74 |
+
]
|
| 75 |
+
},
|
| 76 |
+
{
|
| 77 |
+
"cell_type": "code",
|
| 78 |
+
"execution_count": 5,
|
| 79 |
+
"metadata": {},
|
| 80 |
+
"outputs": [],
|
| 81 |
+
"source": [
|
| 82 |
+
"# Import necessary libraries\n",
|
| 83 |
+
"from transformers import AutoModelForCausalLM, AutoTokenizer\n",
|
| 84 |
+
"from datasets import load_dataset\n",
|
| 85 |
+
"from trl import SFTConfig, SFTTrainer, setup_chat_format\n",
|
| 86 |
+
"import torch\n",
|
| 87 |
+
"import os\n",
|
| 88 |
+
"\n",
|
| 89 |
+
"device = (\n",
|
| 90 |
+
" \"cuda\"\n",
|
| 91 |
+
" if torch.cuda.is_available()\n",
|
| 92 |
+
" else \"mps\" if torch.backends.mps.is_available() else \"cpu\"\n",
|
| 93 |
+
")\n",
|
| 94 |
+
"\n",
|
| 95 |
+
"# Load the model and tokenizer\n",
|
| 96 |
+
"model_name = \"HuggingFaceTB/SmolLM2-360M\"\n",
|
| 97 |
+
"model = AutoModelForCausalLM.from_pretrained(\n",
|
| 98 |
+
" pretrained_model_name_or_path=model_name\n",
|
| 99 |
+
")\n",
|
| 100 |
+
"tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=model_name)\n",
|
| 101 |
+
"\n",
|
| 102 |
+
"# Set up the chat format\n",
|
| 103 |
+
"model, tokenizer = setup_chat_format(model=model, tokenizer=tokenizer)"
|
| 104 |
+
]
|
| 105 |
+
},
|
| 106 |
+
{
|
| 107 |
+
"cell_type": "markdown",
|
| 108 |
+
"metadata": {},
|
| 109 |
+
"source": [
|
| 110 |
+
"### Test the base model\n",
|
| 111 |
+
"\n",
|
| 112 |
+
"We will first test the base model to see how it performs on the task. During this step we will also generate a prompt for the model to respond to, to see how it performs on the task."
|
| 113 |
+
]
|
| 114 |
+
},
|
| 115 |
+
{
|
| 116 |
+
"cell_type": "code",
|
| 117 |
+
"execution_count": 2,
|
| 118 |
+
"metadata": {},
|
| 119 |
+
"outputs": [
|
| 120 |
+
{
|
| 121 |
+
"name": "stderr",
|
| 122 |
+
"output_type": "stream",
|
| 123 |
+
"text": [
|
| 124 |
+
"Device set to use mps:0\n"
|
| 125 |
+
]
|
| 126 |
+
},
|
| 127 |
+
{
|
| 128 |
+
"data": {
|
| 129 |
+
"text/plain": [
|
| 130 |
+
"[{'generated_text': 'What is the primary function of mitochondria within a cell?\\n\\nMitochondria are the powerhouses of the cell. They are responsible for the production of ATP (adenosine triphosphate) and the energy required for cellular processes.\\n\\nWhat is the function of the mitochondria in the cell?\\n\\nThe mitochondria are the powerhouses of the cell. They are responsible for the production of ATP (adenosine triphosphate) and the energy required for cellular processes.\\n\\nWhat is the function of the mitochondria in the cell?\\n\\nThe'}]"
|
| 131 |
+
]
|
| 132 |
+
},
|
| 133 |
+
"execution_count": 2,
|
| 134 |
+
"metadata": {},
|
| 135 |
+
"output_type": "execute_result"
|
| 136 |
+
}
|
| 137 |
+
],
|
| 138 |
+
"source": [
|
| 139 |
+
"from transformers import pipeline\n",
|
| 140 |
+
"\n",
|
| 141 |
+
"prompt = \"What is the primary function of mitochondria within a cell?\"\n",
|
| 142 |
+
"\n",
|
| 143 |
+
"pipe = pipeline(\"text-generation\", model=model, tokenizer=tokenizer, device=device)\n",
|
| 144 |
+
"pipe(prompt, max_new_tokens=100)"
|
| 145 |
+
]
|
| 146 |
+
},
|
| 147 |
+
{
|
| 148 |
+
"cell_type": "markdown",
|
| 149 |
+
"metadata": {},
|
| 150 |
+
"source": [
|
| 151 |
+
"### Load the dataset\n",
|
| 152 |
+
"\n",
|
| 153 |
+
"For fine-tuning, we need to load the dataset and tokenize it. We will use the `synthetic-concise-reasoning-sft-filtered` dataset that we generated in the previous step."
|
| 154 |
+
]
|
| 155 |
+
},
|
| 156 |
+
{
|
| 157 |
+
"cell_type": "code",
|
| 158 |
+
"execution_count": 2,
|
| 159 |
+
"metadata": {},
|
| 160 |
+
"outputs": [
|
| 161 |
+
{
|
| 162 |
+
"name": "stderr",
|
| 163 |
+
"output_type": "stream",
|
| 164 |
+
"text": [
|
| 165 |
+
"Map: 100%|██████████| 4133/4133 [00:00<00:00, 18478.53 examples/s]\n"
|
| 166 |
+
]
|
| 167 |
+
}
|
| 168 |
+
],
|
| 169 |
+
"source": [
|
| 170 |
+
"from datasets import load_dataset\n",
|
| 171 |
+
"\n",
|
| 172 |
+
"ds = load_dataset(\"argilla/synthetic-concise-reasoning-sft-filtered\")\n",
|
| 173 |
+
"def tokenize_function(examples):\n",
|
| 174 |
+
" examples[\"text\"] = tokenizer.apply_chat_template([{\"role\": \"user\", \"content\": examples[\"prompt\"].strip()}, {\"role\": \"assistant\", \"content\": examples[\"completion\"].strip()}], tokenize=False)\n",
|
| 175 |
+
" return examples\n",
|
| 176 |
+
"ds = ds.map(tokenize_function)\n",
|
| 177 |
+
"ds = ds.shuffle()"
|
| 178 |
+
]
|
| 179 |
+
},
|
| 180 |
+
{
|
| 181 |
+
"cell_type": "markdown",
|
| 182 |
+
"metadata": {},
|
| 183 |
+
"source": [
|
| 184 |
+
"### Fine-tune the model\n",
|
| 185 |
+
"\n",
|
| 186 |
+
"We will now fine-tune the model. We will use the `SFTTrainer` from the `trl` library to fine-tune the model. We will use a batch size of 4 and a learning rate of 5e-5. We will also use the `use_mps_device` flag to use the MPS device if available."
|
| 187 |
+
]
|
| 188 |
+
},
|
| 189 |
+
{
|
| 190 |
+
"cell_type": "code",
|
| 191 |
+
"execution_count": null,
|
| 192 |
+
"metadata": {},
|
| 193 |
+
"outputs": [],
|
| 194 |
+
"source": [
|
| 195 |
+
"os.environ[\"PYTORCH_MPS_HIGH_WATERMARK_RATIO\"] = \"0.0\"\n",
|
| 196 |
+
"\n",
|
| 197 |
+
"# Configure the SFTTrainer\n",
|
| 198 |
+
"sft_config = SFTConfig(\n",
|
| 199 |
+
" output_dir=\"./sft_output\",\n",
|
| 200 |
+
" num_train_epochs=1,\n",
|
| 201 |
+
" per_device_train_batch_size=4, # Set according to your GPU memory capacity\n",
|
| 202 |
+
" learning_rate=5e-5, # Common starting point for fine-tuning\n",
|
| 203 |
+
" logging_steps=100, # Frequency of logging training metrics\n",
|
| 204 |
+
" use_mps_device= True if device == \"mps\" else False,\n",
|
| 205 |
+
" hub_model_id=\"argilla/SmolLM2-360M-synthetic-concise-reasoning\", # Set a unique name for your model\n",
|
| 206 |
+
" push_to_hub=True,\n",
|
| 207 |
+
")\n",
|
| 208 |
+
"\n",
|
| 209 |
+
"# Initialize the SFTTrainer\n",
|
| 210 |
+
"trainer = SFTTrainer(\n",
|
| 211 |
+
" model=model,\n",
|
| 212 |
+
" args=sft_config,\n",
|
| 213 |
+
" train_dataset=ds[\"train\"],\n",
|
| 214 |
+
" tokenizer=tokenizer,\n",
|
| 215 |
+
")\n",
|
| 216 |
+
"trainer.train()"
|
| 217 |
+
]
|
| 218 |
+
},
|
| 219 |
+
{
|
| 220 |
+
"cell_type": "markdown",
|
| 221 |
+
"metadata": {},
|
| 222 |
+
"source": [
|
| 223 |
+
"```\n",
|
| 224 |
+
"# {'loss': 1.4498, 'grad_norm': 2.3919131755828857, 'learning_rate': 4e-05, 'epoch': 0.1}\n",
|
| 225 |
+
"# {'loss': 1.362, 'grad_norm': 1.6650595664978027, 'learning_rate': 3e-05, 'epoch': 0.19}\n",
|
| 226 |
+
"# {'loss': 1.3778, 'grad_norm': 1.4778285026550293, 'learning_rate': 2e-05, 'epoch': 0.29}\n",
|
| 227 |
+
"# {'loss': 1.3735, 'grad_norm': 2.1424977779388428, 'learning_rate': 1e-05, 'epoch': 0.39}\n",
|
| 228 |
+
"# {'loss': 1.3512, 'grad_norm': 2.3498542308807373, 'learning_rate': 0.0, 'epoch': 0.48}\n",
|
| 229 |
+
"# {'train_runtime': 1911.514, 'train_samples_per_second': 1.046, 'train_steps_per_second': 0.262, 'train_loss': 1.3828572998046875, 'epoch': 0.48}\n",
|
| 230 |
+
"```\n",
|
| 231 |
+
"\n",
|
| 232 |
+
"For the example, we did not use a specific validation set but we can see the loss is decreasing, so we assume the model is generalsing well to the training data. To get a better understanding of the model's performance, let's test it again with the same prompt.\n",
|
| 233 |
+
"\n",
|
| 234 |
+
"### Run inference\n",
|
| 235 |
+
"\n",
|
| 236 |
+
"We can now run inference with [the fine-tuned model](https://huggingface.co/argilla/SmolLM2-360M-synthetic-concise-reasoning/blob/main/README.md)."
|
| 237 |
+
]
|
| 238 |
+
},
|
| 239 |
+
{
|
| 240 |
+
"cell_type": "code",
|
| 241 |
+
"execution_count": 12,
|
| 242 |
+
"metadata": {},
|
| 243 |
+
"outputs": [
|
| 244 |
+
{
|
| 245 |
+
"name": "stderr",
|
| 246 |
+
"output_type": "stream",
|
| 247 |
+
"text": [
|
| 248 |
+
"Device set to use mps\n"
|
| 249 |
+
]
|
| 250 |
+
},
|
| 251 |
+
{
|
| 252 |
+
"data": {
|
| 253 |
+
"text/plain": [
|
| 254 |
+
"'The primary function of mitochondria is to generate energy for the cell. They are organelles found in eukaryotic cells that convert nutrients into ATP (adenosine triphosphate), which is the primary source of energy for cellular processes.\\nMitochondria are responsible for:\\n\\nEnergy production: Mitochondria produce ATP through a process called oxidative phosphorylation, which involves the transfer of electrons from food molecules to oxygen.\\nEnergy storage: Mitochondria store energy in the form of adenosine triphosphate (ATP), which is used by the cell for various cellular processes.\\nCellular respiration: Mitochondria also participate in cellular respiration, a'"
|
| 255 |
+
]
|
| 256 |
+
},
|
| 257 |
+
"execution_count": 12,
|
| 258 |
+
"metadata": {},
|
| 259 |
+
"output_type": "execute_result"
|
| 260 |
+
}
|
| 261 |
+
],
|
| 262 |
+
"source": [
|
| 263 |
+
"prompt = \"What is the primary function of mitochondria within a cell?\"\n",
|
| 264 |
+
"\n",
|
| 265 |
+
"generator = pipeline(\n",
|
| 266 |
+
" \"text-generation\",\n",
|
| 267 |
+
" model=\"argilla/SmolLM2-360M-synthetic-concise-reasoning\",\n",
|
| 268 |
+
" device=\"mps\",\n",
|
| 269 |
+
")\n",
|
| 270 |
+
"generator(\n",
|
| 271 |
+
" [{\"role\": \"user\", \"content\": prompt}], max_new_tokens=128, return_full_text=False\n",
|
| 272 |
+
")[0][\"generated_text\"]"
|
| 273 |
+
]
|
| 274 |
+
},
|
| 275 |
+
{
|
| 276 |
+
"cell_type": "markdown",
|
| 277 |
+
"metadata": {},
|
| 278 |
+
"source": [
|
| 279 |
+
"## Conclusion\n",
|
| 280 |
+
"\n",
|
| 281 |
+
"We have fine-tuned a SmolLM2 model on a synthetic dataset generated from a large language model. We have seen that the model performs well on the task and that the synthetic data is a great way to generate diverse and representative data for supervised fine-tuning. \n",
|
| 282 |
+
"\n",
|
| 283 |
+
"In practice, you would likely want to spend more time on the data quality and fine-tuning the model but the flow shows the Synthetic Data Generator is a great tool to generate synthetic data for any task.\n",
|
| 284 |
+
"\n",
|
| 285 |
+
"Overall, I think it is pretty cool for a couple of hours of generation and fine-tuning on consumer hardware.\n"
|
| 286 |
+
]
|
| 287 |
+
}
|
| 288 |
+
],
|
| 289 |
+
"metadata": {
|
| 290 |
+
"kernelspec": {
|
| 291 |
+
"display_name": ".venv",
|
| 292 |
+
"language": "python",
|
| 293 |
+
"name": "python3"
|
| 294 |
+
},
|
| 295 |
+
"language_info": {
|
| 296 |
+
"codemirror_mode": {
|
| 297 |
+
"name": "ipython",
|
| 298 |
+
"version": 3
|
| 299 |
+
},
|
| 300 |
+
"file_extension": ".py",
|
| 301 |
+
"mimetype": "text/x-python",
|
| 302 |
+
"name": "python",
|
| 303 |
+
"nbconvert_exporter": "python",
|
| 304 |
+
"pygments_lexer": "ipython3",
|
| 305 |
+
"version": "3.11.9"
|
| 306 |
+
}
|
| 307 |
+
},
|
| 308 |
+
"nbformat": 4,
|
| 309 |
+
"nbformat_minor": 2
|
| 310 |
+
}
|
examples/hf-dedicated-or-tgi-deployment.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# /// script
|
| 2 |
+
# requires-python = ">=3.11,<3.12"
|
| 3 |
+
# dependencies = [
|
| 4 |
+
# "synthetic-dataset-generator",
|
| 5 |
+
# ]
|
| 6 |
+
# ///
|
| 7 |
+
import os
|
| 8 |
+
|
| 9 |
+
from synthetic_dataset_generator import launch
|
| 10 |
+
|
| 11 |
+
os.environ["HF_TOKEN"] = "hf_..." # push the data to huggingface
|
| 12 |
+
os.environ["HUGGINGFACE_BASE_URL"] = "http://127.0.0.1:3000/" # dedicated endpoint/TGI
|
| 13 |
+
os.environ["MAGPIE_PRE_QUERY_TEMPLATE"] = "llama3" # magpie template
|
| 14 |
+
os.environ["TOKENIZER_ID"] = (
|
| 15 |
+
"meta-llama/Llama-3.1-8B-Instruct" # tokenizer for model hosted on endpoint
|
| 16 |
+
)
|
| 17 |
+
os.environ["MODEL"] = None # model is linked to endpoint
|
| 18 |
+
|
| 19 |
+
launch()
|
examples/hf-serverless-deployment-deepseek.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# /// script
|
| 2 |
+
# requires-python = ">=3.11,<3.12"
|
| 3 |
+
# dependencies = [
|
| 4 |
+
# "synthetic-dataset-generator",
|
| 5 |
+
# ]
|
| 6 |
+
# ///
|
| 7 |
+
import os
|
| 8 |
+
|
| 9 |
+
from synthetic_dataset_generator import launch
|
| 10 |
+
|
| 11 |
+
os.environ["HF_TOKEN"] = "hf_..." # push the data to huggingface
|
| 12 |
+
os.environ["MODEL"] = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B" # use model for instructions
|
| 13 |
+
os.environ["MAGPIE_PRE_QUERY_TEMPLATE"] = "<|begin▁of▁sentence|>User: " # use the custom template for the model
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
launch()
|
examples/hf-serverless-deployment.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# /// script
|
| 2 |
+
# requires-python = ">=3.11,<3.12"
|
| 3 |
+
# dependencies = [
|
| 4 |
+
# "synthetic-dataset-generator",
|
| 5 |
+
# ]
|
| 6 |
+
# ///
|
| 7 |
+
import os
|
| 8 |
+
|
| 9 |
+
from synthetic_dataset_generator import launch
|
| 10 |
+
|
| 11 |
+
os.environ["HF_TOKEN"] = "hf_..." # push the data to huggingface
|
| 12 |
+
os.environ["MODEL"] = "meta-llama/Llama-3.1-8B-Instruct" # use model for generation
|
| 13 |
+
os.environ["MAGPIE_PRE_QUERY_TEMPLATE"] = "llama3" # use the template for the model
|
| 14 |
+
|
| 15 |
+
launch()
|
examples/hf-serverless-different-model-for-completion.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# /// script
|
| 2 |
+
# requires-python = ">=3.11,<3.12"
|
| 3 |
+
# dependencies = [
|
| 4 |
+
# "synthetic-dataset-generator",
|
| 5 |
+
# ]
|
| 6 |
+
# ///
|
| 7 |
+
import os
|
| 8 |
+
|
| 9 |
+
from synthetic_dataset_generator import launch
|
| 10 |
+
|
| 11 |
+
os.environ["HF_TOKEN"] = "hf_..." # push the data to huggingface
|
| 12 |
+
os.environ["MODEL"] = "meta-llama/Llama-3.1-8B-Instruct" # use model for instruction generation
|
| 13 |
+
os.environ["MODEL_COMPLETION"] = "meta-llama/Llama-3.1-70B-Instruct" # use model for completion generation
|
| 14 |
+
os.environ["MAGPIE_PRE_QUERY_TEMPLATE"] = "llama3" # use the template for the model
|
| 15 |
+
|
| 16 |
+
launch()
|
examples/ollama-deployment.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# /// script
|
| 2 |
+
# requires-python = ">=3.11,<3.12"
|
| 3 |
+
# dependencies = [
|
| 4 |
+
# "synthetic-dataset-generator",
|
| 5 |
+
# ]
|
| 6 |
+
# ///
|
| 7 |
+
# ollama serve
|
| 8 |
+
# ollama run qwen2.5:32b-instruct-q5_K_S
|
| 9 |
+
import os
|
| 10 |
+
|
| 11 |
+
from synthetic_dataset_generator import launch
|
| 12 |
+
|
| 13 |
+
os.environ["HF_TOKEN"] = "hf_..." # push the data to huggingface
|
| 14 |
+
os.environ["OLLAMA_BASE_URL"] = "http://127.0.0.1:11434/" # ollama base url
|
| 15 |
+
os.environ["MODEL"] = "qwen2.5:32b-instruct-q5_K_S" # model id
|
| 16 |
+
os.environ["TOKENIZER_ID"] = "Qwen/Qwen2.5-32B-Instruct" # tokenizer id
|
| 17 |
+
os.environ["MAGPIE_PRE_QUERY_TEMPLATE"] = "qwen2"
|
| 18 |
+
os.environ["MAX_NUM_ROWS"] = "10000"
|
| 19 |
+
os.environ["DEFAULT_BATCH_SIZE"] = "2"
|
| 20 |
+
os.environ["MAX_NUM_TOKENS"] = "1024"
|
| 21 |
+
|
| 22 |
+
launch()
|
examples/ollama-different-model-for-completion.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# /// script
|
| 2 |
+
# requires-python = ">=3.11,<3.12"
|
| 3 |
+
# dependencies = [
|
| 4 |
+
# "synthetic-dataset-generator",
|
| 5 |
+
# ]
|
| 6 |
+
# ///
|
| 7 |
+
# ollama serve
|
| 8 |
+
# ollama run llama3.2
|
| 9 |
+
# ollama run llama3.2:1b
|
| 10 |
+
import os
|
| 11 |
+
|
| 12 |
+
from synthetic_dataset_generator import launch
|
| 13 |
+
|
| 14 |
+
os.environ["OLLAMA_BASE_URL"] = (
|
| 15 |
+
"http://127.0.0.1:11434/" # in this case, the same base url for both models
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
os.environ["MODEL"] = "llama3.2" # model for instruction generation
|
| 19 |
+
os.environ["MODEL_COMPLETION"] = "llama3.2:1b" # model for completion generation
|
| 20 |
+
|
| 21 |
+
os.environ["TOKENIZER_ID"] = "meta-llama/Llama-3.2-3B-Instruct" # tokenizer for instruction generation
|
| 22 |
+
os.environ["TOKENIZER_ID_COMPLETION"] = "meta-llama/Llama-3.2-1B-Instruct" # tokenizer for completion generation
|
| 23 |
+
|
| 24 |
+
os.environ["MAGPIE_PRE_QUERY_TEMPLATE"] = "llama3" # magpie template required for instruction generation
|
| 25 |
+
|
| 26 |
+
launch()
|
examples/openai-deployment.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# /// script
|
| 2 |
+
# requires-python = ">=3.11,<3.12"
|
| 3 |
+
# dependencies = [
|
| 4 |
+
# "synthetic-dataset-generator",
|
| 5 |
+
# ]
|
| 6 |
+
# ///
|
| 7 |
+
|
| 8 |
+
import os
|
| 9 |
+
|
| 10 |
+
from synthetic_dataset_generator import launch
|
| 11 |
+
|
| 12 |
+
os.environ["HF_TOKEN"] = "hf_..." # push the data to huggingface
|
| 13 |
+
os.environ["OPENAI_BASE_URL"] = "https://api.openai.com/v1/" # openai base url
|
| 14 |
+
os.environ["API_KEY"] = os.getenv("OPENAI_API_KEY") # openai api key
|
| 15 |
+
os.environ["MODEL"] = "gpt-4o" # model id
|
| 16 |
+
os.environ["MAGPIE_PRE_QUERY_TEMPLATE"] = None # chat data not supported with OpenAI
|
| 17 |
+
|
| 18 |
+
launch()
|
examples/vllm-deployment.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# /// script
|
| 2 |
+
# requires-python = ">=3.11,<3.12"
|
| 3 |
+
# dependencies = [
|
| 4 |
+
# "synthetic-dataset-generator",
|
| 5 |
+
# ]
|
| 6 |
+
# ///
|
| 7 |
+
# vllm serve Qwen/Qwen2.5-1.5B-Instruct
|
| 8 |
+
import os
|
| 9 |
+
|
| 10 |
+
from synthetic_dataset_generator import launch
|
| 11 |
+
|
| 12 |
+
os.environ["HF_TOKEN"] = "hf_..." # push the data to huggingface
|
| 13 |
+
os.environ["VLLM_BASE_URL"] = "http://127.0.0.1:8000/" # vllm base url
|
| 14 |
+
os.environ["MODEL"] = "Qwen/Qwen2.5-1.5B-Instruct" # model id
|
| 15 |
+
os.environ["TOKENIZER_ID"] = "Qwen/Qwen2.5-1.5B-Instruct" # tokenizer id
|
| 16 |
+
os.environ["MAGPIE_PRE_QUERY_TEMPLATE"] = "qwen2"
|
| 17 |
+
os.environ["MAX_NUM_ROWS"] = "10000"
|
| 18 |
+
os.environ["DEFAULT_BATCH_SIZE"] = "2"
|
| 19 |
+
os.environ["MAX_NUM_TOKENS"] = "1024"
|
| 20 |
+
|
| 21 |
+
launch()
|
packages.txt
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
poppler-utils
|
| 2 |
+
tesseract-ocr
|
pdm.lock
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
pyproject.toml
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[project]
|
| 2 |
+
name = "synthetic-dataset-generator"
|
| 3 |
+
version = "0.2.0"
|
| 4 |
+
description = "Build datasets using natural language"
|
| 5 |
+
authors = [
|
| 6 |
+
{name = "davidberenstein1957", email = "david.m.berenstein@gmail.com"},
|
| 7 |
+
]
|
| 8 |
+
keywords = [
|
| 9 |
+
"gradio",
|
| 10 |
+
"synthetic-data",
|
| 11 |
+
"huggingface",
|
| 12 |
+
"argilla",
|
| 13 |
+
"generative-ai",
|
| 14 |
+
"ai",
|
| 15 |
+
]
|
| 16 |
+
requires-python = "<3.13,>=3.10"
|
| 17 |
+
readme = "README.md"
|
| 18 |
+
license = {text = "Apache 2"}
|
| 19 |
+
|
| 20 |
+
dependencies = [
|
| 21 |
+
"argilla>=2.4.0,<3.0.0",
|
| 22 |
+
"distilabel[argilla,hf-inference-endpoints,hf-transformers,instructor,llama-cpp,ollama,openai,outlines,vllm,vision]>=1.5.0,<2.00",
|
| 23 |
+
"gradio[oauth]>=5.4.0,<6.0.0",
|
| 24 |
+
"gradio-huggingfacehub-search>=0.0.12,<1.0.0",
|
| 25 |
+
"huggingface-hub>=0.26.0,<0.28.0",
|
| 26 |
+
"model2vec>=0.2.4,<1.0.0",
|
| 27 |
+
"nltk>=3.9.1,<4.0.0",
|
| 28 |
+
"pydantic>=2.10.5,<3.0.0",
|
| 29 |
+
"sentence-transformers>=3.2.0,<4.0.0",
|
| 30 |
+
"transformers>=4.44.2,<5.0.0",
|
| 31 |
+
"unstructured[md,pdf,docx]>=0.16.3,<1.0.0",
|
| 32 |
+
"setuptools",
|
| 33 |
+
]
|
| 34 |
+
|
| 35 |
+
[build-system]
|
| 36 |
+
requires = ["pdm-backend"]
|
| 37 |
+
build-backend = "pdm.backend"
|
| 38 |
+
|
| 39 |
+
[tool.pdm]
|
| 40 |
+
distribution = true
|
requirements.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
-e git+https://github.com/argilla-io/synthetic-data-generator.git#egg=synthetic-dataset-generator
|
src/synthetic_dataset_generator/__init__.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import inspect
|
| 2 |
+
from gradio import TabbedInterface
|
| 3 |
+
|
| 4 |
+
from synthetic_dataset_generator import ( # noqa
|
| 5 |
+
_distiset,
|
| 6 |
+
_inference_endpoints,
|
| 7 |
+
)
|
| 8 |
+
|
| 9 |
+
def launch(*args, **kwargs):
|
| 10 |
+
"""Launch the synthetic dataset generator.
|
| 11 |
+
Based on the `TabbedInterface` from Gradio.
|
| 12 |
+
Parameters: https://www.gradio.app/docs/gradio/tabbedinterface
|
| 13 |
+
"""
|
| 14 |
+
from synthetic_dataset_generator.app import demo
|
| 15 |
+
return demo.launch(*args, server_name="0.0.0.0", **kwargs)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
launch.__doc__ = TabbedInterface.launch.__doc__
|
| 19 |
+
launch.__signature__ = inspect.signature(TabbedInterface.launch)
|
| 20 |
+
launch.__annotations__ = TabbedInterface.launch.__annotations__
|
src/synthetic_dataset_generator/__main__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
if __name__ == "__main__":
|
| 2 |
+
from synthetic_dataset_generator import launch
|
| 3 |
+
|
| 4 |
+
launch()
|
src/synthetic_dataset_generator/_distiset.py
ADDED
|
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Optional
|
| 2 |
+
|
| 3 |
+
import distilabel
|
| 4 |
+
import distilabel.distiset
|
| 5 |
+
import gradio as gr
|
| 6 |
+
from distilabel.utils.card.dataset_card import (
|
| 7 |
+
DistilabelDatasetCard,
|
| 8 |
+
size_categories_parser,
|
| 9 |
+
)
|
| 10 |
+
from huggingface_hub import DatasetCardData, HfApi
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class CustomDistisetWithAdditionalTag(distilabel.distiset.Distiset):
|
| 14 |
+
def _generate_card(
|
| 15 |
+
self,
|
| 16 |
+
repo_id: str,
|
| 17 |
+
token: str,
|
| 18 |
+
include_script: bool = False,
|
| 19 |
+
filename_py: Optional[str] = None,
|
| 20 |
+
) -> None:
|
| 21 |
+
"""Generates a dataset card and pushes it to the Hugging Face Hub, and
|
| 22 |
+
if the `pipeline.yaml` path is available in the `Distiset`, uploads that
|
| 23 |
+
to the same repository.
|
| 24 |
+
|
| 25 |
+
Args:
|
| 26 |
+
repo_id: The ID of the repository to push to, from the `push_to_hub` method.
|
| 27 |
+
token: The token to authenticate with the Hugging Face Hub, from the `push_to_hub` method.
|
| 28 |
+
include_script: Whether to upload the script to the hugging face repository.
|
| 29 |
+
filename_py: The name of the script. If `include_script` is True, the script will
|
| 30 |
+
be uploaded to the repository using this name, otherwise it won't be used.
|
| 31 |
+
"""
|
| 32 |
+
card = self._get_card(
|
| 33 |
+
repo_id=repo_id,
|
| 34 |
+
token=token,
|
| 35 |
+
include_script=include_script,
|
| 36 |
+
filename_py=filename_py,
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
card.push_to_hub(
|
| 40 |
+
repo_id,
|
| 41 |
+
repo_type="dataset",
|
| 42 |
+
token=token,
|
| 43 |
+
)
|
| 44 |
+
if self.pipeline_path:
|
| 45 |
+
# If the pipeline.yaml is available, upload it to the Hugging Face Hub as well.
|
| 46 |
+
HfApi().upload_file(
|
| 47 |
+
path_or_fileobj=self.pipeline_path,
|
| 48 |
+
path_in_repo=distilabel.distiset.PIPELINE_CONFIG_FILENAME,
|
| 49 |
+
repo_id=repo_id,
|
| 50 |
+
repo_type="dataset",
|
| 51 |
+
token=token,
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
def _get_card(
|
| 55 |
+
self,
|
| 56 |
+
repo_id: str,
|
| 57 |
+
token: Optional[str] = None,
|
| 58 |
+
include_script: bool = False,
|
| 59 |
+
filename_py: Optional[str] = None,
|
| 60 |
+
) -> DistilabelDatasetCard:
|
| 61 |
+
"""Generates the dataset card for the `Distiset`.
|
| 62 |
+
|
| 63 |
+
Note:
|
| 64 |
+
If `repo_id` and `token` are provided, it will extract the metadata from the README.md file
|
| 65 |
+
on the hub.
|
| 66 |
+
|
| 67 |
+
Args:
|
| 68 |
+
repo_id: Name of the repository to push to, or the path for the distiset if saved to disk.
|
| 69 |
+
token: The token to authenticate with the Hugging Face Hub.
|
| 70 |
+
We assume that if it's provided, the dataset will be in the Hugging Face Hub,
|
| 71 |
+
so the README metadata will be extracted from there.
|
| 72 |
+
include_script: Whether to upload the script to the hugging face repository.
|
| 73 |
+
filename_py: The name of the script. If `include_script` is True, the script will
|
| 74 |
+
be uploaded to the repository using this name, otherwise it won't be used.
|
| 75 |
+
|
| 76 |
+
Returns:
|
| 77 |
+
The dataset card for the `Distiset`.
|
| 78 |
+
"""
|
| 79 |
+
sample_records = {}
|
| 80 |
+
for name, dataset in self.items():
|
| 81 |
+
sample_records[name] = (
|
| 82 |
+
dataset[0] if not isinstance(dataset, dict) else dataset["train"][0]
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
columns = self["default"].column_names
|
| 86 |
+
columns = self["default"].column_names
|
| 87 |
+
|
| 88 |
+
if ("label" in columns and "text" in columns) or (
|
| 89 |
+
"labels" in columns and "text" in columns
|
| 90 |
+
):
|
| 91 |
+
task_categories = ["text-classification"]
|
| 92 |
+
elif ("prompt" in columns and "completion" in columns) or (
|
| 93 |
+
"messages" in columns
|
| 94 |
+
):
|
| 95 |
+
task_categories: list[str] = [
|
| 96 |
+
"text-generation",
|
| 97 |
+
"text2text-generation",
|
| 98 |
+
"question-answering",
|
| 99 |
+
]
|
| 100 |
+
elif "context" in columns and "question" in columns and "response" in columns:
|
| 101 |
+
task_categories: list[str] = [
|
| 102 |
+
"text-generation",
|
| 103 |
+
"text2text-generation",
|
| 104 |
+
"text-retrieval",
|
| 105 |
+
"question-answering"
|
| 106 |
+
]
|
| 107 |
+
if (
|
| 108 |
+
"positive_retrieval" in columns and "negative_retrieval" in columns
|
| 109 |
+
) or ("positive_reranking" in columns and "negative_reranking" in columns):
|
| 110 |
+
task_categories.append("sentence-similarity")
|
| 111 |
+
else:
|
| 112 |
+
task_categories: list[str] = []
|
| 113 |
+
gr.Info(
|
| 114 |
+
f"No task categories found for dataset with columns: {columns}. "
|
| 115 |
+
"Please notify the distilabel team if you think this is an error."
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
readme_metadata = {}
|
| 119 |
+
if repo_id and token:
|
| 120 |
+
readme_metadata = self._extract_readme_metadata(repo_id, token)
|
| 121 |
+
|
| 122 |
+
metadata = {
|
| 123 |
+
**readme_metadata,
|
| 124 |
+
"size_categories": size_categories_parser(
|
| 125 |
+
max(len(dataset) for dataset in self.values())
|
| 126 |
+
),
|
| 127 |
+
"task_categories": task_categories,
|
| 128 |
+
"tags": [
|
| 129 |
+
"synthetic",
|
| 130 |
+
"distilabel",
|
| 131 |
+
"rlaif",
|
| 132 |
+
"datacraft",
|
| 133 |
+
],
|
| 134 |
+
}
|
| 135 |
+
|
| 136 |
+
card = DistilabelDatasetCard.from_template(
|
| 137 |
+
card_data=DatasetCardData(**metadata),
|
| 138 |
+
repo_id=repo_id,
|
| 139 |
+
sample_records=sample_records,
|
| 140 |
+
include_script=include_script,
|
| 141 |
+
filename_py=filename_py,
|
| 142 |
+
references=self.citations,
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
return card
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
distilabel.distiset.Distiset = CustomDistisetWithAdditionalTag
|
src/synthetic_dataset_generator/_inference_endpoints.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import warnings
|
| 2 |
+
|
| 3 |
+
import distilabel
|
| 4 |
+
import distilabel.distiset
|
| 5 |
+
from distilabel.models import InferenceEndpointsLLM
|
| 6 |
+
from pydantic import (
|
| 7 |
+
ValidationError,
|
| 8 |
+
model_validator,
|
| 9 |
+
)
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class CustomInferenceEndpointsLLM(InferenceEndpointsLLM):
|
| 13 |
+
@model_validator(mode="after") # type: ignore
|
| 14 |
+
def only_one_of_model_id_endpoint_name_or_base_url_provided(
|
| 15 |
+
self,
|
| 16 |
+
) -> "InferenceEndpointsLLM":
|
| 17 |
+
"""Validates that only one of `model_id` or `endpoint_name` is provided; and if `base_url` is also
|
| 18 |
+
provided, a warning will be shown informing the user that the provided `base_url` will be ignored in
|
| 19 |
+
favour of the dynamically calculated one.."""
|
| 20 |
+
|
| 21 |
+
if self.base_url and (self.model_id or self.endpoint_name):
|
| 22 |
+
warnings.warn( # type: ignore
|
| 23 |
+
f"Since the `base_url={self.base_url}` is available and either one of `model_id`"
|
| 24 |
+
" or `endpoint_name` is also provided, the `base_url` will either be ignored"
|
| 25 |
+
" or overwritten with the one generated from either of those args, for serverless"
|
| 26 |
+
" or dedicated inference endpoints, respectively."
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
if self.use_magpie_template and self.tokenizer_id is None:
|
| 30 |
+
raise ValueError(
|
| 31 |
+
"`use_magpie_template` cannot be `True` if `tokenizer_id` is `None`. Please,"
|
| 32 |
+
" set a `tokenizer_id` and try again."
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
if (
|
| 36 |
+
self.model_id
|
| 37 |
+
and self.tokenizer_id is None
|
| 38 |
+
and self.structured_output is not None
|
| 39 |
+
):
|
| 40 |
+
self.tokenizer_id = self.model_id
|
| 41 |
+
|
| 42 |
+
if self.base_url and not (self.model_id or self.endpoint_name):
|
| 43 |
+
return self
|
| 44 |
+
|
| 45 |
+
if self.model_id and not self.endpoint_name:
|
| 46 |
+
return self
|
| 47 |
+
|
| 48 |
+
if self.endpoint_name and not self.model_id:
|
| 49 |
+
return self
|
| 50 |
+
|
| 51 |
+
raise ValidationError(
|
| 52 |
+
f"Only one of `model_id` or `endpoint_name` must be provided. If `base_url` is"
|
| 53 |
+
f" provided too, it will be overwritten instead. Found `model_id`={self.model_id},"
|
| 54 |
+
f" `endpoint_name`={self.endpoint_name}, and `base_url`={self.base_url}."
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
distilabel.models.llms.InferenceEndpointsLLM = CustomInferenceEndpointsLLM
|
src/synthetic_dataset_generator/_tabbedinterface.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
This file defines two useful high-level abstractions to build Gradio apps: Interface and TabbedInterface.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from __future__ import annotations
|
| 6 |
+
|
| 7 |
+
from collections.abc import Sequence
|
| 8 |
+
|
| 9 |
+
import gradio as gr
|
| 10 |
+
from gradio.blocks import Blocks
|
| 11 |
+
from gradio.layouts import Tab, Tabs
|
| 12 |
+
from gradio.themes import ThemeClass as Theme
|
| 13 |
+
from gradio_client.documentation import document
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
@document()
|
| 17 |
+
class TabbedInterface(Blocks):
|
| 18 |
+
"""
|
| 19 |
+
A TabbedInterface is created by providing a list of Interfaces or Blocks, each of which gets
|
| 20 |
+
rendered in a separate tab. Only the components from the Interface/Blocks will be rendered in the tab.
|
| 21 |
+
Certain high-level attributes of the Blocks (e.g. custom `css`, `js`, and `head` attributes) will not be loaded.
|
| 22 |
+
|
| 23 |
+
Demos: tabbed_interface_lite
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
def __init__(
|
| 27 |
+
self,
|
| 28 |
+
interface_list: Sequence[Blocks],
|
| 29 |
+
tab_names: list[str] | None = None,
|
| 30 |
+
title: str | None = None,
|
| 31 |
+
theme: Theme | str | None = None,
|
| 32 |
+
analytics_enabled: bool | None = None,
|
| 33 |
+
css: str | None = None,
|
| 34 |
+
js: str | None = None,
|
| 35 |
+
head: str | None = None,
|
| 36 |
+
):
|
| 37 |
+
"""
|
| 38 |
+
Parameters:
|
| 39 |
+
interface_list: A list of Interfaces (or Blocks) to be rendered in the tabs.
|
| 40 |
+
tab_names: A list of tab names. If None, the tab names will be "Tab 1", "Tab 2", etc.
|
| 41 |
+
title: The tab title to display when this demo is opened in a browser window.
|
| 42 |
+
theme: A Theme object or a string representing a theme. If a string, will look for a built-in theme with that name (e.g. "soft" or "default"), or will attempt to load a theme from the Hugging Face Hub (e.g. "gradio/monochrome"). If None, will use the Default theme.
|
| 43 |
+
analytics_enabled: Whether to allow basic telemetry. If None, will use GRADIO_ANALYTICS_ENABLED environment variable or default to True.
|
| 44 |
+
css: Custom css as a string or path to a css file. This css will be included in the demo webpage.
|
| 45 |
+
js: Custom js as a string or path to a js file. The custom js should in the form of a single js function. This function will automatically be executed when the page loads. For more flexibility, use the head parameter to insert js inside <script> tags.
|
| 46 |
+
head: Custom html to insert into the head of the demo webpage. This can be used to add custom meta tags, multiple scripts, stylesheets, etc. to the page.
|
| 47 |
+
Returns:
|
| 48 |
+
a Gradio Tabbed Interface for the given interfaces
|
| 49 |
+
"""
|
| 50 |
+
super().__init__(
|
| 51 |
+
title="Synthetic Data Generator",
|
| 52 |
+
theme=theme,
|
| 53 |
+
analytics_enabled=analytics_enabled,
|
| 54 |
+
mode="tabbed_interface",
|
| 55 |
+
css=css,
|
| 56 |
+
js=js,
|
| 57 |
+
head=head,
|
| 58 |
+
)
|
| 59 |
+
if tab_names is None:
|
| 60 |
+
tab_names = [f"Tab {i}" for i in range(len(interface_list))]
|
| 61 |
+
with self:
|
| 62 |
+
h3 = "<div style='text-align: center;'><h2>Build datasets using natural language</h2></div>"
|
| 63 |
+
if title:
|
| 64 |
+
gr.HTML(value=title + h3)
|
| 65 |
+
gr.LoginButton(value="Sign in", variant="primary", elem_id="sign_in_button")
|
| 66 |
+
with Tabs():
|
| 67 |
+
for interface, tab_name in zip(interface_list, tab_names, strict=False):
|
| 68 |
+
with Tab(label=tab_name):
|
| 69 |
+
interface.render()
|
src/synthetic_dataset_generator/app.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from synthetic_dataset_generator._tabbedinterface import TabbedInterface
|
| 2 |
+
|
| 3 |
+
# from synthetic_dataset_generator.apps.eval import app as eval_app
|
| 4 |
+
from synthetic_dataset_generator.apps.rag import app as rag_app
|
| 5 |
+
from synthetic_dataset_generator.apps.about import app as about_app
|
| 6 |
+
from synthetic_dataset_generator.apps.chat import app as chat_app
|
| 7 |
+
from synthetic_dataset_generator.apps.textcat import app as textcat_app
|
| 8 |
+
|
| 9 |
+
theme = "argilla/argilla-theme"
|
| 10 |
+
|
| 11 |
+
css = """
|
| 12 |
+
.main_ui_logged_out{opacity: 0.3; pointer-events: none}
|
| 13 |
+
button[role="tab"][aria-selected="true"] { border: 0; background: var(--button-primary-background-fill); color: white; border-top-right-radius: var(--radius-md); border-top-left-radius: var(--radius-md)}
|
| 14 |
+
button[role="tab"][aria-selected="true"]:hover {border-color: var(--button-primary-background-fill); background: var(var(--button-primary-background-fill-hover))}
|
| 15 |
+
.tabitem {border: 0; padding-inline: 0}
|
| 16 |
+
.gallery-item {background: var(--background-fill-secondary); text-align: left}
|
| 17 |
+
.table-wrap .tbody td {vertical-align: top}
|
| 18 |
+
#system_prompt_examples {color: var(--body-text-color) !important; background-color: var(--block-background-fill) !important;}
|
| 19 |
+
.container {padding-inline: 0 !important}
|
| 20 |
+
.gradio-container { width: 100% !important; }
|
| 21 |
+
.gradio-row { display: flex !important; flex-direction: row !important; }
|
| 22 |
+
.gradio-column { flex: 1 !important; min-width: 0 !important; }
|
| 23 |
+
#sign_in_button {flex-grow: 0; width: auto !important; display: flex; align-items: center; justify-content: center; margin: 0 auto;}
|
| 24 |
+
.datasets {height: 70px;}
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
image = """<br><img src="https://raw.githubusercontent.com/argilla-io/synthetic-data-generator/main/assets/logo.svg" alt="Synthetic Data Generator Logo" style="display: block; margin-left: auto; margin-right: auto; width: clamp(50%, 400px, 100%)"/>"""
|
| 28 |
+
|
| 29 |
+
demo = TabbedInterface(
|
| 30 |
+
[textcat_app, chat_app, rag_app, about_app],
|
| 31 |
+
["Text Classification", "Chat Data", "RAG", "About"],
|
| 32 |
+
css=css,
|
| 33 |
+
title=image,
|
| 34 |
+
theme=theme,
|
| 35 |
+
)
|
src/synthetic_dataset_generator/apps/__init__.py
ADDED
|
File without changes
|
src/synthetic_dataset_generator/apps/about.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
|
| 3 |
+
with gr.Blocks() as app:
|
| 4 |
+
gr.Markdown(
|
| 5 |
+
"""
|
| 6 |
+
Synthetic data is artificially generated information that mimics real-world data. It allows overcoming data limitations by expanding or enhancing datasets.
|
| 7 |
+
|
| 8 |
+
Introducing the Synthetic Data Generator, a user-friendly application that takes a no-code approach to creating custom datasets with Large Language Models (LLMs). The best part: A simple step-by-step process, making dataset creation a non-technical breeze, allowing anyone to create datasets and models in minutes and without any code.
|
| 9 |
+
|
| 10 |
+
The synthetic data generator takes your custom prompt and returns a dataset for your use case, using a synthetic data pipeline. In the background this is powered by [distilabel](https://distilabel.argilla.io/latest/) and the [free Hugging Face text-generation API](https://huggingface.co/docs/api-inference/en/index) but we don't need to worry about these complexities and we can focus on using the UI.
|
| 11 |
+
|
| 12 |
+
- Read more in [our announcement blog post](https://huggingface.co/blog/synthetic-data-generator)
|
| 13 |
+
- Find the library on [GitHub](https://github.com/argilla-io/synthetic-data-generator)
|
| 14 |
+
"""
|
| 15 |
+
)
|
src/synthetic_dataset_generator/apps/base.py
ADDED
|
@@ -0,0 +1,270 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import io
|
| 2 |
+
import uuid
|
| 3 |
+
from tqdm import tqdm
|
| 4 |
+
from typing import Union
|
| 5 |
+
|
| 6 |
+
import argilla as rg
|
| 7 |
+
import gradio as gr
|
| 8 |
+
import pandas as pd
|
| 9 |
+
from datasets import Dataset, concatenate_datasets, get_dataset_config_names, get_dataset_split_names, load_dataset
|
| 10 |
+
from gradio import OAuthToken
|
| 11 |
+
from huggingface_hub import HfApi, upload_file, repo_exists
|
| 12 |
+
from unstructured.chunking.title import chunk_by_title
|
| 13 |
+
from unstructured.partition.auto import partition
|
| 14 |
+
|
| 15 |
+
from synthetic_dataset_generator.constants import MAX_NUM_ROWS, SAVE_LOCAL_DIR
|
| 16 |
+
from synthetic_dataset_generator.utils import get_argilla_client
|
| 17 |
+
|
| 18 |
+
if SAVE_LOCAL_DIR is not None:
|
| 19 |
+
import os
|
| 20 |
+
os.makedirs(SAVE_LOCAL_DIR, exist_ok=True)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def validate_argilla_user_workspace_dataset(
|
| 24 |
+
dataset_name: str,
|
| 25 |
+
add_to_existing_dataset: bool = True,
|
| 26 |
+
oauth_token: Union[OAuthToken, None] = None,
|
| 27 |
+
progress=gr.Progress(),
|
| 28 |
+
) -> str:
|
| 29 |
+
progress(0.1, desc="Validating dataset configuration")
|
| 30 |
+
hf_user = HfApi().whoami(token=oauth_token.token)["name"]
|
| 31 |
+
client = get_argilla_client()
|
| 32 |
+
if dataset_name is None or dataset_name == "":
|
| 33 |
+
raise gr.Error("Dataset name is required")
|
| 34 |
+
# Create user if it doesn't exist
|
| 35 |
+
rg_user = client.users(username=hf_user)
|
| 36 |
+
if rg_user is None:
|
| 37 |
+
rg_user = client.users.add(
|
| 38 |
+
rg.User(username=hf_user, role="admin", password=str(uuid.uuid4()))
|
| 39 |
+
)
|
| 40 |
+
# Create workspace if it doesn't exist
|
| 41 |
+
workspace = client.workspaces(name=hf_user)
|
| 42 |
+
if workspace is None:
|
| 43 |
+
workspace = client.workspaces.add(rg.Workspace(name=hf_user))
|
| 44 |
+
workspace.add_user(hf_user)
|
| 45 |
+
# Check if dataset exists
|
| 46 |
+
dataset = client.datasets(name=dataset_name, workspace=hf_user)
|
| 47 |
+
if dataset and not add_to_existing_dataset:
|
| 48 |
+
raise gr.Error(f"Dataset {dataset_name} already exists")
|
| 49 |
+
progress(1.0, desc="Dataset configuration validated")
|
| 50 |
+
return ""
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def push_pipeline_code_to_hub(
|
| 54 |
+
pipeline_code: str,
|
| 55 |
+
org_name: str,
|
| 56 |
+
repo_name: str,
|
| 57 |
+
oauth_token: Union[OAuthToken, None] = None,
|
| 58 |
+
progress=gr.Progress(),
|
| 59 |
+
):
|
| 60 |
+
repo_id: str | None = validate_push_to_hub(org_name, repo_name)
|
| 61 |
+
progress(0.1, desc="Uploading pipeline code")
|
| 62 |
+
with io.BytesIO(pipeline_code.encode("utf-8")) as f:
|
| 63 |
+
upload_file(
|
| 64 |
+
path_or_fileobj=f,
|
| 65 |
+
path_in_repo="pipeline.py",
|
| 66 |
+
repo_id=repo_id,
|
| 67 |
+
repo_type="dataset",
|
| 68 |
+
token=oauth_token.token,
|
| 69 |
+
commit_message="Include pipeline script",
|
| 70 |
+
create_pr=False,
|
| 71 |
+
)
|
| 72 |
+
progress(1.0, desc="Pipeline code uploaded")
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def validate_push_to_hub(org_name: str, repo_name: str):
|
| 76 |
+
repo_id = (
|
| 77 |
+
f"{org_name}/{repo_name}"
|
| 78 |
+
if repo_name is not None and org_name is not None
|
| 79 |
+
else None
|
| 80 |
+
)
|
| 81 |
+
if repo_id is not None:
|
| 82 |
+
if not all([repo_id, org_name, repo_name]):
|
| 83 |
+
raise gr.Error(
|
| 84 |
+
"Please provide a `repo_name` and `org_name` to push the dataset to."
|
| 85 |
+
)
|
| 86 |
+
return repo_id
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def combine_datasets(
|
| 90 |
+
repo_id: str, dataset: Dataset, oauth_token: Union[OAuthToken, None]
|
| 91 |
+
) -> Dataset:
|
| 92 |
+
try:
|
| 93 |
+
new_dataset = load_dataset(
|
| 94 |
+
repo_id,
|
| 95 |
+
split="train",
|
| 96 |
+
download_mode="force_redownload",
|
| 97 |
+
token=oauth_token.token,
|
| 98 |
+
)
|
| 99 |
+
return concatenate_datasets([dataset, new_dataset])
|
| 100 |
+
except Exception:
|
| 101 |
+
return dataset
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def show_success_message(org_name: str, repo_name: str) -> gr.Markdown:
|
| 105 |
+
client = get_argilla_client()
|
| 106 |
+
if client is None:
|
| 107 |
+
return gr.Markdown(
|
| 108 |
+
value=f"""
|
| 109 |
+
<div style="padding: 1em; background-color: var(--block-background-fill); border-color: var(--border-color-primary); border-width: 1px; border-radius: 5px;">
|
| 110 |
+
<h3 style="color: #2e7d32; margin: 0;">Dataset Published Successfully!</h3>
|
| 111 |
+
<p style="margin-top: 0.5em;">
|
| 112 |
+
The generated dataset is in the right format for fine-tuning with TRL, AutoTrain, or other frameworks.
|
| 113 |
+
<div style="display: flex; gap: 10px;">
|
| 114 |
+
<a href="https://huggingface.co/datasets/{org_name}/{repo_name}" target="_blank" class="lg primary svelte-1137axg" style="color: white !important; margin-top: 0.5em; text-decoration: none;">
|
| 115 |
+
Open in Hugging Face
|
| 116 |
+
</a>
|
| 117 |
+
</div>
|
| 118 |
+
</p>
|
| 119 |
+
<p style="margin-top: 1em; color: var(--block-title-text-color)">
|
| 120 |
+
By configuring an `ARGILLA_API_URL` and `ARGILLA_API_KEY` you can curate the dataset in Argilla.
|
| 121 |
+
Unfamiliar with Argilla? Here are some docs to help you get started:
|
| 122 |
+
<br>• <a href="https://docs.argilla.io/latest/getting_started/quickstart/" target="_blank">How to get started with Argilla</a>
|
| 123 |
+
<br>• <a href="https://docs.argilla.io/latest/how_to_guides/annotate/" target="_blank">How to curate data in Argilla</a>
|
| 124 |
+
<br>• <a href="https://docs.argilla.io/latest/how_to_guides/import_export/" target="_blank">How to export data once you have reviewed the dataset</a>
|
| 125 |
+
</p>
|
| 126 |
+
</div>
|
| 127 |
+
""",
|
| 128 |
+
visible=True,
|
| 129 |
+
height=None,
|
| 130 |
+
min_height=None,
|
| 131 |
+
max_height=None,
|
| 132 |
+
)
|
| 133 |
+
argilla_api_url = client.api_url
|
| 134 |
+
# Transform Docker internal URL to localhost if needed
|
| 135 |
+
if "argilla:" in argilla_api_url:
|
| 136 |
+
argilla_api_url = argilla_api_url.replace("argilla:", "127.0.0.1:")
|
| 137 |
+
return gr.Markdown(
|
| 138 |
+
value=f"""
|
| 139 |
+
<div style="padding: 1em; background-color: var(--block-background-fill); border-color: var(--border-color-primary); border-width: 1px; border-radius: 5px;">
|
| 140 |
+
<h3 style="color: #2e7d32; margin: 0;">Dataset Published Successfully!</h3>
|
| 141 |
+
<p style="margin-top: 0.5em;">
|
| 142 |
+
The generated dataset is <a href="https://huggingface.co/datasets/{org_name}/{repo_name}" target="_blank">available in the Hub</a>. It is in the right format for fine-tuning with TRL, AutoTrain, or other frameworks.
|
| 143 |
+
<div style="display: flex; gap: 10px;">
|
| 144 |
+
<a href="{argilla_api_url}" target="_blank" class="lg primary svelte-1137axg" style="color: white !important; margin-top: 0.5em; text-decoration: none;">
|
| 145 |
+
Open in Argilla
|
| 146 |
+
</a>
|
| 147 |
+
</div>
|
| 148 |
+
</p>
|
| 149 |
+
<p style="margin-top: 1em; color: var(--block-title-text-color)">
|
| 150 |
+
Unfamiliar with Argilla? Here are some docs to help you get started:
|
| 151 |
+
<br>• <a href="https://docs.argilla.io/latest/how_to_guides/annotate/" target="_blank">How to curate data in Argilla</a>
|
| 152 |
+
<br>• <a href="https://docs.argilla.io/latest/how_to_guides/import_export/" target="_blank">How to export data once you have reviewed the dataset</a>
|
| 153 |
+
</p>
|
| 154 |
+
</div>
|
| 155 |
+
""",
|
| 156 |
+
visible=True,
|
| 157 |
+
height=None,
|
| 158 |
+
min_height=None,
|
| 159 |
+
max_height=None,
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
def hide_success_message() -> gr.Markdown:
|
| 164 |
+
return gr.Markdown(value="", visible=True, height=100)
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
def test_max_num_rows(num_rows: int) -> int:
|
| 168 |
+
if num_rows > MAX_NUM_ROWS:
|
| 169 |
+
num_rows = MAX_NUM_ROWS
|
| 170 |
+
gr.Info(
|
| 171 |
+
f"Number of rows is larger than the configured maximum. Setting number of rows to {MAX_NUM_ROWS}. Set environment variable `MAX_NUM_ROWS` to change this behavior."
|
| 172 |
+
)
|
| 173 |
+
return num_rows
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
def get_iframe(hub_repo_id: str) -> str:
|
| 177 |
+
if not hub_repo_id:
|
| 178 |
+
return ""
|
| 179 |
+
|
| 180 |
+
if not repo_exists(repo_id=hub_repo_id, repo_type="dataset"):
|
| 181 |
+
return ""
|
| 182 |
+
|
| 183 |
+
url = f"https://huggingface.co/datasets/{hub_repo_id}/embed/viewer"
|
| 184 |
+
iframe = f"""
|
| 185 |
+
<iframe
|
| 186 |
+
src="{url}"
|
| 187 |
+
frameborder="0"
|
| 188 |
+
width="100%"
|
| 189 |
+
height="600px"
|
| 190 |
+
></iframe>
|
| 191 |
+
"""
|
| 192 |
+
return iframe
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
def _get_valid_columns(dataframe: pd.DataFrame):
|
| 196 |
+
doc_valid_columns = []
|
| 197 |
+
|
| 198 |
+
for col in dataframe.columns:
|
| 199 |
+
sample_val = dataframe[col].iloc[0]
|
| 200 |
+
if isinstance(sample_val, str):
|
| 201 |
+
doc_valid_columns.append(col)
|
| 202 |
+
|
| 203 |
+
return doc_valid_columns
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
def load_dataset_from_hub(
|
| 207 |
+
repo_id: str,
|
| 208 |
+
num_rows: int = 10,
|
| 209 |
+
token: Union[OAuthToken, None] = None,
|
| 210 |
+
progress=gr.Progress(track_tqdm=True),
|
| 211 |
+
):
|
| 212 |
+
if not repo_id:
|
| 213 |
+
raise gr.Error("Please provide a Hub repo ID")
|
| 214 |
+
subsets = get_dataset_config_names(repo_id, token=token)
|
| 215 |
+
splits = get_dataset_split_names(repo_id, subsets[0], token=token)
|
| 216 |
+
ds = load_dataset(repo_id, subsets[0], split=splits[0], token=token, streaming=True)
|
| 217 |
+
rows = []
|
| 218 |
+
for idx, row in enumerate(tqdm(ds, desc="Loading the dataset", total=num_rows)):
|
| 219 |
+
rows.append(row)
|
| 220 |
+
if idx == num_rows:
|
| 221 |
+
break
|
| 222 |
+
ds = Dataset.from_list(rows)
|
| 223 |
+
dataframe = ds.to_pandas()
|
| 224 |
+
doc_valid_columns = _get_valid_columns(dataframe)
|
| 225 |
+
col_doc = doc_valid_columns[0] if doc_valid_columns else ""
|
| 226 |
+
return (
|
| 227 |
+
dataframe,
|
| 228 |
+
gr.Dropdown(
|
| 229 |
+
choices=doc_valid_columns,
|
| 230 |
+
label="Documents column",
|
| 231 |
+
value=col_doc,
|
| 232 |
+
interactive=(False if col_doc == "" else True),
|
| 233 |
+
multiselect=False,
|
| 234 |
+
),
|
| 235 |
+
)
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
def preprocess_input_data(
|
| 239 |
+
file_paths: list[str], num_rows: int, progress=gr.Progress(track_tqdm=True)
|
| 240 |
+
):
|
| 241 |
+
if not file_paths:
|
| 242 |
+
raise gr.Error("Please provide an input file")
|
| 243 |
+
|
| 244 |
+
data = {}
|
| 245 |
+
total_chunks = 0
|
| 246 |
+
|
| 247 |
+
for file_path in tqdm(file_paths, desc="Processing files", total=len(file_paths)):
|
| 248 |
+
partitioned_file = partition(filename=file_path)
|
| 249 |
+
chunks = [str(chunk) for chunk in chunk_by_title(partitioned_file)]
|
| 250 |
+
data[file_path] = chunks
|
| 251 |
+
total_chunks += len(chunks)
|
| 252 |
+
if total_chunks >= num_rows:
|
| 253 |
+
break
|
| 254 |
+
|
| 255 |
+
dataframe = pd.DataFrame.from_records(
|
| 256 |
+
[(k, v) for k, values in data.items() for v in values],
|
| 257 |
+
columns=["filename", "chunks"],
|
| 258 |
+
)
|
| 259 |
+
col_doc = "chunks"
|
| 260 |
+
|
| 261 |
+
return (
|
| 262 |
+
dataframe,
|
| 263 |
+
gr.Dropdown(
|
| 264 |
+
choices=["chunks"],
|
| 265 |
+
label="Documents column",
|
| 266 |
+
value=col_doc,
|
| 267 |
+
interactive=(False if col_doc == "" else True),
|
| 268 |
+
multiselect=False,
|
| 269 |
+
),
|
| 270 |
+
)
|
src/synthetic_dataset_generator/apps/chat.py
ADDED
|
@@ -0,0 +1,1142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import ast
|
| 2 |
+
import json
|
| 3 |
+
import os
|
| 4 |
+
import random
|
| 5 |
+
import uuid
|
| 6 |
+
from typing import Dict, List, Union
|
| 7 |
+
|
| 8 |
+
import argilla as rg
|
| 9 |
+
import gradio as gr
|
| 10 |
+
import pandas as pd
|
| 11 |
+
from datasets import Dataset
|
| 12 |
+
from distilabel.distiset import Distiset
|
| 13 |
+
from gradio.oauth import OAuthToken
|
| 14 |
+
from gradio_huggingfacehub_search import HuggingfaceHubSearch
|
| 15 |
+
from huggingface_hub import HfApi
|
| 16 |
+
|
| 17 |
+
from synthetic_dataset_generator.apps.base import (
|
| 18 |
+
combine_datasets,
|
| 19 |
+
hide_success_message,
|
| 20 |
+
load_dataset_from_hub,
|
| 21 |
+
preprocess_input_data,
|
| 22 |
+
push_pipeline_code_to_hub,
|
| 23 |
+
show_success_message,
|
| 24 |
+
test_max_num_rows,
|
| 25 |
+
validate_argilla_user_workspace_dataset,
|
| 26 |
+
validate_push_to_hub,
|
| 27 |
+
)
|
| 28 |
+
from synthetic_dataset_generator.constants import (
|
| 29 |
+
BASE_URL,
|
| 30 |
+
DEFAULT_BATCH_SIZE,
|
| 31 |
+
MODEL,
|
| 32 |
+
MODEL_COMPLETION,
|
| 33 |
+
SAVE_LOCAL_DIR,
|
| 34 |
+
SFT_AVAILABLE,
|
| 35 |
+
)
|
| 36 |
+
from synthetic_dataset_generator.pipelines.base import get_rewritten_prompts
|
| 37 |
+
from synthetic_dataset_generator.pipelines.chat import (
|
| 38 |
+
DEFAULT_DATASET_DESCRIPTIONS,
|
| 39 |
+
generate_pipeline_code,
|
| 40 |
+
get_follow_up_generator,
|
| 41 |
+
get_magpie_generator,
|
| 42 |
+
get_prompt_generator,
|
| 43 |
+
get_response_generator,
|
| 44 |
+
get_sentence_pair_generator,
|
| 45 |
+
)
|
| 46 |
+
from synthetic_dataset_generator.pipelines.embeddings import (
|
| 47 |
+
get_embeddings,
|
| 48 |
+
get_sentence_embedding_dimensions,
|
| 49 |
+
)
|
| 50 |
+
from synthetic_dataset_generator.utils import (
|
| 51 |
+
column_to_list,
|
| 52 |
+
get_argilla_client,
|
| 53 |
+
get_org_dropdown,
|
| 54 |
+
get_random_repo_name,
|
| 55 |
+
swap_visibility,
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def _get_dataframe():
|
| 60 |
+
return gr.Dataframe(
|
| 61 |
+
headers=["prompt", "completion"],
|
| 62 |
+
wrap=True,
|
| 63 |
+
interactive=False,
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def convert_dataframe_messages(dataframe: pd.DataFrame) -> pd.DataFrame:
|
| 68 |
+
def convert_to_list_of_dicts(messages: str) -> List[Dict[str, str]]:
|
| 69 |
+
return ast.literal_eval(
|
| 70 |
+
messages.replace("'user'}", "'user'},")
|
| 71 |
+
.replace("'system'}", "'system'},")
|
| 72 |
+
.replace("'assistant'}", "'assistant'},")
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
if "messages" in dataframe.columns:
|
| 76 |
+
dataframe["messages"] = dataframe["messages"].apply(
|
| 77 |
+
lambda x: convert_to_list_of_dicts(x) if isinstance(x, str) else x
|
| 78 |
+
)
|
| 79 |
+
return dataframe
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def generate_system_prompt(dataset_description: str, progress=gr.Progress()):
|
| 83 |
+
progress(0.1, desc="Initializing")
|
| 84 |
+
generate_description = get_prompt_generator()
|
| 85 |
+
progress(0.5, desc="Generating")
|
| 86 |
+
result = next(
|
| 87 |
+
generate_description.process(
|
| 88 |
+
[
|
| 89 |
+
{
|
| 90 |
+
"instruction": dataset_description,
|
| 91 |
+
}
|
| 92 |
+
]
|
| 93 |
+
)
|
| 94 |
+
)[0]["generation"]
|
| 95 |
+
progress(1.0, desc="Prompt generated")
|
| 96 |
+
return result
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def load_dataset_file(
|
| 100 |
+
repo_id: str,
|
| 101 |
+
file_paths: list[str],
|
| 102 |
+
input_type: str,
|
| 103 |
+
num_rows: int = 10,
|
| 104 |
+
token: Union[OAuthToken, None] = None,
|
| 105 |
+
progress=gr.Progress(),
|
| 106 |
+
):
|
| 107 |
+
progress(0.1, desc="Loading the source data")
|
| 108 |
+
if input_type == "dataset-input":
|
| 109 |
+
return load_dataset_from_hub(repo_id=repo_id, num_rows=num_rows, token=token)
|
| 110 |
+
else:
|
| 111 |
+
return preprocess_input_data(file_paths=file_paths, num_rows=num_rows)
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def generate_sample_dataset(
|
| 115 |
+
repo_id: str,
|
| 116 |
+
file_paths: list[str],
|
| 117 |
+
input_type: str,
|
| 118 |
+
system_prompt: str,
|
| 119 |
+
document_column: str,
|
| 120 |
+
num_turns: int,
|
| 121 |
+
num_rows: int,
|
| 122 |
+
oauth_token: Union[OAuthToken, None],
|
| 123 |
+
progress=gr.Progress(),
|
| 124 |
+
):
|
| 125 |
+
if input_type == "prompt-input":
|
| 126 |
+
dataframe = pd.DataFrame(columns=["prompt", "completion"])
|
| 127 |
+
else:
|
| 128 |
+
dataframe, _ = load_dataset_file(
|
| 129 |
+
repo_id=repo_id,
|
| 130 |
+
file_paths=file_paths,
|
| 131 |
+
input_type=input_type,
|
| 132 |
+
num_rows=num_rows,
|
| 133 |
+
token=oauth_token,
|
| 134 |
+
)
|
| 135 |
+
progress(0.5, desc="Generating sample dataset")
|
| 136 |
+
dataframe = generate_dataset(
|
| 137 |
+
input_type=input_type,
|
| 138 |
+
dataframe=dataframe,
|
| 139 |
+
system_prompt=system_prompt,
|
| 140 |
+
document_column=document_column,
|
| 141 |
+
num_turns=num_turns,
|
| 142 |
+
num_rows=num_rows,
|
| 143 |
+
is_sample=True,
|
| 144 |
+
)
|
| 145 |
+
progress(1.0, desc="Sample dataset generated")
|
| 146 |
+
return dataframe
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def generate_dataset_from_prompt(
|
| 150 |
+
system_prompt: str,
|
| 151 |
+
num_turns: int = 1,
|
| 152 |
+
num_rows: int = 10,
|
| 153 |
+
temperature: float = 0.9,
|
| 154 |
+
temperature_completion: Union[float, None] = None,
|
| 155 |
+
is_sample: bool = False,
|
| 156 |
+
progress=gr.Progress(),
|
| 157 |
+
) -> pd.DataFrame:
|
| 158 |
+
num_rows = test_max_num_rows(num_rows)
|
| 159 |
+
progress(0.0, desc="(1/2) Generating instructions")
|
| 160 |
+
magpie_generator = get_magpie_generator(num_turns, temperature, is_sample)
|
| 161 |
+
response_generator = get_response_generator(
|
| 162 |
+
system_prompt=system_prompt,
|
| 163 |
+
num_turns=num_turns,
|
| 164 |
+
temperature=temperature or temperature_completion,
|
| 165 |
+
is_sample=is_sample,
|
| 166 |
+
)
|
| 167 |
+
total_steps: int = num_rows * 2
|
| 168 |
+
batch_size = DEFAULT_BATCH_SIZE
|
| 169 |
+
|
| 170 |
+
# create prompt rewrites
|
| 171 |
+
prompt_rewrites = get_rewritten_prompts(system_prompt, num_rows)
|
| 172 |
+
|
| 173 |
+
# create instructions
|
| 174 |
+
n_processed = 0
|
| 175 |
+
magpie_results = []
|
| 176 |
+
while n_processed < num_rows:
|
| 177 |
+
progress(
|
| 178 |
+
0.5 * n_processed / num_rows,
|
| 179 |
+
total=total_steps,
|
| 180 |
+
desc="(1/2) Generating instructions",
|
| 181 |
+
)
|
| 182 |
+
remaining_rows = num_rows - n_processed
|
| 183 |
+
batch_size = min(batch_size, remaining_rows)
|
| 184 |
+
rewritten_system_prompt = random.choice(prompt_rewrites)
|
| 185 |
+
inputs = [{"system_prompt": rewritten_system_prompt} for _ in range(batch_size)]
|
| 186 |
+
batch = list(magpie_generator.process(inputs=inputs))
|
| 187 |
+
magpie_results.extend(batch[0])
|
| 188 |
+
n_processed += batch_size
|
| 189 |
+
random.seed(a=random.randint(0, 2**32 - 1))
|
| 190 |
+
progress(0.5, desc="(1/2) Generating instructions")
|
| 191 |
+
|
| 192 |
+
# generate responses
|
| 193 |
+
n_processed = 0
|
| 194 |
+
response_results = []
|
| 195 |
+
if num_turns == 1:
|
| 196 |
+
while n_processed < num_rows:
|
| 197 |
+
progress(
|
| 198 |
+
0.5 + 0.5 * n_processed / num_rows,
|
| 199 |
+
total=total_steps,
|
| 200 |
+
desc="(2/2) Generating responses",
|
| 201 |
+
)
|
| 202 |
+
batch = magpie_results[n_processed : n_processed + batch_size]
|
| 203 |
+
responses = list(response_generator.process(inputs=batch))
|
| 204 |
+
response_results.extend(responses[0])
|
| 205 |
+
n_processed += batch_size
|
| 206 |
+
random.seed(a=random.randint(0, 2**32 - 1))
|
| 207 |
+
for result in response_results:
|
| 208 |
+
result["prompt"] = result["instruction"]
|
| 209 |
+
result["completion"] = result["generation"]
|
| 210 |
+
result["system_prompt"] = system_prompt
|
| 211 |
+
else:
|
| 212 |
+
for result in magpie_results:
|
| 213 |
+
result["conversation"].insert(
|
| 214 |
+
0, {"role": "system", "content": system_prompt}
|
| 215 |
+
)
|
| 216 |
+
result["messages"] = result["conversation"]
|
| 217 |
+
while n_processed < num_rows:
|
| 218 |
+
progress(
|
| 219 |
+
0.5 + 0.5 * n_processed / num_rows,
|
| 220 |
+
total=total_steps,
|
| 221 |
+
desc="(2/2) Generating responses",
|
| 222 |
+
)
|
| 223 |
+
batch = magpie_results[n_processed : n_processed + batch_size]
|
| 224 |
+
responses = list(response_generator.process(inputs=batch))
|
| 225 |
+
response_results.extend(responses[0])
|
| 226 |
+
n_processed += batch_size
|
| 227 |
+
random.seed(a=random.randint(0, 2**32 - 1))
|
| 228 |
+
for result in response_results:
|
| 229 |
+
result["messages"].append(
|
| 230 |
+
{"role": "assistant", "content": result["generation"]}
|
| 231 |
+
)
|
| 232 |
+
progress(
|
| 233 |
+
1,
|
| 234 |
+
total=total_steps,
|
| 235 |
+
desc="(2/2) Creating dataset",
|
| 236 |
+
)
|
| 237 |
+
|
| 238 |
+
# create distiset
|
| 239 |
+
distiset_results = []
|
| 240 |
+
for result in response_results:
|
| 241 |
+
record = {}
|
| 242 |
+
for relevant_keys in [
|
| 243 |
+
"messages",
|
| 244 |
+
"prompt",
|
| 245 |
+
"completion",
|
| 246 |
+
"model_name",
|
| 247 |
+
"system_prompt",
|
| 248 |
+
]:
|
| 249 |
+
if relevant_keys in result:
|
| 250 |
+
record[relevant_keys] = result[relevant_keys]
|
| 251 |
+
distiset_results.append(record)
|
| 252 |
+
|
| 253 |
+
distiset = Distiset(
|
| 254 |
+
{
|
| 255 |
+
"default": Dataset.from_list(distiset_results),
|
| 256 |
+
}
|
| 257 |
+
)
|
| 258 |
+
|
| 259 |
+
# If not pushing to hub generate the dataset directly
|
| 260 |
+
distiset = distiset["default"]
|
| 261 |
+
if num_turns == 1:
|
| 262 |
+
outputs = distiset.to_pandas()[["prompt", "completion", "system_prompt"]]
|
| 263 |
+
else:
|
| 264 |
+
outputs = distiset.to_pandas()[["messages"]]
|
| 265 |
+
dataframe = pd.DataFrame(outputs)
|
| 266 |
+
progress(1.0, desc="Dataset generation completed")
|
| 267 |
+
return dataframe
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
def generate_dataset_from_seed(
|
| 271 |
+
dataframe: pd.DataFrame,
|
| 272 |
+
document_column: str,
|
| 273 |
+
num_turns: int = 1,
|
| 274 |
+
num_rows: int = 10,
|
| 275 |
+
temperature: float = 0.9,
|
| 276 |
+
temperature_completion: Union[float, None] = None,
|
| 277 |
+
is_sample: bool = False,
|
| 278 |
+
progress=gr.Progress(),
|
| 279 |
+
) -> pd.DataFrame:
|
| 280 |
+
num_rows = test_max_num_rows(num_rows)
|
| 281 |
+
progress(0.0, desc="Initializing dataset generation")
|
| 282 |
+
document_data = column_to_list(dataframe, document_column)
|
| 283 |
+
if len(document_data) < num_rows:
|
| 284 |
+
document_data += random.choices(document_data, k=num_rows - len(document_data))
|
| 285 |
+
instruction_generator = get_sentence_pair_generator(
|
| 286 |
+
temperature=temperature, is_sample=is_sample
|
| 287 |
+
)
|
| 288 |
+
response_generator = get_response_generator(
|
| 289 |
+
system_prompt=None,
|
| 290 |
+
num_turns=1,
|
| 291 |
+
temperature=temperature or temperature_completion,
|
| 292 |
+
is_sample=is_sample,
|
| 293 |
+
)
|
| 294 |
+
follow_up_generator_instruction = get_follow_up_generator(
|
| 295 |
+
type="instruction", temperature=temperature, is_sample=is_sample
|
| 296 |
+
)
|
| 297 |
+
follow_up_generator_response = get_follow_up_generator(
|
| 298 |
+
type="response",
|
| 299 |
+
temperature=temperature or temperature_completion,
|
| 300 |
+
is_sample=is_sample,
|
| 301 |
+
)
|
| 302 |
+
steps = 2 * num_turns
|
| 303 |
+
total_steps: int = num_rows * steps
|
| 304 |
+
step_progress = round(1 / steps, 2)
|
| 305 |
+
batch_size = DEFAULT_BATCH_SIZE
|
| 306 |
+
|
| 307 |
+
# create instructions
|
| 308 |
+
n_processed = 0
|
| 309 |
+
instruction_results = []
|
| 310 |
+
while n_processed < num_rows:
|
| 311 |
+
progress(
|
| 312 |
+
step_progress * n_processed / num_rows,
|
| 313 |
+
total=total_steps,
|
| 314 |
+
desc="Generating instructions",
|
| 315 |
+
)
|
| 316 |
+
remaining_rows = num_rows - n_processed
|
| 317 |
+
batch_size = min(batch_size, remaining_rows)
|
| 318 |
+
batch = [
|
| 319 |
+
{"anchor": document}
|
| 320 |
+
for document in document_data[n_processed : n_processed + batch_size]
|
| 321 |
+
]
|
| 322 |
+
questions = list(instruction_generator.process(inputs=batch))
|
| 323 |
+
instruction_results.extend(questions[0])
|
| 324 |
+
n_processed += batch_size
|
| 325 |
+
for result in instruction_results:
|
| 326 |
+
result["instruction"] = result["positive"]
|
| 327 |
+
result["prompt"] = result.pop("positive")
|
| 328 |
+
|
| 329 |
+
progress(step_progress, desc="Generating instructions")
|
| 330 |
+
|
| 331 |
+
# generate responses
|
| 332 |
+
n_processed = 0
|
| 333 |
+
response_results = []
|
| 334 |
+
while n_processed < num_rows:
|
| 335 |
+
progress(
|
| 336 |
+
step_progress + step_progress * n_processed / num_rows,
|
| 337 |
+
total=total_steps,
|
| 338 |
+
desc="Generating responses",
|
| 339 |
+
)
|
| 340 |
+
batch = instruction_results[n_processed : n_processed + batch_size]
|
| 341 |
+
responses = list(response_generator.process(inputs=batch))
|
| 342 |
+
response_results.extend(responses[0])
|
| 343 |
+
n_processed += batch_size
|
| 344 |
+
for result in response_results:
|
| 345 |
+
result["completion"] = result.pop("generation")
|
| 346 |
+
|
| 347 |
+
# generate follow-ups
|
| 348 |
+
if num_turns > 1:
|
| 349 |
+
n_processed = 0
|
| 350 |
+
final_conversations = []
|
| 351 |
+
|
| 352 |
+
while n_processed < num_rows:
|
| 353 |
+
progress(
|
| 354 |
+
step_progress + step_progress * n_processed / num_rows,
|
| 355 |
+
total=total_steps,
|
| 356 |
+
desc="Generating follow-ups",
|
| 357 |
+
)
|
| 358 |
+
batch = response_results[n_processed : n_processed + batch_size]
|
| 359 |
+
conversations_batch = [
|
| 360 |
+
{
|
| 361 |
+
"messages": [
|
| 362 |
+
{"role": "user", "content": result["prompt"]},
|
| 363 |
+
{"role": "assistant", "content": result["completion"]},
|
| 364 |
+
]
|
| 365 |
+
}
|
| 366 |
+
for result in batch
|
| 367 |
+
]
|
| 368 |
+
|
| 369 |
+
for _ in range(num_turns - 1):
|
| 370 |
+
follow_up_instructions = list(
|
| 371 |
+
follow_up_generator_instruction.process(inputs=conversations_batch)
|
| 372 |
+
)
|
| 373 |
+
for conv, follow_up in zip(
|
| 374 |
+
conversations_batch, follow_up_instructions[0]
|
| 375 |
+
):
|
| 376 |
+
conv["messages"].append(
|
| 377 |
+
{"role": "user", "content": follow_up["generation"]}
|
| 378 |
+
)
|
| 379 |
+
|
| 380 |
+
follow_up_responses = list(
|
| 381 |
+
follow_up_generator_response.process(inputs=conversations_batch)
|
| 382 |
+
)
|
| 383 |
+
for conv, follow_up in zip(conversations_batch, follow_up_responses[0]):
|
| 384 |
+
conv["messages"].append(
|
| 385 |
+
{"role": "assistant", "content": follow_up["generation"]}
|
| 386 |
+
)
|
| 387 |
+
|
| 388 |
+
final_conversations.extend(
|
| 389 |
+
[{"messages": conv["messages"]} for conv in conversations_batch]
|
| 390 |
+
)
|
| 391 |
+
n_processed += batch_size
|
| 392 |
+
|
| 393 |
+
# create distiset
|
| 394 |
+
distiset_results = []
|
| 395 |
+
if num_turns == 1:
|
| 396 |
+
for result in response_results:
|
| 397 |
+
record = {}
|
| 398 |
+
for relevant_keys in ["prompt", "completion"]:
|
| 399 |
+
if relevant_keys in result:
|
| 400 |
+
record[relevant_keys] = result[relevant_keys]
|
| 401 |
+
distiset_results.append(record)
|
| 402 |
+
dataframe = pd.DataFrame(distiset_results)
|
| 403 |
+
else:
|
| 404 |
+
distiset_results = final_conversations
|
| 405 |
+
dataframe = pd.DataFrame(distiset_results)
|
| 406 |
+
dataframe["messages"] = dataframe["messages"].apply(lambda x: json.dumps(x))
|
| 407 |
+
|
| 408 |
+
progress(1.0, desc="Dataset generation completed")
|
| 409 |
+
return dataframe
|
| 410 |
+
|
| 411 |
+
|
| 412 |
+
def generate_dataset(
|
| 413 |
+
input_type: str,
|
| 414 |
+
dataframe: pd.DataFrame,
|
| 415 |
+
system_prompt: str,
|
| 416 |
+
document_column: str,
|
| 417 |
+
num_turns: int = 1,
|
| 418 |
+
num_rows: int = 10,
|
| 419 |
+
temperature: float = 0.9,
|
| 420 |
+
temperature_completion: Union[float, None] = None,
|
| 421 |
+
is_sample: bool = False,
|
| 422 |
+
progress=gr.Progress(),
|
| 423 |
+
) -> pd.DataFrame:
|
| 424 |
+
if input_type == "prompt-input":
|
| 425 |
+
dataframe = generate_dataset_from_prompt(
|
| 426 |
+
system_prompt=system_prompt,
|
| 427 |
+
num_turns=num_turns,
|
| 428 |
+
num_rows=num_rows,
|
| 429 |
+
temperature=temperature,
|
| 430 |
+
temperature_completion=temperature_completion,
|
| 431 |
+
is_sample=is_sample,
|
| 432 |
+
)
|
| 433 |
+
else:
|
| 434 |
+
dataframe = generate_dataset_from_seed(
|
| 435 |
+
dataframe=dataframe,
|
| 436 |
+
document_column=document_column,
|
| 437 |
+
num_turns=num_turns,
|
| 438 |
+
num_rows=num_rows,
|
| 439 |
+
temperature=temperature,
|
| 440 |
+
temperature_completion=temperature_completion,
|
| 441 |
+
is_sample=is_sample,
|
| 442 |
+
)
|
| 443 |
+
return dataframe
|
| 444 |
+
|
| 445 |
+
|
| 446 |
+
def push_dataset_to_hub(
|
| 447 |
+
dataframe: pd.DataFrame,
|
| 448 |
+
org_name: str,
|
| 449 |
+
repo_name: str,
|
| 450 |
+
oauth_token: Union[gr.OAuthToken, None],
|
| 451 |
+
private: bool,
|
| 452 |
+
pipeline_code: str,
|
| 453 |
+
progress=gr.Progress(),
|
| 454 |
+
):
|
| 455 |
+
progress(0.0, desc="Validating")
|
| 456 |
+
repo_id = validate_push_to_hub(org_name, repo_name)
|
| 457 |
+
progress(0.3, desc="Converting")
|
| 458 |
+
original_dataframe = dataframe.copy(deep=True)
|
| 459 |
+
dataframe = convert_dataframe_messages(dataframe)
|
| 460 |
+
progress(0.7, desc="Creating dataset")
|
| 461 |
+
dataset = Dataset.from_pandas(dataframe)
|
| 462 |
+
dataset = combine_datasets(repo_id, dataset, oauth_token)
|
| 463 |
+
progress(0.9, desc="Pushing dataset")
|
| 464 |
+
distiset = Distiset({"default": dataset})
|
| 465 |
+
distiset.push_to_hub(
|
| 466 |
+
repo_id=repo_id,
|
| 467 |
+
private=private,
|
| 468 |
+
include_script=False,
|
| 469 |
+
token=oauth_token.token,
|
| 470 |
+
create_pr=False,
|
| 471 |
+
)
|
| 472 |
+
push_pipeline_code_to_hub(pipeline_code, org_name, repo_name, oauth_token)
|
| 473 |
+
progress(1.0, desc="Dataset pushed")
|
| 474 |
+
return original_dataframe
|
| 475 |
+
|
| 476 |
+
|
| 477 |
+
def push_dataset(
|
| 478 |
+
org_name: str,
|
| 479 |
+
repo_name: str,
|
| 480 |
+
private: bool,
|
| 481 |
+
original_repo_id: str,
|
| 482 |
+
file_paths: list[str],
|
| 483 |
+
input_type: str,
|
| 484 |
+
system_prompt: str,
|
| 485 |
+
document_column: str,
|
| 486 |
+
num_turns: int = 1,
|
| 487 |
+
num_rows: int = 10,
|
| 488 |
+
temperature: float = 0.9,
|
| 489 |
+
temperature_completion: Union[float, None] = None,
|
| 490 |
+
pipeline_code: str = "",
|
| 491 |
+
oauth_token: Union[gr.OAuthToken, None] = None,
|
| 492 |
+
progress=gr.Progress(),
|
| 493 |
+
) -> pd.DataFrame:
|
| 494 |
+
if input_type == "prompt-input":
|
| 495 |
+
dataframe = _get_dataframe()
|
| 496 |
+
else:
|
| 497 |
+
dataframe, _ = load_dataset_file(
|
| 498 |
+
repo_id=original_repo_id,
|
| 499 |
+
file_paths=file_paths,
|
| 500 |
+
input_type=input_type,
|
| 501 |
+
num_rows=num_rows,
|
| 502 |
+
token=oauth_token,
|
| 503 |
+
)
|
| 504 |
+
progress(0.5, desc="Generating dataset")
|
| 505 |
+
dataframe = generate_dataset(
|
| 506 |
+
input_type=input_type,
|
| 507 |
+
dataframe=dataframe,
|
| 508 |
+
system_prompt=system_prompt,
|
| 509 |
+
document_column=document_column,
|
| 510 |
+
num_turns=num_turns,
|
| 511 |
+
num_rows=num_rows,
|
| 512 |
+
temperature=temperature,
|
| 513 |
+
temperature_completion=temperature_completion,
|
| 514 |
+
)
|
| 515 |
+
push_dataset_to_hub(
|
| 516 |
+
dataframe=dataframe,
|
| 517 |
+
org_name=org_name,
|
| 518 |
+
repo_name=repo_name,
|
| 519 |
+
oauth_token=oauth_token,
|
| 520 |
+
private=private,
|
| 521 |
+
pipeline_code=pipeline_code,
|
| 522 |
+
)
|
| 523 |
+
try:
|
| 524 |
+
progress(0.1, desc="Setting up user and workspace")
|
| 525 |
+
hf_user = HfApi().whoami(token=oauth_token.token)["name"]
|
| 526 |
+
client = get_argilla_client()
|
| 527 |
+
if client is None:
|
| 528 |
+
return ""
|
| 529 |
+
progress(0.5, desc="Creating dataset in Argilla")
|
| 530 |
+
if "messages" in dataframe.columns:
|
| 531 |
+
settings = rg.Settings(
|
| 532 |
+
fields=[
|
| 533 |
+
rg.ChatField(
|
| 534 |
+
name="messages",
|
| 535 |
+
description="The messages in the conversation",
|
| 536 |
+
title="Messages",
|
| 537 |
+
),
|
| 538 |
+
],
|
| 539 |
+
questions=[
|
| 540 |
+
rg.RatingQuestion(
|
| 541 |
+
name="rating",
|
| 542 |
+
title="Rating",
|
| 543 |
+
description="The rating of the conversation",
|
| 544 |
+
values=list(range(1, 6)),
|
| 545 |
+
),
|
| 546 |
+
],
|
| 547 |
+
metadata=[
|
| 548 |
+
rg.IntegerMetadataProperty(
|
| 549 |
+
name="user_message_length", title="User Message Length"
|
| 550 |
+
),
|
| 551 |
+
rg.IntegerMetadataProperty(
|
| 552 |
+
name="assistant_message_length",
|
| 553 |
+
title="Assistant Message Length",
|
| 554 |
+
),
|
| 555 |
+
],
|
| 556 |
+
vectors=[
|
| 557 |
+
rg.VectorField(
|
| 558 |
+
name="messages_embeddings",
|
| 559 |
+
dimensions=get_sentence_embedding_dimensions(),
|
| 560 |
+
)
|
| 561 |
+
],
|
| 562 |
+
guidelines="Please review the conversation and provide a score for the assistant's response.",
|
| 563 |
+
)
|
| 564 |
+
|
| 565 |
+
dataframe["user_message_length"] = dataframe["messages"].apply(
|
| 566 |
+
lambda x: sum([len(y["content"]) for y in x if y["role"] == "user"])
|
| 567 |
+
)
|
| 568 |
+
dataframe["assistant_message_length"] = dataframe["messages"].apply(
|
| 569 |
+
lambda x: sum(
|
| 570 |
+
[len(y["content"]) for y in x if y["role"] == "assistant"]
|
| 571 |
+
)
|
| 572 |
+
)
|
| 573 |
+
dataframe["messages_embeddings"] = get_embeddings(
|
| 574 |
+
dataframe["messages"].apply(
|
| 575 |
+
lambda x: " ".join([y["content"] for y in x])
|
| 576 |
+
)
|
| 577 |
+
)
|
| 578 |
+
else:
|
| 579 |
+
settings = rg.Settings(
|
| 580 |
+
fields=[
|
| 581 |
+
rg.TextField(
|
| 582 |
+
name="system_prompt",
|
| 583 |
+
title="System Prompt",
|
| 584 |
+
description="The system prompt used for the conversation",
|
| 585 |
+
required=False,
|
| 586 |
+
),
|
| 587 |
+
rg.TextField(
|
| 588 |
+
name="prompt",
|
| 589 |
+
title="Prompt",
|
| 590 |
+
description="The prompt used for the conversation",
|
| 591 |
+
),
|
| 592 |
+
rg.TextField(
|
| 593 |
+
name="completion",
|
| 594 |
+
title="Completion",
|
| 595 |
+
description="The completion from the assistant",
|
| 596 |
+
),
|
| 597 |
+
],
|
| 598 |
+
questions=[
|
| 599 |
+
rg.RatingQuestion(
|
| 600 |
+
name="rating",
|
| 601 |
+
title="Rating",
|
| 602 |
+
description="The rating of the conversation",
|
| 603 |
+
values=list(range(1, 6)),
|
| 604 |
+
),
|
| 605 |
+
],
|
| 606 |
+
metadata=[
|
| 607 |
+
rg.IntegerMetadataProperty(
|
| 608 |
+
name="prompt_length", title="Prompt Length"
|
| 609 |
+
),
|
| 610 |
+
rg.IntegerMetadataProperty(
|
| 611 |
+
name="completion_length", title="Completion Length"
|
| 612 |
+
),
|
| 613 |
+
],
|
| 614 |
+
vectors=[
|
| 615 |
+
rg.VectorField(
|
| 616 |
+
name="prompt_embeddings",
|
| 617 |
+
dimensions=get_sentence_embedding_dimensions(),
|
| 618 |
+
)
|
| 619 |
+
],
|
| 620 |
+
guidelines="Please review the conversation and correct the prompt and completion where needed.",
|
| 621 |
+
)
|
| 622 |
+
dataframe["prompt_length"] = dataframe["prompt"].apply(len)
|
| 623 |
+
dataframe["completion_length"] = dataframe["completion"].apply(len)
|
| 624 |
+
dataframe["prompt_embeddings"] = get_embeddings(dataframe["prompt"])
|
| 625 |
+
|
| 626 |
+
rg_dataset = client.datasets(name=repo_name, workspace=hf_user)
|
| 627 |
+
if rg_dataset is None:
|
| 628 |
+
rg_dataset = rg.Dataset(
|
| 629 |
+
name=repo_name,
|
| 630 |
+
workspace=hf_user,
|
| 631 |
+
settings=settings,
|
| 632 |
+
client=client,
|
| 633 |
+
)
|
| 634 |
+
rg_dataset = rg_dataset.create()
|
| 635 |
+
progress(0.7, desc="Pushing dataset to Argilla")
|
| 636 |
+
hf_dataset = Dataset.from_pandas(dataframe)
|
| 637 |
+
rg_dataset.records.log(records=hf_dataset)
|
| 638 |
+
progress(1.0, desc="Dataset pushed to Argilla")
|
| 639 |
+
except Exception as e:
|
| 640 |
+
raise gr.Error(f"Error pushing dataset to Argilla: {e}")
|
| 641 |
+
return ""
|
| 642 |
+
|
| 643 |
+
|
| 644 |
+
def save_local(
|
| 645 |
+
repo_id: str,
|
| 646 |
+
file_paths: list[str],
|
| 647 |
+
input_type: str,
|
| 648 |
+
system_prompt: str,
|
| 649 |
+
document_column: str,
|
| 650 |
+
num_turns: int,
|
| 651 |
+
num_rows: int,
|
| 652 |
+
temperature: float,
|
| 653 |
+
repo_name: str,
|
| 654 |
+
temperature_completion: Union[float, None] = None,
|
| 655 |
+
) -> pd.DataFrame:
|
| 656 |
+
if input_type == "prompt-input":
|
| 657 |
+
dataframe = _get_dataframe()
|
| 658 |
+
else:
|
| 659 |
+
dataframe, _ = load_dataset_file(
|
| 660 |
+
repo_id=repo_id,
|
| 661 |
+
file_paths=file_paths,
|
| 662 |
+
input_type=input_type,
|
| 663 |
+
num_rows=num_rows,
|
| 664 |
+
)
|
| 665 |
+
dataframe = generate_dataset(
|
| 666 |
+
input_type=input_type,
|
| 667 |
+
dataframe=dataframe,
|
| 668 |
+
system_prompt=system_prompt,
|
| 669 |
+
document_column=document_column,
|
| 670 |
+
num_turns=num_turns,
|
| 671 |
+
num_rows=num_rows,
|
| 672 |
+
temperature=temperature,
|
| 673 |
+
temperature_completion=temperature_completion,
|
| 674 |
+
)
|
| 675 |
+
local_dataset = Dataset.from_pandas(dataframe)
|
| 676 |
+
output_csv = os.path.join(SAVE_LOCAL_DIR, repo_name + ".csv")
|
| 677 |
+
output_json = os.path.join(SAVE_LOCAL_DIR, repo_name + ".json")
|
| 678 |
+
local_dataset.to_csv(output_csv, index=False)
|
| 679 |
+
local_dataset.to_json(output_json, index=False)
|
| 680 |
+
return output_csv, output_json
|
| 681 |
+
|
| 682 |
+
|
| 683 |
+
def show_system_prompt_visibility():
|
| 684 |
+
return {system_prompt: gr.Textbox(visible=True)}
|
| 685 |
+
|
| 686 |
+
|
| 687 |
+
def hide_system_prompt_visibility():
|
| 688 |
+
return {system_prompt: gr.Textbox(visible=False)}
|
| 689 |
+
|
| 690 |
+
|
| 691 |
+
def show_document_column_visibility():
|
| 692 |
+
return {document_column: gr.Dropdown(visible=True)}
|
| 693 |
+
|
| 694 |
+
|
| 695 |
+
def hide_document_column_visibility():
|
| 696 |
+
return {
|
| 697 |
+
document_column: gr.Dropdown(
|
| 698 |
+
choices=["Load your data first in step 1."],
|
| 699 |
+
value="Load your data first in step 1.",
|
| 700 |
+
visible=False,
|
| 701 |
+
)
|
| 702 |
+
}
|
| 703 |
+
|
| 704 |
+
|
| 705 |
+
def show_pipeline_code_visibility():
|
| 706 |
+
return {pipeline_code_ui: gr.Accordion(visible=True)}
|
| 707 |
+
|
| 708 |
+
|
| 709 |
+
def hide_pipeline_code_visibility():
|
| 710 |
+
return {pipeline_code_ui: gr.Accordion(visible=False)}
|
| 711 |
+
|
| 712 |
+
|
| 713 |
+
def show_temperature_completion():
|
| 714 |
+
if MODEL != MODEL_COMPLETION:
|
| 715 |
+
return {temperature_completion: gr.Slider(value=0.9, visible=True)}
|
| 716 |
+
|
| 717 |
+
|
| 718 |
+
def show_save_local_button():
|
| 719 |
+
return {btn_save_local: gr.Button(visible=True)}
|
| 720 |
+
|
| 721 |
+
|
| 722 |
+
def hide_save_local_button():
|
| 723 |
+
return {btn_save_local: gr.Button(visible=False)}
|
| 724 |
+
|
| 725 |
+
|
| 726 |
+
def show_save_local():
|
| 727 |
+
gr.update(success_message, min_height=0)
|
| 728 |
+
return {
|
| 729 |
+
csv_file: gr.File(visible=True),
|
| 730 |
+
json_file: gr.File(visible=True),
|
| 731 |
+
success_message: success_message
|
| 732 |
+
}
|
| 733 |
+
|
| 734 |
+
def hide_save_local():
|
| 735 |
+
gr.update(success_message, min_height=100)
|
| 736 |
+
return {
|
| 737 |
+
csv_file: gr.File(visible=False),
|
| 738 |
+
json_file: gr.File(visible=False),
|
| 739 |
+
success_message: success_message,
|
| 740 |
+
}
|
| 741 |
+
|
| 742 |
+
|
| 743 |
+
######################
|
| 744 |
+
# Gradio UI
|
| 745 |
+
######################
|
| 746 |
+
|
| 747 |
+
|
| 748 |
+
with gr.Blocks() as app:
|
| 749 |
+
with gr.Column() as main_ui:
|
| 750 |
+
if not SFT_AVAILABLE:
|
| 751 |
+
gr.Markdown(
|
| 752 |
+
value="\n".join(
|
| 753 |
+
[
|
| 754 |
+
"## Supervised Fine-Tuning not available",
|
| 755 |
+
"",
|
| 756 |
+
f"This tool relies on the [Magpie](https://arxiv.org/abs/2406.08464) prequery template, which is not implemented for the {MODEL} with {BASE_URL}.",
|
| 757 |
+
"Use Llama3 or Qwen2 models with Hugging Face Inference Endpoints.",
|
| 758 |
+
]
|
| 759 |
+
)
|
| 760 |
+
)
|
| 761 |
+
else:
|
| 762 |
+
gr.Markdown("## 1. Select your input")
|
| 763 |
+
with gr.Row(equal_height=False):
|
| 764 |
+
with gr.Column(scale=2):
|
| 765 |
+
input_type = gr.Dropdown(
|
| 766 |
+
label="Input type",
|
| 767 |
+
choices=["prompt-input", "dataset-input", "file-input"],
|
| 768 |
+
value="prompt-input",
|
| 769 |
+
multiselect=False,
|
| 770 |
+
visible=False,
|
| 771 |
+
)
|
| 772 |
+
with gr.Tab("Generate from prompt") as tab_prompt_input:
|
| 773 |
+
with gr.Row(equal_height=False):
|
| 774 |
+
with gr.Column(scale=2):
|
| 775 |
+
dataset_description = gr.Textbox(
|
| 776 |
+
label="Dataset description",
|
| 777 |
+
placeholder="Give a precise description of your desired dataset.",
|
| 778 |
+
)
|
| 779 |
+
with gr.Row():
|
| 780 |
+
clear_prompt_btn_part = gr.Button(
|
| 781 |
+
"Clear", variant="secondary"
|
| 782 |
+
)
|
| 783 |
+
load_prompt_btn = gr.Button(
|
| 784 |
+
"Create", variant="primary"
|
| 785 |
+
)
|
| 786 |
+
with gr.Column(scale=3):
|
| 787 |
+
examples = gr.Examples(
|
| 788 |
+
examples=DEFAULT_DATASET_DESCRIPTIONS,
|
| 789 |
+
inputs=[dataset_description],
|
| 790 |
+
cache_examples=False,
|
| 791 |
+
label="Examples",
|
| 792 |
+
)
|
| 793 |
+
with gr.Tab("Load from Hub") as tab_dataset_input:
|
| 794 |
+
with gr.Row(equal_height=False):
|
| 795 |
+
with gr.Column(scale=2):
|
| 796 |
+
search_in = HuggingfaceHubSearch(
|
| 797 |
+
label="Search",
|
| 798 |
+
placeholder="Search for a dataset",
|
| 799 |
+
search_type="dataset",
|
| 800 |
+
sumbit_on_select=True,
|
| 801 |
+
)
|
| 802 |
+
with gr.Row():
|
| 803 |
+
clear_dataset_btn_part = gr.Button(
|
| 804 |
+
"Clear", variant="secondary"
|
| 805 |
+
)
|
| 806 |
+
load_dataset_btn = gr.Button(
|
| 807 |
+
"Load", variant="primary"
|
| 808 |
+
)
|
| 809 |
+
with gr.Column(scale=3):
|
| 810 |
+
examples = gr.Examples(
|
| 811 |
+
examples=[
|
| 812 |
+
"charris/wikipedia_sample",
|
| 813 |
+
"plaguss/argilla_sdk_docs_raw_unstructured",
|
| 814 |
+
"BeIR/hotpotqa-generated-queries",
|
| 815 |
+
],
|
| 816 |
+
label="Example datasets",
|
| 817 |
+
fn=lambda x: x,
|
| 818 |
+
inputs=[search_in],
|
| 819 |
+
run_on_click=True,
|
| 820 |
+
)
|
| 821 |
+
search_out = gr.HTML(
|
| 822 |
+
label="Dataset preview", visible=False
|
| 823 |
+
)
|
| 824 |
+
with gr.Tab("Load your file") as tab_file_input:
|
| 825 |
+
with gr.Row(equal_height=False):
|
| 826 |
+
with gr.Column(scale=2):
|
| 827 |
+
file_in = gr.File(
|
| 828 |
+
label="Upload your file. Supported formats: .md, .txt, .docx, .pdf",
|
| 829 |
+
file_count="multiple",
|
| 830 |
+
file_types=[".md", ".txt", ".docx", ".pdf"],
|
| 831 |
+
)
|
| 832 |
+
with gr.Row():
|
| 833 |
+
clear_file_btn_part = gr.Button(
|
| 834 |
+
"Clear", variant="secondary"
|
| 835 |
+
)
|
| 836 |
+
load_file_btn = gr.Button("Load", variant="primary")
|
| 837 |
+
with gr.Column(scale=3):
|
| 838 |
+
file_out = gr.HTML(
|
| 839 |
+
label="Dataset preview", visible=False
|
| 840 |
+
)
|
| 841 |
+
|
| 842 |
+
gr.HTML(value="<hr>")
|
| 843 |
+
gr.Markdown(value="## 2. Configure your dataset")
|
| 844 |
+
with gr.Row(equal_height=False):
|
| 845 |
+
with gr.Column(scale=2):
|
| 846 |
+
system_prompt = gr.Textbox(
|
| 847 |
+
label="System prompt",
|
| 848 |
+
placeholder="You are a helpful assistant.",
|
| 849 |
+
)
|
| 850 |
+
document_column = gr.Dropdown(
|
| 851 |
+
label="Document Column",
|
| 852 |
+
info="Select the document column to generate the chat data",
|
| 853 |
+
choices=["Load your data first in step 1."],
|
| 854 |
+
value="Load your data first in step 1.",
|
| 855 |
+
interactive=False,
|
| 856 |
+
multiselect=False,
|
| 857 |
+
allow_custom_value=False,
|
| 858 |
+
visible=False,
|
| 859 |
+
)
|
| 860 |
+
num_turns = gr.Number(
|
| 861 |
+
value=1,
|
| 862 |
+
label="Number of turns in the conversation",
|
| 863 |
+
minimum=1,
|
| 864 |
+
maximum=4,
|
| 865 |
+
step=1,
|
| 866 |
+
interactive=True,
|
| 867 |
+
info="Choose between 1 (single turn with 'instruction-response' columns) and 2-4 (multi-turn conversation with a 'messages' column).",
|
| 868 |
+
)
|
| 869 |
+
with gr.Row():
|
| 870 |
+
clear_btn_full = gr.Button(
|
| 871 |
+
"Clear",
|
| 872 |
+
variant="secondary",
|
| 873 |
+
)
|
| 874 |
+
btn_apply_to_sample_dataset = gr.Button(
|
| 875 |
+
"Save", variant="primary"
|
| 876 |
+
)
|
| 877 |
+
with gr.Column(scale=3):
|
| 878 |
+
dataframe = _get_dataframe()
|
| 879 |
+
|
| 880 |
+
gr.HTML(value="<hr>")
|
| 881 |
+
gr.Markdown(value="## 3. Generate your dataset")
|
| 882 |
+
with gr.Row(equal_height=False):
|
| 883 |
+
with gr.Column(scale=2):
|
| 884 |
+
org_name = get_org_dropdown()
|
| 885 |
+
repo_name = gr.Textbox(
|
| 886 |
+
label="Repo name",
|
| 887 |
+
placeholder="dataset_name",
|
| 888 |
+
value=f"my-distiset-{str(uuid.uuid4())[:8]}",
|
| 889 |
+
interactive=True,
|
| 890 |
+
)
|
| 891 |
+
num_rows = gr.Number(
|
| 892 |
+
label="Number of rows",
|
| 893 |
+
value=10,
|
| 894 |
+
interactive=True,
|
| 895 |
+
scale=1,
|
| 896 |
+
)
|
| 897 |
+
temperature = gr.Slider(
|
| 898 |
+
label="Temperature",
|
| 899 |
+
minimum=0.1,
|
| 900 |
+
maximum=1.5,
|
| 901 |
+
value=0.9,
|
| 902 |
+
step=0.1,
|
| 903 |
+
interactive=True,
|
| 904 |
+
)
|
| 905 |
+
temperature_completion = gr.Slider(
|
| 906 |
+
label="Temperature for completion",
|
| 907 |
+
minimum=0.1,
|
| 908 |
+
maximum=1.5,
|
| 909 |
+
value=None,
|
| 910 |
+
step=0.1,
|
| 911 |
+
interactive=True,
|
| 912 |
+
visible=False,
|
| 913 |
+
)
|
| 914 |
+
private = gr.Checkbox(
|
| 915 |
+
label="Private dataset",
|
| 916 |
+
value=False,
|
| 917 |
+
interactive=True,
|
| 918 |
+
scale=1,
|
| 919 |
+
)
|
| 920 |
+
btn_push_to_hub = gr.Button(
|
| 921 |
+
"Push to Hub", variant="primary", scale=2
|
| 922 |
+
)
|
| 923 |
+
btn_save_local = gr.Button(
|
| 924 |
+
"Save locally", variant="primary", scale=2, visible=False
|
| 925 |
+
)
|
| 926 |
+
with gr.Column(scale=3):
|
| 927 |
+
csv_file = gr.File(
|
| 928 |
+
label="CSV",
|
| 929 |
+
elem_classes="datasets",
|
| 930 |
+
visible=False,
|
| 931 |
+
)
|
| 932 |
+
json_file = gr.File(
|
| 933 |
+
label="JSON",
|
| 934 |
+
elem_classes="datasets",
|
| 935 |
+
visible=False,
|
| 936 |
+
)
|
| 937 |
+
success_message = gr.Markdown(
|
| 938 |
+
visible=False,
|
| 939 |
+
min_height=0 # don't remove this otherwise progress is not visible
|
| 940 |
+
)
|
| 941 |
+
with gr.Accordion(
|
| 942 |
+
"Customize your pipeline with distilabel",
|
| 943 |
+
open=False,
|
| 944 |
+
visible=False,
|
| 945 |
+
) as pipeline_code_ui:
|
| 946 |
+
code = generate_pipeline_code(
|
| 947 |
+
repo_id=search_in.value,
|
| 948 |
+
input_type=input_type.value,
|
| 949 |
+
system_prompt=system_prompt.value,
|
| 950 |
+
document_column=document_column.value,
|
| 951 |
+
num_turns=num_turns.value,
|
| 952 |
+
num_rows=num_rows.value,
|
| 953 |
+
)
|
| 954 |
+
pipeline_code = gr.Code(
|
| 955 |
+
value=code,
|
| 956 |
+
language="python",
|
| 957 |
+
label="Distilabel Pipeline Code",
|
| 958 |
+
)
|
| 959 |
+
|
| 960 |
+
tab_prompt_input.select(
|
| 961 |
+
fn=lambda: "prompt-input",
|
| 962 |
+
inputs=[],
|
| 963 |
+
outputs=[input_type],
|
| 964 |
+
).then(fn=show_system_prompt_visibility, inputs=[], outputs=[system_prompt]).then(
|
| 965 |
+
fn=hide_document_column_visibility, inputs=[], outputs=[document_column]
|
| 966 |
+
)
|
| 967 |
+
|
| 968 |
+
tab_dataset_input.select(
|
| 969 |
+
fn=lambda: "dataset-input",
|
| 970 |
+
inputs=[],
|
| 971 |
+
outputs=[input_type],
|
| 972 |
+
).then(fn=hide_system_prompt_visibility, inputs=[], outputs=[system_prompt]).then(
|
| 973 |
+
fn=show_document_column_visibility, inputs=[], outputs=[document_column]
|
| 974 |
+
)
|
| 975 |
+
|
| 976 |
+
tab_file_input.select(
|
| 977 |
+
fn=lambda: "file-input",
|
| 978 |
+
inputs=[],
|
| 979 |
+
outputs=[input_type],
|
| 980 |
+
).then(fn=hide_system_prompt_visibility, inputs=[], outputs=[system_prompt]).then(
|
| 981 |
+
fn=show_document_column_visibility, inputs=[], outputs=[document_column]
|
| 982 |
+
)
|
| 983 |
+
|
| 984 |
+
search_in.submit(
|
| 985 |
+
fn=lambda df: pd.DataFrame(columns=df.columns),
|
| 986 |
+
inputs=[dataframe],
|
| 987 |
+
outputs=[dataframe],
|
| 988 |
+
)
|
| 989 |
+
|
| 990 |
+
load_prompt_btn.click(
|
| 991 |
+
fn=generate_system_prompt,
|
| 992 |
+
inputs=[dataset_description],
|
| 993 |
+
outputs=[system_prompt],
|
| 994 |
+
).success(
|
| 995 |
+
fn=generate_sample_dataset,
|
| 996 |
+
inputs=[
|
| 997 |
+
search_in,
|
| 998 |
+
file_in,
|
| 999 |
+
input_type,
|
| 1000 |
+
system_prompt,
|
| 1001 |
+
document_column,
|
| 1002 |
+
num_turns,
|
| 1003 |
+
num_rows,
|
| 1004 |
+
],
|
| 1005 |
+
outputs=dataframe,
|
| 1006 |
+
)
|
| 1007 |
+
|
| 1008 |
+
gr.on(
|
| 1009 |
+
triggers=[load_dataset_btn.click, load_file_btn.click],
|
| 1010 |
+
fn=load_dataset_file,
|
| 1011 |
+
inputs=[search_in, file_in, input_type],
|
| 1012 |
+
outputs=[dataframe, document_column],
|
| 1013 |
+
)
|
| 1014 |
+
|
| 1015 |
+
btn_apply_to_sample_dataset.click(
|
| 1016 |
+
fn=generate_sample_dataset,
|
| 1017 |
+
inputs=[
|
| 1018 |
+
search_in,
|
| 1019 |
+
file_in,
|
| 1020 |
+
input_type,
|
| 1021 |
+
system_prompt,
|
| 1022 |
+
document_column,
|
| 1023 |
+
num_turns,
|
| 1024 |
+
num_rows,
|
| 1025 |
+
],
|
| 1026 |
+
outputs=dataframe,
|
| 1027 |
+
)
|
| 1028 |
+
|
| 1029 |
+
btn_push_to_hub.click(
|
| 1030 |
+
fn=validate_argilla_user_workspace_dataset,
|
| 1031 |
+
inputs=[repo_name],
|
| 1032 |
+
outputs=[success_message],
|
| 1033 |
+
).then(
|
| 1034 |
+
fn=validate_push_to_hub,
|
| 1035 |
+
inputs=[org_name, repo_name],
|
| 1036 |
+
outputs=[success_message],
|
| 1037 |
+
).success(
|
| 1038 |
+
fn=hide_save_local,
|
| 1039 |
+
outputs=[csv_file, json_file, success_message],
|
| 1040 |
+
).success(
|
| 1041 |
+
fn=hide_success_message,
|
| 1042 |
+
outputs=[success_message],
|
| 1043 |
+
).success(
|
| 1044 |
+
fn=hide_pipeline_code_visibility,
|
| 1045 |
+
inputs=[],
|
| 1046 |
+
outputs=[pipeline_code_ui],
|
| 1047 |
+
).success(
|
| 1048 |
+
fn=push_dataset,
|
| 1049 |
+
inputs=[
|
| 1050 |
+
org_name,
|
| 1051 |
+
repo_name,
|
| 1052 |
+
private,
|
| 1053 |
+
search_in,
|
| 1054 |
+
file_in,
|
| 1055 |
+
input_type,
|
| 1056 |
+
system_prompt,
|
| 1057 |
+
document_column,
|
| 1058 |
+
num_turns,
|
| 1059 |
+
num_rows,
|
| 1060 |
+
temperature,
|
| 1061 |
+
temperature_completion,
|
| 1062 |
+
pipeline_code,
|
| 1063 |
+
],
|
| 1064 |
+
outputs=[success_message],
|
| 1065 |
+
).success(
|
| 1066 |
+
fn=show_success_message,
|
| 1067 |
+
inputs=[org_name, repo_name],
|
| 1068 |
+
outputs=[success_message],
|
| 1069 |
+
).success(
|
| 1070 |
+
fn=generate_pipeline_code,
|
| 1071 |
+
inputs=[
|
| 1072 |
+
search_in,
|
| 1073 |
+
input_type,
|
| 1074 |
+
system_prompt,
|
| 1075 |
+
document_column,
|
| 1076 |
+
num_turns,
|
| 1077 |
+
num_rows,
|
| 1078 |
+
],
|
| 1079 |
+
outputs=[pipeline_code],
|
| 1080 |
+
).success(
|
| 1081 |
+
fn=show_pipeline_code_visibility,
|
| 1082 |
+
inputs=[],
|
| 1083 |
+
outputs=[pipeline_code_ui],
|
| 1084 |
+
)
|
| 1085 |
+
|
| 1086 |
+
btn_save_local.click(
|
| 1087 |
+
fn=hide_success_message,
|
| 1088 |
+
outputs=[success_message],
|
| 1089 |
+
).success(
|
| 1090 |
+
fn=hide_pipeline_code_visibility,
|
| 1091 |
+
inputs=[],
|
| 1092 |
+
outputs=[pipeline_code_ui],
|
| 1093 |
+
).success(
|
| 1094 |
+
fn=show_save_local,
|
| 1095 |
+
inputs=[],
|
| 1096 |
+
outputs=[csv_file, json_file, success_message],
|
| 1097 |
+
).success(
|
| 1098 |
+
save_local,
|
| 1099 |
+
inputs=[
|
| 1100 |
+
search_in,
|
| 1101 |
+
file_in,
|
| 1102 |
+
input_type,
|
| 1103 |
+
system_prompt,
|
| 1104 |
+
document_column,
|
| 1105 |
+
num_turns,
|
| 1106 |
+
num_rows,
|
| 1107 |
+
temperature,
|
| 1108 |
+
repo_name,
|
| 1109 |
+
temperature_completion,
|
| 1110 |
+
],
|
| 1111 |
+
outputs=[csv_file, json_file],
|
| 1112 |
+
).success(
|
| 1113 |
+
fn=generate_pipeline_code,
|
| 1114 |
+
inputs=[
|
| 1115 |
+
search_in,
|
| 1116 |
+
input_type,
|
| 1117 |
+
system_prompt,
|
| 1118 |
+
document_column,
|
| 1119 |
+
num_turns,
|
| 1120 |
+
num_rows,
|
| 1121 |
+
],
|
| 1122 |
+
outputs=[pipeline_code],
|
| 1123 |
+
).success(
|
| 1124 |
+
fn=show_pipeline_code_visibility,
|
| 1125 |
+
inputs=[],
|
| 1126 |
+
outputs=[pipeline_code_ui],
|
| 1127 |
+
)
|
| 1128 |
+
|
| 1129 |
+
clear_dataset_btn_part.click(fn=lambda: "", inputs=[], outputs=[search_in])
|
| 1130 |
+
clear_file_btn_part.click(fn=lambda: None, inputs=[], outputs=[file_in])
|
| 1131 |
+
clear_prompt_btn_part.click(fn=lambda: "", inputs=[], outputs=[dataset_description])
|
| 1132 |
+
clear_btn_full.click(
|
| 1133 |
+
fn=lambda df: ("", "", [], _get_dataframe()),
|
| 1134 |
+
inputs=[dataframe],
|
| 1135 |
+
outputs=[system_prompt, document_column, num_turns, dataframe],
|
| 1136 |
+
)
|
| 1137 |
+
app.load(fn=swap_visibility, outputs=main_ui)
|
| 1138 |
+
app.load(fn=get_org_dropdown, outputs=[org_name])
|
| 1139 |
+
app.load(fn=get_random_repo_name, outputs=[repo_name])
|
| 1140 |
+
app.load(fn=show_temperature_completion, outputs=[temperature_completion])
|
| 1141 |
+
if SAVE_LOCAL_DIR is not None:
|
| 1142 |
+
app.load(fn=show_save_local_button, outputs=btn_save_local)
|
src/synthetic_dataset_generator/apps/eval.py
ADDED
|
@@ -0,0 +1,894 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import uuid
|
| 3 |
+
from typing import Union
|
| 4 |
+
|
| 5 |
+
import argilla as rg
|
| 6 |
+
import gradio as gr
|
| 7 |
+
import numpy as np
|
| 8 |
+
import pandas as pd
|
| 9 |
+
from datasets import (
|
| 10 |
+
Dataset,
|
| 11 |
+
get_dataset_config_names,
|
| 12 |
+
get_dataset_split_names,
|
| 13 |
+
load_dataset,
|
| 14 |
+
)
|
| 15 |
+
from distilabel.distiset import Distiset
|
| 16 |
+
from gradio.oauth import OAuthToken #
|
| 17 |
+
from gradio_huggingfacehub_search import HuggingfaceHubSearch
|
| 18 |
+
from huggingface_hub import HfApi
|
| 19 |
+
|
| 20 |
+
from synthetic_dataset_generator.apps.base import (
|
| 21 |
+
combine_datasets,
|
| 22 |
+
get_iframe,
|
| 23 |
+
hide_success_message,
|
| 24 |
+
push_pipeline_code_to_hub,
|
| 25 |
+
show_success_message,
|
| 26 |
+
test_max_num_rows,
|
| 27 |
+
validate_argilla_user_workspace_dataset,
|
| 28 |
+
validate_push_to_hub,
|
| 29 |
+
)
|
| 30 |
+
from synthetic_dataset_generator.constants import DEFAULT_BATCH_SIZE
|
| 31 |
+
from synthetic_dataset_generator.pipelines.embeddings import (
|
| 32 |
+
get_embeddings,
|
| 33 |
+
get_sentence_embedding_dimensions,
|
| 34 |
+
)
|
| 35 |
+
from synthetic_dataset_generator.pipelines.eval import (
|
| 36 |
+
generate_pipeline_code,
|
| 37 |
+
get_custom_evaluator,
|
| 38 |
+
get_ultrafeedback_evaluator,
|
| 39 |
+
)
|
| 40 |
+
from synthetic_dataset_generator.utils import (
|
| 41 |
+
column_to_list,
|
| 42 |
+
extract_column_names,
|
| 43 |
+
get_argilla_client,
|
| 44 |
+
get_org_dropdown,
|
| 45 |
+
get_random_repo_name,
|
| 46 |
+
pad_or_truncate_list,
|
| 47 |
+
process_columns,
|
| 48 |
+
swap_visibility,
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def get_valid_columns(dataframe: pd.DataFrame):
|
| 53 |
+
instruction_valid_columns = []
|
| 54 |
+
response_valid_columns = []
|
| 55 |
+
|
| 56 |
+
for col in dataframe.columns:
|
| 57 |
+
sample_val = dataframe[col].iloc[0]
|
| 58 |
+
if isinstance(sample_val, str) or (
|
| 59 |
+
isinstance(sample_val, (list, np.ndarray))
|
| 60 |
+
and all(isinstance(item, dict) and "role" in item for item in sample_val)
|
| 61 |
+
):
|
| 62 |
+
instruction_valid_columns.append(col)
|
| 63 |
+
response_valid_columns.append(col)
|
| 64 |
+
if isinstance(sample_val, (list, np.ndarray)) and all(
|
| 65 |
+
isinstance(item, str) for item in sample_val
|
| 66 |
+
):
|
| 67 |
+
response_valid_columns.append(col)
|
| 68 |
+
|
| 69 |
+
return instruction_valid_columns, response_valid_columns
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def load_dataset_from_hub(
|
| 73 |
+
repo_id: str, num_rows: int = 10, token: Union[OAuthToken, None] = None
|
| 74 |
+
):
|
| 75 |
+
if not repo_id:
|
| 76 |
+
raise gr.Error("Hub repo id is required")
|
| 77 |
+
subsets = get_dataset_config_names(repo_id, token=token)
|
| 78 |
+
splits = get_dataset_split_names(repo_id, subsets[0], token=token)
|
| 79 |
+
ds = load_dataset(repo_id, subsets[0], split=splits[0], token=token, streaming=True)
|
| 80 |
+
rows = []
|
| 81 |
+
for idx, row in enumerate(ds):
|
| 82 |
+
rows.append(row)
|
| 83 |
+
if idx == num_rows:
|
| 84 |
+
break
|
| 85 |
+
ds = Dataset.from_list(rows)
|
| 86 |
+
dataframe = ds.to_pandas()
|
| 87 |
+
instruction_valid_columns, response_valid_columns = get_valid_columns(dataframe)
|
| 88 |
+
col_instruction = instruction_valid_columns[0] if instruction_valid_columns else ""
|
| 89 |
+
col_response = "No valid response columns found."
|
| 90 |
+
for col in response_valid_columns:
|
| 91 |
+
if col != col_instruction:
|
| 92 |
+
col_response = col
|
| 93 |
+
break
|
| 94 |
+
|
| 95 |
+
prompt_template = gr.Code(
|
| 96 |
+
label="Prompt template",
|
| 97 |
+
value="\n".join(
|
| 98 |
+
[
|
| 99 |
+
"Evaluate the following text based on criteria.",
|
| 100 |
+
"Criteria: quality.",
|
| 101 |
+
"Score: between 1 and 10.",
|
| 102 |
+
"Text: {{" + col_response + "}}",
|
| 103 |
+
]
|
| 104 |
+
),
|
| 105 |
+
language="jinja2",
|
| 106 |
+
interactive=True,
|
| 107 |
+
)
|
| 108 |
+
structured_output = gr.Code(
|
| 109 |
+
label="Structured output",
|
| 110 |
+
value=json.dumps(
|
| 111 |
+
{
|
| 112 |
+
"type": "object",
|
| 113 |
+
"properties": {"quality": {"type": "integer"}},
|
| 114 |
+
"required": ["quality"],
|
| 115 |
+
},
|
| 116 |
+
indent=4,
|
| 117 |
+
),
|
| 118 |
+
language="json",
|
| 119 |
+
interactive=True,
|
| 120 |
+
)
|
| 121 |
+
return (
|
| 122 |
+
dataframe,
|
| 123 |
+
gr.Dropdown(
|
| 124 |
+
choices=instruction_valid_columns,
|
| 125 |
+
label="Instruction column",
|
| 126 |
+
value=col_instruction,
|
| 127 |
+
interactive=True,
|
| 128 |
+
),
|
| 129 |
+
gr.Dropdown(
|
| 130 |
+
choices=response_valid_columns,
|
| 131 |
+
label="Response column",
|
| 132 |
+
value=col_response,
|
| 133 |
+
interactive=(
|
| 134 |
+
False if col_response == "No valid response columns found." else True
|
| 135 |
+
),
|
| 136 |
+
),
|
| 137 |
+
prompt_template,
|
| 138 |
+
structured_output,
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def define_evaluation_aspects(task_type: str):
|
| 143 |
+
if task_type == "chat-eval":
|
| 144 |
+
return gr.Dropdown(
|
| 145 |
+
value=["overall-rating"],
|
| 146 |
+
choices=["helpfulness", "truthfulness", "overall-rating", "honesty"],
|
| 147 |
+
label="Evaluation Aspects",
|
| 148 |
+
multiselect=True,
|
| 149 |
+
interactive=True,
|
| 150 |
+
)
|
| 151 |
+
else:
|
| 152 |
+
return gr.Dropdown(interactive=False, visible=False)
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
def evaluate_instruction_response(
|
| 156 |
+
dataframe: pd.DataFrame,
|
| 157 |
+
aspects: list[str],
|
| 158 |
+
instruction_column: str,
|
| 159 |
+
response_columns: str,
|
| 160 |
+
num_rows: int = 10,
|
| 161 |
+
is_sample: bool = False,
|
| 162 |
+
progress=gr.Progress(),
|
| 163 |
+
):
|
| 164 |
+
progress(0.0, desc="Evaluating instructions and responses")
|
| 165 |
+
data = process_columns(dataframe, instruction_column, response_columns)
|
| 166 |
+
num_generations = len(data[0]["generations"])
|
| 167 |
+
evaluated_results = []
|
| 168 |
+
for entry in data:
|
| 169 |
+
result_row = {
|
| 170 |
+
"instruction": entry["instruction"],
|
| 171 |
+
"generations": entry["generations"],
|
| 172 |
+
}
|
| 173 |
+
for aspect in aspects:
|
| 174 |
+
result_row[f"ratings_{aspect}"] = None
|
| 175 |
+
result_row[f"rationale_for_ratings_{aspect}"] = None
|
| 176 |
+
if aspect in ["truthfulness", "helpfulness"]:
|
| 177 |
+
result_row[f"type_{aspect}"] = None
|
| 178 |
+
result_row[f"rationale_for_type_{aspect}"] = None
|
| 179 |
+
result_row["model_name"] = None
|
| 180 |
+
evaluated_results.append(result_row)
|
| 181 |
+
|
| 182 |
+
batch_size = DEFAULT_BATCH_SIZE
|
| 183 |
+
total_steps: int = len(aspects) * num_rows
|
| 184 |
+
|
| 185 |
+
# evaluate instructions and responses
|
| 186 |
+
for aspect in aspects:
|
| 187 |
+
ultrafeedback_evaluator = get_ultrafeedback_evaluator(aspect, is_sample)
|
| 188 |
+
n_processed = 0
|
| 189 |
+
|
| 190 |
+
while n_processed < num_rows:
|
| 191 |
+
progress(
|
| 192 |
+
(len(aspects) * n_processed) / total_steps,
|
| 193 |
+
total=total_steps,
|
| 194 |
+
desc=f"Evaluating aspect: {aspect}",
|
| 195 |
+
)
|
| 196 |
+
|
| 197 |
+
remaining_rows = num_rows - n_processed
|
| 198 |
+
batch_size = min(batch_size, remaining_rows)
|
| 199 |
+
inputs = data[n_processed : n_processed + batch_size]
|
| 200 |
+
batch_results = list(ultrafeedback_evaluator.process(inputs=inputs))
|
| 201 |
+
for j, result in enumerate(batch_results[0]):
|
| 202 |
+
idx = n_processed + j
|
| 203 |
+
evaluated_results[idx][f"ratings_{aspect}"] = pad_or_truncate_list(
|
| 204 |
+
result.get("ratings"), num_generations
|
| 205 |
+
)
|
| 206 |
+
evaluated_results[idx]["model_name"] = result.get("model_name")
|
| 207 |
+
if aspect in ["truthfulness", "helpfulness"]:
|
| 208 |
+
evaluated_results[idx][f"type_{aspect}"] = pad_or_truncate_list(
|
| 209 |
+
result.get("types"), num_generations
|
| 210 |
+
)
|
| 211 |
+
evaluated_results[idx][f"rationale_for_type_{aspect}"] = (
|
| 212 |
+
pad_or_truncate_list(result.get("rationales"), num_generations)
|
| 213 |
+
)
|
| 214 |
+
evaluated_results[idx][f"rationale_for_ratings_{aspect}"] = (
|
| 215 |
+
pad_or_truncate_list(
|
| 216 |
+
result.get("rationales-for-ratings"), num_generations
|
| 217 |
+
)
|
| 218 |
+
)
|
| 219 |
+
else:
|
| 220 |
+
evaluated_results[idx][f"rationale_for_ratings_{aspect}"] = (
|
| 221 |
+
pad_or_truncate_list(result.get("rationales"), num_generations)
|
| 222 |
+
)
|
| 223 |
+
n_processed += batch_size
|
| 224 |
+
|
| 225 |
+
# create final dataset
|
| 226 |
+
dataframe = pd.DataFrame(evaluated_results)
|
| 227 |
+
progress(1.0, desc="Dataset evaluation completed")
|
| 228 |
+
return dataframe
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
def evaluate_custom(
|
| 232 |
+
dataframe: pd.DataFrame,
|
| 233 |
+
prompt_template: str,
|
| 234 |
+
structured_output: dict,
|
| 235 |
+
num_rows: int = 10,
|
| 236 |
+
is_sample: bool = False,
|
| 237 |
+
progress=gr.Progress(),
|
| 238 |
+
):
|
| 239 |
+
progress(0.0, desc="Evaluating dataset")
|
| 240 |
+
columns = extract_column_names(prompt_template)
|
| 241 |
+
input_columns = {column: column_to_list(dataframe, column) for column in columns}
|
| 242 |
+
|
| 243 |
+
custom_evaluator = get_custom_evaluator(
|
| 244 |
+
prompt_template, structured_output, columns, is_sample
|
| 245 |
+
)
|
| 246 |
+
batch_size = DEFAULT_BATCH_SIZE
|
| 247 |
+
|
| 248 |
+
# evaluate the data
|
| 249 |
+
n_processed = 0
|
| 250 |
+
evaluation_results = []
|
| 251 |
+
while n_processed < num_rows:
|
| 252 |
+
progress(
|
| 253 |
+
n_processed / num_rows,
|
| 254 |
+
desc="Evaluating dataset",
|
| 255 |
+
)
|
| 256 |
+
remaining_rows = num_rows - n_processed
|
| 257 |
+
batch_size = min(batch_size, remaining_rows)
|
| 258 |
+
|
| 259 |
+
inputs = []
|
| 260 |
+
for idx in range(n_processed, n_processed + batch_size):
|
| 261 |
+
input = {column: input_columns[column][idx] for column in input_columns}
|
| 262 |
+
inputs.append(input)
|
| 263 |
+
|
| 264 |
+
batch = list(custom_evaluator.process(inputs=inputs))
|
| 265 |
+
evaluation_results.extend(batch[0])
|
| 266 |
+
n_processed += batch_size
|
| 267 |
+
|
| 268 |
+
# create final dataset
|
| 269 |
+
distiset_results = []
|
| 270 |
+
for result in evaluation_results:
|
| 271 |
+
record = {key: result[key] for key in result if key != "distilabel_metadata"}
|
| 272 |
+
distiset_results.append(record)
|
| 273 |
+
|
| 274 |
+
dataframe = pd.DataFrame(distiset_results)
|
| 275 |
+
progress(1.0, desc="Dataset evaluation completed")
|
| 276 |
+
return dataframe
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
def _evaluate_dataset(
|
| 280 |
+
dataframe: pd.DataFrame,
|
| 281 |
+
eval_type: str,
|
| 282 |
+
aspects_instruction_response: list[str],
|
| 283 |
+
instruction_instruction_response: str,
|
| 284 |
+
response_instruction_response: str,
|
| 285 |
+
prompt_template: str,
|
| 286 |
+
structured_output: dict,
|
| 287 |
+
num_rows: int = 10,
|
| 288 |
+
is_sample: bool = False,
|
| 289 |
+
):
|
| 290 |
+
num_rows = test_max_num_rows(num_rows)
|
| 291 |
+
if eval_type == "chat-eval":
|
| 292 |
+
dataframe = evaluate_instruction_response(
|
| 293 |
+
dataframe=dataframe,
|
| 294 |
+
aspects=aspects_instruction_response,
|
| 295 |
+
instruction_column=instruction_instruction_response,
|
| 296 |
+
response_columns=response_instruction_response,
|
| 297 |
+
num_rows=num_rows,
|
| 298 |
+
is_sample=is_sample,
|
| 299 |
+
)
|
| 300 |
+
else:
|
| 301 |
+
dataframe = evaluate_custom(
|
| 302 |
+
dataframe=dataframe,
|
| 303 |
+
prompt_template=prompt_template,
|
| 304 |
+
structured_output=structured_output,
|
| 305 |
+
num_rows=num_rows,
|
| 306 |
+
is_sample=is_sample,
|
| 307 |
+
)
|
| 308 |
+
return dataframe
|
| 309 |
+
|
| 310 |
+
|
| 311 |
+
def evaluate_sample_dataset(
|
| 312 |
+
repo_id: str,
|
| 313 |
+
eval_type: str,
|
| 314 |
+
aspects_instruction_response: list[str],
|
| 315 |
+
instruction_instruction_response: str,
|
| 316 |
+
response_instruction_response: str,
|
| 317 |
+
prompt_template: str,
|
| 318 |
+
structured_output: dict,
|
| 319 |
+
):
|
| 320 |
+
dataframe, _, _, _, _ = load_dataset_from_hub(repo_id, num_rows=10)
|
| 321 |
+
dataframe = _evaluate_dataset(
|
| 322 |
+
dataframe=dataframe,
|
| 323 |
+
eval_type=eval_type,
|
| 324 |
+
aspects_instruction_response=aspects_instruction_response,
|
| 325 |
+
instruction_instruction_response=instruction_instruction_response,
|
| 326 |
+
response_instruction_response=response_instruction_response,
|
| 327 |
+
prompt_template=prompt_template,
|
| 328 |
+
structured_output=structured_output,
|
| 329 |
+
num_rows=10,
|
| 330 |
+
is_sample=True,
|
| 331 |
+
)
|
| 332 |
+
return dataframe
|
| 333 |
+
|
| 334 |
+
|
| 335 |
+
def push_dataset_to_hub(
|
| 336 |
+
dataframe: pd.DataFrame,
|
| 337 |
+
org_name: str,
|
| 338 |
+
repo_name: str,
|
| 339 |
+
oauth_token: Union[gr.OAuthToken, None],
|
| 340 |
+
private: bool,
|
| 341 |
+
pipeline_code: str,
|
| 342 |
+
progress=gr.Progress(),
|
| 343 |
+
):
|
| 344 |
+
progress(0.0, desc="Validating")
|
| 345 |
+
repo_id = validate_push_to_hub(org_name, repo_name)
|
| 346 |
+
progress(0.5, desc="Creating dataset")
|
| 347 |
+
dataset = Dataset.from_pandas(dataframe)
|
| 348 |
+
dataset = combine_datasets(repo_id, dataset, oauth_token)
|
| 349 |
+
distiset = Distiset({"default": dataset})
|
| 350 |
+
progress(0.9, desc="Pushing dataset")
|
| 351 |
+
distiset.push_to_hub(
|
| 352 |
+
repo_id=repo_id,
|
| 353 |
+
private=private,
|
| 354 |
+
include_script=False,
|
| 355 |
+
token=oauth_token.token,
|
| 356 |
+
create_pr=False,
|
| 357 |
+
)
|
| 358 |
+
push_pipeline_code_to_hub(pipeline_code, org_name, repo_name, oauth_token)
|
| 359 |
+
progress(1.0, desc="Dataset pushed")
|
| 360 |
+
return dataframe
|
| 361 |
+
|
| 362 |
+
|
| 363 |
+
def push_dataset(
|
| 364 |
+
org_name: str,
|
| 365 |
+
repo_name: str,
|
| 366 |
+
private: bool,
|
| 367 |
+
num_rows: int,
|
| 368 |
+
original_repo_id: str,
|
| 369 |
+
eval_type: str,
|
| 370 |
+
aspects_instruction_response: list[str],
|
| 371 |
+
instruction_instruction_response: str,
|
| 372 |
+
response_instruction_response: str,
|
| 373 |
+
prompt_template: str,
|
| 374 |
+
structured_output: dict,
|
| 375 |
+
pipeline_code: str,
|
| 376 |
+
oauth_token: Union[gr.OAuthToken, None] = None,
|
| 377 |
+
progress=gr.Progress(),
|
| 378 |
+
) -> pd.DataFrame:
|
| 379 |
+
dataframe, _, _, _, _ = load_dataset_from_hub(original_repo_id, num_rows=num_rows)
|
| 380 |
+
dataframe = _evaluate_dataset(
|
| 381 |
+
dataframe=dataframe,
|
| 382 |
+
eval_type=eval_type,
|
| 383 |
+
aspects_instruction_response=aspects_instruction_response,
|
| 384 |
+
instruction_instruction_response=instruction_instruction_response,
|
| 385 |
+
response_instruction_response=response_instruction_response,
|
| 386 |
+
prompt_template=prompt_template,
|
| 387 |
+
structured_output=structured_output,
|
| 388 |
+
num_rows=num_rows,
|
| 389 |
+
)
|
| 390 |
+
push_dataset_to_hub(
|
| 391 |
+
dataframe, org_name, repo_name, oauth_token, private, pipeline_code
|
| 392 |
+
)
|
| 393 |
+
try:
|
| 394 |
+
progress(0.1, desc="Setting up user and workspace")
|
| 395 |
+
hf_user = HfApi().whoami(token=oauth_token.token)["name"]
|
| 396 |
+
client = get_argilla_client()
|
| 397 |
+
if client is None:
|
| 398 |
+
return ""
|
| 399 |
+
progress(0.5, desc="Creating dataset in Argilla")
|
| 400 |
+
if eval_type == "chat-eval":
|
| 401 |
+
num_generations = len((dataframe["generations"][0]))
|
| 402 |
+
fields = [
|
| 403 |
+
rg.ChatField(
|
| 404 |
+
name=f"chat_{i}",
|
| 405 |
+
title=f"Chat {i+1}",
|
| 406 |
+
description=f"User and assistant conversation for generation {i+1}",
|
| 407 |
+
)
|
| 408 |
+
for i in range(num_generations)
|
| 409 |
+
]
|
| 410 |
+
questions = []
|
| 411 |
+
for i in range(num_generations):
|
| 412 |
+
for aspect in aspects_instruction_response:
|
| 413 |
+
questions.append(
|
| 414 |
+
rg.RatingQuestion(
|
| 415 |
+
name=f"ratings_{aspect}_{i}",
|
| 416 |
+
values=list(range(11)),
|
| 417 |
+
title=f"Ratings for {aspect} for response {i+1}",
|
| 418 |
+
required=True,
|
| 419 |
+
)
|
| 420 |
+
)
|
| 421 |
+
questions.append(
|
| 422 |
+
rg.TextQuestion(
|
| 423 |
+
name=f"rationale_for_ratings_{aspect}_{i}",
|
| 424 |
+
title=f"Rationale for ratings for {aspect} for response {i+1}",
|
| 425 |
+
required=False,
|
| 426 |
+
use_markdown=True,
|
| 427 |
+
)
|
| 428 |
+
)
|
| 429 |
+
if aspect in ["truthfulness", "helpfulness"]:
|
| 430 |
+
questions.append(
|
| 431 |
+
rg.RatingQuestion(
|
| 432 |
+
name=f"type_{aspect}_{i}",
|
| 433 |
+
values=list(range(1, 6)),
|
| 434 |
+
title=f"The type of the response {i+1} for {aspect}",
|
| 435 |
+
required=True,
|
| 436 |
+
)
|
| 437 |
+
)
|
| 438 |
+
questions.append(
|
| 439 |
+
rg.TextQuestion(
|
| 440 |
+
name=f"rationale_for_type_{aspect}_{i}",
|
| 441 |
+
title=f"Rationale for type of the response {i+1} for {aspect}",
|
| 442 |
+
required=False,
|
| 443 |
+
use_markdown=True,
|
| 444 |
+
)
|
| 445 |
+
)
|
| 446 |
+
metadata = [
|
| 447 |
+
rg.IntegerMetadataProperty(
|
| 448 |
+
name="instruction_length", title="Instruction length"
|
| 449 |
+
),
|
| 450 |
+
]
|
| 451 |
+
for i in range(num_generations):
|
| 452 |
+
metadata.append(
|
| 453 |
+
rg.IntegerMetadataProperty(
|
| 454 |
+
name=f"response_{i}_length", title=f"Response {i+1} length"
|
| 455 |
+
)
|
| 456 |
+
)
|
| 457 |
+
vectors = [
|
| 458 |
+
rg.VectorField(
|
| 459 |
+
name="instruction_embeddings",
|
| 460 |
+
dimensions=get_sentence_embedding_dimensions(),
|
| 461 |
+
)
|
| 462 |
+
]
|
| 463 |
+
settings = rg.Settings(
|
| 464 |
+
fields=fields,
|
| 465 |
+
questions=questions,
|
| 466 |
+
metadata=metadata,
|
| 467 |
+
vectors=vectors,
|
| 468 |
+
guidelines="Please review the conversation and provide an evaluation.",
|
| 469 |
+
)
|
| 470 |
+
|
| 471 |
+
dataframe["instruction_length"] = dataframe["instruction"].apply(len)
|
| 472 |
+
for i in range(num_generations):
|
| 473 |
+
dataframe[f"response_{i}_length"] = dataframe["generations"].apply(
|
| 474 |
+
lambda gens: len(gens[i]) if i < len(gens) else 0
|
| 475 |
+
)
|
| 476 |
+
dataframe["instruction_embeddings"] = get_embeddings(
|
| 477 |
+
dataframe["instruction"].to_list()
|
| 478 |
+
)
|
| 479 |
+
|
| 480 |
+
rg_dataset = client.datasets(name=repo_name, workspace=hf_user)
|
| 481 |
+
if rg_dataset is None:
|
| 482 |
+
rg_dataset = rg.Dataset(
|
| 483 |
+
name=repo_name,
|
| 484 |
+
workspace=hf_user,
|
| 485 |
+
settings=settings,
|
| 486 |
+
client=client,
|
| 487 |
+
)
|
| 488 |
+
rg_dataset = rg_dataset.create()
|
| 489 |
+
|
| 490 |
+
progress(0.7, desc="Pushing dataset to Argilla")
|
| 491 |
+
hf_dataset = Dataset.from_pandas(dataframe)
|
| 492 |
+
records = []
|
| 493 |
+
for sample in hf_dataset:
|
| 494 |
+
fields = {}
|
| 495 |
+
metadata = {"instruction_length": sample.get("instruction_length", 0)}
|
| 496 |
+
vectors = {
|
| 497 |
+
"instruction_embeddings": sample.get("instruction_embeddings", [])
|
| 498 |
+
}
|
| 499 |
+
suggestions = []
|
| 500 |
+
generations = sample.get("generations", [])
|
| 501 |
+
for i in range(num_generations):
|
| 502 |
+
fields[f"chat_{i}"] = [
|
| 503 |
+
{"role": "user", "content": sample.get("instruction", "")},
|
| 504 |
+
{"role": "assistant", "content": generations[i]},
|
| 505 |
+
]
|
| 506 |
+
metadata[f"response_{i}_length"] = sample.get(
|
| 507 |
+
f"response_{i}_length", 0
|
| 508 |
+
)
|
| 509 |
+
|
| 510 |
+
for aspect in aspects_instruction_response:
|
| 511 |
+
ratings = sample.get(f"ratings_{aspect}", [])
|
| 512 |
+
rationales = sample.get(f"rationale_for_ratings__{aspect}", [])
|
| 513 |
+
|
| 514 |
+
rating_value = (
|
| 515 |
+
ratings[i]
|
| 516 |
+
if ratings and isinstance(ratings[i], int)
|
| 517 |
+
else None
|
| 518 |
+
)
|
| 519 |
+
rationale_value = (
|
| 520 |
+
rationales[i]
|
| 521 |
+
if rationales and isinstance(rationales[i], str)
|
| 522 |
+
else None
|
| 523 |
+
)
|
| 524 |
+
|
| 525 |
+
if rating_value is not None:
|
| 526 |
+
suggestions.append(
|
| 527 |
+
rg.Suggestion(
|
| 528 |
+
question_name=f"ratings_{aspect}_{i}",
|
| 529 |
+
value=rating_value,
|
| 530 |
+
)
|
| 531 |
+
)
|
| 532 |
+
if rationale_value is not None:
|
| 533 |
+
suggestions.append(
|
| 534 |
+
rg.Suggestion(
|
| 535 |
+
question_name=f"rationale_for_ratings_{aspect}_{i}",
|
| 536 |
+
value=rationale_value,
|
| 537 |
+
)
|
| 538 |
+
)
|
| 539 |
+
|
| 540 |
+
if aspect in ["truthfulness", "helpfulness"]:
|
| 541 |
+
types = sample.get(f"type_{aspect}", [])
|
| 542 |
+
rationale_types = sample.get(
|
| 543 |
+
f"rationale_for_type_{aspect}", []
|
| 544 |
+
)
|
| 545 |
+
|
| 546 |
+
type_value = (
|
| 547 |
+
types[i]
|
| 548 |
+
if types and isinstance(types[i], int)
|
| 549 |
+
else None
|
| 550 |
+
)
|
| 551 |
+
rationale_type_value = (
|
| 552 |
+
rationale_types[i]
|
| 553 |
+
if rationale_types
|
| 554 |
+
and isinstance(rationale_types[i], str)
|
| 555 |
+
else None
|
| 556 |
+
)
|
| 557 |
+
if type_value is not None:
|
| 558 |
+
suggestions.append(
|
| 559 |
+
rg.Suggestion(
|
| 560 |
+
question_name=f"type_{aspect}_{i}",
|
| 561 |
+
value=type_value,
|
| 562 |
+
)
|
| 563 |
+
)
|
| 564 |
+
if rationale_type_value is not None:
|
| 565 |
+
suggestions.append(
|
| 566 |
+
rg.Suggestion(
|
| 567 |
+
question_name=f"rationale_for_type_{aspect}_{i}",
|
| 568 |
+
value=rationale_type_value,
|
| 569 |
+
)
|
| 570 |
+
)
|
| 571 |
+
records.append(
|
| 572 |
+
rg.Record(
|
| 573 |
+
fields=fields,
|
| 574 |
+
metadata=metadata,
|
| 575 |
+
vectors=vectors,
|
| 576 |
+
suggestions=suggestions,
|
| 577 |
+
)
|
| 578 |
+
)
|
| 579 |
+
rg_dataset.records.log(records=records)
|
| 580 |
+
progress(1.0, desc="Dataset pushed to Argilla")
|
| 581 |
+
else:
|
| 582 |
+
columns = extract_column_names(prompt_template)
|
| 583 |
+
settings = rg.Settings(
|
| 584 |
+
fields=[
|
| 585 |
+
rg.TextField(
|
| 586 |
+
name=column,
|
| 587 |
+
title=column.capitalize(),
|
| 588 |
+
description="The column content",
|
| 589 |
+
)
|
| 590 |
+
for column in columns
|
| 591 |
+
],
|
| 592 |
+
questions=[
|
| 593 |
+
rg.TextQuestion(
|
| 594 |
+
name="evaluation",
|
| 595 |
+
title="Evaluation",
|
| 596 |
+
description="The generated evaluation",
|
| 597 |
+
use_markdown=True,
|
| 598 |
+
),
|
| 599 |
+
],
|
| 600 |
+
metadata=[
|
| 601 |
+
rg.IntegerMetadataProperty(
|
| 602 |
+
name=f"{column}_length", title=f"{column.capitalize()} length"
|
| 603 |
+
)
|
| 604 |
+
for column in columns
|
| 605 |
+
],
|
| 606 |
+
vectors=[
|
| 607 |
+
rg.VectorField(
|
| 608 |
+
name=f"{column}_embeddings",
|
| 609 |
+
dimensions=get_sentence_embedding_dimensions(),
|
| 610 |
+
)
|
| 611 |
+
for column in columns
|
| 612 |
+
],
|
| 613 |
+
guidelines="Please review, correct and provide an accurate evaluation.",
|
| 614 |
+
)
|
| 615 |
+
for column in columns:
|
| 616 |
+
dataframe[f"{column}_length"] = dataframe[column].apply(len)
|
| 617 |
+
dataframe[f"{column}_embeddings"] = get_embeddings(dataframe[column])
|
| 618 |
+
|
| 619 |
+
rg_dataset = client.datasets(name=repo_name, workspace=hf_user)
|
| 620 |
+
if rg_dataset is None:
|
| 621 |
+
rg_dataset = rg.Dataset(
|
| 622 |
+
name=repo_name,
|
| 623 |
+
workspace=hf_user,
|
| 624 |
+
settings=settings,
|
| 625 |
+
client=client,
|
| 626 |
+
)
|
| 627 |
+
rg_dataset = rg_dataset.create()
|
| 628 |
+
progress(0.7, desc="Pushing dataset to Argilla")
|
| 629 |
+
hf_dataset = Dataset.from_pandas(dataframe)
|
| 630 |
+
rg_dataset.records.log(
|
| 631 |
+
records=hf_dataset, mapping={"generation": "evaluation"}
|
| 632 |
+
)
|
| 633 |
+
progress(1.0, desc="Dataset pushed to Argilla")
|
| 634 |
+
except Exception as e:
|
| 635 |
+
raise gr.Error(f"Error pushing dataset to Argilla: {e}")
|
| 636 |
+
return ""
|
| 637 |
+
|
| 638 |
+
|
| 639 |
+
def show_pipeline_code_visibility():
|
| 640 |
+
return {pipeline_code_ui: gr.Accordion(visible=True)}
|
| 641 |
+
|
| 642 |
+
|
| 643 |
+
def hide_pipeline_code_visibility():
|
| 644 |
+
return {pipeline_code_ui: gr.Accordion(visible=False)}
|
| 645 |
+
|
| 646 |
+
|
| 647 |
+
######################
|
| 648 |
+
# Gradio UI
|
| 649 |
+
######################
|
| 650 |
+
|
| 651 |
+
|
| 652 |
+
with gr.Blocks() as app:
|
| 653 |
+
with gr.Column() as main_ui:
|
| 654 |
+
gr.Markdown("## 1. Select your input dataset")
|
| 655 |
+
with gr.Row(equal_height=False):
|
| 656 |
+
with gr.Column(scale=2):
|
| 657 |
+
search_in = HuggingfaceHubSearch(
|
| 658 |
+
label="Search",
|
| 659 |
+
placeholder="Search for a dataset",
|
| 660 |
+
search_type="dataset",
|
| 661 |
+
sumbit_on_select=True,
|
| 662 |
+
)
|
| 663 |
+
with gr.Row():
|
| 664 |
+
clear_btn_part = gr.Button("Clear", variant="secondary")
|
| 665 |
+
load_btn = gr.Button("Load", variant="primary")
|
| 666 |
+
|
| 667 |
+
with gr.Column(scale=3):
|
| 668 |
+
examples = gr.Examples(
|
| 669 |
+
examples=[
|
| 670 |
+
"argilla/distilabel-sft-easy",
|
| 671 |
+
"HuggingFaceFW/fineweb-edu",
|
| 672 |
+
"argilla/distilabel-intel-orca-dpo-pairs",
|
| 673 |
+
],
|
| 674 |
+
label="Example datasets",
|
| 675 |
+
fn=lambda x: x,
|
| 676 |
+
inputs=[search_in],
|
| 677 |
+
run_on_click=True,
|
| 678 |
+
)
|
| 679 |
+
search_out = gr.HTML(label="Dataset preview", visible=False)
|
| 680 |
+
|
| 681 |
+
gr.HTML(value="<hr>")
|
| 682 |
+
gr.Markdown(value="## 2. Configure your task")
|
| 683 |
+
with gr.Row(equal_height=False):
|
| 684 |
+
with gr.Column(scale=2):
|
| 685 |
+
eval_type = gr.Dropdown(
|
| 686 |
+
label="Evaluation type",
|
| 687 |
+
choices=["chat-eval", "custom-eval"],
|
| 688 |
+
value="chat-eval",
|
| 689 |
+
multiselect=False,
|
| 690 |
+
visible=False,
|
| 691 |
+
)
|
| 692 |
+
with gr.Tab("Response Evaluation") as tab_instruction_response:
|
| 693 |
+
aspects_instruction_response = define_evaluation_aspects(
|
| 694 |
+
"chat-eval"
|
| 695 |
+
)
|
| 696 |
+
instruction_instruction_response = gr.Dropdown(
|
| 697 |
+
label="Instruction Column",
|
| 698 |
+
info="Select the instruction column to evaluate",
|
| 699 |
+
choices=["Load your data first in step 1."],
|
| 700 |
+
value="Load your data first in step 1.",
|
| 701 |
+
interactive=False,
|
| 702 |
+
multiselect=False,
|
| 703 |
+
allow_custom_value=False,
|
| 704 |
+
)
|
| 705 |
+
response_instruction_response = gr.Dropdown(
|
| 706 |
+
label="Response Column",
|
| 707 |
+
info="Select the response column(s) to evaluate",
|
| 708 |
+
choices=["Load your data first in step 1."],
|
| 709 |
+
value="Load your data first in step 1.",
|
| 710 |
+
interactive=False,
|
| 711 |
+
multiselect=False,
|
| 712 |
+
allow_custom_value=False,
|
| 713 |
+
)
|
| 714 |
+
tab_instruction_response.select(
|
| 715 |
+
fn=lambda: "chat-eval",
|
| 716 |
+
inputs=[],
|
| 717 |
+
outputs=[eval_type],
|
| 718 |
+
)
|
| 719 |
+
with gr.Tab("Custom Evaluation Prompt") as tab_custom:
|
| 720 |
+
aspects_custom = define_evaluation_aspects("custom-eval")
|
| 721 |
+
prompt_template = gr.Code(
|
| 722 |
+
label="Prompt template",
|
| 723 |
+
value="Load your data first in step 1.",
|
| 724 |
+
language="markdown",
|
| 725 |
+
interactive=False,
|
| 726 |
+
)
|
| 727 |
+
structured_output = gr.Code(
|
| 728 |
+
label="Structured output",
|
| 729 |
+
value="Load your data first in step 1.",
|
| 730 |
+
language="json",
|
| 731 |
+
interactive=False,
|
| 732 |
+
)
|
| 733 |
+
tab_custom.select(
|
| 734 |
+
fn=lambda: "custom-eval",
|
| 735 |
+
inputs=[],
|
| 736 |
+
outputs=[eval_type],
|
| 737 |
+
)
|
| 738 |
+
with gr.Row():
|
| 739 |
+
clear_btn_full = gr.Button("Clear", variant="secondary")
|
| 740 |
+
btn_apply_to_sample_dataset = gr.Button("Save", variant="primary")
|
| 741 |
+
with gr.Column(scale=3):
|
| 742 |
+
dataframe = gr.Dataframe(
|
| 743 |
+
headers=["prompt", "completion", "evaluation"],
|
| 744 |
+
wrap=True,
|
| 745 |
+
interactive=False,
|
| 746 |
+
)
|
| 747 |
+
|
| 748 |
+
gr.HTML(value="<hr>")
|
| 749 |
+
gr.Markdown(value="## 3. Evaluate your dataset")
|
| 750 |
+
with gr.Row(equal_height=False):
|
| 751 |
+
with gr.Column(scale=2):
|
| 752 |
+
org_name = get_org_dropdown()
|
| 753 |
+
repo_name = gr.Textbox(
|
| 754 |
+
label="Repo name",
|
| 755 |
+
placeholder="dataset_name",
|
| 756 |
+
value=f"my-distiset-{str(uuid.uuid4())[:8]}",
|
| 757 |
+
interactive=True,
|
| 758 |
+
)
|
| 759 |
+
num_rows = gr.Number(
|
| 760 |
+
label="Number of rows",
|
| 761 |
+
value=10,
|
| 762 |
+
interactive=True,
|
| 763 |
+
scale=1,
|
| 764 |
+
)
|
| 765 |
+
private = gr.Checkbox(
|
| 766 |
+
label="Private dataset",
|
| 767 |
+
value=False,
|
| 768 |
+
interactive=True,
|
| 769 |
+
scale=1,
|
| 770 |
+
)
|
| 771 |
+
btn_push_to_hub = gr.Button("Push to Hub", variant="primary", scale=2)
|
| 772 |
+
with gr.Column(scale=3):
|
| 773 |
+
success_message = gr.Markdown(
|
| 774 |
+
visible=True,
|
| 775 |
+
min_height=100, # don't remove this otherwise progress is not visible
|
| 776 |
+
)
|
| 777 |
+
with gr.Accordion(
|
| 778 |
+
"Customize your pipeline with distilabel",
|
| 779 |
+
open=False,
|
| 780 |
+
visible=False,
|
| 781 |
+
) as pipeline_code_ui:
|
| 782 |
+
code = generate_pipeline_code(
|
| 783 |
+
repo_id=search_in.value,
|
| 784 |
+
aspects=aspects_instruction_response.value,
|
| 785 |
+
instruction_column=instruction_instruction_response,
|
| 786 |
+
response_columns=response_instruction_response,
|
| 787 |
+
prompt_template=prompt_template.value,
|
| 788 |
+
structured_output=structured_output.value,
|
| 789 |
+
num_rows=num_rows.value,
|
| 790 |
+
eval_type=eval_type.value,
|
| 791 |
+
)
|
| 792 |
+
pipeline_code = gr.Code(
|
| 793 |
+
value=code,
|
| 794 |
+
language="python",
|
| 795 |
+
label="Distilabel Pipeline Code",
|
| 796 |
+
)
|
| 797 |
+
|
| 798 |
+
search_in.submit(fn=get_iframe, inputs=search_in, outputs=search_out).then(
|
| 799 |
+
fn=lambda df: pd.DataFrame(columns=df.columns),
|
| 800 |
+
inputs=[dataframe],
|
| 801 |
+
outputs=[dataframe],
|
| 802 |
+
)
|
| 803 |
+
|
| 804 |
+
load_btn.click(
|
| 805 |
+
fn=load_dataset_from_hub,
|
| 806 |
+
inputs=[search_in],
|
| 807 |
+
outputs=[
|
| 808 |
+
dataframe,
|
| 809 |
+
instruction_instruction_response,
|
| 810 |
+
response_instruction_response,
|
| 811 |
+
prompt_template,
|
| 812 |
+
structured_output,
|
| 813 |
+
],
|
| 814 |
+
)
|
| 815 |
+
|
| 816 |
+
btn_apply_to_sample_dataset.click(
|
| 817 |
+
fn=evaluate_sample_dataset,
|
| 818 |
+
inputs=[
|
| 819 |
+
search_in,
|
| 820 |
+
eval_type,
|
| 821 |
+
aspects_instruction_response,
|
| 822 |
+
instruction_instruction_response,
|
| 823 |
+
response_instruction_response,
|
| 824 |
+
prompt_template,
|
| 825 |
+
structured_output,
|
| 826 |
+
],
|
| 827 |
+
outputs=dataframe,
|
| 828 |
+
)
|
| 829 |
+
|
| 830 |
+
btn_push_to_hub.click(
|
| 831 |
+
fn=validate_argilla_user_workspace_dataset,
|
| 832 |
+
inputs=[repo_name],
|
| 833 |
+
outputs=[success_message],
|
| 834 |
+
).then(
|
| 835 |
+
fn=validate_push_to_hub,
|
| 836 |
+
inputs=[org_name, repo_name],
|
| 837 |
+
outputs=[success_message],
|
| 838 |
+
).success(
|
| 839 |
+
fn=hide_success_message,
|
| 840 |
+
outputs=[success_message],
|
| 841 |
+
).success(
|
| 842 |
+
fn=hide_pipeline_code_visibility,
|
| 843 |
+
inputs=[],
|
| 844 |
+
outputs=[pipeline_code_ui],
|
| 845 |
+
).success(
|
| 846 |
+
fn=push_dataset,
|
| 847 |
+
inputs=[
|
| 848 |
+
org_name,
|
| 849 |
+
repo_name,
|
| 850 |
+
private,
|
| 851 |
+
num_rows,
|
| 852 |
+
search_in,
|
| 853 |
+
eval_type,
|
| 854 |
+
aspects_instruction_response,
|
| 855 |
+
instruction_instruction_response,
|
| 856 |
+
response_instruction_response,
|
| 857 |
+
prompt_template,
|
| 858 |
+
structured_output,
|
| 859 |
+
pipeline_code,
|
| 860 |
+
],
|
| 861 |
+
outputs=[success_message],
|
| 862 |
+
).success(
|
| 863 |
+
fn=show_success_message,
|
| 864 |
+
inputs=[org_name, repo_name],
|
| 865 |
+
outputs=[success_message],
|
| 866 |
+
).success(
|
| 867 |
+
fn=generate_pipeline_code,
|
| 868 |
+
inputs=[
|
| 869 |
+
search_in,
|
| 870 |
+
prompt_template,
|
| 871 |
+
structured_output,
|
| 872 |
+
eval_type,
|
| 873 |
+
],
|
| 874 |
+
outputs=[pipeline_code],
|
| 875 |
+
).success(
|
| 876 |
+
fn=show_pipeline_code_visibility,
|
| 877 |
+
inputs=[],
|
| 878 |
+
outputs=[pipeline_code_ui],
|
| 879 |
+
)
|
| 880 |
+
|
| 881 |
+
clear_btn_part.click(fn=lambda: "", inputs=[], outputs=[search_in])
|
| 882 |
+
clear_btn_full.click(
|
| 883 |
+
fn=lambda df: ("", "", pd.DataFrame(columns=df.columns)),
|
| 884 |
+
inputs=[dataframe],
|
| 885 |
+
outputs=[
|
| 886 |
+
instruction_instruction_response,
|
| 887 |
+
response_instruction_response,
|
| 888 |
+
dataframe,
|
| 889 |
+
],
|
| 890 |
+
)
|
| 891 |
+
|
| 892 |
+
app.load(fn=swap_visibility, outputs=main_ui)
|
| 893 |
+
app.load(fn=get_org_dropdown, outputs=[org_name])
|
| 894 |
+
app.load(fn=get_random_repo_name, outputs=[repo_name])
|
src/synthetic_dataset_generator/apps/rag.py
ADDED
|
@@ -0,0 +1,972 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import random
|
| 3 |
+
import uuid
|
| 4 |
+
from typing import Union
|
| 5 |
+
|
| 6 |
+
import argilla as rg
|
| 7 |
+
import gradio as gr
|
| 8 |
+
import nltk
|
| 9 |
+
import pandas as pd
|
| 10 |
+
from datasets import Dataset
|
| 11 |
+
from distilabel.distiset import Distiset
|
| 12 |
+
from gradio.oauth import OAuthToken
|
| 13 |
+
from gradio_huggingfacehub_search import HuggingfaceHubSearch
|
| 14 |
+
from huggingface_hub import HfApi
|
| 15 |
+
|
| 16 |
+
from synthetic_dataset_generator.apps.base import (
|
| 17 |
+
combine_datasets,
|
| 18 |
+
hide_success_message,
|
| 19 |
+
load_dataset_from_hub,
|
| 20 |
+
preprocess_input_data,
|
| 21 |
+
push_pipeline_code_to_hub,
|
| 22 |
+
show_success_message,
|
| 23 |
+
test_max_num_rows,
|
| 24 |
+
validate_argilla_user_workspace_dataset,
|
| 25 |
+
validate_push_to_hub,
|
| 26 |
+
)
|
| 27 |
+
from synthetic_dataset_generator.constants import (
|
| 28 |
+
DEFAULT_BATCH_SIZE,
|
| 29 |
+
MODEL,
|
| 30 |
+
MODEL_COMPLETION,
|
| 31 |
+
SAVE_LOCAL_DIR,
|
| 32 |
+
)
|
| 33 |
+
from synthetic_dataset_generator.pipelines.base import get_rewritten_prompts
|
| 34 |
+
from synthetic_dataset_generator.pipelines.embeddings import (
|
| 35 |
+
get_embeddings,
|
| 36 |
+
get_sentence_embedding_dimensions,
|
| 37 |
+
)
|
| 38 |
+
from synthetic_dataset_generator.pipelines.rag import (
|
| 39 |
+
DEFAULT_DATASET_DESCRIPTIONS,
|
| 40 |
+
generate_pipeline_code,
|
| 41 |
+
get_chunks_generator,
|
| 42 |
+
get_prompt_generator,
|
| 43 |
+
get_response_generator,
|
| 44 |
+
get_sentence_pair_generator,
|
| 45 |
+
)
|
| 46 |
+
from synthetic_dataset_generator.utils import (
|
| 47 |
+
column_to_list,
|
| 48 |
+
get_argilla_client,
|
| 49 |
+
get_org_dropdown,
|
| 50 |
+
get_random_repo_name,
|
| 51 |
+
swap_visibility,
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
os.makedirs("./nltk_data", exist_ok=True)
|
| 55 |
+
nltk.data.path.append("./nltk_data")
|
| 56 |
+
nltk.download("punkt_tab", download_dir="./nltk_data")
|
| 57 |
+
nltk.download("averaged_perceptron_tagger_eng", download_dir="./nltk_data")
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def generate_system_prompt(dataset_description: str, progress=gr.Progress()):
|
| 61 |
+
progress(0.1, desc="Initializing")
|
| 62 |
+
generate_description = get_prompt_generator()
|
| 63 |
+
progress(0.5, desc="Generating")
|
| 64 |
+
result = next(
|
| 65 |
+
generate_description.process(
|
| 66 |
+
[
|
| 67 |
+
{
|
| 68 |
+
"instruction": dataset_description,
|
| 69 |
+
}
|
| 70 |
+
]
|
| 71 |
+
)
|
| 72 |
+
)[0]["generation"]
|
| 73 |
+
progress(1.0, desc="Prompt generated")
|
| 74 |
+
return result
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def load_dataset_file(
|
| 78 |
+
repo_id: str,
|
| 79 |
+
file_paths: list[str],
|
| 80 |
+
input_type: str,
|
| 81 |
+
num_rows: int = 10,
|
| 82 |
+
token: Union[OAuthToken, None] = None,
|
| 83 |
+
progress=gr.Progress(),
|
| 84 |
+
):
|
| 85 |
+
progress(0.1, desc="Loading the source data")
|
| 86 |
+
if input_type == "dataset-input":
|
| 87 |
+
return load_dataset_from_hub(repo_id=repo_id, num_rows=num_rows, token=token)
|
| 88 |
+
else:
|
| 89 |
+
return preprocess_input_data(file_paths=file_paths, num_rows=num_rows)
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def generate_sample_dataset(
|
| 93 |
+
repo_id: str,
|
| 94 |
+
file_paths: list[str],
|
| 95 |
+
input_type: str,
|
| 96 |
+
system_prompt: str,
|
| 97 |
+
document_column: str,
|
| 98 |
+
retrieval_reranking: list[str],
|
| 99 |
+
num_rows: str,
|
| 100 |
+
oauth_token: Union[OAuthToken, None],
|
| 101 |
+
progress=gr.Progress(),
|
| 102 |
+
):
|
| 103 |
+
retrieval = "Retrieval" in retrieval_reranking
|
| 104 |
+
reranking = "Reranking" in retrieval_reranking
|
| 105 |
+
|
| 106 |
+
if input_type == "prompt-input":
|
| 107 |
+
dataframe = pd.DataFrame(columns=["context", "question", "response"])
|
| 108 |
+
else:
|
| 109 |
+
dataframe, _ = load_dataset_file(
|
| 110 |
+
repo_id=repo_id,
|
| 111 |
+
file_paths=file_paths,
|
| 112 |
+
input_type=input_type,
|
| 113 |
+
num_rows=num_rows,
|
| 114 |
+
token=oauth_token,
|
| 115 |
+
)
|
| 116 |
+
progress(0.5, desc="Generating dataset")
|
| 117 |
+
dataframe = generate_dataset(
|
| 118 |
+
input_type=input_type,
|
| 119 |
+
dataframe=dataframe,
|
| 120 |
+
system_prompt=system_prompt,
|
| 121 |
+
document_column=document_column,
|
| 122 |
+
retrieval=retrieval,
|
| 123 |
+
reranking=reranking,
|
| 124 |
+
num_rows=10,
|
| 125 |
+
is_sample=True,
|
| 126 |
+
)
|
| 127 |
+
progress(1.0, desc="Sample dataset generated")
|
| 128 |
+
return dataframe
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def generate_dataset(
|
| 132 |
+
input_type: str,
|
| 133 |
+
dataframe: pd.DataFrame,
|
| 134 |
+
system_prompt: str,
|
| 135 |
+
document_column: str,
|
| 136 |
+
retrieval: bool = False,
|
| 137 |
+
reranking: bool = False,
|
| 138 |
+
num_rows: int = 10,
|
| 139 |
+
temperature: float = 0.7,
|
| 140 |
+
temperature_completion: Union[float, None] = None,
|
| 141 |
+
is_sample: bool = False,
|
| 142 |
+
progress=gr.Progress(),
|
| 143 |
+
):
|
| 144 |
+
num_rows = test_max_num_rows(num_rows)
|
| 145 |
+
progress(0.0, desc="Initializing dataset generation")
|
| 146 |
+
if input_type == "prompt-input":
|
| 147 |
+
chunk_generator = get_chunks_generator(
|
| 148 |
+
temperature=temperature, is_sample=is_sample
|
| 149 |
+
)
|
| 150 |
+
else:
|
| 151 |
+
document_data = column_to_list(dataframe, document_column)
|
| 152 |
+
if len(document_data) < num_rows:
|
| 153 |
+
document_data += random.choices(
|
| 154 |
+
document_data, k=num_rows - len(document_data)
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
retrieval_generator = get_sentence_pair_generator(
|
| 158 |
+
action="query",
|
| 159 |
+
triplet=True if retrieval else False,
|
| 160 |
+
temperature=temperature,
|
| 161 |
+
is_sample=is_sample,
|
| 162 |
+
)
|
| 163 |
+
response_generator = get_response_generator(
|
| 164 |
+
temperature=temperature_completion or temperature, is_sample=is_sample
|
| 165 |
+
)
|
| 166 |
+
if reranking:
|
| 167 |
+
reranking_generator = get_sentence_pair_generator(
|
| 168 |
+
action="semantically-similar",
|
| 169 |
+
triplet=True,
|
| 170 |
+
temperature=temperature,
|
| 171 |
+
is_sample=is_sample,
|
| 172 |
+
)
|
| 173 |
+
steps = 2 + sum([1 if reranking else 0, 1 if input_type == "prompt-type" else 0])
|
| 174 |
+
total_steps: int = num_rows * steps
|
| 175 |
+
step_progress = round(1 / steps, 2)
|
| 176 |
+
batch_size = DEFAULT_BATCH_SIZE
|
| 177 |
+
|
| 178 |
+
# generate chunks
|
| 179 |
+
if input_type == "prompt-input":
|
| 180 |
+
n_processed = 0
|
| 181 |
+
chunk_results = []
|
| 182 |
+
rewritten_system_prompts = get_rewritten_prompts(system_prompt, num_rows)
|
| 183 |
+
while n_processed < num_rows:
|
| 184 |
+
progress(
|
| 185 |
+
step_progress * n_processed / num_rows,
|
| 186 |
+
total=total_steps,
|
| 187 |
+
desc="Generating chunks",
|
| 188 |
+
)
|
| 189 |
+
remaining_rows = num_rows - n_processed
|
| 190 |
+
batch_size = min(batch_size, remaining_rows)
|
| 191 |
+
inputs = [
|
| 192 |
+
{"task": random.choice(rewritten_system_prompts)}
|
| 193 |
+
for _ in range(batch_size)
|
| 194 |
+
]
|
| 195 |
+
chunks = list(chunk_generator.process(inputs=inputs))
|
| 196 |
+
chunk_results.extend(chunks[0])
|
| 197 |
+
n_processed += batch_size
|
| 198 |
+
random.seed(a=random.randint(0, 2**32 - 1))
|
| 199 |
+
document_data = [chunk["generation"] for chunk in chunk_results]
|
| 200 |
+
progress(step_progress, desc="Generating chunks")
|
| 201 |
+
|
| 202 |
+
# generate questions
|
| 203 |
+
n_processed = 0
|
| 204 |
+
retrieval_results = []
|
| 205 |
+
while n_processed < num_rows:
|
| 206 |
+
progress(
|
| 207 |
+
step_progress * n_processed / num_rows,
|
| 208 |
+
total=total_steps,
|
| 209 |
+
desc="Generating questions",
|
| 210 |
+
)
|
| 211 |
+
remaining_rows = num_rows - n_processed
|
| 212 |
+
batch_size = min(batch_size, remaining_rows)
|
| 213 |
+
inputs = [
|
| 214 |
+
{"anchor": document}
|
| 215 |
+
for document in document_data[n_processed : n_processed + batch_size]
|
| 216 |
+
]
|
| 217 |
+
questions = list(retrieval_generator.process(inputs=inputs))
|
| 218 |
+
retrieval_results.extend(questions[0])
|
| 219 |
+
n_processed += batch_size
|
| 220 |
+
for result in retrieval_results:
|
| 221 |
+
result["context"] = result["anchor"]
|
| 222 |
+
if retrieval:
|
| 223 |
+
result["question"] = result["positive"]
|
| 224 |
+
result["positive_retrieval"] = result.pop("positive")
|
| 225 |
+
result["negative_retrieval"] = result.pop("negative")
|
| 226 |
+
else:
|
| 227 |
+
result["question"] = result.pop("positive")
|
| 228 |
+
|
| 229 |
+
progress(step_progress, desc="Generating questions")
|
| 230 |
+
|
| 231 |
+
# generate responses
|
| 232 |
+
n_processed = 0
|
| 233 |
+
response_results = []
|
| 234 |
+
while n_processed < num_rows:
|
| 235 |
+
progress(
|
| 236 |
+
step_progress + step_progress * n_processed / num_rows,
|
| 237 |
+
total=total_steps,
|
| 238 |
+
desc="Generating responses",
|
| 239 |
+
)
|
| 240 |
+
batch = retrieval_results[n_processed : n_processed + batch_size]
|
| 241 |
+
responses = list(response_generator.process(inputs=batch))
|
| 242 |
+
response_results.extend(responses[0])
|
| 243 |
+
n_processed += batch_size
|
| 244 |
+
for result in response_results:
|
| 245 |
+
result["response"] = result["generation"]
|
| 246 |
+
progress(step_progress, desc="Generating responses")
|
| 247 |
+
|
| 248 |
+
# generate reranking
|
| 249 |
+
if reranking:
|
| 250 |
+
n_processed = 0
|
| 251 |
+
reranking_results = []
|
| 252 |
+
while n_processed < num_rows:
|
| 253 |
+
progress(
|
| 254 |
+
step_progress * n_processed / num_rows,
|
| 255 |
+
total=total_steps,
|
| 256 |
+
desc="Generating reranking data",
|
| 257 |
+
)
|
| 258 |
+
batch = response_results[n_processed : n_processed + batch_size]
|
| 259 |
+
batch = list(reranking_generator.process(inputs=batch))
|
| 260 |
+
reranking_results.extend(batch[0])
|
| 261 |
+
n_processed += batch_size
|
| 262 |
+
for result in reranking_results:
|
| 263 |
+
result["positive_reranking"] = result.pop("positive")
|
| 264 |
+
result["negative_reranking"] = result.pop("negative")
|
| 265 |
+
progress(
|
| 266 |
+
1,
|
| 267 |
+
total=total_steps,
|
| 268 |
+
desc="Creating dataset",
|
| 269 |
+
)
|
| 270 |
+
|
| 271 |
+
# create distiset
|
| 272 |
+
distiset_results = []
|
| 273 |
+
source_results = reranking_results if reranking else response_results
|
| 274 |
+
base_keys = ["context", "question", "response"]
|
| 275 |
+
retrieval_keys = ["positive_retrieval", "negative_retrieval"] if retrieval else []
|
| 276 |
+
reranking_keys = ["positive_reranking", "negative_reranking"] if reranking else []
|
| 277 |
+
relevant_keys = base_keys + retrieval_keys + reranking_keys
|
| 278 |
+
|
| 279 |
+
for result in source_results:
|
| 280 |
+
record = {key: result.get(key) for key in relevant_keys if key in result}
|
| 281 |
+
distiset_results.append(record)
|
| 282 |
+
|
| 283 |
+
dataframe = pd.DataFrame(distiset_results)
|
| 284 |
+
|
| 285 |
+
progress(1.0, desc="Dataset generation completed")
|
| 286 |
+
return dataframe
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
def push_dataset_to_hub(
|
| 290 |
+
dataframe: pd.DataFrame,
|
| 291 |
+
org_name: str,
|
| 292 |
+
repo_name: str,
|
| 293 |
+
oauth_token: Union[gr.OAuthToken, None],
|
| 294 |
+
private: bool,
|
| 295 |
+
pipeline_code: str,
|
| 296 |
+
progress=gr.Progress(),
|
| 297 |
+
):
|
| 298 |
+
progress(0.0, desc="Validating")
|
| 299 |
+
repo_id = validate_push_to_hub(org_name, repo_name)
|
| 300 |
+
progress(0.5, desc="Creating dataset")
|
| 301 |
+
dataset = Dataset.from_pandas(dataframe)
|
| 302 |
+
dataset = combine_datasets(repo_id, dataset, oauth_token)
|
| 303 |
+
distiset = Distiset({"default": dataset})
|
| 304 |
+
progress(0.9, desc="Pushing dataset")
|
| 305 |
+
distiset.push_to_hub(
|
| 306 |
+
repo_id=repo_id,
|
| 307 |
+
private=private,
|
| 308 |
+
include_script=False,
|
| 309 |
+
token=oauth_token.token,
|
| 310 |
+
create_pr=False,
|
| 311 |
+
)
|
| 312 |
+
push_pipeline_code_to_hub(pipeline_code, org_name, repo_name, oauth_token)
|
| 313 |
+
progress(1.0, desc="Dataset pushed")
|
| 314 |
+
return dataframe
|
| 315 |
+
|
| 316 |
+
|
| 317 |
+
def push_dataset(
|
| 318 |
+
org_name: str,
|
| 319 |
+
repo_name: str,
|
| 320 |
+
private: bool,
|
| 321 |
+
original_repo_id: str,
|
| 322 |
+
file_paths: list[str],
|
| 323 |
+
input_type: str,
|
| 324 |
+
system_prompt: str,
|
| 325 |
+
document_column: str,
|
| 326 |
+
retrieval_reranking: list[str],
|
| 327 |
+
num_rows: int,
|
| 328 |
+
temperature: float,
|
| 329 |
+
temperature_completion: float,
|
| 330 |
+
pipeline_code: str,
|
| 331 |
+
oauth_token: Union[gr.OAuthToken, None] = None,
|
| 332 |
+
progress=gr.Progress(),
|
| 333 |
+
) -> pd.DataFrame:
|
| 334 |
+
retrieval = "Retrieval" in retrieval_reranking
|
| 335 |
+
reranking = "Reranking" in retrieval_reranking
|
| 336 |
+
|
| 337 |
+
if input_type == "prompt-input":
|
| 338 |
+
dataframe = pd.DataFrame(columns=["context", "question", "response"])
|
| 339 |
+
else:
|
| 340 |
+
dataframe, _ = load_dataset_file(
|
| 341 |
+
repo_id=original_repo_id,
|
| 342 |
+
file_paths=file_paths,
|
| 343 |
+
input_type=input_type,
|
| 344 |
+
num_rows=num_rows,
|
| 345 |
+
token=oauth_token,
|
| 346 |
+
)
|
| 347 |
+
progress(0.5, desc="Generating dataset")
|
| 348 |
+
dataframe = generate_dataset(
|
| 349 |
+
input_type=input_type,
|
| 350 |
+
dataframe=dataframe,
|
| 351 |
+
system_prompt=system_prompt,
|
| 352 |
+
document_column=document_column,
|
| 353 |
+
retrieval=retrieval,
|
| 354 |
+
reranking=reranking,
|
| 355 |
+
num_rows=num_rows,
|
| 356 |
+
temperature=temperature,
|
| 357 |
+
temperature_completion=temperature_completion,
|
| 358 |
+
is_sample=True,
|
| 359 |
+
)
|
| 360 |
+
push_dataset_to_hub(
|
| 361 |
+
dataframe, org_name, repo_name, oauth_token, private, pipeline_code
|
| 362 |
+
)
|
| 363 |
+
dataframe = dataframe[
|
| 364 |
+
dataframe.applymap(lambda x: str(x).strip() if pd.notna(x) else x).apply(
|
| 365 |
+
lambda row: row.notna().all() and (row != "").all(), axis=1
|
| 366 |
+
)
|
| 367 |
+
]
|
| 368 |
+
try:
|
| 369 |
+
progress(0.1, desc="Setting up user and workspace")
|
| 370 |
+
hf_user = HfApi().whoami(token=oauth_token.token)["name"]
|
| 371 |
+
client = get_argilla_client()
|
| 372 |
+
if client is None:
|
| 373 |
+
return ""
|
| 374 |
+
|
| 375 |
+
progress(0.5, desc="Creating dataset in Argilla")
|
| 376 |
+
fields = [
|
| 377 |
+
rg.TextField(
|
| 378 |
+
name="context",
|
| 379 |
+
title="Context",
|
| 380 |
+
description="Context for the generation",
|
| 381 |
+
),
|
| 382 |
+
rg.ChatField(
|
| 383 |
+
name="chat",
|
| 384 |
+
title="Chat",
|
| 385 |
+
description="User and assistant conversation based on the context",
|
| 386 |
+
),
|
| 387 |
+
]
|
| 388 |
+
for item in ["positive", "negative"]:
|
| 389 |
+
if retrieval:
|
| 390 |
+
fields.append(
|
| 391 |
+
rg.TextField(
|
| 392 |
+
name=f"{item}_retrieval",
|
| 393 |
+
title=f"{item.capitalize()} retrieval",
|
| 394 |
+
description=f"The {item} query for retrieval",
|
| 395 |
+
)
|
| 396 |
+
)
|
| 397 |
+
if reranking:
|
| 398 |
+
fields.append(
|
| 399 |
+
rg.TextField(
|
| 400 |
+
name=f"{item}_reranking",
|
| 401 |
+
title=f"{item.capitalize()} reranking",
|
| 402 |
+
description=f"The {item} query for reranking",
|
| 403 |
+
)
|
| 404 |
+
)
|
| 405 |
+
|
| 406 |
+
questions = [
|
| 407 |
+
rg.LabelQuestion(
|
| 408 |
+
name="relevant",
|
| 409 |
+
title="Are the question and response relevant to the given context?",
|
| 410 |
+
labels=["yes", "no"],
|
| 411 |
+
),
|
| 412 |
+
rg.LabelQuestion(
|
| 413 |
+
name="is_response_correct",
|
| 414 |
+
title="Is the response correct?",
|
| 415 |
+
labels=["yes", "no"],
|
| 416 |
+
),
|
| 417 |
+
]
|
| 418 |
+
for item in ["positive", "negative"]:
|
| 419 |
+
if retrieval:
|
| 420 |
+
questions.append(
|
| 421 |
+
rg.LabelQuestion(
|
| 422 |
+
name=f"is_{item}_retrieval_relevant",
|
| 423 |
+
title=f"Is the {item} retrieval relevant?",
|
| 424 |
+
labels=["yes", "no"],
|
| 425 |
+
required=False,
|
| 426 |
+
)
|
| 427 |
+
)
|
| 428 |
+
if reranking:
|
| 429 |
+
questions.append(
|
| 430 |
+
rg.LabelQuestion(
|
| 431 |
+
name=f"is_{item}_reranking_relevant",
|
| 432 |
+
title=f"Is the {item} reranking relevant?",
|
| 433 |
+
labels=["yes", "no"],
|
| 434 |
+
required=False,
|
| 435 |
+
)
|
| 436 |
+
)
|
| 437 |
+
metadata = [
|
| 438 |
+
rg.IntegerMetadataProperty(
|
| 439 |
+
name=f"{item}_length", title=f"{item.capitalize()} length"
|
| 440 |
+
)
|
| 441 |
+
for item in ["context", "question", "response"]
|
| 442 |
+
]
|
| 443 |
+
|
| 444 |
+
vectors = [
|
| 445 |
+
rg.VectorField(
|
| 446 |
+
name=f"{item}_embeddings",
|
| 447 |
+
dimensions=get_sentence_embedding_dimensions(),
|
| 448 |
+
)
|
| 449 |
+
for item in ["context", "question", "response"]
|
| 450 |
+
]
|
| 451 |
+
settings = rg.Settings(
|
| 452 |
+
fields=fields,
|
| 453 |
+
questions=questions,
|
| 454 |
+
metadata=metadata,
|
| 455 |
+
vectors=vectors,
|
| 456 |
+
guidelines="Please review the conversation and provide an evaluation.",
|
| 457 |
+
)
|
| 458 |
+
|
| 459 |
+
dataframe["chat"] = dataframe.apply(
|
| 460 |
+
lambda row: [
|
| 461 |
+
{"role": "user", "content": row["question"]},
|
| 462 |
+
{"role": "assistant", "content": row["response"]},
|
| 463 |
+
],
|
| 464 |
+
axis=1,
|
| 465 |
+
)
|
| 466 |
+
|
| 467 |
+
for item in ["context", "question", "response"]:
|
| 468 |
+
dataframe[f"{item}_length"] = dataframe[item].apply(
|
| 469 |
+
lambda x: len(x) if x is not None else 0
|
| 470 |
+
)
|
| 471 |
+
dataframe[f"{item}_embeddings"] = get_embeddings(
|
| 472 |
+
dataframe[item].apply(lambda x: x if x is not None else "").to_list()
|
| 473 |
+
)
|
| 474 |
+
|
| 475 |
+
rg_dataset = client.datasets(name=repo_name, workspace=hf_user)
|
| 476 |
+
if rg_dataset is None:
|
| 477 |
+
rg_dataset = rg.Dataset(
|
| 478 |
+
name=repo_name,
|
| 479 |
+
workspace=hf_user,
|
| 480 |
+
settings=settings,
|
| 481 |
+
client=client,
|
| 482 |
+
)
|
| 483 |
+
rg_dataset = rg_dataset.create()
|
| 484 |
+
|
| 485 |
+
progress(0.7, desc="Pushing dataset to Argilla")
|
| 486 |
+
hf_dataset = Dataset.from_pandas(dataframe)
|
| 487 |
+
rg_dataset.records.log(records=hf_dataset)
|
| 488 |
+
progress(1.0, desc="Dataset pushed to Argilla")
|
| 489 |
+
except Exception as e:
|
| 490 |
+
raise gr.Error(f"Error pushing dataset to Argilla: {e}")
|
| 491 |
+
return ""
|
| 492 |
+
|
| 493 |
+
|
| 494 |
+
def save_local(
|
| 495 |
+
repo_id: str,
|
| 496 |
+
file_paths: list[str],
|
| 497 |
+
input_type: str,
|
| 498 |
+
system_prompt: str,
|
| 499 |
+
document_column: str,
|
| 500 |
+
retrieval_reranking: list[str],
|
| 501 |
+
num_rows: int,
|
| 502 |
+
temperature: float,
|
| 503 |
+
repo_name: str,
|
| 504 |
+
temperature_completion: float,
|
| 505 |
+
) -> pd.DataFrame:
|
| 506 |
+
retrieval = "Retrieval" in retrieval_reranking
|
| 507 |
+
reranking = "Reranking" in retrieval_reranking
|
| 508 |
+
|
| 509 |
+
if input_type == "prompt-input":
|
| 510 |
+
dataframe = pd.DataFrame(columns=["context", "question", "response"])
|
| 511 |
+
else:
|
| 512 |
+
dataframe, _ = load_dataset_file(
|
| 513 |
+
repo_id=repo_id,
|
| 514 |
+
file_paths=file_paths,
|
| 515 |
+
input_type=input_type,
|
| 516 |
+
num_rows=num_rows,
|
| 517 |
+
)
|
| 518 |
+
dataframe = generate_dataset(
|
| 519 |
+
input_type=input_type,
|
| 520 |
+
dataframe=dataframe,
|
| 521 |
+
system_prompt=system_prompt,
|
| 522 |
+
document_column=document_column,
|
| 523 |
+
retrieval=retrieval,
|
| 524 |
+
reranking=reranking,
|
| 525 |
+
num_rows=num_rows,
|
| 526 |
+
temperature=temperature,
|
| 527 |
+
temperature_completion=temperature_completion,
|
| 528 |
+
)
|
| 529 |
+
local_dataset = Dataset.from_pandas(dataframe)
|
| 530 |
+
output_csv = os.path.join(SAVE_LOCAL_DIR, repo_name + ".csv")
|
| 531 |
+
output_json = os.path.join(SAVE_LOCAL_DIR, repo_name + ".json")
|
| 532 |
+
local_dataset.to_csv(output_csv, index=False)
|
| 533 |
+
local_dataset.to_json(output_json, index=False)
|
| 534 |
+
return output_csv, output_json
|
| 535 |
+
|
| 536 |
+
|
| 537 |
+
def show_system_prompt_visibility():
|
| 538 |
+
return {system_prompt: gr.Textbox(visible=True)}
|
| 539 |
+
|
| 540 |
+
|
| 541 |
+
def hide_system_prompt_visibility():
|
| 542 |
+
return {system_prompt: gr.Textbox(visible=False)}
|
| 543 |
+
|
| 544 |
+
|
| 545 |
+
def show_document_column_visibility():
|
| 546 |
+
return {document_column: gr.Dropdown(visible=True)}
|
| 547 |
+
|
| 548 |
+
|
| 549 |
+
def hide_document_column_visibility():
|
| 550 |
+
return {
|
| 551 |
+
document_column: gr.Dropdown(
|
| 552 |
+
choices=["Load your data first in step 1."],
|
| 553 |
+
value="Load your data first in step 1.",
|
| 554 |
+
visible=False,
|
| 555 |
+
)
|
| 556 |
+
}
|
| 557 |
+
|
| 558 |
+
|
| 559 |
+
def show_pipeline_code_visibility():
|
| 560 |
+
return {pipeline_code_ui: gr.Accordion(visible=True)}
|
| 561 |
+
|
| 562 |
+
|
| 563 |
+
def hide_pipeline_code_visibility():
|
| 564 |
+
return {pipeline_code_ui: gr.Accordion(visible=False)}
|
| 565 |
+
|
| 566 |
+
|
| 567 |
+
def show_temperature_completion():
|
| 568 |
+
if MODEL != MODEL_COMPLETION:
|
| 569 |
+
return {temperature_completion: gr.Slider(value=0.9, visible=True)}
|
| 570 |
+
|
| 571 |
+
|
| 572 |
+
def show_save_local_button():
|
| 573 |
+
return {btn_save_local: gr.Button(visible=True)}
|
| 574 |
+
|
| 575 |
+
|
| 576 |
+
def hide_save_local_button():
|
| 577 |
+
return {btn_save_local: gr.Button(visible=False)}
|
| 578 |
+
|
| 579 |
+
|
| 580 |
+
def show_save_local():
|
| 581 |
+
gr.update(success_message, min_height=0)
|
| 582 |
+
return {
|
| 583 |
+
csv_file: gr.File(visible=True),
|
| 584 |
+
json_file: gr.File(visible=True),
|
| 585 |
+
success_message: success_message,
|
| 586 |
+
}
|
| 587 |
+
|
| 588 |
+
|
| 589 |
+
def hide_save_local():
|
| 590 |
+
gr.update(success_message, min_height=100)
|
| 591 |
+
return {
|
| 592 |
+
csv_file: gr.File(visible=False),
|
| 593 |
+
json_file: gr.File(visible=False),
|
| 594 |
+
success_message: success_message,
|
| 595 |
+
}
|
| 596 |
+
|
| 597 |
+
|
| 598 |
+
######################
|
| 599 |
+
# Gradio UI
|
| 600 |
+
######################
|
| 601 |
+
|
| 602 |
+
|
| 603 |
+
with gr.Blocks() as app:
|
| 604 |
+
with gr.Column() as main_ui:
|
| 605 |
+
gr.Markdown("## 1. Select your input")
|
| 606 |
+
with gr.Row(equal_height=False):
|
| 607 |
+
with gr.Column(scale=2):
|
| 608 |
+
input_type = gr.Dropdown(
|
| 609 |
+
label="Input type",
|
| 610 |
+
choices=["dataset-input", "file-input", "prompt-input"],
|
| 611 |
+
value="dataset-input",
|
| 612 |
+
multiselect=False,
|
| 613 |
+
visible=False,
|
| 614 |
+
)
|
| 615 |
+
with gr.Tab("Load from Hub") as tab_dataset_input:
|
| 616 |
+
with gr.Row(equal_height=False):
|
| 617 |
+
with gr.Column(scale=2):
|
| 618 |
+
search_in = HuggingfaceHubSearch(
|
| 619 |
+
label="Search",
|
| 620 |
+
placeholder="Search for a dataset",
|
| 621 |
+
search_type="dataset",
|
| 622 |
+
sumbit_on_select=True,
|
| 623 |
+
)
|
| 624 |
+
with gr.Row():
|
| 625 |
+
clear_dataset_btn_part = gr.Button(
|
| 626 |
+
"Clear", variant="secondary"
|
| 627 |
+
)
|
| 628 |
+
load_dataset_btn = gr.Button("Load", variant="primary")
|
| 629 |
+
with gr.Column(scale=3):
|
| 630 |
+
examples = gr.Examples(
|
| 631 |
+
examples=[
|
| 632 |
+
"charris/wikipedia_sample",
|
| 633 |
+
"plaguss/argilla_sdk_docs_raw_unstructured",
|
| 634 |
+
"BeIR/hotpotqa-generated-queries",
|
| 635 |
+
],
|
| 636 |
+
label="Example datasets",
|
| 637 |
+
fn=lambda x: x,
|
| 638 |
+
inputs=[search_in],
|
| 639 |
+
run_on_click=True,
|
| 640 |
+
)
|
| 641 |
+
search_out = gr.HTML(label="Dataset preview", visible=False)
|
| 642 |
+
with gr.Tab("Load your file") as tab_file_input:
|
| 643 |
+
with gr.Row(equal_height=False):
|
| 644 |
+
with gr.Column(scale=2):
|
| 645 |
+
file_in = gr.File(
|
| 646 |
+
label="Upload your file. Supported formats: .md, .txt, .docx, .pdf",
|
| 647 |
+
file_count="multiple",
|
| 648 |
+
file_types=[".md", ".txt", ".docx", ".pdf"],
|
| 649 |
+
)
|
| 650 |
+
with gr.Row():
|
| 651 |
+
clear_file_btn_part = gr.Button(
|
| 652 |
+
"Clear", variant="secondary"
|
| 653 |
+
)
|
| 654 |
+
load_file_btn = gr.Button("Load", variant="primary")
|
| 655 |
+
with gr.Column(scale=3):
|
| 656 |
+
file_out = gr.HTML(label="Dataset preview", visible=False)
|
| 657 |
+
with gr.Tab("Generate from prompt") as tab_prompt_input:
|
| 658 |
+
with gr.Row(equal_height=False):
|
| 659 |
+
with gr.Column(scale=2):
|
| 660 |
+
dataset_description = gr.Textbox(
|
| 661 |
+
label="Dataset description",
|
| 662 |
+
placeholder="Give a precise description of your desired dataset.",
|
| 663 |
+
)
|
| 664 |
+
with gr.Row():
|
| 665 |
+
clear_prompt_btn_part = gr.Button(
|
| 666 |
+
"Clear", variant="secondary"
|
| 667 |
+
)
|
| 668 |
+
load_prompt_btn = gr.Button("Create", variant="primary")
|
| 669 |
+
with gr.Column(scale=3):
|
| 670 |
+
examples = gr.Examples(
|
| 671 |
+
examples=DEFAULT_DATASET_DESCRIPTIONS,
|
| 672 |
+
inputs=[dataset_description],
|
| 673 |
+
cache_examples=False,
|
| 674 |
+
label="Examples",
|
| 675 |
+
)
|
| 676 |
+
|
| 677 |
+
gr.HTML(value="<hr>")
|
| 678 |
+
gr.Markdown(value="## 2. Configure your task")
|
| 679 |
+
with gr.Row(equal_height=False):
|
| 680 |
+
with gr.Column(scale=2):
|
| 681 |
+
system_prompt = gr.Textbox(
|
| 682 |
+
label="System prompt",
|
| 683 |
+
placeholder="You are a helpful assistant.",
|
| 684 |
+
visible=False,
|
| 685 |
+
)
|
| 686 |
+
document_column = gr.Dropdown(
|
| 687 |
+
label="Document Column",
|
| 688 |
+
info="Select the document column to generate the RAG dataset",
|
| 689 |
+
choices=["Load your data first in step 1."],
|
| 690 |
+
value="Load your data first in step 1.",
|
| 691 |
+
interactive=False,
|
| 692 |
+
multiselect=False,
|
| 693 |
+
allow_custom_value=False,
|
| 694 |
+
)
|
| 695 |
+
retrieval_reranking = gr.CheckboxGroup(
|
| 696 |
+
choices=[("Retrieval", "Retrieval"), ("Reranking", "Reranking")],
|
| 697 |
+
type="value",
|
| 698 |
+
label="Data for RAG",
|
| 699 |
+
info="Indicate the additional data you want to generate for RAG.",
|
| 700 |
+
)
|
| 701 |
+
with gr.Row():
|
| 702 |
+
clear_btn_full = gr.Button("Clear", variant="secondary")
|
| 703 |
+
btn_apply_to_sample_dataset = gr.Button("Save", variant="primary")
|
| 704 |
+
with gr.Column(scale=3):
|
| 705 |
+
dataframe = gr.Dataframe(
|
| 706 |
+
headers=["context", "question", "response"],
|
| 707 |
+
wrap=True,
|
| 708 |
+
interactive=False,
|
| 709 |
+
)
|
| 710 |
+
|
| 711 |
+
gr.HTML(value="<hr>")
|
| 712 |
+
gr.Markdown(value="## 3. Generate your dataset")
|
| 713 |
+
with gr.Row(equal_height=False):
|
| 714 |
+
with gr.Column(scale=2):
|
| 715 |
+
org_name = get_org_dropdown()
|
| 716 |
+
repo_name = gr.Textbox(
|
| 717 |
+
label="Repo name",
|
| 718 |
+
placeholder="dataset_name",
|
| 719 |
+
value=f"my-distiset-{str(uuid.uuid4())[:8]}",
|
| 720 |
+
interactive=True,
|
| 721 |
+
)
|
| 722 |
+
num_rows = gr.Number(
|
| 723 |
+
label="Number of rows",
|
| 724 |
+
value=10,
|
| 725 |
+
interactive=True,
|
| 726 |
+
scale=1,
|
| 727 |
+
)
|
| 728 |
+
temperature = gr.Slider(
|
| 729 |
+
label="Temperature",
|
| 730 |
+
minimum=0.1,
|
| 731 |
+
maximum=1.5,
|
| 732 |
+
value=0.7,
|
| 733 |
+
step=0.1,
|
| 734 |
+
interactive=True,
|
| 735 |
+
)
|
| 736 |
+
temperature_completion = gr.Slider(
|
| 737 |
+
label="Temperature for completion",
|
| 738 |
+
minimum=0.1,
|
| 739 |
+
maximum=1.5,
|
| 740 |
+
value=None,
|
| 741 |
+
step=0.1,
|
| 742 |
+
interactive=True,
|
| 743 |
+
visible=False,
|
| 744 |
+
)
|
| 745 |
+
private = gr.Checkbox(
|
| 746 |
+
label="Private dataset",
|
| 747 |
+
value=False,
|
| 748 |
+
interactive=True,
|
| 749 |
+
scale=1,
|
| 750 |
+
)
|
| 751 |
+
btn_push_to_hub = gr.Button("Push to Hub", variant="primary", scale=2)
|
| 752 |
+
btn_save_local = gr.Button(
|
| 753 |
+
"Save locally", variant="primary", scale=2, visible=False
|
| 754 |
+
)
|
| 755 |
+
with gr.Column(scale=3):
|
| 756 |
+
csv_file = gr.File(
|
| 757 |
+
label="CSV",
|
| 758 |
+
elem_classes="datasets",
|
| 759 |
+
visible=False,
|
| 760 |
+
)
|
| 761 |
+
json_file = gr.File(
|
| 762 |
+
label="JSON",
|
| 763 |
+
elem_classes="datasets",
|
| 764 |
+
visible=False,
|
| 765 |
+
)
|
| 766 |
+
success_message = gr.Markdown(
|
| 767 |
+
visible=False,
|
| 768 |
+
min_height=0, # don't remove this otherwise progress is not visible
|
| 769 |
+
)
|
| 770 |
+
with gr.Accordion(
|
| 771 |
+
"Customize your pipeline with distilabel",
|
| 772 |
+
open=False,
|
| 773 |
+
visible=False,
|
| 774 |
+
) as pipeline_code_ui:
|
| 775 |
+
code = generate_pipeline_code(
|
| 776 |
+
repo_id=search_in.value,
|
| 777 |
+
input_type=input_type.value,
|
| 778 |
+
system_prompt=system_prompt.value,
|
| 779 |
+
document_column=document_column.value,
|
| 780 |
+
retrieval_reranking=retrieval_reranking.value,
|
| 781 |
+
num_rows=num_rows.value,
|
| 782 |
+
)
|
| 783 |
+
pipeline_code = gr.Code(
|
| 784 |
+
value=code,
|
| 785 |
+
language="python",
|
| 786 |
+
label="Distilabel Pipeline Code",
|
| 787 |
+
)
|
| 788 |
+
|
| 789 |
+
tab_dataset_input.select(
|
| 790 |
+
fn=lambda: "dataset-input",
|
| 791 |
+
inputs=[],
|
| 792 |
+
outputs=[input_type],
|
| 793 |
+
).then(fn=hide_system_prompt_visibility, inputs=[], outputs=[system_prompt]).then(
|
| 794 |
+
fn=show_document_column_visibility, inputs=[], outputs=[document_column]
|
| 795 |
+
)
|
| 796 |
+
|
| 797 |
+
tab_file_input.select(
|
| 798 |
+
fn=lambda: "file-input",
|
| 799 |
+
inputs=[],
|
| 800 |
+
outputs=[input_type],
|
| 801 |
+
).then(fn=hide_system_prompt_visibility, inputs=[], outputs=[system_prompt]).then(
|
| 802 |
+
fn=show_document_column_visibility, inputs=[], outputs=[document_column]
|
| 803 |
+
)
|
| 804 |
+
|
| 805 |
+
tab_prompt_input.select(
|
| 806 |
+
fn=lambda: "prompt-input",
|
| 807 |
+
inputs=[],
|
| 808 |
+
outputs=[input_type],
|
| 809 |
+
).then(fn=show_system_prompt_visibility, inputs=[], outputs=[system_prompt]).then(
|
| 810 |
+
fn=hide_document_column_visibility, inputs=[], outputs=[document_column]
|
| 811 |
+
)
|
| 812 |
+
|
| 813 |
+
search_in.submit(
|
| 814 |
+
fn=lambda df: pd.DataFrame(columns=df.columns),
|
| 815 |
+
inputs=[dataframe],
|
| 816 |
+
outputs=[dataframe],
|
| 817 |
+
)
|
| 818 |
+
|
| 819 |
+
gr.on(
|
| 820 |
+
triggers=[load_dataset_btn.click, load_file_btn.click],
|
| 821 |
+
fn=load_dataset_file,
|
| 822 |
+
inputs=[search_in, file_in, input_type],
|
| 823 |
+
outputs=[dataframe, document_column],
|
| 824 |
+
)
|
| 825 |
+
|
| 826 |
+
load_prompt_btn.click(
|
| 827 |
+
fn=generate_system_prompt,
|
| 828 |
+
inputs=[dataset_description],
|
| 829 |
+
outputs=[system_prompt],
|
| 830 |
+
).success(
|
| 831 |
+
fn=generate_sample_dataset,
|
| 832 |
+
inputs=[
|
| 833 |
+
search_in,
|
| 834 |
+
file_in,
|
| 835 |
+
input_type,
|
| 836 |
+
system_prompt,
|
| 837 |
+
document_column,
|
| 838 |
+
retrieval_reranking,
|
| 839 |
+
num_rows,
|
| 840 |
+
],
|
| 841 |
+
outputs=dataframe,
|
| 842 |
+
)
|
| 843 |
+
|
| 844 |
+
btn_apply_to_sample_dataset.click(
|
| 845 |
+
fn=generate_sample_dataset,
|
| 846 |
+
inputs=[
|
| 847 |
+
search_in,
|
| 848 |
+
file_in,
|
| 849 |
+
input_type,
|
| 850 |
+
system_prompt,
|
| 851 |
+
document_column,
|
| 852 |
+
retrieval_reranking,
|
| 853 |
+
num_rows,
|
| 854 |
+
],
|
| 855 |
+
outputs=dataframe,
|
| 856 |
+
)
|
| 857 |
+
|
| 858 |
+
btn_push_to_hub.click(
|
| 859 |
+
fn=validate_argilla_user_workspace_dataset,
|
| 860 |
+
inputs=[repo_name],
|
| 861 |
+
outputs=[success_message],
|
| 862 |
+
).then(
|
| 863 |
+
fn=validate_push_to_hub,
|
| 864 |
+
inputs=[org_name, repo_name],
|
| 865 |
+
outputs=[success_message],
|
| 866 |
+
).success(
|
| 867 |
+
fn=hide_save_local,
|
| 868 |
+
outputs=[csv_file, json_file, success_message],
|
| 869 |
+
).success(
|
| 870 |
+
fn=hide_success_message,
|
| 871 |
+
outputs=[success_message],
|
| 872 |
+
).success(
|
| 873 |
+
fn=hide_pipeline_code_visibility,
|
| 874 |
+
inputs=[],
|
| 875 |
+
outputs=[pipeline_code_ui],
|
| 876 |
+
).success(
|
| 877 |
+
fn=push_dataset,
|
| 878 |
+
inputs=[
|
| 879 |
+
org_name,
|
| 880 |
+
repo_name,
|
| 881 |
+
private,
|
| 882 |
+
search_in,
|
| 883 |
+
file_in,
|
| 884 |
+
input_type,
|
| 885 |
+
system_prompt,
|
| 886 |
+
document_column,
|
| 887 |
+
retrieval_reranking,
|
| 888 |
+
num_rows,
|
| 889 |
+
temperature,
|
| 890 |
+
temperature_completion,
|
| 891 |
+
pipeline_code,
|
| 892 |
+
],
|
| 893 |
+
outputs=[success_message],
|
| 894 |
+
).success(
|
| 895 |
+
fn=show_success_message,
|
| 896 |
+
inputs=[org_name, repo_name],
|
| 897 |
+
outputs=[success_message],
|
| 898 |
+
).success(
|
| 899 |
+
fn=generate_pipeline_code,
|
| 900 |
+
inputs=[
|
| 901 |
+
search_in,
|
| 902 |
+
input_type,
|
| 903 |
+
system_prompt,
|
| 904 |
+
document_column,
|
| 905 |
+
retrieval_reranking,
|
| 906 |
+
num_rows,
|
| 907 |
+
],
|
| 908 |
+
outputs=[pipeline_code],
|
| 909 |
+
).success(
|
| 910 |
+
fn=show_pipeline_code_visibility,
|
| 911 |
+
inputs=[],
|
| 912 |
+
outputs=[pipeline_code_ui],
|
| 913 |
+
)
|
| 914 |
+
|
| 915 |
+
btn_save_local.click(
|
| 916 |
+
fn=hide_success_message,
|
| 917 |
+
outputs=[success_message],
|
| 918 |
+
).success(
|
| 919 |
+
fn=hide_pipeline_code_visibility,
|
| 920 |
+
inputs=[],
|
| 921 |
+
outputs=[pipeline_code_ui],
|
| 922 |
+
).success(
|
| 923 |
+
fn=show_save_local,
|
| 924 |
+
inputs=[],
|
| 925 |
+
outputs=[csv_file, json_file, success_message],
|
| 926 |
+
).success(
|
| 927 |
+
save_local,
|
| 928 |
+
inputs=[
|
| 929 |
+
search_in,
|
| 930 |
+
file_in,
|
| 931 |
+
input_type,
|
| 932 |
+
system_prompt,
|
| 933 |
+
document_column,
|
| 934 |
+
retrieval_reranking,
|
| 935 |
+
num_rows,
|
| 936 |
+
temperature,
|
| 937 |
+
repo_name,
|
| 938 |
+
temperature_completion,
|
| 939 |
+
],
|
| 940 |
+
outputs=[csv_file, json_file],
|
| 941 |
+
).success(
|
| 942 |
+
fn=generate_pipeline_code,
|
| 943 |
+
inputs=[
|
| 944 |
+
search_in,
|
| 945 |
+
input_type,
|
| 946 |
+
system_prompt,
|
| 947 |
+
document_column,
|
| 948 |
+
retrieval_reranking,
|
| 949 |
+
num_rows,
|
| 950 |
+
],
|
| 951 |
+
outputs=[pipeline_code],
|
| 952 |
+
).success(
|
| 953 |
+
fn=show_pipeline_code_visibility,
|
| 954 |
+
inputs=[],
|
| 955 |
+
outputs=[pipeline_code_ui],
|
| 956 |
+
)
|
| 957 |
+
|
| 958 |
+
clear_dataset_btn_part.click(fn=lambda: "", inputs=[], outputs=[search_in])
|
| 959 |
+
clear_file_btn_part.click(fn=lambda: None, inputs=[], outputs=[file_in])
|
| 960 |
+
clear_prompt_btn_part.click(fn=lambda: "", inputs=[], outputs=[dataset_description])
|
| 961 |
+
clear_btn_full.click(
|
| 962 |
+
fn=lambda df: ("", [], pd.DataFrame(columns=df.columns)),
|
| 963 |
+
inputs=[dataframe],
|
| 964 |
+
outputs=[document_column, retrieval_reranking, dataframe],
|
| 965 |
+
)
|
| 966 |
+
|
| 967 |
+
app.load(fn=swap_visibility, outputs=main_ui)
|
| 968 |
+
app.load(fn=get_org_dropdown, outputs=[org_name])
|
| 969 |
+
app.load(fn=get_random_repo_name, outputs=[repo_name])
|
| 970 |
+
app.load(fn=show_temperature_completion, outputs=[temperature_completion])
|
| 971 |
+
if SAVE_LOCAL_DIR is not None:
|
| 972 |
+
app.load(fn=show_save_local_button, outputs=btn_save_local)
|