laurencer commited on
Commit
c50cb21
1 Parent(s): 16c750e

Upload folder using huggingface_hub

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ wandb/run-20240211_141255-f3ffr2e5/run-f3ffr2e5.wandb filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Created by https://www.toptal.com/developers/gitignore/api/python,macos
2
+ # Edit at https://www.toptal.com/developers/gitignore?templates=python,macos
3
+
4
+ ### TorchTune ###
5
+
6
+ output/
7
+ model/
8
+
9
+ ### macOS ###
10
+ # General
11
+ .DS_Store
12
+ .AppleDouble
13
+ .LSOverride
14
+
15
+ # Icon must end with two \r
16
+ Icon
17
+
18
+
19
+ # Thumbnails
20
+ ._*
21
+
22
+ # Files that might appear in the root of a volume
23
+ .DocumentRevisions-V100
24
+ .fseventsd
25
+ .Spotlight-V100
26
+ .TemporaryItems
27
+ .Trashes
28
+ .VolumeIcon.icns
29
+ .com.apple.timemachine.donotpresent
30
+
31
+ # Directories potentially created on remote AFP share
32
+ .AppleDB
33
+ .AppleDesktop
34
+ Network Trash Folder
35
+ Temporary Items
36
+ .apdisk
37
+
38
+ ### macOS Patch ###
39
+ # iCloud generated files
40
+ *.icloud
41
+
42
+ ### Python ###
43
+ # Byte-compiled / optimized / DLL files
44
+ __pycache__/
45
+ *.py[cod]
46
+ *$py.class
47
+
48
+ # C extensions
49
+ *.so
50
+
51
+ # Distribution / packaging
52
+ .Python
53
+ build/
54
+ develop-eggs/
55
+ dist/
56
+ downloads/
57
+ eggs/
58
+ .eggs/
59
+ lib/
60
+ lib64/
61
+ parts/
62
+ sdist/
63
+ var/
64
+ wheels/
65
+ share/python-wheels/
66
+ *.egg-info/
67
+ .installed.cfg
68
+ *.egg
69
+ MANIFEST
70
+
71
+ # PyInstaller
72
+ # Usually these files are written by a python script from a template
73
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
74
+ *.manifest
75
+ *.spec
76
+
77
+ # Installer logs
78
+ pip-log.txt
79
+ pip-delete-this-directory.txt
80
+
81
+ # Unit test / coverage reports
82
+ htmlcov/
83
+ .tox/
84
+ .nox/
85
+ .coverage
86
+ .coverage.*
87
+ .cache
88
+ nosetests.xml
89
+ coverage.xml
90
+ *.cover
91
+ *.py,cover
92
+ .hypothesis/
93
+ .pytest_cache/
94
+ cover/
95
+
96
+ # Translations
97
+ *.mo
98
+ *.pot
99
+
100
+ # Django stuff:
101
+ *.log
102
+ local_settings.py
103
+ db.sqlite3
104
+ db.sqlite3-journal
105
+
106
+ # Flask stuff:
107
+ instance/
108
+ .webassets-cache
109
+
110
+ # Scrapy stuff:
111
+ .scrapy
112
+
113
+ # Sphinx documentation
114
+ docs/_build/
115
+
116
+ # PyBuilder
117
+ .pybuilder/
118
+ target/
119
+
120
+ # Jupyter Notebook
121
+ .ipynb_checkpoints
122
+
123
+ # IPython
124
+ profile_default/
125
+ ipython_config.py
126
+
127
+ # pyenv
128
+ # For a library or package, you might want to ignore these files since the code is
129
+ # intended to run in multiple environments; otherwise, check them in:
130
+ # .python-version
131
+
132
+ # pipenv
133
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
134
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
135
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
136
+ # install all needed dependencies.
137
+ #Pipfile.lock
138
+
139
+ # poetry
140
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
141
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
142
+ # commonly ignored for libraries.
143
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
144
+ #poetry.lock
145
+
146
+ # pdm
147
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
148
+ #pdm.lock
149
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
150
+ # in version control.
151
+ # https://pdm.fming.dev/#use-with-ide
152
+ .pdm.toml
153
+
154
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
155
+ __pypackages__/
156
+
157
+ # Celery stuff
158
+ celerybeat-schedule
159
+ celerybeat.pid
160
+
161
+ # SageMath parsed files
162
+ *.sage.py
163
+
164
+ # Environments
165
+ .env
166
+ .venv
167
+ env/
168
+ venv/
169
+ ENV/
170
+ env.bak/
171
+ venv.bak/
172
+
173
+ # Spyder project settings
174
+ .spyderproject
175
+ .spyproject
176
+
177
+ # Rope project settings
178
+ .ropeproject
179
+
180
+ # mkdocs documentation
181
+ /site
182
+
183
+ # mypy
184
+ .mypy_cache/
185
+ .dmypy.json
186
+ dmypy.json
187
+
188
+ # Pyre type checker
189
+ .pyre/
190
+
191
+ # pytype static type analyzer
192
+ .pytype/
193
+
194
+ # Cython debug symbols
195
+ cython_debug/
196
+
197
+ # PyCharm
198
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
199
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
200
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
201
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
202
+ #.idea/
203
+
204
+ ### Python Patch ###
205
+ # Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration
206
+ poetry.toml
207
+
208
+ # ruff
209
+ .ruff_cache/
210
+
211
+ # LSP config files
212
+ pyrightconfig.json
213
+
214
+ # End of https://www.toptal.com/developers/gitignore/api/python,macos
README.md ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # torchtune research repo: token coloring (colorful llama)
2
+
3
+ Playground to try out [token coloring](https://docs.google.com/document/d/1Win9vhddD-pu5P3SsG7E-dzN5oQl5DYWW1DhO7sBOgI/edit#heading=h.oqq00pt8expe) with TorchTune.
4
+
5
+ The repo was generated using the alpha version of [torchtune](https://github.com/pytorch-labs/torchtune).
6
+
7
+ Brief notes:
8
+
9
+ - The starting recipe is based on the Alpaca Llama2 7B full finetune recipe (switched to bf16).
10
+ - I copied a lot of functionality (like the actual model definition, dataset, etc) from torchtune repository directly since I needed to make changes.
11
+ - I reduced the flexiblity of the recipe (e.g. cannot specify the model or tokenizer) and increased it in other ways (e.g. can pass in a dataset path directly).
12
+ - I added intermediate checkpointing (i.e. every `n` steps) and automatically upload the checkpoint to HuggingFace Hub.
13
+ - Assumes `output/` is used to store model outputs and `model/` is used to store the base model checkpoints.
14
+
15
+ ## Getting started
16
+
17
+ The below instructions can be copy-pasted as is on to a running instance. They assume that the `HF_TOKEN` environment variable is set with a valid token.
18
+
19
+ ```bash
20
+ # for RunPod
21
+ cd /workspace
22
+ git clone git@github.com:pytorch-labs/torchtune.git
23
+ cd torchtune
24
+ pip install -e .
25
+
26
+ cd /workspace
27
+ git clone git@github.com:laurencer/torchtune-colorful-llama.git
28
+ cd torchtune-colorful-llama
29
+
30
+ # for wandb support
31
+ pip install wandb
32
+ ```
33
+
34
+ ```bash
35
+ mkdir -p model/
36
+ tune download --repo-id meta-llama/Llama-2-7b --output-dir model/
37
+ ```
38
+
39
+ ```bash
40
+ tune convert_checkpoint --checkpoint-path model/consolidated.00.pth --output-path model/llama2_native.tune
41
+ ```
42
+
43
+ ```bash
44
+ mkdir -p output/
45
+ # tune --nnodes 1 --nproc_per_node 1 ./full_finetune.py --config basic_config.yaml
46
+ nohup tune --nnodes 1 --nproc_per_node 1 ./full_finetune.py --config basic_config.yaml 2>&1 > training_log_$(date "+%Y.%m.%d_%H.%M.%S").log &
47
+ sleep 1
48
+ tail -f training_log_*.log
49
+ ```
basic_config.yaml ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Runs the full_finetune.py recipe
2
+ #
3
+ # To launch, run the following command from root:
4
+ # tune --nnodes 1 --nproc_per_node 1 --config alpaca_llama2_full_finetune --override model_checkpoint=<your_checkpoint_dir> ...
5
+
6
+ # Dataset and Dataloader
7
+ dataset: yahma/alpaca-cleaned
8
+ seed: 42
9
+ shuffle: True
10
+
11
+ # Checkpointing
12
+ # Removed for now given poor upload speeds for checkpoints
13
+ # hf_repo_id: laurencer/Llama7b-Alpaca-Tune-4epochs-WithColoring
14
+ checkpoint_every_n_steps: 5000 # 25k steps per epoch
15
+
16
+ # Model Arguments
17
+ model_checkpoint: model/llama2_native.tune
18
+ tokenizer_checkpoint: model/tokenizer.model
19
+
20
+ # Fine-tuning arguments
21
+ batch_size: 2
22
+ lr: 2e-5
23
+ epochs: 4
24
+ optimizer: SGD
25
+ loss: CrossEntropyLoss
26
+ output_dir: output/alpaca-llama2-finetune
27
+ device: cuda
28
+ dtype: fp16
29
+ enable_fsdp: False
30
+ enable_activation_checkpointing: True
31
+ resume_from_checkpoint: False
32
+
33
+ # Logging arguments
34
+ metric_logger_type: wandb
35
+ project: torchtune
custom_dataset.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from typing import List, Tuple
8
+
9
+ import torch
10
+
11
+ import torch.nn.functional as F
12
+ from torch.nn.utils.rnn import pad_sequence
13
+ from torch.utils.data import Dataset
14
+
15
+ from datasets import load_dataset
16
+
17
+ # Not ideal to import this type here but it's needed for the transform function
18
+ from torchtune.modules import Tokenizer
19
+
20
+
21
+ CROSS_ENTROPY_IGNORE_IDX = -100
22
+
23
+
24
+ DEFAULT = 0
25
+ INSTRUCTION = 1
26
+ INPUT = 2
27
+ RESPONSE = 3
28
+
29
+
30
+ class ColoringAlpacaDataset(Dataset):
31
+ """
32
+ See torchtune.datasets.alpaca.AlpacaDataset for the original implementation.
33
+
34
+ Constructor now takes in a dataset path directly.
35
+
36
+ This implementation returns 3 lists representing the tokens, labels, and token colors
37
+ (as opposed to just the tokens & labels from the original).
38
+ """
39
+
40
+ def __init__(
41
+ self,
42
+ tokenizer: Tokenizer,
43
+ dataset_path: str = "yahma/alpaca-cleaned",
44
+ train_on_input: bool = True,
45
+ **kwargs
46
+ ) -> None:
47
+ self._data = load_dataset(dataset_path, split="train")
48
+ self._tokenizer = tokenizer
49
+ self.train_on_input = train_on_input
50
+ self.num_colors = 4 # matches the above usage of DEFAULT, INSTRUCTION, INPUT, RESPONSE
51
+
52
+ def __len__(self):
53
+ return len(self._data)
54
+
55
+ def __getitem__(self, index: int) -> Tuple[List[int], List[int], List[int]]:
56
+ sample = self._data[index]
57
+
58
+ return self._transform(
59
+ instruction=sample["instruction"],
60
+ input=sample["input"],
61
+ output=sample["output"],
62
+ )
63
+
64
+ def _transform(
65
+ self, instruction: str, input: str, output: str
66
+ ) -> Tuple[List[int], List[int], List[int]]:
67
+ """
68
+ Split a sample on ``response`` tag to create input and labels.
69
+
70
+ Args:
71
+ instruction (str): Instruction text.
72
+ input (str): Input text. Can be an empty string. Determines the prompt generation template
73
+ used.
74
+ output (str): Response text.
75
+
76
+ Returns:
77
+ Tuple of encoded inputs, labels, token colors.
78
+ """
79
+ prompt = self._generate_prompt(instruction, input)
80
+
81
+ # First handle the prompt
82
+ colors = []
83
+ tokenized = []
84
+ labels = []
85
+ is_first = True
86
+ for token_type, text in prompt:
87
+ tokenized_part = self._tokenizer.encode(
88
+ text=text, add_bos=is_first, add_eos=False
89
+ )
90
+ is_first = False
91
+
92
+ tokenized += tokenized_part
93
+ colors += [token_type] * len(tokenized_part)
94
+ if not self.train_on_input:
95
+ labels += [CROSS_ENTROPY_IGNORE_IDX] * len(tokenized_part)
96
+ else:
97
+ labels += tokenized_part
98
+
99
+ # Now add the response tokens
100
+ tokenized_part = self._tokenizer.encode(
101
+ text=output, add_bos=False, add_eos=True
102
+ )
103
+ tokenized += tokenized_part
104
+ colors += [RESPONSE] * len(tokenized_part)
105
+ labels += tokenized_part
106
+
107
+ assert len(tokenized) == len(labels)
108
+ assert len(tokenized) == len(colors)
109
+
110
+ return tokenized, labels, colors
111
+
112
+ def _generate_prompt(self, instruction: str, input: str) -> List[Tuple[(int, str)]]:
113
+ """
114
+ Generate prompt from instruction and input.
115
+
116
+ Args:
117
+ instruction (str): Instruction text.
118
+ input (str): Input text.
119
+
120
+ Returns:
121
+ List of (int, templated text)
122
+ """
123
+ if input:
124
+ return [
125
+ (DEFAULT, (
126
+ "Below is an instruction that describes a task, paired with an input that provides further context. "
127
+ "Write a response that appropriately completes the request.\n\n"
128
+ "### Instruction:\n"
129
+ )),
130
+ (INSTRUCTION, instruction),
131
+ (DEFAULT, "\n\n### Input:\n"),
132
+ (INPUT, input),
133
+ (DEFAULT, "\n\n### Response:\n"),
134
+ ]
135
+ else:
136
+ return [
137
+ (DEFAULT, (
138
+ "Below is an instruction that describes a task. "
139
+ "Write a response that appropriately completes the request.\n\n"
140
+ "### Instruction:\n"
141
+ )),
142
+ (INSTRUCTION, instruction),
143
+ (DEFAULT, "\n\n### Response:\n"),
144
+ ]
145
+
146
+
147
+ # TokenPair is a pair (tuple) of three lists: tokenized text inputs, labels, colors.
148
+ TokenPair = Tuple[List[int], List[int], List[int]]
149
+
150
+
151
+ def padded_collate(
152
+ batch: List[TokenPair],
153
+ padding_idx: int = 0,
154
+ ignore_idx: int = -100,
155
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
156
+ input_ids = pad_sequence(
157
+ [torch.tensor(x[0]) for x in batch],
158
+ batch_first=True,
159
+ padding_value=padding_idx,
160
+ )
161
+ labels = pad_sequence(
162
+ [torch.tensor(x[1]) for x in batch],
163
+ batch_first=True,
164
+ padding_value=ignore_idx,
165
+ )
166
+ colors = pad_sequence(
167
+ [torch.tensor(x[2]) for x in batch],
168
+ batch_first=True,
169
+ padding_value=ignore_idx,
170
+ )
171
+
172
+ input_ids_seq_len = input_ids.shape[-1]
173
+ labels_seq_len = labels.shape[-1]
174
+ colors_seq_len = colors.shape[-1]
175
+
176
+ assert input_ids_seq_len == labels_seq_len
177
+ assert input_ids_seq_len == colors_seq_len
178
+
179
+ return input_ids, labels, colors
custom_model.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import torch
4
+
5
+ from torch import nn, Tensor
6
+ import copy
7
+
8
+ from torchtune.modules import (
9
+ CausalSelfAttention,
10
+ FeedForward,
11
+ KVCache,
12
+ RMSNorm,
13
+ RotaryPositionalEmbeddings,
14
+ # TransformerDecoder, replaced with our custom implementation.
15
+ TransformerDecoderLayer,
16
+ )
17
+
18
+ from masked_apply import MaskedApply
19
+
20
+
21
+ def _get_clones(module: nn.Module, n: int) -> nn.ModuleList:
22
+ """
23
+ Return a list of ``n`` identical layers.
24
+
25
+ Args:
26
+ module (nn.Module): module to be cloned
27
+ n (int): number of clones
28
+
29
+ Returns:
30
+ nn.ModuleList: list of ``n`` identical layers
31
+ """
32
+ # FIXME: copy.deepcopy() is not defined on nn.module
33
+ return nn.ModuleList([copy.deepcopy(module) for i in range(n)])
34
+
35
+
36
+ class ColoringTransformerDecoder(nn.Module):
37
+ """
38
+ See torchtune.models.llama2.TransformerDecoder for the original implementation.
39
+ """
40
+
41
+ def __init__(
42
+ self,
43
+ tok_embeddings: nn.Embedding,
44
+ embedding_transform: nn.Module,
45
+ layer: TransformerDecoderLayer,
46
+ num_layers: int,
47
+ norm: nn.Module,
48
+ output: nn.Linear,
49
+ ) -> None:
50
+ super().__init__()
51
+ self.tok_embeddings = tok_embeddings
52
+ self.embedding_transform = embedding_transform
53
+ self.layers = _get_clones(layer, num_layers)
54
+ self.norm = norm
55
+ self.output = output
56
+
57
+ def forward(
58
+ self,
59
+ tokens: Tensor,
60
+ mask: Optional[Tensor] = None,
61
+ colors: Optional[Tensor] = None,
62
+ curr_pos: int = 0
63
+ ) -> Tensor:
64
+ """
65
+ Args:
66
+ tokens (Tensor): input tensor with shape [b x s]
67
+ mask (Optional[Tensor]): attention mask tensor, defaults to None.
68
+ curr_pos (int): current position in the seq, defaults to 0.
69
+ Only relevant when incrementally decoding.
70
+
71
+ Returns:
72
+ Tensor: output tensor with shape [b x s x v]
73
+
74
+ Notation used for tensor shapes:
75
+ - b: batch size
76
+ - s: sequence length
77
+ - v: vocab size
78
+ - d: embed dim
79
+ """
80
+ # input tensor of shape [b, s]
81
+ bsz, seq_len = tokens.shape
82
+
83
+ # shape: [b, s, d]
84
+ h = self.tok_embeddings(tokens)
85
+
86
+ h = self.embedding_transform(h, colors)
87
+
88
+ # TODO: Fix the masking logic to not rely on checking kv_cache
89
+ if seq_len > 1 and self.layers[0].attn.kv_cache is not None:
90
+ mask = torch.full(
91
+ (1, 1, seq_len, seq_len), float("-inf"), device=tokens.device
92
+ )
93
+ mask = torch.triu(mask, diagonal=curr_pos + 1)
94
+
95
+ for layer in self.layers:
96
+ # shape: [b, s, d]
97
+ h = layer(h, mask, curr_pos)
98
+
99
+ # shape: [b, s, d]
100
+ h = self.norm(h)
101
+
102
+ # shape: [b, s, v]
103
+ output = self.output(h).float()
104
+ return output
105
+
106
+
107
+ def colouring_llama2_7b(max_batch_size: Optional[int] = None) -> ColoringTransformerDecoder:
108
+ """Builder for creating a Llama2 model initialized w/ the default 7b parameter values.
109
+ From https://arxiv.org/abs/2307.09288, these default values are:
110
+ - vocab_size: 32,000
111
+ - embed_dim: 4,096
112
+ - num_layers: 32
113
+ - num_heads: 32
114
+ - num_kv_heads: 32
115
+ - max_seq_len: 4,096
116
+ - norm_eps: 1e-5
117
+
118
+ Args:
119
+ max_batch_size (Optional[int]): Maximum batch size to be passed to KVCache.
120
+
121
+ Returns:
122
+ A ``TransformerDecoder`` instance of the Llama2 model.
123
+ """
124
+ return colouring_llama2(
125
+ vocab_size=32_000,
126
+ num_layers=32,
127
+ num_heads=32,
128
+ num_kv_heads=32,
129
+ embed_dim=4096,
130
+ max_seq_len=4096,
131
+ num_colors=4, # color for default, instruction, input, response
132
+ max_batch_size=max_batch_size,
133
+ attn_dropout=0.0,
134
+ norm_eps=1e-5,
135
+ )
136
+
137
+ def _scale_hidden_dim_for_mlp(dim: int, multiple_of: int = 256) -> int:
138
+ """Scale hidden dimension for MLP to keep number of parameters and computation constant.
139
+
140
+ Args:
141
+ dim (int): Input dimension.
142
+ multiple_of (int): Round scaled dimension to nearest multiple of `multiple_of` for clean computation.
143
+
144
+ Returns:
145
+ Scaled hidden dimension.
146
+ """
147
+ # Scale hidden dimension by (2/3)4d for SwiGLU to keep number of
148
+ # parameters and computation constant
149
+ hidden_dim = 4 * int(2 * dim / 3)
150
+ # Round hidden dimension to nearest multiple of `multiple_of`
151
+ hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
152
+ return hidden_dim
153
+
154
+
155
+ def colouring_llama2(
156
+ vocab_size: int,
157
+ num_layers: int,
158
+ num_heads: int,
159
+ num_kv_heads: int,
160
+ embed_dim: int,
161
+ max_seq_len: int,
162
+ num_colors: int,
163
+ attn_dropout: float = 0.0,
164
+ max_batch_size: Optional[int] = None,
165
+ norm_eps: float = 1e-5,
166
+ ):
167
+ head_dim = embed_dim // num_heads
168
+ num_kv_heads = num_kv_heads if num_kv_heads else num_heads
169
+ kv_cache = (
170
+ KVCache(
171
+ max_batch_size=max_batch_size,
172
+ max_seq_len=max_seq_len,
173
+ n_kv_heads=num_heads,
174
+ head_dim=head_dim,
175
+ )
176
+ if max_batch_size is not None
177
+ else None
178
+ )
179
+ rope = RotaryPositionalEmbeddings(dim=head_dim, max_seq_len=max_seq_len)
180
+ self_attn = CausalSelfAttention(
181
+ embed_dim=embed_dim,
182
+ num_heads=num_heads,
183
+ num_kv_heads=num_kv_heads,
184
+ head_dim=head_dim,
185
+ q_proj=nn.Linear(embed_dim, num_heads * head_dim, bias=False),
186
+ k_proj=nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False),
187
+ v_proj=nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False),
188
+ output_proj=nn.Linear(embed_dim, embed_dim, bias=False),
189
+ pos_embeddings=rope,
190
+ kv_cache=kv_cache,
191
+ max_seq_len=max_seq_len,
192
+ attn_dropout=attn_dropout,
193
+ )
194
+ hidden_dim = _scale_hidden_dim_for_mlp(embed_dim)
195
+ mlp = FeedForward(dim=embed_dim, hidden_dim=hidden_dim, linear_class=nn.Linear)
196
+ layer = TransformerDecoderLayer(
197
+ attn=self_attn,
198
+ mlp=mlp,
199
+ sa_norm=RMSNorm(dim=embed_dim, eps=norm_eps),
200
+ mlp_norm=RMSNorm(dim=embed_dim, eps=norm_eps),
201
+ )
202
+ tok_embeddings = nn.Embedding(vocab_size, embed_dim)
203
+ output_proj = nn.Linear(embed_dim, vocab_size, bias=False)
204
+ return ColoringTransformerDecoder(
205
+ tok_embeddings=tok_embeddings,
206
+ embedding_transform=MaskedApply([nn.Linear(embed_dim, embed_dim) for _ in range(num_colors)]),
207
+ layer=layer,
208
+ num_layers=num_layers,
209
+ norm=RMSNorm(embed_dim, eps=norm_eps),
210
+ output=output_proj,
211
+ )
custom_params.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field, fields
2
+ from typing import List, Optional
3
+
4
+ from torchtune.datasets import ALL_DATASETS
5
+ from torchtune.models import ALL_MODELS, ALL_TOKENIZERS
6
+ from torchtune.utils.metric_logging import ALL_METRIC_LOGGERS
7
+ from torchtune.utils.precision import PRECISION_STR_TO_DTYPE
8
+
9
+
10
+ @dataclass
11
+ class ColoringFinetuneParams:
12
+ """Arguments for the finetune_llm recipe.
13
+
14
+ Args:
15
+ device (str): Device to use for training. Options are "cpu" and "cuda"
16
+ dtype (str): Data type to use for training.
17
+ seed (int): Random seed to use for training.
18
+ model (str): String specifying model architecture to fine-tune. See ``torchtune.models.get_model`` for options.
19
+ model_checkpoint (str): Local path to load model checkpoint from.
20
+ tokenizer (str): String specifying tokenizer to use. See ``torchtune.models.get_tokenizer`` for options.
21
+ tokenizer_checkpoint (str): Local path to load tokenizer checkpoint from.
22
+ dataset (str): String specifying dataset to use. See ``torchtune.datasets.get_dataset`` for options.
23
+ Currently, only predefined datasets in library are supported.
24
+ shuffle (bool): Whether to shuffle dataset.
25
+ batch_size (int): Batch size to use for training.
26
+ epochs (int): Number of epochs to train for.
27
+ optimizer (str): String specifying optimizer to use. See ``torchtune.optim.get_optimizer`` for options.
28
+ loss (str): String specifying loss function to use. See ``torchtune.losses.get_loss`` for options.
29
+ lr (float): Learning rate to use for optimizer.
30
+ activation_checkpointing (bool): Whether to use activation checkpointing.
31
+ output_dir (str): Local path to save checkpoints and logs to.
32
+ run_generation (int): Run eval on a prompt every ``run_generation`` steps. Set to 0 to disable.
33
+ max_steps_per_epoch (int): Maximum number of steps to take per epoch.
34
+ metric_logger_type (str): String specifying metric logger to use. See ``torchtune.utils.get_metric_logger``
35
+ for options.
36
+ project (str): Project name to use for logging. Used by ``WandBLogger``.
37
+ resume_from_previous_checkpoint (bool): Whether to resume fine-tuning from a previous checkpoint.
38
+ cpu_offload (bool): Whether to offload model to CPU.
39
+
40
+ Raises:
41
+ ValueError: If ``cpu_offload`` is ``True`` but ``device`` is not ``cuda`` and <= 1 GPUs.
42
+ """
43
+
44
+ # Model
45
+ model_checkpoint: str = ""
46
+
47
+ # Tokenizer
48
+ tokenizer_checkpoint: str = ""
49
+
50
+ hf_repo_id: Optional[str] = None
51
+ checkpoint_every_n_steps: Optional[int] = None
52
+
53
+ # Dataset and Sampler
54
+ dataset: str = ""
55
+ train_on_input: bool = True
56
+ shuffle: bool = True
57
+ batch_size: int = 2
58
+
59
+ # Optimizer and Scheduler
60
+ optimizer: str = "SGD"
61
+ lr: float = 2e-5
62
+ loss: str = "CrossEntropyLoss"
63
+ gradient_accumulation_steps: int = 1
64
+
65
+ # Training
66
+ epochs: int = 3
67
+ max_steps_per_epoch: Optional[int] = None
68
+ resume_from_checkpoint: bool = False
69
+ run_generation: Optional[int] = None
70
+
71
+ # Distributed
72
+ cpu_offload: bool = False
73
+ enable_fsdp: bool = True
74
+ enable_activation_checkpointing: bool = True
75
+
76
+ # Environment
77
+ device: str = "cuda"
78
+ dtype: str = "fp16"
79
+ seed: Optional[int] = None
80
+
81
+ # Logging
82
+ output_dir: str = "/tmp/full_finetune_output"
83
+ metric_logger_type: str = "disk"
84
+ project: Optional[str] = None
85
+ log_every_n_steps: Optional[int] = None
86
+
87
+ def __post_init__(self):
88
+ for param in fields(self):
89
+ if getattr(self, param.name) == "":
90
+ raise TypeError(f"{param.name} needs to be specified")
91
+
92
+ if self.cpu_offload and self.device != "cuda":
93
+ raise ValueError(
94
+ "Cannot offload model to CPU if device is not cuda or <= 1 GPUs."
95
+ )
96
+ if self.enable_fsdp and self.device == "cpu":
97
+ raise ValueError("FSDP is not supported on CPU.")
98
+
99
+ if self.metric_logger_type not in ALL_METRIC_LOGGERS:
100
+ raise ValueError(
101
+ f"Metric logger not recognized. Expected one of {ALL_METRIC_LOGGERS}, received {self.metric_logger_type}."
102
+ )
103
+ if self.dtype not in PRECISION_STR_TO_DTYPE:
104
+ raise ValueError(
105
+ f"Dtype {self.dtype} must be one of {', '.join(PRECISION_STR_TO_DTYPE.keys())} for finetuning."
106
+ )
full_finetune.py ADDED
@@ -0,0 +1,502 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import argparse
8
+ import os
9
+ import sys
10
+
11
+ from functools import partial
12
+ from typing import Any, Dict, Optional, Tuple
13
+ from warnings import warn
14
+
15
+ import torch
16
+
17
+ from torch import nn
18
+ from torch.cuda.amp import GradScaler
19
+ from torch.distributed import init_process_group
20
+ from torch.optim import Optimizer
21
+ from torch.utils.data import DataLoader, DistributedSampler
22
+ from torchtune.utils import get_device
23
+
24
+ from torchtune import models, modules, utils
25
+ from torchtune.utils.constants import (
26
+ EPOCHS_KEY,
27
+ MAX_STEPS_KEY,
28
+ MODEL_KEY,
29
+ OPT_KEY,
30
+ SEED_KEY,
31
+ TOTAL_EPOCHS_KEY,
32
+ )
33
+
34
+ from tqdm import tqdm
35
+
36
+ from recipes.interfaces import FTRecipeInterface
37
+ from recipes.params import FullFinetuneParams
38
+
39
+ from torchtune.models.llama2 import llama2_tokenizer
40
+
41
+ from huggingface_hub import HfApi
42
+
43
+ from custom_params import ColoringFinetuneParams
44
+ from custom_model import ColoringTransformerDecoder, colouring_llama2_7b
45
+ from custom_dataset import ColoringAlpacaDataset, padded_collate
46
+
47
+ log = utils.get_logger("DEBUG")
48
+
49
+
50
+ class ColoringFinetuneRecipe(FTRecipeInterface):
51
+ """
52
+ Full finetuning recipe for dense transformer-based LLMs such as Llama2.
53
+
54
+ This recipe supports:
55
+ - FSDP and activation checkpointing. This is enabled by default but can be
56
+ configured using the ``enable_fsdp`` and ``enable_activation_checkpointing`` flags.
57
+ - Mixed precision training - fp32, fp16 and bf16 are supported.
58
+ - Checkpointing of model weights, optimizer state and the recipe state (epoch and seed).
59
+ - Resuming from checkpoints saved using the ``save_checkpoint`` functionality.
60
+ - Logging to terminal. WandB and TensorBoard are currently not supported.
61
+
62
+ Assumptions:
63
+ - Training is launched with the Tune CLI (recommended) which uses TorchRun under the
64
+ hood. Setting up the env variables is handled by TorchRun.
65
+ - Training happens on CUDA (CPU training is not supported)
66
+ - Checkpoints are ONLY saved at epoch boundaries. Mid-epoch checkpointing is NOT supported.
67
+ - Datasets are Map-style and data fits in memory (not streamed).
68
+ """
69
+
70
+ _model: ColoringTransformerDecoder
71
+
72
+ def __init__(self, params: ColoringFinetuneParams) -> None:
73
+
74
+ self._device = utils.get_device(device=params.device)
75
+ self._dtype = utils.get_dtype(dtype=params.dtype)
76
+
77
+ self._hf_hub = HfApi()
78
+ self._hf_repo_id = params.hf_repo_id
79
+
80
+ if self._hf_repo_id is not None:
81
+ self._hf_hub.create_repo(
82
+ repo_id=self._hf_repo_id,
83
+ repo_type="model",
84
+ private=True,
85
+ exist_ok=True
86
+ )
87
+
88
+ # logging attributes
89
+ self._output_dir = params.output_dir
90
+ self._metric_logger = utils.get_metric_logger(
91
+ metric_logger_type=params.metric_logger_type,
92
+ project=params.project,
93
+ log_dir=params.output_dir,
94
+ )
95
+ self._log_every_n_steps = (
96
+ params.log_every_n_steps if params.log_every_n_steps else 1
97
+ )
98
+
99
+ self._checkpoint_every_n_steps = params.checkpoint_every_n_steps
100
+
101
+ # _is_rank_zero is used primarily for logging. In the future, the logger
102
+ # should directly take care of this
103
+ _, rank = utils.get_world_size_and_rank()
104
+ self._is_rank_zero = rank == 0
105
+
106
+ # Training params
107
+ self._resume_from_checkpoint = params.resume_from_checkpoint
108
+ self._enable_fsdp = params.enable_fsdp
109
+ self._gradient_accumulation_steps = params.gradient_accumulation_steps
110
+
111
+ # These are public properties which are updated by the checkpoint loader
112
+ # when ``resume_from_checkpoint`` is `True` or validated in tests
113
+ self.seed = utils.set_seed(seed=params.seed)
114
+ self.epochs_run = 0
115
+ self.total_epochs = params.epochs
116
+ self.max_steps_per_epoch = params.max_steps_per_epoch
117
+ self.total_training_steps = 0
118
+
119
+ def load_checkpoint(self, ckpt_path: str):
120
+ """
121
+ Extract the checkpoint state from file and validate.
122
+ """
123
+ ckpt_dict = torch.load(ckpt_path, map_location="cpu", weights_only=True)
124
+ utils.validate_checkpoint(ckpt_dict, self._resume_from_checkpoint)
125
+ return ckpt_dict
126
+
127
+ def setup(self, params: FullFinetuneParams) -> None:
128
+ """
129
+ Sets up the recipe state correctly. This includes setting recipe attributes based
130
+ on the ``resume_from_checkpoint`` flag.
131
+ """
132
+
133
+ ckpt_dict = self.load_checkpoint(ckpt_path=params.model_checkpoint)
134
+
135
+ # If we're resuming from checkpoint, the recipe's state should be updated before
136
+ # initializing the training components. This ensures that the seed is correctly
137
+ # propagated to the relevant components
138
+ if self._resume_from_checkpoint:
139
+ self._update_recipe_state(ckpt_dict)
140
+
141
+ # ``_setup_model`` handles initialization and loading the state dict. This method
142
+ # should be called before ``_setup_optimizer`` since transforming the optimizer
143
+ # state dict requires the model
144
+ self._model = self._setup_model(
145
+ enable_fsdp=params.enable_fsdp,
146
+ enable_activation_checkpointing=params.enable_activation_checkpointing,
147
+ model_state_dict=ckpt_dict[MODEL_KEY],
148
+ )
149
+
150
+ self._tokenizer = self._setup_tokenizer(
151
+ tokenizer_checkpoint=params.tokenizer_checkpoint
152
+ )
153
+
154
+ # _setup_optimizer should take in ckpt_dict only if training is resumed from
155
+ # checkpoint. Transforming the opt state dict is handled by this method
156
+ self._optimizer = self._setup_optimizer(
157
+ optimizer=params.optimizer,
158
+ lr=params.lr,
159
+ opt_state_dict=ckpt_dict[OPT_KEY] if self._resume_from_checkpoint else None,
160
+ )
161
+
162
+ self._loss_fn = self._setup_loss(loss=params.loss)
163
+
164
+ # sampler and dataloader depend on the tokenizer and loss_fn and should be
165
+ # setup after both of these are initialized
166
+ self._sampler, self._dataloader = self._setup_data(
167
+ dataset=params.dataset,
168
+ train_on_input=params.train_on_input,
169
+ shuffle=params.shuffle,
170
+ batch_size=params.batch_size,
171
+ )
172
+
173
+ # training setup
174
+ self._autocast = utils.get_autocast(self._dtype, self._device)
175
+ self._grad_scaler = None
176
+ if self._dtype == torch.float16:
177
+ self._grad_scaler = utils.get_gradient_scaler(fsdp=params.enable_fsdp)
178
+ else:
179
+ self._grad_scaler = GradScaler(enabled=False)
180
+
181
+ # Finally update the recipe state which can only be correctly set after all of the
182
+ # other components have been initialized and updated.
183
+ #
184
+ # Number of training steps in each epoch depends on the number of batches produced
185
+ # by the dataloader, the max_steps_per_epoch param set by the user and the
186
+ # gradient_accumulation_steps param. This value is used for logging and tracking
187
+ # training state. The computation should happen after the dataloader has been setup
188
+ self._steps_per_epoch = (
189
+ len(self._dataloader) // self._gradient_accumulation_steps
190
+ )
191
+ if (
192
+ self.max_steps_per_epoch is not None
193
+ and self.max_steps_per_epoch < self._steps_per_epoch
194
+ ):
195
+ self._steps_per_epoch = self.max_steps_per_epoch
196
+ self.total_training_steps = self.epochs_run * self._steps_per_epoch
197
+
198
+ def _update_recipe_state(self, ckpt_dict: Dict[str, Any]) -> None:
199
+ """
200
+ Updates the recipe state from checkpoint.
201
+ """
202
+ # If seed, total_epoch or max_steps_per_epoch don't match,
203
+ # warn the user and overwrite
204
+ if (
205
+ self.seed != ckpt_dict[SEED_KEY]
206
+ or self.total_epochs != ckpt_dict[TOTAL_EPOCHS_KEY]
207
+ or self.max_steps_per_epoch != ckpt_dict[MAX_STEPS_KEY]
208
+ ):
209
+ warn(
210
+ message="""Configured value for seed, epochs or max_steps_per_epoch
211
+ does not match the value stored in checkpoint."""
212
+ )
213
+ self.seed = utils.set_seed(seed=ckpt_dict[SEED_KEY])
214
+ self.epochs_run = ckpt_dict[EPOCHS_KEY]
215
+ self.total_epochs = ckpt_dict[TOTAL_EPOCHS_KEY]
216
+ self.max_steps_per_epoch = ckpt_dict[MAX_STEPS_KEY]
217
+
218
+ def _setup_model(
219
+ self,
220
+ enable_fsdp: bool,
221
+ enable_activation_checkpointing: bool,
222
+ model_state_dict: Dict[str, Any],
223
+ ) -> nn.Module:
224
+ """
225
+ Set up the model including enabling FSDP and activation checkpointing. For this recipe,
226
+ ``enable_fsdp`` should always be ``True``. This is currently a configurable flag for
227
+ running tests on CPUs.
228
+ """
229
+
230
+ with get_device(self._device):
231
+ model = colouring_llama2_7b()
232
+
233
+ model = (
234
+ utils.wrap_fsdp(
235
+ model=model,
236
+ device=self._device,
237
+ dtype=self._dtype,
238
+ strategy="FULL_SHARD",
239
+ auto_wrap_policy={modules.TransformerDecoderLayer},
240
+ )
241
+ if enable_fsdp
242
+ else model
243
+ )
244
+ if enable_activation_checkpointing:
245
+ utils.set_activation_checkpointing(
246
+ model, auto_wrap_policy={modules.TransformerDecoderLayer}
247
+ )
248
+
249
+ model.load_state_dict(model_state_dict, strict=False)
250
+
251
+ if self._is_rank_zero:
252
+ log.info(
253
+ "Model is initialized. FSDP and Activation Checkpointing are enabled."
254
+ )
255
+
256
+ log.info("Compiling model")
257
+ model = torch.compile(model)
258
+ return model
259
+
260
+ def _setup_tokenizer(
261
+ self, tokenizer_checkpoint: str
262
+ ) -> modules.Tokenizer:
263
+ """
264
+ Unlike ```setup_model```, this takes in the checkpoint and loads the sentencepiece
265
+ tokenizer model. This is related to how the tokenizer is implemented and should
266
+ change in a future iteration.
267
+ """
268
+ tokenizer = llama2_tokenizer(tokenizer_checkpoint)
269
+
270
+ if self._is_rank_zero:
271
+ log.info("Tokenizer is initialized from file.")
272
+ return tokenizer
273
+
274
+ def _setup_optimizer(
275
+ self, optimizer: str, lr: float, opt_state_dict: Optional[Dict[str, Any]] = None
276
+ ) -> Optimizer:
277
+ """
278
+ Set up the optimizer. This method also handles transforing the state dict
279
+ for FSDP.
280
+ """
281
+ optimizer = modules.get_optimizer(optimizer, self._model, lr)
282
+ if opt_state_dict:
283
+ opt_state_dict = utils.transform_opt_state_dict(
284
+ opt_state_dict, self._model, optimizer
285
+ )
286
+ optimizer.load_state_dict(opt_state_dict)
287
+
288
+ if self._is_rank_zero:
289
+ log.info("Optimizer is initialized.")
290
+ return optimizer
291
+
292
+ def _setup_loss(self, loss: str) -> nn.Module:
293
+ loss_fn = modules.get_loss(loss)
294
+
295
+ if self._is_rank_zero:
296
+ log.info("Loss is initialized.")
297
+
298
+ return loss_fn
299
+
300
+ def _setup_data(
301
+ self, dataset: str, shuffle: bool, batch_size: int, train_on_input: bool
302
+ ) -> Tuple[DistributedSampler, DataLoader]:
303
+ """
304
+ All data related setup happens here. Currently this recipe only supports the
305
+ DistributedSamplers with Map-style Datasets which fit into memory. Other samplers,
306
+ iterable datasets and streaming datasets are not supported.
307
+ """
308
+ world_size, rank = utils.get_world_size_and_rank()
309
+ ds = ColoringAlpacaDataset(tokenizer=self._tokenizer, dataset=dataset, train_on_input=train_on_input)
310
+
311
+ sampler = DistributedSampler(
312
+ ds,
313
+ num_replicas=world_size,
314
+ rank=rank,
315
+ shuffle=shuffle,
316
+ seed=0,
317
+ )
318
+
319
+ dataloader = DataLoader(
320
+ dataset=ds,
321
+ batch_size=batch_size,
322
+ sampler=sampler,
323
+ collate_fn=partial(
324
+ padded_collate,
325
+ padding_idx=self._tokenizer.pad_id,
326
+ ignore_idx=self._loss_fn.ignore_index, # TODO support loss without ignore_index
327
+ ),
328
+ )
329
+
330
+ if self._is_rank_zero:
331
+ log.info("Dataset and Sampler are initialized.")
332
+
333
+ return sampler, dataloader
334
+
335
+ def save_checkpoint(self, epoch: int) -> None:
336
+ """
337
+ Checkpoint the relevant state of a recipe.
338
+
339
+ This makes use of the `save_checkpoint` utility which is responsible for
340
+ writing the checkpoint dictionary to file. The contents of the dict are dictated
341
+ by whether training is complete or not.
342
+
343
+ If training is ongoing, optimizer state, seed and epochs_run are saved along with the
344
+ model weights.
345
+ """
346
+ os.makedirs(self._output_dir, exist_ok=True)
347
+ output_loc = f"{self._output_dir}/model_{epoch}.ckpt"
348
+ ckpt_dict = {MODEL_KEY: self._model}
349
+
350
+ # if training is in-progress, checkpoint the optimizer state as well
351
+ if epoch + 1 < self.total_epochs:
352
+ ckpt_dict.update(
353
+ {
354
+ OPT_KEY: self._optimizer,
355
+ SEED_KEY: self.seed,
356
+ EPOCHS_KEY: self.epochs_run,
357
+ TOTAL_EPOCHS_KEY: self.total_epochs,
358
+ MAX_STEPS_KEY: self.max_steps_per_epoch,
359
+ }
360
+ )
361
+ utils.save_checkpoint(ckpt_dict, output_loc)
362
+
363
+ if self._is_rank_zero:
364
+ log.info(
365
+ f"Model checkpoint of size {os.path.getsize(output_loc) >> 20} MB saved to {output_loc}"
366
+ )
367
+
368
+ if self._hf_repo_id is not None:
369
+ log.info(f"Uploading checkpoint to HuggingFace Hub: {self._hf_repo_id}")
370
+ self._hf_hub.upload_folder(
371
+ folder_path=self._output_dir,
372
+ repo_id=self._hf_repo_id,
373
+ repo_type="model",
374
+ run_as_future=True,
375
+ commit_message=f"Checkpoint for epoch {epoch} (step {self.total_training_steps})"
376
+ )
377
+ else:
378
+ log.info("Skipping uploading to HuggingFace Hub (no repo id specified)")
379
+
380
+
381
+
382
+ def _should_update_weights(self, curr_step: int) -> bool:
383
+ """
384
+ Determines whether the weights should be updated on the current step or not.
385
+ True is returned either if we've accumulated gradients for enough steps or if this
386
+ is the last step in the epoch.
387
+ """
388
+ should_update_weights = (
389
+ curr_step + 1
390
+ ) % self._gradient_accumulation_steps == 0 or (
391
+ curr_step + 1
392
+ ) == self._steps_per_epoch
393
+ return should_update_weights
394
+
395
+ def train(self) -> None:
396
+ """
397
+ The core training loop. Supports training on subsets of the dataset using the
398
+ ``max_steps_per_epoch``.
399
+ """
400
+ _, rank = utils.get_world_size_and_rank()
401
+
402
+ # zero out the gradients before starting training
403
+ self._optimizer.zero_grad()
404
+
405
+ # self.epochs_run should be non-zero when we're resuming from a checkpoint
406
+ for curr_epoch in range(self.epochs_run, self.total_epochs):
407
+
408
+ # Update the sampler to ensure data is correctly shuffled across epochs
409
+ # in case shuffle is True
410
+ self._sampler.set_epoch(curr_epoch)
411
+
412
+ for idx, batch in enumerate(
413
+ pbar := tqdm(self._dataloader, disable=not (rank == 0))
414
+ ):
415
+ if (
416
+ self.max_steps_per_epoch is not None
417
+ and (idx // self._gradient_accumulation_steps)
418
+ == self.max_steps_per_epoch
419
+ ):
420
+ break
421
+
422
+ input_ids, labels, colors = batch
423
+
424
+ input_ids = input_ids.to(self._device)
425
+ labels = labels.to(self._device)
426
+ colors = colors.to(self._device)
427
+
428
+ with self._autocast:
429
+ logits = self._model(input_ids, colors=colors)
430
+ # Shift so that tokens < n predict n
431
+ logits = logits[..., :-1, :].contiguous()
432
+ labels = labels[..., 1:].contiguous()
433
+ logits = logits.transpose(1, 2)
434
+ # Compute loss
435
+ loss = self._loss_fn(logits, labels)
436
+
437
+ # Note: We're always logging the loss before normalizing it
438
+ # Check if this is the norm or not
439
+ pbar.set_description(f"{curr_epoch+1}|{idx+1}|Loss: {loss.item()}")
440
+
441
+ if self.total_training_steps % self._log_every_n_steps == 0:
442
+ self._metric_logger.log_dict(
443
+ {
444
+ "loss": loss.item(),
445
+ "lr": self._optimizer.param_groups[0]["lr"],
446
+ "gpu_resources": torch.cuda.memory_allocated(),
447
+ },
448
+ step=self.total_training_steps,
449
+ )
450
+
451
+ if self._checkpoint_every_n_steps is not None:
452
+ if self.total_training_steps % self._checkpoint_every_n_steps == 0:
453
+ self.save_checkpoint(epoch=curr_epoch)
454
+
455
+ # Does loss normalization need to happen within autocast context?
456
+ loss = loss / self._gradient_accumulation_steps
457
+ self._grad_scaler.scale(loss).backward()
458
+
459
+ if self._should_update_weights(idx):
460
+ self._grad_scaler.step(self._optimizer)
461
+ self._grad_scaler.update()
462
+ self._optimizer.zero_grad(set_to_none=True)
463
+
464
+ # Update the number of steps when the weights are updated
465
+ self.total_training_steps += 1
466
+
467
+ self.epochs_run += 1
468
+ self.save_checkpoint(epoch=curr_epoch)
469
+
470
+ def cleanup(self) -> None:
471
+ self._metric_logger.close()
472
+
473
+
474
+ def recipe_main() -> None:
475
+ """
476
+ Entry point for the recipe.
477
+
478
+ Configurable parameters are read in the following order:
479
+ - Parameters specified in ``ColoringFinetuneParams``
480
+ - Overwritten by Parameters specified in ``alpaca_llama2_full_finetune.yaml``
481
+ - Overwritten by arguments from the command-line using ``TuneArgumentParser``
482
+ """
483
+ parser = utils.TuneArgumentParser(
484
+ description=ColoringFinetuneParams.__doc__,
485
+ formatter_class=argparse.RawDescriptionHelpFormatter,
486
+ )
487
+ args, _ = parser.parse_known_args()
488
+ args = vars(args)
489
+ recipe_params = ColoringFinetuneParams(**args)
490
+
491
+ # Env variables set by torch run; only need to initialize process group
492
+ # Disabled since this breaks for now on RunPod.
493
+ # init_process_group(backend="nccl")
494
+
495
+ recipe = ColoringFinetuneRecipe(params=recipe_params)
496
+ recipe.setup(params=recipe_params)
497
+ recipe.train()
498
+ recipe.cleanup()
499
+
500
+
501
+ if __name__ == "__main__":
502
+ sys.exit(recipe_main())
masked_apply.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ class MaskedApply(nn.Module):
6
+ """
7
+ Uses an index mask to select a sbuset of the input and apply a layer to it.
8
+
9
+ E.g. if mask is [[0, 1, 0]] layers[0] will be applied to the first and third element
10
+ and layers[1] will be applied to the second element.
11
+ """
12
+
13
+ def __init__(self, layers):
14
+ super(MaskedApply, self).__init__()
15
+ self.num_layers = len(layers)
16
+ self.layers = nn.ModuleList(layers)
17
+
18
+ def forward(self, x, mask):
19
+ # Ensure mask is a long tensor
20
+ mask = mask.long()
21
+
22
+ # Flatten x and mask for easier processing
23
+ batch_size, seq_length, embedding_size = x.shape
24
+
25
+ x_flat = x.view(-1, embedding_size)
26
+ mask_flat = mask.view(-1)
27
+
28
+ # Output placeholder
29
+ output_flat = torch.zeros_like(x_flat)
30
+
31
+ # Process each mask value
32
+ for i in range(self.num_layers):
33
+ # Find indices for current mask value
34
+ indices = torch.where(mask_flat == i)[0]
35
+
36
+ # Select relevant inputs for the current linear layer
37
+ selected_inputs = torch.index_select(x_flat, 0, indices)
38
+
39
+ # Apply linear layer
40
+ transformed = self.layers[i](selected_inputs)
41
+
42
+ # TODO: figure out why this is necessary.
43
+ transformed = transformed.to(x_flat.dtype)
44
+
45
+ # Place results back in the output tensor
46
+ output_flat.index_copy_(0, indices, transformed)
47
+
48
+ # Reshape output to original dimensions
49
+ output = output_flat.view(batch_size, seq_length, embedding_size)
50
+ return output
wandb/run-20240211_140449-81tescpe/files/config.yaml ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ wandb_version: 1
2
+
3
+ log_dir:
4
+ desc: null
5
+ value: output/alpaca-llama2-finetune
6
+ _wandb:
7
+ desc: null
8
+ value:
9
+ python_version: 3.10.12
10
+ cli_version: 0.16.3
11
+ framework: torch
12
+ is_jupyter_run: false
13
+ is_kaggle_kernel: false
14
+ start_time: 1707660289.367696
15
+ t:
16
+ 1:
17
+ - 1
18
+ - 49
19
+ - 51
20
+ - 55
21
+ 2:
22
+ - 1
23
+ - 49
24
+ - 51
25
+ - 55
26
+ 3:
27
+ - 16
28
+ - 23
29
+ 4: 3.10.12
30
+ 5: 0.16.3
31
+ 8:
32
+ - 5
33
+ 13: linux-x86_64
wandb/run-20240211_140449-81tescpe/files/requirements.txt ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiohttp==3.9.3
2
+ aiosignal==1.3.1
3
+ antlr4-python3-runtime==4.9.3
4
+ anyio==4.2.0
5
+ appdirs==1.4.4
6
+ argon2-cffi-bindings==21.2.0
7
+ argon2-cffi==23.1.0
8
+ arrow==1.3.0
9
+ asttokens==2.4.1
10
+ async-lru==2.0.4
11
+ async-timeout==4.0.3
12
+ attrs==23.2.0
13
+ babel==2.14.0
14
+ beautifulsoup4==4.12.3
15
+ bleach==6.1.0
16
+ blinker==1.4
17
+ certifi==2024.2.2
18
+ cffi==1.16.0
19
+ charset-normalizer==3.3.2
20
+ click==8.1.7
21
+ comm==0.2.1
22
+ cryptography==3.4.8
23
+ datasets==2.15.0
24
+ dbus-python==1.2.18
25
+ debugpy==1.8.0
26
+ decorator==5.1.1
27
+ defusedxml==0.7.1
28
+ dill==0.3.7
29
+ distro==1.7.0
30
+ docker-pycreds==0.4.0
31
+ entrypoints==0.4
32
+ exceptiongroup==1.2.0
33
+ executing==2.0.1
34
+ fastjsonschema==2.19.1
35
+ filelock==3.13.1
36
+ fqdn==1.5.1
37
+ frozenlist==1.4.1
38
+ fsspec==2023.10.0
39
+ gitdb==4.0.11
40
+ gitpython==3.1.41
41
+ h11==0.14.0
42
+ httpcore==1.0.2
43
+ httplib2==0.20.2
44
+ httpx==0.26.0
45
+ huggingface-hub==0.19.4
46
+ idna==3.6
47
+ importlib-metadata==4.6.4
48
+ ipykernel==6.29.0
49
+ ipython-genutils==0.2.0
50
+ ipython==8.21.0
51
+ ipywidgets==8.1.1
52
+ isoduration==20.11.0
53
+ jedi==0.19.1
54
+ jeepney==0.7.1
55
+ jinja2==3.1.3
56
+ json5==0.9.14
57
+ jsonpointer==2.4
58
+ jsonschema-specifications==2023.12.1
59
+ jsonschema==4.21.1
60
+ jupyter-archive==3.4.0
61
+ jupyter-client==7.4.9
62
+ jupyter-contrib-core==0.4.2
63
+ jupyter-contrib-nbextensions==0.7.0
64
+ jupyter-core==5.7.1
65
+ jupyter-events==0.9.0
66
+ jupyter-highlight-selected-word==0.2.0
67
+ jupyter-lsp==2.2.2
68
+ jupyter-nbextensions-configurator==0.6.3
69
+ jupyter-server-terminals==0.5.2
70
+ jupyter-server==2.12.5
71
+ jupyterlab-pygments==0.3.0
72
+ jupyterlab-server==2.25.2
73
+ jupyterlab-widgets==3.0.9
74
+ jupyterlab==4.1.0
75
+ keyring==23.5.0
76
+ launchpadlib==1.10.16
77
+ lazr.restfulclient==0.14.4
78
+ lazr.uri==1.0.6
79
+ lxml==5.1.0
80
+ markupsafe==2.1.5
81
+ matplotlib-inline==0.1.6
82
+ mistune==3.0.2
83
+ more-itertools==8.10.0
84
+ mpmath==1.3.0
85
+ multidict==6.0.5
86
+ multiprocess==0.70.15
87
+ nbclassic==1.0.0
88
+ nbclient==0.9.0
89
+ nbconvert==7.14.2
90
+ nbformat==5.9.2
91
+ nest-asyncio==1.6.0
92
+ networkx==3.2.1
93
+ notebook-shim==0.2.3
94
+ notebook==6.5.5
95
+ numpy==1.26.3
96
+ nvidia-cublas-cu12==12.1.3.1
97
+ nvidia-cuda-cupti-cu12==12.1.105
98
+ nvidia-cuda-nvrtc-cu12==12.1.105
99
+ nvidia-cuda-runtime-cu12==12.1.105
100
+ nvidia-cudnn-cu12==8.9.2.26
101
+ nvidia-cufft-cu12==11.0.2.54
102
+ nvidia-curand-cu12==10.3.2.106
103
+ nvidia-cusolver-cu12==11.4.5.107
104
+ nvidia-cusparse-cu12==12.1.0.106
105
+ nvidia-nccl-cu12==2.19.3
106
+ nvidia-nvjitlink-cu12==12.3.101
107
+ nvidia-nvtx-cu12==12.1.105
108
+ oauthlib==3.2.0
109
+ omegaconf==2.3.0
110
+ overrides==7.7.0
111
+ packaging==23.2
112
+ pandas==2.2.0
113
+ pandocfilters==1.5.1
114
+ parso==0.8.3
115
+ pexpect==4.9.0
116
+ pillow==10.2.0
117
+ pip==24.0
118
+ platformdirs==4.2.0
119
+ prometheus-client==0.19.0
120
+ prompt-toolkit==3.0.43
121
+ protobuf==4.25.2
122
+ psutil==5.9.8
123
+ ptyprocess==0.7.0
124
+ pure-eval==0.2.2
125
+ pyarrow-hotfix==0.6
126
+ pyarrow==15.0.0
127
+ pycparser==2.21
128
+ pygments==2.17.2
129
+ pygobject==3.42.1
130
+ pyjwt==2.3.0
131
+ pyparsing==2.4.7
132
+ python-apt==2.4.0+ubuntu2
133
+ python-dateutil==2.8.2
134
+ python-json-logger==2.0.7
135
+ pytz==2024.1
136
+ pyyaml==6.0.1
137
+ pyzmq==24.0.1
138
+ referencing==0.33.0
139
+ requests==2.31.0
140
+ rfc3339-validator==0.1.4
141
+ rfc3986-validator==0.1.1
142
+ rpds-py==0.17.1
143
+ secretstorage==3.3.1
144
+ send2trash==1.8.2
145
+ sentencepiece==0.1.99
146
+ sentry-sdk==1.40.3
147
+ setproctitle==1.3.3
148
+ setuptools==69.0.3
149
+ six==1.16.0
150
+ smmap==5.0.1
151
+ sniffio==1.3.0
152
+ soupsieve==2.5
153
+ stack-data==0.6.3
154
+ sympy==1.12
155
+ terminado==0.18.0
156
+ tinycss2==1.2.1
157
+ tomli==2.0.1
158
+ torch==2.2.0
159
+ torchaudio==2.2.0
160
+ torchtune==0.0.1
161
+ torchvision==0.17.0
162
+ tornado==6.4
163
+ tqdm==4.66.1
164
+ traitlets==5.14.1
165
+ triton==2.2.0
166
+ types-python-dateutil==2.8.19.20240106
167
+ typing-extensions==4.9.0
168
+ tzdata==2023.4
169
+ uri-template==1.3.0
170
+ urllib3==2.2.0
171
+ wadllib==1.3.6
172
+ wandb==0.16.3
173
+ wcwidth==0.2.13
174
+ webcolors==1.13
175
+ webencodings==0.5.1
176
+ websocket-client==1.7.0
177
+ wheel==0.42.0
178
+ widgetsnbextension==4.0.9
179
+ xxhash==3.4.1
180
+ yarl==1.9.4
181
+ zipp==1.0.0
wandb/run-20240211_140449-81tescpe/files/wandb-metadata.json ADDED
@@ -0,0 +1,691 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "os": "Linux-5.4.0-169-generic-x86_64-with-glibc2.35",
3
+ "python": "3.10.12",
4
+ "heartbeatAt": "2024-02-11T14:04:50.615271",
5
+ "startedAt": "2024-02-11T14:04:49.324806",
6
+ "docker": null,
7
+ "cuda": null,
8
+ "args": [
9
+ "--config",
10
+ "basic_config.yaml"
11
+ ],
12
+ "state": "running",
13
+ "program": "/workspace/torchtune-coloring/./full_finetune.py",
14
+ "codePathLocal": "full_finetune.py",
15
+ "codePath": "full_finetune.py",
16
+ "git": {
17
+ "remote": "git@github.com:laurencer/torchtune-colorful-llama.git",
18
+ "commit": "bce1cd9d7dc857040353558881688a78f4e8691b"
19
+ },
20
+ "email": null,
21
+ "root": "/workspace/torchtune-coloring",
22
+ "host": "513e57971672",
23
+ "username": "root",
24
+ "executable": "/usr/bin/python",
25
+ "cpu_count": 64,
26
+ "cpu_count_logical": 128,
27
+ "cpu_freq": {
28
+ "current": 1755.92525,
29
+ "min": 1500.0,
30
+ "max": 2800.0
31
+ },
32
+ "cpu_freq_per_core": [
33
+ {
34
+ "current": 1500.04,
35
+ "min": 1500.0,
36
+ "max": 2800.0
37
+ },
38
+ {
39
+ "current": 1498.933,
40
+ "min": 1500.0,
41
+ "max": 2800.0
42
+ },
43
+ {
44
+ "current": 1497.537,
45
+ "min": 1500.0,
46
+ "max": 2800.0
47
+ },
48
+ {
49
+ "current": 1497.464,
50
+ "min": 1500.0,
51
+ "max": 2800.0
52
+ },
53
+ {
54
+ "current": 1497.681,
55
+ "min": 1500.0,
56
+ "max": 2800.0
57
+ },
58
+ {
59
+ "current": 1499.897,
60
+ "min": 1500.0,
61
+ "max": 2800.0
62
+ },
63
+ {
64
+ "current": 1499.836,
65
+ "min": 1500.0,
66
+ "max": 2800.0
67
+ },
68
+ {
69
+ "current": 1497.352,
70
+ "min": 1500.0,
71
+ "max": 2800.0
72
+ },
73
+ {
74
+ "current": 2342.743,
75
+ "min": 1500.0,
76
+ "max": 2800.0
77
+ },
78
+ {
79
+ "current": 2959.736,
80
+ "min": 1500.0,
81
+ "max": 2800.0
82
+ },
83
+ {
84
+ "current": 1736.638,
85
+ "min": 1500.0,
86
+ "max": 2800.0
87
+ },
88
+ {
89
+ "current": 2270.242,
90
+ "min": 1500.0,
91
+ "max": 2800.0
92
+ },
93
+ {
94
+ "current": 1498.665,
95
+ "min": 1500.0,
96
+ "max": 2800.0
97
+ },
98
+ {
99
+ "current": 1497.398,
100
+ "min": 1500.0,
101
+ "max": 2800.0
102
+ },
103
+ {
104
+ "current": 1497.231,
105
+ "min": 1500.0,
106
+ "max": 2800.0
107
+ },
108
+ {
109
+ "current": 1499.877,
110
+ "min": 1500.0,
111
+ "max": 2800.0
112
+ },
113
+ {
114
+ "current": 1499.695,
115
+ "min": 1500.0,
116
+ "max": 2800.0
117
+ },
118
+ {
119
+ "current": 1499.602,
120
+ "min": 1500.0,
121
+ "max": 2800.0
122
+ },
123
+ {
124
+ "current": 1499.693,
125
+ "min": 1500.0,
126
+ "max": 2800.0
127
+ },
128
+ {
129
+ "current": 1498.044,
130
+ "min": 1500.0,
131
+ "max": 2800.0
132
+ },
133
+ {
134
+ "current": 1878.102,
135
+ "min": 1500.0,
136
+ "max": 2800.0
137
+ },
138
+ {
139
+ "current": 2083.492,
140
+ "min": 1500.0,
141
+ "max": 2800.0
142
+ },
143
+ {
144
+ "current": 2048.864,
145
+ "min": 1500.0,
146
+ "max": 2800.0
147
+ },
148
+ {
149
+ "current": 2013.355,
150
+ "min": 1500.0,
151
+ "max": 2800.0
152
+ },
153
+ {
154
+ "current": 2977.601,
155
+ "min": 1500.0,
156
+ "max": 2800.0
157
+ },
158
+ {
159
+ "current": 3724.526,
160
+ "min": 1500.0,
161
+ "max": 2800.0
162
+ },
163
+ {
164
+ "current": 2979.262,
165
+ "min": 1500.0,
166
+ "max": 2800.0
167
+ },
168
+ {
169
+ "current": 2979.431,
170
+ "min": 1500.0,
171
+ "max": 2800.0
172
+ },
173
+ {
174
+ "current": 1499.174,
175
+ "min": 1500.0,
176
+ "max": 2800.0
177
+ },
178
+ {
179
+ "current": 1499.507,
180
+ "min": 1500.0,
181
+ "max": 2800.0
182
+ },
183
+ {
184
+ "current": 1499.878,
185
+ "min": 1500.0,
186
+ "max": 2800.0
187
+ },
188
+ {
189
+ "current": 1499.719,
190
+ "min": 1500.0,
191
+ "max": 2800.0
192
+ },
193
+ {
194
+ "current": 2979.341,
195
+ "min": 1500.0,
196
+ "max": 2800.0
197
+ },
198
+ {
199
+ "current": 3724.914,
200
+ "min": 1500.0,
201
+ "max": 2800.0
202
+ },
203
+ {
204
+ "current": 2981.767,
205
+ "min": 1500.0,
206
+ "max": 2800.0
207
+ },
208
+ {
209
+ "current": 2975.319,
210
+ "min": 1500.0,
211
+ "max": 2800.0
212
+ },
213
+ {
214
+ "current": 1963.286,
215
+ "min": 1500.0,
216
+ "max": 2800.0
217
+ },
218
+ {
219
+ "current": 1666.585,
220
+ "min": 1500.0,
221
+ "max": 2800.0
222
+ },
223
+ {
224
+ "current": 2111.485,
225
+ "min": 1500.0,
226
+ "max": 2800.0
227
+ },
228
+ {
229
+ "current": 2423.18,
230
+ "min": 1500.0,
231
+ "max": 2800.0
232
+ },
233
+ {
234
+ "current": 1499.225,
235
+ "min": 1500.0,
236
+ "max": 2800.0
237
+ },
238
+ {
239
+ "current": 1499.833,
240
+ "min": 1500.0,
241
+ "max": 2800.0
242
+ },
243
+ {
244
+ "current": 1499.229,
245
+ "min": 1500.0,
246
+ "max": 2800.0
247
+ },
248
+ {
249
+ "current": 1499.076,
250
+ "min": 1500.0,
251
+ "max": 2800.0
252
+ },
253
+ {
254
+ "current": 1766.004,
255
+ "min": 1500.0,
256
+ "max": 2800.0
257
+ },
258
+ {
259
+ "current": 1577.367,
260
+ "min": 1500.0,
261
+ "max": 2800.0
262
+ },
263
+ {
264
+ "current": 1581.383,
265
+ "min": 1500.0,
266
+ "max": 2800.0
267
+ },
268
+ {
269
+ "current": 1580.484,
270
+ "min": 1500.0,
271
+ "max": 2800.0
272
+ },
273
+ {
274
+ "current": 1499.674,
275
+ "min": 1500.0,
276
+ "max": 2800.0
277
+ },
278
+ {
279
+ "current": 1499.863,
280
+ "min": 1500.0,
281
+ "max": 2800.0
282
+ },
283
+ {
284
+ "current": 1498.172,
285
+ "min": 1500.0,
286
+ "max": 2800.0
287
+ },
288
+ {
289
+ "current": 1499.716,
290
+ "min": 1500.0,
291
+ "max": 2800.0
292
+ },
293
+ {
294
+ "current": 1498.782,
295
+ "min": 1500.0,
296
+ "max": 2800.0
297
+ },
298
+ {
299
+ "current": 1497.927,
300
+ "min": 1500.0,
301
+ "max": 2800.0
302
+ },
303
+ {
304
+ "current": 1498.965,
305
+ "min": 1500.0,
306
+ "max": 2800.0
307
+ },
308
+ {
309
+ "current": 1497.912,
310
+ "min": 1500.0,
311
+ "max": 2800.0
312
+ },
313
+ {
314
+ "current": 1499.62,
315
+ "min": 1500.0,
316
+ "max": 2800.0
317
+ },
318
+ {
319
+ "current": 1498.714,
320
+ "min": 1500.0,
321
+ "max": 2800.0
322
+ },
323
+ {
324
+ "current": 1498.079,
325
+ "min": 1500.0,
326
+ "max": 2800.0
327
+ },
328
+ {
329
+ "current": 1497.777,
330
+ "min": 1500.0,
331
+ "max": 2800.0
332
+ },
333
+ {
334
+ "current": 1499.872,
335
+ "min": 1500.0,
336
+ "max": 2800.0
337
+ },
338
+ {
339
+ "current": 1499.831,
340
+ "min": 1500.0,
341
+ "max": 2800.0
342
+ },
343
+ {
344
+ "current": 1498.093,
345
+ "min": 1500.0,
346
+ "max": 2800.0
347
+ },
348
+ {
349
+ "current": 1497.111,
350
+ "min": 1500.0,
351
+ "max": 2800.0
352
+ },
353
+ {
354
+ "current": 1499.556,
355
+ "min": 1500.0,
356
+ "max": 2800.0
357
+ },
358
+ {
359
+ "current": 1499.672,
360
+ "min": 1500.0,
361
+ "max": 2800.0
362
+ },
363
+ {
364
+ "current": 1499.554,
365
+ "min": 1500.0,
366
+ "max": 2800.0
367
+ },
368
+ {
369
+ "current": 1499.614,
370
+ "min": 1500.0,
371
+ "max": 2800.0
372
+ },
373
+ {
374
+ "current": 1500.085,
375
+ "min": 1500.0,
376
+ "max": 2800.0
377
+ },
378
+ {
379
+ "current": 1499.844,
380
+ "min": 1500.0,
381
+ "max": 2800.0
382
+ },
383
+ {
384
+ "current": 1499.574,
385
+ "min": 1500.0,
386
+ "max": 2800.0
387
+ },
388
+ {
389
+ "current": 1498.902,
390
+ "min": 1500.0,
391
+ "max": 2800.0
392
+ },
393
+ {
394
+ "current": 2862.547,
395
+ "min": 1500.0,
396
+ "max": 2800.0
397
+ },
398
+ {
399
+ "current": 3409.479,
400
+ "min": 1500.0,
401
+ "max": 2800.0
402
+ },
403
+ {
404
+ "current": 2926.343,
405
+ "min": 1500.0,
406
+ "max": 2800.0
407
+ },
408
+ {
409
+ "current": 2321.842,
410
+ "min": 1500.0,
411
+ "max": 2800.0
412
+ },
413
+ {
414
+ "current": 1495.763,
415
+ "min": 1500.0,
416
+ "max": 2800.0
417
+ },
418
+ {
419
+ "current": 1498.724,
420
+ "min": 1500.0,
421
+ "max": 2800.0
422
+ },
423
+ {
424
+ "current": 1497.288,
425
+ "min": 1500.0,
426
+ "max": 2800.0
427
+ },
428
+ {
429
+ "current": 1497.339,
430
+ "min": 1500.0,
431
+ "max": 2800.0
432
+ },
433
+ {
434
+ "current": 1497.647,
435
+ "min": 1500.0,
436
+ "max": 2800.0
437
+ },
438
+ {
439
+ "current": 1496.864,
440
+ "min": 1500.0,
441
+ "max": 2800.0
442
+ },
443
+ {
444
+ "current": 1499.035,
445
+ "min": 1500.0,
446
+ "max": 2800.0
447
+ },
448
+ {
449
+ "current": 1497.227,
450
+ "min": 1500.0,
451
+ "max": 2800.0
452
+ },
453
+ {
454
+ "current": 1867.397,
455
+ "min": 1500.0,
456
+ "max": 2800.0
457
+ },
458
+ {
459
+ "current": 1498.385,
460
+ "min": 1500.0,
461
+ "max": 2800.0
462
+ },
463
+ {
464
+ "current": 1498.957,
465
+ "min": 1500.0,
466
+ "max": 2800.0
467
+ },
468
+ {
469
+ "current": 1498.714,
470
+ "min": 1500.0,
471
+ "max": 2800.0
472
+ },
473
+ {
474
+ "current": 2979.319,
475
+ "min": 1500.0,
476
+ "max": 2800.0
477
+ },
478
+ {
479
+ "current": 3725.068,
480
+ "min": 1500.0,
481
+ "max": 2800.0
482
+ },
483
+ {
484
+ "current": 2979.209,
485
+ "min": 1500.0,
486
+ "max": 2800.0
487
+ },
488
+ {
489
+ "current": 2975.679,
490
+ "min": 1500.0,
491
+ "max": 2800.0
492
+ },
493
+ {
494
+ "current": 1498.377,
495
+ "min": 1500.0,
496
+ "max": 2800.0
497
+ },
498
+ {
499
+ "current": 1498.144,
500
+ "min": 1500.0,
501
+ "max": 2800.0
502
+ },
503
+ {
504
+ "current": 1497.576,
505
+ "min": 1500.0,
506
+ "max": 2800.0
507
+ },
508
+ {
509
+ "current": 1499.402,
510
+ "min": 1500.0,
511
+ "max": 2800.0
512
+ },
513
+ {
514
+ "current": 2979.745,
515
+ "min": 1500.0,
516
+ "max": 2800.0
517
+ },
518
+ {
519
+ "current": 3725.174,
520
+ "min": 1500.0,
521
+ "max": 2800.0
522
+ },
523
+ {
524
+ "current": 2980.206,
525
+ "min": 1500.0,
526
+ "max": 2800.0
527
+ },
528
+ {
529
+ "current": 2978.411,
530
+ "min": 1500.0,
531
+ "max": 2800.0
532
+ },
533
+ {
534
+ "current": 2072.139,
535
+ "min": 1500.0,
536
+ "max": 2800.0
537
+ },
538
+ {
539
+ "current": 2094.813,
540
+ "min": 1500.0,
541
+ "max": 2800.0
542
+ },
543
+ {
544
+ "current": 2050.315,
545
+ "min": 1500.0,
546
+ "max": 2800.0
547
+ },
548
+ {
549
+ "current": 3524.044,
550
+ "min": 1500.0,
551
+ "max": 2800.0
552
+ },
553
+ {
554
+ "current": 1497.289,
555
+ "min": 1500.0,
556
+ "max": 2800.0
557
+ },
558
+ {
559
+ "current": 1501.308,
560
+ "min": 1500.0,
561
+ "max": 2800.0
562
+ },
563
+ {
564
+ "current": 1498.431,
565
+ "min": 1500.0,
566
+ "max": 2800.0
567
+ },
568
+ {
569
+ "current": 1499.037,
570
+ "min": 1500.0,
571
+ "max": 2800.0
572
+ },
573
+ {
574
+ "current": 1499.557,
575
+ "min": 1500.0,
576
+ "max": 2800.0
577
+ },
578
+ {
579
+ "current": 1499.081,
580
+ "min": 1500.0,
581
+ "max": 2800.0
582
+ },
583
+ {
584
+ "current": 1498.268,
585
+ "min": 1500.0,
586
+ "max": 2800.0
587
+ },
588
+ {
589
+ "current": 1498.813,
590
+ "min": 1500.0,
591
+ "max": 2800.0
592
+ },
593
+ {
594
+ "current": 1498.585,
595
+ "min": 1500.0,
596
+ "max": 2800.0
597
+ },
598
+ {
599
+ "current": 1498.952,
600
+ "min": 1500.0,
601
+ "max": 2800.0
602
+ },
603
+ {
604
+ "current": 1496.882,
605
+ "min": 1500.0,
606
+ "max": 2800.0
607
+ },
608
+ {
609
+ "current": 1498.68,
610
+ "min": 1500.0,
611
+ "max": 2800.0
612
+ },
613
+ {
614
+ "current": 1497.807,
615
+ "min": 1500.0,
616
+ "max": 2800.0
617
+ },
618
+ {
619
+ "current": 1498.723,
620
+ "min": 1500.0,
621
+ "max": 2800.0
622
+ },
623
+ {
624
+ "current": 1498.047,
625
+ "min": 1500.0,
626
+ "max": 2800.0
627
+ },
628
+ {
629
+ "current": 1497.625,
630
+ "min": 1500.0,
631
+ "max": 2800.0
632
+ },
633
+ {
634
+ "current": 1496.718,
635
+ "min": 1500.0,
636
+ "max": 2800.0
637
+ },
638
+ {
639
+ "current": 1498.27,
640
+ "min": 1500.0,
641
+ "max": 2800.0
642
+ },
643
+ {
644
+ "current": 1498.148,
645
+ "min": 1500.0,
646
+ "max": 2800.0
647
+ },
648
+ {
649
+ "current": 1498.911,
650
+ "min": 1500.0,
651
+ "max": 2800.0
652
+ },
653
+ {
654
+ "current": 1499.737,
655
+ "min": 1500.0,
656
+ "max": 2800.0
657
+ },
658
+ {
659
+ "current": 1499.721,
660
+ "min": 1500.0,
661
+ "max": 2800.0
662
+ },
663
+ {
664
+ "current": 1496.39,
665
+ "min": 1500.0,
666
+ "max": 2800.0
667
+ },
668
+ {
669
+ "current": 1500.029,
670
+ "min": 1500.0,
671
+ "max": 2800.0
672
+ }
673
+ ],
674
+ "disk": {
675
+ "/": {
676
+ "total": 100.0,
677
+ "used": 13.073665618896484
678
+ }
679
+ },
680
+ "gpu": "NVIDIA A100 80GB PCIe",
681
+ "gpu_count": 1,
682
+ "gpu_devices": [
683
+ {
684
+ "name": "NVIDIA A100 80GB PCIe",
685
+ "memory_total": 85899345920
686
+ }
687
+ ],
688
+ "memory": {
689
+ "total": 1007.7841453552246
690
+ }
691
+ }
wandb/run-20240211_140449-81tescpe/files/wandb-summary.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"loss": 6.642009258270264, "lr": 2e-05, "gpu_resources": 28185677312, "_timestamp": 1707660723.104045, "_runtime": 433.7363488674164, "_step": 649, "_wandb": {"runtime": 433}}
wandb/run-20240211_140449-81tescpe/run-81tescpe.wandb ADDED
Binary file (499 kB). View file
 
