Andrew DalPino commited on
Commit
e431f0f
·
1 Parent(s): fc4824e

Blanket optimizations

Browse files
.gitignore CHANGED
@@ -10,6 +10,7 @@ wheels/
10
  *.egg-info/
11
  .installed.cfg
12
  *.egg
 
13
  .venv
14
  venv/
15
  ENV/
 
10
  *.egg-info/
11
  .installed.cfg
12
  *.egg
13
+ *.sarif
14
  .venv
15
  venv/
16
  ENV/
README.md CHANGED
@@ -13,7 +13,7 @@ tags:
13
  ---
14
  # LightGPT
15
 
16
- LightGPT is a lightweight generative pre-trained Transformer (GPT) model for the people! Built using pure PyTorch, LightGPT can answer questions, follow instructions, summarize documents, chat, and much more. Best of all, the model weights *and* code are fully open-source for you to customize, improve upon, and share with the world.
17
 
18
  ## Features
19
 
@@ -23,18 +23,18 @@ LightGPT is a lightweight generative pre-trained Transformer (GPT) model for the
23
 
24
  - **Fully Open-source**: Unlike closed-source LLMs, LightGPT provides both the model weights *and* the source code to train, fine-tune, export, and generate text from the model using your own hardware. With the help of the open-source software community, we aim to democratize access to AI and continually improve the models.
25
 
26
- ## Suggested Pre-training Configurations
27
 
28
- Below is a table of some suggested pre-training configurations but feel free to experiment with settings on your own. See the `model_sizing.ipynb` notebook to estimate the memory and compute requirements for your model configuration.
29
 
30
- | Name | Vocab. Size | Block Size | Embedding Dim. | Attn. Heads | Layers | Parameters | Training Tokens |
31
  |---|---|---|---|---|---|---|---|
32
- | Small | 50,257 | 1024 | 1024 | 16 | 24 | 353M | 7B |
33
- | Medium | 50,257 | 1024 | 2048 | 32 | 32 | 1.7B | 34B |
34
- | Large | 100,275 | 2048 | 4096 | 64 | 32 | 6.8B | 132B |
35
- | X-large | 100,275 | 2048 | 4096 | 64 | 64 | 13B | 262B |
36
- | XX-large | 200,017 | 4096 | 8192 | 128 | 64 | 53B | 1T |
37
- | XXX-large | 200,017 | 4096 | 8192 | 128 | 128 | 105B | 2T |
38
 
39
  ## Install Project Dependencies
40
 
@@ -48,37 +48,37 @@ source ./.venv/bin/activate
48
  pip install -r requirements.txt
49
  ```
50
 
51
- ## Pre-training
52
 
53
- For the pre-training corpus we use the Fineweb dataset which consists of about 15T high-quality tokens gathered from the worldwide web. The dataset has been split into 3 subsets (10BT, 100BT, and 350BT versions) for training smaller models. If you'd like to start training right away, the default settings should work on most single-GPU systems with 12G of VRAM or more.
54
 
55
  ```
56
- python pre-train.py
57
  ```
58
 
59
  **Note** that it will take a while to download and pre-process the dataset the first time that the training script is run.
60
 
61
- To customize the default "Small" architecture you can adjust the `block_size`, `embedding_dimensions`, `num_hidden_layers`, and `num_attention_heads` arguments of the pre-training script.
62
 
63
  ```
64
- python pre-train.py --block_size=2048 --embedding_dimensions=4096 --num_hidden_layers=64 --num_attention_heads=64
65
  ```
66
 
67
  You can also adjust the `batch_size`, `learning_rate`, and `gradient_accumulation_steps` to suite your training setup.
68
 
69
  ```
70
- python pre-train.py --batch_size=32 --learning_rate=0.01 --gradient_accumulation_steps=128
71
  ```
72
 
73
  For distributed training, use PyTorch's [torchrun](https://pytorch.org/docs/stable/elastic/run.html) extension to launch a distributed data parallel (DDP) session. The example below is for executing the training script on a single node with 8 individual GPUs.
74
 
75
  ```
76
- torchrun --standalone --nnodes=1 --nproc-per-node=8 pre-train.py --batch_size=16 --gradient_accumulation_steps=128
77
  ```
78
 
79
  **Note** that when training in data-parallel mode it's important that the `gradient_accumulation_steps` divides evenly into the world size for maximum performance. For example, if we have an 8 GPU cluster, we could perform 32 gradient accumulation steps in exactly 4 passes over the network.
80
 
81
- ### Pre-training Arguments
82
 
83
  | Argument | Default | Type | Description |
84
  |---|---|---|---|
@@ -88,6 +88,7 @@ torchrun --standalone --nnodes=1 --nproc-per-node=8 pre-train.py --batch_size=16
88
  | --num_dataset_processes | 8 | int | The number of processes (CPUs) to use to process the dataset. |
89
  | --batch_size | 1 | int | The number of samples to pass through the network at a time. |
90
  | --gradient_accumulation_steps | 128 | int | The number of batches to pass through the network before updating the weights. |
 
91
  | --samples_per_epoch | 4096 | int | The number of training samples to pass through the network every epoch. |
92
  | --num_epochs | 1686 | int | The number of epochs to train for. |
93
  | --learning_rate | 1e-2 | float | The learning rate of the Adafactor optimizer. |
@@ -95,17 +96,17 @@ torchrun --standalone --nnodes=1 --nproc-per-node=8 pre-train.py --batch_size=16
95
  | --low_memory_optimizer | False | bool | Should the optimizer reduce its memory consumption in exchange for a slightly slower runtime? |
96
  | --max_gradient_norm | 1.0 | float | Clip gradients above this threshold before stepping. |
97
  | --eval_interval | 10 | int | Evaluate the model after this many epochs on the testing set. |
98
- | --block_size | 1024 | int | The number of tokens within the context window for every sample. |
99
  | --embedding_dimensions | 1024 | int | The dimensionality of the token embeddings. |
100
  | --num_attention_heads | 16 | int | The number of attention heads within every block. |
101
  | --num_hidden_layers | 24 | int | The number of attention/MLP blocks within the hidden layer of the network. |
102
  | --feed_forward_ratio | 4 | (1, 2, 4) | The ratio of hidden neurons to embedding dimensions in the MLP layers of the network. |
103
  | --dropout | 0.1 | float | The proportion of signals to send to zero during training as regularization. |
104
- | --activation_checkpointing | False | bool | Should we use activation checkpointing? This will reduce drastically memory utilization during training at the cost of needing to recompute the forward pass. |
105
  | --ddp_sharding_level | 2 | int | The level of sharding to use for DDP training. Options are 2 or 3 for partial and full sharding respectively, or 0 for no sharding. |
106
- | --checkpoint_interval | 20 | int | Save the model parameters to disk every this many epochs. |
107
  | --checkpoint_path | "./checkpoints/checkpoint.pt" | str | The path to the base checkpoint file on disk. |
108
  | --resume | False | bool | Should we resume training from the last checkpoint? |
 
109
  | --device | "cuda" | str | The device to run the computation on. |
110
  | --seed | None | int | The seed for the random number generator. |
111
 
@@ -116,12 +117,13 @@ torchrun --standalone --nnodes=1 --nproc-per-node=8 pre-train.py --batch_size=16
116
  | Argument | Default | Type | Description |
117
  |---|---|---|---|
118
  | --base_model_path | "./checkpoints/checkpoint.pt" | string | The path to the base checkpoint on disk. |
 
 
119
  | --batch_size | 1 | int | The number of samples to pass through the network at a time. |
120
  | --gradient_accumulation_steps | 64 | int | The number of batches to pass through the network before updating the weights. |
121
  | --learning_rate | 5e-4 | float | The learning rate of the Adafactor optimizer. |
122
  | --rms_decay | -0.8 | float | The decay rate of the RMS coefficient of the Adafactor optimizer. |
123
- | --optimizer_low_memory | True | bool | Should the optimizer reduce its memory consumption in exchange for a slightly slower runtime? |
124
- | --mask_input | False | bool | Should we mask the input part of the sample i.e. only train on the output? |
125
  | --rank | 8 | int | The rank of the LoRA decomposition matrices. |
126
  | --alpha | 1.0 | float | The strength of the LoRA signal. |
127
  | --dropout | 0.05 | float | The proportion of signals to send to zero during training as regularization. |
@@ -131,6 +133,7 @@ torchrun --standalone --nnodes=1 --nproc-per-node=8 pre-train.py --batch_size=16
131
  | --checkpoint_interval | 1 | int | Save the model parameters to disk every this many epochs. |
132
  | --checkpoint_path | "./checkpoints/lora_instruction.pt" | string | The path to the LoRA checkpoint. |
133
  | --resume | False | bool | Should we resume training from the last checkpoint? |
 
134
  | --device | "cuda" | string | The device to run the computation on. |
135
  | --seed | None | int | The seed for the random number generator. |
136
 
 
13
  ---
14
  # LightGPT
15
 
16
+ LightGPT is a lightweight generative pretrained Transformer (GPT) model for the people! Built using PyTorch and trained on the Fineweb and Alpaca datasets, LightGPT can answer questions, follow instructions, summarize documents, chat, and more. Best of all, the model weights *and* code are fully open-source for you to customize, improve upon, and share with the world.
17
 
18
  ## Features
19
 
 
23
 
24
  - **Fully Open-source**: Unlike closed-source LLMs, LightGPT provides both the model weights *and* the source code to train, fine-tune, export, and generate text from the model using your own hardware. With the help of the open-source software community, we aim to democratize access to AI and continually improve the models.
25
 
26
+ ## Suggested Pretraining Configurations
27
 
28
+ Below is a table of some suggested pretraining configurations but feel free to experiment with settings on your own. See the `model_sizing.ipynb` notebook to estimate the memory and compute requirements for your model configuration.
29
 
30
+ | Name | Vocab. Size | Embedding Dim. | Attn. Heads | Layers | Parameters | Training Tokens |
31
  |---|---|---|---|---|---|---|---|
32
+ | Small | 50,257 | 1024 | 16 | 24 | 353M | 7B |
33
+ | Medium | 50,257 | 2048 | 32 | 32 | 1.7B | 34B |
34
+ | Large | 100,275 | 4096 | 64 | 32 | 6.8B | 132B |
35
+ | X-large | 100,275 | 4096 | 64 | 64 | 13B | 262B |
36
+ | XX-large | 200,017 | 8192 | 128 | 64 | 53B | 1T |
37
+ | XXX-large | 200,017 | 8192 | 128 | 128 | 105B | 2T |
38
 
39
  ## Install Project Dependencies
40
 
 
48
  pip install -r requirements.txt
49
  ```
