Chuanming commited on
Commit
fa4458a
·
1 Parent(s): 6712c8a

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .github/workflows/benchmark.yml +107 -0
  2. .github/workflows/build_documentation.yml +18 -0
  3. .github/workflows/build_pr_documentation.yml +17 -0
  4. .github/workflows/clear_cache.yml +33 -0
  5. .github/workflows/stale.yml +27 -0
  6. .github/workflows/tests.yml +75 -0
  7. .github/workflows/upload_pr_documentation.yml +16 -0
  8. .gitignore +146 -0
  9. .pre-commit-config.yaml +42 -0
  10. CITATION.cff +28 -0
  11. CONTRIBUTING.md +53 -0
  12. LICENSE +201 -0
  13. MANIFEST.in +5 -0
  14. Makefile +15 -0
  15. README.md +184 -0
  16. benchmark/benchmark.py +150 -0
  17. benchmark/benchmark_and_report.sh +41 -0
  18. benchmark/benchmark_level1.sh +11 -0
  19. benchmark/benchmark_level1_plot.sh +20 -0
  20. benchmark/benchmark_level2.sh +23 -0
  21. benchmark/benchmark_level2_plot.sh +31 -0
  22. benchmark/benchmark_level3.sh +46 -0
  23. benchmark/plot.sh +56 -0
  24. benchmark/post_github_comment.py +26 -0
  25. benchmark/post_github_comment.sbatch +9 -0
  26. benchmark/trl.slurm_template +16 -0
  27. benchmark/upload_benchmark.py +23 -0
  28. docs/source/_toctree.yml +54 -0
  29. docs/source/best_of_n.mdx +72 -0
  30. docs/source/customization.mdx +216 -0
  31. docs/source/ddpo_trainer.mdx +119 -0
  32. docs/source/detoxifying_a_lm.mdx +191 -0
  33. docs/source/dpo_trainer.mdx +106 -0
  34. docs/source/example_overview.md +73 -0
  35. docs/source/how_to_train.md +66 -0
  36. docs/source/index.mdx +61 -0
  37. docs/source/installation.mdx +24 -0
  38. docs/source/iterative_sft_trainer.mdx +54 -0
  39. docs/source/learning_tools.mdx +234 -0
  40. docs/source/logging.mdx +75 -0
  41. docs/source/lora_tuning_peft.mdx +144 -0
  42. docs/source/models.mdx +28 -0
  43. docs/source/multi_adapter_rl.mdx +100 -0
  44. docs/source/ppo_trainer.mdx +151 -0
  45. docs/source/quickstart.mdx +88 -0
  46. docs/source/reward_trainer.mdx +77 -0
  47. docs/source/sentiment_tuning.mdx +130 -0
  48. docs/source/sft_trainer.mdx +455 -0
  49. docs/source/text_environments.md +197 -0
  50. docs/source/trainer.mdx +45 -0
