Step 6000
Browse files- .gitignore +215 -0
- README.md +77 -0
- baseline/adversarial_config.yaml +32 -0
- baseline/baseline_config.yaml +32 -0
- baseline/custom_dataset.py +110 -0
- baseline/custom_params.py +114 -0
- baseline/full_finetune.py +455 -0
- colorful/adversarial_config.yaml +39 -0
- colorful/basic_config.yaml +39 -0
- colorful/custom_dataset.py +179 -0
- colorful/custom_model.py +267 -0
- colorful/custom_params.py +110 -0
- colorful/full_finetune.py +511 -0
- colorful/masked_apply.py +73 -0
.gitignore
ADDED
@@ -0,0 +1,215 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Created by https://www.toptal.com/developers/gitignore/api/python,macos
|
2 |
+
# Edit at https://www.toptal.com/developers/gitignore?templates=python,macos
|
3 |
+
|
4 |
+
### TorchTune ###
|
5 |
+
|
6 |
+
output/
|
7 |
+
model/
|
8 |
+
wandb/
|
9 |
+
|
10 |
+
### macOS ###
|
11 |
+
# General
|
12 |
+
.DS_Store
|
13 |
+
.AppleDouble
|
14 |
+
.LSOverride
|
15 |
+
|
16 |
+
# Icon must end with two \r
|
17 |
+
Icon
|
18 |
+
|
19 |
+
|
20 |
+
# Thumbnails
|
21 |
+
._*
|
22 |
+
|
23 |
+
# Files that might appear in the root of a volume
|
24 |
+
.DocumentRevisions-V100
|
25 |
+
.fseventsd
|
26 |
+
.Spotlight-V100
|
27 |
+
.TemporaryItems
|
28 |
+
.Trashes
|
29 |
+
.VolumeIcon.icns
|
30 |
+
.com.apple.timemachine.donotpresent
|
31 |
+
|
32 |
+
# Directories potentially created on remote AFP share
|
33 |
+
.AppleDB
|
34 |
+
.AppleDesktop
|
35 |
+
Network Trash Folder
|
36 |
+
Temporary Items
|
37 |
+
.apdisk
|
38 |
+
|
39 |
+
### macOS Patch ###
|
40 |
+
# iCloud generated files
|
41 |
+
*.icloud
|
42 |
+
|
43 |
+
### Python ###
|
44 |
+
# Byte-compiled / optimized / DLL files
|
45 |
+
__pycache__/
|
46 |
+
*.py[cod]
|
47 |
+
*$py.class
|
48 |
+
|
49 |
+
# C extensions
|
50 |
+
*.so
|
51 |
+
|
52 |
+
# Distribution / packaging
|
53 |
+
.Python
|
54 |
+
build/
|
55 |
+
develop-eggs/
|
56 |
+
dist/
|
57 |
+
downloads/
|
58 |
+
eggs/
|
59 |
+
.eggs/
|
60 |
+
lib/
|
61 |
+
lib64/
|
62 |
+
parts/
|
63 |
+
sdist/
|
64 |
+
var/
|
65 |
+
wheels/
|
66 |
+
share/python-wheels/
|
67 |
+
*.egg-info/
|
68 |
+
.installed.cfg
|
69 |
+
*.egg
|
70 |
+
MANIFEST
|
71 |
+
|
72 |
+
# PyInstaller
|
73 |
+
# Usually these files are written by a python script from a template
|
74 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
75 |
+
*.manifest
|
76 |
+
*.spec
|
77 |
+
|
78 |
+
# Installer logs
|
79 |
+
pip-log.txt
|
80 |
+
pip-delete-this-directory.txt
|
81 |
+
|
82 |
+
# Unit test / coverage reports
|
83 |
+
htmlcov/
|
84 |
+
.tox/
|
85 |
+
.nox/
|
86 |
+
.coverage
|
87 |
+
.coverage.*
|
88 |
+
.cache
|
89 |
+
nosetests.xml
|
90 |
+
coverage.xml
|
91 |
+
*.cover
|
92 |
+
*.py,cover
|
93 |
+
.hypothesis/
|
94 |
+
.pytest_cache/
|
95 |
+
cover/
|
96 |
+
|
97 |
+
# Translations
|
98 |
+
*.mo
|
99 |
+
*.pot
|
100 |
+
|
101 |
+
# Django stuff:
|
102 |
+
*.log
|
103 |
+
local_settings.py
|
104 |
+
db.sqlite3
|
105 |
+
db.sqlite3-journal
|
106 |
+
|
107 |
+
# Flask stuff:
|
108 |
+
instance/
|
109 |
+
.webassets-cache
|
110 |
+
|
111 |
+
# Scrapy stuff:
|
112 |
+
.scrapy
|
113 |
+
|
114 |
+
# Sphinx documentation
|
115 |
+
docs/_build/
|
116 |
+
|
117 |
+
# PyBuilder
|
118 |
+
.pybuilder/
|
119 |
+
target/
|
120 |
+
|
121 |
+
# Jupyter Notebook
|
122 |
+
.ipynb_checkpoints
|
123 |
+
|
124 |
+
# IPython
|
125 |
+
profile_default/
|
126 |
+
ipython_config.py
|
127 |
+
|
128 |
+
# pyenv
|
129 |
+
# For a library or package, you might want to ignore these files since the code is
|
130 |
+
# intended to run in multiple environments; otherwise, check them in:
|
131 |
+
# .python-version
|
132 |
+
|
133 |
+
# pipenv
|
134 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
135 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
136 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
137 |
+
# install all needed dependencies.
|
138 |
+
#Pipfile.lock
|
139 |
+
|
140 |
+
# poetry
|
141 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
142 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
143 |
+
# commonly ignored for libraries.
|
144 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
145 |
+
#poetry.lock
|
146 |
+
|
147 |
+
# pdm
|
148 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
149 |
+
#pdm.lock
|
150 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
151 |
+
# in version control.
|
152 |
+
# https://pdm.fming.dev/#use-with-ide
|
153 |
+
.pdm.toml
|
154 |
+
|
155 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
156 |
+
__pypackages__/
|
157 |
+
|
158 |
+
# Celery stuff
|
159 |
+
celerybeat-schedule
|
160 |
+
celerybeat.pid
|
161 |
+
|
162 |
+
# SageMath parsed files
|
163 |
+
*.sage.py
|
164 |
+
|
165 |
+
# Environments
|
166 |
+
.env
|
167 |
+
.venv
|
168 |
+
env/
|
169 |
+
venv/
|
170 |
+
ENV/
|
171 |
+
env.bak/
|
172 |
+
venv.bak/
|
173 |
+
|
174 |
+
# Spyder project settings
|
175 |
+
.spyderproject
|
176 |
+
.spyproject
|
177 |
+
|
178 |
+
# Rope project settings
|
179 |
+
.ropeproject
|
180 |
+
|
181 |
+
# mkdocs documentation
|
182 |
+
/site
|
183 |
+
|
184 |
+
# mypy
|
185 |
+
.mypy_cache/
|
186 |
+
.dmypy.json
|
187 |
+
dmypy.json
|
188 |
+
|
189 |
+
# Pyre type checker
|
190 |
+
.pyre/
|
191 |
+
|
192 |
+
# pytype static type analyzer
|
193 |
+
.pytype/
|
194 |
+
|
195 |
+
# Cython debug symbols
|
196 |
+
cython_debug/
|
197 |
+
|
198 |
+
# PyCharm
|
199 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
200 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
201 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
202 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
203 |
+
#.idea/
|
204 |
+
|
205 |
+
### Python Patch ###
|
206 |
+
# Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration
|
207 |
+
poetry.toml
|
208 |
+
|
209 |
+
# ruff
|
210 |
+
.ruff_cache/
|
211 |
+
|
212 |
+
# LSP config files
|
213 |
+
pyrightconfig.json
|
214 |
+
|
215 |
+
# End of https://www.toptal.com/developers/gitignore/api/python,macos
|
README.md
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# torchtune research repo: token coloring (colorful llama)
|
2 |
+
|
3 |
+
Playground to try out [token coloring](https://docs.google.com/document/d/1Win9vhddD-pu5P3SsG7E-dzN5oQl5DYWW1DhO7sBOgI/edit#heading=h.oqq00pt8expe) with TorchTune.
|
4 |
+
|
5 |
+
The repo was generated using the alpha version of [torchtune](https://github.com/pytorch-labs/torchtune).
|
6 |
+
|
7 |
+
Brief notes:
|
8 |
+
|
9 |
+
- The starting recipe is based on the Alpaca Llama2 7B full finetune recipe (switched to bf16).
|
10 |
+
- I assume `output/` is used to store model outputs and `model/` is used to store the base model checkpoints.
|
11 |
+
|
12 |
+
For the `colorful` recipe:
|
13 |
+
|
14 |
+
- I copied a lot of functionality (like the actual model definition, dataset, etc) from torchtune repository directly since I needed to make changes.
|
15 |
+
- I reduced the flexiblity of the recipe (e.g. cannot specify the model or tokenizer) and increased it in other ways (e.g. can pass in a dataset path directly).
|
16 |
+
- I added intermediate checkpointing (i.e. every `n` steps) and automatically upload the checkpoint to HuggingFace Hub.
|
17 |
+
|
18 |
+
## Getting started
|
19 |
+
|
20 |
+
The below instructions can be copy-pasted as is on to a running instance. They assume that the `HF_TOKEN` environment variable is set with a valid token.
|
21 |
+
|
22 |
+
```bash
|
23 |
+
# for RunPod
|
24 |
+
cd /workspace
|
25 |
+
git clone git@github.com:pytorch-labs/torchtune.git
|
26 |
+
cd torchtune
|
27 |
+
pip install -e .
|
28 |
+
|
29 |
+
cd /workspace
|
30 |
+
git clone git@github.com:laurencer/torchtune-colorful-llama.git
|
31 |
+
cd torchtune-colorful-llama
|
32 |
+
|
33 |
+
# for wandb support
|
34 |
+
pip install wandb
|
35 |
+
```
|
36 |
+
|
37 |
+
```bash
|
38 |
+
mkdir -p model/
|
39 |
+
tune download --repo-id meta-llama/Llama-2-7b --output-dir model/
|
40 |
+
```
|
41 |
+
|
42 |
+
```bash
|
43 |
+
tune convert_checkpoint --checkpoint-path model/consolidated.00.pth --output-path model/llama2_native.tune
|
44 |
+
```
|
45 |
+
|
46 |
+
```bash
|
47 |
+
mkdir -p output/
|
48 |
+
# tune --nnodes 1 --nproc_per_node 1 ./colorful/full_finetune.py --config ./colorful/basic_config.yaml
|
49 |
+
nohup tune --nnodes 1 --nproc_per_node 1 ./colorful/full_finetune.py --config ./colorful/basic_config.yaml 2>&1 > training_log_$(date "+%Y.%m.%d_%H.%M.%S").log &
|
50 |
+
sleep 1
|
51 |
+
tail -f training_log_*.log
|
52 |
+
```
|
53 |
+
|
54 |
+
## Baselines
|
55 |
+
|
56 |
+
Two baseline configs are provided in the `baseline` directory.
|
57 |
+
We forked the original recipe to support customizing the location/path of the Alpaca dataset.
|
58 |
+
|
59 |
+
```bash
|
60 |
+
# tune --nnodes 1 --nproc_per_node 1 ./baseline/full_finetune.py --config ./baseline/baseline_config.yaml
|
61 |
+
nohup tune --nnodes 1 --nproc_per_node 1 ./baseline/full_finetune.py --config ./baseline/baseline_config.yaml 2>&1 > training_log_$(date "+%Y.%m.%d_%H.%M.%S").log &
|
62 |
+
sleep 1
|
63 |
+
tail -f training_log_*.log
|
64 |
+
```
|
65 |
+
|
66 |
+
The adversarial config uses a dataset that is equivalent to 4x the original alpaca cleaned dataset with extra examples that include prompt injection attempts. See [token coloring description](https://docs.google.com/document/d/1Win9vhddD-pu5P3SsG7E-dzN5oQl5DYWW1DhO7sBOgI/edit#heading=h.oqq00pt8expe) for more info.
|
67 |
+
|
68 |
+
```bash
|
69 |
+
# tune --nnodes 1 --nproc_per_node 1 ./baseline/full_finetune.py --config ./baseline/adversarial_config.yaml
|
70 |
+
nohup tune --nnodes 1 --nproc_per_node 1 ./baseline/full_finetune.py --config ./baseline/adversarial_config.yaml 2>&1 > training_log_$(date "+%Y.%m.%d_%H.%M.%S").log &
|
71 |
+
sleep 1
|
72 |
+
tail -f training_log_*.log
|
73 |
+
```
|
74 |
+
|
75 |
+
## Colorful
|
76 |
+
|
77 |
+
The `colorful` directory implements the changes required to support token coloring. This includes a custom dataset implementation and training script.
|
baseline/adversarial_config.yaml
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Runs the full_finetune.py recipe
|
2 |
+
#
|
3 |
+
# To launch, run the following command from root:
|
4 |
+
# tune --nnodes 1 --nproc_per_node 1 --config alpaca_llama2_full_finetune --override model_checkpoint=<your_checkpoint_dir> ...
|
5 |
+
|
6 |
+
# Dataset and Dataloader
|
7 |
+
dataset: laurencer/yahma-alpaca-cleaned-adversarial
|
8 |
+
seed: 42
|
9 |
+
shuffle: True
|
10 |
+
|
11 |
+
# Model Arguments
|
12 |
+
model: llama2_7b
|
13 |
+
model_checkpoint: model/llama2_native.tune
|
14 |
+
tokenizer: llama2_tokenizer
|
15 |
+
tokenizer_checkpoint: model/tokenizer.model
|
16 |
+
|
17 |
+
# Fine-tuning arguments
|
18 |
+
batch_size: 8
|
19 |
+
lr: 2e-5
|
20 |
+
epochs: 1
|
21 |
+
optimizer: SGD
|
22 |
+
loss: CrossEntropyLoss
|
23 |
+
output_dir: output/alpaca-llama2-adversarial
|
24 |
+
device: cuda
|
25 |
+
dtype: bf16
|
26 |
+
enable_fsdp: False
|
27 |
+
enable_activation_checkpointing: True
|
28 |
+
resume_from_checkpoint: False
|
29 |
+
|
30 |
+
# Logging arguments
|
31 |
+
metric_logger_type: wandb
|
32 |
+
project: torchtune
|
baseline/baseline_config.yaml
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Runs the full_finetune.py recipe
|
2 |
+
#
|
3 |
+
# To launch, run the following command from root:
|
4 |
+
# tune --nnodes 1 --nproc_per_node 1 --config alpaca_llama2_full_finetune --override model_checkpoint=<your_checkpoint_dir> ...
|
5 |
+
|
6 |
+
# Dataset and Dataloader
|
7 |
+
dataset: yahma/alpaca-cleaned
|
8 |
+
seed: 42
|
9 |
+
shuffle: True
|
10 |
+
|
11 |
+
# Model Arguments
|
12 |
+
model: llama2_7b
|
13 |
+
model_checkpoint: model/llama2_native.tune
|
14 |
+
tokenizer: llama2_tokenizer
|
15 |
+
tokenizer_checkpoint: model/tokenizer.model
|
16 |
+
|
17 |
+
# Fine-tuning arguments
|
18 |
+
batch_size: 8
|
19 |
+
lr: 2e-5
|
20 |
+
epochs: 4
|
21 |
+
optimizer: SGD
|
22 |
+
loss: CrossEntropyLoss
|
23 |
+
output_dir: output/alpaca-llama2-baseline
|
24 |
+
device: cuda
|
25 |
+
dtype: bf16
|
26 |
+
enable_fsdp: False
|
27 |
+
enable_activation_checkpointing: True
|
28 |
+
resume_from_checkpoint: False
|
29 |
+
|
30 |
+
# Logging arguments
|
31 |
+
metric_logger_type: wandb
|
32 |
+
project: torchtune
|
baseline/custom_dataset.py
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from typing import List, Tuple
|
3 |
+
|
4 |
+
from datasets import load_dataset
|
5 |
+
from torch.utils.data import Dataset
|
6 |
+
|
7 |
+
# Not ideal to import this type here but it's needed for the transform function
|
8 |
+
from torchtune.modules import Tokenizer
|
9 |
+
|
10 |
+
|
11 |
+
CROSS_ENTROPY_IGNORE_IDX = -100
|
12 |
+
|
13 |
+
_PROMPT_TEMPLATE = {
|
14 |
+
"prompt_input": (
|
15 |
+
"Below is an instruction that describes a task, paired with an input that provides further context. "
|
16 |
+
"Write a response that appropriately completes the request.\n\n"
|
17 |
+
"### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
|
18 |
+
),
|
19 |
+
"prompt_no_input": (
|
20 |
+
"Below is an instruction that describes a task. "
|
21 |
+
"Write a response that appropriately completes the request.\n\n"
|
22 |
+
"### Instruction:\n{instruction}\n\n### Response:\n"
|
23 |
+
),
|
24 |
+
}
|
25 |
+
|
26 |
+
|
27 |
+
class AlpacaDataset(Dataset):
|
28 |
+
"""
|
29 |
+
See torchtune.datasets.AlpacaDataset for the original implementation.
|
30 |
+
This version supports custom dataset paths.
|
31 |
+
"""
|
32 |
+
|
33 |
+
def __init__(
|
34 |
+
self,
|
35 |
+
dataset_path: str,
|
36 |
+
tokenizer: Tokenizer,
|
37 |
+
train_on_input: bool = True,
|
38 |
+
**kwargs
|
39 |
+
) -> None:
|
40 |
+
self._data = load_dataset(dataset_path, split="train")
|
41 |
+
self._tokenizer = tokenizer
|
42 |
+
self.train_on_input = train_on_input
|
43 |
+
|
44 |
+
def __len__(self):
|
45 |
+
return len(self._data)
|
46 |
+
|
47 |
+
def __getitem__(self, index: int) -> Tuple[List[int], List[int]]:
|
48 |
+
sample = self._data[index]
|
49 |
+
|
50 |
+
return self._transform(
|
51 |
+
instruction=sample["instruction"],
|
52 |
+
input=sample["input"],
|
53 |
+
output=sample["output"],
|
54 |
+
)
|
55 |
+
|
56 |
+
def _transform(
|
57 |
+
self, instruction: str, input: str, output: str
|
58 |
+
) -> Tuple[List[int], List[int]]:
|
59 |
+
"""
|
60 |
+
Split a sample on ``response`` tag to create input and labels.
|
61 |
+
|
62 |
+
Args:
|
63 |
+
instruction (str): Instruction text.
|
64 |
+
input (str): Input text. Can be an empty string. Determines the prompt generation template
|
65 |
+
used.
|
66 |
+
output (str): Response text.
|
67 |
+
|
68 |
+
Returns:
|
69 |
+
Tuple of encoded inputs and labels.
|
70 |
+
"""
|
71 |
+
prompt = self._generate_prompt(instruction, input)
|
72 |
+
prompt_with_response = prompt + output
|
73 |
+
|
74 |
+
# add bos always; LlamaTokenizer sets this to True by default and neither
|
75 |
+
# alpaca-lora or the original authors change this
|
76 |
+
encoded_prompt = self._tokenizer.encode(
|
77 |
+
text=prompt, add_bos=True, add_eos=False
|
78 |
+
)
|
79 |
+
encoded_prompt_with_response = self._tokenizer.encode(
|
80 |
+
text=prompt_with_response, add_bos=True, add_eos=True
|
81 |
+
)
|
82 |
+
labels = encoded_prompt_with_response.copy()
|
83 |
+
|
84 |
+
if not self.train_on_input:
|
85 |
+
labels[: len(encoded_prompt)] = [CROSS_ENTROPY_IGNORE_IDX] * len(
|
86 |
+
encoded_prompt
|
87 |
+
)
|
88 |
+
|
89 |
+
assert len(encoded_prompt_with_response) == len(labels)
|
90 |
+
|
91 |
+
return encoded_prompt_with_response, labels
|
92 |
+
|
93 |
+
def _generate_prompt(self, instruction: str, input: str) -> str:
|
94 |
+
"""
|
95 |
+
Generate prompt from instruction and input.
|
96 |
+
|
97 |
+
Args:
|
98 |
+
instruction (str): Instruction text.
|
99 |
+
input (str): Input text.
|
100 |
+
|
101 |
+
Returns:
|
102 |
+
Prompt text.
|
103 |
+
"""
|
104 |
+
if input:
|
105 |
+
prompt = _PROMPT_TEMPLATE["prompt_input"].format(
|
106 |
+
instruction=instruction, input=input
|
107 |
+
)
|
108 |
+
else:
|
109 |
+
prompt = _PROMPT_TEMPLATE["prompt_no_input"].format(instruction=instruction)
|
110 |
+
return prompt
|
baseline/custom_params.py
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Customized to remove dataset validation.
|
2 |
+
|
3 |
+
from dataclasses import dataclass, field, fields
|
4 |
+
from typing import List, Optional
|
5 |
+
|
6 |
+
from torchtune.datasets import ALL_DATASETS
|
7 |
+
from torchtune.models import ALL_MODELS, ALL_TOKENIZERS
|
8 |
+
from torchtune.utils.metric_logging import ALL_METRIC_LOGGERS
|
9 |
+
from torchtune.utils.precision import PRECISION_STR_TO_DTYPE
|
10 |
+
|
11 |
+
|
12 |
+
@dataclass
|
13 |
+
class FullFinetuneParams:
|
14 |
+
"""Arguments for the finetune_llm recipe.
|
15 |
+
|
16 |
+
Args:
|
17 |
+
device (str): Device to use for training. Options are "cpu" and "cuda"
|
18 |
+
dtype (str): Data type to use for training.
|
19 |
+
seed (int): Random seed to use for training.
|
20 |
+
model (str): String specifying model architecture to fine-tune. See ``torchtune.models.get_model`` for options.
|
21 |
+
model_checkpoint (str): Local path to load model checkpoint from.
|
22 |
+
tokenizer (str): String specifying tokenizer to use. See ``torchtune.models.get_tokenizer`` for options.
|
23 |
+
tokenizer_checkpoint (str): Local path to load tokenizer checkpoint from.
|
24 |
+
dataset (str): String specifying dataset to use. See ``torchtune.datasets.get_dataset`` for options.
|
25 |
+
Currently, only predefined datasets in library are supported.
|
26 |
+
shuffle (bool): Whether to shuffle dataset.
|
27 |
+
batch_size (int): Batch size to use for training.
|
28 |
+
epochs (int): Number of epochs to train for.
|
29 |
+
optimizer (str): String specifying optimizer to use. See ``torchtune.optim.get_optimizer`` for options.
|
30 |
+
loss (str): String specifying loss function to use. See ``torchtune.losses.get_loss`` for options.
|
31 |
+
lr (float): Learning rate to use for optimizer.
|
32 |
+
activation_checkpointing (bool): Whether to use activation checkpointing.
|
33 |
+
output_dir (str): Local path to save checkpoints and logs to.
|
34 |
+
run_generation (int): Run eval on a prompt every ``run_generation`` steps. Set to 0 to disable.
|
35 |
+
max_steps_per_epoch (int): Maximum number of steps to take per epoch.
|
36 |
+
metric_logger_type (str): String specifying metric logger to use. See ``torchtune.utils.get_metric_logger``
|
37 |
+
for options.
|
38 |
+
project (str): Project name to use for logging. Used by ``WandBLogger``.
|
39 |
+
resume_from_previous_checkpoint (bool): Whether to resume fine-tuning from a previous checkpoint.
|
40 |
+
cpu_offload (bool): Whether to offload model to CPU.
|
41 |
+
|
42 |
+
Raises:
|
43 |
+
ValueError: If ``cpu_offload`` is ``True`` but ``device`` is not ``cuda`` and <= 1 GPUs.
|
44 |
+
"""
|
45 |
+
|
46 |
+
# Model
|
47 |
+
model: str = ""
|
48 |
+
model_checkpoint: str = ""
|
49 |
+
|
50 |
+
# Tokenizer
|
51 |
+
tokenizer: str = ""
|
52 |
+
tokenizer_checkpoint: str = ""
|
53 |
+
|
54 |
+
# Dataset and Sampler
|
55 |
+
dataset: str = ""
|
56 |
+
train_on_input: bool = True
|
57 |
+
shuffle: bool = True
|
58 |
+
batch_size: int = 2
|
59 |
+
|
60 |
+
# Optimizer and Scheduler
|
61 |
+
optimizer: str = "SGD"
|
62 |
+
lr: float = 2e-5
|
63 |
+
loss: str = "CrossEntropyLoss"
|
64 |
+
gradient_accumulation_steps: int = 1
|
65 |
+
|
66 |
+
# Training
|
67 |
+
epochs: int = 3
|
68 |
+
max_steps_per_epoch: Optional[int] = None
|
69 |
+
resume_from_checkpoint: bool = False
|
70 |
+
run_generation: Optional[int] = None
|
71 |
+
|
72 |
+
# Distributed
|
73 |
+
cpu_offload: bool = False
|
74 |
+
enable_fsdp: bool = True
|
75 |
+
enable_activation_checkpointing: bool = True
|
76 |
+
|
77 |
+
# Environment
|
78 |
+
device: str = "cuda"
|
79 |
+
dtype: str = "fp32"
|
80 |
+
seed: Optional[int] = None
|
81 |
+
|
82 |
+
# Logging
|
83 |
+
output_dir: str = "/tmp/full_finetune_output"
|
84 |
+
metric_logger_type: str = "disk"
|
85 |
+
project: Optional[str] = None
|
86 |
+
log_every_n_steps: Optional[int] = None
|
87 |
+
|
88 |
+
def __post_init__(self):
|
89 |
+
for param in fields(self):
|
90 |
+
if getattr(self, param.name) == "":
|
91 |
+
raise TypeError(f"{param.name} needs to be specified")
|
92 |
+
|
93 |
+
if self.cpu_offload and self.device != "cuda":
|
94 |
+
raise ValueError(
|
95 |
+
"Cannot offload model to CPU if device is not cuda or <= 1 GPUs."
|
96 |
+
)
|
97 |
+
if self.enable_fsdp and self.device == "cpu":
|
98 |
+
raise ValueError("FSDP is not supported on CPU.")
|
99 |
+
if self.model not in ALL_MODELS:
|
100 |
+
raise ValueError(
|
101 |
+
f"Model not recognized. Expected one of {ALL_MODELS}, received {self.model}."
|
102 |
+
)
|
103 |
+
if self.tokenizer not in ALL_TOKENIZERS:
|
104 |
+
raise ValueError(
|
105 |
+
f"Tokenizer not recognized. Expected one of {ALL_TOKENIZERS}, received {self.tokenizer}."
|
106 |
+
)
|
107 |
+
if self.metric_logger_type not in ALL_METRIC_LOGGERS:
|
108 |
+
raise ValueError(
|
109 |
+
f"Metric logger not recognized. Expected one of {ALL_METRIC_LOGGERS}, received {self.metric_logger_type}."
|
110 |
+
)
|
111 |
+
if self.dtype not in PRECISION_STR_TO_DTYPE:
|
112 |
+
raise ValueError(
|
113 |
+
f"Dtype {self.dtype} must be one of {', '.join(PRECISION_STR_TO_DTYPE.keys())} for finetuning."
|
114 |
+
)
|
baseline/full_finetune.py
ADDED
@@ -0,0 +1,455 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the BSD-style license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import argparse
|
8 |
+
import os
|
9 |
+
import sys
|
10 |
+
|
11 |
+
from functools import partial
|
12 |
+
from typing import Any, Dict, Optional, Tuple
|
13 |
+
from warnings import warn
|
14 |
+
|
15 |
+
import torch
|
16 |
+
|
17 |
+
from torch import nn
|
18 |
+
from torch.cuda.amp import GradScaler
|
19 |
+
from torch.distributed import init_process_group
|
20 |
+
from torch.optim import Optimizer
|
21 |
+
from torch.utils.data import DataLoader, DistributedSampler
|
22 |
+
|
23 |
+
from torchtune import models, modules, utils
|
24 |
+
from torchtune.utils.constants import (
|
25 |
+
EPOCHS_KEY,
|
26 |
+
MAX_STEPS_KEY,
|
27 |
+
MODEL_KEY,
|
28 |
+
OPT_KEY,
|
29 |
+
SEED_KEY,
|
30 |
+
TOTAL_EPOCHS_KEY,
|
31 |
+
)
|
32 |
+
|
33 |
+
from tqdm import tqdm
|
34 |
+
|
35 |
+
from recipes.interfaces import FTRecipeInterface
|
36 |
+
|
37 |
+
|
38 |
+
from custom_params import FullFinetuneParams
|
39 |
+
from custom_dataset import AlpacaDataset
|
40 |
+
|
41 |
+
log = utils.get_logger("DEBUG")
|
42 |
+
|
43 |
+
|
44 |
+
class FullFinetuneRecipe(FTRecipeInterface):
|
45 |
+
"""
|
46 |
+
Full finetuning recipe for dense transformer-based LLMs such as Llama2.
|
47 |
+
|
48 |
+
This recipe supports:
|
49 |
+
- FSDP and activation checkpointing. This is enabled by default but can be
|
50 |
+
configured using the ``enable_fsdp`` and ``enable_activation_checkpointing`` flags.
|
51 |
+
- Mixed precision training - fp32, fp16 and bf16 are supported.
|
52 |
+
- Checkpointing of model weights, optimizer state and the recipe state (epoch and seed).
|
53 |
+
- Resuming from checkpoints saved using the ``save_checkpoint`` functionality.
|
54 |
+
- Logging to terminal. WandB and TensorBoard are currently not supported.
|
55 |
+
|
56 |
+
Assumptions:
|
57 |
+
- Training is launched with the Tune CLI (recommended) which uses TorchRun under the
|
58 |
+
hood. Setting up the env variables is handled by TorchRun.
|
59 |
+
- Training happens on CUDA (CPU training is not supported)
|
60 |
+
- Checkpoints are ONLY saved at epoch boundaries. Mid-epoch checkpointing is NOT supported.
|
61 |
+
- Datasets are Map-style and data fits in memory (not streamed).
|
62 |
+
"""
|
63 |
+
|
64 |
+
def __init__(self, params: FullFinetuneParams) -> None:
|
65 |
+
|
66 |
+
self._device = utils.get_device(device=params.device)
|
67 |
+
self._dtype = utils.get_dtype(dtype=params.dtype)
|
68 |
+
|
69 |
+
# logging attributes
|
70 |
+
self._output_dir = params.output_dir
|
71 |
+
self._metric_logger = utils.get_metric_logger(
|
72 |
+
metric_logger_type=params.metric_logger_type,
|
73 |
+
project=params.project,
|
74 |
+
log_dir=params.output_dir,
|
75 |
+
)
|
76 |
+
self._log_every_n_steps = (
|
77 |
+
params.log_every_n_steps if params.log_every_n_steps else 1
|
78 |
+
)
|
79 |
+
|
80 |
+
# _is_rank_zero is used primarily for logging. In the future, the logger
|
81 |
+
# should directly take care of this
|
82 |
+
_, rank = utils.get_world_size_and_rank()
|
83 |
+
self._is_rank_zero = rank == 0
|
84 |
+
|
85 |
+
# Training params
|
86 |
+
self._resume_from_checkpoint = params.resume_from_checkpoint
|
87 |
+
self._enable_fsdp = params.enable_fsdp
|
88 |
+
self._gradient_accumulation_steps = params.gradient_accumulation_steps
|
89 |
+
|
90 |
+
# These are public properties which are updated by the checkpoint loader
|
91 |
+
# when ``resume_from_checkpoint`` is `True` or validated in tests
|
92 |
+
self.seed = utils.set_seed(seed=params.seed)
|
93 |
+
self.epochs_run = 0
|
94 |
+
self.total_epochs = params.epochs
|
95 |
+
self.max_steps_per_epoch = params.max_steps_per_epoch
|
96 |
+
self.total_training_steps = 0
|
97 |
+
|
98 |
+
def load_checkpoint(self, ckpt_path: str):
|
99 |
+
"""
|
100 |
+
Extract the checkpoint state from file and validate.
|
101 |
+
"""
|
102 |
+
ckpt_dict = torch.load(ckpt_path, map_location="cpu", weights_only=True)
|
103 |
+
utils.validate_checkpoint(ckpt_dict, self._resume_from_checkpoint)
|
104 |
+
return ckpt_dict
|
105 |
+
|
106 |
+
def setup(self, params: FullFinetuneParams) -> None:
|
107 |
+
"""
|
108 |
+
Sets up the recipe state correctly. This includes setting recipe attributes based
|
109 |
+
on the ``resume_from_checkpoint`` flag.
|
110 |
+
"""
|
111 |
+
|
112 |
+
ckpt_dict = self.load_checkpoint(ckpt_path=params.model_checkpoint)
|
113 |
+
|
114 |
+
# If we're resuming from checkpoint, the recipe's state should be updated before
|
115 |
+
# initializing the training components. This ensures that the seed is correctly
|
116 |
+
# propagated to the relevant components
|
117 |
+
if self._resume_from_checkpoint:
|
118 |
+
self._update_recipe_state(ckpt_dict)
|
119 |
+
|
120 |
+
# ``_setup_model`` handles initialization and loading the state dict. This method
|
121 |
+
# should be called before ``_setup_optimizer`` since transforming the optimizer
|
122 |
+
# state dict requires the model
|
123 |
+
self._model = self._setup_model(
|
124 |
+
model=params.model,
|
125 |
+
enable_fsdp=params.enable_fsdp,
|
126 |
+
enable_activation_checkpointing=params.enable_activation_checkpointing,
|
127 |
+
model_state_dict=ckpt_dict[MODEL_KEY],
|
128 |
+
)
|
129 |
+
|
130 |
+
self._tokenizer = self._setup_tokenizer(
|
131 |
+
tokenizer=params.tokenizer, tokenizer_checkpoint=params.tokenizer_checkpoint
|
132 |
+
)
|
133 |
+
|
134 |
+
# _setup_optimizer should take in ckpt_dict only if training is resumed from
|
135 |
+
# checkpoint. Transforming the opt state dict is handled by this method
|
136 |
+
self._optimizer = self._setup_optimizer(
|
137 |
+
optimizer=params.optimizer,
|
138 |
+
lr=params.lr,
|
139 |
+
opt_state_dict=ckpt_dict[OPT_KEY] if self._resume_from_checkpoint else None,
|
140 |
+
)
|
141 |
+
|
142 |
+
self._loss_fn = self._setup_loss(loss=params.loss)
|
143 |
+
|
144 |
+
# sampler and dataloader depend on the tokenizer and loss_fn and should be
|
145 |
+
# setup after both of these are initialized
|
146 |
+
self._sampler, self._dataloader = self._setup_data(
|
147 |
+
dataset=params.dataset,
|
148 |
+
train_on_input=params.train_on_input,
|
149 |
+
shuffle=params.shuffle,
|
150 |
+
batch_size=params.batch_size,
|
151 |
+
)
|
152 |
+
|
153 |
+
# training setup
|
154 |
+
self._autocast = utils.get_autocast(self._dtype, self._device)
|
155 |
+
self._grad_scaler = None
|
156 |
+
if self._dtype == torch.float16:
|
157 |
+
self._grad_scaler = utils.get_gradient_scaler(fsdp=params.enable_fsdp)
|
158 |
+
else:
|
159 |
+
self._grad_scaler = GradScaler(enabled=False)
|
160 |
+
|
161 |
+
# Finally update the recipe state which can only be correctly set after all of the
|
162 |
+
# other components have been initialized and updated.
|
163 |
+
#
|
164 |
+
# Number of training steps in each epoch depends on the number of batches produced
|
165 |
+
# by the dataloader, the max_steps_per_epoch param set by the user and the
|
166 |
+
# gradient_accumulation_steps param. This value is used for logging and tracking
|
167 |
+
# training state. The computation should happen after the dataloader has been setup
|
168 |
+
self._steps_per_epoch = (
|
169 |
+
len(self._dataloader) // self._gradient_accumulation_steps
|
170 |
+
)
|
171 |
+
if (
|
172 |
+
self.max_steps_per_epoch is not None
|
173 |
+
and self.max_steps_per_epoch < self._steps_per_epoch
|
174 |
+
):
|
175 |
+
self._steps_per_epoch = self.max_steps_per_epoch
|
176 |
+
self.total_training_steps = self.epochs_run * self._steps_per_epoch
|
177 |
+
|
178 |
+
def _update_recipe_state(self, ckpt_dict: Dict[str, Any]) -> None:
|
179 |
+
"""
|
180 |
+
Updates the recipe state from checkpoint.
|
181 |
+
"""
|
182 |
+
# If seed, total_epoch or max_steps_per_epoch don't match,
|
183 |
+
# warn the user and overwrite
|
184 |
+
if (
|
185 |
+
self.seed != ckpt_dict[SEED_KEY]
|
186 |
+
or self.total_epochs != ckpt_dict[TOTAL_EPOCHS_KEY]
|
187 |
+
or self.max_steps_per_epoch != ckpt_dict[MAX_STEPS_KEY]
|
188 |
+
):
|
189 |
+
warn(
|
190 |
+
message="""Configured value for seed, epochs or max_steps_per_epoch
|
191 |
+
does not match the value stored in checkpoint."""
|
192 |
+
)
|
193 |
+
self.seed = utils.set_seed(seed=ckpt_dict[SEED_KEY])
|
194 |
+
self.epochs_run = ckpt_dict[EPOCHS_KEY]
|
195 |
+
self.total_epochs = ckpt_dict[TOTAL_EPOCHS_KEY]
|
196 |
+
self.max_steps_per_epoch = ckpt_dict[MAX_STEPS_KEY]
|
197 |
+
|
198 |
+
def _setup_model(
|
199 |
+
self,
|
200 |
+
model: str,
|
201 |
+
enable_fsdp: bool,
|
202 |
+
enable_activation_checkpointing: bool,
|
203 |
+
model_state_dict: Dict[str, Any],
|
204 |
+
) -> nn.Module:
|
205 |
+
"""
|
206 |
+
Set up the model including enabling FSDP and activation checkpointing. For this recipe,
|
207 |
+
``enable_fsdp`` should always be ``True``. This is currently a configurable flag for
|
208 |
+
running tests on CPUs.
|
209 |
+
"""
|
210 |
+
model = models.get_model(model, device=self._device)
|
211 |
+
model = (
|
212 |
+
utils.wrap_fsdp(
|
213 |
+
model=model,
|
214 |
+
device=self._device,
|
215 |
+
dtype=self._dtype,
|
216 |
+
strategy="FULL_SHARD",
|
217 |
+
auto_wrap_policy={modules.TransformerDecoderLayer},
|
218 |
+
)
|
219 |
+
if enable_fsdp
|
220 |
+
else model
|
221 |
+
)
|
222 |
+
if enable_activation_checkpointing:
|
223 |
+
utils.set_activation_checkpointing(
|
224 |
+
model, auto_wrap_policy={modules.TransformerDecoderLayer}
|
225 |
+
)
|
226 |
+
|
227 |
+
model.load_state_dict(model_state_dict)
|
228 |
+
|
229 |
+
if self._is_rank_zero:
|
230 |
+
log.info(
|
231 |
+
"Model is initialized. FSDP and Activation Checkpointing are enabled."
|
232 |
+
)
|
233 |
+
return model
|
234 |
+
|
235 |
+
def _setup_tokenizer(
|
236 |
+
self, tokenizer: str, tokenizer_checkpoint: str
|
237 |
+
) -> modules.Tokenizer:
|
238 |
+
"""
|
239 |
+
Unlike ```setup_model```, this takes in the checkpoint and loads the sentencepiece
|
240 |
+
tokenizer model. This is related to how the tokenizer is implemented and should
|
241 |
+
change in a future iteration.
|
242 |
+
"""
|
243 |
+
tokenizer = models.get_tokenizer(tokenizer, path=tokenizer_checkpoint)
|
244 |
+
|
245 |
+
if self._is_rank_zero:
|
246 |
+
log.info("Tokenizer is initialized from file.")
|
247 |
+
return tokenizer
|
248 |
+
|
249 |
+
def _setup_optimizer(
|
250 |
+
self, optimizer: str, lr: float, opt_state_dict: Optional[Dict[str, Any]] = None
|
251 |
+
) -> Optimizer:
|
252 |
+
"""
|
253 |
+
Set up the optimizer. This method also handles transforing the state dict
|
254 |
+
for FSDP.
|
255 |
+
"""
|
256 |
+
optimizer = modules.get_optimizer(optimizer, self._model, lr)
|
257 |
+
if opt_state_dict:
|
258 |
+
opt_state_dict = utils.transform_opt_state_dict(
|
259 |
+
opt_state_dict, self._model, optimizer
|
260 |
+
)
|
261 |
+
optimizer.load_state_dict(opt_state_dict)
|
262 |
+
|
263 |
+
if self._is_rank_zero:
|
264 |
+
log.info("Optimizer is initialized.")
|
265 |
+
return optimizer
|
266 |
+
|
267 |
+
def _setup_loss(self, loss: str) -> nn.Module:
|
268 |
+
loss_fn = modules.get_loss(loss)
|
269 |
+
|
270 |
+
if self._is_rank_zero:
|
271 |
+
log.info("Loss is initialized.")
|
272 |
+
|
273 |
+
return loss_fn
|
274 |
+
|
275 |
+
def _setup_data(
|
276 |
+
self, dataset: str, shuffle: bool, batch_size: int, train_on_input: bool
|
277 |
+
) -> Tuple[DistributedSampler, DataLoader]:
|
278 |
+
"""
|
279 |
+
All data related setup happens here. Currently this recipe only supports the
|
280 |
+
DistributedSamplers with Map-style Datasets which fit into memory. Other samplers,
|
281 |
+
iterable datasets and streaming datasets are not supported.
|
282 |
+
"""
|
283 |
+
world_size, rank = utils.get_world_size_and_rank()
|
284 |
+
ds = AlpacaDataset(dataset, tokenizer=self._tokenizer, train_on_input=train_on_input)
|
285 |
+
|
286 |
+
sampler = DistributedSampler(
|
287 |
+
ds,
|
288 |
+
num_replicas=world_size,
|
289 |
+
rank=rank,
|
290 |
+
shuffle=shuffle,
|
291 |
+
seed=0,
|
292 |
+
)
|
293 |
+
dataloader = DataLoader(
|
294 |
+
dataset=ds,
|
295 |
+
batch_size=batch_size,
|
296 |
+
sampler=sampler,
|
297 |
+
collate_fn=partial(
|
298 |
+
utils.padded_collate,
|
299 |
+
padding_idx=self._tokenizer.pad_id,
|
300 |
+
ignore_idx=self._loss_fn.ignore_index, # TODO support loss without ignore_index
|
301 |
+
),
|
302 |
+
)
|
303 |
+
|
304 |
+
if self._is_rank_zero:
|
305 |
+
log.info("Dataset and Sampler are initialized.")
|
306 |
+
|
307 |
+
return sampler, dataloader
|
308 |
+
|
309 |
+
def save_checkpoint(self, epoch: int) -> None:
|
310 |
+
"""
|
311 |
+
Checkpoint the relevant state of a recipe.
|
312 |
+
|
313 |
+
This makes use of the `save_checkpoint` utility which is responsible for
|
314 |
+
writing the checkpoint dictionary to file. The contents of the dict are dictated
|
315 |
+
by whether training is complete or not.
|
316 |
+
|
317 |
+
If training is ongoing, optimizer state, seed and epochs_run are saved along with the
|
318 |
+
model weights.
|
319 |
+
"""
|
320 |
+
os.makedirs(self._output_dir, exist_ok=True)
|
321 |
+
output_loc = f"{self._output_dir}/model_{epoch}.ckpt"
|
322 |
+
ckpt_dict = {MODEL_KEY: self._model}
|
323 |
+
|
324 |
+
# if training is in-progress, checkpoint the optimizer state as well
|
325 |
+
if epoch + 1 < self.total_epochs:
|
326 |
+
ckpt_dict.update(
|
327 |
+
{
|
328 |
+
OPT_KEY: self._optimizer,
|
329 |
+
SEED_KEY: self.seed,
|
330 |
+
EPOCHS_KEY: self.epochs_run,
|
331 |
+
TOTAL_EPOCHS_KEY: self.total_epochs,
|
332 |
+
MAX_STEPS_KEY: self.max_steps_per_epoch,
|
333 |
+
}
|
334 |
+
)
|
335 |
+
utils.save_checkpoint(ckpt_dict, output_loc)
|
336 |
+
|
337 |
+
if self._is_rank_zero:
|
338 |
+
log.info(
|
339 |
+
f"Model checkpoint of size {os.path.getsize(output_loc) >> 20} MB saved to {output_loc}"
|
340 |
+
)
|
341 |
+
|
342 |
+
def _should_update_weights(self, curr_step: int) -> bool:
|
343 |
+
"""
|
344 |
+
Determines whether the weights should be updated on the current step or not.
|
345 |
+
True is returned either if we've accumulated gradients for enough steps or if this
|
346 |
+
is the last step in the epoch.
|
347 |
+
"""
|
348 |
+
should_update_weights = (
|
349 |
+
curr_step + 1
|
350 |
+
) % self._gradient_accumulation_steps == 0 or (
|
351 |
+
curr_step + 1
|
352 |
+
) == self._steps_per_epoch
|
353 |
+
return should_update_weights
|
354 |
+
|
355 |
+
def train(self) -> None:
|
356 |
+
"""
|
357 |
+
The core training loop. Supports training on subsets of the dataset using the
|
358 |
+
``max_steps_per_epoch``.
|
359 |
+
"""
|
360 |
+
_, rank = utils.get_world_size_and_rank()
|
361 |
+
|
362 |
+
# zero out the gradients before starting training
|
363 |
+
self._optimizer.zero_grad()
|
364 |
+
|
365 |
+
# self.epochs_run should be non-zero when we're resuming from a checkpoint
|
366 |
+
for curr_epoch in range(self.epochs_run, self.total_epochs):
|
367 |
+
|
368 |
+
# Update the sampler to ensure data is correctly shuffled across epochs
|
369 |
+
# in case shuffle is True
|
370 |
+
self._sampler.set_epoch(curr_epoch)
|
371 |
+
|
372 |
+
for idx, batch in enumerate(
|
373 |
+
pbar := tqdm(self._dataloader, disable=not (rank == 0))
|
374 |
+
):
|
375 |
+
if (
|
376 |
+
self.max_steps_per_epoch is not None
|
377 |
+
and (idx // self._gradient_accumulation_steps)
|
378 |
+
== self.max_steps_per_epoch
|
379 |
+
):
|
380 |
+
break
|
381 |
+
|
382 |
+
input_ids, labels = batch
|
383 |
+
input_ids = input_ids.to(self._device)
|
384 |
+
labels = labels.to(self._device)
|
385 |
+
|
386 |
+
with self._autocast:
|
387 |
+
logits = self._model(input_ids)
|
388 |
+
# Shift so that tokens < n predict n
|
389 |
+
logits = logits[..., :-1, :].contiguous()
|
390 |
+
labels = labels[..., 1:].contiguous()
|
391 |
+
logits = logits.transpose(1, 2)
|
392 |
+
# Compute loss
|
393 |
+
loss = self._loss_fn(logits, labels)
|
394 |
+
|
395 |
+
# Note: We're always logging the loss before normalizing it
|
396 |
+
# Check if this is the norm or not
|
397 |
+
pbar.set_description(f"{curr_epoch+1}|{idx+1}|Loss: {loss.item()}")
|
398 |
+
|
399 |
+
if self.total_training_steps % self._log_every_n_steps == 0:
|
400 |
+
self._metric_logger.log_dict(
|
401 |
+
{
|
402 |
+
"loss": loss.item(),
|
403 |
+
"lr": self._optimizer.param_groups[0]["lr"],
|
404 |
+
"gpu_resources": torch.cuda.memory_allocated(),
|
405 |
+
},
|
406 |
+
step=self.total_training_steps,
|
407 |
+
)
|
408 |
+
|
409 |
+
# Does loss normalization need to happen within autocast context?
|
410 |
+
loss = loss / self._gradient_accumulation_steps
|
411 |
+
self._grad_scaler.scale(loss).backward()
|
412 |
+
|
413 |
+
if self._should_update_weights(idx):
|
414 |
+
self._grad_scaler.step(self._optimizer)
|
415 |
+
self._grad_scaler.update()
|
416 |
+
self._optimizer.zero_grad(set_to_none=True)
|
417 |
+
|
418 |
+
# Update the number of steps when the weights are updated
|
419 |
+
self.total_training_steps += 1
|
420 |
+
|
421 |
+
self.epochs_run += 1
|
422 |
+
self.save_checkpoint(epoch=curr_epoch)
|
423 |
+
|
424 |
+
def cleanup(self) -> None:
|
425 |
+
self._metric_logger.close()
|
426 |
+
|
427 |
+
|
428 |
+
def recipe_main() -> None:
|
429 |
+
"""
|
430 |
+
Entry point for the recipe.
|
431 |
+
|
432 |
+
Configurable parameters are read in the following order:
|
433 |
+
- Parameters specified in ``FullFinetuneParams``
|
434 |
+
- Overwritten by Parameters specified in ``alpaca_llama2_full_finetune.yaml``
|
435 |
+
- Overwritten by arguments from the command-line using ``TuneArgumentParser``
|
436 |
+
"""
|
437 |
+
parser = utils.TuneArgumentParser(
|
438 |
+
description=FullFinetuneParams.__doc__,
|
439 |
+
formatter_class=argparse.RawDescriptionHelpFormatter,
|
440 |
+
)
|
441 |
+
args, _ = parser.parse_known_args()
|
442 |
+
args = vars(args)
|
443 |
+
recipe_params = FullFinetuneParams(**args)
|
444 |
+
|
445 |
+
# Env variables set by torch run; only need to initialize process group
|
446 |
+
# init_process_group(backend="nccl")
|
447 |
+
|
448 |
+
recipe = FullFinetuneRecipe(params=recipe_params)
|
449 |
+
recipe.setup(params=recipe_params)
|
450 |
+
recipe.train()
|
451 |
+
recipe.cleanup()
|
452 |
+
|
453 |
+
|
454 |
+
if __name__ == "__main__":
|
455 |
+
sys.exit(recipe_main())
|
colorful/adversarial_config.yaml
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Runs the full_finetune.py recipe
|
2 |
+
#
|
3 |
+
# To launch, run the following command from root:
|
4 |
+
# tune --nnodes 1 --nproc_per_node 1 --config alpaca_llama2_full_finetune --override model_checkpoint=<your_checkpoint_dir> ...
|
5 |
+
|
6 |
+
# Dataset and Dataloader
|
7 |
+
dataset: laurencer/yahma-alpaca-cleaned-adversarial
|
8 |
+
seed: 42
|
9 |
+
shuffle: True
|
10 |
+
|
11 |
+
# Checkpointing
|
12 |
+
# Removed for now given poor upload speeds for checkpoints
|
13 |
+
# hf_repo_id: laurencer/Llama7b-Alpaca-Tune-4epochs-WithColoring
|
14 |
+
checkpoint_every_n_steps: 500 # 6k steps per epoch
|
15 |
+
|
16 |
+
# Model Arguments
|
17 |
+
model_checkpoint: model/llama2_native.tune
|
18 |
+
tokenizer_checkpoint: model/tokenizer.model
|
19 |
+
|
20 |
+
color_layer_initialization: zeros
|
21 |
+
norm_before_color_layer: True
|
22 |
+
|
23 |
+
# Fine-tuning arguments
|
24 |
+
compile: False
|
25 |
+
batch_size: 8
|
26 |
+
lr: 2e-5
|
27 |
+
epochs: 4
|
28 |
+
optimizer: SGD
|
29 |
+
loss: CrossEntropyLoss
|
30 |
+
output_dir: output/alpaca-colorful-llama2-finetune
|
31 |
+
device: cuda
|
32 |
+
dtype: bf16
|
33 |
+
enable_fsdp: False
|
34 |
+
enable_activation_checkpointing: True
|
35 |
+
resume_from_checkpoint: False
|
36 |
+
|
37 |
+
# Logging arguments
|
38 |
+
metric_logger_type: wandb
|
39 |
+
project: torchtune
|
colorful/basic_config.yaml
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Runs the full_finetune.py recipe
|
2 |
+
#
|
3 |
+
# To launch, run the following command from root:
|
4 |
+
# tune --nnodes 1 --nproc_per_node 1 --config alpaca_llama2_full_finetune --override model_checkpoint=<your_checkpoint_dir> ...
|
5 |
+
|
6 |
+
# Dataset and Dataloader
|
7 |
+
dataset: yahma/alpaca-cleaned
|
8 |
+
seed: 42
|
9 |
+
shuffle: True
|
10 |
+
|
11 |
+
# Checkpointing
|
12 |
+
# Removed for now given poor upload speeds for checkpoints
|
13 |
+
# hf_repo_id: laurencer/Llama7b-Alpaca-Tune-4epochs-WithColoring
|
14 |
+
checkpoint_every_n_steps: 500 # 6k steps per epoch
|
15 |
+
|
16 |
+
# Model Arguments
|
17 |
+
model_checkpoint: model/llama2_native.tune
|
18 |
+
tokenizer_checkpoint: model/tokenizer.model
|
19 |
+
|
20 |
+
color_layer_initialization: zeros
|
21 |
+
norm_before_color_layer: True
|
22 |
+
|
23 |
+
# Fine-tuning arguments
|
24 |
+
compile: True
|
25 |
+
batch_size: 8
|
26 |
+
lr: 2e-5
|
27 |
+
epochs: 4
|
28 |
+
optimizer: SGD
|
29 |
+
loss: CrossEntropyLoss
|
30 |
+
output_dir: output/alpaca-colorful-llama2-finetune
|
31 |
+
device: cuda
|
32 |
+
dtype: bf16
|
33 |
+
enable_fsdp: False
|
34 |
+
enable_activation_checkpointing: True
|
35 |
+
resume_from_checkpoint: False
|
36 |
+
|
37 |
+
# Logging arguments
|
38 |
+
metric_logger_type: wandb
|
39 |
+
project: torchtune
|
colorful/custom_dataset.py
ADDED
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the BSD-style license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
from typing import List, Tuple
|
8 |
+
|
9 |
+
import torch
|
10 |
+
|
11 |
+
import torch.nn.functional as F
|
12 |
+
from torch.nn.utils.rnn import pad_sequence
|
13 |
+
from torch.utils.data import Dataset
|
14 |
+
|
15 |
+
from datasets import load_dataset
|
16 |
+
|
17 |
+
# Not ideal to import this type here but it's needed for the transform function
|
18 |
+
from torchtune.modules import Tokenizer
|
19 |
+
|
20 |
+
|
21 |
+
CROSS_ENTROPY_IGNORE_IDX = -100
|
22 |
+
|
23 |
+
|
24 |
+
DEFAULT = 0
|
25 |
+
INSTRUCTION = 1
|
26 |
+
INPUT = 2
|
27 |
+
RESPONSE = 3
|
28 |
+
|
29 |
+
|
30 |
+
class ColoringAlpacaDataset(Dataset):
|
31 |
+
"""
|
32 |
+
See torchtune.datasets.alpaca.AlpacaDataset for the original implementation.
|
33 |
+
|
34 |
+
Constructor now takes in a dataset path directly.
|
35 |
+
|
36 |
+
This implementation returns 3 lists representing the tokens, labels, and token colors
|
37 |
+
(as opposed to just the tokens & labels from the original).
|
38 |
+
"""
|
39 |
+
|
40 |
+
def __init__(
|
41 |
+
self,
|
42 |
+
tokenizer: Tokenizer,
|
43 |
+
dataset_path: str = "yahma/alpaca-cleaned",
|
44 |
+
train_on_input: bool = True,
|
45 |
+
**kwargs
|
46 |
+
) -> None:
|
47 |
+
self._data = load_dataset(dataset_path, split="train")
|
48 |
+
self._tokenizer = tokenizer
|
49 |
+
self.train_on_input = train_on_input
|
50 |
+
self.num_colors = 4 # matches the above usage of DEFAULT, INSTRUCTION, INPUT, RESPONSE
|
51 |
+
|
52 |
+
def __len__(self):
|
53 |
+
return len(self._data)
|
54 |
+
|
55 |
+
def __getitem__(self, index: int) -> Tuple[List[int], List[int], List[int]]:
|
56 |
+
sample = self._data[index]
|
57 |
+
|
58 |
+
return self._transform(
|
59 |
+
instruction=sample["instruction"],
|
60 |
+
input=sample["input"],
|
61 |
+
output=sample["output"],
|
62 |
+
)
|
63 |
+
|
64 |
+
def _transform(
|
65 |
+
self, instruction: str, input: str, output: str
|
66 |
+
) -> Tuple[List[int], List[int], List[int]]:
|
67 |
+
"""
|
68 |
+
Split a sample on ``response`` tag to create input and labels.
|
69 |
+
|
70 |
+
Args:
|
71 |
+
instruction (str): Instruction text.
|
72 |
+
input (str): Input text. Can be an empty string. Determines the prompt generation template
|
73 |
+
used.
|
74 |
+
output (str): Response text.
|
75 |
+
|
76 |
+
Returns:
|
77 |
+
Tuple of encoded inputs, labels, token colors.
|
78 |
+
"""
|
79 |
+
prompt = self._generate_prompt(instruction, input)
|
80 |
+
|
81 |
+
# First handle the prompt
|
82 |
+
colors = []
|
83 |
+
tokenized = []
|
84 |
+
labels = []
|
85 |
+
is_first = True
|
86 |
+
for token_type, text in prompt:
|
87 |
+
tokenized_part = self._tokenizer.encode(
|
88 |
+
text=text, add_bos=is_first, add_eos=False
|
89 |
+
)
|
90 |
+
is_first = False
|
91 |
+
|
92 |
+
tokenized += tokenized_part
|
93 |
+
colors += [token_type] * len(tokenized_part)
|
94 |
+
if not self.train_on_input:
|
95 |
+
labels += [CROSS_ENTROPY_IGNORE_IDX] * len(tokenized_part)
|
96 |
+
else:
|
97 |
+
labels += tokenized_part
|
98 |
+
|
99 |
+
# Now add the response tokens
|
100 |
+
tokenized_part = self._tokenizer.encode(
|
101 |
+
text=output, add_bos=False, add_eos=True
|
102 |
+
)
|
103 |
+
tokenized += tokenized_part
|
104 |
+
colors += [RESPONSE] * len(tokenized_part)
|
105 |
+
labels += tokenized_part
|
106 |
+
|
107 |
+
assert len(tokenized) == len(labels)
|
108 |
+
assert len(tokenized) == len(colors)
|
109 |
+
|
110 |
+
return tokenized, labels, colors
|
111 |
+
|
112 |
+
def _generate_prompt(self, instruction: str, input: str) -> List[Tuple[(int, str)]]:
|
113 |
+
"""
|
114 |
+
Generate prompt from instruction and input.
|
115 |
+
|
116 |
+
Args:
|
117 |
+
instruction (str): Instruction text.
|
118 |
+
input (str): Input text.
|
119 |
+
|
120 |
+
Returns:
|
121 |
+
List of (int, templated text)
|
122 |
+
"""
|
123 |
+
if input:
|
124 |
+
return [
|
125 |
+
(DEFAULT, (
|
126 |
+
"Below is an instruction that describes a task, paired with an input that provides further context. "
|
127 |
+
"Write a response that appropriately completes the request.\n\n"
|
128 |
+
"### Instruction:\n"
|
129 |
+
)),
|
130 |
+
(INSTRUCTION, instruction),
|
131 |
+
(DEFAULT, "\n\n### Input:\n"),
|
132 |
+
(INPUT, input),
|
133 |
+
(DEFAULT, "\n\n### Response:\n"),
|
134 |
+
]
|
135 |
+
else:
|
136 |
+
return [
|
137 |
+
(DEFAULT, (
|
138 |
+
"Below is an instruction that describes a task. "
|
139 |
+
"Write a response that appropriately completes the request.\n\n"
|
140 |
+
"### Instruction:\n"
|
141 |
+
)),
|
142 |
+
(INSTRUCTION, instruction),
|
143 |
+
(DEFAULT, "\n\n### Response:\n"),
|
144 |
+
]
|
145 |
+
|
146 |
+
|
147 |
+
# TokenPair is a pair (tuple) of three lists: tokenized text inputs, labels, colors.
|
148 |
+
TokenPair = Tuple[List[int], List[int], List[int]]
|
149 |
+
|
150 |
+
|
151 |
+
def padded_collate(
|
152 |
+
batch: List[TokenPair],
|
153 |
+
padding_idx: int = 0,
|
154 |
+
ignore_idx: int = -100,
|
155 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
156 |
+
input_ids = pad_sequence(
|
157 |
+
[torch.tensor(x[0]) for x in batch],
|
158 |
+
batch_first=True,
|
159 |
+
padding_value=padding_idx,
|
160 |
+
)
|
161 |
+
labels = pad_sequence(
|
162 |
+
[torch.tensor(x[1]) for x in batch],
|
163 |
+
batch_first=True,
|
164 |
+
padding_value=ignore_idx,
|
165 |
+
)
|
166 |
+
colors = pad_sequence(
|
167 |
+
[torch.tensor(x[2]) for x in batch],
|
168 |
+
batch_first=True,
|
169 |
+
padding_value=padding_idx,
|
170 |
+
)
|
171 |
+
|
172 |
+
input_ids_seq_len = input_ids.shape[-1]
|
173 |
+
labels_seq_len = labels.shape[-1]
|
174 |
+
colors_seq_len = colors.shape[-1]
|
175 |
+
|
176 |
+
assert input_ids_seq_len == labels_seq_len
|
177 |
+
assert input_ids_seq_len == colors_seq_len
|
178 |
+
|
179 |
+
return input_ids, labels, colors
|
colorful/custom_model.py
ADDED
@@ -0,0 +1,267 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional
|
2 |
+
import copy
|
3 |
+
import math
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from torch import nn, Tensor
|
7 |
+
|
8 |
+
from torchtune.modules import (
|
9 |
+
CausalSelfAttention,
|
10 |
+
FeedForward,
|
11 |
+
KVCache,
|
12 |
+
RMSNorm,
|
13 |
+
RotaryPositionalEmbeddings,
|
14 |
+
# TransformerDecoder, replaced with our custom implementation.
|
15 |
+
TransformerDecoderLayer,
|
16 |
+
)
|
17 |
+
|
18 |
+
from masked_apply import MaskedApply
|
19 |
+
|
20 |
+
|
21 |
+
def initialize_identity_linear(size):
|
22 |
+
layer = nn.Linear(size, size)
|
23 |
+
layer.weight.data.copy_(torch.eye(size))
|
24 |
+
layer.bias.data.copy_(torch.zeros(size))
|
25 |
+
return layer
|
26 |
+
|
27 |
+
def initialize_linear(size):
|
28 |
+
return nn.Linear(size, size)
|
29 |
+
|
30 |
+
def initialize_kaiming_uniform_linear(size):
|
31 |
+
layer = nn.Linear(size, size)
|
32 |
+
nn.init.kaiming_uniform_(layer.weight, a=math.sqrt(5))
|
33 |
+
layer.bias.data.copy_(torch.zeros(size))
|
34 |
+
return layer
|
35 |
+
|
36 |
+
def initialize_zeros_linear(size):
|
37 |
+
layer = nn.Linear(size, size)
|
38 |
+
layer.weight.data.copy_(torch.zeros(size))
|
39 |
+
layer.bias.data.copy_(torch.zeros(size))
|
40 |
+
return layer
|
41 |
+
|
42 |
+
INITIALIZATION_OPTIONS = {
|
43 |
+
"identity": initialize_identity_linear,
|
44 |
+
"default": initialize_linear,
|
45 |
+
"kaiming_uniform": initialize_kaiming_uniform_linear,
|
46 |
+
"zeros": initialize_zeros_linear,
|
47 |
+
}
|
48 |
+
|
49 |
+
def _get_clones(module: nn.Module, n: int) -> nn.ModuleList:
|
50 |
+
"""
|
51 |
+
Return a list of ``n`` identical layers.
|
52 |
+
|
53 |
+
Args:
|
54 |
+
module (nn.Module): module to be cloned
|
55 |
+
n (int): number of clones
|
56 |
+
|
57 |
+
Returns:
|
58 |
+
nn.ModuleList: list of ``n`` identical layers
|
59 |
+
"""
|
60 |
+
# FIXME: copy.deepcopy() is not defined on nn.module
|
61 |
+
return nn.ModuleList([copy.deepcopy(module) for i in range(n)])
|
62 |
+
|
63 |
+
|
64 |
+
class ColoringTransformerDecoder(nn.Module):
|
65 |
+
"""
|
66 |
+
See torchtune.models.llama2.TransformerDecoder for the original implementation.
|
67 |
+
"""
|
68 |
+
|
69 |
+
def __init__(
|
70 |
+
self,
|
71 |
+
tok_embeddings: nn.Embedding,
|
72 |
+
embedding_transform: nn.Module,
|
73 |
+
layer: TransformerDecoderLayer,
|
74 |
+
num_layers: int,
|
75 |
+
norm: nn.Module,
|
76 |
+
output: nn.Linear,
|
77 |
+
embedding_norm: nn.Module = None
|
78 |
+
) -> None:
|
79 |
+
super().__init__()
|
80 |
+
self.tok_embeddings = tok_embeddings
|
81 |
+
self.embedding_transform = embedding_transform
|
82 |
+
self.embedding_norm = embedding_norm
|
83 |
+
self.layers = _get_clones(layer, num_layers)
|
84 |
+
self.norm = norm
|
85 |
+
self.output = output
|
86 |
+
|
87 |
+
def forward(
|
88 |
+
self,
|
89 |
+
tokens: Tensor,
|
90 |
+
mask: Optional[Tensor] = None,
|
91 |
+
colors: Optional[Tensor] = None,
|
92 |
+
curr_pos: int = 0
|
93 |
+
) -> Tensor:
|
94 |
+
"""
|
95 |
+
Args:
|
96 |
+
tokens (Tensor): input tensor with shape [b x s]
|
97 |
+
mask (Optional[Tensor]): attention mask tensor, defaults to None.
|
98 |
+
curr_pos (int): current position in the seq, defaults to 0.
|
99 |
+
Only relevant when incrementally decoding.
|
100 |
+
|
101 |
+
Returns:
|
102 |
+
Tensor: output tensor with shape [b x s x v]
|
103 |
+
|
104 |
+
Notation used for tensor shapes:
|
105 |
+
- b: batch size
|
106 |
+
- s: sequence length
|
107 |
+
- v: vocab size
|
108 |
+
- d: embed dim
|
109 |
+
"""
|
110 |
+
# input tensor of shape [b, s]
|
111 |
+
bsz, seq_len = tokens.shape
|
112 |
+
|
113 |
+
# shape: [b, s, d]
|
114 |
+
h = self.tok_embeddings(tokens)
|
115 |
+
|
116 |
+
# Apply normalization before embedding transform to improve
|
117 |
+
# training stability.
|
118 |
+
ch = h
|
119 |
+
if self.embedding_norm is not None:
|
120 |
+
# TODO: norm does an in-place operation, so we need to clone the input
|
121 |
+
ch = self.embedding_norm(h.clone())
|
122 |
+
|
123 |
+
# Apply the embedding transform (e.g. color layer)
|
124 |
+
ch = self.embedding_transform(ch, colors)
|
125 |
+
|
126 |
+
# Add the output of the color transform to the embeddings
|
127 |
+
h = h + ch
|
128 |
+
|
129 |
+
# TODO: Fix the masking logic to not rely on checking kv_cache
|
130 |
+
if seq_len > 1 and self.layers[0].attn.kv_cache is not None:
|
131 |
+
mask = torch.full(
|
132 |
+
(1, 1, seq_len, seq_len), float("-inf"), device=tokens.device
|
133 |
+
)
|
134 |
+
mask = torch.triu(mask, diagonal=curr_pos + 1)
|
135 |
+
|
136 |
+
for layer in self.layers:
|
137 |
+
# shape: [b, s, d]
|
138 |
+
h = layer(h, mask, curr_pos)
|
139 |
+
|
140 |
+
# shape: [b, s, d]
|
141 |
+
h = self.norm(h)
|
142 |
+
|
143 |
+
# shape: [b, s, v]
|
144 |
+
output = self.output(h).float()
|
145 |
+
return output
|
146 |
+
|
147 |
+
|
148 |
+
def coloring_llama2_7b(color_layer_initialization: str, norm_before_color_layer: bool = False, max_batch_size: Optional[int] = None) -> ColoringTransformerDecoder:
|
149 |
+
"""Builder for creating a Llama2 model initialized w/ the default 7b parameter values.
|
150 |
+
From https://arxiv.org/abs/2307.09288, these default values are:
|
151 |
+
- vocab_size: 32,000
|
152 |
+
- embed_dim: 4,096
|
153 |
+
- num_layers: 32
|
154 |
+
- num_heads: 32
|
155 |
+
- num_kv_heads: 32
|
156 |
+
- max_seq_len: 4,096
|
157 |
+
- norm_eps: 1e-5
|
158 |
+
|
159 |
+
Args:
|
160 |
+
max_batch_size (Optional[int]): Maximum batch size to be passed to KVCache.
|
161 |
+
|
162 |
+
Returns:
|
163 |
+
A ``TransformerDecoder`` instance of the Llama2 model.
|
164 |
+
"""
|
165 |
+
return coloring_llama2(
|
166 |
+
color_layer_initialization=color_layer_initialization,
|
167 |
+
vocab_size=32_000,
|
168 |
+
num_layers=32,
|
169 |
+
num_heads=32,
|
170 |
+
num_kv_heads=32,
|
171 |
+
embed_dim=4096,
|
172 |
+
max_seq_len=4096,
|
173 |
+
num_colors=4, # color for default, instruction, input, response
|
174 |
+
max_batch_size=max_batch_size,
|
175 |
+
attn_dropout=0.0,
|
176 |
+
norm_eps=1e-5,
|
177 |
+
norm_before_color_layer=norm_before_color_layer
|
178 |
+
)
|
179 |
+
|
180 |
+
def _scale_hidden_dim_for_mlp(dim: int, multiple_of: int = 256) -> int:
|
181 |
+
"""Scale hidden dimension for MLP to keep number of parameters and computation constant.
|
182 |
+
|
183 |
+
Args:
|
184 |
+
dim (int): Input dimension.
|
185 |
+
multiple_of (int): Round scaled dimension to nearest multiple of `multiple_of` for clean computation.
|
186 |
+
|
187 |
+
Returns:
|
188 |
+
Scaled hidden dimension.
|
189 |
+
"""
|
190 |
+
# Scale hidden dimension by (2/3)4d for SwiGLU to keep number of
|
191 |
+
# parameters and computation constant
|
192 |
+
hidden_dim = 4 * int(2 * dim / 3)
|
193 |
+
# Round hidden dimension to nearest multiple of `multiple_of`
|
194 |
+
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
|
195 |
+
return hidden_dim
|
196 |
+
|
197 |
+
|
198 |
+
def coloring_llama2(
|
199 |
+
color_layer_initialization: str,
|
200 |
+
vocab_size: int,
|
201 |
+
num_layers: int,
|
202 |
+
num_heads: int,
|
203 |
+
num_kv_heads: int,
|
204 |
+
embed_dim: int,
|
205 |
+
max_seq_len: int,
|
206 |
+
num_colors: int,
|
207 |
+
norm_before_color_layer: bool = False,
|
208 |
+
attn_dropout: float = 0.0,
|
209 |
+
max_batch_size: Optional[int] = None,
|
210 |
+
norm_eps: float = 1e-5,
|
211 |
+
):
|
212 |
+
if color_layer_initialization not in INITIALIZATION_OPTIONS:
|
213 |
+
raise ValueError(f"Invalid color_layer_initialization: {color_layer_initialization}. Expected one of {list(INITIALIZATION_OPTIONS.keys())}.")
|
214 |
+
color_layer_initializer = INITIALIZATION_OPTIONS[color_layer_initialization]
|
215 |
+
|
216 |
+
head_dim = embed_dim // num_heads
|
217 |
+
num_kv_heads = num_kv_heads if num_kv_heads else num_heads
|
218 |
+
kv_cache = (
|
219 |
+
KVCache(
|
220 |
+
max_batch_size=max_batch_size,
|
221 |
+
max_seq_len=max_seq_len,
|
222 |
+
n_kv_heads=num_heads,
|
223 |
+
head_dim=head_dim,
|
224 |
+
)
|
225 |
+
if max_batch_size is not None
|
226 |
+
else None
|
227 |
+
)
|
228 |
+
rope = RotaryPositionalEmbeddings(dim=head_dim, max_seq_len=max_seq_len)
|
229 |
+
self_attn = CausalSelfAttention(
|
230 |
+
embed_dim=embed_dim,
|
231 |
+
num_heads=num_heads,
|
232 |
+
num_kv_heads=num_kv_heads,
|
233 |
+
head_dim=head_dim,
|
234 |
+
q_proj=nn.Linear(embed_dim, num_heads * head_dim, bias=False),
|
235 |
+
k_proj=nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False),
|
236 |
+
v_proj=nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False),
|
237 |
+
output_proj=nn.Linear(embed_dim, embed_dim, bias=False),
|
238 |
+
pos_embeddings=rope,
|
239 |
+
kv_cache=kv_cache,
|
240 |
+
max_seq_len=max_seq_len,
|
241 |
+
attn_dropout=attn_dropout,
|
242 |
+
)
|
243 |
+
hidden_dim = _scale_hidden_dim_for_mlp(embed_dim)
|
244 |
+
mlp = FeedForward(dim=embed_dim, hidden_dim=hidden_dim, linear_class=nn.Linear)
|
245 |
+
layer = TransformerDecoderLayer(
|
246 |
+
attn=self_attn,
|
247 |
+
mlp=mlp,
|
248 |
+
sa_norm=RMSNorm(dim=embed_dim, eps=norm_eps),
|
249 |
+
mlp_norm=RMSNorm(dim=embed_dim, eps=norm_eps),
|
250 |
+
)
|
251 |
+
tok_embeddings = nn.Embedding(vocab_size, embed_dim)
|
252 |
+
output_proj = nn.Linear(embed_dim, vocab_size, bias=False)
|
253 |
+
embedding_transform = MaskedApply(
|
254 |
+
[color_layer_initializer(embed_dim) for _ in range(num_colors)],
|
255 |
+
strict=True
|
256 |
+
)
|
257 |
+
embedding_norm = RMSNorm(embed_dim, eps=norm_eps) if norm_before_color_layer else None
|
258 |
+
|
259 |
+
return ColoringTransformerDecoder(
|
260 |
+
tok_embeddings=tok_embeddings,
|
261 |
+
embedding_transform=embedding_transform,
|
262 |
+
embedding_norm=embedding_norm,
|
263 |
+
layer=layer,
|
264 |
+
num_layers=num_layers,
|
265 |
+
norm=RMSNorm(embed_dim, eps=norm_eps),
|
266 |
+
output=output_proj,
|
267 |
+
)
|
colorful/custom_params.py
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass, field, fields
|
2 |
+
from typing import List, Optional
|
3 |
+
|
4 |
+
from torchtune.datasets import ALL_DATASETS
|
5 |
+
from torchtune.models import ALL_MODELS, ALL_TOKENIZERS
|
6 |
+
from torchtune.utils.metric_logging import ALL_METRIC_LOGGERS
|
7 |
+
from torchtune.utils.precision import PRECISION_STR_TO_DTYPE
|
8 |
+
|
9 |
+
|
10 |
+
@dataclass
|
11 |
+
class ColoringFinetuneParams:
|
12 |
+
"""Arguments for the finetune_llm recipe.
|
13 |
+
|
14 |
+
Args:
|
15 |
+
device (str): Device to use for training. Options are "cpu" and "cuda"
|
16 |
+
dtype (str): Data type to use for training.
|
17 |
+
seed (int): Random seed to use for training.
|
18 |
+
model (str): String specifying model architecture to fine-tune. See ``torchtune.models.get_model`` for options.
|
19 |
+
model_checkpoint (str): Local path to load model checkpoint from.
|
20 |
+
tokenizer (str): String specifying tokenizer to use. See ``torchtune.models.get_tokenizer`` for options.
|
21 |
+
tokenizer_checkpoint (str): Local path to load tokenizer checkpoint from.
|
22 |
+
dataset (str): String specifying dataset to use. See ``torchtune.datasets.get_dataset`` for options.
|
23 |
+
Currently, only predefined datasets in library are supported.
|
24 |
+
shuffle (bool): Whether to shuffle dataset.
|
25 |
+
batch_size (int): Batch size to use for training.
|
26 |
+
epochs (int): Number of epochs to train for.
|
27 |
+
optimizer (str): String specifying optimizer to use. See ``torchtune.optim.get_optimizer`` for options.
|
28 |
+
loss (str): String specifying loss function to use. See ``torchtune.losses.get_loss`` for options.
|
29 |
+
lr (float): Learning rate to use for optimizer.
|
30 |
+
activation_checkpointing (bool): Whether to use activation checkpointing.
|
31 |
+
output_dir (str): Local path to save checkpoints and logs to.
|
32 |
+
run_generation (int): Run eval on a prompt every ``run_generation`` steps. Set to 0 to disable.
|
33 |
+
max_steps_per_epoch (int): Maximum number of steps to take per epoch.
|
34 |
+
metric_logger_type (str): String specifying metric logger to use. See ``torchtune.utils.get_metric_logger``
|
35 |
+
for options.
|
36 |
+
project (str): Project name to use for logging. Used by ``WandBLogger``.
|
37 |
+
resume_from_previous_checkpoint (bool): Whether to resume fine-tuning from a previous checkpoint.
|
38 |
+
cpu_offload (bool): Whether to offload model to CPU.
|
39 |
+
|
40 |
+
Raises:
|
41 |
+
ValueError: If ``cpu_offload`` is ``True`` but ``device`` is not ``cuda`` and <= 1 GPUs.
|
42 |
+
"""
|
43 |
+
|
44 |
+
# Model
|
45 |
+
model_checkpoint: str = ""
|
46 |
+
|
47 |
+
color_layer_initialization: str = "default"
|
48 |
+
norm_before_color_layer: bool = False
|
49 |
+
|
50 |
+
# Tokenizer
|
51 |
+
tokenizer_checkpoint: str = ""
|
52 |
+
|
53 |
+
hf_repo_id: Optional[str] = None
|
54 |
+
checkpoint_every_n_steps: Optional[int] = None
|
55 |
+
|
56 |
+
# Dataset and Sampler
|
57 |
+
dataset: str = ""
|
58 |
+
train_on_input: bool = True
|
59 |
+
shuffle: bool = True
|
60 |
+
batch_size: int = 2
|
61 |
+
|
62 |
+
# Optimizer and Scheduler
|
63 |
+
optimizer: str = "SGD"
|
64 |
+
lr: float = 2e-5
|
65 |
+
loss: str = "CrossEntropyLoss"
|
66 |
+
gradient_accumulation_steps: int = 1
|
67 |
+
|
68 |
+
# Training
|
69 |
+
compile: bool = False
|
70 |
+
epochs: int = 3
|
71 |
+
max_steps_per_epoch: Optional[int] = None
|
72 |
+
resume_from_checkpoint: bool = False
|
73 |
+
run_generation: Optional[int] = None
|
74 |
+
|
75 |
+
# Distributed
|
76 |
+
cpu_offload: bool = False
|
77 |
+
enable_fsdp: bool = True
|
78 |
+
enable_activation_checkpointing: bool = True
|
79 |
+
|
80 |
+
# Environment
|
81 |
+
device: str = "cuda"
|
82 |
+
dtype: str = "fp16"
|
83 |
+
seed: Optional[int] = None
|
84 |
+
|
85 |
+
# Logging
|
86 |
+
output_dir: str = "/tmp/full_finetune_output"
|
87 |
+
metric_logger_type: str = "disk"
|
88 |
+
project: Optional[str] = None
|
89 |
+
log_every_n_steps: Optional[int] = None
|
90 |
+
|
91 |
+
def __post_init__(self):
|
92 |
+
for param in fields(self):
|
93 |
+
if getattr(self, param.name) == "":
|
94 |
+
raise TypeError(f"{param.name} needs to be specified")
|
95 |
+
|
96 |
+
if self.cpu_offload and self.device != "cuda":
|
97 |
+
raise ValueError(
|
98 |
+
"Cannot offload model to CPU if device is not cuda or <= 1 GPUs."
|
99 |
+
)
|
100 |
+
if self.enable_fsdp and self.device == "cpu":
|
101 |
+
raise ValueError("FSDP is not supported on CPU.")
|
102 |
+
|
103 |
+
if self.metric_logger_type not in ALL_METRIC_LOGGERS:
|
104 |
+
raise ValueError(
|
105 |
+
f"Metric logger not recognized. Expected one of {ALL_METRIC_LOGGERS}, received {self.metric_logger_type}."
|
106 |
+
)
|
107 |
+
if self.dtype not in PRECISION_STR_TO_DTYPE:
|
108 |
+
raise ValueError(
|
109 |
+
f"Dtype {self.dtype} must be one of {', '.join(PRECISION_STR_TO_DTYPE.keys())} for finetuning."
|
110 |
+
)
|
colorful/full_finetune.py
ADDED
@@ -0,0 +1,511 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the BSD-style license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import argparse
|
8 |
+
import os
|
9 |
+
import sys
|
10 |
+
|
11 |
+
from functools import partial
|
12 |
+
from typing import Any, Dict, Optional, Tuple
|
13 |
+
from warnings import warn
|
14 |
+
|
15 |
+
import torch
|
16 |
+
|
17 |
+
from torch import nn
|
18 |
+
from torch.cuda.amp import GradScaler
|
19 |
+
from torch.distributed import init_process_group
|
20 |
+
from torch.optim import Optimizer
|
21 |
+
from torch.utils.data import DataLoader, DistributedSampler
|
22 |
+
from torchtune.utils import get_device
|
23 |
+
|
24 |
+
from torchtune import models, modules, utils
|
25 |
+
from torchtune.utils.constants import (
|
26 |
+
EPOCHS_KEY,
|
27 |
+
MAX_STEPS_KEY,
|
28 |
+
MODEL_KEY,
|
29 |
+
OPT_KEY,
|
30 |
+
SEED_KEY,
|
31 |
+
TOTAL_EPOCHS_KEY,
|
32 |
+
)
|
33 |
+
|
34 |
+
from tqdm import tqdm
|
35 |
+
|
36 |
+
from recipes.interfaces import FTRecipeInterface
|
37 |
+
from recipes.params import FullFinetuneParams
|
38 |
+
|
39 |
+
from torchtune.models.llama2 import llama2_tokenizer
|
40 |
+
|
41 |
+
from huggingface_hub import HfApi
|
42 |
+
|
43 |
+
from custom_params import ColoringFinetuneParams
|
44 |
+
from custom_model import ColoringTransformerDecoder, coloring_llama2_7b
|
45 |
+
from custom_dataset import ColoringAlpacaDataset, padded_collate
|
46 |
+
|
47 |
+
log = utils.get_logger("DEBUG")
|
48 |
+
|
49 |
+
|
50 |
+
class ColoringFinetuneRecipe(FTRecipeInterface):
|
51 |
+
"""
|
52 |
+
Full finetuning recipe for dense transformer-based LLMs such as Llama2.
|
53 |
+
|
54 |
+
This recipe supports:
|
55 |
+
- FSDP and activation checkpointing. This is enabled by default but can be
|
56 |
+
configured using the ``enable_fsdp`` and ``enable_activation_checkpointing`` flags.
|
57 |
+
- Mixed precision training - fp32, fp16 and bf16 are supported.
|
58 |
+
- Checkpointing of model weights, optimizer state and the recipe state (epoch and seed).
|
59 |
+
- Resuming from checkpoints saved using the ``save_checkpoint`` functionality.
|
60 |
+
- Logging to terminal. WandB and TensorBoard are currently not supported.
|
61 |
+
|
62 |
+
Assumptions:
|
63 |
+
- Training is launched with the Tune CLI (recommended) which uses TorchRun under the
|
64 |
+
hood. Setting up the env variables is handled by TorchRun.
|
65 |
+
- Training happens on CUDA (CPU training is not supported)
|
66 |
+
- Checkpoints are ONLY saved at epoch boundaries. Mid-epoch checkpointing is NOT supported.
|
67 |
+
- Datasets are Map-style and data fits in memory (not streamed).
|
68 |
+
"""
|
69 |
+
|
70 |
+
_model: ColoringTransformerDecoder
|
71 |
+
|
72 |
+
def __init__(self, params: ColoringFinetuneParams) -> None:
|
73 |
+
self._params = params
|
74 |
+
|
75 |
+
self._device = utils.get_device(device=params.device)
|
76 |
+
self._dtype = utils.get_dtype(dtype=params.dtype)
|
77 |
+
|
78 |
+
self._hf_hub = HfApi()
|
79 |
+
self._hf_repo_id = params.hf_repo_id
|
80 |
+
|
81 |
+
if self._hf_repo_id is not None:
|
82 |
+
self._hf_hub.create_repo(
|
83 |
+
repo_id=self._hf_repo_id,
|
84 |
+
repo_type="model",
|
85 |
+
private=True,
|
86 |
+
exist_ok=True
|
87 |
+
)
|
88 |
+
|
89 |
+
# logging attributes
|
90 |
+
self._output_dir = params.output_dir
|
91 |
+
self._metric_logger = utils.get_metric_logger(
|
92 |
+
metric_logger_type=params.metric_logger_type,
|
93 |
+
project=params.project,
|
94 |
+
log_dir=params.output_dir,
|
95 |
+
)
|
96 |
+
self._log_every_n_steps = (
|
97 |
+
params.log_every_n_steps if params.log_every_n_steps else 1
|
98 |
+
)
|
99 |
+
|
100 |
+
self._checkpoint_every_n_steps = params.checkpoint_every_n_steps
|
101 |
+
|
102 |
+
# _is_rank_zero is used primarily for logging. In the future, the logger
|
103 |
+
# should directly take care of this
|
104 |
+
_, rank = utils.get_world_size_and_rank()
|
105 |
+
self._is_rank_zero = rank == 0
|
106 |
+
|
107 |
+
# Training params
|
108 |
+
self._compile = params.compile
|
109 |
+
self._resume_from_checkpoint = params.resume_from_checkpoint
|
110 |
+
self._enable_fsdp = params.enable_fsdp
|
111 |
+
self._gradient_accumulation_steps = params.gradient_accumulation_steps
|
112 |
+
|
113 |
+
# These are public properties which are updated by the checkpoint loader
|
114 |
+
# when ``resume_from_checkpoint`` is `True` or validated in tests
|
115 |
+
self.seed = utils.set_seed(seed=params.seed)
|
116 |
+
self.epochs_run = 0
|
117 |
+
self.total_epochs = params.epochs
|
118 |
+
self.max_steps_per_epoch = params.max_steps_per_epoch
|
119 |
+
self.total_training_steps = 0
|
120 |
+
|
121 |
+
def load_checkpoint(self, ckpt_path: str):
|
122 |
+
"""
|
123 |
+
Extract the checkpoint state from file and validate.
|
124 |
+
"""
|
125 |
+
ckpt_dict = torch.load(ckpt_path, map_location="cpu", weights_only=True)
|
126 |
+
utils.validate_checkpoint(ckpt_dict, self._resume_from_checkpoint)
|
127 |
+
return ckpt_dict
|
128 |
+
|
129 |
+
def setup(self, params: FullFinetuneParams) -> None:
|
130 |
+
"""
|
131 |
+
Sets up the recipe state correctly. This includes setting recipe attributes based
|
132 |
+
on the ``resume_from_checkpoint`` flag.
|
133 |
+
"""
|
134 |
+
|
135 |
+
ckpt_dict = self.load_checkpoint(ckpt_path=params.model_checkpoint)
|
136 |
+
|
137 |
+
# If we're resuming from checkpoint, the recipe's state should be updated before
|
138 |
+
# initializing the training components. This ensures that the seed is correctly
|
139 |
+
# propagated to the relevant components
|
140 |
+
if self._resume_from_checkpoint:
|
141 |
+
self._update_recipe_state(ckpt_dict)
|
142 |
+
|
143 |
+
# ``_setup_model`` handles initialization and loading the state dict. This method
|
144 |
+
# should be called before ``_setup_optimizer`` since transforming the optimizer
|
145 |
+
# state dict requires the model
|
146 |
+
self._model = self._setup_model(
|
147 |
+
enable_fsdp=params.enable_fsdp,
|
148 |
+
enable_activation_checkpointing=params.enable_activation_checkpointing,
|
149 |
+
model_state_dict=ckpt_dict[MODEL_KEY],
|
150 |
+
)
|
151 |
+
|
152 |
+
self._tokenizer = self._setup_tokenizer(
|
153 |
+
tokenizer_checkpoint=params.tokenizer_checkpoint
|
154 |
+
)
|
155 |
+
|
156 |
+
# _setup_optimizer should take in ckpt_dict only if training is resumed from
|
157 |
+
# checkpoint. Transforming the opt state dict is handled by this method
|
158 |
+
self._optimizer = self._setup_optimizer(
|
159 |
+
optimizer=params.optimizer,
|
160 |
+
lr=params.lr,
|
161 |
+
opt_state_dict=ckpt_dict[OPT_KEY] if self._resume_from_checkpoint else None,
|
162 |
+
)
|
163 |
+
|
164 |
+
self._loss_fn = self._setup_loss(loss=params.loss)
|
165 |
+
|
166 |
+
# sampler and dataloader depend on the tokenizer and loss_fn and should be
|
167 |
+
# setup after both of these are initialized
|
168 |
+
self._sampler, self._dataloader = self._setup_data(
|
169 |
+
dataset=params.dataset,
|
170 |
+
train_on_input=params.train_on_input,
|
171 |
+
shuffle=params.shuffle,
|
172 |
+
batch_size=params.batch_size,
|
173 |
+
)
|
174 |
+
|
175 |
+
# training setup
|
176 |
+
self._autocast = utils.get_autocast(self._dtype, self._device)
|
177 |
+
self._grad_scaler = None
|
178 |
+
if self._dtype == torch.float16:
|
179 |
+
self._grad_scaler = utils.get_gradient_scaler(fsdp=params.enable_fsdp)
|
180 |
+
else:
|
181 |
+
self._grad_scaler = GradScaler(enabled=False)
|
182 |
+
|
183 |
+
# Finally update the recipe state which can only be correctly set after all of the
|
184 |
+
# other components have been initialized and updated.
|
185 |
+
#
|
186 |
+
# Number of training steps in each epoch depends on the number of batches produced
|
187 |
+
# by the dataloader, the max_steps_per_epoch param set by the user and the
|
188 |
+
# gradient_accumulation_steps param. This value is used for logging and tracking
|
189 |
+
# training state. The computation should happen after the dataloader has been setup
|
190 |
+
self._steps_per_epoch = (
|
191 |
+
len(self._dataloader) // self._gradient_accumulation_steps
|
192 |
+
)
|
193 |
+
if (
|
194 |
+
self.max_steps_per_epoch is not None
|
195 |
+
and self.max_steps_per_epoch < self._steps_per_epoch
|
196 |
+
):
|
197 |
+
self._steps_per_epoch = self.max_steps_per_epoch
|
198 |
+
self.total_training_steps = self.epochs_run * self._steps_per_epoch
|
199 |
+
|
200 |
+
def _update_recipe_state(self, ckpt_dict: Dict[str, Any]) -> None:
|
201 |
+
"""
|
202 |
+
Updates the recipe state from checkpoint.
|
203 |
+
"""
|
204 |
+
# If seed, total_epoch or max_steps_per_epoch don't match,
|
205 |
+
# warn the user and overwrite
|
206 |
+
if (
|
207 |
+
self.seed != ckpt_dict[SEED_KEY]
|
208 |
+
or self.total_epochs != ckpt_dict[TOTAL_EPOCHS_KEY]
|
209 |
+
or self.max_steps_per_epoch != ckpt_dict[MAX_STEPS_KEY]
|
210 |
+
):
|
211 |
+
warn(
|
212 |
+
message="""Configured value for seed, epochs or max_steps_per_epoch
|
213 |
+
does not match the value stored in checkpoint."""
|
214 |
+
)
|
215 |
+
self.seed = utils.set_seed(seed=ckpt_dict[SEED_KEY])
|
216 |
+
self.epochs_run = ckpt_dict[EPOCHS_KEY]
|
217 |
+
self.total_epochs = ckpt_dict[TOTAL_EPOCHS_KEY]
|
218 |
+
self.max_steps_per_epoch = ckpt_dict[MAX_STEPS_KEY]
|
219 |
+
|
220 |
+
def _setup_model(
|
221 |
+
self,
|
222 |
+
enable_fsdp: bool,
|
223 |
+
enable_activation_checkpointing: bool,
|
224 |
+
model_state_dict: Dict[str, Any],
|
225 |
+
) -> nn.Module:
|
226 |
+
"""
|
227 |
+
Set up the model including enabling FSDP and activation checkpointing. For this recipe,
|
228 |
+
``enable_fsdp`` should always be ``True``. This is currently a configurable flag for
|
229 |
+
running tests on CPUs.
|
230 |
+
"""
|
231 |
+
|
232 |
+
with get_device(self._device):
|
233 |
+
model = coloring_llama2_7b(
|
234 |
+
self._params.color_layer_initialization,
|
235 |
+
norm_before_color_layer=self._params.norm_before_color_layer
|
236 |
+
)
|
237 |
+
|
238 |
+
model = (
|
239 |
+
utils.wrap_fsdp(
|
240 |
+
model=model,
|
241 |
+
device=self._device,
|
242 |
+
dtype=self._dtype,
|
243 |
+
strategy="FULL_SHARD",
|
244 |
+
auto_wrap_policy={modules.TransformerDecoderLayer},
|
245 |
+
)
|
246 |
+
if enable_fsdp
|
247 |
+
else model
|
248 |
+
)
|
249 |
+
if enable_activation_checkpointing:
|
250 |
+
utils.set_activation_checkpointing(
|
251 |
+
model, auto_wrap_policy={modules.TransformerDecoderLayer}
|
252 |
+
)
|
253 |
+
|
254 |
+
model.load_state_dict(model_state_dict, strict=False)
|
255 |
+
|
256 |
+
if self._is_rank_zero:
|
257 |
+
log.info(
|
258 |
+
"Model is initialized. FSDP and Activation Checkpointing are enabled."
|
259 |
+
)
|
260 |
+
|
261 |
+
if self._compile:
|
262 |
+
log.info("Compiling model using torch.compile. The first batch may take a few minutes while compilation occurs.")
|
263 |
+
model = torch.compile(model)
|
264 |
+
else:
|
265 |
+
log.info("Skipping model compilation")
|
266 |
+
|
267 |
+
return model
|
268 |
+
|
269 |
+
def _setup_tokenizer(
|
270 |
+
self, tokenizer_checkpoint: str
|
271 |
+
) -> modules.Tokenizer:
|
272 |
+
"""
|
273 |
+
Unlike ```setup_model```, this takes in the checkpoint and loads the sentencepiece
|
274 |
+
tokenizer model. This is related to how the tokenizer is implemented and should
|
275 |
+
change in a future iteration.
|
276 |
+
"""
|
277 |
+
tokenizer = llama2_tokenizer(tokenizer_checkpoint)
|
278 |
+
|
279 |
+
if self._is_rank_zero:
|
280 |
+
log.info("Tokenizer is initialized from file.")
|
281 |
+
return tokenizer
|
282 |
+
|
283 |
+
def _setup_optimizer(
|
284 |
+
self, optimizer: str, lr: float, opt_state_dict: Optional[Dict[str, Any]] = None
|
285 |
+
) -> Optimizer:
|
286 |
+
"""
|
287 |
+
Set up the optimizer. This method also handles transforing the state dict
|
288 |
+
for FSDP.
|
289 |
+
"""
|
290 |
+
optimizer = modules.get_optimizer(optimizer, self._model, lr)
|
291 |
+
if opt_state_dict:
|
292 |
+
opt_state_dict = utils.transform_opt_state_dict(
|
293 |
+
opt_state_dict, self._model, optimizer
|
294 |
+
)
|
295 |
+
optimizer.load_state_dict(opt_state_dict)
|
296 |
+
|
297 |
+
if self._is_rank_zero:
|
298 |
+
log.info("Optimizer is initialized.")
|
299 |
+
return optimizer
|
300 |
+
|
301 |
+
def _setup_loss(self, loss: str) -> nn.Module:
|
302 |
+
loss_fn = modules.get_loss(loss)
|
303 |
+
|
304 |
+
if self._is_rank_zero:
|
305 |
+
log.info("Loss is initialized.")
|
306 |
+
|
307 |
+
return loss_fn
|
308 |
+
|
309 |
+
def _setup_data(
|
310 |
+
self, dataset: str, shuffle: bool, batch_size: int, train_on_input: bool
|
311 |
+
) -> Tuple[DistributedSampler, DataLoader]:
|
312 |
+
"""
|
313 |
+
All data related setup happens here. Currently this recipe only supports the
|
314 |
+
DistributedSamplers with Map-style Datasets which fit into memory. Other samplers,
|
315 |
+
iterable datasets and streaming datasets are not supported.
|
316 |
+
"""
|
317 |
+
world_size, rank = utils.get_world_size_and_rank()
|
318 |
+
ds = ColoringAlpacaDataset(tokenizer=self._tokenizer, dataset=dataset, train_on_input=train_on_input)
|
319 |
+
|
320 |
+
sampler = DistributedSampler(
|
321 |
+
ds,
|
322 |
+
num_replicas=world_size,
|
323 |
+
rank=rank,
|
324 |
+
shuffle=shuffle,
|
325 |
+
seed=0,
|
326 |
+
)
|
327 |
+
|
328 |
+
dataloader = DataLoader(
|
329 |
+
dataset=ds,
|
330 |
+
batch_size=batch_size,
|
331 |
+
sampler=sampler,
|
332 |
+
collate_fn=partial(
|
333 |
+
padded_collate,
|
334 |
+
padding_idx=self._tokenizer.pad_id,
|
335 |
+
ignore_idx=self._loss_fn.ignore_index, # TODO support loss without ignore_index
|
336 |
+
),
|
337 |
+
)
|
338 |
+
|
339 |
+
if self._is_rank_zero:
|
340 |
+
log.info("Dataset and Sampler are initialized.")
|
341 |
+
|
342 |
+
return sampler, dataloader
|
343 |
+
|
344 |
+
def save_checkpoint(self, epoch: int) -> None:
|
345 |
+
"""
|
346 |
+
Checkpoint the relevant state of a recipe.
|
347 |
+
|
348 |
+
This makes use of the `save_checkpoint` utility which is responsible for
|
349 |
+
writing the checkpoint dictionary to file. The contents of the dict are dictated
|
350 |
+
by whether training is complete or not.
|
351 |
+
|
352 |
+
If training is ongoing, optimizer state, seed and epochs_run are saved along with the
|
353 |
+
model weights.
|
354 |
+
"""
|
355 |
+
os.makedirs(self._output_dir, exist_ok=True)
|
356 |
+
output_loc = f"{self._output_dir}/model_{epoch}.ckpt"
|
357 |
+
ckpt_dict = {MODEL_KEY: self._model}
|
358 |
+
|
359 |
+
# if training is in-progress, checkpoint the optimizer state as well
|
360 |
+
if epoch + 1 < self.total_epochs:
|
361 |
+
ckpt_dict.update(
|
362 |
+
{
|
363 |
+
OPT_KEY: self._optimizer,
|
364 |
+
SEED_KEY: self.seed,
|
365 |
+
EPOCHS_KEY: self.epochs_run,
|
366 |
+
TOTAL_EPOCHS_KEY: self.total_epochs,
|
367 |
+
MAX_STEPS_KEY: self.max_steps_per_epoch,
|
368 |
+
}
|
369 |
+
)
|
370 |
+
utils.save_checkpoint(ckpt_dict, output_loc)
|
371 |
+
|
372 |
+
if self._is_rank_zero:
|
373 |
+
log.info(
|
374 |
+
f"Model checkpoint of size {os.path.getsize(output_loc) >> 20} MB saved to {output_loc}"
|
375 |
+
)
|
376 |
+
|
377 |
+
if self._hf_repo_id is not None:
|
378 |
+
log.info(f"Uploading checkpoint to HuggingFace Hub: {self._hf_repo_id}")
|
379 |
+
self._hf_hub.upload_folder(
|
380 |
+
folder_path=self._output_dir,
|
381 |
+
repo_id=self._hf_repo_id,
|
382 |
+
repo_type="model",
|
383 |
+
run_as_future=True,
|
384 |
+
commit_message=f"Checkpoint for epoch {epoch} (step {self.total_training_steps})"
|
385 |
+
)
|
386 |
+
else:
|
387 |
+
log.info("Skipping uploading to HuggingFace Hub (no repo id specified)")
|
388 |
+
|
389 |
+
|
390 |
+
|
391 |
+
def _should_update_weights(self, curr_step: int) -> bool:
|
392 |
+
"""
|
393 |
+
Determines whether the weights should be updated on the current step or not.
|
394 |
+
True is returned either if we've accumulated gradients for enough steps or if this
|
395 |
+
is the last step in the epoch.
|
396 |
+
"""
|
397 |
+
should_update_weights = (
|
398 |
+
curr_step + 1
|
399 |
+
) % self._gradient_accumulation_steps == 0 or (
|
400 |
+
curr_step + 1
|
401 |
+
) == self._steps_per_epoch
|
402 |
+
return should_update_weights
|
403 |
+
|
404 |
+
def train(self) -> None:
|
405 |
+
"""
|
406 |
+
The core training loop. Supports training on subsets of the dataset using the
|
407 |
+
``max_steps_per_epoch``.
|
408 |
+
"""
|
409 |
+
_, rank = utils.get_world_size_and_rank()
|
410 |
+
|
411 |
+
# zero out the gradients before starting training
|
412 |
+
self._optimizer.zero_grad()
|
413 |
+
|
414 |
+
# self.epochs_run should be non-zero when we're resuming from a checkpoint
|
415 |
+
for curr_epoch in range(self.epochs_run, self.total_epochs):
|
416 |
+
|
417 |
+
# Update the sampler to ensure data is correctly shuffled across epochs
|
418 |
+
# in case shuffle is True
|
419 |
+
self._sampler.set_epoch(curr_epoch)
|
420 |
+
|
421 |
+
for idx, batch in enumerate(
|
422 |
+
pbar := tqdm(self._dataloader, disable=not (rank == 0))
|
423 |
+
):
|
424 |
+
if (
|
425 |
+
self.max_steps_per_epoch is not None
|
426 |
+
and (idx // self._gradient_accumulation_steps)
|
427 |
+
== self.max_steps_per_epoch
|
428 |
+
):
|
429 |
+
break
|
430 |
+
|
431 |
+
input_ids, labels, colors = batch
|
432 |
+
|
433 |
+
input_ids = input_ids.to(self._device)
|
434 |
+
labels = labels.to(self._device)
|
435 |
+
colors = colors.to(self._device)
|
436 |
+
|
437 |
+
with self._autocast:
|
438 |
+
logits = self._model(input_ids, colors=colors)
|
439 |
+
# Shift so that tokens < n predict n
|
440 |
+
logits = logits[..., :-1, :].contiguous()
|
441 |
+
labels = labels[..., 1:].contiguous()
|
442 |
+
logits = logits.transpose(1, 2)
|
443 |
+
# Compute loss
|
444 |
+
loss = self._loss_fn(logits, labels)
|
445 |
+
|
446 |
+
# Note: We're always logging the loss before normalizing it
|
447 |
+
# Check if this is the norm or not
|
448 |
+
pbar.set_description(f"{curr_epoch+1}|{idx+1}|Loss: {loss.item()}")
|
449 |
+
|
450 |
+
if self.total_training_steps % self._log_every_n_steps == 0:
|
451 |
+
self._metric_logger.log_dict(
|
452 |
+
{
|
453 |
+
"loss": loss.item(),
|
454 |
+
"lr": self._optimizer.param_groups[0]["lr"],
|
455 |
+
"gpu_resources": torch.cuda.memory_allocated(),
|
456 |
+
},
|
457 |
+
step=self.total_training_steps,
|
458 |
+
)
|
459 |
+
|
460 |
+
if self._checkpoint_every_n_steps is not None:
|
461 |
+
if self.total_training_steps > 0 and self.total_training_steps % self._checkpoint_every_n_steps == 0:
|
462 |
+
self.save_checkpoint(epoch=curr_epoch)
|
463 |
+
|
464 |
+
# Does loss normalization need to happen within autocast context?
|
465 |
+
loss = loss / self._gradient_accumulation_steps
|
466 |
+
self._grad_scaler.scale(loss).backward()
|
467 |
+
|
468 |
+
if self._should_update_weights(idx):
|
469 |
+
self._grad_scaler.step(self._optimizer)
|
470 |
+
self._grad_scaler.update()
|
471 |
+
self._optimizer.zero_grad(set_to_none=True)
|
472 |
+
|
473 |
+
# Update the number of steps when the weights are updated
|
474 |
+
self.total_training_steps += 1
|
475 |
+
|
476 |
+
self.epochs_run += 1
|
477 |
+
self.save_checkpoint(epoch=curr_epoch)
|
478 |
+
|
479 |
+
def cleanup(self) -> None:
|
480 |
+
self._metric_logger.close()
|
481 |
+
|
482 |
+
|
483 |
+
def recipe_main() -> None:
|
484 |
+
"""
|
485 |
+
Entry point for the recipe.
|
486 |
+
|
487 |
+
Configurable parameters are read in the following order:
|
488 |
+
- Parameters specified in ``ColoringFinetuneParams``
|
489 |
+
- Overwritten by Parameters specified in ``alpaca_llama2_full_finetune.yaml``
|
490 |
+
- Overwritten by arguments from the command-line using ``TuneArgumentParser``
|
491 |
+
"""
|
492 |
+
parser = utils.TuneArgumentParser(
|
493 |
+
description=ColoringFinetuneParams.__doc__,
|
494 |
+
formatter_class=argparse.RawDescriptionHelpFormatter,
|
495 |
+
)
|
496 |
+
args, _ = parser.parse_known_args()
|
497 |
+
args = vars(args)
|
498 |
+
recipe_params = ColoringFinetuneParams(**args)
|
499 |
+
|
500 |
+
# Env variables set by torch run; only need to initialize process group
|
501 |
+
# Disabled since this breaks for now on RunPod.
|
502 |
+
# init_process_group(backend="nccl")
|
503 |
+
|
504 |
+
recipe = ColoringFinetuneRecipe(params=recipe_params)
|
505 |
+
recipe.setup(params=recipe_params)
|
506 |
+
recipe.train()
|
507 |
+
recipe.cleanup()
|
508 |
+
|
509 |
+
|
510 |
+
if __name__ == "__main__":
|
511 |
+
sys.exit(recipe_main())
|
colorful/masked_apply.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
|
5 |
+
class MaskedApply(nn.Module):
|
6 |
+
"""
|
7 |
+
Uses an index mask to select a sbuset of the input and apply a layer to it.
|
8 |
+
|
9 |
+
E.g. if mask is [[0, 1, 0]] layers[0] will be applied to the first and third element
|
10 |
+
and layers[1] will be applied to the second element.
|
11 |
+
"""
|
12 |
+
|
13 |
+
def __init__(self, layers, strict=False):
|
14 |
+
super(MaskedApply, self).__init__()
|
15 |
+
self.num_layers = len(layers)
|
16 |
+
self.layers = nn.ModuleList(layers)
|
17 |
+
self.strict = strict
|
18 |
+
|
19 |
+
# Create a CPU tensor to store the maximum value found.
|
20 |
+
# This will prevent the GPU being blocked while we check
|
21 |
+
# whether an index is > num_layers in strict mode.
|
22 |
+
self._maximum_found_cpu = torch.tensor([-1], device='cpu')
|
23 |
+
self._maximum_found = torch.tensor([-1])
|
24 |
+
if torch.cuda.is_available():
|
25 |
+
self._maximum_found_cpu = self._maximum_found_cpu.pin_memory()
|
26 |
+
|
27 |
+
def forward(self, x, mask):
|
28 |
+
# If in strict mode, check if we previously violated the maximum found.
|
29 |
+
if self._maximum_found_cpu >= self.num_layers:
|
30 |
+
raise ValueError(f'Unexpected index value found {self._maximum_found_cpu}. Should be less than {self.num_layers}')
|
31 |
+
|
32 |
+
# Ensure mask is a long tensor
|
33 |
+
mask = mask.long()
|
34 |
+
|
35 |
+
# Flatten x and mask for easier processing
|
36 |
+
batch_size, seq_length, embedding_size = x.shape
|
37 |
+
|
38 |
+
x_flat = x.view(-1, embedding_size)
|
39 |
+
mask_flat = mask.view(-1)
|
40 |
+
|
41 |
+
# Output placeholder
|
42 |
+
output_flat = torch.zeros_like(x_flat)
|
43 |
+
|
44 |
+
# Process each mask value
|
45 |
+
for i in range(self.num_layers):
|
46 |
+
# Find indices for current mask value
|
47 |
+
indices = torch.where(mask_flat == i)[0]
|
48 |
+
|
49 |
+
# Select relevant inputs for the current linear layer
|
50 |
+
selected_inputs = torch.index_select(x_flat, 0, indices)
|
51 |
+
|
52 |
+
# Apply linear layer
|
53 |
+
transformed = self.layers[i](selected_inputs)
|
54 |
+
|
55 |
+
# TODO: figure out why this is necessary.
|
56 |
+
transformed = transformed.to(x_flat.dtype)
|
57 |
+
|
58 |
+
# Place results back in the output tensor
|
59 |
+
output_flat.index_copy_(0, indices, transformed)
|
60 |
+
|
61 |
+
# Copy any out of range indices
|
62 |
+
if self.strict:
|
63 |
+
# This check is done asynchronously.
|
64 |
+
self._maximum_found = max(max(mask_flat), self._maximum_found)
|
65 |
+
self._maximum_found_cpu.copy_(self._maximum_found, non_blocking=True)
|
66 |
+
else:
|
67 |
+
indices = torch.where(mask_flat >= self.num_layers)[0]
|
68 |
+
selected_inputs = torch.index_select(x_flat, 0, indices)
|
69 |
+
output_flat.index_copy_(0, indices, selected_inputs)
|
70 |
+
|
71 |
+
# Reshape output to original dimensions
|
72 |
+
output = output_flat.view(batch_size, seq_length, embedding_size)
|
73 |
+
return output
|