wandb/run-20240211_141255-f3ffr2e5/files/config.yaml ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ wandb_version: 1
2
+
3
+ log_dir:
4
+ desc: null
5
+ value: output/alpaca-llama2-finetune
6
+ _wandb:
7
+ desc: null
8
+ value:
9
+ python_version: 3.10.12
10
+ cli_version: 0.16.3
11
+ framework: torch
12
+ is_jupyter_run: false
13
+ is_kaggle_kernel: false
14
+ start_time: 1707660775.784475
15
+ t:
16
+ 1:
17
+ - 1
18
+ - 49
19
+ - 51
20
+ - 55
21
+ 2:
22
+ - 1
23
+ - 49
24
+ - 51
25
+ - 55
26
+ 3:
27
+ - 16
28
+ - 23
29
+ 4: 3.10.12
30
+ 5: 0.16.3
31
+ 8:
32
+ - 5
33
+ 13: linux-x86_64
wandb/run-20240211_141255-f3ffr2e5/files/requirements.txt ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiohttp==3.9.3
2
+ aiosignal==1.3.1
3
+ antlr4-python3-runtime==4.9.3
4
+ anyio==4.2.0
5
+ appdirs==1.4.4
6
+ argon2-cffi-bindings==21.2.0
7
+ argon2-cffi==23.1.0
8
+ arrow==1.3.0
9
+ asttokens==2.4.1
10
+ async-lru==2.0.4
11
+ async-timeout==4.0.3
12
+ attrs==23.2.0
13
+ babel==2.14.0
14
+ beautifulsoup4==4.12.3
15
+ bleach==6.1.0
16
+ blinker==1.4
17
+ certifi==2024.2.2
18
+ cffi==1.16.0
19
+ charset-normalizer==3.3.2
20
+ click==8.1.7
21
+ comm==0.2.1
22
+ cryptography==3.4.8
23
+ datasets==2.15.0
24
+ dbus-python==1.2.18
25
+ debugpy==1.8.0
26
+ decorator==5.1.1
27
+ defusedxml==0.7.1
28
+ dill==0.3.7
29
+ distro==1.7.0
30
+ docker-pycreds==0.4.0
31
+ entrypoints==0.4
32
+ exceptiongroup==1.2.0
33
+ executing==2.0.1
34
+ fastjsonschema==2.19.1
35
+ filelock==3.13.1
36
+ fqdn==1.5.1
37
+ frozenlist==1.4.1
38
+ fsspec==2023.10.0
39
+ gitdb==4.0.11
40
+ gitpython==3.1.41
41
+ h11==0.14.0
42
+ httpcore==1.0.2
43
+ httplib2==0.20.2
44
+ httpx==0.26.0
45
+ huggingface-hub==0.19.4
46
+ idna==3.6
47
+ importlib-metadata==4.6.4
48
+ ipykernel==6.29.0
49
+ ipython-genutils==0.2.0
50
+ ipython==8.21.0
51
+ ipywidgets==8.1.1
52
+ isoduration==20.11.0
53
+ jedi==0.19.1
54
+ jeepney==0.7.1
55
+ jinja2==3.1.3
56
+ json5==0.9.14
57
+ jsonpointer==2.4
58
+ jsonschema-specifications==2023.12.1
59
+ jsonschema==4.21.1
60
+ jupyter-archive==3.4.0
61
+ jupyter-client==7.4.9
62
+ jupyter-contrib-core==0.4.2
63
+ jupyter-contrib-nbextensions==0.7.0
64
+ jupyter-core==5.7.1
65
+ jupyter-events==0.9.0
66
+ jupyter-highlight-selected-word==0.2.0
67
+ jupyter-lsp==2.2.2
68
+ jupyter-nbextensions-configurator==0.6.3
69
+ jupyter-server-terminals==0.5.2
70
+ jupyter-server==2.12.5
71
+ jupyterlab-pygments==0.3.0
72
+ jupyterlab-server==2.25.2
73
+ jupyterlab-widgets==3.0.9
74
+ jupyterlab==4.1.0
75
+ keyring==23.5.0
76
+ launchpadlib==1.10.16
77
+ lazr.restfulclient==0.14.4
78
+ lazr.uri==1.0.6
79
+ lxml==5.1.0
80
+ markupsafe==2.1.5
81
+ matplotlib-inline==0.1.6
82
+ mistune==3.0.2
83
+ more-itertools==8.10.0
84
+ mpmath==1.3.0
85
+ multidict==6.0.5
86
+ multiprocess==0.70.15
87
+ nbclassic==1.0.0
88
+ nbclient==0.9.0
89
+ nbconvert==7.14.2
90
+ nbformat==5.9.2
91
+ nest-asyncio==1.6.0
92
+ networkx==3.2.1
93
+ notebook-shim==0.2.3
94
+ notebook==6.5.5
95
+ numpy==1.26.3
96
+ nvidia-cublas-cu12==12.1.3.1
97
+ nvidia-cuda-cupti-cu12==12.1.105
98
+ nvidia-cuda-nvrtc-cu12==12.1.105
99
+ nvidia-cuda-runtime-cu12==12.1.105
100
+ nvidia-cudnn-cu12==8.9.2.26
101
+ nvidia-cufft-cu12==11.0.2.54
102
+ nvidia-curand-cu12==10.3.2.106
103
+ nvidia-cusolver-cu12==11.4.5.107
104
+ nvidia-cusparse-cu12==12.1.0.106
105
+ nvidia-nccl-cu12==2.19.3
106
+ nvidia-nvjitlink-cu12==12.3.101
107
+ nvidia-nvtx-cu12==12.1.105
108
+ oauthlib==3.2.0
109
+ omegaconf==2.3.0
110
+ overrides==7.7.0
111
+ packaging==23.2
112
+ pandas==2.2.0
113
+ pandocfilters==1.5.1
114
+ parso==0.8.3
115
+ pexpect==4.9.0
116
+ pillow==10.2.0
117
+ pip==24.0
118
+ platformdirs==4.2.0
119
+ prometheus-client==0.19.0
120
+ prompt-toolkit==3.0.43
121
+ protobuf==4.25.2
122
+ psutil==5.9.8
123
+ ptyprocess==0.7.0
124
+ pure-eval==0.2.2
125
+ pyarrow-hotfix==0.6
126
+ pyarrow==15.0.0
127
+ pycparser==2.21
128
+ pygments==2.17.2
129
+ pygobject==3.42.1
130
+ pyjwt==2.3.0
131
+ pyparsing==2.4.7
132
+ python-apt==2.4.0+ubuntu2
133
+ python-dateutil==2.8.2
134
+ python-json-logger==2.0.7
135
+ pytz==2024.1
136
+ pyyaml==6.0.1
137
+ pyzmq==24.0.1
138
+ referencing==0.33.0
139
+ requests==2.31.0
140
+ rfc3339-validator==0.1.4
141
+ rfc3986-validator==0.1.1
142
+ rpds-py==0.17.1
143
+ secretstorage==3.3.1
144
+ send2trash==1.8.2
145
+ sentencepiece==0.1.99
146
+ sentry-sdk==1.40.3
147
+ setproctitle==1.3.3
148
+ setuptools==69.0.3
149
+ six==1.16.0
150
+ smmap==5.0.1
151
+ sniffio==1.3.0
152
+ soupsieve==2.5
153
+ stack-data==0.6.3
154
+ sympy==1.12
155
+ terminado==0.18.0
156
+ tinycss2==1.2.1
157
+ tomli==2.0.1
158
+ torch==2.2.0
159
+ torchaudio==2.2.0
160
+ torchtune==0.0.1
161
+ torchvision==0.17.0
162
+ tornado==6.4
163
+ tqdm==4.66.1
164
+ traitlets==5.14.1
165
+ triton==2.2.0
166
+ types-python-dateutil==2.8.19.20240106
167
+ typing-extensions==4.9.0
168
+ tzdata==2023.4
169
+ uri-template==1.3.0
170
+ urllib3==2.2.0
171
+ wadllib==1.3.6
172
+ wandb==0.16.3
173
+ wcwidth==0.2.13
174
+ webcolors==1.13
175
+ webencodings==0.5.1
176
+ websocket-client==1.7.0
177
+ wheel==0.42.0
178
+ widgetsnbextension==4.0.9
179
+ xxhash==3.4.1
180
+ yarl==1.9.4
181
+ zipp==1.0.0
wandb/run-20240211_141255-f3ffr2e5/files/wandb-metadata.json ADDED
@@ -0,0 +1,691 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "os": "Linux-5.4.0-169-generic-x86_64-with-glibc2.35",
3
+ "python": "3.10.12",
4
+ "heartbeatAt": "2024-02-11T14:12:57.431913",
5
+ "startedAt": "2024-02-11T14:12:55.736045",
6
+ "docker": null,
7
+ "cuda": null,
8
+ "args": [
9
+ "--config",
10
+ "basic_config.yaml"
11
+ ],
12
+ "state": "running",
13
+ "program": "/workspace/torchtune-coloring/./full_finetune.py",
14
+ "codePathLocal": "full_finetune.py",
15
+ "codePath": "full_finetune.py",
16
+ "git": {
17
+ "remote": "git@github.com:laurencer/torchtune-colorful-llama.git",
18
+ "commit": "bce1cd9d7dc857040353558881688a78f4e8691b"
19
+ },
20
+ "email": null,
21
+ "root": "/workspace/torchtune-coloring",
22
+ "host": "513e57971672",
23
+ "username": "root",
24
+ "executable": "/usr/bin/python",
25
+ "cpu_count": 64,
26
+ "cpu_count_logical": 128,
27
+ "cpu_freq": {
28
+ "current": 1584.06415625,
29
+ "min": 1500.0,
30
+ "max": 2800.0
31
+ },
32
+ "cpu_freq_per_core": [
33
+ {
34
+ "current": 1490.009,
35
+ "min": 1500.0,
36
+ "max": 2800.0
37
+ },
38
+ {
39
+ "current": 1497.378,
40
+ "min": 1500.0,
41
+ "max": 2800.0
42
+ },
43
+ {
44
+ "current": 1499.271,
45
+ "min": 1500.0,
46
+ "max": 2800.0
47
+ },
48
+ {
49
+ "current": 1498.077,
50
+ "min": 1500.0,
51
+ "max": 2800.0
52
+ },
53
+ {
54
+ "current": 1696.135,
55
+ "min": 1500.0,
56
+ "max": 2800.0
57
+ },
58
+ {
59
+ "current": 1810.431,
60
+ "min": 1500.0,
61
+ "max": 2800.0
62
+ },
63
+ {
64
+ "current": 1650.597,
65
+ "min": 1500.0,
66
+ "max": 2800.0
67
+ },
68
+ {
69
+ "current": 1668.338,
70
+ "min": 1500.0,
71
+ "max": 2800.0
72
+ },
73
+ {
74
+ "current": 1497.137,
75
+ "min": 1500.0,
76
+ "max": 2800.0
77
+ },
78
+ {
79
+ "current": 1497.334,
80
+ "min": 1500.0,
81
+ "max": 2800.0
82
+ },
83
+ {
84
+ "current": 1497.558,
85
+ "min": 1500.0,
86
+ "max": 2800.0
87
+ },
88
+ {
89
+ "current": 1580.833,
90
+ "min": 1500.0,
91
+ "max": 2800.0
92
+ },
93
+ {
94
+ "current": 1497.972,
95
+ "min": 1500.0,
96
+ "max": 2800.0
97
+ },
98
+ {
99
+ "current": 1498.117,
100
+ "min": 1500.0,
101
+ "max": 2800.0
102
+ },
103
+ {
104
+ "current": 1796.541,
105
+ "min": 1500.0,
106
+ "max": 2800.0
107
+ },
108
+ {
109
+ "current": 1497.312,
110
+ "min": 1500.0,
111
+ "max": 2800.0
112
+ },
113
+ {
114
+ "current": 2977.727,
115
+ "min": 1500.0,
116
+ "max": 2800.0
117
+ },
118
+ {
119
+ "current": 2979.389,
120
+ "min": 1500.0,
121
+ "max": 2800.0
122
+ },
123
+ {
124
+ "current": 2978.317,
125
+ "min": 1500.0,
126
+ "max": 2800.0
127
+ },
128
+ {
129
+ "current": 3695.755,
130
+ "min": 1500.0,
131
+ "max": 2800.0
132
+ },
133
+ {
134
+ "current": 1498.377,
135
+ "min": 1500.0,
136
+ "max": 2800.0
137
+ },
138
+ {
139
+ "current": 1497.215,
140
+ "min": 1500.0,
141
+ "max": 2800.0
142
+ },
143
+ {
144
+ "current": 1494.35,
145
+ "min": 1500.0,
146
+ "max": 2800.0
147
+ },
148
+ {
149
+ "current": 1498.254,
150
+ "min": 1500.0,
151
+ "max": 2800.0
152
+ },
153
+ {
154
+ "current": 1499.02,
155
+ "min": 1500.0,
156
+ "max": 2800.0
157
+ },
158
+ {
159
+ "current": 1497.747,
160
+ "min": 1500.0,
161
+ "max": 2800.0
162
+ },
163
+ {
164
+ "current": 1499.785,
165
+ "min": 1500.0,
166
+ "max": 2800.0
167
+ },
168
+ {
169
+ "current": 1497.015,
170
+ "min": 1500.0,
171
+ "max": 2800.0
172
+ },
173
+ {
174
+ "current": 1526.811,
175
+ "min": 1500.0,
176
+ "max": 2800.0
177
+ },
178
+ {
179
+ "current": 1566.368,
180
+ "min": 1500.0,
181
+ "max": 2800.0
182
+ },
183
+ {
184
+ "current": 1701.151,
185
+ "min": 1500.0,
186
+ "max": 2800.0
187
+ },
188
+ {
189
+ "current": 1507.923,
190
+ "min": 1500.0,
191
+ "max": 2800.0
192
+ },
193
+ {
194
+ "current": 1499.608,
195
+ "min": 1500.0,
196
+ "max": 2800.0
197
+ },
198
+ {
199
+ "current": 1498.845,
200
+ "min": 1500.0,
201
+ "max": 2800.0
202
+ },
203
+ {
204
+ "current": 1497.249,
205
+ "min": 1500.0,
206
+ "max": 2800.0
207
+ },
208
+ {
209
+ "current": 1499.128,
210
+ "min": 1500.0,
211
+ "max": 2800.0
212
+ },
213
+ {
214
+ "current": 1497.649,
215
+ "min": 1500.0,
216
+ "max": 2800.0
217
+ },
218
+ {
219
+ "current": 1499.034,
220
+ "min": 1500.0,
221
+ "max": 2800.0
222
+ },
223
+ {
224
+ "current": 1497.386,
225
+ "min": 1500.0,
226
+ "max": 2800.0
227
+ },
228
+ {
229
+ "current": 1498.641,
230
+ "min": 1500.0,
231
+ "max": 2800.0
232
+ },
233
+ {
234
+ "current": 1498.814,
235
+ "min": 1500.0,
236
+ "max": 2800.0
237
+ },
238
+ {
239
+ "current": 1498.542,
240
+ "min": 1500.0,
241
+ "max": 2800.0
242
+ },
243
+ {
244
+ "current": 1498.895,
245
+ "min": 1500.0,
246
+ "max": 2800.0
247
+ },
248
+ {
249
+ "current": 1498.555,
250
+ "min": 1500.0,
251
+ "max": 2800.0
252
+ },
253
+ {
254
+ "current": 1498.328,
255
+ "min": 1500.0,
256
+ "max": 2800.0
257
+ },
258
+ {
259
+ "current": 1498.571,
260
+ "min": 1500.0,
261
+ "max": 2800.0
262
+ },
263
+ {
264
+ "current": 1498.412,
265
+ "min": 1500.0,
266
+ "max": 2800.0
267
+ },
268
+ {
269
+ "current": 1497.382,
270
+ "min": 1500.0,
271
+ "max": 2800.0
272
+ },
273
+ {
274
+ "current": 1499.44,
275
+ "min": 1500.0,
276
+ "max": 2800.0
277
+ },
278
+ {
279
+ "current": 1495.766,
280
+ "min": 1500.0,
281
+ "max": 2800.0
282
+ },
283
+ {
284
+ "current": 1499.108,
285
+ "min": 1500.0,
286
+ "max": 2800.0
287
+ },
288
+ {
289
+ "current": 1499.73,
290
+ "min": 1500.0,
291
+ "max": 2800.0
292
+ },
293
+ {
294
+ "current": 1499.463,
295
+ "min": 1500.0,
296
+ "max": 2800.0
297
+ },
298
+ {
299
+ "current": 1497.523,
300
+ "min": 1500.0,
301
+ "max": 2800.0
302
+ },
303
+ {
304
+ "current": 1500.13,
305
+ "min": 1500.0,
306
+ "max": 2800.0
307
+ },
308
+ {
309
+ "current": 1499.545,
310
+ "min": 1500.0,
311
+ "max": 2800.0
312
+ },
313
+ {
314
+ "current": 1498.452,
315
+ "min": 1500.0,
316
+ "max": 2800.0
317
+ },
318
+ {
319
+ "current": 1498.325,
320
+ "min": 1500.0,
321
+ "max": 2800.0
322
+ },
323
+ {
324
+ "current": 1498.653,
325
+ "min": 1500.0,
326
+ "max": 2800.0
327
+ },
328
+ {
329
+ "current": 1499.635,
330
+ "min": 1500.0,
331
+ "max": 2800.0
332
+ },
333
+ {
334
+ "current": 1498.506,
335
+ "min": 1500.0,
336
+ "max": 2800.0
337
+ },
338
+ {
339
+ "current": 1499.004,
340
+ "min": 1500.0,
341
+ "max": 2800.0
342
+ },
343
+ {
344
+ "current": 1499.265,
345
+ "min": 1500.0,
346
+ "max": 2800.0
347
+ },
348
+ {
349
+ "current": 1498.955,
350
+ "min": 1500.0,
351
+ "max": 2800.0
352
+ },
353
+ {
354
+ "current": 1331.298,
355
+ "min": 1500.0,
356
+ "max": 2800.0
357
+ },
358
+ {
359
+ "current": 1498.548,
360
+ "min": 1500.0,
361
+ "max": 2800.0
362
+ },
363
+ {
364
+ "current": 1384.617,
365
+ "min": 1500.0,
366
+ "max": 2800.0
367
+ },
368
+ {
369
+ "current": 1498.803,
370
+ "min": 1500.0,
371
+ "max": 2800.0
372
+ },
373
+ {
374
+ "current": 2003.768,
375
+ "min": 1500.0,
376
+ "max": 2800.0
377
+ },
378
+ {
379
+ "current": 2386.047,
380
+ "min": 1500.0,
381
+ "max": 2800.0
382
+ },
383
+ {
384
+ "current": 1670.529,
385
+ "min": 1500.0,
386
+ "max": 2800.0
387
+ },
388
+ {
389
+ "current": 1680.364,
390
+ "min": 1500.0,
391
+ "max": 2800.0
392
+ },
393
+ {
394
+ "current": 1496.711,
395
+ "min": 1500.0,
396
+ "max": 2800.0
397
+ },
398
+ {
399
+ "current": 1496.734,
400
+ "min": 1500.0,
401
+ "max": 2800.0
402
+ },
403
+ {
404
+ "current": 1498.113,
405
+ "min": 1500.0,
406
+ "max": 2800.0
407
+ },
408
+ {
409
+ "current": 1497.733,
410
+ "min": 1500.0,
411
+ "max": 2800.0
412
+ },
413
+ {
414
+ "current": 2185.862,
415
+ "min": 1500.0,
416
+ "max": 2800.0
417
+ },
418
+ {
419
+ "current": 2139.21,
420
+ "min": 1500.0,
421
+ "max": 2800.0
422
+ },
423
+ {
424
+ "current": 2640.006,
425
+ "min": 1500.0,
426
+ "max": 2800.0
427
+ },
428
+ {
429
+ "current": 2195.686,
430
+ "min": 1500.0,
431
+ "max": 2800.0
432
+ },
433
+ {
434
+ "current": 2979.829,
435
+ "min": 1500.0,
436
+ "max": 2800.0
437
+ },
438
+ {
439
+ "current": 2979.073,
440
+ "min": 1500.0,
441
+ "max": 2800.0
442
+ },
443
+ {
444
+ "current": 2961.456,
445
+ "min": 1500.0,
446
+ "max": 2800.0
447
+ },
448
+ {
449
+ "current": 3723.45,
450
+ "min": 1500.0,
451
+ "max": 2800.0
452
+ },
453
+ {
454
+ "current": 1499.311,
455
+ "min": 1500.0,
456
+ "max": 2800.0
457
+ },
458
+ {
459
+ "current": 1497.576,
460
+ "min": 1500.0,
461
+ "max": 2800.0
462
+ },
463
+ {
464
+ "current": 1493.545,
465
+ "min": 1500.0,
466
+ "max": 2800.0
467
+ },
468
+ {
469
+ "current": 1497.524,
470
+ "min": 1500.0,
471
+ "max": 2800.0
472
+ },
473
+ {
474
+ "current": 1498.523,
475
+ "min": 1500.0,
476
+ "max": 2800.0
477
+ },
478
+ {
479
+ "current": 1499.226,
480
+ "min": 1500.0,
481
+ "max": 2800.0
482
+ },
483
+ {
484
+ "current": 1498.089,
485
+ "min": 1500.0,
486
+ "max": 2800.0
487
+ },
488
+ {
489
+ "current": 1497.806,
490
+ "min": 1500.0,
491
+ "max": 2800.0
492
+ },
493
+ {
494
+ "current": 1499.455,
495
+ "min": 1500.0,
496
+ "max": 2800.0
497
+ },
498
+ {
499
+ "current": 1499.626,
500
+ "min": 1500.0,
501
+ "max": 2800.0
502
+ },
503
+ {
504
+ "current": 1500.045,
505
+ "min": 1500.0,
506
+ "max": 2800.0
507
+ },
508
+ {
509
+ "current": 1496.146,
510
+ "min": 1500.0,
511
+ "max": 2800.0
512
+ },
513
+ {
514
+ "current": 1498.683,
515
+ "min": 1500.0,
516
+ "max": 2800.0
517
+ },
518
+ {
519
+ "current": 1498.746,
520
+ "min": 1500.0,
521
+ "max": 2800.0
522
+ },
523
+ {
524
+ "current": 1499.509,
525
+ "min": 1500.0,
526
+ "max": 2800.0
527
+ },
528
+ {
529
+ "current": 1498.5,
530
+ "min": 1500.0,
531
+ "max": 2800.0
532
+ },
533
+ {
534
+ "current": 1497.181,
535
+ "min": 1500.0,
536
+ "max": 2800.0
537
+ },
538
+ {
539
+ "current": 1498.949,
540
+ "min": 1500.0,
541
+ "max": 2800.0
542
+ },
543
+ {
544
+ "current": 1499.742,
545
+ "min": 1500.0,
546
+ "max": 2800.0
547
+ },
548
+ {
549
+ "current": 1499.275,
550
+ "min": 1500.0,
551
+ "max": 2800.0
552
+ },
553
+ {
554
+ "current": 1497.657,
555
+ "min": 1500.0,
556
+ "max": 2800.0
557
+ },
558
+ {
559
+ "current": 1497.18,
560
+ "min": 1500.0,
561
+ "max": 2800.0
562
+ },
563
+ {
564
+ "current": 1499.544,
565
+ "min": 1500.0,
566
+ "max": 2800.0
567
+ },
568
+ {
569
+ "current": 1498.82,
570
+ "min": 1500.0,
571
+ "max": 2800.0
572
+ },
573
+ {
574
+ "current": 1498.69,
575
+ "min": 1500.0,
576
+ "max": 2800.0
577
+ },
578
+ {
579
+ "current": 1499.346,
580
+ "min": 1500.0,
581
+ "max": 2800.0
582
+ },
583
+ {
584
+ "current": 1499.574,
585
+ "min": 1500.0,
586
+ "max": 2800.0
587
+ },
588
+ {
589
+ "current": 1498.708,
590
+ "min": 1500.0,
591
+ "max": 2800.0
592
+ },
593
+ {
594
+ "current": 1495.929,
595
+ "min": 1500.0,
596
+ "max": 2800.0
597
+ },
598
+ {
599
+ "current": 1499.447,
600
+ "min": 1500.0,
601
+ "max": 2800.0
602
+ },
603
+ {
604
+ "current": 1496.645,
605
+ "min": 1500.0,
606
+ "max": 2800.0
607
+ },
608
+ {
609
+ "current": 1495.605,
610
+ "min": 1500.0,
611
+ "max": 2800.0
612
+ },
613
+ {
614
+ "current": 1499.426,
615
+ "min": 1500.0,
616
+ "max": 2800.0
617
+ },
618
+ {
619
+ "current": 1499.76,
620
+ "min": 1500.0,
621
+ "max": 2800.0
622
+ },
623
+ {
624
+ "current": 1499.735,
625
+ "min": 1500.0,
626
+ "max": 2800.0
627
+ },
628
+ {
629
+ "current": 1499.099,
630
+ "min": 1500.0,
631
+ "max": 2800.0
632
+ },
633
+ {
634
+ "current": 1498.845,
635
+ "min": 1500.0,
636
+ "max": 2800.0
637
+ },
638
+ {
639
+ "current": 1499.781,
640
+ "min": 1500.0,
641
+ "max": 2800.0
642
+ },
643
+ {
644
+ "current": 1497.862,
645
+ "min": 1500.0,
646
+ "max": 2800.0
647
+ },
648
+ {
649
+ "current": 1498.535,
650
+ "min": 1500.0,
651
+ "max": 2800.0
652
+ },
653
+ {
654
+ "current": 1497.513,
655
+ "min": 1500.0,
656
+ "max": 2800.0
657
+ },
658
+ {
659
+ "current": 1498.411,
660
+ "min": 1500.0,
661
+ "max": 2800.0
662
+ },
663
+ {
664
+ "current": 1497.487,
665
+ "min": 1500.0,
666
+ "max": 2800.0
667
+ },
668
+ {
669
+ "current": 1498.069,
670
+ "min": 1500.0,
671
+ "max": 2800.0
672
+ }
673
+ ],
674
+ "disk": {
675
+ "/": {
676
+ "total": 100.0,
677
+ "used": 13.073677062988281
678
+ }
679
+ },
680
+ "gpu": "NVIDIA A100 80GB PCIe",
681
+ "gpu_count": 1,
682
+ "gpu_devices": [
683
+ {
684
+ "name": "NVIDIA A100 80GB PCIe",
685
+ "memory_total": 85899345920
686
+ }
687
+ ],
688
+ "memory": {
689
+ "total": 1007.7841453552246
690
+ }
691
+ }
wandb/run-20240211_141255-f3ffr2e5/files/wandb-summary.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"loss": 5.02125883102417, "lr": 2e-05, "gpu_resources": 41688505856, "_timestamp": 1707690323.6159635, "_runtime": 29547.831488370895, "_step": 72713, "_wandb": {"runtime": 29548}}
wandb/run-20240211_141255-f3ffr2e5/run-f3ffr2e5.wandb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f897ff6fe35e0befb48c9d12218b5443a432b2e464eca8e783351a8c84e92c8c
3
+ size 65078304