50
 
51
+ ## Pretraining
52
 
53
+ For the pretraining corpus we use the Fineweb dataset which consists of about 15T high-quality tokens gathered from the worldwide web. The dataset has been split into 3 subsets (10BT, 100BT, and 350BT versions) for training smaller models. If you'd like to start training right away, the default settings should work on most single-GPU systems with 12G of VRAM or more.
54
 
55
  ```
56
+ python pretrain.py
57
  ```
58
 
59
  **Note** that it will take a while to download and pre-process the dataset the first time that the training script is run.
60
 
61
+ To customize the default "Small" architecture you can adjust the `block_size`, `embedding_dimensions`, `num_hidden_layers`, and `num_attention_heads` arguments of the pretraining script.
62
 
63
  ```
64
+ python pretrain.py --block_size=2048 --embedding_dimensions=4096 --num_hidden_layers=64 --num_attention_heads=64
65
  ```
66
 
67
  You can also adjust the `batch_size`, `learning_rate`, and `gradient_accumulation_steps` to suite your training setup.
68
 
69
  ```
70
+ python pretrain.py --batch_size=32 --learning_rate=0.01 --gradient_accumulation_steps=128
71
  ```
72
 
73
  For distributed training, use PyTorch's [torchrun](https://pytorch.org/docs/stable/elastic/run.html) extension to launch a distributed data parallel (DDP) session. The example below is for executing the training script on a single node with 8 individual GPUs.
74
 
75
  ```
76
+ torchrun --standalone --nnodes=1 --nproc-per-node=8 pretrain.py --batch_size=16 --gradient_accumulation_steps=128
77
  ```
78
 
79
  **Note** that when training in data-parallel mode it's important that the `gradient_accumulation_steps` divides evenly into the world size for maximum performance. For example, if we have an 8 GPU cluster, we could perform 32 gradient accumulation steps in exactly 4 passes over the network.
80
 
81
+ ### Pretraining Arguments
82
 
83
  | Argument | Default | Type | Description |
84
  |---|---|---|---|
 
88
  | --num_dataset_processes | 8 | int | The number of processes (CPUs) to use to process the dataset. |
89
  | --batch_size | 1 | int | The number of samples to pass through the network at a time. |
90
  | --gradient_accumulation_steps | 128 | int | The number of batches to pass through the network before updating the weights. |
91
+ | --tokens_per_sample | 1024 | int | The number of tokens to pack into a single training sequence. This is sometimes called the context length or block size. |
92
  | --samples_per_epoch | 4096 | int | The number of training samples to pass through the network every epoch. |
93
  | --num_epochs | 1686 | int | The number of epochs to train for. |
94
  | --learning_rate | 1e-2 | float | The learning rate of the Adafactor optimizer. |
 
96
  | --low_memory_optimizer | False | bool | Should the optimizer reduce its memory consumption in exchange for a slightly slower runtime? |
97
  | --max_gradient_norm | 1.0 | float | Clip gradients above this threshold before stepping. |
98
  | --eval_interval | 10 | int | Evaluate the model after this many epochs on the testing set. |
 
99
  | --embedding_dimensions | 1024 | int | The dimensionality of the token embeddings. |
100
  | --num_attention_heads | 16 | int | The number of attention heads within every block. |
101
  | --num_hidden_layers | 24 | int | The number of attention/MLP blocks within the hidden layer of the network. |
102
  | --feed_forward_ratio | 4 | (1, 2, 4) | The ratio of hidden neurons to embedding dimensions in the MLP layers of the network. |
103
  | --dropout | 0.1 | float | The proportion of signals to send to zero during training as regularization. |
104
+ | --activation_checkpointing | False | bool | Should we use activation checkpointing? This will drastically reduce memory utilization during training at the cost of recomputing the forward pass. |
105
  | --ddp_sharding_level | 2 | int | The level of sharding to use for DDP training. Options are 2 or 3 for partial and full sharding respectively, or 0 for no sharding. |
106
+ | --checkpoint_interval | 20 | int | Save the model checkpoint to disk every this many epochs. |
107
  | --checkpoint_path | "./checkpoints/checkpoint.pt" | str | The path to the base checkpoint file on disk. |
108
  | --resume | False | bool | Should we resume training from the last checkpoint? |
109
+ | --run_dir_path | "./runs/pretrain" | str | The path to the TensorBoard run directory for this training session. |
110
  | --device | "cuda" | str | The device to run the computation on. |
111
  | --seed | None | int | The seed for the random number generator. |
112
 
 
117
  | Argument | Default | Type | Description |
118
  |---|---|---|---|
119
  | --base_model_path | "./checkpoints/checkpoint.pt" | string | The path to the base checkpoint on disk. |
120
+ | --max_tokens_per_sample | 4096 | int | The maximum number of tokens to pack into a single training sequence. |
121
+ | --mask_input | False | bool | Should we mask the input part of the training sequences i.e. only train on the supervised output? |
122
  | --batch_size | 1 | int | The number of samples to pass through the network at a time. |
123
  | --gradient_accumulation_steps | 64 | int | The number of batches to pass through the network before updating the weights. |
124
  | --learning_rate | 5e-4 | float | The learning rate of the Adafactor optimizer. |
125
  | --rms_decay | -0.8 | float | The decay rate of the RMS coefficient of the Adafactor optimizer. |
126
+ | --optimizer_low_memory | False | bool | Should the optimizer reduce its memory consumption in exchange for a slightly slower runtime? |
 
127
  | --rank | 8 | int | The rank of the LoRA decomposition matrices. |
128
  | --alpha | 1.0 | float | The strength of the LoRA signal. |
129
  | --dropout | 0.05 | float | The proportion of signals to send to zero during training as regularization. |
 
133
  | --checkpoint_interval | 1 | int | Save the model parameters to disk every this many epochs. |
134
  | --checkpoint_path | "./checkpoints/lora_instruction.pt" | string | The path to the LoRA checkpoint. |
135
  | --resume | False | bool | Should we resume training from the last checkpoint? |
136
+ | --run_dir_path | "./runs/instruction-tune" | str | The path to the TensorBoard run directory for this training session. |
137
  | --device | "cuda" | string | The device to run the computation on. |
138
  | --seed | None | int | The seed for the random number generator. |
139
 
beam_search.py CHANGED
@@ -22,7 +22,8 @@ def main():
22
  "--checkpoint_path", default="./checkpoints/checkpoint.pt", type=str
23
  )
24
  parser.add_argument("--lora_path", default=None, type=str)
25
- parser.add_argument("--max_tokens", default=500, type=int)
 
26
  parser.add_argument("--num_candidates", default=3, type=int)
27
  parser.add_argument("--beam_width", default=16, type=int)
28
  parser.add_argument("--device", default="cuda", type=str)
@@ -92,6 +93,7 @@ def main():
92
  candidates = model.beam_search(
93
  prompt,
94
  args.max_tokens,
 
95
  args.num_candidates,
96
  args.beam_width,
97
  )
 
22
  "--checkpoint_path", default="./checkpoints/checkpoint.pt", type=str
23
  )
24
  parser.add_argument("--lora_path", default=None, type=str)
25
+ parser.add_argument("--max_tokens", default=100, type=int)
26
+ parser.add_argument("--context_length", default=1024, type=int)
27
  parser.add_argument("--num_candidates", default=3, type=int)
28
  parser.add_argument("--beam_width", default=16, type=int)
29
  parser.add_argument("--device", default="cuda", type=str)
 
93
  candidates = model.beam_search(
94
  prompt,
95
  args.max_tokens,
96
+ args.context_length,
97
  args.num_candidates,
98
  args.beam_width,
99
  )
export_model.ipynb CHANGED
@@ -9,7 +9,7 @@
9
  },
