laurencer commited on
Commit
261dbc8
1 Parent(s): e143414
.gitignore ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Created by https://www.toptal.com/developers/gitignore/api/python,macos
2
+ # Edit at https://www.toptal.com/developers/gitignore?templates=python,macos
3
+
4
+ ### TorchTune ###
5
+
6
+ output/
7
+ model/
8
+ wandb/
9
+
10
+ ### macOS ###
11
+ # General
12
+ .DS_Store
13
+ .AppleDouble
14
+ .LSOverride
15
+
16
+ # Icon must end with two \r
17
+ Icon
18
+
19
+
20
+ # Thumbnails
21
+ ._*
22
+
23
+ # Files that might appear in the root of a volume
24
+ .DocumentRevisions-V100
25
+ .fseventsd
26
+ .Spotlight-V100
27
+ .TemporaryItems
28
+ .Trashes
29
+ .VolumeIcon.icns
30
+ .com.apple.timemachine.donotpresent
31
+
32
+ # Directories potentially created on remote AFP share
33
+ .AppleDB
34
+ .AppleDesktop
35
+ Network Trash Folder
36
+ Temporary Items
37
+ .apdisk
38
+
39
+ ### macOS Patch ###
40
+ # iCloud generated files
41
+ *.icloud
42
+
43
+ ### Python ###
44
+ # Byte-compiled / optimized / DLL files
45
+ __pycache__/
46
+ *.py[cod]
47
+ *$py.class
48
+
49
+ # C extensions
50
+ *.so
51
+
52
+ # Distribution / packaging
53
+ .Python
54
+ build/
55
+ develop-eggs/
56
+ dist/
57
+ downloads/
58
+ eggs/
59
+ .eggs/
60
+ lib/
61
+ lib64/
62
+ parts/
63
+ sdist/
64
+ var/
65
+ wheels/
66
+ share/python-wheels/
67
+ *.egg-info/
68
+ .installed.cfg
69
+ *.egg
70
+ MANIFEST
71
+
72
+ # PyInstaller
73
+ # Usually these files are written by a python script from a template
74
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
75
+ *.manifest
76
+ *.spec
77
+
78
+ # Installer logs
79
+ pip-log.txt
80
+ pip-delete-this-directory.txt
81
+
82
+ # Unit test / coverage reports
83
+ htmlcov/
84
+ .tox/
85
+ .nox/
86
+ .coverage
87
+ .coverage.*
88
+ .cache
89
+ nosetests.xml
90
+ coverage.xml
91
+ *.cover
92
+ *.py,cover
93
+ .hypothesis/
94
+ .pytest_cache/
95
+ cover/
96
+
97
+ # Translations
98
+ *.mo
99
+ *.pot
100
+
101
+ # Django stuff:
102
+ *.log
103
+ local_settings.py
104
+ db.sqlite3
105
+ db.sqlite3-journal
106
+
107
+ # Flask stuff:
108
+ instance/
109
+ .webassets-cache
110
+
111
+ # Scrapy stuff:
112
+ .scrapy
113
+
114
+ # Sphinx documentation
115
+ docs/_build/
116
+
117
+ # PyBuilder
118
+ .pybuilder/
119
+ target/
120
+
121
+ # Jupyter Notebook
122
+ .ipynb_checkpoints
123
+
124
+ # IPython
125
+ profile_default/
126
+ ipython_config.py
127
+
128
+ # pyenv
129
+ # For a library or package, you might want to ignore these files since the code is
130
+ # intended to run in multiple environments; otherwise, check them in:
131
+ # .python-version
132
+
133
+ # pipenv
134
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
135
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
136
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
137
+ # install all needed dependencies.
138
+ #Pipfile.lock
139
+
140
+ # poetry
141
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
142
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
143
+ # commonly ignored for libraries.
144
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
145
+ #poetry.lock
146
+
147
+ # pdm
148
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
149
+ #pdm.lock
150
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
151
+ # in version control.
152
+ # https://pdm.fming.dev/#use-with-ide
153
+ .pdm.toml
154
+
155
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
156
+ __pypackages__/
157
+
158
+ # Celery stuff
159
+ celerybeat-schedule
160
+ celerybeat.pid
161
+
162
+ # SageMath parsed files
163
+ *.sage.py
164
+
165
+ # Environments
166
+ .env
167
+ .venv
168
+ env/
169
+ venv/
170
+ ENV/
171
+ env.bak/
172
+ venv.bak/
173
+
174
+ # Spyder project settings
175
+ .spyderproject
176
+ .spyproject
177
+
178
+ # Rope project settings
179
+ .ropeproject
180
+
181
+ # mkdocs documentation
182
+ /site
183
+
184
+ # mypy
185
+ .mypy_cache/
186
+ .dmypy.json
187
+ dmypy.json
188
+
189
+ # Pyre type checker
190
+ .pyre/
191
+
192
+ # pytype static type analyzer
193
+ .pytype/
194
+
195
+ # Cython debug symbols
196
+ cython_debug/
197
+
198
+ # PyCharm
199
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
200
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
201
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
202
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
203
+ #.idea/
204
+
205
+ ### Python Patch ###
206
+ # Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration
207
+ poetry.toml
208
+
209
+ # ruff
210
+ .ruff_cache/
211
+
212
+ # LSP config files
213
+ pyrightconfig.json
214
+
215
+ # End of https://www.toptal.com/developers/gitignore/api/python,macos
README.md ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # torchtune research repo: token coloring (colorful llama)
2
+
3
+ Playground to try out [token coloring](https://docs.google.com/document/d/1Win9vhddD-pu5P3SsG7E-dzN5oQl5DYWW1DhO7sBOgI/edit#heading=h.oqq00pt8expe) with TorchTune.
4
+
5
+ The repo was generated using the alpha version of [torchtune](https://github.com/pytorch-labs/torchtune).
6
+
7
+ Brief notes:
8
+
9
+ - The starting recipe is based on the Alpaca Llama2 7B full finetune recipe (switched to bf16).
10
+ - I assume `output/` is used to store model outputs and `model/` is used to store the base model checkpoints.
11
+
12
+ For the `colorful` recipe:
13
+
14
+ - I copied a lot of functionality (like the actual model definition, dataset, etc) from torchtune repository directly since I needed to make changes.
15
+ - I reduced the flexiblity of the recipe (e.g. cannot specify the model or tokenizer) and increased it in other ways (e.g. can pass in a dataset path directly).
16
+ - I added intermediate checkpointing (i.e. every `n` steps) and automatically upload the checkpoint to HuggingFace Hub.
17
+
18
+ ## Getting started
19
+
20
+ The below instructions can be copy-pasted as is on to a running instance. They assume that the `HF_TOKEN` environment variable is set with a valid token.
21
+
22
+ ```bash
23
+ # for RunPod
24
+ cd /workspace
25
+ git clone git@github.com:pytorch-labs/torchtune.git
26
+ cd torchtune
27
+ pip install -e .
28
+
29
+ cd /workspace
30
+ git clone git@github.com:laurencer/torchtune-colorful-llama.git
31
+ cd torchtune-colorful-llama
32
+
33
+ # for wandb support
34
+ pip install wandb
35
+ ```
36
+
37
+ ```bash
38
+ mkdir -p model/
39
+ tune download --repo-id meta-llama/Llama-2-7b --output-dir model/
40
+ ```
41
+
42
+ ```bash
43
+ tune convert_checkpoint --checkpoint-path model/consolidated.00.pth --output-path model/llama2_native.tune
44
+ ```
45
+
46
+ ```bash
47
+ mkdir -p output/
48
+ # tune --nnodes 1 --nproc_per_node 1 ./colorful/full_finetune.py --config ./colorful/basic_config.yaml
49
+ nohup tune --nnodes 1 --nproc_per_node 1 ./colorful/full_finetune.py --config ./colorful/basic_config.yaml 2>&1 > training_log_$(date "+%Y.%m.%d_%H.%M.%S").log &
50
+ sleep 1
51
+ tail -f training_log_*.log
52
+ ```
53
+
54
+ ## Baselines
55
+
56
+ Two baseline configs are provided in the `baseline` directory.
57
+ We forked the original recipe to support customizing the location/path of the Alpaca dataset.
58
+
59
+ ```bash
60
+ # tune --nnodes 1 --nproc_per_node 1 ./baseline/full_finetune.py --config ./baseline/baseline_config.yaml
61
+ nohup tune --nnodes 1 --nproc_per_node 1 ./baseline/full_finetune.py --config ./baseline/baseline_config.yaml 2>&1 > training_log_$(date "+%Y.%m.%d_%H.%M.%S").log &
62
+ sleep 1
63
+ tail -f training_log_*.log
64
+ ```
65
+
66
+ The adversarial config uses a dataset that is equivalent to 4x the original alpaca cleaned dataset with extra examples that include prompt injection attempts. See [token coloring description](https://docs.google.com/document/d/1Win9vhddD-pu5P3SsG7E-dzN5oQl5DYWW1DhO7sBOgI/edit#heading=h.oqq00pt8expe) for more info.
67
+
68
+ ```bash
69
+ # tune --nnodes 1 --nproc_per_node 1 ./baseline/full_finetune.py --config ./baseline/adversarial_config.yaml
70
+ nohup tune --nnodes 1 --nproc_per_node 1 ./baseline/full_finetune.py --config ./baseline/adversarial_config.yaml 2>&1 > training_log_$(date "+%Y.%m.%d_%H.%M.%S").log &
71
+ sleep 1
72
+ tail -f training_log_*.log
73
+ ```
74
+
75
+ ## Colorful
76
+
77
+ The `colorful` directory implements the changes required to support token coloring. This includes a custom dataset implementation and training script.
baseline/adversarial_config.yaml ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Runs the full_finetune.py recipe
2
+ #
3
+ # To launch, run the following command from root:
4
+ # tune --nnodes 1 --nproc_per_node 1 --config alpaca_llama2_full_finetune --override model_checkpoint=<your_checkpoint_dir> ...
5
+
6
+ # Dataset and Dataloader
7
+ dataset: laurencer/yahma-alpaca-cleaned-adversarial
8
+ seed: 42
9
+ shuffle: True
10
+
11
+ # Model Arguments
12
+ model: llama2_7b
13
+ model_checkpoint: model/llama2_native.tune
14
+ tokenizer: llama2_tokenizer
15
+ tokenizer_checkpoint: model/tokenizer.model
16
+
17
+ # Fine-tuning arguments
18
+ batch_size: 8
19
+ lr: 2e-5
20
+ epochs: 1
21
+ optimizer: SGD
22
+ loss: CrossEntropyLoss
23
+ output_dir: output/alpaca-llama2-adversarial
24
+ device: cuda
25
+ dtype: bf16
26
+ enable_fsdp: False
27
+ enable_activation_checkpointing: True
28
+ resume_from_checkpoint: False
29
+
30
+ # Logging arguments
31
+ metric_logger_type: wandb
32
+ project: torchtune
baseline/baseline_config.yaml ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Runs the full_finetune.py recipe
2
+ #
3
+ # To launch, run the following command from root:
4
+ # tune --nnodes 1 --nproc_per_node 1 --config alpaca_llama2_full_finetune --override model_checkpoint=<your_checkpoint_dir> ...
5
+
6
+ # Dataset and Dataloader
7
+ dataset: yahma/alpaca-cleaned
8
+ seed: 42
9
+ shuffle: True
10
+
11
+ # Model Arguments
12
+ model: llama2_7b
13
+ model_checkpoint: model/llama2_native.tune
14
+ tokenizer: llama2_tokenizer
15
+ tokenizer_checkpoint: model/tokenizer.model
16
+
17
+ # Fine-tuning arguments
18
+ batch_size: 8
19
+ lr: 2e-5
20
+ epochs: 4
21
+ optimizer: SGD
22
+ loss: CrossEntropyLoss
23
+ output_dir: output/alpaca-llama2-baseline
24
+ device: cuda
25
+ dtype: bf16
26
+ enable_fsdp: False
27
+ enable_activation_checkpointing: True
28
+ resume_from_checkpoint: False
29
+
30
+ # Logging arguments
31
+ metric_logger_type: wandb
32
+ project: torchtune
baseline/custom_dataset.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from typing import List, Tuple
3
+
4
+ from datasets import load_dataset
5
+ from torch.utils.data import Dataset
6
+
7
+ # Not ideal to import this type here but it's needed for the transform function
8
+ from torchtune.modules import Tokenizer
9
+
10
+
11
+ CROSS_ENTROPY_IGNORE_IDX = -100
12
+
13
+ _PROMPT_TEMPLATE = {
14
+ "prompt_input": (
15
+ "Below is an instruction that describes a task, paired with an input that provides further context. "
16
+ "Write a response that appropriately completes the request.\n\n"
17
+ "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
18
+ ),
19
+ "prompt_no_input": (
20
+ "Below is an instruction that describes a task. "
21
+ "Write a response that appropriately completes the request.\n\n"
22
+ "### Instruction:\n{instruction}\n\n### Response:\n"
23
+ ),
24
+ }
25
+
26
+
27
+ class AlpacaDataset(Dataset):
28
+ """
29
+ See torchtune.datasets.AlpacaDataset for the original implementation.
30
+ This version supports custom dataset paths.
31
+ """
32
+
33
+ def __init__(
34
+ self,
35
+ dataset_path: str,
36
+ tokenizer: Tokenizer,
37
+ train_on_input: bool = True,
38
+ **kwargs
39
+ ) -> None:
40
+ self._data = load_dataset(dataset_path, split="train")
41
+ self._tokenizer = tokenizer
42
+ self.train_on_input = train_on_input
43
+
44
+ def __len__(self):
45
+ return len(self._data)
46
+
47
+ def __getitem__(self, index: int) -> Tuple[List[int], List[int]]:
48
+ sample = self._data[index]
49
+
50
+ return self._transform(
51
+ instruction=sample["instruction"],
52
+ input=sample["input"],
53
+ output=sample["output"],
54
+ )
55
+
56
+ def _transform(
57
+ self, instruction: str, input: str, output: str
58
+ ) -> Tuple[List[int], List[int]]:
59
+ """
60
+ Split a sample on ``response`` tag to create input and labels.
61
+
62
+ Args:
63
+ instruction (str): Instruction text.
64
+ input (str): Input text. Can be an empty string. Determines the prompt generation template
65
+ used.
66
+ output (str): Response text.
67
+
68
+ Returns:
69
+ Tuple of encoded inputs and labels.
70
+ """
71
+ prompt = self._generate_prompt(instruction, input)
72
+ prompt_with_response = prompt + output
73
+
74
+ # add bos always; LlamaTokenizer sets this to True by default and neither
75
+ # alpaca-lora or the original authors change this
76
+ encoded_prompt = self._tokenizer.encode(
77
+ text=prompt, add_bos=True, add_eos=False
78
+ )
79
+ encoded_prompt_with_response = self._tokenizer.encode(
80
+ text=prompt_with_response, add_bos=True, add_eos=True
81
+ )
82
+ labels = encoded_prompt_with_response.copy()
83
+
84
+ if not self.train_on_input:
85
+ labels[: len(encoded_prompt)] = [CROSS_ENTROPY_IGNORE_IDX] * len(
86
+ encoded_prompt
87
+ )
88
+
89
+ assert len(encoded_prompt_with_response) == len(labels)
90
+
91
+ return encoded_prompt_with_response, labels
92
+
93
+ def _generate_prompt(self, instruction: str, input: str) -> str:
94
+ """
95
+ Generate prompt from instruction and input.
96
+
97
+ Args:
98
+ instruction (str): Instruction text.
99
+ input (str): Input text.
100
+
101
+ Returns:
102
+ Prompt text.
103
+ """
104
+ if input:
105
+ prompt = _PROMPT_TEMPLATE["prompt_input"].format(
106
+ instruction=instruction, input=input
107
+ )
108
+ else:
109
+ prompt = _PROMPT_TEMPLATE["prompt_no_input"].format(instruction=instruction)
110
+ return prompt
baseline/custom_params.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Customized to remove dataset validation.
2
+
3
+ from dataclasses import dataclass, field, fields
4
+ from typing import List, Optional
5
+
6
+ from torchtune.datasets import ALL_DATASETS
7
+ from torchtune.models import ALL_MODELS, ALL_TOKENIZERS
8
+ from torchtune.utils.metric_logging import ALL_METRIC_LOGGERS
9
+ from torchtune.utils.precision import PRECISION_STR_TO_DTYPE
10
+
11
+
12
+ @dataclass
13
+ class FullFinetuneParams:
14
+ """Arguments for the finetune_llm recipe.
15
+
16
+ Args:
17
+ device (str): Device to use for training. Options are "cpu" and "cuda"
18
+ dtype (str): Data type to use for training.
19
+ seed (int): Random seed to use for training.
20
+ model (str): String specifying model architecture to fine-tune. See ``torchtune.models.get_model`` for options.
21
+ model_checkpoint (str): Local path to load model checkpoint from.
22
+ tokenizer (str): String specifying tokenizer to use. See ``torchtune.models.get_tokenizer`` for options.
23
+ tokenizer_checkpoint (str): Local path to load tokenizer checkpoint from.
24
+ dataset (str): String specifying dataset to use. See ``torchtune.datasets.get_dataset`` for options.
25
+ Currently, only predefined datasets in library are supported.
26
+ shuffle (bool): Whether to shuffle dataset.
27
+ batch_size (int): Batch size to use for training.
28
+ epochs (int): Number of epochs to train for.
29
+ optimizer (str): String specifying optimizer to use. See ``torchtune.optim.get_optimizer`` for options.
30
+ loss (str): String specifying loss function to use. See ``torchtune.losses.get_loss`` for options.
31
+ lr (float): Learning rate to use for optimizer.
32
+ activation_checkpointing (bool): Whether to use activation checkpointing.
33
+ output_dir (str): Local path to save checkpoints and logs to.
34
+ run_generation (int): Run eval on a prompt every ``run_generation`` steps. Set to 0 to disable.
35
+ max_steps_per_epoch (int): Maximum number of steps to take per epoch.
36
+ metric_logger_type (str): String specifying metric logger to use. See ``torchtune.utils.get_metric_logger``
37
+ for options.
38
+ project (str): Project name to use for logging. Used by ``WandBLogger``.
39
+ resume_from_previous_checkpoint (bool): Whether to resume fine-tuning from a previous checkpoint.
40
+ cpu_offload (bool): Whether to offload model to CPU.
41
+
42
+ Raises:
43
+ ValueError: If ``cpu_offload`` is ``True`` but ``device`` is not ``cuda`` and <= 1 GPUs.
44
+ """
45
+
46
+ # Model
47
+ model: str = ""
48
+ model_checkpoint: str = ""
49
+
50
+ # Tokenizer
51
+ tokenizer: str = ""
52
+ tokenizer_checkpoint: str = ""
53
+
54
+ # Dataset and Sampler
55
+ dataset: str = ""
56
+ train_on_input: bool = True
57
+ shuffle: bool = True
58
+ batch_size: int = 2
59
+
60
+ # Optimizer and Scheduler
61
+ optimizer: str = "SGD"
62
+ lr: float = 2e-5
63
+ loss: str = "CrossEntropyLoss"
64
+ gradient_accumulation_steps: int = 1
65
+
66
+ # Training
67
+ epochs: int = 3
68
+ max_steps_per_epoch: Optional[int] = None
69
+ resume_from_checkpoint: bool = False
70
+ run_generation: Optional[int] = None
71
+
72
+ # Distributed
73
+ cpu_offload: bool = False
74
+ enable_fsdp: bool = True
75
+ enable_activation_checkpointing: bool = True
76
+
77
+ # Environment
78
+ device: str = "cuda"
79
+ dtype: str = "fp32"
80
+ seed: Optional[int] = None
81
+
82
+ # Logging
83
+ output_dir: str = "/tmp/full_finetune_output"
84
+ metric_logger_type: str = "disk"
85
+ project: Optional[str] = None
86
+ log_every_n_steps: Optional[int] = None
87
+
88
+ def __post_init__(self):
89
+ for param in fields(self):
90
+ if getattr(self, param.name) == "":
91
+ raise TypeError(f"{param.name} needs to be specified")
92
+
93
+ if self.cpu_offload and self.device != "cuda":
94
+ raise ValueError(
95
+ "Cannot offload model to CPU if device is not cuda or <= 1 GPUs."
96
+ )
97
+ if self.enable_fsdp and self.device == "cpu":
98
+ raise ValueError("FSDP is not supported on CPU.")
99
+ if self.model not in ALL_MODELS:
100
+ raise ValueError(
101
+ f"Model not recognized. Expected one of {ALL_MODELS}, received {self.model}."
102
+ )
103
+ if self.tokenizer not in ALL_TOKENIZERS:
104
+ raise ValueError(
105
+ f"Tokenizer not recognized. Expected one of {ALL_TOKENIZERS}, received {self.tokenizer}."
106
+ )
107
+ if self.metric_logger_type not in ALL_METRIC_LOGGERS:
108
+ raise ValueError(
109
+ f"Metric logger not recognized. Expected one of {ALL_METRIC_LOGGERS}, received {self.metric_logger_type}."
110
+ )
111
+ if self.dtype not in PRECISION_STR_TO_DTYPE:
112
+ raise ValueError(
113
+ f"Dtype {self.dtype} must be one of {', '.join(PRECISION_STR_TO_DTYPE.keys())} for finetuning."
114
+ )
baseline/full_finetune.py ADDED
@@ -0,0 +1,455 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import argparse
8
+ import os
9
+ import sys
10
+
11
+ from functools import partial
12
+ from typing import Any, Dict, Optional, Tuple
13
+ from warnings import warn
14
+
15
+ import torch
16
+
17
+ from torch import nn
18
+ from torch.cuda.amp import GradScaler
19
+ from torch.distributed import init_process_group
20
+ from torch.optim import Optimizer
21
+ from torch.utils.data import DataLoader, DistributedSampler
22
+
23
+ from torchtune import models, modules, utils
24
+ from torchtune.utils.constants import (
25
+ EPOCHS_KEY,
26
+ MAX_STEPS_KEY,
27
+ MODEL_KEY,
28
+ OPT_KEY,
29
+ SEED_KEY,
30
+ TOTAL_EPOCHS_KEY,
31
+ )
32
+
33
+ from tqdm import tqdm
34
+
35
+ from recipes.interfaces import FTRecipeInterface
36
+
37
+
38
+ from custom_params import FullFinetuneParams
39
+ from custom_dataset import AlpacaDataset
40
+
41
+ log = utils.get_logger("DEBUG")
42
+
43
+
44
+ class FullFinetuneRecipe(FTRecipeInterface):
45
+ """
46
+ Full finetuning recipe for dense transformer-based LLMs such as Llama2.
47
+
48
+ This recipe supports:
49
+ - FSDP and activation checkpointing. This is enabled by default but can be
50
+ configured using the ``enable_fsdp`` and ``enable_activation_checkpointing`` flags.
51
+ - Mixed precision training - fp32, fp16 and bf16 are supported.
52
+ - Checkpointing of model weights, optimizer state and the recipe state (epoch and seed).
53
+ - Resuming from checkpoints saved using the ``save_checkpoint`` functionality.
54
+ - Logging to terminal. WandB and TensorBoard are currently not supported.
55
+
56
+ Assumptions:
57
+ - Training is launched with the Tune CLI (recommended) which uses TorchRun under the
58
+ hood. Setting up the env variables is handled by TorchRun.
59
+ - Training happens on CUDA (CPU training is not supported)
60
+ - Checkpoints are ONLY saved at epoch boundaries. Mid-epoch checkpointing is NOT supported.
61
+ - Datasets are Map-style and data fits in memory (not streamed).
62
+ """
63
+
64
+ def __init__(self, params: FullFinetuneParams) -> None:
65
+
66
+ self._device = utils.get_device(device=params.device)
67
+ self._dtype = utils.get_dtype(dtype=params.dtype)
68
+
69
+ # logging attributes
70
+ self._output_dir = params.output_dir
71
+ self._metric_logger = utils.get_metric_logger(
72
+ metric_logger_type=params.metric_logger_type,
73
+ project=params.project,
74
+ log_dir=params.output_dir,
75
+ )
76
+ self._log_every_n_steps = (
77
+ params.log_every_n_steps if params.log_every_n_steps else 1
78
+ )
79
+
80
+ # _is_rank_zero is used primarily for logging. In the future, the logger
81
+ # should directly take care of this
82
+ _, rank = utils.get_world_size_and_rank()
83
+ self._is_rank_zero = rank == 0
84
+
85
+ # Training params
86
+ self._resume_from_checkpoint = params.resume_from_checkpoint
87
+ self._enable_fsdp = params.enable_fsdp
88
+ self._gradient_accumulation_steps = params.gradient_accumulation_steps
89
+
90
+ # These are public properties which are updated by the checkpoint loader
91
+ # when ``resume_from_checkpoint`` is `True` or validated in tests
92
+ self.seed = utils.set_seed(seed=params.seed)
93
+ self.epochs_run = 0
94
+ self.total_epochs = params.epochs
95
+ self.max_steps_per_epoch = params.max_steps_per_epoch
96
+ self.total_training_steps = 0
97
+
98
+ def load_checkpoint(self, ckpt_path: str):
99
+ """
100
+ Extract the checkpoint state from file and validate.
101
+ """
102
+ ckpt_dict = torch.load(ckpt_path, map_location="cpu", weights_only=True)
103
+ utils.validate_checkpoint(ckpt_dict, self._resume_from_checkpoint)
104
+ return ckpt_dict
105
+
106
+ def setup(self, params: FullFinetuneParams) -> None:
107
+ """
108
+ Sets up the recipe state correctly. This includes setting recipe attributes based
109
+ on the ``resume_from_checkpoint`` flag.
110
+ """
111
+
112
+ ckpt_dict = self.load_checkpoint(ckpt_path=params.model_checkpoint)
113
+
114
+ # If we're resuming from checkpoint, the recipe's state should be updated before
115
+ # initializing the training components. This ensures that the seed is correctly
116
+ # propagated to the relevant components
117
+ if self._resume_from_checkpoint:
118
+ self._update_recipe_state(ckpt_dict)
119
+
120
+ # ``_setup_model`` handles initialization and loading the state dict. This method
121
+ # should be called before ``_setup_optimizer`` since transforming the optimizer
122
+ # state dict requires the model
123
+ self._model = self._setup_model(
124
+ model=params.model,
125
+ enable_fsdp=params.enable_fsdp,
126
+ enable_activation_checkpointing=params.enable_activation_checkpointing,
127
+ model_state_dict=ckpt_dict[MODEL_KEY],
128
+ )
129
+
130
+ self._tokenizer = self._setup_tokenizer(
131
+ tokenizer=params.tokenizer, tokenizer_checkpoint=params.tokenizer_checkpoint
132
+ )
133
+
134
+ # _setup_optimizer should take in ckpt_dict only if training is resumed from
135
+ # checkpoint. Transforming the opt state dict is handled by this method
136
+ self._optimizer = self._setup_optimizer(
137
+ optimizer=params.optimizer,
138
+ lr=params.lr,
139
+ opt_state_dict=ckpt_dict[OPT_KEY] if self._resume_from_checkpoint else None,
140
+ )
141
+
142
+ self._loss_fn = self._setup_loss(loss=params.loss)
143
+
144
+ # sampler and dataloader depend on the tokenizer and loss_fn and should be
145
+ # setup after both of these are initialized
146
+ self._sampler, self._dataloader = self._setup_data(
147
+ dataset=params.dataset,
148
+ train_on_input=params.train_on_input,
149
+ shuffle=params.shuffle,
150
+ batch_size=params.batch_size,
151
+ )
152
+
153
+ # training setup
154
+ self._autocast = utils.get_autocast(self._dtype, self._device)
155
+ self._grad_scaler = None
156
+ if self._dtype == torch.float16:
157
+ self._grad_scaler = utils.get_gradient_scaler(fsdp=params.enable_fsdp)
158
+ else:
159
+ self._grad_scaler = GradScaler(enabled=False)
160
+
161
+ # Finally update the recipe state which can only be correctly set after all of the
162
+ # other components have been initialized and updated.
163
+ #
164
+ # Number of training steps in each epoch depends on the number of batches produced
165
+ # by the dataloader, the max_steps_per_epoch param set by the user and the
166
+ # gradient_accumulation_steps param. This value is used for logging and tracking
167
+ # training state. The computation should happen after the dataloader has been setup
168
+ self._steps_per_epoch = (
169
+ len(self._dataloader) // self._gradient_accumulation_steps
170
+ )
171
+ if (
172
+ self.max_steps_per_epoch is not None
173
+ and self.max_steps_per_epoch < self._steps_per_epoch
174
+ ):
175
+ self._steps_per_epoch = self.max_steps_per_epoch
176
+ self.total_training_steps = self.epochs_run * self._steps_per_epoch
177
+
178
+ def _update_recipe_state(self, ckpt_dict: Dict[str, Any]) -> None:
179
+ """
180
+ Updates the recipe state from checkpoint.
181
+ """
182
+ # If seed, total_epoch or max_steps_per_epoch don't match,
183
+ # warn the user and overwrite
184
+ if (
185
+ self.seed != ckpt_dict[SEED_KEY]
186
+ or self.total_epochs != ckpt_dict[TOTAL_EPOCHS_KEY]
187
+ or self.max_steps_per_epoch != ckpt_dict[MAX_STEPS_KEY]
188
+ ):
189
+ warn(
190
+ message="""Configured value for seed, epochs or max_steps_per_epoch
191
+ does not match the value stored in checkpoint."""
192
+ )
193
+ self.seed = utils.set_seed(seed=ckpt_dict[SEED_KEY])
194
+ self.epochs_run = ckpt_dict[EPOCHS_KEY]
195
+ self.total_epochs = ckpt_dict[TOTAL_EPOCHS_KEY]
196
+ self.max_steps_per_epoch = ckpt_dict[MAX_STEPS_KEY]
197
+
198
+ def _setup_model(
199
+ self,
200
+ model: str,
201
+ enable_fsdp: bool,
202
+ enable_activation_checkpointing: bool,
203
+ model_state_dict: Dict[str, Any],
204
+ ) -> nn.Module:
205
+ """
206
+ Set up the model including enabling FSDP and activation checkpointing. For this recipe,
207
+ ``enable_fsdp`` should always be ``True``. This is currently a configurable flag for
208
+ running tests on CPUs.
209
+ """
210
+ model = models.get_model(model, device=self._device)
211
+ model = (
212
+ utils.wrap_fsdp(
213
+ model=model,
214
+ device=self._device,
215
+ dtype=self._dtype,
216
+ strategy="FULL_SHARD",
217
+ auto_wrap_policy={modules.TransformerDecoderLayer},
218
+ )
219
+ if enable_fsdp
220
+ else model
221
+ )
222
+ if enable_activation_checkpointing:
223
+ utils.set_activation_checkpointing(
224
+ model, auto_wrap_policy={modules.TransformerDecoderLayer}
225
+ )
226
+
227
+ model.load_state_dict(model_state_dict)
228
+
229
+ if self._is_rank_zero:
230
+ log.info(
231
+ "Model is initialized. FSDP and Activation Checkpointing are enabled."
232
+ )
233
+ return model
234
+
235
+ def _setup_tokenizer(
236
+ self, tokenizer: str, tokenizer_checkpoint: str
237
+ ) -> modules.Tokenizer:
238
+ """
239
+ Unlike ```setup_model```, this takes in the checkpoint and loads the sentencepiece
240
+ tokenizer model. This is related to how the tokenizer is implemented and should
241
+ change in a future iteration.
242
+ """
243
+ tokenizer = models.get_tokenizer(tokenizer, path=tokenizer_checkpoint)
244
+
245
+ if self._is_rank_zero:
246
+ log.info("Tokenizer is initialized from file.")
247
+ return tokenizer
248
+
249
+ def _setup_optimizer(
250
+ self, optimizer: str, lr: float, opt_state_dict: Optional[Dict[str, Any]] = None
251
+ ) -> Optimizer:
252
+ """
253
+ Set up the optimizer. This method also handles transforing the state dict
254
+ for FSDP.
255
+ """
256
+ optimizer = modules.get_optimizer(optimizer, self._model, lr)
257
+ if opt_state_dict:
258
+ opt_state_dict = utils.transform_opt_state_dict(
259
+ opt_state_dict, self._model, optimizer
260
+ )
261
+ optimizer.load_state_dict(opt_state_dict)
262
+
263
+ if self._is_rank_zero:
264
+ log.info("Optimizer is initialized.")
265
+ return optimizer
266
+
267
+ def _setup_loss(self, loss: str) -> nn.Module:
268
+ loss_fn = modules.get_loss(loss)
269
+
270
+ if self._is_rank_zero:
271
+ log.info("Loss is initialized.")
272
+
273
+ return loss_fn
274
+
275
+ def _setup_data(
276
+ self, dataset: str, shuffle: bool, batch_size: int, train_on_input: bool
277
+ ) -> Tuple[DistributedSampler, DataLoader]:
278
+ """
279
+ All data related setup happens here. Currently this recipe only supports the
280
+ DistributedSamplers with Map-style Datasets which fit into memory. Other samplers,
281
+ iterable datasets and streaming datasets are not supported.
282
+ """
283
+ world_size, rank = utils.get_world_size_and_rank()
284
+ ds = AlpacaDataset(dataset, tokenizer=self._tokenizer, train_on_input=train_on_input)
285
+
286
+ sampler = DistributedSampler(
287
+ ds,
288
+ num_replicas=world_size,
289
+ rank=rank,
290
+ shuffle=shuffle,
291
+ seed=0,
292
+ )
293
+ dataloader = DataLoader(
294
+ dataset=ds,
295
+ batch_size=batch_size,
296
+ sampler=sampler,
297
+ collate_fn=partial(
298
+ utils.padded_collate,
299
+ padding_idx=self._tokenizer.pad_id,
300
+ ignore_idx=self._loss_fn.ignore_index, # TODO support loss without ignore_index
301
+ ),
302
+ )
303
+
304
+ if self._is_rank_zero:
305
+ log.info("Dataset and Sampler are initialized.")
306
+
307
+ return sampler, dataloader
308
+
309
+ def save_checkpoint(self, epoch: int) -> None:
310
+ """
311
+ Checkpoint the relevant state of a recipe.
312
+
313
+ This makes use of the `save_checkpoint` utility which is responsible for
314
+ writing the checkpoint dictionary to file. The contents of the dict are dictated
315
+ by whether training is complete or not.
316
+
317
+ If training is ongoing, optimizer state, seed and epochs_run are saved along with the
318
+ model weights.
319
+ """
320
+ os.makedirs(self._output_dir, exist_ok=True)
321
+ output_loc = f"{self._output_dir}/model_{epoch}.ckpt"
322
+ ckpt_dict = {MODEL_KEY: self._model}
323
+
324
+ # if training is in-progress, checkpoint the optimizer state as well
325
+ if epoch + 1 < self.total_epochs:
326
+ ckpt_dict.update(
327
+ {
328
+ OPT_KEY: self._optimizer,
329
+ SEED_KEY: self.seed,
330
+ EPOCHS_KEY: self.epochs_run,
331
+ TOTAL_EPOCHS_KEY: self.total_epochs,
332
+ MAX_STEPS_KEY: self.max_steps_per_epoch,
333
+ }
334
+ )
335
+ utils.save_checkpoint(ckpt_dict, output_loc)
336
+
337
+ if self._is_rank_zero:
338
+ log.info(
339
+ f"Model checkpoint of size {os.path.getsize(output_loc) >> 20} MB saved to {output_loc}"
340
+ )
341
+
342
+ def _should_update_weights(self, curr_step: int) -> bool:
343
+ """
344
+ Determines whether the weights should be updated on the current step or not.
345
+ True is returned either if we've accumulated gradients for enough steps or if this
346
+ is the last step in the epoch.
347
+ """
348
+ should_update_weights = (
349
+ curr_step + 1
350
+ ) % self._gradient_accumulation_steps == 0 or (
351
+ curr_step + 1
352
+ ) == self._steps_per_epoch
353
+ return should_update_weights
354
+
355
+ def train(self) -> None:
356
+ """
357
+ The core training loop. Supports training on subsets of the dataset using the
358
+ ``max_steps_per_epoch``.
359
+ """
360
+ _, rank = utils.get_world_size_and_rank()
361
+
362
+ # zero out the gradients before starting training
363
+ self._optimizer.zero_grad()
364
+
365
+ # self.epochs_run should be non-zero when we're resuming from a checkpoint
366
+ for curr_epoch in range(self.epochs_run, self.total_epochs):
367
+
368
+ # Update the sampler to ensure data is correctly shuffled across epochs
369
+ # in case shuffle is True
370
+ self._sampler.set_epoch(curr_epoch)
371
+
372
+ for idx, batch in enumerate(
373
+ pbar := tqdm(self._dataloader, disable=not (rank == 0))
374
+ ):
375
+ if (
376
+ self.max_steps_per_epoch is not None
377
+ and (idx // self._gradient_accumulation_steps)
378
+ == self.max_steps_per_epoch
379
+ ):
380
+ break
381
+
382
+ input_ids, labels = batch
383
+ input_ids = input_ids.to(self._device)
384
+ labels = labels.to(self._device)
385
+
386
+ with self._autocast:
387
+ logits = self._model(input_ids)
388
+ # Shift so that tokens < n predict n
389
+ logits = logits[..., :-1, :].contiguous()
390
+ labels = labels[..., 1:].contiguous()
391
+ logits = logits.transpose(1, 2)
392
+ # Compute loss
393
+ loss = self._loss_fn(logits, labels)
394
+
395
+ # Note: We're always logging the loss before normalizing it
396
+ # Check if this is the norm or not
397
+ pbar.set_description(f"{curr_epoch+1}|{idx+1}|Loss: {loss.item()}")
398
+
399
+ if self.total_training_steps % self._log_every_n_steps == 0:
400
+ self._metric_logger.log_dict(
401
+ {
402
+ "loss": loss.item(),
403
+ "lr": self._optimizer.param_groups[0]["lr"],
404
+ "gpu_resources": torch.cuda.memory_allocated(),
405
+ },
406
+ step=self.total_training_steps,
407
+ )
408
+
409
+ # Does loss normalization need to happen within autocast context?
410
+ loss = loss / self._gradient_accumulation_steps
411
+ self._grad_scaler.scale(loss).backward()
412
+
413
+ if self._should_update_weights(idx):
414
+ self._grad_scaler.step(self._optimizer)
415
+ self._grad_scaler.update()
416
+ self._optimizer.zero_grad(set_to_none=True)
417
+
418
+ # Update the number of steps when the weights are updated
419
+ self.total_training_steps += 1
420
+
421
+ self.epochs_run += 1
422
+ self.save_checkpoint(epoch=curr_epoch)
423
+
424
+ def cleanup(self) -> None:
425
+ self._metric_logger.close()
426
+
427
+
428
+ def recipe_main() -> None:
429
+ """
430
+ Entry point for the recipe.
431
+
432
+ Configurable parameters are read in the following order:
433
+ - Parameters specified in ``FullFinetuneParams``
434
+ - Overwritten by Parameters specified in ``alpaca_llama2_full_finetune.yaml``
435
+ - Overwritten by arguments from the command-line using ``TuneArgumentParser``
436
+ """
437
+ parser = utils.TuneArgumentParser(
438
+ description=FullFinetuneParams.__doc__,
439
+ formatter_class=argparse.RawDescriptionHelpFormatter,
440
+ )
441
+ args, _ = parser.parse_known_args()
442
+ args = vars(args)
443
+ recipe_params = FullFinetuneParams(**args)
444
+
445
+ # Env variables set by torch run; only need to initialize process group
446
+ # init_process_group(backend="nccl")
447
+
448
+ recipe = FullFinetuneRecipe(params=recipe_params)
449
+ recipe.setup(params=recipe_params)
450
+ recipe.train()
451
+ recipe.cleanup()
452
+
453
+
454
+ if __name__ == "__main__":
455
+ sys.exit(recipe_main())
colorful/adversarial_config.yaml ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Runs the full_finetune.py recipe
2
+ #
3
+ # To launch, run the following command from root:
4
+ # tune --nnodes 1 --nproc_per_node 1 --config alpaca_llama2_full_finetune --override model_checkpoint=<your_checkpoint_dir> ...
5
+
6
+ # Dataset and Dataloader
7
+ dataset: laurencer/yahma-alpaca-cleaned-adversarial
8
+ seed: 42
9
+ shuffle: True
10
+
11
+ # Checkpointing
12
+ # Removed for now given poor upload speeds for checkpoints
13
+ # hf_repo_id: laurencer/Llama7b-Alpaca-Tune-4epochs-WithColoring
14
+ checkpoint_every_n_steps: 500 # 6k steps per epoch
15
+
16
+ # Model Arguments
17
+ model_checkpoint: model/llama2_native.tune
18
+ tokenizer_checkpoint: model/tokenizer.model
19
+
20
+ color_layer_initialization: zeros
21
+ norm_before_color_layer: True
22
+
23
+ # Fine-tuning arguments
24
+ compile: False
25
+ batch_size: 8
26
+ lr: 2e-5
27
+ epochs: 4
28
+ optimizer: SGD
29
+ loss: CrossEntropyLoss
30
+ output_dir: output/alpaca-colorful-llama2-finetune
31
+ device: cuda
32
+ dtype: bf16
33
+ enable_fsdp: False
34
+ enable_activation_checkpointing: True
35
+ resume_from_checkpoint: False
36
+
37
+ # Logging arguments
38
+ metric_logger_type: wandb
39
+ project: torchtune
colorful/basic_config.yaml ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Runs the full_finetune.py recipe
2
+ #
3
+ # To launch, run the following command from root:
4
+ # tune --nnodes 1 --nproc_per_node 1 --config alpaca_llama2_full_finetune --override model_checkpoint=<your_checkpoint_dir> ...
5
+
6
+ # Dataset and Dataloader
7
+ dataset: yahma/alpaca-cleaned
8
+ seed: 42
9
+ shuffle: True
10
+
11
+ # Checkpointing
12
+ # Removed for now given poor upload speeds for checkpoints
13
+ # hf_repo_id: laurencer/Llama7b-Alpaca-Tune-4epochs-WithColoring
14
+ checkpoint_every_n_steps: 500 # 6k steps per epoch
15
+
16
+ # Model Arguments
17
+ model_checkpoint: model/llama2_native.tune
18
+ tokenizer_checkpoint: model/tokenizer.model
19
+
20
+ color_layer_initialization: zeros
21
+ norm_before_color_layer: True
22
+
23
+ # Fine-tuning arguments
24
+ compile: True
25
+ batch_size: 8
26
+ lr: 2e-5
27
+ epochs: 4
28
+ optimizer: SGD
29
+ loss: CrossEntropyLoss
30
+ output_dir: output/alpaca-colorful-llama2-finetune
31
+ device: cuda
32
+ dtype: bf16
33
+ enable_fsdp: False
34
+ enable_activation_checkpointing: True
35
+ resume_from_checkpoint: False
36
+
37
+ # Logging arguments
38
+ metric_logger_type: wandb
39
+ project: torchtune
colorful/custom_dataset.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from typing import List, Tuple
8
+
9
+ import torch
10
+
11
+ import torch.nn.functional as F
12
+ from torch.nn.utils.rnn import pad_sequence
13
+ from torch.utils.data import Dataset
14
+
15
+ from datasets import load_dataset
16
+
17
+ # Not ideal to import this type here but it's needed for the transform function
18
+ from torchtune.modules import Tokenizer
19
+
20
+
21
+ CROSS_ENTROPY_IGNORE_IDX = -100
22
+
23
+
24
+ DEFAULT = 0
25
+ INSTRUCTION = 1
26
+ INPUT = 2
27
+ RESPONSE = 3
28
+
29
+
30
+ class ColoringAlpacaDataset(Dataset):
31
+ """
32
+ See torchtune.datasets.alpaca.AlpacaDataset for the original implementation.
33
+
34
+ Constructor now takes in a dataset path directly.
35
+
36
+ This implementation returns 3 lists representing the tokens, labels, and token colors
37
+ (as opposed to just the tokens & labels from the original).
38
+ """
39
+
40
+ def __init__(
41
+ self,
42
+ tokenizer: Tokenizer,
43
+ dataset_path: str = "yahma/alpaca-cleaned",
44
+ train_on_input: bool = True,
45
+ **kwargs
46
+ ) -> None:
47
+ self._data = load_dataset(dataset_path, split="train")
48
+ self._tokenizer = tokenizer
49
+ self.train_on_input = train_on_input
50
+ self.num_colors = 4 # matches the above usage of DEFAULT, INSTRUCTION, INPUT, RESPONSE
51
+
52
+ def __len__(self):
53
+ return len(self._data)
54
+
55
+ def __getitem__(self, index: int) -> Tuple[List[int], List[int], List[int]]:
56
+ sample = self._data[index]
57
+
58
+ return self._transform(
59
+ instruction=sample["instruction"],
60
+ input=sample["input"],
61
+ output=sample["output"],
62
+ )
63
+
64
+ def _transform(
65
+ self, instruction: str, input: str, output: str
66
+ ) -> Tuple[List[int], List[int], List[int]]:
67
+ """
68
+ Split a sample on ``response`` tag to create input and labels.
69
+
70
+ Args:
71
+ instruction (str): Instruction text.
72
+ input (str): Input text. Can be an empty string. Determines the prompt generation template
73
+ used.
74
+ output (str): Response text.
75
+
76
+ Returns:
77
+ Tuple of encoded inputs, labels, token colors.
78
+ """
79
+ prompt = self._generate_prompt(instruction, input)
80
+
81
+ # First handle the prompt
82
+ colors = []
83
+ tokenized = []
84
+ labels = []
85
+ is_first = True
86
+ for token_type, text in prompt:
87
+ tokenized_part = self._tokenizer.encode(
88
+ text=text, add_bos=is_first, add_eos=False
89
+ )
90
+ is_first = False
91
+
92
+ tokenized += tokenized_part
93
+ colors += [token_type] * len(tokenized_part)
94
+ if not self.train_on_input:
95
+ labels += [CROSS_ENTROPY_IGNORE_IDX] * len(tokenized_part)
96
+ else:
97
+ labels += tokenized_part
98
+
99
+ # Now add the response tokens
100
+ tokenized_part = self._tokenizer.encode(
101
+ text=output, add_bos=False, add_eos=True
102
+ )
103
+ tokenized += tokenized_part
104
+ colors += [RESPONSE] * len(tokenized_part)
105
+ labels += tokenized_part
106
+
107
+ assert len(tokenized) == len(labels)
108
+ assert len(tokenized) == len(colors)
109
+
110
+ return tokenized, labels, colors
111
+
112
+ def _generate_prompt(self, instruction: str, input: str) -> List[Tuple[(int, str)]]:
113
+ """
114
+ Generate prompt from instruction and input.
115
+
116
+ Args:
117
+ instruction (str): Instruction text.
118
+ input (str): Input text.
119
+
120
+ Returns:
121
+ List of (int, templated text)
122
+ """
123
+ if input:
124
+ return [
125
+ (DEFAULT, (
126
+ "Below is an instruction that describes a task, paired with an input that provides further context. "
127
+ "Write a response that appropriately completes the request.\n\n"
128
+ "### Instruction:\n"
129
+ )),
130
+ (INSTRUCTION, instruction),
131
+ (DEFAULT, "\n\n### Input:\n"),
132
+ (INPUT, input),
133
+ (DEFAULT, "\n\n### Response:\n"),
134
+ ]
135
+ else:
136
+ return [
137
+ (DEFAULT, (
138
+ "Below is an instruction that describes a task. "
139
+ "Write a response that appropriately completes the request.\n\n"
140
+ "### Instruction:\n"
141
+ )),
142
+ (INSTRUCTION, instruction),
143
+ (DEFAULT, "\n\n### Response:\n"),
144
+ ]
145
+
146
+
147
+ # TokenPair is a pair (tuple) of three lists: tokenized text inputs, labels, colors.
148
+ TokenPair = Tuple[List[int], List[int], List[int]]
149
+
150
+
151
+ def padded_collate(
152
+ batch: List[TokenPair],
153
+ padding_idx: int = 0,
154
+ ignore_idx: int = -100,
155
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
156
+ input_ids = pad_sequence(
157
+ [torch.tensor(x[0]) for x in batch],
158
+ batch_first=True,
159
+ padding_value=padding_idx,
160
+ )
161
+ labels = pad_sequence(
162
+ [torch.tensor(x[1]) for x in batch],
163
+ batch_first=True,
164
+ padding_value=ignore_idx,
165
+ )
166
+ colors = pad_sequence(
167
+ [torch.tensor(x[2]) for x in batch],
168
+ batch_first=True,
169
+ padding_value=padding_idx,
170
+ )
171
+
172
+ input_ids_seq_len = input_ids.shape[-1]
173
+ labels_seq_len = labels.shape[-1]
174
+ colors_seq_len = colors.shape[-1]
175
+
176
+ assert input_ids_seq_len == labels_seq_len
177
+ assert input_ids_seq_len == colors_seq_len
178
+
179
+ return input_ids, labels, colors
colorful/custom_model.py ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+ import copy
3
+ import math
4
+
5
+ import torch
6
+ from torch import nn, Tensor
7
+
8
+ from torchtune.modules import (
9
+ CausalSelfAttention,
10
+ FeedForward,
11
+ KVCache,
12
+ RMSNorm,
13
+ RotaryPositionalEmbeddings,
14
+ # TransformerDecoder, replaced with our custom implementation.
15
+ TransformerDecoderLayer,
16
+ )
17
+
18
+ from masked_apply import MaskedApply
19
+
20
+
21
+ def initialize_identity_linear(size):
22
+ layer = nn.Linear(size, size)
23
+ layer.weight.data.copy_(torch.eye(size))
24
+ layer.bias.data.copy_(torch.zeros(size))
25
+ return layer
26
+
27
+ def initialize_linear(size):
28
+ return nn.Linear(size, size)
29
+
30
+ def initialize_kaiming_uniform_linear(size):
31
+ layer = nn.Linear(size, size)
32
+ nn.init.kaiming_uniform_(layer.weight, a=math.sqrt(5))
33
+ layer.bias.data.copy_(torch.zeros(size))
34
+ return layer
35
+
36
+ def initialize_zeros_linear(size):
37
+ layer = nn.Linear(size, size)
38
+ layer.weight.data.copy_(torch.zeros(size))
39
+ layer.bias.data.copy_(torch.zeros(size))
40
+ return layer
41
+
42
+ INITIALIZATION_OPTIONS = {
43
+ "identity": initialize_identity_linear,
44
+ "default": initialize_linear,
45
+ "kaiming_uniform": initialize_kaiming_uniform_linear,
46
+ "zeros": initialize_zeros_linear,
47
+ }
48
+
49
+ def _get_clones(module: nn.Module, n: int) -> nn.ModuleList:
50
+ """
51
+ Return a list of ``n`` identical layers.
52
+
53
+ Args:
54
+ module (nn.Module): module to be cloned
55
+ n (int): number of clones
56
+
57
+ Returns:
58
+ nn.ModuleList: list of ``n`` identical layers
59
+ """
60
+ # FIXME: copy.deepcopy() is not defined on nn.module
61
+ return nn.ModuleList([copy.deepcopy(module) for i in range(n)])
62
+
63
+
64
+ class ColoringTransformerDecoder(nn.Module):
65
+ """
66
+ See torchtune.models.llama2.TransformerDecoder for the original implementation.
67
+ """
68
+
69
+ def __init__(
70
+ self,
71
+ tok_embeddings: nn.Embedding,
72
+ embedding_transform: nn.Module,
73
+ layer: TransformerDecoderLayer,
74
+ num_layers: int,
75
+ norm: nn.Module,
76
+ output: nn.Linear,
77
+ embedding_norm: nn.Module = None
78
+ ) -> None:
79
+ super().__init__()
80
+ self.tok_embeddings = tok_embeddings
81
+ self.embedding_transform = embedding_transform
82
+ self.embedding_norm = embedding_norm
83
+ self.layers = _get_clones(layer, num_layers)
84
+ self.norm = norm
85
+ self.output = output
86
+
87
+ def forward(
88
+ self,
89
+ tokens: Tensor,
90
+ mask: Optional[Tensor] = None,
91
+ colors: Optional[Tensor] = None,
92
+ curr_pos: int = 0
93
+ ) -> Tensor:
94
+ """
95
+ Args:
96
+ tokens (Tensor): input tensor with shape [b x s]
97
+ mask (Optional[Tensor]): attention mask tensor, defaults to None.
98
+ curr_pos (int): current position in the seq, defaults to 0.
99
+ Only relevant when incrementally decoding.
100
+
101
+ Returns:
102
+ Tensor: output tensor with shape [b x s x v]
103
+
104
+ Notation used for tensor shapes:
105
+ - b: batch size
106
+ - s: sequence length
107
+ - v: vocab size
108
+ - d: embed dim
109
+ """
110
+ # input tensor of shape [b, s]
111
+ bsz, seq_len = tokens.shape
112
+
113
+ # shape: [b, s, d]
114
+ h = self.tok_embeddings(tokens)
115
+
116
+ # Apply normalization before embedding transform to improve
117
+ # training stability.
118
+ ch = h
119
+ if self.embedding_norm is not None:
120
+ # TODO: norm does an in-place operation, so we need to clone the input
121
+ ch = self.embedding_norm(h.clone())
122
+
123
+ # Apply the embedding transform (e.g. color layer)
124
+ ch = self.embedding_transform(ch, colors)
125
+
126
+ # Add the output of the color transform to the embeddings
127
+ h = h + ch
128
+
129
+ # TODO: Fix the masking logic to not rely on checking kv_cache
130
+ if seq_len > 1 and self.layers[0].attn.kv_cache is not None:
131
+ mask = torch.full(
132
+ (1, 1, seq_len, seq_len), float("-inf"), device=tokens.device
133
+ )
134
+ mask = torch.triu(mask, diagonal=curr_pos + 1)
135
+
136
+ for layer in self.layers:
137
+ # shape: [b, s, d]
138
+ h = layer(h, mask, curr_pos)
139
+
140
+ # shape: [b, s, d]
141
+ h = self.norm(h)
142
+
143
+ # shape: [b, s, v]
144
+ output = self.output(h).float()
145
+ return output
146
+
147
+
148
+ def coloring_llama2_7b(color_layer_initialization: str, norm_before_color_layer: bool = False, max_batch_size: Optional[int] = None) -> ColoringTransformerDecoder:
149
+ """Builder for creating a Llama2 model initialized w/ the default 7b parameter values.
150
+ From https://arxiv.org/abs/2307.09288, these default values are:
151
+ - vocab_size: 32,000
152
+ - embed_dim: 4,096
153
+ - num_layers: 32
154
+ - num_heads: 32
155
+ - num_kv_heads: 32
156
+ - max_seq_len: 4,096
157
+ - norm_eps: 1e-5
158
+
159
+ Args:
160
+ max_batch_size (Optional[int]): Maximum batch size to be passed to KVCache.
161
+
162
+ Returns:
163
+ A ``TransformerDecoder`` instance of the Llama2 model.
164
+ """
165
+ return coloring_llama2(
166
+ color_layer_initialization=color_layer_initialization,
167
+ vocab_size=32_000,
168
+ num_layers=32,
169
+ num_heads=32,
170
+ num_kv_heads=32,
171
+ embed_dim=4096,
172
+ max_seq_len=4096,
173
+ num_colors=4, # color for default, instruction, input, response
174
+ max_batch_size=max_batch_size,
175
+ attn_dropout=0.0,
176
+ norm_eps=1e-5,
177
+ norm_before_color_layer=norm_before_color_layer
178
+ )
179
+
180
+ def _scale_hidden_dim_for_mlp(dim: int, multiple_of: int = 256) -> int:
181
+ """Scale hidden dimension for MLP to keep number of parameters and computation constant.
182
+
183
+ Args:
184
+ dim (int): Input dimension.
185
+ multiple_of (int): Round scaled dimension to nearest multiple of `multiple_of` for clean computation.
186
+
187
+ Returns:
188
+ Scaled hidden dimension.
189
+ """
190
+ # Scale hidden dimension by (2/3)4d for SwiGLU to keep number of
191
+ # parameters and computation constant
192
+ hidden_dim = 4 * int(2 * dim / 3)
193
+ # Round hidden dimension to nearest multiple of `multiple_of`
194
+ hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
195
+ return hidden_dim
196
+
197
+
198
+ def coloring_llama2(
199
+ color_layer_initialization: str,
200
+ vocab_size: int,
201
+ num_layers: int,
202
+ num_heads: int,
203
+ num_kv_heads: int,
204
+ embed_dim: int,
205
+ max_seq_len: int,
206
+ num_colors: int,
207
+ norm_before_color_layer: bool = False,
208
+ attn_dropout: float = 0.0,
209
+ max_batch_size: Optional[int] = None,
210
+ norm_eps: float = 1e-5,
211
+ ):
212
+ if color_layer_initialization not in INITIALIZATION_OPTIONS:
213
+ raise ValueError(f"Invalid color_layer_initialization: {color_layer_initialization}. Expected one of {list(INITIALIZATION_OPTIONS.keys())}.")
214
+ color_layer_initializer = INITIALIZATION_OPTIONS[color_layer_initialization]
215
+
216
+ head_dim = embed_dim // num_heads
217
+ num_kv_heads = num_kv_heads if num_kv_heads else num_heads
218
+ kv_cache = (
219
+ KVCache(
220
+ max_batch_size=max_batch_size,
221
+ max_seq_len=max_seq_len,
222
+ n_kv_heads=num_heads,
223
+ head_dim=head_dim,
224
+ )
225
+ if max_batch_size is not None
226
+ else None
227
+ )
228
+ rope = RotaryPositionalEmbeddings(dim=head_dim, max_seq_len=max_seq_len)
229
+ self_attn = CausalSelfAttention(
230
+ embed_dim=embed_dim,
231
+ num_heads=num_heads,
232
+ num_kv_heads=num_kv_heads,
233
+ head_dim=head_dim,
234
+ q_proj=nn.Linear(embed_dim, num_heads * head_dim, bias=False),
235
+ k_proj=nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False),
236
+ v_proj=nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False),
237
+ output_proj=nn.Linear(embed_dim, embed_dim, bias=False),
238
+ pos_embeddings=rope,
239
+ kv_cache=kv_cache,
240
+ max_seq_len=max_seq_len,
241
+ attn_dropout=attn_dropout,
242
+ )
243
+ hidden_dim = _scale_hidden_dim_for_mlp(embed_dim)
244
+ mlp = FeedForward(dim=embed_dim, hidden_dim=hidden_dim, linear_class=nn.Linear)
245
+ layer = TransformerDecoderLayer(
246
+ attn=self_attn,
247
+ mlp=mlp,
248
+ sa_norm=RMSNorm(dim=embed_dim, eps=norm_eps),
249
+ mlp_norm=RMSNorm(dim=embed_dim, eps=norm_eps),
250
+ )
251
+ tok_embeddings = nn.Embedding(vocab_size, embed_dim)
252
+ output_proj = nn.Linear(embed_dim, vocab_size, bias=False)
253
+ embedding_transform = MaskedApply(
254
+ [color_layer_initializer(embed_dim) for _ in range(num_colors)],
255
+ strict=True
256
+ )
257
+ embedding_norm = RMSNorm(embed_dim, eps=norm_eps) if norm_before_color_layer else None
258
+
259
+ return ColoringTransformerDecoder(
260
+ tok_embeddings=tok_embeddings,
261
+ embedding_transform=embedding_transform,
262
+ embedding_norm=embedding_norm,
263
+ layer=layer,
264
+ num_layers=num_layers,
265
+ norm=RMSNorm(embed_dim, eps=norm_eps),
266
+ output=output_proj,
267
+ )
colorful/custom_params.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field, fields
2
+ from typing import List, Optional
3
+
4
+ from torchtune.datasets import ALL_DATASETS
5
+ from torchtune.models import ALL_MODELS, ALL_TOKENIZERS
6
+ from torchtune.utils.metric_logging import ALL_METRIC_LOGGERS
7
+ from torchtune.utils.precision import PRECISION_STR_TO_DTYPE
8
+
9
+
10
+ @dataclass
11
+ class ColoringFinetuneParams:
12
+ """Arguments for the finetune_llm recipe.
13
+
14
+ Args:
15
+ device (str): Device to use for training. Options are "cpu" and "cuda"
16
+ dtype (str): Data type to use for training.
17
+ seed (int): Random seed to use for training.
18
+ model (str): String specifying model architecture to fine-tune. See ``torchtune.models.get_model`` for options.
19
+ model_checkpoint (str): Local path to load model checkpoint from.
20
+ tokenizer (str): String specifying tokenizer to use. See ``torchtune.models.get_tokenizer`` for options.
21
+ tokenizer_checkpoint (str): Local path to load tokenizer checkpoint from.
22
+ dataset (str): String specifying dataset to use. See ``torchtune.datasets.get_dataset`` for options.
23
+ Currently, only predefined datasets in library are supported.
24
+ shuffle (bool): Whether to shuffle dataset.
25
+ batch_size (int): Batch size to use for training.
26
+ epochs (int): Number of epochs to train for.
27
+ optimizer (str): String specifying optimizer to use. See ``torchtune.optim.get_optimizer`` for options.
28
+ loss (str): String specifying loss function to use. See ``torchtune.losses.get_loss`` for options.
29
+ lr (float): Learning rate to use for optimizer.
30
+ activation_checkpointing (bool): Whether to use activation checkpointing.
31
+ output_dir (str): Local path to save checkpoints and logs to.
32
+ run_generation (int): Run eval on a prompt every ``run_generation`` steps. Set to 0 to disable.
33
+ max_steps_per_epoch (int): Maximum number of steps to take per epoch.
34
+ metric_logger_type (str): String specifying metric logger to use. See ``torchtune.utils.get_metric_logger``
35
+ for options.
36
+ project (str): Project name to use for logging. Used by ``WandBLogger``.
37
+ resume_from_previous_checkpoint (bool): Whether to resume fine-tuning from a previous checkpoint.
38
+ cpu_offload (bool): Whether to offload model to CPU.
39
+
40
+ Raises:
41
+ ValueError: If ``cpu_offload`` is ``True`` but ``device`` is not ``cuda`` and <= 1 GPUs.
42
+ """
43
+
44
+ # Model
45
+ model_checkpoint: str = ""
46
+
47
+ color_layer_initialization: str = "default"
48
+ norm_before_color_layer: bool = False
49
+
50
+ # Tokenizer
51
+ tokenizer_checkpoint: str = ""
52
+
53
+ hf_repo_id: Optional[str] = None
54
+ checkpoint_every_n_steps: Optional[int] = None
55
+
56
+ # Dataset and Sampler
57
+ dataset: str = ""
58
+ train_on_input: bool = True
59
+ shuffle: bool = True
60
+ batch_size: int = 2
61
+
62
+ # Optimizer and Scheduler
63
+ optimizer: str = "SGD"
64
+ lr: float = 2e-5
65
+ loss: str = "CrossEntropyLoss"
66
+ gradient_accumulation_steps: int = 1
67
+
68
+ # Training
69
+ compile: bool = False
70
+ epochs: int = 3
71
+ max_steps_per_epoch: Optional[int] = None
72
+ resume_from_checkpoint: bool = False
73
+ run_generation: Optional[int] = None
74
+
75
+ # Distributed
76
+ cpu_offload: bool = False
77
+ enable_fsdp: bool = True
78
+ enable_activation_checkpointing: bool = True
79
+
80
+ # Environment
81
+ device: str = "cuda"
82
+ dtype: str = "fp16"
83
+ seed: Optional[int] = None
84
+
85
+ # Logging
86
+ output_dir: str = "/tmp/full_finetune_output"
87
+ metric_logger_type: str = "disk"
88
+ project: Optional[str] = None
89
+ log_every_n_steps: Optional[int] = None
90
+
91
+ def __post_init__(self):
92
+ for param in fields(self):
93
+ if getattr(self, param.name) == "":
94
+ raise TypeError(f"{param.name} needs to be specified")
95
+
96
+ if self.cpu_offload and self.device != "cuda":
97
+ raise ValueError(
98
+ "Cannot offload model to CPU if device is not cuda or <= 1 GPUs."
99
+ )
100
+ if self.enable_fsdp and self.device == "cpu":
101
+ raise ValueError("FSDP is not supported on CPU.")
102
+
103
+ if self.metric_logger_type not in ALL_METRIC_LOGGERS:
104
+ raise ValueError(
105
+ f"Metric logger not recognized. Expected one of {ALL_METRIC_LOGGERS}, received {self.metric_logger_type}."
106
+ )
107
+ if self.dtype not in PRECISION_STR_TO_DTYPE:
108
+ raise ValueError(
109
+ f"Dtype {self.dtype} must be one of {', '.join(PRECISION_STR_TO_DTYPE.keys())} for finetuning."
110
+ )
colorful/full_finetune.py ADDED
@@ -0,0 +1,511 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import argparse
8
+ import os
9
+ import sys
10
+
11
+ from functools import partial
12
+ from typing import Any, Dict, Optional, Tuple
13
+ from warnings import warn
14
+
15
+ import torch
16
+
17
+ from torch import nn
18
+ from torch.cuda.amp import GradScaler
19
+ from torch.distributed import init_process_group
20
+ from torch.optim import Optimizer
21
+ from torch.utils.data import DataLoader, DistributedSampler
22
+ from torchtune.utils import get_device
23
+
24
+ from torchtune import models, modules, utils
25
+ from torchtune.utils.constants import (
26
+ EPOCHS_KEY,
27
+ MAX_STEPS_KEY,
28
+ MODEL_KEY,
29
+ OPT_KEY,
30
+ SEED_KEY,
31
+ TOTAL_EPOCHS_KEY,
32
+ )
33
+
34
+ from tqdm import tqdm
35
+
36
+ from recipes.interfaces import FTRecipeInterface
37
+ from recipes.params import FullFinetuneParams
38
+
39
+ from torchtune.models.llama2 import llama2_tokenizer
40
+
41
+ from huggingface_hub import HfApi
42
+
43
+ from custom_params import ColoringFinetuneParams
44
+ from custom_model import ColoringTransformerDecoder, coloring_llama2_7b
45
+ from custom_dataset import ColoringAlpacaDataset, padded_collate
46
+
47
+ log = utils.get_logger("DEBUG")
48
+
49
+
50
+ class ColoringFinetuneRecipe(FTRecipeInterface):
51
+ """
52
+ Full finetuning recipe for dense transformer-based LLMs such as Llama2.
53
+
54
+ This recipe supports:
55
+ - FSDP and activation checkpointing. This is enabled by default but can be
56
+ configured using the ``enable_fsdp`` and ``enable_activation_checkpointing`` flags.
57
+ - Mixed precision training - fp32, fp16 and bf16 are supported.
58
+ - Checkpointing of model weights, optimizer state and the recipe state (epoch and seed).
59
+ - Resuming from checkpoints saved using the ``save_checkpoint`` functionality.
60
+ - Logging to terminal. WandB and TensorBoard are currently not supported.
61
+
62
+ Assumptions:
63
+ - Training is launched with the Tune CLI (recommended) which uses TorchRun under the
64
+ hood. Setting up the env variables is handled by TorchRun.
65
+ - Training happens on CUDA (CPU training is not supported)
66
+ - Checkpoints are ONLY saved at epoch boundaries. Mid-epoch checkpointing is NOT supported.
67
+ - Datasets are Map-style and data fits in memory (not streamed).
68
+ """
69
+
70
+ _model: ColoringTransformerDecoder
71
+
72
+ def __init__(self, params: ColoringFinetuneParams) -> None:
73
+ self._params = params
74
+
75
+ self._device = utils.get_device(device=params.device)
76
+ self._dtype = utils.get_dtype(dtype=params.dtype)
77
+
78
+ self._hf_hub = HfApi()
79
+ self._hf_repo_id = params.hf_repo_id
80
+
81
+ if self._hf_repo_id is not None:
82
+ self._hf_hub.create_repo(
83
+ repo_id=self._hf_repo_id,
84
+ repo_type="model",
85
+ private=True,
86
+ exist_ok=True
87
+ )
88
+
89
+ # logging attributes
90
+ self._output_dir = params.output_dir
91
+ self._metric_logger = utils.get_metric_logger(
92
+ metric_logger_type=params.metric_logger_type,
93
+ project=params.project,
94
+ log_dir=params.output_dir,
95
+ )
96
+ self._log_every_n_steps = (
97
+ params.log_every_n_steps if params.log_every_n_steps else 1
98
+ )
99
+
100
+ self._checkpoint_every_n_steps = params.checkpoint_every_n_steps
101
+
102
+ # _is_rank_zero is used primarily for logging. In the future, the logger
103
+ # should directly take care of this
104
+ _, rank = utils.get_world_size_and_rank()
105
+ self._is_rank_zero = rank == 0
106
+
107
+ # Training params
108
+ self._compile = params.compile
109
+ self._resume_from_checkpoint = params.resume_from_checkpoint
110
+ self._enable_fsdp = params.enable_fsdp
111
+ self._gradient_accumulation_steps = params.gradient_accumulation_steps
112
+
113
+ # These are public properties which are updated by the checkpoint loader
114
+ # when ``resume_from_checkpoint`` is `True` or validated in tests
115
+ self.seed = utils.set_seed(seed=params.seed)
116
+ self.epochs_run = 0
117
+ self.total_epochs = params.epochs
118
+ self.max_steps_per_epoch = params.max_steps_per_epoch
119
+ self.total_training_steps = 0
120
+
121
+ def load_checkpoint(self, ckpt_path: str):
122
+ """
123
+ Extract the checkpoint state from file and validate.
124
+ """
125
+ ckpt_dict = torch.load(ckpt_path, map_location="cpu", weights_only=True)
126
+ utils.validate_checkpoint(ckpt_dict, self._resume_from_checkpoint)
127
+ return ckpt_dict
128
+
129
+ def setup(self, params: FullFinetuneParams) -> None:
130
+ """
131
+ Sets up the recipe state correctly. This includes setting recipe attributes based
132
+ on the ``resume_from_checkpoint`` flag.
133
+ """
134
+
135
+ ckpt_dict = self.load_checkpoint(ckpt_path=params.model_checkpoint)
136
+
137
+ # If we're resuming from checkpoint, the recipe's state should be updated before
138
+ # initializing the training components. This ensures that the seed is correctly
139
+ # propagated to the relevant components
140
+ if self._resume_from_checkpoint:
141
+ self._update_recipe_state(ckpt_dict)
142
+
143
+ # ``_setup_model`` handles initialization and loading the state dict. This method
144
+ # should be called before ``_setup_optimizer`` since transforming the optimizer
145
+ # state dict requires the model
146
+ self._model = self._setup_model(
147
+ enable_fsdp=params.enable_fsdp,
148
+ enable_activation_checkpointing=params.enable_activation_checkpointing,
149
+ model_state_dict=ckpt_dict[MODEL_KEY],
150
+ )
151
+
152
+ self._tokenizer = self._setup_tokenizer(
153
+ tokenizer_checkpoint=params.tokenizer_checkpoint
154
+ )
155
+
156
+ # _setup_optimizer should take in ckpt_dict only if training is resumed from
157
+ # checkpoint. Transforming the opt state dict is handled by this method
158
+ self._optimizer = self._setup_optimizer(
159
+ optimizer=params.optimizer,
160
+ lr=params.lr,
161
+ opt_state_dict=ckpt_dict[OPT_KEY] if self._resume_from_checkpoint else None,
162
+ )
163
+
164
+ self._loss_fn = self._setup_loss(loss=params.loss)
165
+
166
+ # sampler and dataloader depend on the tokenizer and loss_fn and should be
167
+ # setup after both of these are initialized
168
+ self._sampler, self._dataloader = self._setup_data(
169
+ dataset=params.dataset,
170
+ train_on_input=params.train_on_input,
171
+ shuffle=params.shuffle,
172
+ batch_size=params.batch_size,
173
+ )
174
+
175
+ # training setup
176
+ self._autocast = utils.get_autocast(self._dtype, self._device)
177
+ self._grad_scaler = None
178
+ if self._dtype == torch.float16:
179
+ self._grad_scaler = utils.get_gradient_scaler(fsdp=params.enable_fsdp)
180
+ else:
181
+ self._grad_scaler = GradScaler(enabled=False)
182
+
183
+ # Finally update the recipe state which can only be correctly set after all of the
184
+ # other components have been initialized and updated.
185
+ #
186
+ # Number of training steps in each epoch depends on the number of batches produced
187
+ # by the dataloader, the max_steps_per_epoch param set by the user and the
188
+ # gradient_accumulation_steps param. This value is used for logging and tracking
189
+ # training state. The computation should happen after the dataloader has been setup
190
+ self._steps_per_epoch = (
191
+ len(self._dataloader) // self._gradient_accumulation_steps
192
+ )
193
+ if (
194
+ self.max_steps_per_epoch is not None
195
+ and self.max_steps_per_epoch < self._steps_per_epoch
196
+ ):
197
+ self._steps_per_epoch = self.max_steps_per_epoch
198
+ self.total_training_steps = self.epochs_run * self._steps_per_epoch
199
+
200
+ def _update_recipe_state(self, ckpt_dict: Dict[str, Any]) -> None:
201
+ """
202
+ Updates the recipe state from checkpoint.
203
+ """
204
+ # If seed, total_epoch or max_steps_per_epoch don't match,
205
+ # warn the user and overwrite
206
+ if (
207
+ self.seed != ckpt_dict[SEED_KEY]
208
+ or self.total_epochs != ckpt_dict[TOTAL_EPOCHS_KEY]
209
+ or self.max_steps_per_epoch != ckpt_dict[MAX_STEPS_KEY]
210
+ ):
211
+ warn(
212
+ message="""Configured value for seed, epochs or max_steps_per_epoch
213
+ does not match the value stored in checkpoint."""
214
+ )
215
+ self.seed = utils.set_seed(seed=ckpt_dict[SEED_KEY])
216
+ self.epochs_run = ckpt_dict[EPOCHS_KEY]
217
+ self.total_epochs = ckpt_dict[TOTAL_EPOCHS_KEY]
218
+ self.max_steps_per_epoch = ckpt_dict[MAX_STEPS_KEY]
219
+
220
+ def _setup_model(
221
+ self,
222
+ enable_fsdp: bool,
223
+ enable_activation_checkpointing: bool,
224
+ model_state_dict: Dict[str, Any],
225
+ ) -> nn.Module:
226
+ """
227
+ Set up the model including enabling FSDP and activation checkpointing. For this recipe,
228
+ ``enable_fsdp`` should always be ``True``. This is currently a configurable flag for
229
+ running tests on CPUs.
230
+ """
231
+
232
+ with get_device(self._device):
233
+ model = coloring_llama2_7b(
234
+ self._params.color_layer_initialization,
235
+ norm_before_color_layer=self._params.norm_before_color_layer
236
+ )
237
+
238
+ model = (
239
+ utils.wrap_fsdp(
240
+ model=model,
241
+ device=self._device,
242
+ dtype=self._dtype,
243
+ strategy="FULL_SHARD",
244
+ auto_wrap_policy={modules.TransformerDecoderLayer},
245
+ )
246
+ if enable_fsdp
247
+ else model
248
+ )
249
+ if enable_activation_checkpointing:
250
+ utils.set_activation_checkpointing(
251
+ model, auto_wrap_policy={modules.TransformerDecoderLayer}
252
+ )
253
+
254
+ model.load_state_dict(model_state_dict, strict=False)
255
+
256
+ if self._is_rank_zero:
257
+ log.info(
258
+ "Model is initialized. FSDP and Activation Checkpointing are enabled."
259
+ )
260
+
261
+ if self._compile:
262
+ log.info("Compiling model using torch.compile. The first batch may take a few minutes while compilation occurs.")
263
+ model = torch.compile(model)
264
+ else:
265
+ log.info("Skipping model compilation")
266
+
267
+ return model
268
+
269
+ def _setup_tokenizer(
270
+ self, tokenizer_checkpoint: str
271
+ ) -> modules.Tokenizer:
272
+ """
273
+ Unlike ```setup_model```, this takes in the checkpoint and loads the sentencepiece
274
+ tokenizer model. This is related to how the tokenizer is implemented and should
275
+ change in a future iteration.
276
+ """
277
+ tokenizer = llama2_tokenizer(tokenizer_checkpoint)
278
+
279
+ if self._is_rank_zero:
280
+ log.info("Tokenizer is initialized from file.")
281
+ return tokenizer
282
+
283
+ def _setup_optimizer(
284
+ self, optimizer: str, lr: float, opt_state_dict: Optional[Dict[str, Any]] = None
285
+ ) -> Optimizer:
286
+ """
287
+ Set up the optimizer. This method also handles transforing the state dict
288
+ for FSDP.
289
+ """
290
+ optimizer = modules.get_optimizer(optimizer, self._model, lr)
291
+ if opt_state_dict:
292
+ opt_state_dict = utils.transform_opt_state_dict(
293
+ opt_state_dict, self._model, optimizer
294
+ )
295
+ optimizer.load_state_dict(opt_state_dict)
296
+
297
+ if self._is_rank_zero:
298
+ log.info("Optimizer is initialized.")
299
+ return optimizer
300
+
301
+ def _setup_loss(self, loss: str) -> nn.Module:
302
+ loss_fn = modules.get_loss(loss)
303
+
304
+ if self._is_rank_zero:
305
+ log.info("Loss is initialized.")
306
+
307
+ return loss_fn
308
+
309
+ def _setup_data(
310
+ self, dataset: str, shuffle: bool, batch_size: int, train_on_input: bool
311
+ ) -> Tuple[DistributedSampler, DataLoader]:
312
+ """
313
+ All data related setup happens here. Currently this recipe only supports the
314
+ DistributedSamplers with Map-style Datasets which fit into memory. Other samplers,
315
+ iterable datasets and streaming datasets are not supported.
316
+ """
317
+ world_size, rank = utils.get_world_size_and_rank()
318
+ ds = ColoringAlpacaDataset(tokenizer=self._tokenizer, dataset=dataset, train_on_input=train_on_input)
319
+
320
+ sampler = DistributedSampler(
321
+ ds,
322
+ num_replicas=world_size,
323
+ rank=rank,
324
+ shuffle=shuffle,
325
+ seed=0,
326
+ )
327
+
328
+ dataloader = DataLoader(
329
+ dataset=ds,
330
+ batch_size=batch_size,
331
+ sampler=sampler,
332
+ collate_fn=partial(
333
+ padded_collate,
334
+ padding_idx=self._tokenizer.pad_id,
335
+ ignore_idx=self._loss_fn.ignore_index, # TODO support loss without ignore_index
336
+ ),
337
+ )
338
+
339
+ if self._is_rank_zero:
340
+ log.info("Dataset and Sampler are initialized.")
341
+
342
+ return sampler, dataloader
343
+
344
+ def save_checkpoint(self, epoch: int) -> None:
345
+ """
346
+ Checkpoint the relevant state of a recipe.
347
+
348
+ This makes use of the `save_checkpoint` utility which is responsible for
349
+ writing the checkpoint dictionary to file. The contents of the dict are dictated
350
+ by whether training is complete or not.
351
+
352
+ If training is ongoing, optimizer state, seed and epochs_run are saved along with the
353
+ model weights.
354
+ """
355
+ os.makedirs(self._output_dir, exist_ok=True)
356
+ output_loc = f"{self._output_dir}/model_{epoch}.ckpt"
357
+ ckpt_dict = {MODEL_KEY: self._model}
358
+
359
+ # if training is in-progress, checkpoint the optimizer state as well
360
+ if epoch + 1 < self.total_epochs:
361
+ ckpt_dict.update(
362
+ {
363
+ OPT_KEY: self._optimizer,
364
+ SEED_KEY: self.seed,
365
+ EPOCHS_KEY: self.epochs_run,
366
+ TOTAL_EPOCHS_KEY: self.total_epochs,
367
+ MAX_STEPS_KEY: self.max_steps_per_epoch,
368
+ }
369
+ )
370
+ utils.save_checkpoint(ckpt_dict, output_loc)
371
+
372
+ if self._is_rank_zero:
373
+ log.info(
374
+ f"Model checkpoint of size {os.path.getsize(output_loc) >> 20} MB saved to {output_loc}"
375
+ )
376
+
377
+ if self._hf_repo_id is not None:
378
+ log.info(f"Uploading checkpoint to HuggingFace Hub: {self._hf_repo_id}")
379
+ self._hf_hub.upload_folder(
380
+ folder_path=self._output_dir,
381
+ repo_id=self._hf_repo_id,
382
+ repo_type="model",
383
+ run_as_future=True,
384
+ commit_message=f"Checkpoint for epoch {epoch} (step {self.total_training_steps})"
385
+ )
386
+ else:
387
+ log.info("Skipping uploading to HuggingFace Hub (no repo id specified)")
388
+
389
+
390
+
391
+ def _should_update_weights(self, curr_step: int) -> bool:
392
+ """
393
+ Determines whether the weights should be updated on the current step or not.
394
+ True is returned either if we've accumulated gradients for enough steps or if this
395
+ is the last step in the epoch.
396
+ """
397
+ should_update_weights = (
398
+ curr_step + 1
399
+ ) % self._gradient_accumulation_steps == 0 or (
400
+ curr_step + 1
401
+ ) == self._steps_per_epoch
402
+ return should_update_weights
403
+
404
+ def train(self) -> None:
405
+ """
406
+ The core training loop. Supports training on subsets of the dataset using the
407
+ ``max_steps_per_epoch``.
408
+ """
409
+ _, rank = utils.get_world_size_and_rank()
410
+
411
+ # zero out the gradients before starting training
412
+ self._optimizer.zero_grad()
413
+
414
+ # self.epochs_run should be non-zero when we're resuming from a checkpoint
415
+ for curr_epoch in range(self.epochs_run, self.total_epochs):
416
+
417
+ # Update the sampler to ensure data is correctly shuffled across epochs
418
+ # in case shuffle is True
419
+ self._sampler.set_epoch(curr_epoch)
420
+
421
+ for idx, batch in enumerate(
422
+ pbar := tqdm(self._dataloader, disable=not (rank == 0))
423
+ ):
424
+ if (
425
+ self.max_steps_per_epoch is not None
426
+ and (idx // self._gradient_accumulation_steps)
427
+ == self.max_steps_per_epoch
428
+ ):
429
+ break
430
+
431
+ input_ids, labels, colors = batch
432
+
433
+ input_ids = input_ids.to(self._device)
434
+ labels = labels.to(self._device)
435
+ colors = colors.to(self._device)
436
+
437
+ with self._autocast:
438
+ logits = self._model(input_ids, colors=colors)
439
+ # Shift so that tokens < n predict n
440
+ logits = logits[..., :-1, :].contiguous()
441
+ labels = labels[..., 1:].contiguous()
442
+ logits = logits.transpose(1, 2)
443
+ # Compute loss
444
+ loss = self._loss_fn(logits, labels)
445
+
446
+ # Note: We're always logging the loss before normalizing it
447
+ # Check if this is the norm or not
448
+ pbar.set_description(f"{curr_epoch+1}|{idx+1}|Loss: {loss.item()}")
449
+
450
+ if self.total_training_steps % self._log_every_n_steps == 0:
451
+ self._metric_logger.log_dict(
452
+ {
453
+ "loss": loss.item(),
454
+ "lr": self._optimizer.param_groups[0]["lr"],
455
+ "gpu_resources": torch.cuda.memory_allocated(),
456
+ },
457
+ step=self.total_training_steps,
458
+ )
459
+
460
+ if self._checkpoint_every_n_steps is not None:
461
+ if self.total_training_steps > 0 and self.total_training_steps % self._checkpoint_every_n_steps == 0:
462
+ self.save_checkpoint(epoch=curr_epoch)
463
+
464
+ # Does loss normalization need to happen within autocast context?
465
+ loss = loss / self._gradient_accumulation_steps
466
+ self._grad_scaler.scale(loss).backward()
467
+
468
+ if self._should_update_weights(idx):
469
+ self._grad_scaler.step(self._optimizer)
470
+ self._grad_scaler.update()
471
+ self._optimizer.zero_grad(set_to_none=True)
472
+
473
+ # Update the number of steps when the weights are updated
474
+ self.total_training_steps += 1
475
+
476
+ self.epochs_run += 1
477
+ self.save_checkpoint(epoch=curr_epoch)
478
+
479
+ def cleanup(self) -> None:
480
+ self._metric_logger.close()
481
+
482
+
483
+ def recipe_main() -> None:
484
+ """
485
+ Entry point for the recipe.
486
+
487
+ Configurable parameters are read in the following order:
488
+ - Parameters specified in ``ColoringFinetuneParams``
489
+ - Overwritten by Parameters specified in ``alpaca_llama2_full_finetune.yaml``
490
+ - Overwritten by arguments from the command-line using ``TuneArgumentParser``
491
+ """
492
+ parser = utils.TuneArgumentParser(
493
+ description=ColoringFinetuneParams.__doc__,
494
+ formatter_class=argparse.RawDescriptionHelpFormatter,
495
+ )
496
+ args, _ = parser.parse_known_args()
497
+ args = vars(args)
498
+ recipe_params = ColoringFinetuneParams(**args)
499
+
500
+ # Env variables set by torch run; only need to initialize process group
501
+ # Disabled since this breaks for now on RunPod.
502
+ # init_process_group(backend="nccl")
503
+
504
+ recipe = ColoringFinetuneRecipe(params=recipe_params)
505
+ recipe.setup(params=recipe_params)
506
+ recipe.train()
507
+ recipe.cleanup()
508
+
509
+
510
+ if __name__ == "__main__":
511
+ sys.exit(recipe_main())
colorful/masked_apply.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ class MaskedApply(nn.Module):
6
+ """
7
+ Uses an index mask to select a sbuset of the input and apply a layer to it.
8
+
9
+ E.g. if mask is [[0, 1, 0]] layers[0] will be applied to the first and third element
10
+ and layers[1] will be applied to the second element.
11
+ """
12
+
13
+ def __init__(self, layers, strict=False):
14
+ super(MaskedApply, self).__init__()
15
+ self.num_layers = len(layers)
16
+ self.layers = nn.ModuleList(layers)
17
+ self.strict = strict
18
+
19
+ # Create a CPU tensor to store the maximum value found.
20
+ # This will prevent the GPU being blocked while we check
21
+ # whether an index is > num_layers in strict mode.
22
+ self._maximum_found_cpu = torch.tensor([-1], device='cpu')
23
+ self._maximum_found = torch.tensor([-1])
24
+ if torch.cuda.is_available():
25
+ self._maximum_found_cpu = self._maximum_found_cpu.pin_memory()
26
+
27
+ def forward(self, x, mask):
28
+ # If in strict mode, check if we previously violated the maximum found.
29
+ if self._maximum_found_cpu >= self.num_layers:
30
+ raise ValueError(f'Unexpected index value found {self._maximum_found_cpu}. Should be less than {self.num_layers}')
31
+
32
+ # Ensure mask is a long tensor
33
+ mask = mask.long()
34
+
35
+ # Flatten x and mask for easier processing
36
+ batch_size, seq_length, embedding_size = x.shape
37
+
38
+ x_flat = x.view(-1, embedding_size)
39
+ mask_flat = mask.view(-1)
40
+
41
+ # Output placeholder
42
+ output_flat = torch.zeros_like(x_flat)
43
+
44
+ # Process each mask value
45
+ for i in range(self.num_layers):
46
+ # Find indices for current mask value
47
+ indices = torch.where(mask_flat == i)[0]
48
+
49
+ # Select relevant inputs for the current linear layer
50
+ selected_inputs = torch.index_select(x_flat, 0, indices)
51
+
52
+ # Apply linear layer
53
+ transformed = self.layers[i](selected_inputs)
54
+
55
+ # TODO: figure out why this is necessary.
56
+ transformed = transformed.to(x_flat.dtype)
57
+
58
+ # Place results back in the output tensor
59
+ output_flat.index_copy_(0, indices, transformed)
60
+
61
+ # Copy any out of range indices
62
+ if self.strict:
63
+ # This check is done asynchronously.
64
+ self._maximum_found = max(max(mask_flat), self._maximum_found)
65
+ self._maximum_found_cpu.copy_(self._maximum_found, non_blocking=True)
66
+ else:
67
+ indices = torch.where(mask_flat >= self.num_layers)[0]
68
+ selected_inputs = torch.index_select(x_flat, 0, indices)
69
+ output_flat.index_copy_(0, indices, selected_inputs)
70
+
71
+ # Reshape output to original dimensions
72
+ output = output_flat.view(batch_size, seq_length, embedding_size)
73
+ return output