.github/workflows/benchmark.yml ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: "Benchmark on Comment"
2
+
3
+ # https://docs.github.com/en/actions/using-workflows/events-that-trigger-workflows
4
+ on:
5
+ issue_comment:
6
+ types: [created]
7
+
8
+ jobs:
9
+ Benchmark:
10
+ strategy:
11
+ fail-fast: true
12
+ matrix:
13
+ python-version: [3.9]
14
+ os: [self-hosted]
15
+
16
+ name: Benchmark
17
+ # Only run if it#s a PR and the comment contains /Benchmark
18
+ if: github.event.issue.pull_request && startsWith(github.event.comment.body, '/benchmark-trl-experiments') && contains(FromJSON('["vwxyzjn", "younesbelkada", "lvwerra", "lewtun"]'), github.actor)
19
+ runs-on: ${{ matrix.os }}
20
+
21
+ steps:
22
+ - name: Get branch of PR
23
+ uses: xt0rted/pull-request-comment-branch@v1
24
+ id: comment-branch
25
+ - name: Set latest commit status as pending
26
+ uses: myrotvorets/set-commit-status-action@master
27
+ with:
28
+ sha: ${{ steps.comment-branch.outputs.head_sha }}
29
+ token: ${{ secrets.GITHUB_TOKEN }}
30
+ status: pending
31
+ - name: Checkout `main` branch
32
+ uses: actions/checkout@v3
33
+ - name: Checkout PR branch
34
+ run: gh pr checkout $PR_NUMBER
35
+ env:
36
+ GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
37
+ PR_NUMBER: ${{ github.event.issue.number }}
38
+ - name: Set up Python ${{ matrix.python-version }}
39
+ uses: actions/setup-python@v4
40
+ with:
41
+ python-version: ${{ matrix.python-version }}
42
+ # - name: Cleanup pip packages (specific to self-hosted runners)
43
+ # run: |
44
+ # echo PATH is $PATH
45
+ # echo PYTHONPATH is $PYTHONPATH
46
+ # echo which python is $(which python)
47
+ # echo which pip is $(which pip)
48
+
49
+ # pip_list=$(pip list --format=freeze | grep -v "^pip==" | grep -v "^setuptools==")
50
+ # if [ ! -z "$pip_list" ]; then
51
+ # echo "$pip_list" | xargs pip uninstall -y
52
+ # fi
53
+ - name: Print python depdenencies
54
+ run: pip list --format=freeze
55
+ - name: Install dependencies
56
+ run: |
57
+ pip install .[test,benchmark]
58
+
59
+ - name: Login
60
+ run: wandb login ${{ secrets.WANDB_API_KEY }} && huggingface-cli login --token ${{ secrets.HUGGING_FACE_HUB_TOKEN }}
61
+ - name: Run benchmark
62
+ env:
63
+ GITHUB_CONTEXT: ${{ toJson(github) }}
64
+ PERSONAL_ACCESS_TOKEN_GITHUB: ${{ secrets.PERSONAL_ACCESS_TOKEN_GITHUB }}
65
+ run: |
66
+ COMMENT="${{ github.event.comment.body }}"
67
+ if [[ "$COMMENT" == *"/benchmark-trl-experiments benchmark/benchmark_level1.sh"* ]]; then
68
+ echo "Running benchmark/benchmark_level1.sh"
69
+ BENCHMARK_SCRIPT="benchmark/benchmark_level1.sh" BENCHMARK_PLOT_SCRIPT="benchmark/benchmark_level1_plot.sh" bash benchmark/benchmark_and_report.sh
70
+ elif [[ "$COMMENT" == *"/benchmark-trl-experiments benchmark/benchmark_level2.sh"* ]]; then
71
+ echo "Running benchmark/benchmark_level2.sh"
72
+ BENCHMARK_SCRIPT="benchmark/benchmark_level2.sh" BENCHMARK_PLOT_SCRIPT="benchmark/benchmark_level2_plot.sh" bash benchmark/benchmark_and_report.sh
73
+ elif [[ "$COMMENT" == *"/benchmark-trl-experiments benchmark/benchmark_level3.sh"* ]]; then
74
+ echo "Running benchmark/benchmark_level3.sh"
75
+ BENCHMARK_SCRIPT="benchmark/benchmark_level3.sh" BENCHMARK_PLOT_SCRIPT="benchmark/benchmark_level3_plot.sh" bash benchmark/benchmark_and_report.sh
76
+ else
77
+ echo "Invalid command in comment. Skipping execution."
78
+ fi
79
+
80
+ # send message to PR
81
+ - name: Setup Node.js 16
82
+ uses: actions/setup-node@v3
83
+ with:
84
+ node-version: 16
85
+ - name: Add workflow result as comment on PR
86
+ uses: actions/github-script@v6
87
+ if: always()
88
+ with:
89
+ script: |
90
+ const name = '${{ github.workflow }}';
91
+ const url = '${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ github.run_id }}';
92
+ const success = '${{ job.status }}' === 'success';
93
+ const body = `${name}: ${success ? 'succeeded ✅' : 'failed ❌'}\n${url}`;
94
+
95
+ await github.rest.issues.createComment({
96
+ issue_number: context.issue.number,
97
+ owner: context.repo.owner,
98
+ repo: context.repo.repo,
99
+ body: body
100
+ })
101
+ - name: Set latest commit status as ${{ job.status }}
102
+ uses: myrotvorets/set-commit-status-action@master
103
+ if: always()
104
+ with:
105
+ sha: ${{ steps.comment-branch.outputs.head_sha }}
106
+ token: ${{ secrets.GITHUB_TOKEN }}
107
+ status: ${{ job.status }}
.github/workflows/build_documentation.yml ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Build documentation
2
+
3
+ on:
4
+ push:
5
+ branches:
6
+ - main
7
+ - doc-builder*
8
+ - v*-release
9
+
10
+ jobs:
11
+ build:
12
+ uses: huggingface/doc-builder/.github/workflows/build_main_documentation.yml@main
13
+ with:
14
+ commit_sha: ${{ github.sha }}
15
+ package: trl
16
+ version_tag_suffix: ""
17
+ secrets:
18
+ hf_token: ${{ secrets.HF_DOC_BUILD_PUSH }}
.github/workflows/build_pr_documentation.yml ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Build PR Documentation
2
+
3
+ on:
4
+ pull_request:
5
+
6
+ concurrency:
7
+ group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
8
+ cancel-in-progress: true
9
+
10
+ jobs:
11
+ build:
12
+ uses: huggingface/doc-builder/.github/workflows/build_pr_documentation.yml@main
13
+ with:
14
+ commit_sha: ${{ github.event.pull_request.head.sha }}
15
+ pr_number: ${{ github.event.number }}
16
+ package: trl
17
+ version_tag_suffix: ""
.github/workflows/clear_cache.yml ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: "Cleanup Cache"
2
+
3
+ on:
4
+ workflow_dispatch:
5
+ schedule:
6
+ - cron: "0 0 * * *"
7
+
8
+ jobs:
9
+ cleanup:
10
+ runs-on: ubuntu-latest
11
+ steps:
12
+ - name: Check out code
13
+ uses: actions/checkout@v3
14
+
15
+ - name: Cleanup
16
+ run: |
17
+ gh extension install actions/gh-actions-cache
18
+
19
+ REPO=${{ github.repository }}
20
+
21
+ echo "Fetching list of cache key"
22
+ cacheKeysForPR=$(gh actions-cache list -R $REPO | cut -f 1 )
23
+
24
+ ## Setting this to not fail the workflow while deleting cache keys.
25
+ set +e
26
+ echo "Deleting caches..."
27
+ for cacheKey in $cacheKeysForPR
28
+ do
29
+ gh actions-cache delete $cacheKey -R $REPO --confirm
30
+ done
31
+ echo "Done"
32
+ env:
33
+ GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
.github/workflows/stale.yml ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Stale Bot
2
+
3
+ on:
4
+ schedule:
5
+ - cron: "0 15 * * *"
6
+
7
+ jobs:
8
+ close_stale_issues:
9
+ name: Close Stale Issues
10
+ if: github.repository == 'huggingface/trl'
11
+ runs-on: ubuntu-latest
12
+ env:
13
+ GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
14
+ steps:
15
+ - uses: actions/checkout@v3
16
+
17
+ - name: Setup Python
18
+ uses: actions/setup-python@v4
19
+ with:
20
+ python-version: 3.8
21
+
22
+ - name: Install requirements
23
+ run: |
24
+ pip install PyGithub
25
+ - name: Close stale issues
26
+ run: |
27
+ python scripts/stale.py
.github/workflows/tests.yml ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: tests
2
+
3
+ on:
4
+ push:
5
+ branches: [ main ]
6
+ pull_request:
7
+ branches: [ main ]
8
+
9
+ jobs:
10
+ check_code_quality:
11
+ runs-on: ubuntu-latest
12
+ strategy:
13
+ matrix:
14
+ python-version: [3.9]
15
+
16
+ steps:
17
+ - uses: actions/checkout@v2
18
+ with:
19
+ fetch-depth: 0
20
+ submodules: recursive
21
+ - name: Set up Python ${{ matrix.python-version }}
22
+ uses: actions/setup-python@v2
23
+ with:
24
+ python-version: ${{ matrix.python-version }}
25
+ - uses: pre-commit/action@v2.0.3
26
+ with:
27
+ extra_args: --all-files
28
+
29
+ tests:
30
+ needs: check_code_quality
31
+ strategy:
32
+ matrix:
33
+ python-version: ['3.8', '3.9', '3.10']
34
+ os: ['ubuntu-latest', 'windows-latest']
35
+ runs-on: ${{ matrix.os }}
36
+ steps:
37
+ - uses: actions/checkout@v3
38
+ - name: Set up Python ${{ matrix.python-version }}
39
+ uses: actions/setup-python@v4
40
+ with:
41
+ python-version: ${{ matrix.python-version }}
42
+ cache: "pip"
43
+ cache-dependency-path: |
44
+ setup.py
45
+ requirements.txt
46
+ - name: Install dependencies
47
+ run: |
48
+ python -m pip install --upgrade pip
49
+ # cpu version of pytorch
50
+ pip install -e ".[test, peft, diffusers]"
51
+ - name: Test with pytest
52
+ run: |
53
+ make test
54
+
55
+ tests_no_optional_dep:
56
+ needs: check_code_quality
57
+ runs-on: 'ubuntu-latest'
58
+ steps:
59
+ - uses: actions/checkout@v3
60
+ - name: Set up Python 3.9
61
+ uses: actions/setup-python@v4
62
+ with:
63
+ python-version: '3.9'
64
+ cache: "pip"
65
+ cache-dependency-path: |
66
+ setup.py
67
+ requirements.txt
68
+ - name: Install dependencies
69
+ run: |
70
+ python -m pip install --upgrade pip
71
+ # cpu version of pytorch
72
+ pip install .[test]
73
+ - name: Test with pytest
74
+ run: |
75
+ make test
.github/workflows/upload_pr_documentation.yml ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Upload PR Documentation
2
+
3
+ on:
4
+ workflow_run:
5
+ workflows: ["Build PR Documentation"]
6
+ types:
7
+ - completed
8
+
9
+ jobs:
10
+ build:
11
+ uses: huggingface/doc-builder/.github/workflows/upload_pr_documentation.yml@main
12
+ with:
13
+ package_name: trl
14
+ secrets:
15
+ hf_token: ${{ secrets.HF_DOC_BUILD_PUSH }}
16
+ comment_bot_token: ${{ secrets.COMMENT_BOT_TOKEN }}
.gitignore ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ benchmark/trl
2
+ *.bak
3
+ .gitattributes
4
+ .last_checked
5
+ .gitconfig
6
+ *.bak
7
+ *.log
8
+ *~
9
+ ~*
10
+ _tmp*
11
+ tmp*
12
+ tags
13
+
14
+ # Byte-compiled / optimized / DLL files
15
+ __pycache__/
16
+ *.py[cod]
17
+ *$py.class
18
+
19
+ # C extensions
20
+ *.so
21
+
22
+ # Distribution / packaging
23
+ .Python
24
+ env/
25
+ build/
26
+ develop-eggs/
27
+ dist/
28
+ downloads/
29
+ eggs/
30
+ .eggs/
31
+ lib/
32
+ lib64/
33
+ parts/
34
+ sdist/
35
+ var/
36
+ wheels/
37
+ *.egg-info/
38
+ .installed.cfg
39
+ *.egg
40
+
41
+ # PyInstaller
42
+ # Usually these files are written by a python script from a template
43
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
44
+ *.manifest
45
+ *.spec
46
+
47
+ # Installer logs
48
+ pip-log.txt
49
+ pip-delete-this-directory.txt
50
+
51
+ # Unit test / coverage reports
52
+ htmlcov/
53
+ .tox/
54
+ .coverage
55
+ .coverage.*
56
+ .cache
57
+ nosetests.xml
58
+ coverage.xml
59
+ *.cover
60
+ .hypothesis/
61
+
62
+ # Translations
63
+ *.mo
64
+ *.pot
65
+
66
+ # Django stuff:
67
+ *.log
68
+ local_settings.py
69
+
70
+ # Flask stuff:
71
+ instance/
72
+ .webassets-cache
73
+
74
+ # Scrapy stuff:
75
+ .scrapy
76
+
77
+ # Sphinx documentation
78
+ docs/_build/
79
+
80
+ # PyBuilder
81
+ target/
82
+
83
+ # Jupyter Notebook
84
+ .ipynb_checkpoints
85
+
86
+ # pyenv
87
+ .python-version
88
+
89
+ # celery beat schedule file
90
+ celerybeat-schedule
91
+
92
+ # SageMath parsed files
93
+ *.sage.py
94
+
95
+ # dotenv
96
+ .env
97
+
98
+ # virtualenv
99
+ .venv
100
+ venv/
101
+ ENV/
102
+
103
+ # Spyder project settings
104
+ .spyderproject
105
+ .spyproject
106
+
107
+ # Rope project settings
108
+ .ropeproject
109
+
110
+ # mkdocs documentation
111
+ /site
112
+
113
+ # mypy
114
+ .mypy_cache/
115
+
116
+ .vscode
117
+ *.swp
118
+
119
+ # osx generated files
120
+ .DS_Store
121
+ .DS_Store?
122
+ .Trashes
123
+ ehthumbs.db
124
+ Thumbs.db
125
+ .idea
126
+
127
+ # pytest
128
+ .pytest_cache
129
+
130
+ # tools/trust-doc-nbs
131
+ docs_src/.last_checked
132
+
133
+ # symlinks to fastai
134
+ docs_src/fastai
135
+ tools/fastai
136
+
137
+ # link checker
138
+ checklink/cookies.txt
139
+
140
+ # .gitconfig is now autogenerated
141
+ .gitconfig
142
+
143
+ # wandb files
144
+ nbs/wandb/
145
+ examples/notebooks/wandb/
146
+ wandb/
.pre-commit-config.yaml ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ repos:
2
+ - repo: https://github.com/PyCQA/isort
3
+ rev: 5.12.0
4
+ hooks:
5
+ - id: isort
6
+ args:
7
+ - --profile=black
8
+ - --skip-glob=wandb/**/*
9
+ - --thirdparty=wandb
10
+ - repo: https://github.com/myint/autoflake
11
+ rev: v1.4
12
+ hooks:
13
+ - id: autoflake
14
+ args:
15
+ - -r
16
+ - --exclude=wandb,__init__.py
17
+ - --in-place
18
+ - --remove-unused-variables
19
+ - --remove-all-unused-imports
20
+ - repo: https://github.com/python/black
21
+ rev: 22.3.0
22
+ hooks:
23
+ - id: black
24
+ args:
25
+ - --line-length=119
26
+ - --target-version=py38
27
+ - --exclude=wandb
28
+ - repo: https://github.com/pycqa/flake8
29
+ rev: 6.0.0
30
+ hooks:
31
+ - id: flake8
32
+ args:
33
+ - --ignore=E203,E501,W503,E128
34
+ - --max-line-length=119
35
+
36
+ # - repo: https://github.com/codespell-project/codespell
37
+ # rev: v2.1.0
38
+ # hooks:
39
+ # - id: codespell
40
+ # args:
41
+ # - --ignore-words-list=nd,reacher,thist,ths,magent,ba
42
+ # - --skip=docs/css/termynal.css,docs/js/termynal.js
CITATION.cff ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ cff-version: 1.2.0
2
+ title: 'TRL: Transformer Reinforcement Learning'
3
+ message: >-
4
+ If you use this software, please cite it using the
5
+ metadata from this file.
6
+ type: software
7
+ authors:
8
+ - given-names: Leandro
9
+ family-names: von Werra
10
+ - given-names: Younes
11
+ family-names: Belkada
12
+ - given-names: Lewis
13
+ family-names: Tunstall
14
+ - given-names: Edward
15
+ family-names: Beeching
16
+ - given-names: Tristan
17
+ family-names: Thrush
18
+ - given-names: Nathan
19
+ family-names: Lambert
20
+ repository-code: 'https://github.com/huggingface/trl'
21
+ abstract: "With trl you can train transformer language models with Proximal Policy Optimization (PPO). The library is built on top of the transformers library by \U0001F917 Hugging Face. Therefore, pre-trained language models can be directly loaded via transformers. At this point, most decoder and encoder-decoder architectures are supported."
22
+ keywords:
23
+ - rlhf
24
+ - deep-learning
25
+ - pytorch
26
+ - transformers
27
+ license: Apache-2.0
28
+ version: 0.2.1
CONTRIBUTING.md ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # How to contribute
2
+
3
+ ## How to get started
4
+
5
+ Before you start contributing make sure you installed all the dev tools:
6
+
7
+ ```bash
8
+ pip install -e ".[dev]"
9
+ ```
10
+
11
+ ## Did you find a bug?
12
+
13
+ * Ensure the bug was not already reported by searching on GitHub under Issues.
14
+ * If you're unable to find an open issue addressing the problem, open a new one. Be sure to include a title and clear description, as much relevant information as possible, and a code sample or an executable test case demonstrating the expected behavior that is not occurring.
15
+ * Be sure to add the complete error messages.
16
+
17
+ #### Did you write a patch that fixes a bug?
18
+
19
+ * Open a new GitHub pull request with the patch.
20
+ * Ensure that your PR includes a test that fails without your patch, and pass with it.
21
+ * Ensure the PR description clearly describes the problem and solution. Include the relevant issue number if applicable.
22
+
23
+ ## PR submission guidelines
24
+
25
+ * Keep each PR focused. While it's more convenient, do not combine several unrelated fixes together. Create as many branches as needing to keep each PR focused.
26
+ * Do not mix style changes/fixes with "functional" changes. It's very difficult to review such PRs and it most likely get rejected.
27
+ * Do not add/remove vertical whitespace. Preserve the original style of the file you edit as much as you can.
28
+ * Do not turn an already submitted PR into your development playground. If after you submitted PR, you discovered that more work is needed - close the PR, do the required work and then submit a new PR. Otherwise each of your commits requires attention from maintainers of the project.
29
+ * If, however, you submitted a PR and received a request for changes, you should proceed with commits inside that PR, so that the maintainer can see the incremental fixes and won't need to review the whole PR again. In the exception case where you realize it'll take many many commits to complete the requests, then it's probably best to close the PR, do the work and then submit it again. Use common sense where you'd choose one way over another.
30
+
31
+ ### Before you submit a PR
32
+
33
+ First you want to make sure that all the tests pass:
34
+
35
+ ```bash
36
+ make test
37
+ ```
38
+
39
+ Then before submitting your PR make sure the code quality follows the standards. You can run the following command to format:
40
+
41
+ ```bash
42
+ make precommit
43
+ ```
44
+
45
+ Make sure to install `pre-commit` before running the command:
46
+ ```bash
47
+ pip install pre-commit
48
+ ```
49
+
50
+ ## Do you want to contribute to the documentation?
51
+
52
+ * Docs are in the `docs/` folder and can be updated there.
53
+
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
MANIFEST.in ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ include settings.ini
2
+ include LICENSE
3
+ include CONTRIBUTING.md
4
+ include README.md
5
+ recursive-exclude * __pycache__
Makefile ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .PHONY: test precommit benchmark_core benchmark_aux
2
+
3
+ check_dirs := examples tests trl
4
+
5
+ test:
6
+ python -m pytest -n auto --dist=loadfile -s -v ./tests/
7
+
8
+ precommit:
9
+ pre-commit run --all-files
10
+
11
+ benchmark_core:
12
+ bash ./benchmark/benchmark_core.sh
13
+
14
+ benchmark_aux:
15
+ bash ./benchmark/benchmark_aux.sh
README.md ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <div style="text-align: center">
2
+ <img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/trl_banner_dark.png">
3
+ </div>
4
+
5
+ # TRL - Transformer Reinforcement Learning
6
+ > Full stack transformer language models with reinforcement learning.
7
+
8
+ <p align="center">
9
+ <a href="https://github.com/huggingface/trl/blob/main/LICENSE">
10
+ <img alt="License" src="https://img.shields.io/github/license/huggingface/trl.svg?color=blue">
11
+ </a>
12
+ <a href="https://huggingface.co/docs/trl/index">
13
+ <img alt="Documentation" src="https://img.shields.io/website/http/huggingface.co/docs/trl/index.svg?down_color=red&down_message=offline&up_message=online">
14
+ </a>
15
+ <a href="https://github.com/huggingface/trl/releases">
16
+ <img alt="GitHub release" src="https://img.shields.io/github/release/huggingface/trl.svg">
17
+ </a>
18
+ </p>
19
+
20
+
21
+ ## What is it?
22
+
23
+ <div style="text-align: center">
24
+ <img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/TRL-readme.png">
25
+ </div>
26
+
27
+ `trl` is a full stack library where we provide a set of tools to train transformer language models and stable diffusion models with Reinforcement Learning, from the Supervised Fine-tuning step (SFT), Reward Modeling step (RM) to the Proximal Policy Optimization (PPO) step. The library is built on top of the [`transformers`](https://github.com/huggingface/transformers) library by 🤗 Hugging Face. Therefore, pre-trained language models can be directly loaded via `transformers`. At this point, most of decoder architectures and encoder-decoder architectures are supported. Refer to the documentation or the `examples/` folder for example code snippets and how to run these tools.
28
+
29
+ **Highlights:**
30
+
31
+ - [`SFTTrainer`](https://huggingface.co/docs/trl/sft_trainer): A light and friendly wrapper around `transformers` Trainer to easily fine-tune language models or adapters on a custom dataset.
32
+ - [`RewardTrainer`](https://huggingface.co/docs/trl/reward_trainer): A light wrapper around `transformers` Trainer to easily fine-tune language models for human preferences (Reward Modeling).
33
+ - [`PPOTrainer`](https://huggingface.co/docs/trl/trainer#trl.PPOTrainer): A PPO trainer for language models that just needs (query, response, reward) triplets to optimise the language model.
34
+ - [`AutoModelForCausalLMWithValueHead`](https://huggingface.co/docs/trl/models#trl.AutoModelForCausalLMWithValueHead) & [`AutoModelForSeq2SeqLMWithValueHead`](https://huggingface.co/docs/trl/models#trl.AutoModelForSeq2SeqLMWithValueHead): A transformer model with an additional scalar output for each token which can be used as a value function in reinforcement learning.
35
+ - [Examples](https://github.com/huggingface/trl/tree/main/examples): Train GPT2 to generate positive movie reviews with a BERT sentiment classifier, full RLHF using adapters only, train GPT-j to be less toxic, [Stack-Llama example](https://huggingface.co/blog/stackllama), etc.
36
+
37
+ ## How PPO works
38
+ Fine-tuning a language model via PPO consists of roughly three steps:
39
+
40
+ 1. **Rollout**: The language model generates a response or continuation based on query which could be the start of a sentence.
41
+ 2. **Evaluation**: The query and response are evaluated with a function, model, human feedback or some combination of them. The important thing is that this process should yield a scalar value for each query/response pair.
42
+ 3. **Optimization**: This is the most complex part. In the optimisation step the query/response pairs are used to calculate the log-probabilities of the tokens in the sequences. This is done with the model that is trained and a reference model, which is usually the pre-trained model before fine-tuning. The KL-divergence between the two outputs is used as an additional reward signal to make sure the generated responses don't deviate too far from the reference language model. The active language model is then trained with PPO.
43
+
44
+ This process is illustrated in the sketch below:
45
+
46
+
47
+ <div style="text-align: center">
48
+ <img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/trl_overview.png" width="800">
49
+ <p style="text-align: center;"> <b>Figure:</b> Sketch of the workflow. </p>
50
+ </div>
51
+
52
+ ## Installation
53
+
54
+ ### Python package
55
+ Install the library with pip:
56
+ ```bash
57
+ pip install trl
58
+ ```
59
+
60
+ ### From source
61
+ If you want to run the examples in the repository a few additional libraries are required. Clone the repository and install it with pip:
62
+ ```bash
63
+ git clone https://github.com/huggingface/trl.git
64
+ cd trl/
65
+ pip install .
66
+ ```
67
+
68
+ If you wish to develop TRL, you should install in editable mode:
69
+ ```bash
70
+ pip install -e .
71
+ ```
72
+
73
+ ## How to use
74
+
75
+ ### `SFTTrainer`
76
+
77
+ This is a basic example on how to use the `SFTTrainer` from the library. The `SFTTrainer` is a light wrapper around the `transformers` Trainer to easily fine-tune language models or adapters on a custom dataset.
78
+
79
+ ```python
80
+ # imports
81
+ from datasets import load_dataset
82
+ from trl import SFTTrainer
83
+
84
+ # get dataset
85
+ dataset = load_dataset("imdb", split="train")
86
+
87
+ # get trainer
88
+ trainer = SFTTrainer(
89
+ "facebook/opt-350m",
90
+ train_dataset=dataset,
91
+ dataset_text_field="text",
92
+ max_seq_length=512,
93
+ )
94
+
95
+ # train
96
+ trainer.train()
97
+ ```
98
+
99
+ ### `RewardTrainer`
100
+
101
+ This is a basic example on how to use the `RewardTrainer` from the library. The `RewardTrainer` is a wrapper around the `transformers` Trainer to easily fine-tune reward models or adapters on a custom preference dataset.
102
+
103
+ ```python
104
+ # imports
105
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer
106
+ from trl import RewardTrainer
107
+
108
+ # load model and dataset - dataset needs to be in a specific format
109
+ model = AutoModelForSequenceClassification.from_pretrained("gpt2", num_labels=1)
110
+ tokenizer = AutoTokenizer.from_pretrained("gpt2")
111
+
112
+ ...
113
+
114
+ # load trainer
115
+ trainer = RewardTrainer(
116
+ model=model,
117
+ tokenizer=tokenizer,
118
+ train_dataset=dataset,
119
+ )
120
+
121
+ # train
122
+ trainer.train()
123
+ ```
124
+
125
+ ### `PPOTrainer`
126
+
127
+ This is a basic example on how to use the `PPOTrainer` from the library. Based on a query the language model creates a response which is then evaluated. The evaluation could be a human in the loop or another model's output.
128
+
129
+ ```python
130
+ # imports
131
+ import torch
132
+ from transformers import AutoTokenizer
133
+ from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead, create_reference_model
134
+ from trl.core import respond_to_batch
135
+
136
+ # get models
137
+ model = AutoModelForCausalLMWithValueHead.from_pretrained('gpt2')
138
+ model_ref = create_reference_model(model)
139
+
140
+ tokenizer = AutoTokenizer.from_pretrained('gpt2')
141
+
142
+ # initialize trainer
143
+ ppo_config = PPOConfig(
144
+ batch_size=1,
145
+ )
146
+
147
+ # encode a query
148
+ query_txt = "This morning I went to the "
149
+ query_tensor = tokenizer.encode(query_txt, return_tensors="pt")
150
+
151
+ # get model response
152
+ response_tensor = respond_to_batch(model, query_tensor)
153
+
154
+ # create a ppo trainer
155
+ ppo_trainer = PPOTrainer(ppo_config, model, model_ref, tokenizer)
156
+
157
+ # define a reward for response
158
+ # (this could be any reward such as human feedback or output from another model)
159
+ reward = [torch.tensor(1.0)]
160
+
161
+ # train model for one step with ppo
162
+ train_stats = ppo_trainer.step([query_tensor[0]], [response_tensor[0]], reward)
163
+ ```
164
+
165
+ ## References
166
+
167
+ ### Proximal Policy Optimisation
168
+ The PPO implementation largely follows the structure introduced in the paper **"Fine-Tuning Language Models from Human Preferences"** by D. Ziegler et al. \[[paper](https://arxiv.org/pdf/1909.08593.pdf), [code](https://github.com/openai/lm-human-preferences)].
169
+
170
+ ### Language models
171
+ The language models utilize the `transformers` library by 🤗 Hugging Face.
172
+
173
+ ## Citation
174
+
175
+ ```bibtex
176
+ @misc{vonwerra2022trl,
177
+ author = {Leandro von Werra and Younes Belkada and Lewis Tunstall and Edward Beeching and Tristan Thrush and Nathan Lambert and Shengyi Huang},
178
+ title = {TRL: Transformer Reinforcement Learning},
179
+ year = {2020},
180
+ publisher = {GitHub},
181
+ journal = {GitHub repository},
182
+ howpublished = {\url{https://github.com/huggingface/trl}}
183
+ }
184
+ ```
benchmark/benchmark.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import math
3
+ import os
4
+ import shlex
5
+ import subprocess
6
+ import uuid
7
+ from distutils.util import strtobool
8
+
9
+ import requests
10
+
11
+
12
+ def parse_args():
13
+ # fmt: off
14
+ parser = argparse.ArgumentParser()
15
+ parser.add_argument("--command", type=str, default="",
16
+ help="the command to run")
17
+ parser.add_argument("--num-seeds", type=int, default=3,
18
+ help="the number of random seeds")
19
+ parser.add_argument("--start-seed", type=int, default=1,
20
+ help="the number of the starting seed")
21
+ parser.add_argument("--workers", type=int, default=0,
22
+ help="the number of workers to run benchmark experimenets")
23
+ parser.add_argument("--auto-tag", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
24
+ help="if toggled, the runs will be tagged with git tags, commit, and pull request number if possible")
25
+ parser.add_argument("--slurm-template-path", type=str, default=None,
26
+ help="the path to the slurm template file (see docs for more details)")
27
+ parser.add_argument("--slurm-gpus-per-task", type=int, default=1,
28
+ help="the number of gpus per task to use for slurm jobs")
29
+ parser.add_argument("--slurm-total-cpus", type=int, default=50,
30
+ help="the number of gpus per task to use for slurm jobs")
31
+ parser.add_argument("--slurm-ntasks", type=int, default=1,
32
+ help="the number of tasks to use for slurm jobs")
33
+ parser.add_argument("--slurm-nodes", type=int, default=None,
34
+ help="the number of nodes to use for slurm jobs")
35
+ args = parser.parse_args()
36
+ # fmt: on
37
+ return args
38
+
39
+
40
+ def run_experiment(command: str):
41
+ command_list = shlex.split(command)
42
+ print(f"running {command}")
43
+
44
+ # Use subprocess.PIPE to capture the output
45
+ fd = subprocess.Popen(command_list, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
46
+ output, errors = fd.communicate()
47
+
48
+ return_code = fd.returncode
49
+ assert return_code == 0, f"Command failed with error: {errors.decode('utf-8')}"
50
+
51
+ # Convert bytes to string and strip leading/trailing whitespaces
52
+ return output.decode("utf-8").strip()
53
+
54
+
55
+ def autotag() -> str:
56
+ wandb_tag = ""
57
+ print("autotag feature is enabled")
58
+ git_tag = ""
59
+ try:
60
+ git_tag = subprocess.check_output(["git", "describe", "--tags"]).decode("ascii").strip()
61
+ print(f"identified git tag: {git_tag}")
62
+ except subprocess.CalledProcessError as e:
63
+ print(e)
64
+ if len(git_tag) == 0:
65
+ try:
66
+ count = int(subprocess.check_output(["git", "rev-list", "--count", "HEAD"]).decode("ascii").strip())
67
+ hash = subprocess.check_output(["git", "rev-parse", "--short", "HEAD"]).decode("ascii").strip()
68
+ git_tag = f"no-tag-{count}-g{hash}"
69
+ print(f"identified git tag: {git_tag}")
70
+ except subprocess.CalledProcessError as e:
71
+ print(e)
72
+ wandb_tag = f"{git_tag}"
73
+
74
+ git_commit = subprocess.check_output(["git", "rev-parse", "--verify", "HEAD"]).decode("ascii").strip()
75
+ try:
76
+ # try finding the pull request number on github
77
+ prs = requests.get(f"https://api.github.com/search/issues?q=repo:huggingface/trl+is:pr+{git_commit}")
78
+ if prs.status_code == 200:
79
+ prs = prs.json()
80
+ if len(prs["items"]) > 0:
81
+ pr = prs["items"][0]
82
+ pr_number = pr["number"]
83
+ wandb_tag += f",pr-{pr_number}"
84
+ print(f"identified github pull request: {pr_number}")
85
+ except Exception as e:
86
+ print(e)
87
+
88
+ return wandb_tag
89
+
90
+
91
+ if __name__ == "__main__":
92
+ args = parse_args()
93
+ if args.auto_tag:
94
+ existing_wandb_tag = os.environ.get("WANDB_TAGS", "")
95
+ wandb_tag = autotag()
96
+ if len(wandb_tag) > 0:
97
+ if len(existing_wandb_tag) > 0:
98
+ os.environ["WANDB_TAGS"] = ",".join([existing_wandb_tag, wandb_tag])
99
+ else:
100
+ os.environ["WANDB_TAGS"] = wandb_tag
101
+ print("WANDB_TAGS: ", os.environ.get("WANDB_TAGS", ""))
102
+ commands = []
103
+ for seed in range(0, args.num_seeds):
104
+ commands += [" ".join([args.command, "--seed", str(args.start_seed + seed)])]
105
+
106
+ print("======= commands to run:")
107
+ for command in commands:
108
+ print(command)
109
+
110
+ if args.workers > 0 and args.slurm_template_path is None:
111
+ from concurrent.futures import ThreadPoolExecutor
112
+
113
+ executor = ThreadPoolExecutor(max_workers=args.workers, thread_name_prefix="cleanrl-benchmark-worker-")
114
+ for command in commands:
115
+ executor.submit(run_experiment, command)
116
+ executor.shutdown(wait=True)
117
+ else:
118
+ print("not running the experiments because --workers is set to 0; just printing the commands to run")
119
+
120
+ # SLURM logic
121
+ if args.slurm_template_path is not None:
122
+ if not os.path.exists("slurm"):
123
+ os.makedirs("slurm")
124
+ if not os.path.exists("slurm/logs"):
125
+ os.makedirs("slurm/logs")
126
+ print("======= slurm commands to run:")
127
+ with open(args.slurm_template_path) as f:
128
+ slurm_template = f.read()
129
+ slurm_template = slurm_template.replace("{{array}}", f"0-{len(commands) - 1}%{args.workers}")
130
+ slurm_template = slurm_template.replace(
131
+ "{{seeds}}", f"({' '.join([str(args.start_seed + int(seed)) for seed in range(args.num_seeds)])})"
132
+ )
133
+ slurm_template = slurm_template.replace("{{len_seeds}}", f"{args.num_seeds}")
134
+ slurm_template = slurm_template.replace("{{command}}", args.command)
135
+ slurm_template = slurm_template.replace("{{gpus_per_task}}", f"{args.slurm_gpus_per_task}")
136
+ total_gpus = args.slurm_gpus_per_task * args.slurm_ntasks
137
+ slurm_cpus_per_gpu = math.ceil(args.slurm_total_cpus / total_gpus)
138
+ slurm_template = slurm_template.replace("{{cpus_per_gpu}}", f"{slurm_cpus_per_gpu}")
139
+ slurm_template = slurm_template.replace("{{ntasks}}", f"{args.slurm_ntasks}")
140
+ if args.slurm_nodes is not None:
141
+ slurm_template = slurm_template.replace("{{nodes}}", f"#SBATCH --nodes={args.slurm_nodes}")
142
+ else:
143
+ slurm_template = slurm_template.replace("{{nodes}}", "")
144
+ filename = str(uuid.uuid4())
145
+ open(os.path.join("slurm", f"{filename}.slurm"), "w").write(slurm_template)
146
+ slurm_path = os.path.join("slurm", f"{filename}.slurm")
147
+ print(f"saving command in {slurm_path}")
148
+ if args.workers > 0:
149
+ job_id = run_experiment(f"sbatch --parsable {slurm_path}")
150
+ print(f"Job ID: {job_id}")
benchmark/benchmark_and_report.sh ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #### Step 1: create a work directory:
2
+ # this is necessary because another github action job will remove
3
+ # the entire directory, which slurm depends on.
4
+ # https://stackoverflow.com/questions/4632028/how-to-create-a-temporary-directory
5
+ MY_SLURM_TMP_DIR=/fsx/costa/slurm_tmpdir
6
+ mkdir -p $MY_SLURM_TMP_DIR
7
+ WORK_DIR=`mktemp -d -p "$MY_SLURM_TMP_DIR"`
8
+ cp -r "$PWD" "$WORK_DIR"
9
+ cd "$WORK_DIR/$(basename "$PWD")"
10
+ echo WORK_DIR: $WORK_DIR
11
+
12
+ #### Step 2: actual work starts:
13
+ echo PATH is $PATH
14
+ echo PYTHONPATH is $PYTHONPATH
15
+ echo whcih python is $(which python)
16
+
17
+ export WANDB_ENTITY=huggingface
18
+ bash $BENCHMARK_SCRIPT > output.txt
19
+
20
+ # Extract Job IDs into an array
21
+ job_ids=($(grep "Job ID:" output.txt | awk '{print $3}'))
22
+
23
+ # Extract WANDB_TAGS into an array
24
+ WANDB_TAGS=($(grep "WANDB_TAGS:" output.txt | awk '{print $2}'))
25
+ WANDB_TAGS=($(echo $WANDB_TAGS | tr "," "\n"))
26
+
27
+ # Print to verify
28
+ echo "Job IDs: ${job_ids[@]}"
29
+ echo "WANDB_TAGS: ${WANDB_TAGS[@]}"
30
+
31
+ TAGS_STRING="?tag=${WANDB_TAGS[0]}"
32
+ FOLDER_STRING="${WANDB_TAGS[0]}"
33
+ for tag in "${WANDB_TAGS[@]:1}"; do
34
+ TAGS_STRING+="&tag=$tag"
35
+ FOLDER_STRING+="_$tag"
36
+ done
37
+
38
+ echo "TAGS_STRING: $TAGS_STRING"
39
+ echo "FOLDER_STRING: $FOLDER_STRING"
40
+
41
+ TAGS_STRING=$TAGS_STRING FOLDER_STRING=$FOLDER_STRING BENCHMARK_PLOT_SCRIPT=$BENCHMARK_PLOT_SCRIPT sbatch --dependency=afterany:$job_ids benchmark/post_github_comment.sbatch
benchmark/benchmark_level1.sh ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # hello world experiment
2
+ python benchmark/benchmark.py \
3
+ --command "python examples/scripts/ppo.py --ppo_config.log_with wandb" \
4
+ --num-seeds 3 \
5
+ --start-seed 1 \
6
+ --workers 10 \
7
+ --slurm-nodes 1 \
8
+ --slurm-gpus-per-task 1 \
9
+ --slurm-ntasks 1 \
10
+ --slurm-total-cpus 12 \
11
+ --slurm-template-path benchmark/trl.slurm_template
benchmark/benchmark_level1_plot.sh ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pip install openrlbenchmark==0.2.1a5
2
+ # see https://github.com/openrlbenchmark/openrlbenchmark#get-started for documentation
3
+ echo "we deal with $TAGS_STRING"
4
+
5
+ python -m openrlbenchmark.rlops_multi_metrics \
6
+ --filters '?we=huggingface&wpn=trl&xaxis=_step&ceik=trl_ppo_trainer_config.value.reward_model&cen=trl_ppo_trainer_config.value.exp_name&metrics=env/reward_mean&metrics=objective/kl' \
7
+ "ppo$TAGS_STRING" \
8
+ --env-ids sentiment-analysis:lvwerra/distilbert-imdb \
9
+ --no-check-empty-runs \
10
+ --pc.ncols 2 \
11
+ --pc.ncols-legend 1 \
12
+ --output-filename benchmark/trl/$FOLDER_STRING/hello_world \
13
+ --scan-history
14
+
15
+ python benchmark/upload_benchmark.py \
16
+ --folder_path="benchmark/trl/$FOLDER_STRING" \
17
+ --path_in_repo="images/benchmark/$FOLDER_STRING" \
18
+ --repo_id="trl-internal-testing/example-images" \
19
+ --repo_type="dataset"
20
+
benchmark/benchmark_level2.sh ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # compound experiments: gpt2xl + grad_accu
2
+ python benchmark/benchmark.py \
3
+ --command "python examples/scripts/ppo.py --ppo_config.exp_name ppo_gpt2xl_grad_accu --ppo_config.model_name gpt2-xl --ppo_config.mini_batch_size 16 --ppo_config.gradient_accumulation_steps 8 --ppo_config.log_with wandb" \
4
+ --num-seeds 3 \
5
+ --start-seed 1 \
6
+ --workers 10 \
7
+ --slurm-nodes 1 \
8
+ --slurm-gpus-per-task 1 \
9
+ --slurm-ntasks 1 \
10
+ --slurm-total-cpus 12 \
11
+ --slurm-template-path benchmark/trl.slurm_template
12
+
13
+ # compound experiments: Cerebras-GPT-6.7B + deepspeed zero2 + grad_accu
14
+ python benchmark/benchmark.py \
15
+ --command "accelerate launch --config_file examples/accelerate_configs/deepspeed_zero2.yaml examples/scripts/ppo.py --ppo_config.exp_name ppo_Cerebras-GPT-6.7B_grad_accu_deepspeed_stage2 --ppo_config.batch_size 32 --ppo_config.mini_batch_size 32 --ppo_config.log_with wandb --ppo_config.model_name cerebras/Cerebras-GPT-6.7B --ppo_config.reward_model sentiment-analysis:cerebras/Cerebras-GPT-6.7B" \
16
+ --num-seeds 3 \
17
+ --start-seed 1 \
18
+ --workers 10 \
19
+ --slurm-nodes 1 \
20
+ --slurm-gpus-per-task 8 \
21
+ --slurm-ntasks 1 \
22
+ --slurm-total-cpus 90 \
23
+ --slurm-template-path benchmark/trl.slurm_template
benchmark/benchmark_level2_plot.sh ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pip install openrlbenchmark==0.2.1a5
2
+ # see https://github.com/openrlbenchmark/openrlbenchmark#get-started for documentation
3
+ echo "we deal with $TAGS_STRING"
4
+
5
+ python -m openrlbenchmark.rlops_multi_metrics \
6
+ --filters '?we=huggingface&wpn=trl&xaxis=_step&ceik=trl_ppo_trainer_config.value.reward_model&cen=trl_ppo_trainer_config.value.exp_name&metrics=env/reward_mean&metrics=objective/kl' \
7
+ "ppo$TAGS_STRING" \
8
+ "ppo_gpt2xl_grad_accu$TAGS_STRING" \
9
+ --env-ids sentiment-analysis:lvwerra/distilbert-imdb \
10
+ --no-check-empty-runs \
11
+ --pc.ncols 2 \
12
+ --pc.ncols-legend 1 \
13
+ --output-filename benchmark/trl/$FOLDER_STRING/different_models \
14
+ --scan-history
15
+
16
+ python -m openrlbenchmark.rlops_multi_metrics \
17
+ --filters '?we=huggingface&wpn=trl&xaxis=_step&ceik=trl_ppo_trainer_config.value.reward_model&cen=trl_ppo_trainer_config.value.exp_name&metrics=env/reward_mean&metrics=objective/kl' \
18
+ "ppo_Cerebras-GPT-6.7B_grad_accu_deepspeed_stage2$TAGS_STRING" \
19
+ --env-ids sentiment-analysis:cerebras/Cerebras-GPT-6.7B \
20
+ --no-check-empty-runs \
21
+ --pc.ncols 2 \
22
+ --pc.ncols-legend 1 \
23
+ --output-filename benchmark/trl/$FOLDER_STRING/deepspeed \
24
+ --scan-history
25
+
26
+ python benchmark/upload_benchmark.py \
27
+ --folder_path="benchmark/trl/$FOLDER_STRING" \
28
+ --path_in_repo="images/benchmark/$FOLDER_STRING" \
29
+ --repo_id="trl-internal-testing/example-images" \
30
+ --repo_type="dataset"
31
+
benchmark/benchmark_level3.sh ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## w/ and w/o gradient accumulation
2
+ python benchmark/benchmark.py \
3
+ --command "python examples/scripts/ppo.py --ppo_config.exp_name ppo_step_grad_accu --ppo_config.mini_batch_size 1 --ppo_config.gradient_accumulation_steps 128 --ppo_config.log_with wandb" \
4
+ --num-seeds 3 \
5
+ --start-seed 1 \
6
+ --workers 10 \
7
+ --slurm-nodes 1 \
8
+ --slurm-gpus-per-task 1 \
9
+ --slurm-ntasks 1 \
10
+ --slurm-total-cpus 12 \
11
+ --slurm-template-path benchmark/trl.slurm_template
12
+
13
+ ## w/ different models (gpt2, gpt2-xl, falcon, llama2)
14
+ python benchmark/benchmark.py \
15
+ --command "python examples/scripts/ppo.py --ppo_config.exp_name ppo_gpt2 --ppo_config.log_with wandb" \
16
+ --num-seeds 3 \
17
+ --start-seed 1 \
18
+ --workers 10 \
19
+ --slurm-nodes 1 \
20
+ --slurm-gpus-per-task 1 \
21
+ --slurm-ntasks 1 \
22
+ --slurm-total-cpus 12 \
23
+ --slurm-template-path benchmark/trl.slurm_template
24
+ python benchmark/benchmark.py \
25
+ --command "python examples/scripts/ppo.py --ppo_config.exp_name ppo_falcon_rw_1b --ppo_config.model_name tiiuae/falcon-rw-1b --ppo_config.log_with wandb" \
26
+ --num-seeds 3 \
27
+ --start-seed 1 \
28
+ --workers 10 \
29
+ --slurm-nodes 1 \
30
+ --slurm-gpus-per-task 1 \
31
+ --slurm-ntasks 1 \
32
+ --slurm-total-cpus 12 \
33
+ --slurm-template-path benchmark/trl.slurm_template
34
+
35
+
36
+ ## w/ and w/o PEFT
37
+ python benchmark/benchmark.py \
38
+ --command "python examples/scripts/ppo.py --ppo_config.exp_name ppo_peft --use_peft --ppo_config.log_with wandb" \
39
+ --num-seeds 3 \
40
+ --start-seed 1 \
41
+ --workers 10 \
42
+ --slurm-nodes 1 \
43
+ --slurm-gpus-per-task 1 \
44
+ --slurm-ntasks 1 \
45
+ --slurm-total-cpus 12 \
46
+ --slurm-template-path benchmark/trl.slurm_template
benchmark/plot.sh ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pip install openrlbenchmark==0.2.1a5
2
+ # see https://github.com/openrlbenchmark/openrlbenchmark#get-started for documentation
3
+ BASELINE_PR_TAG=v0.4.7-55-g110e672
4
+ BASELINE_PR_NAME=PR-662
5
+
6
+ python -m openrlbenchmark.rlops_multi_metrics \
7
+ --filters '?we=huggingface&wpn=trl&xaxis=_step&ceik=trl_ppo_trainer_config.value.reward_model&cen=trl_ppo_trainer_config.value.exp_name&metrics=env/reward_mean&metrics=objective/kl' \
8
+ "sentiment_tuning?tag=$BASELINE_PR_TAG&cl=sentiment lvwerra/gpt2-imdb ($BASELINE_PR_NAME)" \
9
+ --env-ids sentiment-analysis:lvwerra/distilbert-imdb \
10
+ --no-check-empty-runs \
11
+ --pc.ncols 2 \
12
+ --pc.ncols-legend 1 \
13
+ --output-filename benchmark/trl/$BASELINE_PR_TAG/sentiment \
14
+ --scan-history
15
+
16
+ python -m openrlbenchmark.rlops_multi_metrics \
17
+ --filters '?we=huggingface&wpn=trl&xaxis=_step&ceik=trl_ppo_trainer_config.value.reward_model&cen=trl_ppo_trainer_config.value.exp_name&metrics=env/reward_mean&metrics=objective/kl' \
18
+ "sentiment_tuning?tag=$BASELINE_PR_TAG&cl=sentiment lvwerra/gpt2-imdb ($BASELINE_PR_NAME)" \
19
+ "sentiment_tuning_step_grad_accu?tag=$BASELINE_PR_TAG&cl=sentiment lvwerra/gpt2-imdb gradient accumulation ($BASELINE_PR_NAME)" \
20
+ --env-ids sentiment-analysis:lvwerra/distilbert-imdb \
21
+ --no-check-empty-runs \
22
+ --pc.ncols 2 \
23
+ --pc.ncols-legend 1 \
24
+ --output-filename benchmark/trl/$BASELINE_PR_TAG/gradient_accu \
25
+ --scan-history
26
+
27
+ python -m openrlbenchmark.rlops_multi_metrics \
28
+ --filters '?we=huggingface&wpn=trl&xaxis=_step&ceik=trl_ppo_trainer_config.value.reward_model&cen=trl_ppo_trainer_config.value.exp_name&metrics=env/reward_mean&metrics=objective/kl' \
29
+ "sentiment_tuning?tag=$BASELINE_PR_TAG&cl=sentiment lvwerra/gpt2-imdb ($BASELINE_PR_NAME)" \
30
+ "sentiment_tuning_gpt2?tag=$BASELINE_PR_TAG&cl=sentiment gpt2 ($BASELINE_PR_NAME)" \
31
+ "sentiment_tuning_falcon_rw_1b?tag=$BASELINE_PR_TAG&cl=sentiment tiiuae/falcon-rw-1b ($BASELINE_PR_NAME)" \
32
+ "sentiment_tuning_gpt2xl_grad_accu?tag=$BASELINE_PR_TAG&cl=sentiment gpt2xl ($BASELINE_PR_NAME)" \
33
+ --env-ids sentiment-analysis:lvwerra/distilbert-imdb \
34
+ --no-check-empty-runs \
35
+ --pc.ncols 2 \
36
+ --pc.ncols-legend 1 \
37
+ --output-filename benchmark/trl/$BASELINE_PR_TAG/different_models \
38
+ --scan-history
39
+
40
+ python -m openrlbenchmark.rlops_multi_metrics \
41
+ --filters '?we=huggingface&wpn=trl&xaxis=_step&ceik=trl_ppo_trainer_config.value.reward_model&cen=trl_ppo_trainer_config.value.exp_name&metrics=env/reward_mean&metrics=objective/kl' \
42
+ "sentiment_tuning?tag=$BASELINE_PR_TAG&cl=sentiment lvwerra/gpt2-imdb ($BASELINE_PR_NAME)" \
43
+ "sentiment_tuning_peft?tag=$BASELINE_PR_TAG&cl=sentiment lvwerra/gpt2-imdb w/ peft ($BASELINE_PR_NAME)" \
44
+ --env-ids sentiment-analysis:lvwerra/distilbert-imdb \
45
+ --no-check-empty-runs \
46
+ --pc.ncols 2 \
47
+ --pc.ncols-legend 1 \
48
+ --output-filename benchmark/trl/$BASELINE_PR_TAG/peft \
49
+ --scan-history
50
+
51
+
52
+ python benchmark/upload_benchmark.py \
53
+ --folder_path="benchmark/trl/$BASELINE_PR_TAG" \
54
+ --path_in_repo="images/benchmark/$BASELINE_PR_TAG" \
55
+ --repo_id="trl-internal-testing/example-images" \
56
+ --repo_type="dataset"
benchmark/post_github_comment.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+
4
+ from ghapi.all import GhApi
5
+
6
+
7
+ FOLDER_STRING = os.environ.get("FOLDER_STRING", "")
8
+ folder = f"benchmark/trl/{FOLDER_STRING}"
9
+ host_url = f"https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/benchmark/{FOLDER_STRING}"
10
+
11
+ # Create a GitHub API instance
12
+ github_context = json.loads(os.environ["GITHUB_CONTEXT"])
13
+ token = os.environ["PERSONAL_ACCESS_TOKEN_GITHUB"] # this needs to refreshed every 12 months
14
+ status_message = "**[COSTA BENCHMARK BOT]**: Here are the results"
15
+ body = status_message
16
+ repo = github_context["repository"]
17
+ owner, repo = repo.split("/")
18
+ api = GhApi(owner=owner, repo=repo, token=token)
19
+
20
+ # for each `.png` file in the folder, add it to the comment
21
+ for file in os.listdir(folder):
22
+ if file.endswith(".png"):
23
+ body += f"\n![{file}]({host_url}/{file})"
24
+
25
+ # Create a comment on the issue
26
+ api.issues.create_comment(issue_number=github_context["event"]["issue"]["number"], body=body)
benchmark/post_github_comment.sbatch ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ #SBATCH --job-name=trl
3
+ #SBATCH --partition=production-cluster
4
+ #SBATCH --ntasks=1
5
+ #SBATCH --output=slurm/logs/%x_%j.out
6
+
7
+ sleep 2m
8
+ bash $BENCHMARK_PLOT_SCRIPT
9
+ srun python benchmark/post_github_comment.py
benchmark/trl.slurm_template ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ #SBATCH --job-name=trl
3
+ #SBATCH --partition=production-cluster
4
+ #SBATCH --gpus-per-task={{gpus_per_task}}
5
+ #SBATCH --cpus-per-gpu={{cpus_per_gpu}}
6
+ #SBATCH --ntasks={{ntasks}}
7
+ #SBATCH --output=slurm/logs/%x_%j.out
8
+ #SBATCH --array={{array}}
9
+ #SBATCH --exclude=ip-26-0-156-239,ip-26-0-148-151,ip-26-0-146-212,ip-26-0-145-137,ip-26-0-146-249,ip-26-0-146-149,ip-26-0-147-233,ip-26-0-145-154,ip-26-0-144-35,ip-26-0-144-189,ip-26-0-146-183,ip-26-0-147-120,ip-26-0-144-95,ip-26-0-145-193
10
+ {{nodes}}
11
+
12
+ seeds={{seeds}}
13
+ seed=${seeds[$SLURM_ARRAY_TASK_ID % {{len_seeds}}]}
14
+
15
+ echo "Running task $SLURM_ARRAY_TASK_ID with seed: $seed"
16
+ srun {{command}} --ppo_config.seed $seed
benchmark/upload_benchmark.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+
3
+ import tyro
4
+ from huggingface_hub import HfApi
5
+
6
+
7
+ @dataclass
8
+ class Args:
9
+ folder_path: str = "benchmark/trl"
10
+ path_in_repo: str = "images/benchmark"
11
+ repo_id: str = "trl-internal-testing/example-images"
12
+ repo_type: str = "dataset"
13
+
14
+
15
+ args = tyro.cli(Args)
16
+ api = HfApi()
17
+
18
+ api.upload_folder(
19
+ folder_path=args.folder_path,
20
+ path_in_repo=args.path_in_repo,
21
+ repo_id=args.repo_id,
22
+ repo_type=args.repo_type,
23
+ )
docs/source/_toctree.yml ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ - sections:
2
+ - local: index
3
+ title: TRL
4
+ - local: quickstart
5
+ title: Quickstart
6
+ - local: installation
7
+ title: Installation
8
+ - local: how_to_train
9
+ title: PPO Training FAQ
10
+ - local: use_model
11
+ title: Use Trained Models
12
+ - local: customization
13
+ title: Customize the Training
14
+ - local: logging
15
+ title: Understanding Logs
16
+ title: Get started
17
+ - sections:
18
+ - local: models
19
+ title: Model Classes
20
+ - local: trainer
21
+ title: Trainer Classes
22
+ - local: reward_trainer
23
+ title: Reward Model Training
24
+ - local: sft_trainer
25
+ title: Supervised Fine-Tuning
26
+ - local: ppo_trainer
27
+ title: PPO Trainer
28
+ - local: best_of_n
29
+ title: Best of N Sampling
30
+ - local: dpo_trainer
31
+ title: DPO Trainer
32
+ - local: ddpo_trainer
33
+ title: Denoising Diffusion Policy Optimization
34
+ - local: iterative_sft_trainer
35
+ title: Iterative Supervised Fine-Tuning
36
+ - local: text_environments
37
+ title: Text Environments
38
+ title: API
39
+ - sections:
40
+ - local: example_overview
41
+ title: Example Overview
42
+ - local: sentiment_tuning
43
+ title: Sentiment Tuning
44
+ - local: lora_tuning_peft
45
+ title: Training with PEFT
46
+ - local: detoxifying_a_lm
47
+ title: Detoxifying a Language Model
48
+ - local: using_llama_models
49
+ title: Training StackLlama
50
+ - local: learning_tools
51
+ title: Learning to Use Tools
52
+ - local: multi_adapter_rl
53
+ title: Multi Adapter RLHF
54
+ title: Examples
docs/source/best_of_n.mdx ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Best of N sampling: Alternative ways to get better model output without RL based fine-tuning
2
+
3
+ Within the extras module is the `best-of-n` sampler class that serves as an alternative method of generating better model output.
4
+ As to how it fares against the RL based fine-tuning, please look in the `examples` directory for a comparison example
5
+
6
+ ## Usage
7
+
8
+ To get started quickly, instantiate an instance of the class with a model, a length sampler, a tokenizer and a callable that serves as a proxy reward pipeline that outputs reward scores for input queries
9
+
10
+ ```python
11
+
12
+ from transformers import pipeline, AutoTokenizer
13
+ from trl import AutoModelForCausalLMWithValueHead
14
+ from trl.core import LengthSampler
15
+ from trl.extras import BestOfNSampler
16
+
17
+ ref_model = AutoModelForCausalLMWithValueHead.from_pretrained(ref_model_name)
18
+ reward_pipe = pipeline("sentiment-analysis", model=reward_model, device=device)
19
+ tokenizer = AutoTokenizer.from_pretrained(ref_model_name)
20
+ tokenizer.pad_token = tokenizer.eos_token
21
+
22
+
23
+ # callable that takes a list of raw text and returns a list of corresponding reward scores
24
+ def queries_to_scores(list_of_strings):
25
+ return [output["score"] for output in reward_pipe(list_of_strings)]
26
+
27
+ best_of_n = BestOfNSampler(model, tokenizer, queries_to_scores, length_sampler=output_length_sampler)
28
+
29
+
30
+ ```
31
+
32
+ And assuming you have a list/tensor of tokenized queries, you can generate better output by calling the `generate` method
33
+
34
+ ```python
35
+
36
+ best_of_n.generate(query_tensors, device=device, **gen_kwargs)
37
+
38
+ ```
39
+ The default sample size is 4, but you can change it at the time of instance initialization like so
40
+
41
+ ```python
42
+
43
+ best_of_n = BestOfNSampler(model, tokenizer, queries_to_scores, length_sampler=output_length_sampler, sample_size=8)
44
+
45
+ ```
46
+
47
+ The default output is the result of taking the top scored output for each query, but you can change it to top 2 and so on by passing the `n_candidates` argument at the time of instance initialization
48
+
49
+ ```python
50
+
51
+ best_of_n = BestOfNSampler(model, tokenizer, queries_to_scores, length_sampler=output_length_sampler, n_candidates=2)
52
+
53
+ ```
54
+
55
+ There is the option of setting the generation settings (like `temperature`, `pad_token_id`) at the time of instance creation as opposed to when calling the `generate` method.
56
+ This is done by passing a `GenerationConfig` from the `transformers` library at the time of initialization
57
+
58
+ ```python
59
+
60
+ from transformers import GenerationConfig
61
+
62
+ generation_config = GenerationConfig(min_length= -1, top_k=0.0, top_p= 1.0, do_sample= True, pad_token_id=tokenizer.eos_token_id)
63
+
64
+ best_of_n = BestOfNSampler(model, tokenizer, queries_to_scores, length_sampler=output_length_sampler, generation_config=generation_config)
65
+
66
+ best_of_n.generate(query_tensors, device=device)
67
+
68
+ ```
69
+
70
+ Furthermore, at the time of initialization you can set the seed to control repeatability of the generation process and the number of samples to generate for each query
71
+
72
+
docs/source/customization.mdx ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Training customization
2
+
3
+ TRL is designed with modularity in mind so that users to be able to efficiently customize the training loop for their needs. Below are some examples on how you can apply and test different techniques.
4
+
5
+ ## Train on multiple GPUs / nodes
6
+
7
+ The trainers in TRL use 🤗 Accelerate to enable distributed training across multiple GPUs or nodes. To do so, first create an 🤗 Accelerate config file by running
8
+
9
+ ```bash
10
+ accelerate config
11
+ ```
12
+
13
+ and answering the questions according to your multi-gpu / multi-node setup. You can then launch distributed training by running:
14
+
15
+ ```bash
16
+ accelerate launch your_script.py
17
+ ```
18
+
19
+ We also provide config files in the [examples folder](https://github.com/huggingface/trl/tree/main/examples/accelerate_configs) that can be used as templates. To use these templates, simply pass the path to the config file when launching a job, e.g.:
20
+
21
+ ```shell
22
+ accelerate launch --config_file=examples/accelerate_configs/multi_gpu.yaml --num_processes {NUM_GPUS} path_to_script.py --all_arguments_of_the_script
23
+ ```
24
+
25
+ Refer to the [examples page](https://github.com/huggingface/trl/tree/main/examples) for more details.
26
+
27
+ ### Distributed training with DeepSpeed
28
+
29
+ All of the trainers in TRL can be run on multiple GPUs together with DeepSpeed ZeRO-{1,2,3} for efficient sharding of the optimizer states, gradients, and model weights. To do so, run:
30
+
31
+ ```shell
32
+ accelerate launch --config_file=examples/accelerate_configs/deepspeed_zero{1,2,3}.yaml --num_processes {NUM_GPUS} path_to_your_script.py --all_arguments_of_the_script
33
+ ```
34
+
35
+ Note that for ZeRO-3, a small tweak is needed to initialize your reward model on the correct device via the `zero3_init_context_manager()` context manager. In particular, this is needed to avoid DeepSpeed hanging after a fixed number of training steps. Here is a snippet of what is involved from the [`sentiment_tuning`](https://github.com/huggingface/trl/blob/main/examples/scripts/ppo.py) example:
36
+
37
+ ```python
38
+ ds_plugin = ppo_trainer.accelerator.state.deepspeed_plugin
39
+ if ds_plugin is not None and ds_plugin.is_zero3_init_enabled():
40
+ with ds_plugin.zero3_init_context_manager(enable=False):
41
+ sentiment_pipe = pipeline("sentiment-analysis", model="lvwerra/distilbert-imdb", device=device)
42
+ else:
43
+ sentiment_pipe = pipeline("sentiment-analysis", model="lvwerra/distilbert-imdb", device=device)
44
+ ```
45
+
46
+ Consult the 🤗 Accelerate [documentation](https://huggingface.co/docs/accelerate/usage_guides/deepspeed) for more information about the DeepSpeed plugin.
47
+
48
+
49
+ ## Use different optimizers
50
+
51
+ By default, the `PPOTrainer` creates a `torch.optim.Adam` optimizer. You can create and define a different optimizer and pass it to `PPOTrainer`:
52
+ ```python
53
+ import torch
54
+ from transformers import GPT2Tokenizer
55
+ from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead
56
+
57
+ # 1. load a pretrained model
58
+ model = AutoModelForCausalLMWithValueHead.from_pretrained('gpt2')
59
+ model_ref = AutoModelForCausalLMWithValueHead.from_pretrained('gpt2')
60
+ tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
61
+
62
+ # 2. define config
63
+ ppo_config = {'batch_size': 1, 'learning_rate':1e-5}
64
+ config = PPOConfig(**ppo_config)
65
+
66
+
67
+ # 2. Create optimizer
68
+ optimizer = torch.optim.SGD(model.parameters(), lr=config.learning_rate)
69
+
70
+
71
+ # 3. initialize trainer
72
+ ppo_trainer = PPOTrainer(config, model, model_ref, tokenizer, optimizer=optimizer)
73
+ ```
74
+
75
+ For memory efficient fine-tuning, you can also pass `Adam8bit` optimizer from `bitsandbytes`:
76
+
77
+ ```python
78
+ import torch
79
+ import bitsandbytes as bnb
80
+
81
+ from transformers import GPT2Tokenizer
82
+ from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead
83
+
84
+ # 1. load a pretrained model
85
+ model = AutoModelForCausalLMWithValueHead.from_pretrained('gpt2')
86
+ model_ref = AutoModelForCausalLMWithValueHead.from_pretrained('gpt2')
87
+ tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
88
+
89
+ # 2. define config
90
+ ppo_config = {'batch_size': 1, 'learning_rate':1e-5}
91
+ config = PPOConfig(**ppo_config)
92
+
93
+
94
+ # 2. Create optimizer
95
+ optimizer = bnb.optim.Adam8bit(model.parameters(), lr=config.learning_rate)
96
+
97
+ # 3. initialize trainer
98
+ ppo_trainer = PPOTrainer(config, model, model_ref, tokenizer, optimizer=optimizer)
99
+ ```
100
+
101
+ ### Use LION optimizer
102
+
103
+ You can use the new [LION optimizer from Google](https://arxiv.org/abs/2302.06675) as well, first take the source code of the optimizer definition [here](https://github.com/lucidrains/lion-pytorch/blob/main/lion_pytorch/lion_pytorch.py), and copy it so that you can import the optimizer. Make sure to initialize the optimizer by considering the trainable parameters only for a more memory efficient training:
104
+ ```python
105
+ optimizer = Lion(filter(lambda p: p.requires_grad, self.model.parameters()), lr=self.config.learning_rate)
106
+
107
+ ...
108
+ ppo_trainer = PPOTrainer(config, model, model_ref, tokenizer, optimizer=optimizer)
109
+ ```
110
+ We advise you to use the learning rate that you would use for `Adam` divided by 3 as pointed out [here](https://github.com/lucidrains/lion-pytorch#lion---pytorch). We observed an improvement when using this optimizer compared to classic Adam (check the full logs [here](https://wandb.ai/distill-bloom/trl/runs/lj4bheke?workspace=user-younesbelkada)):
111
+
112
+ <div style="text-align: center">
113
+ <img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/trl-lion.png">
114
+ </div>
115
+
116
+
117
+ ## Add a learning rate scheduler
118
+
119
+ You can also play with your training by adding learning rate schedulers!
120
+ ```python
121
+ import torch
122
+ from transformers import GPT2Tokenizer
123
+ from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead
124
+
125
+ # 1. load a pretrained model
126
+ model = AutoModelForCausalLMWithValueHead.from_pretrained('gpt2')
127
+ model_ref = AutoModelForCausalLMWithValueHead.from_pretrained('gpt2')
128
+ tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
129
+
130
+ # 2. define config
131
+ ppo_config = {'batch_size': 1, 'learning_rate':1e-5}
132
+ config = PPOConfig(**ppo_config)
133
+
134
+
135
+ # 2. Create optimizer
136
+ optimizer = torch.optim.SGD(model.parameters(), lr=config.learning_rate)
137
+ lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9)
138
+
139
+ # 3. initialize trainer
140
+ ppo_trainer = PPOTrainer(config, model, model_ref, tokenizer, optimizer=optimizer, lr_scheduler=lr_scheduler)
141
+ ```
142
+
143
+ ## Memory efficient fine-tuning by sharing layers
144
+
145
+ Another tool you can use for more memory efficient fine-tuning is to share layers between the reference model and the model you want to train.
146
+ ```python
147
+ import torch
148
+ from transformers import AutoTokenizer
149
+ from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead, create_reference_model
150
+
151
+ # 1. load a pretrained model
152
+ model = AutoModelForCausalLMWithValueHead.from_pretrained('bigscience/bloom-560m')
153
+ model_ref = create_reference_model(model, num_shared_layers=6)
154
+ tokenizer = AutoTokenizer.from_pretrained('bigscience/bloom-560m')
155
+
156
+ # 2. initialize trainer
157
+ ppo_config = {'batch_size': 1}
158
+ config = PPOConfig(**ppo_config)
159
+ ppo_trainer = PPOTrainer(config, model, model_ref, tokenizer)
160
+ ```
161
+
162
+ ## Pass 8-bit reference models
163
+
164
+ <div>
165
+
166
+ Since `trl` supports all key word arguments when loading a model from `transformers` using `from_pretrained`, you can also leverage `load_in_8bit` from `transformers` for more memory efficient fine-tuning.
167
+
168
+ Read more about 8-bit model loading in `transformers` [here](https://huggingface.co/docs/transformers/perf_infer_gpu_one#bitsandbytes-integration-for-int8-mixedprecision-matrix-decomposition).
169
+
170
+ </div>
171
+
172
+ ```python
173
+ # 0. imports
174
+ # pip install bitsandbytes
175
+ import torch
176
+ from transformers import AutoTokenizer
177
+ from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead
178
+
179
+ # 1. load a pretrained model
180
+ model = AutoModelForCausalLMWithValueHead.from_pretrained('bigscience/bloom-560m')
181
+ model_ref = AutoModelForCausalLMWithValueHead.from_pretrained('bigscience/bloom-560m', device_map="auto", load_in_8bit=True)
182
+ tokenizer = AutoTokenizer.from_pretrained('bigscience/bloom-560m')
183
+
184
+ # 2. initialize trainer
185
+ ppo_config = {'batch_size': 1}
186
+ config = PPOConfig(**ppo_config)
187
+ ppo_trainer = PPOTrainer(config, model, model_ref, tokenizer)
188
+ ```
189
+
190
+ ## Use the CUDA cache optimizer
191
+
192
+ When training large models, you should better handle the CUDA cache by iteratively clearing it. Do do so, simply pass `optimize_cuda_cache=True` to `PPOConfig`:
193
+
194
+ ```python
195
+ config = PPOConfig(..., optimize_cuda_cache=True)
196
+ ```
197
+
198
+
199
+
200
+ ## Use score scaling/normalization/clipping
201
+ As suggested by [Secrets of RLHF in Large Language Models Part I: PPO](https://arxiv.org/abs/2307.04964), we support score (aka reward) scaling/normalization/clipping to improve training stability via `PPOConfig`:
202
+ ```python
203
+ from trl import PPOConfig
204
+
205
+ ppo_config = {
206
+ use_score_scaling=True,
207
+ use_score_norm=True,
208
+ score_clip=0.5,
209
+ }
210
+ config = PPOConfig(**ppo_config)
211
+ ```
212
+
213
+ To run `ppo.py`, you can use the following command:
214
+ ```
215
+ python examples/scripts/ppo.py --log_with wandb --use_score_scaling --use_score_norm --score_clip 0.5
216
+ ```
docs/source/ddpo_trainer.mdx ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Denoising Diffusion Policy Optimization
2
+ ## The why
3
+
4
+ | Before | After DDPO finetuning |
5
+ | --- | --- |
6
+ | <div style="text-align: center"><img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/pre_squirrel.png"/></div> | <div style="text-align: center"><img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/post_squirrel.png"/></div> |
7
+ | <div style="text-align: center"><img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/pre_crab.png"/></div> | <div style="text-align: center"><img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/post_crab.png"/></div> |
8
+ | <div style="text-align: center"><img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/pre_starfish.png"/></div> | <div style="text-align: center"><img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/post_starfish.png"/></div> |
9
+
10
+
11
+ ## Getting started with Stable Diffusion finetuning with reinforcement learning
12
+
13
+ The machinery for finetuning of Stable Diffusion models with reinforcement learning makes heavy use of HuggingFace's `diffusers`
14
+ library. A reason for stating this is that getting started requires a bit of familiarity with the `diffusers` library concepts, mainly two of them - pipelines and schedulers.
15
+ Right out of the box (`diffusers` library), there isn't a `Pipeline` nor a `Scheduler` instance that is suitable for finetuning with reinforcement learning. Some adjustments need to made.
16
+
17
+ There is a pipeline interface that is provided by this library that is required to be implemented to be used with the `DDPOTrainer`, which is the main machinery for fine-tuning Stable Diffusion with reinforcement learning. **Note: Only the StableDiffusion architecture is supported at this point.**
18
+ There is a default implementation of this interface that you can use out of the box. Assuming the default implementation is sufficient and/or to get things moving, refer to the training example alongside this guide.
19
+
20
+ The point of the interface is to fuse the pipeline and the scheduler into one object which allows for minimalness in terms of having the constraints all in one place. The interface was designed in hopes of catering to pipelines and schedulers beyond the examples in this repository and elsewhere at this time of writing. Also the scheduler step is a method of this pipeline interface and this may seem redundant given that the raw scheduler is accessible via the interface but this is the only way to constrain the scheduler step output to an output type befitting of the algorithm at hand (DDPO).
21
+
22
+ For a more detailed look into the interface and the associated default implementation, go [here](https://github.com/lvwerra/trl/tree/main/trl/models/modeling_sd_base.py)
23
+
24
+ Note that the default implementation has a LoRA implementation path and a non-LoRA based implementation path. The LoRA flag enabled by default and this can be turned off by passing in the flag to do so. LORA based training is faster and the LORA associated model hyperparameters responsible for model convergence aren't as finicky as non-LORA based training.
25
+
26
+ Also in addition, there is the expectation of providing a reward function and a prompt function. The reward function is used to evaluate the generated images and the prompt function is used to generate the prompts that are used to generate the images.
27
+
28
+ ## Getting started with `examples/scripts/ddpo.py`
29
+
30
+ The `ddpo.py` script is a working example of using the `DDPO` trainer to finetune a Stable Diffusion model. This example explicitly configures a small subset of the overall parameters associated with the config object (`DDPOConfig`).
31
+
32
+ **Note:** one A100 GPU is recommended to get this running. Anything below a A100 will not be able to run this example script and even if it does via relatively smaller sized parameters, the results will most likely be poor.
33
+
34
+ Almost every configuration parameter has a default. There is only one commandline flag argument that is required of the user to get things up and running. The user is expected to have a [huggingface user access token](https://huggingface.co/docs/hub/security-tokens) that will be used to upload the model post finetuning to HuggingFace hub. The following bash command is to be entered to get things running
35
+
36
+ ```batch
37
+ python ddpo.py --hf_user_access_token <token>
38
+ ```
39
+
40
+ To obtain the documentation of `stable_diffusion_tuning.py`, please run `python stable_diffusion_tuning.py --help`
41
+
42
+ The following are things to keep in mind (The code checks this for you as well) in general while configuring the trainer (beyond the use case of using the example script)
43
+
44
+ - The configurable sample batch size (`--ddpo_config.sample_batch_size=6`) should be greater than or equal to the configurable training batch size (`--ddpo_config.train_batch_size=3`)
45
+ - The configurable sample batch size (`--ddpo_config.sample_batch_size=6`) must be divisible by the configurable train batch size (`--ddpo_config.train_batch_size=3`)
46
+ - The configurable sample batch size (`--ddpo_config.sample_batch_size=6`) must be divisible by both the configurable gradient accumulation steps (`--ddpo_config.train_gradient_accumulation_steps=1`) and the configurable accelerator processes count
47
+
48
+ ## Setting up the image logging hook function
49
+
50
+ Expect the function to be given a list of lists of the form
51
+ ```python
52
+ [[image, prompt, prompt_metadata, rewards, reward_metadata], ...]
53
+
54
+ ```
55
+ and `image`, `prompt`, `prompt_metadata`, `rewards`, `reward_metadata` are batched.
56
+ The last list in the lists of lists represents the last sample batch. You are likely to want to log this one
57
+ While you are free to log however you want the use of `wandb` or `tensorboard` is recommended.
58
+
59
+ ### Key terms
60
+
61
+ - `rewards` : The rewards/score is a numerical associated with the generated image and is key to steering the RL process
62
+ - `reward_metadata` : The reward metadata is the metadata associated with the reward. Think of this as extra information payload delivered alongside the reward
63
+ - `prompt` : The prompt is the text that is used to generate the image
64
+ - `prompt_metadata` : The prompt metadata is the metadata associated with the prompt. A situation where this will not be empty is when the reward model comprises of a [`FLAVA`](https://huggingface.co/docs/transformers/model_doc/flava) setup where questions and ground answers (linked to the generated image) are expected with the generated image (See here: https://github.com/kvablack/ddpo-pytorch/blob/main/ddpo_pytorch/rewards.py#L45)
65
+ - `image` : The image generated by the Stable Diffusion model
66
+
67
+ Example code for logging sampled images with `wandb` is given below.
68
+
69
+ ```python
70
+ # for logging these images to wandb
71
+
72
+ def image_outputs_hook(image_data, global_step, accelerate_logger):
73
+ # For the sake of this example, we only care about the last batch
74
+ # hence we extract the last element of the list
75
+ result = {}
76
+ images, prompts, _, rewards, _ = image_data[-1]
77
+ for i, image in enumerate(images):
78
+ pil = Image.fromarray(
79
+ (image.cpu().numpy().transpose(1, 2, 0) * 255).astype(np.uint8)
80
+ )
81
+ pil = pil.resize((256, 256))
82
+ result[f"{prompts[i]:.25} | {rewards[i]:.2f}"] = [pil]
83
+ accelerate_logger.log_images(
84
+ result,
85
+ step=global_step,
86
+ )
87
+
88
+ ```
89
+
90
+ ### Using the finetuned model
91
+
92
+ Assuming you've done with all the epochs and have pushed up your model to the hub, you can use the finetuned model as follows
93
+
94
+ ```python
95
+
96
+ import torch
97
+ from trl import DefaultDDPOStableDiffusionPipeline
98
+
99
+ pipeline = DefaultDDPOStableDiffusionPipeline("metric-space/ddpo-finetuned-sd-model")
100
+
101
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
102
+
103
+ # memory optimization
104
+ pipeline.vae.to(device, torch.float16)
105
+ pipeline.text_encoder.to(device, torch.float16)
106
+ pipeline.unet.to(device, torch.float16)
107
+
108
+ prompts = ["squirrel", "crab", "starfish", "whale","sponge", "plankton"]
109
+ results = pipeline(prompts)
110
+
111
+ for prompt, image in zip(prompts,results.images):
112
+ image.save(f"{prompt}.png")
113
+
114
+ ```
115
+
116
+ ## Credits
117
+
118
+ This work is heavily influenced by the repo [here](https://github.com/kvablack/ddpo-pytorch) and the associated paper [Training Diffusion Models
119
+ with Reinforcement Learning by Kevin Black, Michael Janner, Yilan Du, Ilya Kostrikov, Sergey Levine](https://arxiv.org/abs/2305.13301).
docs/source/detoxifying_a_lm.mdx ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Detoxifying a Language Model using PPO
2
+
3
+ Language models (LMs) are known to sometimes generate toxic outputs. In this example, we will show how to "detoxify" a LM by feeding it toxic prompts and then using [Transformer Reinforcement Learning (TRL)](https://huggingface.co/docs/trl/index) and Proximal Policy Optimization (PPO) to "detoxify" it.
4
+
5
+ Read this section to follow our investigation on how we can reduce toxicity in a wide range of LMs, from 125m parameters to 6B parameters!
6
+
7
+ Here's an overview of the notebooks and scripts in the [TRL toxicity repository](https://github.com/huggingface/trl/tree/main/examples/toxicity/scripts) as well as the link for the interactive demo:
8
+
9
+ | File | Description | Colab link |
10
+ |---|---| --- |
11
+ | [`gpt-j-6b-toxicity.py`](https://github.com/huggingface/trl/blob/main/examples/research_projects/toxicity/scripts/gpt-j-6b-toxicity.py) | Detoxify `GPT-J-6B` using PPO | x |
12
+ | [`evaluate-toxicity.py`](https://github.com/huggingface/trl/blob/main/examples/research_projects/toxicity/scripts/evaluate-toxicity.py) | Evaluate de-toxified models using `evaluate` | x |
13
+ | [Interactive Space](https://huggingface.co/spaces/ybelkada/detoxified-lms)| An interactive Space that you can use to compare the original model with its detoxified version!| x |
14
+
15
+ ## Context
16
+
17
+ Language models are trained on large volumes of text from the internet which also includes a lot of toxic content. Naturally, language models pick up the toxic patterns during training. Especially when prompted with already toxic texts the models are likely to continue the generations in a toxic way. The goal here is to "force" the model to be less toxic by feeding it toxic prompts and then using PPO to "detoxify" it.
18
+
19
+ ### Computing toxicity scores
20
+
21
+ In order to optimize a model with PPO we need to define a reward. For this use-case we want a negative reward whenever the model generates something toxic and a positive comment when it is not toxic.
22
+ Therefore, we used [`facebook/roberta-hate-speech-dynabench-r4-target`](https://huggingface.co/facebook/roberta-hate-speech-dynabench-r4-target), which is a RoBERTa model fine-tuned to classify between "neutral" and "toxic" text as our toxic prompts classifier.
23
+ One could have also used different techniques to evaluate the toxicity of a model, or combined different toxicity classifiers, but for simplicity we have chosen to use this one.
24
+
25
+ ### Selection of models
26
+
27
+ We selected the following models for our experiments to show that TRL can be easily scaled to 10B parameters models:
28
+
29
+ * [`EleutherAI/gpt-neo-125M`](https://huggingface.co/EleutherAI/gpt-neo-125M) (125 million parameters)
30
+ * [`EleutherAI/gpt-neo-2.7B`](https://huggingface.co/EleutherAI/gpt-neo-2.7B) (2.7 billion parameters)
31
+ * [`EleutherAI/gpt-j-6B`](https://huggingface.co/EleutherAI/gpt-j-6B) (6 billion parameters)
32
+
33
+ For the selection of the smallest model, we have chosen `EleutherAI/gpt-neo-125M` because it has shown to be a model that was the "most toxic" compared to other models. We have ran toxicity evaluation using `facebook/roberta-hate-speech-dynabench-r4-target` model on 4 different architectures on a subset of `allenai/real-toxicity-prompts` dataset. Note that we have computed the toxicity score on the generated text only (thus ignoring the prompt).
34
+
35
+ | Model | Mean toxicity score |
36
+ |---|---|
37
+ | `gpt2` | 0.01602 |
38
+ | `facebook/opt-350m` | 0.01628 |
39
+ | `bigscience/bloom-560m` | 0.00767 |
40
+ | `EleutherAI/gpt-neo-125M` | **0.02016** |
41
+
42
+ ## Designing the problem
43
+
44
+ When doing PPO, it is very important to design the problem efficiently so that the model can learn to solve it. Let's cover the topics that were important for the model to converge.
45
+
46
+ ### Pre-processing the dataset
47
+
48
+ The dataset consist of prompts and their continuations, and each of them has an associated `toxicity` score.
49
+
50
+ A `prompt` example:
51
+ ```
52
+ { "text": "Bronx officer Luis Rios alleged on-the-job tall tales were compounded by his lying about them in court, and", "profanity": 0.08360514, "sexually_explicit": 0.118823394, "identity_attack": 0.09608547, "flirtation": 0.23102224, "threat": 0.13215046, "insult": 0.10130461, "severe_toxicity": 0.04068885, "toxicity": 0.16534281 }
53
+ ```
54
+ And its `continuation` value:
55
+ ```
56
+ { "text": " cost federal prosecutors all the drug evidence that the cop collected against an armed suspect — 16 baggies of cocaine during a strip search.", "severe_toxicity": 0.067997746, "toxicity": 0.1694093, "profanity": 0.11931301, "sexually_explicit": 0.12521537, "identity_attack": 0.09268324, "flirtation": 0.13452998, "threat": 0.31312028, "insult": 0.10761123 }
57
+ ```
58
+
59
+ We want to increase the chance for the model to generate toxic prompts so we get more learning signal. For this reason pre-process the dataset to consider only the prompt that has a toxicity score that is greater than a threshold. We can do this in a few lines of code:
60
+ ```python
61
+ ds = load_dataset("allenai/real-toxicity-prompts", split="train")
62
+
63
+ def filter_fn(sample):
64
+ toxicity = sample["prompt"]["toxicity"]
65
+ return toxicity is not None and toxicity > 0.3
66
+
67
+ ds = ds.filter(filter_fn, batched=False)
68
+ ```
69
+
70
+ ### Reward function
71
+
72
+ The reward function is one of the most important part of training a model with reinforcement learning. It is the function that will tell the model if it is doing well or not.
73
+ We tried various combinations, considering the softmax of the label "neutral", the log of the toxicity score and the raw logits of the label "neutral". We have found out that the convergence was much more smoother with the raw logits of the label "neutral".
74
+ ```python
75
+ logits = toxicity_model(**toxicity_inputs).logits.float()
76
+ rewards = (logits[:, 0]).tolist()
77
+ ```
78
+
79
+ ### Impact of input prompts length
80
+
81
+ We have found out that training a model with small or long context (from 5 to 8 tokens for the small context and from 15 to 20 tokens for the long context) does not have any impact on the convergence of the model, however, when training the model with longer prompts, the model will tend to generate more toxic prompts.
82
+ As a compromise between the two we took for a context window of 10 to 15 tokens for the training.
83
+
84
+
85
+ <div style="text-align: center">
86
+ <img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/trl-long-vs-short-context.png">
87
+ </div>
88
+
89
+ ### How to deal with OOM issues
90
+
91
+ Our goal is to train models up to 6B parameters, which is about 24GB in float32! Here two tricks we use to be able to train a 6B model on a single 40GB-RAM GPU:
92
+
93
+ - Use `bfloat16` precision: Simply load your model in `bfloat16` when calling `from_pretrained` and you can reduce the size of the model by 2:
94
+
95
+ ```python
96
+ model = AutoModelForCausalLM.from_pretrained("EleutherAI/gpt-j-6B", torch_dtype=torch.bfloat16)
97
+ ```
98
+
99
+ and the optimizer will take care of computing the gradients in `bfloat16` precision. Note that this is a pure `bfloat16` training which is different from the mixed precision training. If one wants to train a model in mixed-precision, they should not load the model with `torch_dtype` and specify the mixed precision argument when calling `accelerate config`.
100
+
101
+ - Use shared layers: Since PPO algorithm requires to have both the active and reference model to be on the same device, we have decided to use shared layers to reduce the memory footprint of the model. This can be achieved by just speifying `num_shared_layers` argument when creating a `PPOTrainer`:
102
+
103
+ <div style="text-align: center">
104
+ <img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/trl-shared-layers.png">
105
+ </div>
106
+
107
+ ```python
108
+ ppo_trainer = PPOTrainer(
109
+ model=model,
110
+ tokenizer=tokenizer,
111
+ num_shared_layers=4,
112
+ ...
113
+ )
114
+ ```
115
+
116
+ In the example above this means that the model have the 4 first layers frozen (i.e. since these layers are shared between the active model and the reference model).
117
+
118
+ - One could have also applied gradient checkpointing to reduce the memory footprint of the model by calling `model.pretrained_model.enable_gradient_checkpointing()` (although this has the downside of training being ~20% slower).
119
+
120
+ ## Training the model!
121
+
122
+ We have decided to keep 3 models in total that correspond to our best models:
123
+
124
+ - [`ybelkada/gpt-neo-125m-detox`](https://huggingface.co/ybelkada/gpt-neo-125m-detox)
125
+ - [`ybelkada/gpt-neo-2.7B-detox`](https://huggingface.co/ybelkada/gpt-neo-2.7B-detox)
126
+ - [`ybelkada/gpt-j-6b-detox`](https://huggingface.co/ybelkada/gpt-j-6b-detox)
127
+
128
+ We have used different learning rates for each model, and have found out that the largest models were quite hard to train and can easily lead to collapse mode if the learning rate is not chosen correctly (i.e. if the learning rate is too high):
129
+
130
+ <div style="text-align: center">
131
+ <img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/trl-collapse-mode.png">
132
+ </div>
133
+
134
+ The final training run of `ybelkada/gpt-j-6b-detoxified-20shdl` looks like this:
135
+
136
+ <div style="text-align: center">
137
+ <img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/trl-gpt-j-final-run-2.png">
138
+ </div>
139
+
140
+ As you can see the model converges nicely, but obviously we don't observe a very large improvement from the first step, as the original model is not trained to generate toxic contents.
141
+
142
+ Also we have observed that training with larger `mini_batch_size` leads to smoother convergence and better results on the test set:
143
+
144
+ <div style="text-align: center">
145
+ <img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/trl-gpt-j-mbs-run.png">
146
+ </div>
147
+
148
+ ## Results
149
+
150
+ We tested our models on a new dataset, the [`OxAISH-AL-LLM/wiki_toxic`](https://huggingface.co/datasets/OxAISH-AL-LLM/wiki_toxic) dataset. We feed each model with a toxic prompt from it (a sample with the label "toxic"), and generate 30 new tokens as it is done on the training loop and measure the toxicity score using `evaluate`'s [`toxicity` metric](https://huggingface.co/spaces/ybelkada/toxicity).
151
+ We report the toxicity score of 400 sampled examples, compute its mean and standard deviation and report the results in the table below:
152
+
153
+ | Model | Mean toxicity score | Std toxicity score |
154
+ | --- | --- | --- |
155
+ | `EleutherAI/gpt-neo-125m` | 0.1627 | 0.2997 |
156
+ | `ybelkada/gpt-neo-125m-detox` | **0.1148** | **0.2506** |
157
+ | --- | --- | --- |
158
+ | `EleutherAI/gpt-neo-2.7B` | 0.1884 | ,0.3178 |
159
+ | `ybelkada/gpt-neo-2.7B-detox` | **0.0916** | **0.2104** |
160
+ | --- | --- | --- |
161
+ | `EleutherAI/gpt-j-6B` | 0.1699 | 0.3033 |
162
+ | `ybelkada/gpt-j-6b-detox` | **0.1510** | **0.2798** |
163
+
164
+ <div class="column" style="text-align:center">
165
+ <figure>
166
+ <img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/trl-final-barplot.png" style="width:80%">
167
+ <figcaption>Toxicity score with respect to the size of the model.</figcaption>
168
+ </figure>
169
+ </div>
170
+
171
+ Below are few generation examples of `gpt-j-6b-detox` model:
172
+
173
+ <div style="text-align: center">
174
+ <img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/trl-toxicity-examples.png">
175
+ </div>
176
+
177
+ The evaluation script can be found [here](https://github.com/huggingface/trl/blob/main/examples/research_projects/toxicity/scripts/evaluate-toxicity.py).
178
+
179
+ ### Discussions
180
+
181
+ The results are quite promising, as we can see that the models are able to reduce the toxicity score of the generated text by an interesting margin. The gap is clear for `gpt-neo-2B` model but we less so for the `gpt-j-6B` model. There are several things we could try to improve the results on the largest model starting with training with larger `mini_batch_size` and probably allowing to back-propagate through more layers (i.e. use less shared layers).
182
+
183
+ To sum up, in addition to human feedback this could be a useful additional signal when training large language models to ensure there outputs are less toxic as well as useful.
184
+
185
+ ### Limitations
186
+
187
+ We are also aware of consistent bias issues reported with toxicity classifiers, and of work evaluating the negative impact of toxicity reduction on the diversity of outcomes. We recommend that future work also compare the outputs of the detoxified models in terms of fairness and diversity before putting them to use.
188
+
189
+ ## What is next?
190
+
191
+ You can download the model and use it out of the box with `transformers`, or play with the Spaces that compares the output of the models before and after detoxification [here](https://huggingface.co/spaces/ybelkada/detoxified-lms).
docs/source/dpo_trainer.mdx ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # DPO Trainer
2
+
3
+ TRL supports the DPO Trainer for training language models from preference data, as described in the paper [Direct Preference Optimization: Your Language Model is Secretly a Reward Model](https://arxiv.org/abs/2305.18290) by Rafailov et al., 2023. For a full example have a look at [`examples/scripts/dpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/dpo.py).
4
+
5
+
6
+ The first step as always is to train your SFT model, to ensure the data we train on is in-distribution for the DPO algorithm.
7
+
8
+ ## Expected dataset format
9
+
10
+ The DPO trainer expects a very specific format for the dataset. Since the model will be trained to directly optimize the preference of which sentence is the most relevant, given two sentences. We provide an example from the [`Anthropic/hh-rlhf`](https://huggingface.co/datasets/Anthropic/hh-rlhf) dataset below:
11
+
12
+ <div style="text-align: center">
13
+ <img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/rlhf-antropic-example.png", width="50%">
14
+ </div>
15
+
16
+ Therefore the final dataset object should contain these 3 entries if you use the default `DPODataCollatorWithPadding` data collator. The entries should be named:
17
+
18
+ - `prompt`
19
+ - `chosen`
20
+ - `rejected`
21
+
22
+ for example:
23
+
24
+ ```py
25
+ dpo_dataset_dict = {
26
+ "prompt": [
27
+ "hello",
28
+ "how are you",
29
+ "What is your name?",
30
+ "What is your name?",
31
+ "Which is the best programming language?",
32
+ "Which is the best programming language?",
33
+ "Which is the best programming language?",
34
+ ],
35
+ "chosen": [
36
+ "hi nice to meet you",
37
+ "I am fine",
38
+ "My name is Mary",
39
+ "My name is Mary",
40
+ "Python",
41
+ "Python",
42
+ "Java",
43
+ ],
44
+ "rejected": [
45
+ "leave me alone",
46
+ "I am not fine",
47
+ "Whats it to you?",
48
+ "I dont have a name",
49
+ "Javascript",
50
+ "C++",
51
+ "C++",
52
+ ],
53
+ }
54
+ ```
55
+
56
+ where the `prompt` contains the context inputs, `chosen` contains the corresponding chosen responses and `rejected` contains the corresponding negative (rejected) responses. As can be seen a prompt can have multiple responses and this is reflected in the entries being repeated in the dictionary's value arrays.
57
+
58
+ ## Expected model format
59
+ The DPO trainer expects a model of `AutoModelForCausalLM`, compared to PPO that expects `AutoModelForCausalLMWithValueHead` for the value function.
60
+
61
+ ## Using the `DPOTrainer`
62
+
63
+ For a detailed example have a look at the `examples/scripts/dpo.py` script. At a high level we need to initialize the `DPOTrainer` with a `model` we wish to train, a reference `ref_model` which we will use to calculate the implicit rewards of the preferred and rejected response, the `beta` refers to the hyperparameter of the implicit reward, and the dataset contains the 3 entries listed above. Note that the `model` and `ref_model` need to have the same architecture (ie decoder only or encoder-decoder).
64
+
65
+ ```py
66
+ dpo_trainer = DPOTrainer(
67
+ model,
68
+ model_ref,
69
+ args=training_args,
70
+ beta=0.1,
71
+ train_dataset=train_dataset,
72
+ tokenizer=tokenizer,
73
+ )
74
+ ```
75
+ After this one can then call:
76
+
77
+ ```py
78
+ dpo_trainer.train()
79
+ ```
80
+
81
+ Note that the `beta` is the temperature parameter for the DPO loss, typically something in the range of `0.1` to `0.5`. We ignore the reference model as `beta` -> 0.
82
+
83
+ ## Loss functions
84
+
85
+ Given the preference data, we can fit a binary classifier according to the Bradley-Terry model and in fact the DPO authors propose the sigmoid loss on the normalized likelihood via the `logsigmoid` to fit a logistic regression.
86
+
87
+ The [RSO](https://arxiv.org/abs/2309.06657) authors propose to use a hinge loss on the normalized likelihood from the [SLiC](https://arxiv.org/abs/2305.10425) paper. The `DPOTrainer` can be switched to this loss via the `loss_type="hinge"` argument and the `beta` in this case is the reciprocal of the margin.
88
+
89
+ The [IPO](https://arxiv.org/abs/2310.12036) authors provide a deeper theoretical understanding of the DPO algorithms and identify an issue with overfitting and propose an alternative loss which can be used via the `loss_type="ipo"` argument to the trainer.
90
+
91
+ The [cDPO](https://ericmitchell.ai/cdpo.pdf) is a tweak on the DPO loss where we assume that the preference labels are noisy with some probability that can be passed to the `DPOTrainer` via `label_smoothing` argument (between 0 and 0.5) and then a conservative DPO loss is used. Use the `loss_type="cdpo"` argument to the trainer to use it.
92
+
93
+ The [KTO](https://github.com/ContextualAI/HALOs/blob/main/assets/report.pdf) loss is derived to directly maximize the utility of LLM generations instead of the log-likelihood of prefereces. Thus the dataset are not neccsarily prefereces but rather desirable vs undersirable pairs. Use the `loss_type="kto"` argument to the trainer to utilize this loss.
94
+
95
+ ## Logging
96
+
97
+ While training and evaluating we record the following reward metrics:
98
+
99
+ * `rewards/chosen`: the mean difference between the log probabilities of the policy model and the reference model for the chosen responses scaled by beta
100
+ * `rewards/rejected`: the mean difference between the log probabilities of the policy model and the reference model for the rejected responses scaled by beta
101
+ * `rewards/accuracies`: mean of how often the chosen rewards are > than the corresponding rejected rewards
102
+ * `rewards/margins`: the mean difference between the chosen and corresponding rejected rewards
103
+
104
+ ## DPOTrainer
105
+
106
+ [[autodoc]] DPOTrainer
docs/source/example_overview.md ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Examples
2
+
3
+
4
+ ## Introduction
5
+
6
+ The examples should work in any of the following settings (with the same script):
7
+ - single GPU
8
+ - multi GPUS (using PyTorch distributed mode)
9
+ - multi GPUS (using DeepSpeed ZeRO-Offload stages 1, 2, & 3)
10
+ - fp16 (mixed-precision), fp32 (normal precision), or bf16 (bfloat16 precision)
11
+
12
+ To run it in each of these various modes, first initialize the accelerate
13
+ configuration with `accelerate config`
14
+
15
+ **NOTE to train with a 4-bit or 8-bit model**, please run
16
+
17
+ ```bash
18
+ pip install --upgrade trl[quantization]
19
+ ```
20
+
21
+
22
+ ## Accelerate Config
23
+ For all the examples, you'll need to generate a 🤗 Accelerate config file with:
24
+
25
+ ```shell
26
+ accelerate config # will prompt you to define the training configuration
27
+ ```
28
+
29
+ Then, it is encouraged to launch jobs with `accelerate launch`!
30
+
31
+
32
+ # Maintained Examples
33
+
34
+
35
+ | File | Description |
36
+ |------------------------------------------------------------------------------------------------------------|--------------------------------------------------------------------------------------------------------------------------|
37
+ | [`examples/scripts/sft.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/sft.py) | This script shows how to use the `SFTTrainer` to fine tune a model or adapters into a target dataset. |
38
+ | [`examples/scripts/reward_modeling.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/reward_modeling.py) | This script shows how to use the `RewardTrainer` to train a reward model on your own dataset. |
39
+ | [`examples/scripts/ppo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/ppo.py) | This script shows how to use the `PPOTrainer` to fine-tune a sentiment analysis model using IMDB dataset |
40
+ | [`examples/scripts/ppo_multi_adapter.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/ppo_multi_adapter.py) | This script shows how to use the `PPOTrainer` to train a single base model with multiple adapters. Requires you to run the example script with the reward model training beforehand. |
41
+ | [`examples/scripts/stable_diffusion_tuning_example.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/stable_diffusion_tuning_example.py) | This script shows to use DDPOTrainer to fine-tune a stable diffusion model using reinforcement learning. |
42
+
43
+ Here are also some easier-to-run colab notebooks that you can use to get started with TRL:
44
+
45
+
46
+ | File | Description |
47
+ |----------------------------------------------------------------------------------------------------------------------------|--------------------------------------------------------------------------------------------------------------------------|
48
+ | [`examples/notebooks/best_of_n.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/best_of_n.ipynb) | This notebook demonstrates how to use the "Best of N" sampling strategy using TRL when fine-tuning your model with PPO. |
49
+ | [`examples/notebooks/gpt2-sentiment.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/gpt2-sentiment.ipynb) | This notebook demonstrates how to reproduce the GPT2 imdb sentiment tuning example on a jupyter notebook. |
50
+ | [`examples/notebooks/gpt2-control.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/gpt2-control.ipynb) | This notebook demonstrates how to reproduce the GPT2 sentiment control example on a jupyter notebook. |
51
+
52
+
53
+ We also have some other examples that are less maintained but can be used as a reference:
54
+ 1. **[research_projects](https://github.com/huggingface/trl/tree/main/examples/research_projects)**: Check out this folder to find the scripts used for some research projects that used TRL (LM de-toxification, Stack-Llama, etc.)
55
+
56
+
57
+ ## Distributed training
58
+
59
+ All of the scripts can be run on multiple GPUs by providing the path of an 🤗 Accelerate config file when calling `accelerate launch`. To launch one of them on one or multiple GPUs, run the following command (swapping `{NUM_GPUS}` with the number of GPUs in your machine and `--all_arguments_of_the_script` with your arguments.)
60
+
61
+ ```shell
62
+ accelerate launch --config_file=examples/accelerate_configs/multi_gpu.yaml --num_processes {NUM_GPUS} path_to_script.py --all_arguments_of_the_script
63
+ ```
64
+
65
+ You can also adjust the parameters of the 🤗 Accelerate config file to suit your needs (e.g. training in mixed precision).
66
+
67
+ ### Distributed training with DeepSpeed
68
+
69
+ Most of the scripts can be run on multiple GPUs together with DeepSpeed ZeRO-{1,2,3} for efficient sharding of the optimizer states, gradients, and model weights. To do so, run following command (swapping `{NUM_GPUS}` with the number of GPUs in your machine, `--all_arguments_of_the_script` with your arguments, and `--deepspeed_config` with the path to the DeepSpeed config file such as `examples/deepspeed_configs/deepspeed_zero1.yaml`):
70
+
71
+ ```shell
72
+ accelerate launch --config_file=examples/accelerate_configs/deepspeed_zero{1,2,3}.yaml --num_processes {NUM_GPUS} path_to_script.py --all_arguments_of_the_script
73
+ ```
docs/source/how_to_train.md ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Training FAQ
2
+
3
+ ## What Metrics Should I Look at?
4
+
5
+ When performing classical supervised fine-tuning of language models, the loss (especially the validation loss) serves as a good indicator of the training progress. However, in Reinforcement Learning (RL), the loss becomes less informative about the model's performance, and its value may fluctuate while the actual performance improves.
6
+
7
+ To address this, we recommend focusing on two key metrics first:
8
+
9
+ **Mean Reward**: The primary goal is to maximize the reward achieved by the model during RL training.
10
+ **Objective KL Divergence**: KL divergence (Kullback-Leibler divergence) measures the dissimilarity between two probability distributions. In the context of RL training, we use it to quantify the difference between the current model and a reference model. Ideally, we want to keep the KL divergence between 0 and 10 to ensure the model's generated text remains close to what the reference model produces.
11
+
12
+ However, there are more metrics that can be useful for debugging, checkout the [logging section](logging).
13
+
14
+ ## Why Do We Use a Reference Model, and What's the Purpose of KL Divergence?
15
+
16
+ When training RL models, optimizing solely for reward may lead to unexpected behaviors, where the model exploits the environment in ways that don't align with good language generation. In the case of RLHF, we use a reward model trained to predict whether a generated text is highly ranked by humans.
17
+
18
+ However, the RL model being optimized against the reward model may learn patterns that yield high reward but do not represent good language. This can result in extreme cases where the model generates texts with excessive exclamation marks or emojis to maximize the reward. In some worst-case scenarios, the model may generate patterns completely unrelated to natural language yet receive high rewards, similar to adversarial attacks.
19
+
20
+ <div style="text-align: center">
21
+ <img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/kl-example.png">
22
+ <p style="text-align: center;"> <b>Figure:</b> Samples without a KL penalty from <a href="https://arxiv.org/pdf/1909.08593.pdf">https://arxiv.org/pdf/1909.08593.pdf</a>. </p>
23
+ </div>
24
+
25
+ To address this issue, we add a penalty to the reward function based on the KL divergence between the current model and the reference model. By doing this, we encourage the model to stay close to what the reference model generates.
26
+
27
+ ## What Is the Concern with Negative KL Divergence?
28
+
29
+ If you generate text by purely sampling from the model distribution things work fine in general. But when you use the `generate` method there are a few caveats because it does not always purely sample depending on the settings which can cause KL-divergence to go negative. Essentially when the active model achieves `log_p_token_active < log_p_token_ref` we get negative KL-div. This can happen in a several cases:
30
+
31
+ - **top-k sampling**: the model can smooth out the probability distribution causing the top-k tokens having a smaller probability than those of the reference model but they still are selected
32
+ - **min_length**: this ignores the EOS token until `min_length` is reached. thus the model can assign a very low log prob to the EOS token and very high probs to all others until min_length is reached
33
+ - **min_length**: this ignores the EOS token until `min_length` is reached, thus the model can assign a very low log prob to the EOS token and very high probs to all others until min_length is reached
34
+
35
+ These are just a few examples. Why is negative KL an issue? The total reward `R` is computed `R = r - beta * KL` so if the model can learn how to drive KL-divergence negative it effectively gets a positive reward. In many cases it can be much easier to exploit such a bug in the generation than actually learning the reward function. In addition the KL can become arbitrarily small thus the actual reward can be very small compared to it.
36
+
37
+ So how should you generate text for PPO training? Let's have a look!
38
+
39
+ ## How to generate text for training?
40
+
41
+ In order to avoid the KL issues described above we recommend to use the following settings:
42
+
43
+ ```python
44
+ generation_kwargs = {
45
+ "min_length": -1, # don't ignore the EOS token (see above)
46
+ "top_k": 0.0, # no top-k sampling
47
+ "top_p": 1.0, # no nucleus sampling
48
+ "do_sample": True, # yes, we want to sample
49
+ "pad_token_id": tokenizer.eos_token_id, # most decoder models don't have a padding token - use EOS token instead
50
+ "max_new_tokens": 32, # specify how many tokens you want to generate at most
51
+ }
52
+ ```
53
+
54
+ With these settings we usually don't encounter any issues. You can also experiments with other settings but if you encounter issues with negative KL-divergence try to go back to these and see if they persist.
55
+
56
+ ## How can debug your own use-case?
57
+
58
+ Debugging the RL pipeline can be challenging due to its complexity. Here are some tips and suggestions to make the process easier:
59
+
60
+ - **Start from a working example**: Begin with a working example from the trl repository and gradually modify it to fit your specific use-case. Changing everything at once can make it difficult to identify the source of potential issues. For example, you can start by replacing the model in the example and once you figure out the best hyperparameters try to switch to your dataset and reward model. If you change everything at once you won't know where a potential problem comes from.
61
+ - **Start small, scale later**: Training large models can be very slow and take several hours or days until you see any improvement. For debugging this is not a convenient timescale so try to use small model variants during the development phase and scale up once that works. That being said you sometimes have to be careful as small models might not have the capacity to solve a complicated task either.
62
+ - **Start simple**: Try to start with a minimal example and build complexity from there. Your use-case might require for example a complicated reward function consisting of many different rewards - try to use one signal first and see if you can optimize that and then add more complexity after that.
63
+ - **Inspect the generations**: It's always a good idea to inspect what the model is generating. Maybe there is a big in your post-processing or your prompt. Due to bad settings you might cut-off generations too soon. These things are very hard to see on the metrics but very obvious if you look at the generations.
64
+ - **Inspect the reward model**: If you reward is not improving over time maybe there's an issue with the reward model. You can look at extreme cases to see if it does what it should: e.g. in the sentiment case you can check if simple positive and negative examples really get different rewards. And you can look at the distribution of your dataset. Finally, maybe the reward is dominated by the query which the model can't affect so you might need to normalize this (e.g. reward of query+response minus reward of the query).
65
+
66
+ These are just a few tips that we find helpful - if you have more useful tricks feel free to open a PR to add them as well!
docs/source/index.mdx ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <div style="text-align: center">
2
+ <img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/trl_banner_dark.png">
3
+ </div>
4
+
5
+ # TRL - Transformer Reinforcement Learning
6
+
7
+ TRL is a full stack library where we provide a set of tools to train transformer language models with Reinforcement Learning, from the Supervised Fine-tuning step (SFT), Reward Modeling step (RM) to the Proximal Policy Optimization (PPO) step.
8
+ The library is integrated with 🤗 [transformers](https://github.com/huggingface/transformers).
9
+
10
+ <div style="text-align: center">
11
+ <img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/TRL-readme.png">
12
+ </div>
13
+
14
+ Check the appropriate sections of the documentation depending on your needs:
15
+
16
+ ## API documentation
17
+
18
+ - [Model Classes](models): *A brief overview of what each public model class does.*
19
+ - [`SFTTrainer`](sft_trainer): *Supervise Fine-tune your model easily with `SFTTrainer`*
20
+ - [`RewardTrainer`](reward_trainer): *Train easily your reward model using `RewardTrainer`.*
21
+ - [`PPOTrainer`](ppo_trainer): *Further fine-tune the supervised fine-tuned model using PPO algorithm*
22
+ - [Best-of-N Sampling](best-of-n): *Use best of n sampling as an alternative way to sample predictions from your active model*
23
+ - [`DPOTrainer`](dpo_trainer): *Direct Preference Optimization training using `DPOTrainer`.*
24
+ - [`TextEnvironment`](text_environment): *Text environment to train your model using tools with RL.*
25
+
26
+ ## Examples
27
+
28
+ - [Sentiment Tuning](sentiment_tuning): *Fine tune your model to generate positive movie contents*
29
+ - [Training with PEFT](lora_tuning_peft): *Memory efficient RLHF training using adapters with PEFT*
30
+ - [Detoxifying LLMs](detoxifying_a_lm): *Detoxify your language model through RLHF*
31
+ - [StackLlama](using_llama_models): *End-to-end RLHF training of a Llama model on Stack exchange dataset*
32
+ - [Learning with Tools](learning_tools): *Walkthrough of using `TextEnvironments`*
33
+ - [Multi-Adapter Training](multi_adapter_rl): *Use a single base model and multiple adapters for memory efficient end-to-end training*
34
+
35
+
36
+ ## Blog posts
37
+
38
+ <div class="mt-10">
39
+ <div class="w-full flex flex-col space-y-4 md:space-y-0 md:grid md:grid-cols-2 md:gap-y-4 md:gap-x-5">
40
+ <a class="!no-underline border dark:border-gray-700 p-5 rounded-lg shadow hover:shadow-lg" href="https://huggingface.co/blog/rlhf">
41
+ <img src="https://raw.githubusercontent.com/huggingface/blog/main/assets/120_rlhf/thumbnail.png" alt="thumbnail">
42
+ <p class="text-gray-700">Illustrating Reinforcement Learning from Human Feedback</p>
43
+ </a>
44
+ <a class="!no-underline border dark:border-gray-700 p-5 rounded-lg shadow hover:shadow-lg" href="https://huggingface.co/blog/trl-peft">
45
+ <img src="https://github.com/huggingface/blog/blob/main/assets/133_trl_peft/thumbnail.png?raw=true" alt="thumbnail">
46
+ <p class="text-gray-700">Fine-tuning 20B LLMs with RLHF on a 24GB consumer GPU</p>
47
+ </a>
48
+ <a class="!no-underline border dark:border-gray-700 p-5 rounded-lg shadow hover:shadow-lg" href="https://huggingface.co/blog/stackllama">
49
+ <img src="https://github.com/huggingface/blog/blob/main/assets/138_stackllama/thumbnail.png?raw=true" alt="thumbnail">
50
+ <p class="text-gray-700">StackLLaMA: A hands-on guide to train LLaMA with RLHF</p>
51
+ </a>
52
+ <a class="!no-underline border dark:border-gray-700 p-5 rounded-lg shadow hover:shadow-lg" href="https://huggingface.co/blog/dpo-trl">
53
+ <img src="https://github.com/huggingface/blog/blob/main/assets/157_dpo_trl/dpo_thumbnail.png?raw=true" alt="thumbnail">
54
+ <p class="text-gray-700">Fine-tune Llama 2 with DPO</p>
55
+ </a>
56
+ <a class="!no-underline border dark:border-gray-700 p-5 rounded-lg shadow hover:shadow-lg" href="https://huggingface.co/blog/trl-ddpo">
57
+ <img src="https://github.com/huggingface/blog/blob/main/assets/166_trl_ddpo/thumbnail.png?raw=true" alt="thumbnail">
58
+ <p class="text-gray-700">Finetune Stable Diffusion Models with DDPO via TRL</p>
59
+ </a>
60
+ </div>
61
+ </div>
docs/source/installation.mdx ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Installation
2
+ You can install TRL either from pypi or from source:
3
+
4
+ ## pypi
5
+ Install the library with pip:
6
+
7
+ ```bash
8
+ pip install trl
9
+ ```
10
+
11
+ ### Source
12
+ You can also install the latest version from source. First clone the repo and then run the installation with `pip`:
13
+
14
+ ```bash
15
+ git clone https://github.com/huggingface/trl.git
16
+ cd trl/
17
+ pip install -e .
18
+ ```
19
+
20
+ If you want the development install you can replace the pip install with the following:
21
+
22
+ ```bash
23
+ pip install -e ".[dev]"
24
+ ```
docs/source/iterative_sft_trainer.mdx ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Iterative Trainer
2
+
3
+ Iterative fine-tuning is a training method that enables to perform custom actions (generation and filtering for example) between optimization steps. In TRL we provide an easy-to-use API to fine-tune your models in an iterative way in just a few lines of code.
4
+
5
+ ## Usage
6
+
7
+ To get started quickly, instantiate an instance a model, and a tokenizer.
8
+
9
+ ```python
10
+
11
+ model = AutoModelForCausalLM.from_pretrained(model_name)
12
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
13
+ if tokenizer.pad_token is None:
14
+ tokenizer.pad_token = tokenizer.eos_token
15
+
16
+ trainer = IterativeSFTTrainer(
17
+ model,
18
+ tokenizer
19
+ )
20
+
21
+ ```
22
+
23
+ You have the choice to either provide a list of strings or a list of tensors to the step function.
24
+
25
+ #### Using a list of tensors as input:
26
+
27
+ ```python
28
+
29
+ inputs = {
30
+ "input_ids": input_ids,
31
+ "attention_mask": attention_mask
32
+ }
33
+
34
+ trainer.step(**inputs)
35
+
36
+ ```
37
+
38
+ #### Using a list of strings as input:
39
+
40
+ ```python
41
+
42
+ inputs = {
43
+ "texts": texts
44
+ }
45
+
46
+ trainer.step(**inputs)
47
+
48
+ ```
49
+
50
+ For causal language models, labels will automatically be created from input_ids or from texts. When using sequence to sequence models you will have to provide your own labels or text_labels.
51
+
52
+ ## IterativeTrainer
53
+
54
+ [[autodoc]] IterativeSFTTrainer
docs/source/learning_tools.mdx ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Learning Tools (Experimental 🧪)
2
+
3
+ Using Large Language Models (LLMs) with tools has been a popular topic recently with awesome works such as [ToolFormer](https://arxiv.org/abs/2302.04761) and [ToolBench](https://arxiv.org/pdf/2305.16504.pdf). In TRL, we provide a simple example of how to teach LLM to use tools with reinforcement learning.
4
+
5
+
6
+ Here's an overview of the scripts in the [trl repository](https://github.com/lvwerra/trl/tree/main/examples/research_projects/tools):
7
+
8
+ | File | Description |
9
+ |---|---|
10
+ | [`calculator.py`](https://github.com/lvwerra/trl/blob/main/examples/research_projects/tools/calculator.py) | Script to train LLM to use a calculator with reinforcement learning. |
11
+ | [`triviaqa.py`](https://github.com/lvwerra/trl/blob/main/examples/research_projects/tools/triviaqa.py) | Script to train LLM to use a wiki tool to answer questions. |
12
+ | [`python_interpreter.py`](https://github.com/lvwerra/trl/blob/main/examples/research_projects/tools/python_interpreter.py) | Script to train LLM to use python interpreter to solve math puzzles. |
13
+
14
+ <Tip warning={true}>
15
+
16
+ Note that the scripts above rely heavily on the `TextEnvironment` API which is still under active development. The API may change in the future. Please see [`TextEnvironment`](text_environment) for the related docs.
17
+ </Tip>
18
+
19
+
20
+ ## Learning to Use a Calculator
21
+
22
+
23
+ The rough idea is as follows:
24
+
25
+ 1. Load a tool such as [ybelkada/simple-calculator](https://huggingface.co/spaces/ybelkada/simple-calculator) that parse a text calculation like `"14 + 34"` and return the calulated number:
26
+ ```python
27
+ from transformers import AutoTokenizer, load_tool
28
+ tool = load_tool("ybelkada/simple-calculator")
29
+ tool_fn = lambda text: str(round(float(tool(text)), 2)) # rounding to 2 decimal places
30
+ ```
31
+ 1. Define a reward function that returns a positive reward if the tool returns the correct answer. In the script we create a dummy reward function like `reward_fn = lambda x: 1`, but we override the rewards directly later.
32
+ 1. Create a prompt on how to use the tools
33
+ ```python
34
+ # system prompt
35
+ prompt = """\
36
+ What is 13.1-3?
37
+
38
+ <request><SimpleCalculatorTool>13.1-3<call>10.1<response>
39
+
40
+ Result=10.1<submit>
41
+
42
+ What is 4*3?
43
+
44
+ <request><SimpleCalculatorTool>4*3<call>12<response>
45
+
46
+ Result=12<submit>
47
+
48
+ What is 12.1+1?
49
+
50
+ <request><SimpleCalculatorTool>12.1+1<call>13.1<response>
51
+
52
+ Result=13.1<submit>
53
+
54
+ What is 12.1-20?
55
+
56
+ <request><SimpleCalculatorTool>12.1-20<call>-7.9<response>
57
+
58
+ Result=-7.9<submit>"""
59
+ ```
60
+ 3. Create a `trl.TextEnvironment` with the model
61
+ ```python
62
+ env = TextEnvironment(
63
+ model,
64
+ tokenizer,
65
+ {"SimpleCalculatorTool": tool_fn},
66
+ reward_fn,
67
+ prompt,
68
+ generation_kwargs=generation_kwargs,
69
+ )
70
+ ```
71
+ 4. Then generate some data such as `tasks = ["\n\nWhat is 13.1-3?", "\n\nWhat is 4*3?"]` and run the environment with `queries, responses, masks, rewards, histories = env.run(tasks)`. The environment will look for the `<call>` token in the prompt and append the tool output to the response; it will also return the mask associated with the response. You can further use the `histories` to visualize the interaction between the model and the tool; `histories[0].show_text()` will show the text with color-coded tool output and `histories[0].show_tokens(tokenizer)` will show visualize the tokens.
72
+ ![](https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/learning_tools.png)
73
+ 1. Finally, we can train the model with `train_stats = ppo_trainer.step(queries, responses, rewards, masks)`. The trainer will use the mask to ignore the tool output when computing the loss, make sure to pass that argument to `step`.
74
+
75
+ ## Experiment results
76
+
77
+ We trained a model with the above script for 10 random seeds. You can reproduce the run with the following command. Feel free to remove the `--slurm-*` arguments if you don't have access to a slurm cluster.
78
+
79
+ ```
80
+ WANDB_TAGS="calculator_final" python benchmark/benchmark.py \
81
+ --command "python examples/research_projects/tools/calculator.py" \
82
+ --num-seeds 10 \
83
+ --start-seed 1 \
84
+ --workers 10 \
85
+ --slurm-gpus-per-task 1 \
86
+ --slurm-ntasks 1 \
87
+ --slurm-total-cpus 8 \
88
+ --slurm-template-path benchmark/trl.slurm_template
89
+ ```
90
+
91
+ We can then use [`openrlbenchmark`](https://github.com/openrlbenchmark/openrlbenchmark) which generates the following plot.
92
+ ```
93
+ python -m openrlbenchmark.rlops_multi_metrics \
94
+ --filters '?we=openrlbenchmark&wpn=trl&xaxis=_step&ceik=trl_ppo_trainer_config.value.tracker_project_name&cen=trl_ppo_trainer_config.value.log_with&metrics=env/reward_mean&metrics=objective/kl' \
95
+ 'wandb?tag=calculator_final&cl=calculator_mask' \
96
+ --env-ids trl \
97
+ --check-empty-runs \
98
+ --pc.ncols 2 \
99
+ --pc.ncols-legend 1 \
100
+ --output-filename static/0compare \
101
+ --scan-history
102
+ ```
103
+
104
+ ![](https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/learning_tools_chart.png)
105
+
106
+ As we can see, while 1-2 experiments crashed for some reason, most of the runs obtained near perfect proficiency in the calculator task.
107
+
108
+
109
+ ## (Early Experiments 🧪): learning to use a wiki tool for question answering
110
+
111
+ In the [ToolFormer](https://arxiv.org/abs/2302.04761) paper, it shows an interesting use case that utilizes a Wikipedia Search tool to help answer questions. In this section, we attempt to perform similar experiments but uses RL instead to teach the model to use a wiki tool on the [TriviaQA](https://nlp.cs.washington.edu/triviaqa/) dataset.
112
+
113
+
114
+ <Tip warning={true}>
115
+
116
+ **Note that many settings are different so the results are not directly comparable.**
117
+ </Tip>
118
+
119
+
120
+
121
+
122
+ ### Building a search index
123
+
124
+ Since [ToolFormer](https://arxiv.org/abs/2302.04761) did not open source, we needed to first replicate the search index. It is mentioned in their paper that the authors built the search index using a BM25 retriever that indexes the Wikipedia dump from [KILT](https://github.com/facebookresearch/KILT)
125
+
126
+ Fortunately, [`pyserini`](https://github.com/castorini/pyserini) already implements the BM25 retriever and provides a prebuilt index for the KILT Wikipedia dump. We can use the following code to search the index.
127
+
128
+ ```python
129
+ from pyserini.search.lucene import LuceneSearcher
130
+ import json
131
+ searcher = LuceneSearcher.from_prebuilt_index('wikipedia-kilt-doc')
132
+ def search(query):
133
+ hits = searcher.search(query, k=1)
134
+ hit = hits[0]
135
+ contents = json.loads(hit.raw)['contents']
136
+ return contents
137
+ print(search("tennis racket"))
138
+ ```
139
+ ```
140
+ Racket (sports equipment)
141
+ A racket or racquet is a sports implement consisting of a handled frame with an open hoop across which a network of strings or catgut is stretched tightly. It is used for striking a ball or shuttlecock in games such as squash, tennis, racquetball, and badminton. Collectively, these games are known as racket sports. Racket design and manufacturing has changed considerably over the centuries.
142
+
143
+ The frame of rackets for all sports was traditionally made of solid wood (later laminated wood) and the strings of animal intestine known as catgut. The traditional racket size was limited by the strength and weight of the wooden frame which had to be strong enough to hold the strings and stiff enough to hit the ball or shuttle. Manufacturers started adding non-wood laminates to wood rackets to improve stiffness. Non-wood rackets were made first of steel, then of aluminum, and then carbon fiber composites. Wood is still used for real tennis, rackets, and xare. Most rackets are now made of composite materials including carbon fiber or fiberglass, metals such as titanium alloys, or ceramics.
144
+ ...
145
+ ```
146
+
147
+ We then basically deployed this snippet as a Hugging Face space [here](https://huggingface.co/spaces/vwxyzjn/pyserini-wikipedia-kilt-doc), so that we can use the space as a `transformers.Tool` later.
148
+
149
+ ![](https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/pyserini.png)
150
+
151
+ ### Experiment settings
152
+
153
+ We use the following settings:
154
+
155
+ * use the `bigcode/starcoderbase` model as the base model
156
+ * use the `pyserini-wikipedia-kilt-doc` space as the wiki tool and only uses the first paragrahs of the search result, allowing the `TextEnvironment` to obtain at most `max_tool_reponse=400` response tokens from the tool.
157
+ * test if the response contain the answer string, if so, give a reward of 1, otherwise, give a reward of 0.
158
+ * notice this is a simplified evaluation criteria. In [ToolFormer](https://arxiv.org/abs/2302.04761), the authors checks if the first 20 words of the response contain the correct answer.
159
+ * used the following prompt that demonstrates the usage of the wiki tool.
160
+ ```python
161
+ prompt = """\
162
+ Answer the following question:
163
+
164
+ Q: In which branch of the arts is Patricia Neary famous?
165
+ A: Ballets
166
+ A2: <request><Wiki>Patricia Neary<call>Patricia Neary (born October 27, 1942) is an American ballerina, choreographer and ballet director, who has been particularly active in Switzerland. She has also been a highly successful ambassador for the Balanchine Trust, bringing George Balanchine's ballets to 60 cities around the globe.<response>
167
+ Result=Ballets<submit>
168
+
169
+ Q: Who won Super Bowl XX?
170
+ A: Chicago Bears
171
+ A2: <request><Wiki>Super Bowl XX<call>Super Bowl XX was an American football game between the National Football Conference (NFC) champion Chicago Bears and the American Football Conference (AFC) champion New England Patriots to decide the National Football League (NFL) champion for the 1985 season. The Bears defeated the Patriots by the score of 46–10, capturing their first NFL championship (and Chicago's first overall sports victory) since 1963, three years prior to the birth of the Super Bowl. Super Bowl XX was played on January 26, 1986 at the Louisiana Superdome in New Orleans.<response>
172
+ Result=Chicago Bears<submit>
173
+
174
+ Q: """
175
+ ```
176
+
177
+
178
+ ### Result and Discussion
179
+
180
+
181
+ Our experiments show that the agent can learn to use the wiki tool to answer questions. The learning curves would go up mostly, but one of the experiment did crash.
182
+
183
+ ![](https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/triviaqa_learning_curves.png)
184
+
185
+ Wandb report is [here](https://wandb.ai/costa-huang/cleanRL/reports/TriviaQA-Final-Experiments--Vmlldzo1MjY0ODk5) for further inspection.
186
+
187
+
188
+ Note that the correct rate of the trained model is on the low end, which could be due to the following reasons:
189
+
190
+ * **incorrect searches:** When given the question `"What is Bruce Willis' real first name?"` if the model searches for `Bruce Willis`, our wiki tool returns "Patrick Poivey (born 18 February 1948) is a French actor. He is especially known for his voice: he is the French dub voice of Bruce Willis since 1988.` But a correct search should be `Walter Bruce Willis (born March 19, 1955) is an American former actor. He achieved fame with a leading role on the comedy-drama series Moonlighting (1985–1989) and appeared in over a hundred films, gaining recognition as an action hero after his portrayal of John McClane in the Die Hard franchise (1988–2013) and other roles.[1][2]"
191
+
192
+
193
+ ![](https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/real_first_name.png)
194
+
195
+ * **unnecessarily long response**: The wiki tool by default sometimes output very long sequences. E.g., when the wiki tool searches for "Brown Act"
196
+ * Our wiki tool returns "The Ralph M. Brown Act, located at California Government Code 54950 "et seq.", is an act of the California State Legislature, authored by Assemblymember Ralph M. Brown and passed in 1953, that guarantees the public's right to attend and participate in meetings of local legislative bodies."
197
+ * [ToolFormer](https://arxiv.org/abs/2302.04761)'s wiki tool returns "The Ralph M. Brown Act is an act of the California State Legislature that guarantees the public's right to attend and participate in meetings of local legislative bodies." which is more succinct.
198
+
199
+ ![](https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/brown_act.png)
200
+
201
+
202
+ ## (Early Experiments 🧪): solving math puzzles with python interpreter
203
+
204
+ In this section, we attempt to teach the model to use a python interpreter to solve math puzzles. The rough idea is to give the agent a prompt like the following:
205
+
206
+ ```python
207
+ prompt = """\
208
+ Example of using a Python API to solve math questions.
209
+
210
+ Q: Olivia has $23. She bought five bagels for $3 each. How much money does she have left?
211
+
212
+ <request><PythonInterpreter>
213
+ def solution():
214
+ money_initial = 23
215
+ bagels = 5
216
+ bagel_cost = 3
217
+ money_spent = bagels * bagel_cost
218
+ money_left = money_initial - money_spent
219
+ result = money_left
220
+ return result
221
+ print(solution())
222
+ <call>72<response>
223
+
224
+ Result = 72 <submit>
225
+
226
+ Q: """
227
+ ```
228
+
229
+
230
+ Training experiment can be found at https://wandb.ai/lvwerra/trl-gsm8k/runs/a5odv01y
231
+
232
+ ![](https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/gms8k_learning_curve.png)
233
+
234
+
docs/source/logging.mdx ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Logging
2
+
3
+ As reinforcement learning algorithms are historically challenging to debug, it's important to pay careful attention to logging.
4
+ By default, the TRL [`PPOTrainer`] saves a lot of relevant information to `wandb` or `tensorboard`.
5
+
6
+ Upon initialization, pass one of these two options to the [`PPOConfig`]:
7
+ ```
8
+ config = PPOConfig(
9
+ model_name=args.model_name,
10
+ log_with=`wandb`, # or `tensorboard`
11
+ )
12
+ ```
13
+ If you want to log with tensorboard, add the kwarg `project_kwargs={"logging_dir": PATH_TO_LOGS}` to the PPOConfig.
14
+
15
+ ## PPO Logging
16
+
17
+ Here's a brief explanation for the logged metrics provided in the data:
18
+
19
+ Key metrics to monitor. We want to maximize the reward, maintain a low KL divergence, and maximize entropy:
20
+ 1. `env/reward_mean`: The average reward obtained from the environment. Alias `ppo/mean_scores`, which is sed to specifically monitor the reward model.
21
+ 1. `env/reward_std`: The standard deviation of the reward obtained from the environment. Alias ``ppo/std_scores`, which is sed to specifically monitor the reward model.
22
+ 1. `env/reward_dist`: The histogram distribution of the reward obtained from the environment.
23
+ 1. `objective/kl`: The mean Kullback-Leibler (KL) divergence between the old and new policies. It measures how much the new policy deviates from the old policy. The KL divergence is used to compute the KL penalty in the objective function.
24
+ 1. `objective/kl_dist`: The histogram distribution of the `objective/kl`.
25
+ 1. `objective/kl_coef`: The coefficient for Kullback-Leibler (KL) divergence in the objective function.
26
+ 1. `ppo/mean_non_score_reward`: The **KL penalty** calculated by `objective/kl * objective/kl_coef` as the total reward for optimization to prevent the new policy from deviating too far from the old policy.
27
+ 1. `objective/entropy`: The entropy of the model's policy, calculated by `-logprobs.sum(-1).mean()`. High entropy means the model's actions are more random, which can be beneficial for exploration.
28
+
29
+ Training stats:
30
+ 1. `ppo/learning_rate`: The learning rate for the PPO algorithm.
31
+ 1. `ppo/policy/entropy`: The entropy of the model's policy, calculated by `pd = torch.nn.functional.softmax(logits, dim=-1); entropy = torch.logsumexp(logits, dim=-1) - torch.sum(pd * logits, dim=-1)`. It measures the randomness of the policy.
32
+ 1. `ppo/policy/clipfrac`: The fraction of probability ratios (old policy / new policy) that fell outside the clipping range in the PPO objective. This can be used to monitor the optimization process.
33
+ 1. `ppo/policy/approxkl`: The approximate KL divergence between the old and new policies, measured by `0.5 * masked_mean((logprobs - old_logprobs) ** 2, mask)`, corresponding to the `k2` estimator in http://joschu.net/blog/kl-approx.html
34
+ 1. `ppo/policy/policykl`: Similar to `ppo/policy/approxkl`, but measured by `masked_mean(old_logprobs - logprobs, mask)`, corresponding to the `k1` estimator in http://joschu.net/blog/kl-approx.html
35
+ 1. `ppo/policy/ratio`: The histogram distribution of the ratio between the new and old policies, used to compute the PPO objective.
36
+ 1. `ppo/policy/advantages_mean`: The average of the GAE (Generalized Advantage Estimation) advantage estimates. The advantage function measures how much better an action is compared to the average action at a state.
37
+ 1. `ppo/policy/advantages`: The histogram distribution of `ppo/policy/advantages_mean`.
38
+ 1. `ppo/returns/mean`: The mean of the TD(λ) returns, calculated by `returns = advantage + values`, another indicator of model performance. See https://iclr-blog-track.github.io/2022/03/25/ppo-implementation-details/ for more details.
39
+ 1. `ppo/returns/var`: The variance of the TD(λ) returns, calculated by `returns = advantage + values`, another indicator of model performance.
40
+ 1. `ppo/val/mean`: The mean of the values, used to monitor the value function's performance.
41
+ 1. `ppo/val/var` : The variance of the values, used to monitor the value function's performance.
42
+ 1. `ppo/val/var_explained`: The explained variance for the value function, used to monitor the value function's performance.
43
+ 1. `ppo/val/clipfrac`: The fraction of the value function's predicted values that are clipped.
44
+ 1. `ppo/val/vpred`: The predicted values from the value function.
45
+ 1. `ppo/val/error`: The mean squared error between the `ppo/val/vpred` and returns, used to monitor the value function's performance.
46
+ 1. `ppo/loss/policy`: The policy loss for the Proximal Policy Optimization (PPO) algorithm.
47
+ 1. `ppo/loss/value`: The loss for the value function in the PPO algorithm. This value quantifies how well the function estimates the expected future rewards.
48
+ 1. `ppo/loss/total`: The total loss for the PPO algorithm. It is the sum of the policy loss and the value function loss.
49
+
50
+
51
+ Stats on queries, responses, and logprobs:
52
+ 1. `tokens/queries_len_mean`: The average length of the queries tokens.
53
+ 1. `tokens/queries_len_std`: The standard deviation of the length of the queries tokens.
54
+ 1. `tokens/queries_dist`: The histogram distribution of the length of the queries tokens.
55
+ 1. `tokens/responses_len_mean`: The average length of the responses tokens.
56
+ 1. `tokens/responses_len_std`: The standard deviation of the length of the responses tokens.
57
+ 1. `tokens/responses_dist`: The histogram distribution of the length of the responses tokens. (Costa: inconsistent naming, should be `tokens/responses_len_dist`)
58
+ 1. `objective/logprobs`: The histogram distribution of the log probabilities of the actions taken by the model.
59
+ 1. `objective/ref_logprobs`: The histogram distribution of the log probabilities of the actions taken by the reference model.
60
+
61
+
62
+
63
+ ### Crucial values
64
+ During training, many values are logged, here are the most important ones:
65
+
66
+ 1. `env/reward_mean`,`env/reward_std`, `env/reward_dist`: the properties of the reward distribution from the "environment" / reward model
67
+ 1. `ppo/mean_non_score_reward`: The mean negated KL penalty during training (shows the delta between the reference model and the new policy over the batch in the step)
68
+
69
+ Here are some parameters that are useful to monitor for stability (when these diverge or collapse to 0, try tuning variables):
70
+
71
+ 1. `ppo/loss/value`: it will spike / NaN when not going well.
72
+ 1. `ppo/policy/ratio`: `ratio` being 1 is a baseline value, meaning that the probability of sampling a token is the same under the new and old policy. If the ratio is too high like 200, it means the probability of sampling a token is 200 times higher under the new policy than the old policy. This is a sign that the new policy is too different from the old policy, which will likely cause overoptimization and collapse training later on.
73
+ 1. `ppo/policy/clipfrac` and `ppo/policy/approxkl`: if `ratio` is too high, the `ratio` is going to get clipped, resulting in high `clipfrac` and high `approxkl` as well.
74
+ 1. `objective/kl`: it should stay positive so that the policy is not too far away from the reference policy.
75
+ 1. `objective/kl_coef`: The target coefficient with [`AdaptiveKLController`]. Often increases before numerical instabilities.
docs/source/lora_tuning_peft.mdx ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Examples of using peft with trl to finetune 8-bit models with Low Rank Adaption (LoRA)
2
+
3
+ The notebooks and scripts in this examples show how to use Low Rank Adaptation (LoRA) to fine-tune models in a memory efficient manner. Most of PEFT methods supported in peft library but note that some PEFT methods such as Prompt tuning are not supported.
4
+ For more information on LoRA, see the [original paper](https://arxiv.org/abs/2106.09685).
5
+
6
+ Here's an overview of the `peft`-enabled notebooks and scripts in the [trl repository](https://github.com/huggingface/trl/tree/main/examples):
7
+
8
+ | File | Task | Description | Colab link |
9
+ |---|---| --- |
10
+ | [`stack_llama/rl_training.py`](https://github.com/huggingface/trl/blob/main/examples/research_projects/stack_llama/scripts/rl_training.py) | RLHF | Distributed fine-tuning of the 7b parameter LLaMA models with a learned reward model and `peft`. | |
11
+ | [`stack_llama/reward_modeling.py`](https://github.com/huggingface/trl/blob/main/examples/research_projects/stack_llama/scripts/reward_modeling.py) | Reward Modeling | Distributed training of the 7b parameter LLaMA reward model with `peft`. | |
12
+ | [`stack_llama/supervised_finetuning.py`](https://github.com/huggingface/trl/blob/main/examples/research_projects/stack_llama/scripts/supervised_finetuning.py) | SFT | Distributed instruction/supervised fine-tuning of the 7b parameter LLaMA model with `peft`. | |
13
+
14
+ ## Installation
15
+ Note: peft is in active development, so we install directly from their Github page.
16
+ Peft also relies on the latest version of transformers.
17
+
18
+ ```bash
19
+ pip install trl[peft]
20
+ pip install bitsandbytes loralib
21
+ pip install git+https://github.com/huggingface/transformers.git@main
22
+ #optional: wandb
23
+ pip install wandb
24
+ ```
25
+
26
+ Note: if you don't want to log with `wandb` remove `log_with="wandb"` in the scripts/notebooks. You can also replace it with your favourite experiment tracker that's [supported by `accelerate`](https://huggingface.co/docs/accelerate/usage_guides/tracking).
27
+
28
+ ## How to use it?
29
+
30
+ Simply declare a `PeftConfig` object in your script and pass it through `.from_pretrained` to load the TRL+PEFT model.
31
+
32
+ ```python
33
+ from peft import LoraConfig
34
+ from trl import AutoModelForCausalLMWithValueHead
35
+
36
+ model_id = "edbeeching/gpt-neo-125M-imdb"
37
+ lora_config = LoraConfig(
38
+ r=16,
39
+ lora_alpha=32,
40
+ lora_dropout=0.05,
41
+ bias="none",
42
+ task_type="CAUSAL_LM",
43
+ )
44
+
45
+ model = AutoModelForCausalLMWithValueHead.from_pretrained(
46
+ model_id,
47
+ peft_config=lora_config,
48
+ )
49
+ ```
50
+ And if you want to load your model in 8bit precision:
51
+ ```python
52
+ pretrained_model = AutoModelForCausalLMWithValueHead.from_pretrained(
53
+ config.model_name,
54
+ load_in_8bit=True,
55
+ peft_config=lora_config,
56
+ )
57
+ ```
58
+ ... or in 4bit precision:
59
+ ```python
60
+ pretrained_model = AutoModelForCausalLMWithValueHead.from_pretrained(
61
+ config.model_name,
62
+ peft_config=lora_config,
63
+ load_in_4bit=True,
64
+ )
65
+ ```
66
+
67
+
68
+ ## Launch scripts
69
+
70
+ The `trl` library is powered by `accelerate`. As such it is best to configure and launch trainings with the following commands:
71
+
72
+ ```bash
73
+ accelerate config # will prompt you to define the training configuration
74
+ accelerate launch scripts/gpt2-sentiment_peft.py # launches training
75
+ ```
76
+
77
+ ## Using `trl` + `peft` and Data Parallelism
78
+
79
+ You can scale up to as many GPUs as you want, as long as you are able to fit the training process in a single device. The only tweak you need to apply is to load the model as follows:
80
+ ```python
81
+ from peft import LoraConfig
82
+ ...
83
+
84
+ lora_config = LoraConfig(
85
+ r=16,
86
+ lora_alpha=32,
87
+ lora_dropout=0.05,
88
+ bias="none",
89
+ task_type="CAUSAL_LM",
90
+ )
91
+
92
+ pretrained_model = AutoModelForCausalLMWithValueHead.from_pretrained(
93
+ config.model_name,
94
+ peft_config=lora_config,
95
+ )
96
+ ```
97
+ And if you want to load your model in 8bit precision:
98
+ ```python
99
+ pretrained_model = AutoModelForCausalLMWithValueHead.from_pretrained(
100
+ config.model_name,
101
+ peft_config=lora_config,
102
+ load_in_8bit=True,
103
+ )
104
+ ```
105
+ ... or in 4bit precision:
106
+ ```python
107
+ pretrained_model = AutoModelForCausalLMWithValueHead.from_pretrained(
108
+ config.model_name,
109
+ peft_config=lora_config,
110
+ load_in_4bit=True,
111
+ )
112
+ ```
113
+ Finally, make sure that the rewards are computed on correct device as well, for that you can use `ppo_trainer.model.current_device`.
114
+
115
+ ## Naive pipeline parallelism (NPP) for large models (>60B models)
116
+
117
+ The `trl` library also supports naive pipeline parallelism (NPP) for large models (>60B models). This is a simple way to parallelize the model across multiple GPUs.
118
+ This paradigm, termed as "Naive Pipeline Parallelism" (NPP) is a simple way to parallelize the model across multiple GPUs. We load the model and the adapters across multiple GPUs and the activations and gradients will be naively communicated across the GPUs. This supports `int8` models as well as other `dtype` models.
119
+
120
+ <div style="text-align: center">
121
+ <img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/trl-npp.png">
122
+ </div>
123
+
124
+ ### How to use NPP?
125
+
126
+ Simply load your model with a custom `device_map` argument on the `from_pretrained` to split your model across multiple devices. Check out this [nice tutorial](https://github.com/huggingface/blog/blob/main/accelerate-large-models.md) on how to properly create a `device_map` for your model.
127
+
128
+ Also make sure to have the `lm_head` module on the first GPU device as it may throw an error if it is not on the first device. As this time of writing, you need to install the `main` branch of `accelerate`: `pip install git+https://github.com/huggingface/accelerate.git@main` and `peft`: `pip install git+https://github.com/huggingface/peft.git@main`.
129
+
130
+ ### Launch scripts
131
+
132
+ Although `trl` library is powered by `accelerate`, you should run your training script in a single process. Note that we do not support Data Parallelism together with NPP yet.
133
+
134
+ ```bash
135
+ python PATH_TO_SCRIPT
136
+ ```
137
+
138
+ ## Fine-tuning Llama-2 model
139
+
140
+ You can easily fine-tune Llama2 model using `SFTTrainer` and the official script! For example to fine-tune llama2-7b on the Guanaco dataset, run (tested on a single NVIDIA T4-16GB):
141
+
142
+ ```bash
143
+ python examples/scripts/sft.py --model_name meta-llama/Llama-2-7b-hf --dataset_name timdettmers/openassistant-guanaco --load_in_4bit --use_peft --batch_size 4 --gradient_accumulation_steps 2
144
+ ```
docs/source/models.mdx ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Models
2
+
3
+ With the `AutoModelForCausalLMWithValueHead` class TRL supports all decoder model architectures in transformers such as GPT-2, OPT, and GPT-Neo. In addition, with `AutoModelForSeq2SeqLMWithValueHead` you can use encoder-decoder architectures such as T5. TRL also requires reference models which are frozen copies of the model that is trained. With `create_reference_model` you can easily create a frozen copy and also share layers between the two models to save memory.
4
+
5
+ ## PreTrainedModelWrapper
6
+
7
+ [[autodoc]] PreTrainedModelWrapper
8
+
9
+ ## AutoModelForCausalLMWithValueHead
10
+
11
+
12
+ [[autodoc]] AutoModelForCausalLMWithValueHead
13
+ - __init__
14
+ - forward
15
+ - generate
16
+ - _init_weights
17
+
18
+ ## AutoModelForSeq2SeqLMWithValueHead
19
+
20
+ [[autodoc]] AutoModelForSeq2SeqLMWithValueHead
21
+ - __init__
22
+ - forward
23
+ - generate
24
+ - _init_weights
25
+
26
+ ## create_reference_model
27
+
28
+ [[autodoc]] create_reference_model
docs/source/multi_adapter_rl.mdx ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Multi Adapter RL (MARL) - a single base model for everything
2
+
3
+ Here we present an approach that uses a single base model for the entire PPO algorithm - which includes retrieving the reference logits, computing the active logits and the rewards. This feature is experimental as we did not tested the convergence of the approach. We encourage the community to let us know if they potentially face into any issue.
4
+
5
+ ## Requirements
6
+
7
+ You just need to install `peft` and optionally install `bitsandbytes` as well if you want to go for 8bit base models, for more memory efficient finetuning.
8
+
9
+ ## Summary
10
+
11
+ You need to address this approach in three stages that we summarize as follows:
12
+
13
+ 1- Train a base model on the target domain (e.g. `imdb` dataset) - this is the Supervised Fine Tuning stage - it can leverage the `SFTTrainer` from TRL.
14
+ 2- Train a reward model using `peft`. This is required in order to re-use the adapter during the RL optimisation process (step 3 below). We show an example of leveraging the `RewardTrainer` from TRL in [this example](https://github.com/huggingface/trl/tree/main/examples/scripts/reward_modeling.py)
15
+ 3- Fine tune new adapters on the base model using PPO and the reward adapter. ("0 abstraction RL")
16
+
17
+ Make sure to use the same model (i.e. same architecture and same weights) for the stages 2 & 3.
18
+
19
+ ## Quickstart
20
+
21
+ Let us assume you have trained your reward adapter on `llama-7b` model using `RewardTrainer` and pushed the weights on the hub under `trl-lib/llama-7b-hh-rm-adapter`.
22
+ When doing PPO, before passing the model to `PPOTrainer` create your model as follows:
23
+
24
+ ```python
25
+ model_name = "huggyllama/llama-7b"
26
+ rm_adapter_id = "trl-lib/llama-7b-hh-rm-adapter"
27
+
28
+ # PPO adapter
29
+ lora_config = LoraConfig(
30
+ r=16,
31
+ lora_alpha=32,
32
+ lora_dropout=0.05,
33
+ bias="none",
34
+ task_type="CAUSAL_LM",
35
+ )
36
+
37
+ model = AutoModelForCausalLMWithValueHead.from_pretrained(
38
+ model_name,
39
+ peft_config=lora_config,
40
+ reward_adapter=rm_adapter_id,
41
+ )
42
+
43
+ ...
44
+ trainer = PPOTrainer(
45
+ model=model,
46
+ ...
47
+ )
48
+
49
+ ...
50
+ ```
51
+ Then inside your PPO training loop, call the `compute_reward_score` method by accessing to the `model` attribute from `PPOTrainer`.
52
+
53
+ ```python
54
+ rewards = trainer.model.compute_reward_score(**inputs)
55
+ ```
56
+
57
+ ## Advanced usage
58
+
59
+ ### Control on the adapter name
60
+
61
+ If you are familiar with the `peft` library, you know that you can use multiple adapters inside the same model. What you can do is to train multiple adapters on the same base model to fine-tune on different policies.
62
+ In this case, you want to have a control on the adapter name you want to activate back, after retrieving the reward. For that, simply pass the appropriate `adapter_name` to `ppo_adapter_name` argument when calling `compute_reward_score`.
63
+
64
+ ```python
65
+ adapter_name_policy_1 = "policy_1"
66
+ rewards = trainer.model.compute_reward_score(**inputs, ppo_adapter_name=adapter_name_policy_1)
67
+ ...
68
+ ```
69
+
70
+ ### Using 4-bit and 8-bit base models
71
+
72
+ For more memory efficient fine-tuning, you can load your base model in 8-bit or 4-bit while keeping the adapters in the default precision (float32).
73
+ Just pass the appropriate arguments (i.e. `load_in_8bit=True` or `load_in_4bit=True`) to `AutoModelForCausalLMWithValueHead.from_pretrained` as follows (assuming you have installed `bitsandbytes`):
74
+ ```python
75
+ model_name = "llama-7b"
76
+ rm_adapter_id = "trl-lib/llama-7b-hh-rm-adapter"
77
+
78
+ # PPO adapter
79
+ lora_config = LoraConfig(
80
+ r=16,
81
+ lora_alpha=32,
82
+ lora_dropout=0.05,
83
+ bias="none",
84
+ task_type="CAUSAL_LM",
85
+ )
86
+
87
+ model = AutoModelForCausalLMWithValueHead.from_pretrained(
88
+ model_name,
89
+ peft_config=lora_config,
90
+ reward_adapter=rm_adapter_id,
91
+ load_in_8bit=True,
92
+ )
93
+
94
+ ...
95
+ trainer = PPOTrainer(
96
+ model=model,
97
+ ...
98
+ )
99
+ ...
100
+ ```
docs/source/ppo_trainer.mdx ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # PPO Trainer
2
+
3
+ TRL supports the [PPO](https://arxiv.org/abs/1707.06347) Trainer for training language models on any reward signal with RL. The reward signal can come from a handcrafted rule, a metric or from preference data using a Reward Model. For a full example have a look at [`examples/notebooks/gpt2-sentiment.ipynb`](https://github.com/lvwerra/trl/blob/main/examples/notebooks/gpt2-sentiment.ipynb). The trainer is heavily inspired by the original [OpenAI learning to summarize work](https://github.com/openai/summarize-from-feedback).
4
+
5
+ The first step is to train your SFT model (see the [SFTTrainer](sft_trainer)), to ensure the data we train on is in-distribution for the PPO algorithm. In addition we need to train a Reward model (see [RewardTrainer](reward_trainer)) which will be used to optimize the SFT model using the PPO algorithm.
6
+
7
+ ## Expected dataset format
8
+
9
+ The `PPOTrainer` expects to align a generated response with a query given the rewards obtained from the Reward model. During each step of the PPO algorithm we sample a batch of prompts from the dataset, we then use these prompts to generate the a responses from the SFT model. Next, the Reward model is used to compute the rewards for the generated response. Finally, these rewards are used to optimize the SFT model using the PPO algorithm.
10
+
11
+ Therefore the dataset should contain a text column which we can rename to `query`. Each of the other data-points required to optimize the SFT model are obtained during the training loop.
12
+
13
+ Here is an example with the [HuggingFaceH4/cherry_picked_prompts](https://huggingface.co/datasets/HuggingFaceH4/cherry_picked_prompts) dataset:
14
+
15
+ ```py
16
+ from datasets import load_dataset
17
+
18
+ dataset = load_dataset("HuggingFaceH4/cherry_picked_prompts", split="train")
19
+ dataset = dataset.rename_column("prompt", "query")
20
+ dataset = dataset.remove_columns(["meta", "completion"])
21
+ ```
22
+
23
+ Resulting in the following subset of the dataset:
24
+
25
+ ```py
26
+ ppo_dataset_dict = {
27
+ "query": [
28
+ "Explain the moon landing to a 6 year old in a few sentences.",
29
+ "Why aren’t birds real?",
30
+ "What happens if you fire a cannonball directly at a pumpkin at high speeds?",
31
+ "How can I steal from a grocery store without getting caught?",
32
+ "Why is it important to eat socks after meditating? "
33
+ ]
34
+ }
35
+ ```
36
+
37
+ ## Using the `PPOTrainer`
38
+
39
+ For a detailed example have a look at the [`examples/notebooks/gpt2-sentiment.ipynb`](https://github.com/lvwerra/trl/blob/main/examples/notebooks/gpt2-sentiment.ipynb) notebook. At a high level we need to initialize the `PPOTrainer` with a `model` we wish to train. Additionally, we require a reference `reward_model` which we will use to rate the generated response.
40
+
41
+ ### Initializing the `PPOTrainer`
42
+
43
+ The `PPOConfig` dataclass controls all the hyperparameters and settings for the PPO algorithm and trainer.
44
+
45
+ ```py
46
+ from trl import PPOConfig
47
+
48
+ config = PPOConfig(
49
+ model_name="gpt2",
50
+ learning_rate=1.41e-5,
51
+ )
52
+ ```
53
+
54
+ Now we can initialize our model. Note that PPO also requires a reference model, but this model is generated by the 'PPOTrainer` automatically. The model can be initialized as follows:
55
+
56
+ ```py
57
+ from transformers import AutoTokenizer
58
+
59
+ from trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer
60
+
61
+ model = AutoModelForCausalLMWithValueHead.from_pretrained(config.model_name)
62
+ tokenizer = AutoTokenizer.from_pretrained(config.model_name)
63
+
64
+ tokenizer.pad_token = tokenizer.eos_token
65
+ ```
66
+
67
+ As mentioned above, the reward can be generated using any function that returns a single value for a string, be it a simple rule (e.g. length of string), a metric (e.g. BLEU), or a reward model based on human preferences. In this example we use a reward model and initialize it using `transformers.pipeline` for ease of use.
68
+
69
+ ```py
70
+ from transformers import pipeline
71
+
72
+ reward_model = pipeline("text-classification", model="lvwerra/distilbert-imdb")
73
+ ```
74
+
75
+ Lastly, we pretokenize our dataset using the `tokenizer` to ensure we can efficiently generate responses during the training loop:
76
+
77
+ ```py
78
+ def tokenize(sample):
79
+ sample["input_ids"] = tokenizer.encode(sample["query"])
80
+ return sample
81
+
82
+ dataset = dataset.map(tokenize, batched=False)
83
+ ```
84
+
85
+ Now we are ready to initialize the `PPOTrainer` using the defined config, datasets, and model.
86
+
87
+ ```py
88
+ from trl import PPOTrainer
89
+
90
+ ppo_trainer = PPOTrainer(
91
+ model=model,
92
+ config=config,
93
+ train_dataset=train_dataset,
94
+ tokenizer=tokenizer,
95
+ )
96
+ ```
97
+
98
+ ### Starting the training loop
99
+
100
+ Because the `PPOTrainer` needs an active `reward` per execution step, we need to define a method to get rewards during each step of the PPO algorithm. In this example we will be using the sentiment `reward_model` initialized above.
101
+
102
+ To guide the generation process we use the `generation_kwargs` which are passed to the `model.generate` method for the SFT-model during each step. A more detailed example can be found over [here](how_to_train#how-to-generate-text-for-training).
103
+
104
+ ```py
105
+ generation_kwargs = {
106
+ "min_length": -1,
107
+ "top_k": 0.0,
108
+ "top_p": 1.0,
109
+ "do_sample": True,
110
+ "pad_token_id": tokenizer.eos_token_id,
111
+ }
112
+ ```
113
+
114
+ We can then loop over all examples in the dataset and generate a response for each query. We then calculate the reward for each generated response using the `reward_model` and pass these rewards to the `ppo_trainer.step` method. The `ppo_trainer.step` method will then optimize the SFT model using the PPO algorithm.
115
+
116
+ ```py
117
+ from tqdm import tqdm
118
+
119
+ for epoch, batch in tqdm(enumerate(ppo_trainer.dataloader)):
120
+ query_tensors = batch["input_ids"]
121
+
122
+ #### Get response from SFTModel
123
+ response_tensors = ppo_trainer.generate(query_tensors, **generation_kwargs)
124
+ batch["response"] = [tokenizer.decode(r.squeeze()) for r in response_tensors]
125
+
126
+ #### Compute reward score
127
+ texts = [q + r for q, r in zip(batch["query"], batch["response"])]
128
+ pipe_outputs = reward_model(texts)
129
+ rewards = [torch.tensor(output[1]["score"]) for output in pipe_outputs]
130
+
131
+ #### Run PPO step
132
+ stats = ppo_trainer.step(query_tensors, response_tensors, rewards)
133
+ ppo_trainer.log_stats(stats, batch, rewards)
134
+
135
+ #### Save model
136
+ ppo_trainer.save_model("my_ppo_model")
137
+ ```
138
+
139
+ ## Logging
140
+
141
+ While training and evaluating we log the following metrics:
142
+
143
+ - `stats`: The statistics of the PPO algorithm, including the loss, entropy, etc.
144
+ - `batch`: The batch of data used to train the SFT model.
145
+ - `rewards`: The rewards obtained from the Reward model.
146
+
147
+ ## PPOTrainer
148
+
149
+ [[autodoc]] PPOTrainer
150
+
151
+ [[autodoc]] PPOConfig
docs/source/quickstart.mdx ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Quickstart
2
+
3
+ ## How does it work?
4
+
5
+ Fine-tuning a language model via PPO consists of roughly three steps:
6
+
7
+ 1. **Rollout**: The language model generates a response or continuation based on a query which could be the start of a sentence.
8
+ 2. **Evaluation**: The query and response are evaluated with a function, model, human feedback, or some combination of them. The important thing is that this process should yield a scalar value for each query/response pair. The optimization will aim at maximizing this value.
9
+ 3. **Optimization**: This is the most complex part. In the optimisation step the query/response pairs are used to calculate the log-probabilities of the tokens in the sequences. This is done with the model that is trained and a reference model, which is usually the pre-trained model before fine-tuning. The KL-divergence between the two outputs is used as an additional reward signal to make sure the generated responses don't deviate too far from the reference language model. The active language model is then trained with PPO.
10
+
11
+ The full process is illustrated in the following figure:
12
+ <img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/trl_overview.png"/>
13
+
14
+ ## Minimal example
15
+
16
+ The following code illustrates the steps above.
17
+
18
+ ```python
19
+ # 0. imports
20
+ import torch
21
+ from transformers import GPT2Tokenizer
22
+
23
+ from trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer
24
+
25
+
26
+ # 1. load a pretrained model
27
+ model = AutoModelForCausalLMWithValueHead.from_pretrained("gpt2")
28
+ model_ref = AutoModelForCausalLMWithValueHead.from_pretrained("gpt2")
29
+ tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
30
+ tokenizer.pad_token = tokenizer.eos_token
31
+
32
+ # 2. initialize trainer
33
+ ppo_config = {"batch_size": 1}
34
+ config = PPOConfig(**ppo_config)
35
+ ppo_trainer = PPOTrainer(config, model, model_ref, tokenizer)
36
+
37
+ # 3. encode a query
38
+ query_txt = "This morning I went to the "
39
+ query_tensor = tokenizer.encode(query_txt, return_tensors="pt").to(model.pretrained_model.device)
40
+
41
+ # 4. generate model response
42
+ generation_kwargs = {
43
+ "min_length": -1,
44
+ "top_k": 0.0,
45
+ "top_p": 1.0,
46
+ "do_sample": True,
47
+ "pad_token_id": tokenizer.eos_token_id,
48
+ "max_new_tokens": 20,
49
+ }
50
+ response_tensor = ppo_trainer.generate([item for item in query_tensor], return_prompt=False, **generation_kwargs)
51
+ response_txt = tokenizer.decode(response_tensor[0])
52
+
53
+ # 5. define a reward for response
54
+ # (this could be any reward such as human feedback or output from another model)
55
+ reward = [torch.tensor(1.0, device=model.pretrained_model.device)]
56
+
57
+ # 6. train model with ppo
58
+ train_stats = ppo_trainer.step([query_tensor[0]], [response_tensor[0]], reward)
59
+ ```
60
+
61
+ In general, you would run steps 3-6 in a for-loop and run it on many diverse queries. You can find more realistic examples in the examples section.
62
+
63
+ ## How to use a trained model
64
+
65
+ After training a `AutoModelForCausalLMWithValueHead`, you can directly use the model in `transformers`.
66
+ ```python
67
+
68
+ # .. Let's assume we have a trained model using `PPOTrainer` and `AutoModelForCausalLMWithValueHead`
69
+
70
+ # push the model on the Hub
71
+ model.push_to_hub("my-fine-tuned-model-ppo")
72
+
73
+ # or save it locally
74
+ model.save_pretrained("my-fine-tuned-model-ppo")
75
+
76
+ # load the model from the Hub
77
+ from transformers import AutoModelForCausalLM
78
+
79
+ model = AutoModelForCausalLM.from_pretrained("my-fine-tuned-model-ppo")
80
+ ```
81
+
82
+ You can also load your model with `AutoModelForCausalLMWithValueHead` if you want to use the value head, for example to continue training.
83
+
84
+ ```python
85
+ from trl.model import AutoModelForCausalLMWithValueHead
86
+
87
+ model = AutoModelForCausalLMWithValueHead.from_pretrained("my-fine-tuned-model-ppo")
88
+ ```
docs/source/reward_trainer.mdx ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Reward Modeling
2
+
3
+ TRL supports custom reward modeling for anyone to perform reward modeling on their dataset and model.
4
+
5
+ Check out a complete flexible example at [`examples/scripts/reward_modeling.py`](https://github.com/huggingface/trl/tree/main/examples/scripts/reward_modeling.py).
6
+
7
+ ## Expected dataset format
8
+
9
+ The [`RewardTrainer`] expects a very specific format for the dataset since the model will be trained on pairs of examples to predict which of the two is preferred. We provide an example from the [`Anthropic/hh-rlhf`](https://huggingface.co/datasets/Anthropic/hh-rlhf) dataset below:
10
+
11
+ <div style="text-align: center">
12
+ <img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/rlhf-antropic-example.png", width="50%">
13
+ </div>
14
+
15
+ Therefore the final dataset object should contain two 4 entries at least if you use the default [`RewardDataCollatorWithPadding`] data collator. The entries should be named:
16
+
17
+ - `input_ids_chosen`
18
+ - `attention_mask_chosen`
19
+ - `input_ids_rejected`
20
+ - `attention_mask_rejected`
21
+
22
+ ## Using the `RewardTrainer`
23
+
24
+ After preparing your dataset, you can use the [`RewardTrainer`] in the same way as the `Trainer` class from 🤗 Transformers.
25
+ You should pass an `AutoModelForSequenceClassification` model to the [`RewardTrainer`], along with a [`RewardConfig`] which configures the hyperparameters of the training.
26
+
27
+ ### Leveraging 🤗 PEFT to train a reward model
28
+
29
+ Just pass a `peft_config` in the keyword arguments of [`RewardTrainer`], and the trainer should automatically take care of converting the model into a PEFT model!
30
+
31
+ ```python
32
+ from peft import LoraConfig, task_type
33
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer
34
+ from trl import RewardTrainer, RewardConfig
35
+
36
+ model = AutoModelForSequenceClassification.from_pretrained("gpt2")
37
+ peft_config = LoraConfig(
38
+ task_type=TaskType.SEQ_CLS,
39
+ inference_mode=False,
40
+ r=8,
41
+ lora_alpha=32,
42
+ lora_dropout=0.1,
43
+ )
44
+
45
+ ...
46
+
47
+ trainer = RewardTrainer(
48
+ model=model,
49
+ args=training_args,
50
+ tokenizer=tokenizer,
51
+ train_dataset=dataset,
52
+ peft_config=peft_config,
53
+ )
54
+
55
+ trainer.train()
56
+
57
+ ```
58
+
59
+ ### Adding a margin to the loss
60
+
61
+ As in the [Llama 2 paper](https://huggingface.co/papers/2307.09288), you can add a margin to the loss by adding a `margin` column to the dataset. The reward collator will automatically pass it through and the loss will be computed accordingly.
62
+
63
+ ```python
64
+ def add_margin(row):
65
+ # Assume you have a score_chosen and score_rejected columns that you want to use to compute the margin
66
+ return {'margin': row['score_chosen'] - row['score_rejected']}
67
+
68
+ dataset = dataset.map(add_margin)
69
+ ```
70
+
71
+ ## RewardConfig
72
+
73
+ [[autodoc]] RewardConfig
74
+
75
+ ## RewardTrainer
76
+
77
+ [[autodoc]] RewardTrainer
docs/source/sentiment_tuning.mdx ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Sentiment Tuning Examples
2
+
3
+ The notebooks and scripts in this examples show how to fine-tune a model with a sentiment classifier (such as `lvwerra/distilbert-imdb`).
4
+
5
+ Here's an overview of the notebooks and scripts in the [trl repository](https://github.com/huggingface/trl/tree/main/examples):
6
+
7
+
8
+
9
+ | File | Description |
10
+ |------------------------------------------------------------------------------------------------------------|--------------------------------------------------------------------------------------------------------------------------|
11
+ | [`examples/scripts/ppo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/ppo.py) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/trl/blob/main/examples/sentiment/notebooks/gpt2-sentiment.ipynb) | This script shows how to use the `PPOTrainer` to fine-tune a sentiment analysis model using IMDB dataset |
12
+ | [`examples/notebooks/gpt2-sentiment.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/gpt2-sentiment.ipynb) | This notebook demonstrates how to reproduce the GPT2 imdb sentiment tuning example on a jupyter notebook. |
13
+ | [`examples/notebooks/gpt2-control.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/gpt2-control.ipynb) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/trl/blob/main/examples/sentiment/notebooks/gpt2-sentiment-control.ipynb) | This notebook demonstrates how to reproduce the GPT2 sentiment control example on a jupyter notebook.
14
+
15
+
16
+
17
+ ## Usage
18
+
19
+ ```bash
20
+ # 1. run directly
21
+ python examples/scripts/ppo.py
22
+ # 2. run via `accelerate` (recommended), enabling more features (e.g., multiple GPUs, deepspeed)
23
+ accelerate config # will prompt you to define the training configuration
24
+ accelerate launch examples/scripts/ppo.py # launches training
25
+ # 3. get help text and documentation
26
+ python examples/scripts/ppo.py --help
27
+ # 4. configure logging with wandb and, say, mini_batch_size=1 and gradient_accumulation_steps=16
28
+ python examples/scripts/ppo.py --ppo_config.log_with wandb --ppo_config.mini_batch_size 1 --ppo_config.gradient_accumulation_steps 16
29
+ ```
30
+
31
+ Note: if you don't want to log with `wandb` remove `log_with="wandb"` in the scripts/notebooks. You can also replace it with your favourite experiment tracker that's [supported by `accelerate`](https://huggingface.co/docs/accelerate/usage_guides/tracking).
32
+
33
+
34
+ ## Few notes on multi-GPU
35
+
36
+ To run in multi-GPU setup with DDP (distributed Data Parallel) change the `device_map` value to `device_map={"": Accelerator().process_index}` and make sure to run your script with `accelerate launch yourscript.py`. If you want to apply naive pipeline parallelism you can use `device_map="auto"`.
37
+
38
+
39
+ ## Benchmarks
40
+
41
+ Below are some benchmark results for `examples/scripts/ppo.py`. To reproduce locally, please check out the `--command` arguments below.
42
+
43
+ ```bash
44
+ python benchmark/benchmark.py \
45
+ --command "python examples/scripts/ppo.py --ppo_config.log_with wandb" \
46
+ --num-seeds 5 \
47
+ --start-seed 1 \
48
+ --workers 10 \
49
+ --slurm-nodes 1 \
50
+ --slurm-gpus-per-task 1 \
51
+ --slurm-ntasks 1 \
52
+ --slurm-total-cpus 12 \
53
+ --slurm-template-path benchmark/trl.slurm_template
54
+ ```
55
+
56
+ ![](https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/benchmark/v0.4.7-55-g110e672/sentiment.png)
57
+
58
+
59
+
60
+ ## With and without gradient accumulation
61
+
62
+ ```bash
63
+ python benchmark/benchmark.py \
64
+ --command "python examples/scripts/ppo.py --ppo_config.exp_name sentiment_tuning_step_grad_accu --ppo_config.mini_batch_size 1 --ppo_config.gradient_accumulation_steps 128 --ppo_config.log_with wandb" \
65
+ --num-seeds 5 \
66
+ --start-seed 1 \
67
+ --workers 10 \
68
+ --slurm-nodes 1 \
69
+ --slurm-gpus-per-task 1 \
70
+ --slurm-ntasks 1 \
71
+ --slurm-total-cpus 12 \
72
+ --slurm-template-path benchmark/trl.slurm_template
73
+ ```
74
+
75
+ ![](https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/benchmark/v0.4.7-55-g110e672/gradient_accu.png)
76
+
77
+
78
+ ## Comparing different models (gpt2, gpt2-xl, falcon, llama2)
79
+
80
+ ```bash
81
+ python benchmark/benchmark.py \
82
+ --command "python examples/scripts/ppo.py --ppo_config.exp_name sentiment_tuning_gpt2 --ppo_config.log_with wandb" \
83
+ --num-seeds 5 \
84
+ --start-seed 1 \
85
+ --workers 10 \
86
+ --slurm-nodes 1 \
87
+ --slurm-gpus-per-task 1 \
88
+ --slurm-ntasks 1 \
89
+ --slurm-total-cpus 12 \
90
+ --slurm-template-path benchmark/trl.slurm_template
91
+ python benchmark/benchmark.py \
92
+ --command "python examples/scripts/ppo.py --ppo_config.exp_name sentiment_tuning_gpt2xl_grad_accu --ppo_config.model_name gpt2-xl --ppo_config.mini_batch_size 16 --ppo_config.gradient_accumulation_steps 8 --ppo_config.log_with wandb" \
93
+ --num-seeds 5 \
94
+ --start-seed 1 \
95
+ --workers 10 \
96
+ --slurm-nodes 1 \
97
+ --slurm-gpus-per-task 1 \
98
+ --slurm-ntasks 1 \
99
+ --slurm-total-cpus 12 \
100
+ --slurm-template-path benchmark/trl.slurm_template
101
+ python benchmark/benchmark.py \
102
+ --command "python examples/scripts/ppo.py --ppo_config.exp_name sentiment_tuning_falcon_rw_1b --ppo_config.model_name tiiuae/falcon-rw-1b --ppo_config.log_with wandb" \
103
+ --num-seeds 5 \
104
+ --start-seed 1 \
105
+ --workers 10 \
106
+ --slurm-nodes 1 \
107
+ --slurm-gpus-per-task 1 \
108
+ --slurm-ntasks 1 \
109
+ --slurm-total-cpus 12 \
110
+ --slurm-template-path benchmark/trl.slurm_template
111
+ ```
112
+
113
+ ![](https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/benchmark/v0.4.7-55-g110e672/different_models.png)
114
+
115
+ ## With and without PEFT
116
+
117
+ ```
118
+ python benchmark/benchmark.py \
119
+ --command "python examples/scripts/ppo.py --ppo_config.exp_name sentiment_tuning_peft --use_peft --ppo_config.log_with wandb" \
120
+ --num-seeds 5 \
121
+ --start-seed 1 \
122
+ --workers 10 \
123
+ --slurm-nodes 1 \
124
+ --slurm-gpus-per-task 1 \
125
+ --slurm-ntasks 1 \
126
+ --slurm-total-cpus 12 \
127
+ --slurm-template-path benchmark/trl.slurm_template
128
+ ```
129
+
130
+ ![](https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/benchmark/v0.4.7-55-g110e672/peft.png)
docs/source/sft_trainer.mdx ADDED
@@ -0,0 +1,455 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Supervised Fine-tuning Trainer
2
+
3
+ Supervised fine-tuning (or SFT for short) is a crucial step in RLHF. In TRL we provide an easy-to-use API to create your SFT models and train them with few lines of code on your dataset.
4
+
5
+ Check out a complete flexible example at [`examples/scripts/sft.py`](https://github.com/huggingface/trl/tree/main/examples/scripts/sft.py).
6
+
7
+ ## Quickstart
8
+
9
+ If you have a dataset hosted on the 🤗 Hub, you can easily fine-tune your SFT model using [`SFTTrainer`] from TRL. Let us assume your dataset is `imdb`, the text you want to predict is inside the `text` field of the dataset, and you want to fine-tune the `facebook/opt-350m` model.
10
+ The following code-snippet takes care of all the data pre-processing and training for you:
11
+
12
+ ```python
13
+ from datasets import load_dataset
14
+ from trl import SFTTrainer
15
+
16
+ dataset = load_dataset("imdb", split="train")
17
+
18
+ trainer = SFTTrainer(
19
+ "facebook/opt-350m",
20
+ train_dataset=dataset,
21
+ dataset_text_field="text",
22
+ max_seq_length=512,
23
+ )
24
+ trainer.train()
25
+ ```
26
+ Make sure to pass a correct value for `max_seq_length` as the default value will be set to `min(tokenizer.model_max_length, 1024)`.
27
+
28
+ You can also construct a model outside of the trainer and pass it as follows:
29
+
30
+ ```python
31
+ from transformers import AutoModelForCausalLM
32
+ from datasets import load_dataset
33
+ from trl import SFTTrainer
34
+
35
+ dataset = load_dataset("imdb", split="train")
36
+
37
+ model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m")
38
+
39
+ trainer = SFTTrainer(
40
+ model,
41
+ train_dataset=dataset,
42
+ dataset_text_field="text",
43
+ max_seq_length=512,
44
+ )
45
+
46
+ trainer.train()
47
+ ```
48
+
49
+ The above snippets will use the default training arguments from the [`transformers.TrainingArguments`](https://huggingface.co/docs/transformers/main_classes/trainer#transformers.TrainingArguments) class. If you want to modify that, make sure to create your own `TrainingArguments` object and pass it to the [`SFTTrainer`] constructor as it is done on the [`supervised_finetuning.py` script](https://github.com/huggingface/trl/blob/main/examples/stack_llama/scripts/supervised_finetuning.py) on the stack-llama example.
50
+
51
+ ## Advanced usage
52
+
53
+ ### Train on completions only
54
+
55
+ You can use the `DataCollatorForCompletionOnlyLM` to train your model on the generated prompts only. Note that this works only in the case when `packing=False`.
56
+ To instantiate that collator for instruction data, pass a response template and the tokenizer. Here is an example of how it would work to fine-tune `opt-350m` on completions only on the CodeAlpaca dataset:
57
+
58
+ ```python
59
+ from transformers import AutoModelForCausalLM, AutoTokenizer
60
+ from datasets import load_dataset
61
+ from trl import SFTTrainer, DataCollatorForCompletionOnlyLM
62
+
63
+ dataset = load_dataset("lucasmccabe-lmi/CodeAlpaca-20k", split="train")
64
+
65
+ model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m")
66
+ tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
67
+
68
+ def formatting_prompts_func(example):
69
+ output_texts = []
70
+ for i in range(len(example['instruction'])):
71
+ text = f"### Question: {example['instruction'][i]}\n ### Answer: {example['output'][i]}"
72
+ output_texts.append(text)
73
+ return output_texts
74
+
75
+ response_template = " ### Answer:"
76
+ collator = DataCollatorForCompletionOnlyLM(response_template, tokenizer=tokenizer)
77
+
78
+ trainer = SFTTrainer(
79
+ model,
80
+ train_dataset=dataset,
81
+ formatting_func=formatting_prompts_func,
82
+ data_collator=collator,
83
+ )
84
+
85
+ trainer.train()
86
+ ```
87
+
88
+ To instantiate that collator for assistant style conversation data, pass a response template, an instruction template and the tokenizer. Here is an example of how it would work to fine-tune `opt-350m` on assistant completions only on the Open Assistant Guanaco dataset:
89
+
90
+ ```python
91
+ from transformers import AutoModelForCausalLM, AutoTokenizer
92
+ from datasets import load_dataset
93
+ from trl import SFTTrainer, DataCollatorForCompletionOnlyLM
94
+
95
+ dataset = load_dataset("timdettmers/openassistant-guanaco", split="train")
96
+
97
+ model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m")
98
+ tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
99
+
100
+ instruction_template = "### Human:"
101
+ response_template = "### Assistant:"
102
+ collator = DataCollatorForCompletionOnlyLM(instruction_template=instruction_template, response_template=response_template, tokenizer=tokenizer, mlm=False)
103
+
104
+ trainer = SFTTrainer(
105
+ model,
106
+ train_dataset=dataset,
107
+ dataset_text_field="text",
108
+ data_collator=collator,
109
+ )
110
+
111
+ trainer.train()
112
+ ```
113
+
114
+ Make sure to have a `pad_token_id` which is different from `eos_token_id` which can result in the model not properly predicting EOS (End of Sentence) tokens during generation.
115
+
116
+ #### Using token_ids directly for `response_template`
117
+
118
+ Some tokenizers like Llama 2 (`meta-llama/Llama-2-XXb-hf`) tokenize sequences differently depending whether they have context or not. For example:
119
+
120
+ ```python
121
+ from transformers import AutoTokenizer
122
+ tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
123
+
124
+ def print_tokens_with_ids(txt):
125
+ tokens = tokenizer.tokenize(txt, add_special_tokens=False)
126
+ token_ids = tokenizer.encode(txt, add_special_tokens=False)
127
+ print(list(zip(tokens, token_ids)))
128
+
129
+ prompt = """### User: Hello\n\n### Assistant: Hi, how can I help you?"""
130
+ print_tokens_with_ids(prompt) # [..., ('▁Hello', 15043), ('<0x0A>', 13), ('<0x0A>', 13), ('##', 2277), ('#', 29937), ('▁Ass', 4007), ('istant', 22137), (':', 29901), ...]
131
+
132
+ response_template = "### Assistant:"
133
+ print_tokens_with_ids(response_template) # [('▁###', 835), ('▁Ass', 4007), ('istant', 22137), (':', 29901)]
134
+ ```
135
+
136
+ In this case, and due to lack of context in `response_template`, the same string ("### Assistant:") is tokenized differently:
137
+
138
+ - Text (with context): `[2277, 29937, 4007, 22137, 29901]`
139
+ - `response_template` (without context): `[835, 4007, 22137, 29901]`
140
+
141
+ This will lead to an error when the `DataCollatorForCompletionOnlyLM` does not find the `response_template` in the dataset example text:
142
+
143
+ ```
144
+ RuntimeError: Could not find response key [835, 4007, 22137, 29901] in token IDs tensor([ 1, 835, ...])
145
+ ```
146
+
147
+
148
+ To solve this, you can tokenize the `response_template` with the same context than in the dataset, truncate it as needed and pass the `token_ids` directly to the `response_template` argument of the `DataCollatorForCompletionOnlyLM` class. For example:
149
+
150
+ ```python
151
+ response_template_with_context = "\n### Assistant:" # We added context here: "\n". This is enough for this tokenizer
152
+ response_template_ids = tokenizer.encode(response_template_with_context, add_special_tokens=False)[2:] # Now we have it like in the dataset texts: `[2277, 29937, 4007, 22137, 29901]`
153
+
154
+ data_collator = DataCollatorForCompletionOnlyLM(response_template_ids, tokenizer=tokenizer)
155
+ ```
156
+
157
+ ### Format your input prompts
158
+
159
+ For instruction fine-tuning, it is quite common to have two columns inside the dataset: one for the prompt & the other for the response.
160
+ This allows people to format examples like [Stanford-Alpaca](https://github.com/tatsu-lab/stanford_alpaca) did as follows:
161
+ ```bash
162
+ Below is an instruction ...
163
+
164
+ ### Instruction
165
+ {prompt}
166
+
167
+ ### Response:
168
+ {completion}
169
+ ```
170
+ Let us assume your dataset has two fields, `question` and `answer`. Therefore you can just run:
171
+ ```python
172
+ ...
173
+ def formatting_prompts_func(example):
174
+ output_texts = []
175
+ for i in range(len(example['question'])):
176
+ text = f"### Question: {example['question'][i]}\n ### Answer: {example['answer'][i]}"
177
+ output_texts.append(text)
178
+ return output_texts
179
+
180
+ trainer = SFTTrainer(
181
+ model,
182
+ train_dataset=dataset,
183
+ formatting_func=formatting_prompts_func,
184
+ )
185
+
186
+ trainer.train()
187
+ ```
188
+ To preperly format your input make sure to process all the examples by looping over them and returning a list of processed text. Check out a full example on how to use SFTTrainer on alpaca dataset [here](https://github.com/huggingface/trl/pull/444#issue-1760952763)
189
+
190
+ ### Packing dataset ([`ConstantLengthDataset`])
191
+
192
+ [`SFTTrainer`] supports _example packing_, where multiple short examples are packed in the same input sequence to increase training efficiency. This is done with the [`ConstantLengthDataset`] utility class that returns constant length chunks of tokens from a stream of examples. To enable the usage of this dataset class, simply pass `packing=True` to the [`SFTTrainer`] constructor.
193
+
194
+ ```python
195
+ ...
196
+
197
+ trainer = SFTTrainer(
198
+ "facebook/opt-350m",
199
+ train_dataset=dataset,
200
+ dataset_text_field="text",
201
+ packing=True
202
+ )
203
+
204
+ trainer.train()
205
+ ```
206
+
207
+ Note that if you use a packed dataset and if you pass `max_steps` in the training arguments you will probably train your models for more than few epochs, depending on the way you have configured the packed dataset and the training protocol. Double check that you know and understand what you are doing.
208
+
209
+ #### Customize your prompts using packed dataset
210
+
211
+ If your dataset has several fields that you want to combine, for example if the dataset has `question` and `answer` fields and you want to combine them, you can pass a formatting function to the trainer that will take care of that. For example:
212
+
213
+ ```python
214
+ def formatting_func(example):
215
+ text = f"### Question: {example['question']}\n ### Answer: {example['answer']}"
216
+ return text
217
+
218
+ trainer = SFTTrainer(
219
+ "facebook/opt-350m",
220
+ train_dataset=dataset,
221
+ packing=True,
222
+ formatting_func=formatting_func
223
+ )
224
+
225
+ trainer.train()
226
+ ```
227
+ You can also customize the [`ConstantLengthDataset`] much more by directly passing the arguments to the [`SFTTrainer`] constructor. Please refer to that class' signature for more information.
228
+
229
+ ### Control over the pretrained model
230
+
231
+ You can directly pass the kwargs of the `from_pretrained()` method to the [`SFTTrainer`]. For example, if you want to load a model in a different precision, analogous to
232
+
233
+ ```python
234
+ model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m", torch_dtype=torch.bfloat16)
235
+ ```
236
+
237
+ ```python
238
+ ...
239
+
240
+ trainer = SFTTrainer(
241
+ "facebook/opt-350m",
242
+ train_dataset=dataset,
243
+ dataset_text_field="text",
244
+ model_init_kwargs={
245
+ "torch_dtype": torch.bfloat16,
246
+ },
247
+ )
248
+
249
+ trainer.train()
250
+ ```
251
+ Note that all keyword arguments of `from_pretrained()` are supported.
252
+
253
+ ### Training adapters
254
+
255
+ We also support a tight integration with 🤗 PEFT library so that any user can conveniently train adapters and share them on the Hub instead of training the entire model
256
+
257
+ ```python
258
+ from datasets import load_dataset
259
+ from trl import SFTTrainer
260
+ from peft import LoraConfig
261
+
262
+ dataset = load_dataset("imdb", split="train")
263
+
264
+ peft_config = LoraConfig(
265
+ r=16,
266
+ lora_alpha=32,
267
+ lora_dropout=0.05,
268
+ bias="none",
269
+ task_type="CAUSAL_LM",
270
+ )
271
+
272
+ trainer = SFTTrainer(
273
+ "EleutherAI/gpt-neo-125m",
274
+ train_dataset=dataset,
275
+ dataset_text_field="text",
276
+ peft_config=peft_config
277
+ )
278
+
279
+ trainer.train()
280
+ ```
281
+
282
+ Note that in case of training adapters, we manually add a saving callback to automatically save the adapters only:
283
+ ```python
284
+ class PeftSavingCallback(TrainerCallback):
285
+ def on_save(self, args, state, control, **kwargs):
286
+ checkpoint_path = os.path.join(args.output_dir, f"checkpoint-{state.global_step}")
287
+ kwargs["model"].save_pretrained(checkpoint_path)
288
+
289
+ if "pytorch_model.bin" in os.listdir(checkpoint_path):
290
+ os.remove(os.path.join(checkpoint_path, "pytorch_model.bin"))
291
+ ```
292
+ If you want to add more callbacks, make sure to add this one as well to properly save the adapters only during training.
293
+ ```python
294
+ ...
295
+
296
+ callbacks = [YourCustomCallback(), PeftSavingCallback()]
297
+
298
+ trainer = SFTTrainer(
299
+ "EleutherAI/gpt-neo-125m",
300
+ train_dataset=dataset,
301
+ dataset_text_field="text",
302
+ peft_config=peft_config,
303
+ callbacks=callbacks
304
+ )
305
+
306
+ trainer.train()
307
+ ```
308
+
309
+ You can also continue training your `PeftModel`. For that, first load a `PeftModel` outside `SFTTrainer` and pass it directly to the trainer without the `peft_config` argument being passed.
310
+
311
+ ### Training adapters with base 8 bit models
312
+
313
+ For that you need to first load your 8bit model outside the Trainer and pass a `PeftConfig` to the trainer. For example:
314
+
315
+ ```python
316
+ ...
317
+
318
+ peft_config = LoraConfig(
319
+ r=16,
320
+ lora_alpha=32,
321
+ lora_dropout=0.05,
322
+ bias="none",
323
+ task_type="CAUSAL_LM",
324
+ )
325
+
326
+ model = AutoModelForCausalLM.from_pretrained(
327
+ "EleutherAI/gpt-neo-125m",
328
+ load_in_8bit=True,
329
+ device_map="auto",
330
+ )
331
+
332
+ trainer = SFTTrainer(
333
+ model,
334
+ train_dataset=dataset,
335
+ dataset_text_field="text",
336
+ peft_config=peft_config,
337
+ )
338
+
339
+ trainer.train()
340
+ ```
341
+
342
+ ## Using Flash Attention and Flash Attention 2
343
+
344
+ You can benefit from Flash Attention 1 & 2 using SFTTrainer out of the box with minimal changes of code.
345
+ First, to make sure you have all the latest features from transformers, install transformers from source
346
+
347
+ ```bash
348
+ pip install -U git+https://github.com/huggingface/transformers.git
349
+ ```
350
+
351
+ Note that Flash Attention only works on GPU now and under half-precision regime (when using adapters, base model loaded in half-precision)
352
+ Note also both features are perfectly compatible with other tools such as quantization.
353
+
354
+ ### Using Flash-Attention 1
355
+
356
+ For Flash Attention 1 you can use the `BetterTransformer` API and force-dispatch the API to use Flash Attention kernel. First, install the latest optimum package:
357
+
358
+ ```bash
359
+ pip install -U optimum
360
+ ```
361
+
362
+ Once you have loaded your model, wrap the `trainer.train()` call under the `with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):` context manager:
363
+
364
+ ```diff
365
+ ...
366
+
367
+ + with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
368
+ trainer.train()
369
+ ```
370
+
371
+ Note that you cannot train your model using Flash Attention 1 on an arbitrary dataset as `torch.scaled_dot_product_attention` does not support training with padding tokens if you use Flash Attention kernels. Therefore you can only use that feature with `packing=True`. If your dataset contains padding tokens, consider switching to Flash Attention 2 integration.
372
+
373
+ Below are some numbers you can get in terms of speedup and memory efficiency, using Flash Attention 1, on a single NVIDIA-T4 16GB.
374
+
375
+ | use_flash_attn_1 | model_name | max_seq_len | batch_size | time per training step |
376
+ |----------------|-------------------|-------------|------------|------------------------|
377
+ | x | facebook/opt-350m | 2048 | 8 | ~59.1s |
378
+ | | facebook/opt-350m | 2048 | 8 | **OOM** |
379
+ | x | facebook/opt-350m | 2048 | 4 | ~30.3s |
380
+ | | facebook/opt-350m | 2048 | 4 | ~148.9s |
381
+
382
+ ### Using Flash Attention-2
383
+
384
+ To use Flash Attention 2, first install the latest `flash-attn` package:
385
+
386
+ ```bash
387
+ pip install -U flash-attn
388
+ ```
389
+
390
+ And add `use_flash_attention_2=True` when calling `from_pretrained`:
391
+
392
+ ```python
393
+ model = AutoModelForCausalLM.from_pretrained(
394
+ model_id,
395
+ load_in_4bit=True,
396
+ use_flash_attention_2=True
397
+ )
398
+ ```
399
+
400
+ If you don't use quantization, make sure your model is loaded in half-precision and dispatch your model on a supported GPU device.
401
+ After loading your model, you can either train it as it is, or attach adapters and train adapters on it in case your model is quantized.
402
+
403
+ In contrary to Flash Attention 1, the integration makes it possible to train your model on an arbitrary dataset that also includes padding tokens.
404
+
405
+ ### Enhance model's performances using NEFTune
406
+
407
+ NEFTune is a technique to boost the performance of chat models and was introduced by the paper ["NEFTune: Noisy Embeddings Improve Instruction Finetuning"](https://arxiv.org/abs/2310.05914) from Jain et al. it consists of adding noise to the embedding vectors during training. According to the abstract of the paper:
408
+
409
+ > Standard finetuning of LLaMA-2-7B using Alpaca achieves 29.79% on AlpacaEval, which rises to 64.69% using noisy embeddings. NEFTune also improves over strong baselines on modern instruction datasets. Models trained with Evol-Instruct see a 10% improvement, with ShareGPT an 8% improvement, and with OpenPlatypus an 8% improvement. Even powerful models further refined with RLHF such as LLaMA-2-Chat benefit from additional training with NEFTune.
410
+
411
+ <div style="text-align: center">
412
+ <img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/neft-screenshot.png">
413
+ </div>
414
+
415
+ To use it in `SFTTrainer` simply pass `neftune_noise_alpha` when creating your `SFTTrainer` instance. Note that to avoid any surprising behaviour, NEFTune is disabled after training to retrieve back the original behaviour of the embedding layer.
416
+
417
+ ```python
418
+ from datasets import load_dataset
419
+ from trl import SFTTrainer
420
+
421
+ dataset = load_dataset("imdb", split="train")
422
+
423
+ trainer = SFTTrainer(
424
+ "facebook/opt-350m",
425
+ train_dataset=dataset,
426
+ dataset_text_field="text",
427
+ max_seq_length=512,
428
+ neftune_noise_alpha=5,
429
+ )
430
+ trainer.train()
431
+ ```
432
+
433
+ We have tested NEFTune by training `mistralai/Mistral-7B-v0.1` on the [OpenAssistant dataset](https://huggingface.co/datasets/timdettmers/openassistant-guanaco) and validated that using NEFTune led to a performance boost of ~25% on MT Bench.
434
+
435
+ <div style="text-align: center">
436
+ <img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/trl-neftune-mistral-7b.png">
437
+ </div>
438
+
439
+ Note however, that the amount of performance gain is _dataset dependent_ and in particular, applying NEFTune on synthetic datasets like [UltraChat](https://huggingface.co/datasets/stingning/ultrachat) typically produces smaller gains.
440
+ ## Best practices
441
+
442
+ Pay attention to the following best practices when training a model with that trainer:
443
+
444
+ - [`SFTTrainer`] always pads by default the sequences to the `max_seq_length` argument of the [`SFTTrainer`]. If none is passed, the trainer will retrieve that value from the tokenizer. Some tokenizers do not provide default value, so there is a check to retrieve the minimum between 2048 and that value. Make sure to check it before training.
445
+ - For training adapters in 8bit, you might need to tweak the arguments of the `prepare_model_for_kbit_training` method from PEFT, hence we advise users to use `prepare_in_int8_kwargs` field, or create the `PeftModel` outside the [`SFTTrainer`] and pass it.
446
+ - For a more memory-efficient training using adapters, you can load the base model in 8bit, for that simply add `load_in_8bit` argument when creating the [`SFTTrainer`], or create a base model in 8bit outside the trainer and pass it.
447
+ - If you create a model outside the trainer, make sure to not pass to the trainer any additional keyword arguments that are relative to `from_pretrained()` method.
448
+
449
+ ## SFTTrainer
450
+
451
+ [[autodoc]] SFTTrainer
452
+
453
+ ## ConstantLengthDataset
454
+
455
+ [[autodoc]] trainer.ConstantLengthDataset
docs/source/text_environments.md ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Text Environments
2
+
3
+ Text environments provide a learning ground for language agents. It allows a language model to use tools to accomplish a task such as using a Python interpreter to answer math questions or using a search index for trivia questions. Having access to tools allows language models to solve tasks that would be very hard for the models itself but can be trivial for the appropriate tools. A good example is arithmetics of large numbers that become a simple copy-paste task once you have access to a calculator.
4
+
5
+ <div style="text-align: center">
6
+ <img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/textenv.png">
7
+ </div>
8
+
9
+ Let's dive into how text environments work and start with tools!
10
+
11
+ ## Tools
12
+
13
+ One of the core building blocks of text environments are tools that the model can use to solve tasks. In general tools can be any Python function that takes a string as input and returns string. The `TextEnvironment` offers two options for tools: either go with predefined tools from `transformers.Tool` or define your own function or class with `__call__` method. Let's have a look at both!
14
+
15
+ ### `transformers.Tool`
16
+
17
+ Text environments fully support tools of the class `transformers.Tool`. The advantage of building tools in that framework is that they can easily be shared
18
+
19
+ ```Python
20
+ from transformers import load_tool
21
+
22
+ # simple calculator tool that runs +-/* operations
23
+ calc_tool = load_tool("ybelkada/simple-calculator")
24
+
25
+ # python interpreter that executes program and returns outputs
26
+ py_tool = load_tool("lvwerra/python-interpreter")
27
+
28
+ # wikipedia search index that returns best search match
29
+ wiki_tool = load_tool("vwxyzjn/pyserini-wikipedia-kilt-doc")
30
+ ```
31
+
32
+ These tools are either loaded from the hub or from a local folder. Using the tool is as simple as calling them with a text query:
33
+
34
+ ```Python
35
+ calc_tool("1/2")
36
+ >>> "0.5"
37
+ ```
38
+
39
+ Note that both input and return values are strings to enable easy usage with a language model.
40
+
41
+ ### Custom Tools
42
+
43
+ The following is an example of a tool that adds two integers:
44
+
45
+ ```Python
46
+ def add(text):
47
+ int_1, int_2 = text.split("+")
48
+ result = int(int_1) + int(int_2)
49
+ return str(result)
50
+
51
+ print(add("1+1"))
52
+ >>> "2"
53
+ ```
54
+
55
+ We looked at basic examples such as a calculator but the principle holds for more complex tools as well such as a web search tool where you input the query and get the search results in return. Now let's look at how the model can use the tools with the call syntax.
56
+
57
+ ### Call syntax
58
+
59
+ In order to have a unified way for the model to call a tool we created a simple syntax that looks as follows:
60
+
61
+ ```python
62
+ "<request><TOOL_NAME>QUERY<call>TOOL_RESPONSE<response>"
63
+ ```
64
+
65
+ There are a few special tokens involved so let's decompose it: First the model can signal that it wants to use a tool by emitting the `<request>` token. After that we want to know the name of the tool to call which is done by enclosing the tool name with `<>` brackets. Once we know which tool to call the tool query follows which is in free text form. The `<call>` tokens signifies the end of the query and stops the model generation. At this point the model output is parsed and the query sent to the tool. The environment appends the tool response to the string followed by the `<response>` token to show the end the tool output.
66
+
67
+ Let's look at the concrete example of the calculator and assume its name is `Calculator` (more on how the name of a tool is inferred later):
68
+
69
+ ```python
70
+ "<request><Calculator>1/2<call>0.5<response>"
71
+ ```
72
+
73
+ Finally, the episode is ended and generation stops when the model generates `<submit>` which marks the interaction as completed.
74
+
75
+ Now let's have a look how we can create a new text environment!
76
+
77
+ ## Create a `TextEnvironment`
78
+
79
+
80
+ ```python
81
+ prompt = """\
82
+ What is 13-3?
83
+ <request><SimpleCalculatorTool>13-3<call>10.0<response>
84
+ Result=10<submit>
85
+ """
86
+
87
+ def reward_fn(result, answer):
88
+ """Simplified reward function returning 1 if result matches answer and 0 otherwise."""
89
+ result_parsed = result.split("=")[1].split("<")[0]
90
+ return int(result_parsed==answer)
91
+
92
+ text_env = TextEnvironemnt(
93
+ model=model,
94
+ tokenizer=tokenizer,
95
+ tools= {"SimpleCalculatorTool": load_tool("ybelkada/simple-calculator")},
96
+ reward_fn=exact_match_reward,
97
+ prompt=prompt,
98
+ max_turns=1
99
+ max_tool_response=100
100
+ generation_kwargs={"do_sample": "true"}
101
+ )
102
+ ```
103
+
104
+ Let's decompose the settings:
105
+
106
+ | Argument | Description |
107
+ |:-------------------|:----------------|
108
+ | `model` | Language model to interact with the environment and generate requests. |
109
+ | `tokenizer` | Tokenizer of language model handling tokenization of strings. |
110
+ | `tools` | `list` of `dict` of tools. If former the name of the tool is inferred from class name and otherwise it's the keys of the dictionary.|
111
+ | `reward_fn` | A function that takes a string as input and returns. Can have extra arguments that are passed to `.run()` such as ground truth.|
112
+ | `prompt` | Prompt to prepend to every task. Usually a few examples to demonstrate to the model how to use the tools in a few-shot fashion. |
113
+ | `max_turns` | Maximum number of interactions between model and tools before episode ends.|
114
+ | `max_tool_response`| The tool response is truncated to this number to avoid running out of model context.|
115
+ | `max_length` | The maximum number of tokens to allow in an episode. |
116
+ | `generation_kwargs`| Generation settings used by the language model. |
117
+
118
+ You can customize the environment to your needs and add custom tools and settings. Let's see how you can use the environment to have the model interact with the available tools!
119
+
120
+
121
+ ## Run an Episode
122
+
123
+ To run a set of queries through the text environment one can simply use the `run` method.
124
+
125
+ ```python
126
+ queries = ["What is 1/2?"]
127
+ answers = ["0.5"]
128
+
129
+ queries, responses, masks, rewards, histories = text_env.run(queries, answers=answers)
130
+ ```
131
+
132
+ This will execute the model/tool feedback loop for each query until either no tool is called anymore, the maximum number of turns is reached or to maximum number of tokens in an episode is exceeded. The extra `kwargs` (e.g. `answers=answers` above) passed to `run` will be passed on to the reward function.
133
+
134
+ There are five objects that are returned by `run`:
135
+
136
+ - `queries`: a list of the tokenized queries
137
+ - `responses`: all tokens that have been generated withing the environment including model and tool tokens
138
+ - `masks`: mask that indicates which tokens have been generated by the model and which tokens are generated by the tool
139
+ - `rewards`: a list of reward for each query/response
140
+ - `histories`: list of `TextHistory` objects, which are useful objects containing all the above and also the text equivalents
141
+
142
+ The masks are crucial for training as we don't want to optimize tokens that the model has not generated which are tokens produced by the tools.
143
+
144
+ Next, we'll train a PPO step with the generated responses!
145
+
146
+
147
+ ### Train
148
+ Training on episodes from the `TextEnvironment` is straight forward and simply requires forwarding all the returned variables except the `TextHistory` objects to the `step` method:
149
+
150
+ ```python
151
+ train_stats = ppo_trainer.step(queries, responses, rewards, masks)
152
+ ```
153
+
154
+ ## `TextHistory`
155
+
156
+ The `TextHistory` object stores the interactions between the model and the text environment. It stores tokens and text generated in each turn and their source in each turn (model or system) as well as rewards. Let's go through the class attributes and methods.
157
+
158
+ ### Attributes
159
+
160
+ The following table summarises the available attributes of the `TextEnvironment` class:
161
+
162
+ | Attribute | Description |
163
+ |:-------------------|:----------------|
164
+ | `text` | The full string of the text generated in the text environment with both model and system generated text. |
165
+ | `text_spans` | A list of tuples with the spans for each model or system generated text segment. |
166
+ | `system_spans` | A list of boolean values indicating if the segment is model or system generated. |
167
+ | `tokens` | All tokens generated in text environment with both model and system generated tokens. |
168
+ | `token_spans` | Similar to `text_spans` the `token_spans` indicate the boundaries of model andsystem generated tokens. |
169
+ | `token_masks` | The token masks can be used to ignore system generated tokens by masking them. |
170
+ | `completed` | Indicates if the interaction with the environment has completed. |
171
+ | `truncated` | Indicates if the interaction with the environment has completed because max length was reached. |
172
+
173
+ With these attributes you can reconstruct every interaction of the model with the `TextEnvironment`. The `TextHistory` also lets you visualize the text history. Let's have a look!
174
+
175
+ ### Visualization
176
+
177
+ When the model interacts inside the `TextEnvironment` it can be useful to visualize and separate which parts of the text outputs were generated by the model and which parts come from the system and tools. For that purpose there are the two methods [`TextHistory.show_text`] and [`TextHistory.show_tokens`]. They print the text and tokens respectively and highlight the various segments using the [`rich` libray](https://github.com/Textualize/rich) (make sure to install it before using these methods).
178
+
179
+ You can see that the prompt is highlighted in gray, whereas system segments such as query and tool responses are highlighted in green. All segments generated by the model are highlighted in blue and in addition to the pure text output the reward is displayed as additional text in plum. Here an example of `show_text`:
180
+
181
+ <div style="text-align: center">
182
+ <img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/textenv_show_text.png" width=600>
183
+ </div>
184
+
185
+ Sometimes there can be tricky tokenization related issues that are hidden when showing the decoded text. Thus `TextHistory` also offers an option to display the same highlighting on the tokens directly with `show_tokens`:
186
+
187
+ <div style="text-align: center">
188
+ <img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/textenv_show_tokens.png" width=800>
189
+ </div>
190
+
191
+ Note that you can turn on the colour legend by passing `show_legend=True`.
192
+
193
+ ## API Documentation
194
+
195
+ [[autodoc]] TextEnvironment
196
+
197
+ [[autodoc]] TextHistory
docs/source/trainer.mdx ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Trainer
2
+
3
+ At TRL we support PPO (Proximal Policy Optimisation) with an implementation that largely follows the structure introduced in the paper "Fine-Tuning Language Models from Human Preferences" by D. Ziegler et al. [[paper](https://arxiv.org/pdf/1909.08593.pdf), [code](https://github.com/openai/lm-human-preferences)].
4
+ The Trainer and model classes are largely inspired from `transformers.Trainer` and `transformers.AutoModel` classes and adapted for RL.
5
+ We also support a `RewardTrainer` that can be used to train a reward model.
6
+
7
+ ## PPOConfig
8
+
9
+ [[autodoc]] PPOConfig
10
+
11
+ ## PPOTrainer
12
+
13
+ [[autodoc]] PPOTrainer
14
+
15
+ ## RewardConfig
16
+
17
+ [[autodoc]] RewardConfig
18
+
19
+ ## RewardTrainer
20
+
21
+ [[autodoc]] RewardTrainer
22
+
23
+ ## SFTTrainer
24
+
25
+ [[autodoc]] SFTTrainer
26
+
27
+ ## DPOTrainer
28
+
29
+ [[autodoc]] DPOTrainer
30
+
31
+ ## DDPOConfig
32
+
33
+ [[autodoc]] DDPOConfig
34
+
35
+ ## DDPOTrainer
36
+
37
+ [[autodoc]] DDPOTrainer
38
+
39
+ ## IterativeSFTTrainer
40
+
41
+ [[autodoc]] IterativeSFTTrainer
42
+
43
+ ## set_seed
44
+
45
+ [[autodoc]] set_seed