10
  {
11
  "cell_type": "code",
12
- "execution_count": 2,
13
  "metadata": {},
14
  "outputs": [],
15
  "source": [
@@ -28,25 +28,21 @@
28
  },
29
  {
30
  "cell_type": "code",
31
- "execution_count": 3,
32
  "metadata": {},
33
  "outputs": [
34
  {
35
- "ename": "TypeError",
36
- "evalue": "GPT.__init__() missing 1 required positional argument: 'feed_forward_ratio'",
37
- "output_type": "error",
38
- "traceback": [
39
- "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
40
- "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)",
41
- "Cell \u001b[0;32mIn[3], line 7\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mmodel\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m GPT, GPTWithLoRA\n\u001b[1;32m 5\u001b[0m checkpoint \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mload(checkpoint_path, map_location\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mcpu\u001b[39m\u001b[38;5;124m\"\u001b[39m, weights_only\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\n\u001b[0;32m----> 7\u001b[0m model \u001b[38;5;241m=\u001b[39m \u001b[43mGPT\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mcheckpoint\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mmodel_args\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 9\u001b[0m model \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mcompile(model)\n\u001b[1;32m 11\u001b[0m model\u001b[38;5;241m.\u001b[39mload_state_dict(checkpoint[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmodel\u001b[39m\u001b[38;5;124m\"\u001b[39m])\n",
42
- "\u001b[0;31mTypeError\u001b[0m: GPT.__init__() missing 1 required positional argument: 'feed_forward_ratio'"
43
  ]
44
  }
45
  ],
46
  "source": [
47
  "import torch\n",
48
  "\n",
49
- "from model import GPT, GPTWithLoRA\n",
50
  "\n",
51
  "checkpoint = torch.load(checkpoint_path, map_location=\"cpu\", weights_only=True)\n",
52
  "\n",
@@ -68,10 +64,12 @@
68
  },
69
  {
70
  "cell_type": "code",
71
- "execution_count": 58,
72
  "metadata": {},
73
  "outputs": [],
74
  "source": [
 
 
75
  "if lora_path != None:\n",
76
  " checkpoint = torch.load(lora_path, map_location=\"cpu\", weights_only=True)\n",
77
  "\n",
@@ -95,14 +93,14 @@
95
  },
96
  {
97
  "cell_type": "code",
98
- "execution_count": 59,
99
  "metadata": {},
100
  "outputs": [
101
  {
102
  "name": "stdout",
103
  "output_type": "stream",
104
  "text": [
105
- "Model saved to ./exports/lightgpt-small-turbo.safetensors\n"
106
  ]
107
  }
108
  ],
@@ -127,66 +125,46 @@
127
  },
128
  {
129
  "cell_type": "code",
130
- "execution_count": 86,
131
  "metadata": {},
132
  "outputs": [
133
- {
134
- "name": "stdout",
135
- "output_type": "stream",
136
- "text": [
137
- "[torch.onnx] Obtain model graph for `OptimizedModule([...]` with `torch.export.export`...\n"
138
- ]
139
- },
140
- {
141
- "name": "stderr",
142
- "output_type": "stream",
143
- "text": [
144
- "W0108 18:27:01.430000 5473 torch/onnx/_internal/exporter/_registration.py:73] torchvision is not installed. Skipping torchvision::nms\n"
145
- ]
146
- },
147
- {
148
- "name": "stdout",
149
- "output_type": "stream",
150
- "text": [
151
- "[torch.onnx] Obtain model graph for `OptimizedModule([...]` with `torch.export.export`... ✅\n",
152
- "[torch.onnx] Translate the graph into ONNX...\n"
153
- ]
154
- },
155
  {
156
  "name": "stderr",
157
  "output_type": "stream",
158
  "text": [
159
- "W0108 18:27:04.197000 5473 torch/onnx/_internal/exporter/_core.py:848] Skipping constant argument ConstantArgument(name='', value=None)\n"
 
160
  ]
161
  },
162
  {
163
  "name": "stdout",
164
  "output_type": "stream",
165
  "text": [
166
- "[torch.onnx] Translate the graph into ONNX... ✅\n",
167
- "Model saved to ./exports/lightgpt-small-turbo.onnx\n"
168
  ]
169
  }
170
  ],
171
  "source": [
172
- "from torch.onnx import export\n",
 
 
173
  "\n",
174
  "example_input = torch.randint(0, model.vocabulary_size - 1, (1, model.block_size))\n",
175
  "\n",
 
 
176
  "model.eval() # Turn off dropout and other train-time operations\n",
177
  "\n",
178
- "example_output, _ = model(example_input)\n",
 
 
 
 
179
  "\n",
180
  "onnx_path = path.join(exports_path, f\"{model_name}.onnx\")\n",
181
  "\n",
182
- "export(\n",
183
- " model,\n",
184
- " example_input,\n",
185
- " onnx_path,\n",
186
- " input_names=[\"input_tokens\", \"labels\"],\n",
187
- " output_names=[\"logits\"],\n",
188
- " dynamo=True,\n",
189
- ")\n",
190
  "\n",
191
  "print(f\"Model saved to {onnx_path}\")"
192
  ]
@@ -195,75 +173,41 @@
195
  "cell_type": "markdown",
196
  "metadata": {},
197
  "source": [
198
- "We can verify the ONNX model with the ONNX API."
199
  ]
200
  },
201
  {
202
  "cell_type": "code",
203
- "execution_count": 87,
204
  "metadata": {},
205
  "outputs": [
206
  {
207
  "name": "stdout",
208
  "output_type": "stream",
209
  "text": [
210
- "Looks OK\n"
211
- ]
212
- }
213
- ],
214
- "source": [
215
- "import onnx\n",
216
- "\n",
217
- "onnx_model = onnx.load(onnx_path)\n",
218
- "\n",
219
- "onnx.checker.check_model(onnx_model)\n",
220
- "\n",
221
- "print(\"Looks OK\")"
222
- ]
223
- },
224
- {
225
- "cell_type": "markdown",
226
- "metadata": {},
227
- "source": [
228
- "Lastly, let's compare the output of PyTorch with the ONNX runtime to see if they are the same."
229
- ]
230
- },
231
- {
232
- "cell_type": "code",
233
- "execution_count": null,
234
- "metadata": {},
235
- "outputs": [
236
- {
237
- "ename": "NameError",
238
- "evalue": "name 'onnx_path' is not defined",
239
- "output_type": "error",
240
- "traceback": [
241
- "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
242
- "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
243
- "Cell \u001b[0;32mIn[1], line 7\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mnumpy\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m \u001b[38;5;21;01mnp\u001b[39;00m\n\u001b[1;32m 5\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mnumpy\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mtesting\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m assert_allclose\n\u001b[0;32m----> 7\u001b[0m session \u001b[38;5;241m=\u001b[39m onnxruntime\u001b[38;5;241m.\u001b[39mInferenceSession(\u001b[43monnx_path\u001b[49m, providers\u001b[38;5;241m=\u001b[39m[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mCPUExecutionProvider\u001b[39m\u001b[38;5;124m\"\u001b[39m])\n\u001b[1;32m 9\u001b[0m onnx_input \u001b[38;5;241m=\u001b[39m {\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124minput_tokens\u001b[39m\u001b[38;5;124m\"\u001b[39m: example_input\u001b[38;5;241m.\u001b[39mnumpy()}\n\u001b[1;32m 11\u001b[0m output \u001b[38;5;241m=\u001b[39m session\u001b[38;5;241m.\u001b[39mrun(\u001b[38;5;28;01mNone\u001b[39;00m, onnx_input)\n",
244
- "\u001b[0;31mNameError\u001b[0m: name 'onnx_path' is not defined"
245
  ]
246
  }
247
  ],
248
  "source": [
249
  "import onnxruntime\n",
250
  "\n",
251
- "import numpy as np\n",
252
- "\n",
253
  "from numpy.testing import assert_allclose\n",
254
  "\n",
 
 
255
  "session = onnxruntime.InferenceSession(onnx_path, providers=[\"CPUExecutionProvider\"])\n",
256
  "\n",
257
- "onnx_input = {\"input_tokens\": example_input.numpy()}\n",
258
  "\n",
259
- "output = session.run(None, onnx_input)\n",
260
  "\n",
261
- "onnx_output = output[0]\n",
262
- "pytorch_output = np.array(example_output.detach())\n",
263
  "\n",
264
- "assert_allclose(pytorch_output, onnx_output, rtol=1e-2, atol=1e-03)\n",
265
  "\n",
266
- "print(\"Looking good\")"
267
  ]
268
  }
269
  ],
 
9
  },
10
  {
11
  "cell_type": "code",
12
+ "execution_count": 18,
13
  "metadata": {},
14
  "outputs": [],
15
  "source": [
 
28
  },
29
  {
30
  "cell_type": "code",
31
+ "execution_count": 19,
32
  "metadata": {},
33
  "outputs": [
34
  {
35
+ "name": "stdout",
36
+ "output_type": "stream",
37
+ "text": [
38
+ "Base checkpoint loaded successfully\n"
 
 
 
 
39
  ]
40
  }
41
  ],
42
  "source": [
43
  "import torch\n",
44
  "\n",
45
+ "from model import GPT\n",
46
  "\n",
47
  "checkpoint = torch.load(checkpoint_path, map_location=\"cpu\", weights_only=True)\n",
48
  "\n",
 
64
  },
65
  {
66
  "cell_type": "code",
67
+ "execution_count": 20,
68
  "metadata": {},
69
  "outputs": [],
70
  "source": [
71
+ "from model import GPTWithLoRA\n",
72
+ "\n",
73
  "if lora_path != None:\n",
74
  " checkpoint = torch.load(lora_path, map_location=\"cpu\", weights_only=True)\n",
75
  "\n",
 
93
  },
94
  {
95
  "cell_type": "code",
96
+ "execution_count": 21,
97
  "metadata": {},
98
  "outputs": [
99
  {
100
  "name": "stdout",
101
  "output_type": "stream",
102
  "text": [
103
+ "Model saved to ./exports/lightgpt-small.safetensors\n"
104
  ]
105
  }
106
  ],
 
125
  },
126
  {
127
  "cell_type": "code",
128
+ "execution_count": 22,
129
  "metadata": {},
130
  "outputs": [
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
  {
132
  "name": "stderr",
133
  "output_type": "stream",
134
  "text": [
135
+ "/home/andrew/Workspace/LightGPT/.venv/lib/python3.12/site-packages/torch/onnx/_internal/_exporter_legacy.py:116: UserWarning: torch.onnx.dynamo_export only implements opset version 18 for now. If you need to use a different opset version, please register them with register_custom_op.\n",
136
+ " warnings.warn(\n"
137
  ]
138
  },
139
  {
140
  "name": "stdout",
141
  "output_type": "stream",
142
  "text": [
143
+ "Applied 72 of general pattern rewrite rules.\n",
144
+ "Model saved to ./exports/lightgpt-small.onnx\n"
145
  ]
146
  }
147
  ],
148
  "source": [
149
+ "from model import ONNXModel\n",
150
+ "\n",
151
+ "from torch.onnx import dynamo_export, ExportOptions\n",
152
  "\n",
153
  "example_input = torch.randint(0, model.vocabulary_size - 1, (1, model.block_size))\n",
154
  "\n",
155
+ "model = ONNXModel(model) # Nicer inferencing API\n",
156
+ "\n",
157
  "model.eval() # Turn off dropout and other train-time operations\n",
158
  "\n",
159
+ "export_options = ExportOptions(\n",
160
+ " dynamic_shapes=True\n",
161
+ ") # Necessary for variable batch and sequence lengths\n",
162
+ "\n",
163
+ "onnx_model = dynamo_export(model, example_input, export_options=export_options)\n",
164
  "\n",
165
  "onnx_path = path.join(exports_path, f\"{model_name}.onnx\")\n",
166
  "\n",
167
+ "onnx_model.save(onnx_path)\n",
 
 
 
 
 
 
 
168
  "\n",
169
  "print(f\"Model saved to {onnx_path}\")"
170
  ]
 
173
  "cell_type": "markdown",
174
  "metadata": {},
175
  "source": [
176
+ "Lastly, let's compare the output of PyTorch with the ONNX runtime to see if they are the same."
177
  ]
178
  },
179
  {
180
  "cell_type": "code",
181
+ "execution_count": null,
182
  "metadata": {},
183
  "outputs": [
184
  {
185
  "name": "stdout",
186
  "output_type": "stream",
187
  "text": [
188
+ "Looking good!\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
189
  ]
190
  }
191
  ],
192
  "source": [
193
  "import onnxruntime\n",
194
  "\n",
 
 
195
  "from numpy.testing import assert_allclose\n",
196
  "\n",
197
+ "pytorch_logits = model(example_input)\n",
198
+ "\n",
199
  "session = onnxruntime.InferenceSession(onnx_path, providers=[\"CPUExecutionProvider\"])\n",
200
  "\n",
201
+ "onnx_input = {\"l_x_\": example_input.numpy()}\n",
202
  "\n",
203
+ "onnx_logits = session.run(None, onnx_input)\n",
204
  "\n",
205
+ "onnx_logits = onnx_logits[0]\n",
206
+ "pytorch_logits = pytorch_logits.detach().numpy()\n",
207
  "\n",
208
+ "assert_allclose(pytorch_logits, onnx_logits, rtol=1e-2, atol=1e-03)\n",
209
  "\n",
210
+ "print(\"Looks good!\")"
211
  ]
212
  }
213
  ],
generate.py CHANGED
@@ -23,6 +23,7 @@ def main():
23
  )
24
  parser.add_argument("--lora_path", default=None, type=str)
25
  parser.add_argument("--max_tokens", default=1000, type=int)
 
26
  parser.add_argument("--temperature", default=1.0, type=float)
27
  parser.add_argument("--top_k", default=500, type=int)
28
  parser.add_argument("--top_p", default=0.9, type=float)
@@ -91,7 +92,12 @@ def main():
91
  prompt = torch.tensor(prompt, dtype=torch.int64, device=args.device)
92
 
93
  for token in model.generate(
94
- prompt, args.max_tokens, args.temperature, args.top_k, args.top_p
 
 
 
 
 
95
  ):
96
  out = tokenizer.decode_single_token_bytes(token).decode(
97
  "utf-8", errors="replace"
 
23
  )
24
  parser.add_argument("--lora_path", default=None, type=str)
25
  parser.add_argument("--max_tokens", default=1000, type=int)
26
+ parser.add_argument("--context_length", default=1024, type=int)
27
  parser.add_argument("--temperature", default=1.0, type=float)
28
  parser.add_argument("--top_k", default=500, type=int)
29
  parser.add_argument("--top_p", default=0.9, type=float)
 
92
  prompt = torch.tensor(prompt, dtype=torch.int64, device=args.device)
93
 
94
  for token in model.generate(
95
+ prompt,
96
+ args.max_tokens,
97
+ args.context_length,
98
+ args.temperature,
99
+ args.top_k,
100
+ args.top_p,
101
  ):
102
  out = tokenizer.decode_single_token_bytes(token).decode(
103
  "utf-8", errors="replace"
instruction-tune.py CHANGED
@@ -9,6 +9,7 @@ from torch.optim import Adafactor
9
  from torch.amp import autocast
10
  from torch.cuda import is_available as cuda_is_available, is_bf16_supported
11
  from torch.utils.data import random_split
 
12
 
13
  from torchmetrics.text import Perplexity
14
 
@@ -26,12 +27,13 @@ def main():
26
  parser.add_argument(
27
  "--base_model_path", default="./checkpoints/checkpoint.pt", type=str
28
  )
 
 
29
  parser.add_argument("--batch_size", default=1, type=int)
30
  parser.add_argument("--gradient_accumulation_steps", default=64, type=int)
31
  parser.add_argument("--learning_rate", default=5e-4, type=float)
32
  parser.add_argument("--rms_decay", default=-0.8, type=float)
33
  parser.add_argument("--optimizer_low_memory", default=True, type=bool)
34
- parser.add_argument("--mask_input", default=False, type=bool)
35
  parser.add_argument("--num_epochs", default=4, type=int)
36
  parser.add_argument("--rank", default=8, type=int)
37
  parser.add_argument("--alpha", default=1.0, type=float)
@@ -43,6 +45,7 @@ def main():
43
  "--checkpoint_path", default="./checkpoints/lora_instruction.pt", type=str
44
  )
45
  parser.add_argument("--resume", action="store_true")
 
46
  parser.add_argument("--device", default="cuda", type=str)
47
  parser.add_argument("--seed", default=None, type=int)
48
 
@@ -65,6 +68,8 @@ def main():
65
  torch.manual_seed(args.seed)
66
  random.seed(args.seed)
67
 
 
 
68
  checkpoint = torch.load(
69
  args.base_model_path, map_location=args.device, weights_only=True
70
  )
@@ -75,7 +80,7 @@ def main():
75
 
76
  dataset = Alpaca(
77
  tokenizer,
78
- max_tokens_per_sample=model_args["block_size"],
79
  mask_input=args.mask_input,
80
  )
81
 
@@ -173,6 +178,8 @@ def main():
173
 
174
  average_cross_entropy = total_cross_entropy / total_batches
175
 
 
 
176
  print(
177
  f"Epoch {epoch}: Cross Entropy: {average_cross_entropy:.5f}",
178
  )
@@ -191,6 +198,8 @@ def main():
191
 
192
  perplexity = perplexity_metric.compute()
193
 
 
 
194
  print(f"Perplexity: {perplexity:.3f}")
195
 
196
  perplexity_metric.reset()
 
9
  from torch.amp import autocast
10
  from torch.cuda import is_available as cuda_is_available, is_bf16_supported
11
  from torch.utils.data import random_split
12
+ from torch.utils.tensorboard import SummaryWriter
13
 
14
  from torchmetrics.text import Perplexity
15
 
 
27
  parser.add_argument(
28
  "--base_model_path", default="./checkpoints/checkpoint.pt", type=str
29
  )
30
+ parser.add_argument("--max_tokens_per_sample", default=4096, type=int)
31
+ parser.add_argument("--mask_input", action="store_true")
32
  parser.add_argument("--batch_size", default=1, type=int)
33
  parser.add_argument("--gradient_accumulation_steps", default=64, type=int)
34
  parser.add_argument("--learning_rate", default=5e-4, type=float)
35
  parser.add_argument("--rms_decay", default=-0.8, type=float)
36
  parser.add_argument("--optimizer_low_memory", default=True, type=bool)
 
37
  parser.add_argument("--num_epochs", default=4, type=int)
38
  parser.add_argument("--rank", default=8, type=int)
39
  parser.add_argument("--alpha", default=1.0, type=float)
 
45
  "--checkpoint_path", default="./checkpoints/lora_instruction.pt", type=str
46
  )
47
  parser.add_argument("--resume", action="store_true")
48
+ parser.add_argument("--run_dir_path", default="./runs/instruction-tune", type=str)
49
  parser.add_argument("--device", default="cuda", type=str)
50
  parser.add_argument("--seed", default=None, type=int)
51
 
 
68
  torch.manual_seed(args.seed)
69
  random.seed(args.seed)
70
 
71
+ logger = SummaryWriter(args.run_dir_path)
72
+
73
  checkpoint = torch.load(
74
  args.base_model_path, map_location=args.device, weights_only=True
75
  )
 
80
 
81
  dataset = Alpaca(
82
  tokenizer,
83
+ max_tokens_per_sample=args.max_tokens_per_sample,
84
  mask_input=args.mask_input,
85
  )
86
 
 
178
 
179
  average_cross_entropy = total_cross_entropy / total_batches
180
 
181
+ logger.add_scalar("cross entropy", average_cross_entropy, epoch)
182
+
183
  print(
184
  f"Epoch {epoch}: Cross Entropy: {average_cross_entropy:.5f}",
185
  )
 
198
 
199
  perplexity = perplexity_metric.compute()
200
 
201
+ logger.add_scalar("perplexity", perplexity, epoch)
202
+
203
  print(f"Perplexity: {perplexity:.3f}")
204
 
205
  perplexity_metric.reset()
model.py CHANGED
@@ -1,4 +1,4 @@
1
- from math import sqrt, exp
2
  from dataclasses import dataclass
3
  from functools import partial, cached_property
4
  from typing import Iterator, Self
@@ -13,8 +13,8 @@ from torch.nn import (
13
  Embedding,
14
  MultiheadAttention,
15
  Linear,
 
16
  RMSNorm,
17
- GELU,
18
  Dropout1d,
19
  CrossEntropyLoss,
20
  Parameter,
@@ -27,25 +27,21 @@ from torch.utils.checkpoint import checkpoint as torch_checkpoint
27
 
28
 
29
  class GPT(Module):
30
- """A generative pre-trained transformer."""
31
 
32
  def __init__(
33
  self,
34
- block_size: int,
35
  embedding_dimensions: int,
36
  num_heads: int,
37
  num_layers: int,
38
  feed_forward_ratio: int,
39
  dropout: float,
40
- vocabulary_size: int,
41
  padding_index: int,
42
  eos_index: int,
43
  ):
44
  super().__init__()
45
 
46
- if block_size < 1:
47
- raise ValueError(f"Block size must be greater than 0, {block_size} given.")
48
-
49
  if num_layers <= 0:
50
  raise ValueError(f"Num layers must be greater than 0, {num_layers} given.")
51
 
@@ -67,16 +63,10 @@ class GPT(Module):
67
 
68
  self.token_embeddings = token_embeddings
69
 
70
- causal_mask = torch.full((block_size, block_size), float("-inf"))
71
- causal_mask = torch.triu(causal_mask, diagonal=1)
72
-
73
- self.causal_mask = Buffer(causal_mask, persistent=False)
74
-
75
  self.body = ModuleList(
76
  [
77
  CausalSelfAttentionBlock(
78
  embedding_dimensions,
79
- block_size,
80
  num_heads,
81
  feed_forward_ratio,
82
  dropout,
@@ -93,7 +83,6 @@ class GPT(Module):
93
  self.loss_function = CrossEntropyLoss(ignore_index=padding_index)
94
 
95
  self.vocabulary_size = vocabulary_size
96
- self.block_size = block_size
97
  self.eos_index = eos_index
98
 
99
  @cached_property
@@ -108,9 +97,10 @@ class GPT(Module):
108
  ) -> tuple[Tensor, Tensor | None]:
109
  z = self.token_embeddings(x)
110
 
111
- b, t = x.size()
112
 
113
- causal_mask = self.causal_mask[:t, :t]
 
114
 
115
  for layer in self.body:
116
  z = self.checkpoint(layer, z, causal_mask)
@@ -132,14 +122,15 @@ class GPT(Module):
132
  def generate(
133
  self,
134
  prompt: Tensor,
135
- max_tokens: int = 500,
 
136
  temperature: float = 1.0,
137
  top_k: int = 500,
138
  top_p: float = 0.9,
139
  ) -> Iterator:
140
  """
141
  Given a prompt, sample the next {max_tokens} tokens from the model weighted
142
- by their predicted probabilities.
143
  """
144
 
145
  if max_tokens <= 0:
@@ -161,7 +152,7 @@ class GPT(Module):
161
  context_window = prompt
162
 
163
  for _ in range(max_tokens):
164
- context_window = context_window[-self.block_size :]
165
 
166
  y_pred, _ = self.forward(context_window.unsqueeze(0))
167
 
@@ -201,12 +192,15 @@ class GPT(Module):
201
  def beam_search(
202
  self,
203
  prompt: Tensor,
204
- max_tokens: int = 200,
 
205
  num_candidates: int = 3,
206
  beam_width: int = 16,
207
  ) -> list:
208
  """
209
- Given a prompt, return the {num_candidates} highest probability sequences.
 
 
210
  """
211
 
212
  if max_tokens <= 0:
@@ -267,7 +261,7 @@ class GPT(Module):
267
 
268
  context_window = torch.cat((prompt, candidate.tokens))
269
 
270
- context_window = context_window[-self.block_size :]
271
 
272
  y_pred, _ = self.forward(context_window.unsqueeze(0))
273
 
@@ -293,7 +287,7 @@ class GPT(Module):
293
 
294
  class GPTWithLoRA(Module):
295
  """
296
- A wrapper for pre-trained GPT models that applies a LoRA reparameterization
297
  to the intermediate layers of the network.
298
  """
299
 
@@ -382,13 +376,26 @@ class GPTWithLoRA(Module):
382
  return self.model.beam_search(prompt, max_tokens, num_candidates, beam_width)
383
 
384
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
385
  class CausalSelfAttentionBlock(Module):
386
  """Causal self-attention block with residual connections."""
387
 
388
  def __init__(
389
  self,
390
  embedding_dimensions: int,
391
- block_size: int,
392
  num_heads: int,
393
  feed_forward_ratio: int,
394
  dropout: float,
@@ -400,9 +407,6 @@ class CausalSelfAttentionBlock(Module):
400
  f"Embedding dimensions must be greater than 0, {embedding_dimensions} given."
401
  )
402
 
403
- if block_size <= 0:
404
- raise ValueError(f"Block size must be greater than 0, {block_size} given.")
405
-
406
  if num_heads <= 0:
407
  raise ValueError(f"Num heads must be greater than 0, {num_heads} given.")
408
 
@@ -459,7 +463,7 @@ class MLP(Module):
459
 
460
  self.layers = Sequential(
461
  Linear(embedding_dimensions, hidden_dimensions, bias=False),
462
- GELU(),
463
  Linear(hidden_dimensions, embedding_dimensions, bias=False),
464
  )
465
 
 
1
+ from math import sqrt
2
  from dataclasses import dataclass
3
  from functools import partial, cached_property
4
  from typing import Iterator, Self
 
13
  Embedding,
14
  MultiheadAttention,
15
  Linear,
16
+ SiLU,
17
  RMSNorm,
 
18
  Dropout1d,
19
  CrossEntropyLoss,
20
  Parameter,
 
27
 
28
 
29
  class GPT(Module):
30
+ """A generative pretrained transformer."""
31
 
32
  def __init__(
33
  self,
34
+ vocabulary_size: int,
35
  embedding_dimensions: int,
36
  num_heads: int,
37
  num_layers: int,
38
  feed_forward_ratio: int,
39
  dropout: float,
 
40
  padding_index: int,
41
  eos_index: int,
42
  ):
43
  super().__init__()
44
 
 
 
 
45
  if num_layers <= 0:
46
  raise ValueError(f"Num layers must be greater than 0, {num_layers} given.")
47
 
 
63
 
64
  self.token_embeddings = token_embeddings
65
 
 
 
 
 
 
66
  self.body = ModuleList(
67
  [
68
  CausalSelfAttentionBlock(
69
  embedding_dimensions,
 
70
  num_heads,
71
  feed_forward_ratio,
72
  dropout,
 
83
  self.loss_function = CrossEntropyLoss(ignore_index=padding_index)
84
 
85
  self.vocabulary_size = vocabulary_size
 
86
  self.eos_index = eos_index
87
 
88
  @cached_property
 
97
  ) -> tuple[Tensor, Tensor | None]:
98
  z = self.token_embeddings(x)
99
 
100
+ b, t, d = z.size()
101
 
102
+ causal_mask = torch.full((t, t), float("-inf"), dtype=z.dtype, device=z.device)
103
+ causal_mask = torch.triu(causal_mask, diagonal=1)
104
 
105
  for layer in self.body:
106
  z = self.checkpoint(layer, z, causal_mask)
 
122
  def generate(
123
  self,
124
  prompt: Tensor,
125
+ max_tokens: int = 1000,
126
+ context_length: int = 1024,
127
  temperature: float = 1.0,
128
  top_k: int = 500,
129
  top_p: float = 0.9,
130
  ) -> Iterator:
131
  """
132
  Given a prompt, sample the next {max_tokens} tokens from the model weighted
133
+ by their predicted probabilities and filtered by the {top_k} and {top_p}.
134
  """
135
 
136
  if max_tokens <= 0:
 
152
  context_window = prompt
153
 
154
  for _ in range(max_tokens):
155
+ context_window = context_window[-context_length:]
156
 
157
  y_pred, _ = self.forward(context_window.unsqueeze(0))
158
 
 
192
  def beam_search(
193
  self,
194
  prompt: Tensor,
195
+ max_tokens: int = 100,
196
+ context_length: int = 1024,
197
  num_candidates: int = 3,
198
  beam_width: int = 16,
199
  ) -> list:
200
  """
201
+ Given a prompt, return the {num_candidates} highest probability sequences. Note that
202
+ this method is often best for generating shorter sequences and is typically less
203
+ natural sounding than sequences that are more random in nature.
204
  """
205
 
206
  if max_tokens <= 0:
 
261
 
262
  context_window = torch.cat((prompt, candidate.tokens))
263
 
264
+ context_window = context_window[-context_length:]
265
 
266
  y_pred, _ = self.forward(context_window.unsqueeze(0))
267
 
 
287
 
288
  class GPTWithLoRA(Module):
289
  """
290
+ A wrapper for pretrained GPT models that applies a LoRA reparameterization
291
  to the intermediate layers of the network.
292
  """
293
 
 
376
  return self.model.beam_search(prompt, max_tokens, num_candidates, beam_width)
377
 
378
 
379
+ class ONNXModel(Module):
380
+ """This wrapper provides a cleaner inferencing API for production models."""
381
+
382
+ def __init__(self, model: GPT | GPTWithLoRA):
383
+ super().__init__()
384
+
385
+ self.model = model
386
+
387
+ def forward(self, x: Tensor) -> Tensor:
388
+ logits, _ = self.model.forward(x, None)
389
+
390
+ return logits
391
+
392
+
393
  class CausalSelfAttentionBlock(Module):
394
  """Causal self-attention block with residual connections."""
395
 
396
  def __init__(
397
  self,
398
  embedding_dimensions: int,
 
399
  num_heads: int,
400
  feed_forward_ratio: int,
401
  dropout: float,
 
407
  f"Embedding dimensions must be greater than 0, {embedding_dimensions} given."
408
  )
409
 
 
 
 
410
  if num_heads <= 0:
411
  raise ValueError(f"Num heads must be greater than 0, {num_heads} given.")
412
 
 
463
 
464
  self.layers = Sequential(
465
  Linear(embedding_dimensions, hidden_dimensions, bias=False),
466
+ SiLU(),
467
  Linear(hidden_dimensions, embedding_dimensions, bias=False),
468
  )
469
 
model_sizing.ipynb CHANGED
@@ -9,15 +9,19 @@
9
  },
10
  {
11
  "cell_type": "code",
12
- "execution_count": 148,
13
  "metadata": {},
14
  "outputs": [],
15
  "source": [
16
- "block_size = 1024\n",
17
  "vocabulary_size = 50257\n",
18
  "embedding_dimensions = 1024\n",
19
  "num_attention_heads = 16\n",
20
  "num_hidden_layers = 24\n",
 
 
 
 
21
  "samples_per_epoch = 4096"
22
  ]
23
  },
@@ -30,7 +34,7 @@
30
  },
31
  {
32
  "cell_type": "code",
33
- "execution_count": 149,
34
  "metadata": {},
35
  "outputs": [
36
  {
@@ -49,12 +53,12 @@
49
  "text": [
50
  "Token Embeddings 51,463,168 14.56%\n",
51
  "Attention 100,663,296 28.48%\n",
52
- "MLP 201,326,592 56.95%\n",
53
  "RMS Norm 50,176 0.01%\n",
54
  "Output Layer 0 0.00%\n",
55
  "\n",
56
  "\n",
57
- "Total parameters: 353,503,232\n"
58
  ]
59
  }
60
  ],
@@ -67,7 +71,11 @@
67
  " embedding_dimensions**2 + embedding_dimensions * 3 * embedding_dimensions\n",
68
  " )\n",
69
  " * num_hidden_layers,\n",
70
- " \"MLP\": embedding_dimensions * 4 * embedding_dimensions * 2 * num_hidden_layers,\n",
 
 
 
 
71
  " \"RMS Norm\": embedding_dimensions * num_hidden_layers * 2 + embedding_dimensions,\n",
72
  " \"Output Layer\": 0, # Tied to token embeddings\n",
73
  "}\n",
@@ -99,7 +107,7 @@
99
  },
100
  {
101
  "cell_type": "code",
102
- "execution_count": 150,
103
  "metadata": {},
104
  "outputs": [
105
  {
@@ -125,7 +133,7 @@
125
  },
126
  {
127
  "cell_type": "code",
128
- "execution_count": 151,
129
  "metadata": {},
130
  "outputs": [
131
  {
@@ -151,7 +159,7 @@
151
  },
152
  {
153
  "cell_type": "code",
154
- "execution_count": 152,
155
  "metadata": {},
156
  "outputs": [
157
  {
@@ -181,14 +189,14 @@
181
  },
182
  {
183
  "cell_type": "code",
184
- "execution_count": 153,
185
  "metadata": {},
186
  "outputs": [
187
  {
188
  "name": "stdout",
189
  "output_type": "stream",
190
  "text": [
191
- "Optimal training tokens: 7,070,064,640\n",
192
  "Epochs required: 1,686\n",
193
  "\n"
194
  ]
@@ -197,7 +205,9 @@
197
  "source": [
198
  "num_training_tokens = 20 * total_parameter_count\n",
199
  "\n",
200
- "num_epochs_required = round(num_training_tokens / (samples_per_epoch * block_size))\n",
 
 
201
  "\n",
202
  "print(f\"Optimal training tokens: {num_training_tokens:,}\")\n",
203
  "\n",
@@ -213,7 +223,7 @@
213
  },
214
  {
215
  "cell_type": "code",
216
- "execution_count": 154,
217
  "metadata": {},
218
  "outputs": [
219
  {
@@ -231,43 +241,57 @@
231
  "output_type": "stream",
232
  "text": [
233
  "Attention 309,237,645,312 37.39%\n",
234
- "MLP 412,317,745,152 49.86%\n",
235
  "RMS Norm 179,200 0.00%\n",
236
  "Output Layer 105,396,568,064 12.75%\n",
237
  "\n",
238
  "\n",
239
- "Total forward FLOPs: 826,952,137,728\n"
240
  ]
241
  }
242
  ],
243
  "source": [
244
  "ops_per_matmul = 2 # Multiply + accumulate (MAC)\n",
245
- "ops_per_activation = 9 # Assuming GELU\n",
246
  "ops_per_rms_norm = 7 # y = (x / sqrt(rms[x] + epsilon)) * gamma\n",
247
  "\n",
248
  "head_dimensions = embedding_dimensions // num_attention_heads\n",
249
  "\n",
250
  "# K, Q, V projections\n",
251
  "attention = (\n",
252
- " ops_per_matmul * block_size * (embedding_dimensions * 3 * embedding_dimensions)\n",
 
 
253
  ")\n",
254
  "\n",
255
  "# Attention logits\n",
256
- "attention += ops_per_matmul * block_size * block_size * embedding_dimensions\n",
 
 
257
  "\n",
258
  "# Reductions\n",
259
  "attention += (\n",
260
- " ops_per_matmul * num_attention_heads * (block_size * block_size * head_dimensions)\n",
 
 
261
  ")\n",
262
  "\n",
263
  "# Output projection\n",
264
- "attention += ops_per_matmul * block_size * embedding_dimensions**2\n",
265
  "\n",
266
  "attention *= num_hidden_layers\n",
267
  "\n",
268
  "# Linear transformations\n",
269
- "mlp = ops_per_matmul * block_size * (embedding_dimensions * (4 * embedding_dimensions))\n",
270
- "mlp += ops_per_matmul * block_size * ((4 * embedding_dimensions) * embedding_dimensions)\n",
 
 
 
 
 
 
 
 
271
  "\n",
272
  "# Non-linear activations\n",
273
  "mlp += ops_per_activation * (4 * embedding_dimensions)\n",
@@ -276,7 +300,9 @@
276
  "\n",
277
  "rms_norm = ops_per_rms_norm * embedding_dimensions * (num_hidden_layers + 1)\n",
278
  "\n",
279
- "output_layer = ops_per_matmul * block_size * embedding_dimensions * vocabulary_size\n",
 
 
280
  "\n",
281
  "flops = {\n",
282
  " \"Attention\": attention,\n",
@@ -312,14 +338,14 @@
312
  },
313
  {
314
  "cell_type": "code",
315
- "execution_count": 155,
316
  "metadata": {},
317
  "outputs": [
318
  {
319
  "name": "stdout",
320
  "output_type": "stream",
321
  "text": [
322
- "Total backward FLOPs: 1,653,904,275,456\n"
323
  ]
324
  }
325
  ],
@@ -338,14 +364,14 @@
338
  },
339
  {
340
  "cell_type": "code",
341
- "execution_count": 156,
342
  "metadata": {},
343
  "outputs": [
344
  {
345
  "name": "stdout",
346
  "output_type": "stream",
347
  "text": [
348
- "Total roundtrip FLOPs: 2,480,856,413,184\n"
349
  ]
350
  }
351
  ],
@@ -364,24 +390,24 @@
364
  },
365
  {
366
  "cell_type": "code",
367
- "execution_count": 157,
368
  "metadata": {},
369
  "outputs": [
370
  {
371
  "name": "stdout",
372
  "output_type": "stream",
373
  "text": [
374
- "Total PaLM FLOPs: 2,481,161,502,720\n"
375
  ]
376
  }
377
  ],
378
  "source": [
379
  "palm_flops_per_token = (\n",
380
  " 6 * total_parameter_count\n",
381
- " + 12 * num_hidden_layers * num_attention_heads * head_dimensions * block_size\n",
382
  ")\n",
383
  "\n",
384
- "total_palm_flops = palm_flops_per_token * block_size\n",
385
  "\n",
386
  "print(f\"Total PaLM FLOPs: {total_palm_flops:,}\")"
387
  ]
@@ -390,23 +416,30 @@
390
  "cell_type": "markdown",
391
  "metadata": {},
392
  "source": [
393
- "Finally, let's estimate how long it would take to train over the estimated optimal number of tokens given the hardware conifgurations we defined above. Note that these results shown here are a theoretical scenario and do not factor in additional overhead such as activation checkpointing or network latency."
 
 
 
 
 
 
 
394
  ]
395
  },
396
  {
397
  "cell_type": "code",
398
- "execution_count": 158,
399
  "metadata": {},
400
  "outputs": [
401
  {
402
  "name": "stdout",
403
  "output_type": "stream",
404
  "text": [
405
- "RTX A2000: 935.43 seconds/epoch, 18.25 days required, MFU: 17.0%\n",
406
- "RTX A4000: 348.64 seconds/epoch, 6.80 days required, MFU: 19.0%\n",
407
- "RTX 3090: 154.75 seconds/epoch, 3.02 days required, MFU: 23.0%\n",
408
- "A100 SXM: 44.01 seconds/epoch, 0.86 days required, MFU: 37.0%\n",
409
- "HGX A100: 6.79 seconds/epoch, 0.13 days required, MFU: 30.0%\n"
410
  ]
411
  }
412
  ],
@@ -421,10 +454,6 @@
421
  " mfu: float\n",
422
  "\n",
423
  " @property\n",
424
- " def percentage_utilization(self) -> float:\n",
425
- " return self.mfu * 100\n",
426
- "\n",
427
- " @property\n",
428
  " def actual_flops(self) -> float:\n",
429
  " return self.mfu * self.advertised_flops\n",
430
  "\n",
@@ -443,7 +472,7 @@
443
  " days_required = num_epochs_required * seconds_per_epoch / 60 / 60 / 24\n",
444
  "\n",
445
  " print(\n",
446
- " f\"{device.name}: {seconds_per_epoch:.2f} seconds/epoch, {days_required:,.2f} days required, MFU: {device.percentage_utilization}%\"\n",
447
  " )"
448
  ]
449
  }
 
9
  },
10
  {
11
  "cell_type": "code",
12
+ "execution_count": 188,
13
  "metadata": {},
14
  "outputs": [],
15
  "source": [
16
+ "# Model\n",
17
  "vocabulary_size = 50257\n",
18
  "embedding_dimensions = 1024\n",
19
  "num_attention_heads = 16\n",
20
  "num_hidden_layers = 24\n",
21
+ "feed_forward_ratio = 4\n",
22
+ "\n",
23
+ "# Training set\n",
24
+ "tokens_per_sample = 1024\n",
25
  "samples_per_epoch = 4096"
26
  ]
27
  },
 
34
  },
35
  {
36
  "cell_type": "code",
37
+ "execution_count": 189,
38
  "metadata": {},
39
  "outputs": [
40
  {
 
53
  "text": [
54
  "Token Embeddings 51,463,168 14.56%\n",
55
  "Attention 100,663,296 28.48%\n",
56
+ "MLP 201,326,616 56.95%\n",
57
  "RMS Norm 50,176 0.01%\n",
58
  "Output Layer 0 0.00%\n",
59
  "\n",
60
  "\n",
61
+ "Total parameters: 353,503,256\n"
62
  ]
63
  }
64
  ],
 
71
  " embedding_dimensions**2 + embedding_dimensions * 3 * embedding_dimensions\n",
72
  " )\n",
73
  " * num_hidden_layers,\n",
74
+ " \"MLP\": embedding_dimensions\n",
75
+ " * feed_forward_ratio\n",
76
+ " * embedding_dimensions\n",
77
+ " * 2\n",
78
+ " * num_hidden_layers,\n",
79
  " \"RMS Norm\": embedding_dimensions * num_hidden_layers * 2 + embedding_dimensions,\n",
80
  " \"Output Layer\": 0, # Tied to token embeddings\n",
81
  "}\n",
 
107
  },
108
  {
109
  "cell_type": "code",
110
+ "execution_count": 190,
111
  "metadata": {},
112
  "outputs": [
113
  {
 
133
  },
134
  {
135
  "cell_type": "code",
136
+ "execution_count": 191,
137
  "metadata": {},
138
  "outputs": [
139
  {
 
159
  },
160
  {
161
  "cell_type": "code",
162
+ "execution_count": 192,
163
  "metadata": {},
164
  "outputs": [
165
  {
 
189
  },
190
  {
191
  "cell_type": "code",
192
+ "execution_count": 193,
193
  "metadata": {},
194
  "outputs": [
195
  {
196
  "name": "stdout",
197
  "output_type": "stream",
198
  "text": [
199
+ "Optimal training tokens: 7,070,065,120\n",
200
  "Epochs required: 1,686\n",
201
  "\n"
202
  ]
 
205
  "source": [
206
  "num_training_tokens = 20 * total_parameter_count\n",
207
  "\n",
208
+ "num_epochs_required = round(\n",
209
+ " num_training_tokens / (samples_per_epoch * tokens_per_sample)\n",
210
+ ")\n",
211
  "\n",
212
  "print(f\"Optimal training tokens: {num_training_tokens:,}\")\n",
213
  "\n",
 
223
  },
224
  {
225
  "cell_type": "code",
226
+ "execution_count": 194,
227
  "metadata": {},
228
  "outputs": [
229
  {
 
241
  "output_type": "stream",
242
  "text": [
243
  "Attention 309,237,645,312 37.39%\n",
244
+ "MLP 412,317,450,240 49.86%\n",
245
  "RMS Norm 179,200 0.00%\n",
246
  "Output Layer 105,396,568,064 12.75%\n",
247
  "\n",
248
  "\n",
249
+ "Total forward FLOPs: 826,951,842,816\n"
250
  ]
251
  }
252
  ],
253
  "source": [
254
  "ops_per_matmul = 2 # Multiply + accumulate (MAC)\n",
255
+ "ops_per_activation = 5 # Assuming SiLU\n",
256
  "ops_per_rms_norm = 7 # y = (x / sqrt(rms[x] + epsilon)) * gamma\n",
257
  "\n",
258
  "head_dimensions = embedding_dimensions // num_attention_heads\n",
259
  "\n",
260
  "# K, Q, V projections\n",
261
  "attention = (\n",
262
+ " ops_per_matmul\n",
263
+ " * tokens_per_sample\n",
264
+ " * (embedding_dimensions * 3 * embedding_dimensions)\n",
265
  ")\n",
266
  "\n",
267
  "# Attention logits\n",
268
+ "attention += (\n",
269
+ " ops_per_matmul * tokens_per_sample * tokens_per_sample * embedding_dimensions\n",
270
+ ")\n",
271
  "\n",
272
  "# Reductions\n",
273
  "attention += (\n",
274
+ " ops_per_matmul\n",
275
+ " * num_attention_heads\n",
276
+ " * (tokens_per_sample * tokens_per_sample * head_dimensions)\n",
277
  ")\n",
278
  "\n",
279
  "# Output projection\n",
280
+ "attention += ops_per_matmul * tokens_per_sample * embedding_dimensions**2\n",
281
  "\n",
282
  "attention *= num_hidden_layers\n",
283
  "\n",
284
  "# Linear transformations\n",
285
+ "mlp = (\n",
286
+ " ops_per_matmul\n",
287
+ " * tokens_per_sample\n",
288
+ " * (embedding_dimensions * (4 * embedding_dimensions))\n",
289
+ ")\n",
290
+ "mlp += (\n",
291
+ " ops_per_matmul\n",
292
+ " * tokens_per_sample\n",
293
+ " * ((4 * embedding_dimensions) * embedding_dimensions)\n",
294
+ ")\n",
295
  "\n",
296
  "# Non-linear activations\n",
297
  "mlp += ops_per_activation * (4 * embedding_dimensions)\n",
 
300
  "\n",
301
  "rms_norm = ops_per_rms_norm * embedding_dimensions * (num_hidden_layers + 1)\n",
302
  "\n",
303
+ "output_layer = (\n",
304
+ " ops_per_matmul * tokens_per_sample * embedding_dimensions * vocabulary_size\n",
305
+ ")\n",
306
  "\n",
307
  "flops = {\n",
308
  " \"Attention\": attention,\n",
 
338
  },
339
  {
340
  "cell_type": "code",
341
+ "execution_count": 195,
342
  "metadata": {},
343
  "outputs": [
344
  {
345
  "name": "stdout",
346
  "output_type": "stream",
347
  "text": [
348
+ "Total backward FLOPs: 1,653,903,685,632\n"
349
  ]
350
  }
351
  ],
 
364
  },
365
  {
366
  "cell_type": "code",
367
+ "execution_count": 196,
368
  "metadata": {},
369
  "outputs": [
370
  {
371
  "name": "stdout",
372
  "output_type": "stream",
373
  "text": [
374
+ "Total roundtrip FLOPs: 2,480,855,528,448\n"
375
  ]
376
  }
377
  ],
 
390
  },
391
  {
392
  "cell_type": "code",
393
+ "execution_count": 197,
394
  "metadata": {},
395
  "outputs": [
396
  {
397
  "name": "stdout",
398
  "output_type": "stream",
399
  "text": [
400
+ "Total PaLM FLOPs: 2,481,161,650,176\n"
401
  ]
402
  }
403
  ],
404
  "source": [
405
  "palm_flops_per_token = (\n",
406
  " 6 * total_parameter_count\n",
407
+ " + 12 * num_hidden_layers * num_attention_heads * head_dimensions * tokens_per_sample\n",
408
  ")\n",
409
  "\n",
410
+ "total_palm_flops = palm_flops_per_token * tokens_per_sample\n",
411
  "\n",
412
  "print(f\"Total PaLM FLOPs: {total_palm_flops:,}\")"
413
  ]
 
416
  "cell_type": "markdown",
417
  "metadata": {},
418
  "source": [
419
+ "The two estimates are pretty close so let's proceed."
420
+ ]
421
+ },
422
+ {
423
+ "cell_type": "markdown",
424
+ "metadata": {},
425
+ "source": [
426
+ "Finally, let's estimate how long it would take to train over the optimal number of tokens given some common Nvidia Ampere generation GPU hardware configurations. Note that these results shown here are a theoretical scenario and do not factor in additional overhead such as activation checkpointing or network latency."
427
  ]
428
  },
429
  {
430
  "cell_type": "code",
431
+ "execution_count": 198,
432
  "metadata": {},
433
  "outputs": [
434
  {
435
  "name": "stdout",
436
  "output_type": "stream",
437
  "text": [
438
+ "RTX A2000: 935.43 seconds/epoch, 18.25 days required\n",
439
+ "RTX A4000: 348.64 seconds/epoch, 6.80 days required\n",
440
+ "RTX 3090: 154.75 seconds/epoch, 3.02 days required\n",
441
+ "A100 SXM: 44.01 seconds/epoch, 0.86 days required\n",
442
+ "HGX A100: 6.79 seconds/epoch, 0.13 days required\n"
443
  ]
444
  }
445
  ],
 
454
  " mfu: float\n",
455
  "\n",
456
  " @property\n",
 
 
 
 
457
  " def actual_flops(self) -> float:\n",
458
  " return self.mfu * self.advertised_flops\n",
459
  "\n",
 
472
  " days_required = num_epochs_required * seconds_per_epoch / 60 / 60 / 24\n",
473
  "\n",
474
  " print(\n",
475
+ " f\"{device.name}: {seconds_per_epoch:.2f} seconds/epoch, {days_required:,.2f} days required\"\n",
476
  " )"
477
  ]
478
  }
pre-train.py → pretrain.py RENAMED
@@ -16,6 +16,7 @@ from torch.cuda import set_device, is_available as cuda_is_available, is_bf16_su
16
  from torch.nn.utils import clip_grad_norm_
17
  from torch.distributed import init_process_group, destroy_process_group
18
  from torch.distributed.fsdp import FullyShardedDataParallel, ShardingStrategy
 
19
 
20
  from torchmetrics.text import Perplexity
21
 
@@ -38,7 +39,7 @@ DDP_BACKEND = "nccl"
38
 
39
 
40
  def main():
41
- parser = ArgumentParser(description="Pre-train the GPT.")
42
 
43
  parser.add_argument(
44
  "--dataset_subset",
@@ -54,6 +55,7 @@ def main():
54
  parser.add_argument("--num_dataset_processes", default=8, type=int)
55
  parser.add_argument("--batch_size", default=1, type=int)
56
  parser.add_argument("--gradient_accumulation_steps", default=128, type=int)
 
57
  parser.add_argument("--samples_per_epoch", default=4096, type=int)
58
  parser.add_argument("--num_epochs", default=1686, type=int)
59
  parser.add_argument("--learning_rate", default=1e-2, type=float)
@@ -61,7 +63,6 @@ def main():
61
  parser.add_argument("--low_memory_optimizer", action="store_true")
62
  parser.add_argument("--max_gradient_norm", default=1.0, type=float)
63
  parser.add_argument("--dropout", default=0.1, type=float)
64
- parser.add_argument("--block_size", default=1024, type=int)
65
  parser.add_argument("--embedding_dimensions", default=1024, type=int)
66
  parser.add_argument("--num_attention_heads", default=16, type=int)
67
  parser.add_argument("--num_hidden_layers", default=24, type=int)
@@ -74,6 +75,7 @@ def main():
74
  "--checkpoint_path", default="./checkpoints/checkpoint.pt", type=str
75
  )
76
  parser.add_argument("--resume", action="store_true")
 
77
  parser.add_argument("--device", default="cuda", type=str)
78
  parser.add_argument("--seed", default=None, type=int)
79
 
@@ -156,23 +158,25 @@ def main():
156
  torch.manual_seed(args.seed)
157
  random.seed(args.seed)
158
 
 
 
159
  tokenizer = tiktoken.get_encoding(args.token_encoding)
160
 
161
  training = Fineweb(
162
- tokenizer,
163
  root_path=args.dataset_path,
164
  subset=args.dataset_subset,
165
  split="train",
166
- tokens_per_sample=args.block_size,
167
  samples_per_epoch=args.samples_per_epoch,
168
  num_processes=args.num_dataset_processes,
169
  )
170
  testing = Fineweb(
171
- tokenizer,
172
  root_path=args.dataset_path,
173
  subset=args.dataset_subset,
174
  split="test",
175
- tokens_per_sample=args.block_size,
176
  samples_per_epoch=args.samples_per_epoch,
177
  num_processes=args.num_dataset_processes,
178
  )
@@ -185,13 +189,12 @@ def main():
185
  )
186
 
187
  model_args = {
188
- "block_size": args.block_size,
189
  "embedding_dimensions": args.embedding_dimensions,
190
  "num_heads": args.num_attention_heads,
191
  "num_layers": args.num_hidden_layers,
192
  "feed_forward_ratio": args.feed_forward_ratio,
193
  "dropout": args.dropout,
194
- "vocabulary_size": tokenizer.n_vocab,
195
  "padding_index": training.PADDING_INDEX,
196
  "eos_index": tokenizer.eot_token,
197
  }
@@ -252,7 +255,7 @@ def main():
252
 
253
  register_signal_handlers()
254
 
255
- print("Pre-training ...")
256
 
257
  for epoch in range(starting_epoch, args.num_epochs + 1):
258
  total_cross_entropy, total_gradient_norm = 0.0, 0.0
@@ -292,14 +295,18 @@ def main():
292
 
293
  total_batches += 1
294
 
295
- average_cross_entropy = total_cross_entropy / total_batches
296
- average_gradient_norm = total_gradient_norm / total_steps
 
297
 
298
- print(
299
- f"Epoch {epoch}:",
300
- f"Cross Entropy: {average_cross_entropy:.5f},",
301
- f"Gradient Norm: {average_gradient_norm:.4f}",
302
- )
 
 
 
303
 
304
  if epoch % args.eval_interval == 0 and IS_MASTER:
305
  model.eval()
@@ -315,6 +322,8 @@ def main():
315
 
316
  perplexity = perplexity_metric.compute()
317
 
 
 
318
  print(f"Perplexity: {perplexity:.3f}")
319
 
320
  perplexity_metric.reset()
 
16
  from torch.nn.utils import clip_grad_norm_
17
  from torch.distributed import init_process_group, destroy_process_group
18
  from torch.distributed.fsdp import FullyShardedDataParallel, ShardingStrategy
19
+ from torch.utils.tensorboard import SummaryWriter
20
 
21
  from torchmetrics.text import Perplexity
22
 
 
39
 
40
 
41
  def main():
42
+ parser = ArgumentParser(description="Pretrain the GPT.")
43
 
44
  parser.add_argument(
45
  "--dataset_subset",
 
55
  parser.add_argument("--num_dataset_processes", default=8, type=int)
56
  parser.add_argument("--batch_size", default=1, type=int)
57
  parser.add_argument("--gradient_accumulation_steps", default=128, type=int)
58
+ parser.add_argument("--tokens_per_sample", default=1024, type=int)
59
  parser.add_argument("--samples_per_epoch", default=4096, type=int)
60
  parser.add_argument("--num_epochs", default=1686, type=int)
61
  parser.add_argument("--learning_rate", default=1e-2, type=float)
 
63
  parser.add_argument("--low_memory_optimizer", action="store_true")
64
  parser.add_argument("--max_gradient_norm", default=1.0, type=float)
65
  parser.add_argument("--dropout", default=0.1, type=float)
 
66
  parser.add_argument("--embedding_dimensions", default=1024, type=int)
67
  parser.add_argument("--num_attention_heads", default=16, type=int)
68
  parser.add_argument("--num_hidden_layers", default=24, type=int)
 
75
  "--checkpoint_path", default="./checkpoints/checkpoint.pt", type=str
76
  )
77
  parser.add_argument("--resume", action="store_true")
78
+ parser.add_argument("--run_dir_path", default="./runs/pretrain", type=str)
79
  parser.add_argument("--device", default="cuda", type=str)
80
  parser.add_argument("--seed", default=None, type=int)
81
 
 
158
  torch.manual_seed(args.seed)
159
  random.seed(args.seed)
160
 
161
+ logger = SummaryWriter(args.run_dir_path)
162
+
163
  tokenizer = tiktoken.get_encoding(args.token_encoding)
164
 
165
  training = Fineweb(
166
+ tokenizer=tokenizer,
167
  root_path=args.dataset_path,
168
  subset=args.dataset_subset,
169
  split="train",
170
+ tokens_per_sample=args.tokens_per_sample,
171
  samples_per_epoch=args.samples_per_epoch,
172
  num_processes=args.num_dataset_processes,
173
  )
174
  testing = Fineweb(
175
+ tokenizer=tokenizer,
176
  root_path=args.dataset_path,
177
  subset=args.dataset_subset,
178
  split="test",
179
+ tokens_per_sample=args.tokens_per_sample,
180
  samples_per_epoch=args.samples_per_epoch,
181
  num_processes=args.num_dataset_processes,
182
  )
 
189
  )
190
 
191
  model_args = {
192
+ "vocabulary_size": tokenizer.n_vocab,
193
  "embedding_dimensions": args.embedding_dimensions,
194
  "num_heads": args.num_attention_heads,
195
  "num_layers": args.num_hidden_layers,
196
  "feed_forward_ratio": args.feed_forward_ratio,
197
  "dropout": args.dropout,
 
198
  "padding_index": training.PADDING_INDEX,
199
  "eos_index": tokenizer.eot_token,
200
  }
 
255
 
256
  register_signal_handlers()
257
 
258
+ print("Pretraining ...")
259
 
260
  for epoch in range(starting_epoch, args.num_epochs + 1):
261
  total_cross_entropy, total_gradient_norm = 0.0, 0.0
 
295
 
296
  total_batches += 1
297
 
298
+ if IS_MASTER:
299
+ average_cross_entropy = total_cross_entropy / total_batches
300
+ average_gradient_norm = total_gradient_norm / total_steps
301
 
302
+ logger.add_scalar("cross entropy", average_cross_entropy, epoch)
303
+ logger.add_scalar("gradient norm", average_gradient_norm, epoch)
304
+
305
+ print(
306
+ f"Epoch {epoch}:",
307
+ f"Cross Entropy: {average_cross_entropy:.5f},",
308
+ f"Gradient Norm: {average_gradient_norm:.4f}",
309
+ )
310
 
311
  if epoch % args.eval_interval == 0 and IS_MASTER:
312
  model.eval()
 
322
 
323
  perplexity = perplexity_metric.compute()
324
 
325
+ logger.add_scalar("perplexity", perplexity, epoch)
326
+
327
  print(f"Perplexity: {perplexity:.3f}")
328
 
329
  perplexity_metric.reset()
requirements.txt CHANGED
@@ -6,5 +6,7 @@ tiktoken==0.8.0
6
  tqdm==4.66.6
7
  matplotlib==3.9.2
8
  safetensors==0.5.2
9
- onnxscript==0.1.0
 
10
  onnxruntime==1.20.1
 
 
6
  tqdm==4.66.6
7
  matplotlib==3.9.2
8
  safetensors==0.5.2
9
+ onnx==1.17.0
10
+ onnxscript==0.1.0.dev20250108
11
  onnxruntime==1.20.1
12
+ tensorboard==2.18.0
runs/.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ *
2
+ !.gitignore