Andrew DalPino
commited on
Commit
·
e431f0f
1
Parent(s):
fc4824e
Blanket optimizations
Browse files- .gitignore +1 -0
- README.md +26 -23
- beam_search.py +3 -1
- export_model.ipynb +39 -95
- generate.py +7 -1
- instruction-tune.py +11 -2
- model.py +33 -29
- model_sizing.ipynb +72 -43
- pre-train.py → pretrain.py +25 -16
- requirements.txt +3 -1
- runs/.gitignore +2 -0
.gitignore
CHANGED
@@ -10,6 +10,7 @@ wheels/
|
|
10 |
*.egg-info/
|
11 |
.installed.cfg
|
12 |
*.egg
|
|
|
13 |
.venv
|
14 |
venv/
|
15 |
ENV/
|
|
|
10 |
*.egg-info/
|
11 |
.installed.cfg
|
12 |
*.egg
|
13 |
+
*.sarif
|
14 |
.venv
|
15 |
venv/
|
16 |
ENV/
|
README.md
CHANGED
@@ -13,7 +13,7 @@ tags:
|
|
13 |
---
|
14 |
# LightGPT
|
15 |
|
16 |
-
LightGPT is a lightweight generative
|
17 |
|
18 |
## Features
|
19 |
|
@@ -23,18 +23,18 @@ LightGPT is a lightweight generative pre-trained Transformer (GPT) model for the
|
|
23 |
|
24 |
- **Fully Open-source**: Unlike closed-source LLMs, LightGPT provides both the model weights *and* the source code to train, fine-tune, export, and generate text from the model using your own hardware. With the help of the open-source software community, we aim to democratize access to AI and continually improve the models.
|
25 |
|
26 |
-
## Suggested
|
27 |
|
28 |
-
Below is a table of some suggested
|
29 |
|
30 |
-
| Name | Vocab. Size |
|
31 |
|---|---|---|---|---|---|---|---|
|
32 |
-
| Small | 50,257 | 1024 |
|
33 |
-
| Medium | 50,257 |
|
34 |
-
| Large | 100,275 |
|
35 |
-
| X-large | 100,275 |
|
36 |
-
| XX-large | 200,017 |
|
37 |
-
| XXX-large | 200,017 |
|
38 |
|
39 |
## Install Project Dependencies
|
40 |
|
@@ -48,37 +48,37 @@ source ./.venv/bin/activate
|
|
48 |
pip install -r requirements.txt
|
49 |
```
|
50 |
|
51 |
-
##
|
52 |
|
53 |
-
For the
|
54 |
|
55 |
```
|
56 |
-
python
|
57 |
```
|
58 |
|
59 |
**Note** that it will take a while to download and pre-process the dataset the first time that the training script is run.
|
60 |
|
61 |
-
To customize the default "Small" architecture you can adjust the `block_size`, `embedding_dimensions`, `num_hidden_layers`, and `num_attention_heads` arguments of the
|
62 |
|
63 |
```
|
64 |
-
python
|
65 |
```
|
66 |
|
67 |
You can also adjust the `batch_size`, `learning_rate`, and `gradient_accumulation_steps` to suite your training setup.
|
68 |
|
69 |
```
|
70 |
-
python
|
71 |
```
|
72 |
|
73 |
For distributed training, use PyTorch's [torchrun](https://pytorch.org/docs/stable/elastic/run.html) extension to launch a distributed data parallel (DDP) session. The example below is for executing the training script on a single node with 8 individual GPUs.
|
74 |
|
75 |
```
|
76 |
-
torchrun --standalone --nnodes=1 --nproc-per-node=8
|
77 |
```
|
78 |
|
79 |
**Note** that when training in data-parallel mode it's important that the `gradient_accumulation_steps` divides evenly into the world size for maximum performance. For example, if we have an 8 GPU cluster, we could perform 32 gradient accumulation steps in exactly 4 passes over the network.
|
80 |
|
81 |
-
###
|
82 |
|
83 |
| Argument | Default | Type | Description |
|
84 |
|---|---|---|---|
|
@@ -88,6 +88,7 @@ torchrun --standalone --nnodes=1 --nproc-per-node=8 pre-train.py --batch_size=16
|
|
88 |
| --num_dataset_processes | 8 | int | The number of processes (CPUs) to use to process the dataset. |
|
89 |
| --batch_size | 1 | int | The number of samples to pass through the network at a time. |
|
90 |
| --gradient_accumulation_steps | 128 | int | The number of batches to pass through the network before updating the weights. |
|
|
|
91 |
| --samples_per_epoch | 4096 | int | The number of training samples to pass through the network every epoch. |
|
92 |
| --num_epochs | 1686 | int | The number of epochs to train for. |
|
93 |
| --learning_rate | 1e-2 | float | The learning rate of the Adafactor optimizer. |
|
@@ -95,17 +96,17 @@ torchrun --standalone --nnodes=1 --nproc-per-node=8 pre-train.py --batch_size=16
|
|
95 |
| --low_memory_optimizer | False | bool | Should the optimizer reduce its memory consumption in exchange for a slightly slower runtime? |
|
96 |
| --max_gradient_norm | 1.0 | float | Clip gradients above this threshold before stepping. |
|
97 |
| --eval_interval | 10 | int | Evaluate the model after this many epochs on the testing set. |
|
98 |
-
| --block_size | 1024 | int | The number of tokens within the context window for every sample. |
|
99 |
| --embedding_dimensions | 1024 | int | The dimensionality of the token embeddings. |
|
100 |
| --num_attention_heads | 16 | int | The number of attention heads within every block. |
|
101 |
| --num_hidden_layers | 24 | int | The number of attention/MLP blocks within the hidden layer of the network. |
|
102 |
| --feed_forward_ratio | 4 | (1, 2, 4) | The ratio of hidden neurons to embedding dimensions in the MLP layers of the network. |
|
103 |
| --dropout | 0.1 | float | The proportion of signals to send to zero during training as regularization. |
|
104 |
-
| --activation_checkpointing | False | bool | Should we use activation checkpointing? This will reduce
|
105 |
| --ddp_sharding_level | 2 | int | The level of sharding to use for DDP training. Options are 2 or 3 for partial and full sharding respectively, or 0 for no sharding. |
|
106 |
-
| --checkpoint_interval | 20 | int | Save the model
|
107 |
| --checkpoint_path | "./checkpoints/checkpoint.pt" | str | The path to the base checkpoint file on disk. |
|
108 |
| --resume | False | bool | Should we resume training from the last checkpoint? |
|
|
|
109 |
| --device | "cuda" | str | The device to run the computation on. |
|
110 |
| --seed | None | int | The seed for the random number generator. |
|
111 |
|
@@ -116,12 +117,13 @@ torchrun --standalone --nnodes=1 --nproc-per-node=8 pre-train.py --batch_size=16
|
|
116 |
| Argument | Default | Type | Description |
|
117 |
|---|---|---|---|
|
118 |
| --base_model_path | "./checkpoints/checkpoint.pt" | string | The path to the base checkpoint on disk. |
|
|
|
|
|
119 |
| --batch_size | 1 | int | The number of samples to pass through the network at a time. |
|
120 |
| --gradient_accumulation_steps | 64 | int | The number of batches to pass through the network before updating the weights. |
|
121 |
| --learning_rate | 5e-4 | float | The learning rate of the Adafactor optimizer. |
|
122 |
| --rms_decay | -0.8 | float | The decay rate of the RMS coefficient of the Adafactor optimizer. |
|
123 |
-
| --optimizer_low_memory |
|
124 |
-
| --mask_input | False | bool | Should we mask the input part of the sample i.e. only train on the output? |
|
125 |
| --rank | 8 | int | The rank of the LoRA decomposition matrices. |
|
126 |
| --alpha | 1.0 | float | The strength of the LoRA signal. |
|
127 |
| --dropout | 0.05 | float | The proportion of signals to send to zero during training as regularization. |
|
@@ -131,6 +133,7 @@ torchrun --standalone --nnodes=1 --nproc-per-node=8 pre-train.py --batch_size=16
|
|
131 |
| --checkpoint_interval | 1 | int | Save the model parameters to disk every this many epochs. |
|
132 |
| --checkpoint_path | "./checkpoints/lora_instruction.pt" | string | The path to the LoRA checkpoint. |
|
133 |
| --resume | False | bool | Should we resume training from the last checkpoint? |
|
|
|
134 |
| --device | "cuda" | string | The device to run the computation on. |
|
135 |
| --seed | None | int | The seed for the random number generator. |
|
136 |
|
|
|
13 |
---
|
14 |
# LightGPT
|
15 |
|
16 |
+
LightGPT is a lightweight generative pretrained Transformer (GPT) model for the people! Built using PyTorch and trained on the Fineweb and Alpaca datasets, LightGPT can answer questions, follow instructions, summarize documents, chat, and more. Best of all, the model weights *and* code are fully open-source for you to customize, improve upon, and share with the world.
|
17 |
|
18 |
## Features
|
19 |
|
|
|
23 |
|
24 |
- **Fully Open-source**: Unlike closed-source LLMs, LightGPT provides both the model weights *and* the source code to train, fine-tune, export, and generate text from the model using your own hardware. With the help of the open-source software community, we aim to democratize access to AI and continually improve the models.
|
25 |
|
26 |
+
## Suggested Pretraining Configurations
|
27 |
|
28 |
+
Below is a table of some suggested pretraining configurations but feel free to experiment with settings on your own. See the `model_sizing.ipynb` notebook to estimate the memory and compute requirements for your model configuration.
|
29 |
|
30 |
+
| Name | Vocab. Size | Embedding Dim. | Attn. Heads | Layers | Parameters | Training Tokens |
|
31 |
|---|---|---|---|---|---|---|---|
|
32 |
+
| Small | 50,257 | 1024 | 16 | 24 | 353M | 7B |
|
33 |
+
| Medium | 50,257 | 2048 | 32 | 32 | 1.7B | 34B |
|
34 |
+
| Large | 100,275 | 4096 | 64 | 32 | 6.8B | 132B |
|
35 |
+
| X-large | 100,275 | 4096 | 64 | 64 | 13B | 262B |
|
36 |
+
| XX-large | 200,017 | 8192 | 128 | 64 | 53B | 1T |
|
37 |
+
| XXX-large | 200,017 | 8192 | 128 | 128 | 105B | 2T |
|
38 |
|
39 |
## Install Project Dependencies
|
40 |
|
|
|
48 |
pip install -r requirements.txt
|
49 |
```
|
50 |
|
51 |
+
## Pretraining
|
52 |
|
53 |
+
For the pretraining corpus we use the Fineweb dataset which consists of about 15T high-quality tokens gathered from the worldwide web. The dataset has been split into 3 subsets (10BT, 100BT, and 350BT versions) for training smaller models. If you'd like to start training right away, the default settings should work on most single-GPU systems with 12G of VRAM or more.
|
54 |
|
55 |
```
|
56 |
+
python pretrain.py
|
57 |
```
|
58 |
|
59 |
**Note** that it will take a while to download and pre-process the dataset the first time that the training script is run.
|
60 |
|
61 |
+
To customize the default "Small" architecture you can adjust the `block_size`, `embedding_dimensions`, `num_hidden_layers`, and `num_attention_heads` arguments of the pretraining script.
|
62 |
|
63 |
```
|
64 |
+
python pretrain.py --block_size=2048 --embedding_dimensions=4096 --num_hidden_layers=64 --num_attention_heads=64
|
65 |
```
|
66 |
|
67 |
You can also adjust the `batch_size`, `learning_rate`, and `gradient_accumulation_steps` to suite your training setup.
|
68 |
|
69 |
```
|
70 |
+
python pretrain.py --batch_size=32 --learning_rate=0.01 --gradient_accumulation_steps=128
|
71 |
```
|
72 |
|
73 |
For distributed training, use PyTorch's [torchrun](https://pytorch.org/docs/stable/elastic/run.html) extension to launch a distributed data parallel (DDP) session. The example below is for executing the training script on a single node with 8 individual GPUs.
|
74 |
|
75 |
```
|
76 |
+
torchrun --standalone --nnodes=1 --nproc-per-node=8 pretrain.py --batch_size=16 --gradient_accumulation_steps=128
|
77 |
```
|
78 |
|
79 |
**Note** that when training in data-parallel mode it's important that the `gradient_accumulation_steps` divides evenly into the world size for maximum performance. For example, if we have an 8 GPU cluster, we could perform 32 gradient accumulation steps in exactly 4 passes over the network.
|
80 |
|
81 |
+
### Pretraining Arguments
|
82 |
|
83 |
| Argument | Default | Type | Description |
|
84 |
|---|---|---|---|
|
|
|
88 |
| --num_dataset_processes | 8 | int | The number of processes (CPUs) to use to process the dataset. |
|
89 |
| --batch_size | 1 | int | The number of samples to pass through the network at a time. |
|
90 |
| --gradient_accumulation_steps | 128 | int | The number of batches to pass through the network before updating the weights. |
|
91 |
+
| --tokens_per_sample | 1024 | int | The number of tokens to pack into a single training sequence. This is sometimes called the context length or block size. |
|
92 |
| --samples_per_epoch | 4096 | int | The number of training samples to pass through the network every epoch. |
|
93 |
| --num_epochs | 1686 | int | The number of epochs to train for. |
|
94 |
| --learning_rate | 1e-2 | float | The learning rate of the Adafactor optimizer. |
|
|
|
96 |
| --low_memory_optimizer | False | bool | Should the optimizer reduce its memory consumption in exchange for a slightly slower runtime? |
|
97 |
| --max_gradient_norm | 1.0 | float | Clip gradients above this threshold before stepping. |
|
98 |
| --eval_interval | 10 | int | Evaluate the model after this many epochs on the testing set. |
|
|
|
99 |
| --embedding_dimensions | 1024 | int | The dimensionality of the token embeddings. |
|
100 |
| --num_attention_heads | 16 | int | The number of attention heads within every block. |
|
101 |
| --num_hidden_layers | 24 | int | The number of attention/MLP blocks within the hidden layer of the network. |
|
102 |
| --feed_forward_ratio | 4 | (1, 2, 4) | The ratio of hidden neurons to embedding dimensions in the MLP layers of the network. |
|
103 |
| --dropout | 0.1 | float | The proportion of signals to send to zero during training as regularization. |
|
104 |
+
| --activation_checkpointing | False | bool | Should we use activation checkpointing? This will drastically reduce memory utilization during training at the cost of recomputing the forward pass. |
|
105 |
| --ddp_sharding_level | 2 | int | The level of sharding to use for DDP training. Options are 2 or 3 for partial and full sharding respectively, or 0 for no sharding. |
|
106 |
+
| --checkpoint_interval | 20 | int | Save the model checkpoint to disk every this many epochs. |
|
107 |
| --checkpoint_path | "./checkpoints/checkpoint.pt" | str | The path to the base checkpoint file on disk. |
|
108 |
| --resume | False | bool | Should we resume training from the last checkpoint? |
|
109 |
+
| --run_dir_path | "./runs/pretrain" | str | The path to the TensorBoard run directory for this training session. |
|
110 |
| --device | "cuda" | str | The device to run the computation on. |
|
111 |
| --seed | None | int | The seed for the random number generator. |
|
112 |
|
|
|
117 |
| Argument | Default | Type | Description |
|
118 |
|---|---|---|---|
|
119 |
| --base_model_path | "./checkpoints/checkpoint.pt" | string | The path to the base checkpoint on disk. |
|
120 |
+
| --max_tokens_per_sample | 4096 | int | The maximum number of tokens to pack into a single training sequence. |
|
121 |
+
| --mask_input | False | bool | Should we mask the input part of the training sequences i.e. only train on the supervised output? |
|
122 |
| --batch_size | 1 | int | The number of samples to pass through the network at a time. |
|
123 |
| --gradient_accumulation_steps | 64 | int | The number of batches to pass through the network before updating the weights. |
|
124 |
| --learning_rate | 5e-4 | float | The learning rate of the Adafactor optimizer. |
|
125 |
| --rms_decay | -0.8 | float | The decay rate of the RMS coefficient of the Adafactor optimizer. |
|
126 |
+
| --optimizer_low_memory | False | bool | Should the optimizer reduce its memory consumption in exchange for a slightly slower runtime? |
|
|
|
127 |
| --rank | 8 | int | The rank of the LoRA decomposition matrices. |
|
128 |
| --alpha | 1.0 | float | The strength of the LoRA signal. |
|
129 |
| --dropout | 0.05 | float | The proportion of signals to send to zero during training as regularization. |
|
|
|
133 |
| --checkpoint_interval | 1 | int | Save the model parameters to disk every this many epochs. |
|
134 |
| --checkpoint_path | "./checkpoints/lora_instruction.pt" | string | The path to the LoRA checkpoint. |
|
135 |
| --resume | False | bool | Should we resume training from the last checkpoint? |
|
136 |
+
| --run_dir_path | "./runs/instruction-tune" | str | The path to the TensorBoard run directory for this training session. |
|
137 |
| --device | "cuda" | string | The device to run the computation on. |
|
138 |
| --seed | None | int | The seed for the random number generator. |
|
139 |
|
beam_search.py
CHANGED
@@ -22,7 +22,8 @@ def main():
|
|
22 |
"--checkpoint_path", default="./checkpoints/checkpoint.pt", type=str
|
23 |
)
|
24 |
parser.add_argument("--lora_path", default=None, type=str)
|
25 |
-
parser.add_argument("--max_tokens", default=
|
|
|
26 |
parser.add_argument("--num_candidates", default=3, type=int)
|
27 |
parser.add_argument("--beam_width", default=16, type=int)
|
28 |
parser.add_argument("--device", default="cuda", type=str)
|
@@ -92,6 +93,7 @@ def main():
|
|
92 |
candidates = model.beam_search(
|
93 |
prompt,
|
94 |
args.max_tokens,
|
|
|
95 |
args.num_candidates,
|
96 |
args.beam_width,
|
97 |
)
|
|
|
22 |
"--checkpoint_path", default="./checkpoints/checkpoint.pt", type=str
|
23 |
)
|
24 |
parser.add_argument("--lora_path", default=None, type=str)
|
25 |
+
parser.add_argument("--max_tokens", default=100, type=int)
|
26 |
+
parser.add_argument("--context_length", default=1024, type=int)
|
27 |
parser.add_argument("--num_candidates", default=3, type=int)
|
28 |
parser.add_argument("--beam_width", default=16, type=int)
|
29 |
parser.add_argument("--device", default="cuda", type=str)
|
|
|
93 |
candidates = model.beam_search(
|
94 |
prompt,
|
95 |
args.max_tokens,
|
96 |
+
args.context_length,
|
97 |
args.num_candidates,
|
98 |
args.beam_width,
|
99 |
)
|
export_model.ipynb
CHANGED
@@ -9,7 +9,7 @@
|
|
9 |
},
|
10 |
{
|
11 |
"cell_type": "code",
|
12 |
-
"execution_count":
|
13 |
"metadata": {},
|
14 |
"outputs": [],
|
15 |
"source": [
|
@@ -28,25 +28,21 @@
|
|
28 |
},
|
29 |
{
|
30 |
"cell_type": "code",
|
31 |
-
"execution_count":
|
32 |
"metadata": {},
|
33 |
"outputs": [
|
34 |
{
|
35 |
-
"
|
36 |
-
"
|
37 |
-
"
|
38 |
-
|
39 |
-
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
40 |
-
"\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)",
|
41 |
-
"Cell \u001b[0;32mIn[3], line 7\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mmodel\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m GPT, GPTWithLoRA\n\u001b[1;32m 5\u001b[0m checkpoint \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mload(checkpoint_path, map_location\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mcpu\u001b[39m\u001b[38;5;124m\"\u001b[39m, weights_only\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\n\u001b[0;32m----> 7\u001b[0m model \u001b[38;5;241m=\u001b[39m \u001b[43mGPT\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mcheckpoint\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mmodel_args\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 9\u001b[0m model \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mcompile(model)\n\u001b[1;32m 11\u001b[0m model\u001b[38;5;241m.\u001b[39mload_state_dict(checkpoint[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmodel\u001b[39m\u001b[38;5;124m\"\u001b[39m])\n",
|
42 |
-
"\u001b[0;31mTypeError\u001b[0m: GPT.__init__() missing 1 required positional argument: 'feed_forward_ratio'"
|
43 |
]
|
44 |
}
|
45 |
],
|
46 |
"source": [
|
47 |
"import torch\n",
|
48 |
"\n",
|
49 |
-
"from model import GPT
|
50 |
"\n",
|
51 |
"checkpoint = torch.load(checkpoint_path, map_location=\"cpu\", weights_only=True)\n",
|
52 |
"\n",
|
@@ -68,10 +64,12 @@
|
|
68 |
},
|
69 |
{
|
70 |
"cell_type": "code",
|
71 |
-
"execution_count":
|
72 |
"metadata": {},
|
73 |
"outputs": [],
|
74 |
"source": [
|
|
|
|
|
75 |
"if lora_path != None:\n",
|
76 |
" checkpoint = torch.load(lora_path, map_location=\"cpu\", weights_only=True)\n",
|
77 |
"\n",
|
@@ -95,14 +93,14 @@
|
|
95 |
},
|
96 |
{
|
97 |
"cell_type": "code",
|
98 |
-
"execution_count":
|
99 |
"metadata": {},
|
100 |
"outputs": [
|
101 |
{
|
102 |
"name": "stdout",
|
103 |
"output_type": "stream",
|
104 |
"text": [
|
105 |
-
"Model saved to ./exports/lightgpt-small
|
106 |
]
|
107 |
}
|
108 |
],
|
@@ -127,66 +125,46 @@
|
|
127 |
},
|
128 |
{
|
129 |
"cell_type": "code",
|
130 |
-
"execution_count":
|
131 |
"metadata": {},
|
132 |
"outputs": [
|
133 |
-
{
|
134 |
-
"name": "stdout",
|
135 |
-
"output_type": "stream",
|
136 |
-
"text": [
|
137 |
-
"[torch.onnx] Obtain model graph for `OptimizedModule([...]` with `torch.export.export`...\n"
|
138 |
-
]
|
139 |
-
},
|
140 |
-
{
|
141 |
-
"name": "stderr",
|
142 |
-
"output_type": "stream",
|
143 |
-
"text": [
|
144 |
-
"W0108 18:27:01.430000 5473 torch/onnx/_internal/exporter/_registration.py:73] torchvision is not installed. Skipping torchvision::nms\n"
|
145 |
-
]
|
146 |
-
},
|
147 |
-
{
|
148 |
-
"name": "stdout",
|
149 |
-
"output_type": "stream",
|
150 |
-
"text": [
|
151 |
-
"[torch.onnx] Obtain model graph for `OptimizedModule([...]` with `torch.export.export`... ✅\n",
|
152 |
-
"[torch.onnx] Translate the graph into ONNX...\n"
|
153 |
-
]
|
154 |
-
},
|
155 |
{
|
156 |
"name": "stderr",
|
157 |
"output_type": "stream",
|
158 |
"text": [
|
159 |
-
"
|
|
|
160 |
]
|
161 |
},
|
162 |
{
|
163 |
"name": "stdout",
|
164 |
"output_type": "stream",
|
165 |
"text": [
|
166 |
-
"
|
167 |
-
"Model saved to ./exports/lightgpt-small
|
168 |
]
|
169 |
}
|
170 |
],
|
171 |
"source": [
|
172 |
-
"from
|
|
|
|
|
173 |
"\n",
|
174 |
"example_input = torch.randint(0, model.vocabulary_size - 1, (1, model.block_size))\n",
|
175 |
"\n",
|
|
|
|
|
176 |
"model.eval() # Turn off dropout and other train-time operations\n",
|
177 |
"\n",
|
178 |
-
"
|
|
|
|
|
|
|
|
|
179 |
"\n",
|
180 |
"onnx_path = path.join(exports_path, f\"{model_name}.onnx\")\n",
|
181 |
"\n",
|
182 |
-
"
|
183 |
-
" model,\n",
|
184 |
-
" example_input,\n",
|
185 |
-
" onnx_path,\n",
|
186 |
-
" input_names=[\"input_tokens\", \"labels\"],\n",
|
187 |
-
" output_names=[\"logits\"],\n",
|
188 |
-
" dynamo=True,\n",
|
189 |
-
")\n",
|
190 |
"\n",
|
191 |
"print(f\"Model saved to {onnx_path}\")"
|
192 |
]
|
@@ -195,75 +173,41 @@
|
|
195 |
"cell_type": "markdown",
|
196 |
"metadata": {},
|
197 |
"source": [
|
198 |
-
"
|
199 |
]
|
200 |
},
|
201 |
{
|
202 |
"cell_type": "code",
|
203 |
-
"execution_count":
|
204 |
"metadata": {},
|
205 |
"outputs": [
|
206 |
{
|
207 |
"name": "stdout",
|
208 |
"output_type": "stream",
|
209 |
"text": [
|
210 |
-
"
|
211 |
-
]
|
212 |
-
}
|
213 |
-
],
|
214 |
-
"source": [
|
215 |
-
"import onnx\n",
|
216 |
-
"\n",
|
217 |
-
"onnx_model = onnx.load(onnx_path)\n",
|
218 |
-
"\n",
|
219 |
-
"onnx.checker.check_model(onnx_model)\n",
|
220 |
-
"\n",
|
221 |
-
"print(\"Looks OK\")"
|
222 |
-
]
|
223 |
-
},
|
224 |
-
{
|
225 |
-
"cell_type": "markdown",
|
226 |
-
"metadata": {},
|
227 |
-
"source": [
|
228 |
-
"Lastly, let's compare the output of PyTorch with the ONNX runtime to see if they are the same."
|
229 |
-
]
|
230 |
-
},
|
231 |
-
{
|
232 |
-
"cell_type": "code",
|
233 |
-
"execution_count": null,
|
234 |
-
"metadata": {},
|
235 |
-
"outputs": [
|
236 |
-
{
|
237 |
-
"ename": "NameError",
|
238 |
-
"evalue": "name 'onnx_path' is not defined",
|
239 |
-
"output_type": "error",
|
240 |
-
"traceback": [
|
241 |
-
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
242 |
-
"\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
|
243 |
-
"Cell \u001b[0;32mIn[1], line 7\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mnumpy\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m \u001b[38;5;21;01mnp\u001b[39;00m\n\u001b[1;32m 5\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mnumpy\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mtesting\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m assert_allclose\n\u001b[0;32m----> 7\u001b[0m session \u001b[38;5;241m=\u001b[39m onnxruntime\u001b[38;5;241m.\u001b[39mInferenceSession(\u001b[43monnx_path\u001b[49m, providers\u001b[38;5;241m=\u001b[39m[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mCPUExecutionProvider\u001b[39m\u001b[38;5;124m\"\u001b[39m])\n\u001b[1;32m 9\u001b[0m onnx_input \u001b[38;5;241m=\u001b[39m {\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124minput_tokens\u001b[39m\u001b[38;5;124m\"\u001b[39m: example_input\u001b[38;5;241m.\u001b[39mnumpy()}\n\u001b[1;32m 11\u001b[0m output \u001b[38;5;241m=\u001b[39m session\u001b[38;5;241m.\u001b[39mrun(\u001b[38;5;28;01mNone\u001b[39;00m, onnx_input)\n",
|
244 |
-
"\u001b[0;31mNameError\u001b[0m: name 'onnx_path' is not defined"
|
245 |
]
|
246 |
}
|
247 |
],
|
248 |
"source": [
|
249 |
"import onnxruntime\n",
|
250 |
"\n",
|
251 |
-
"import numpy as np\n",
|
252 |
-
"\n",
|
253 |
"from numpy.testing import assert_allclose\n",
|
254 |
"\n",
|
|
|
|
|
255 |
"session = onnxruntime.InferenceSession(onnx_path, providers=[\"CPUExecutionProvider\"])\n",
|
256 |
"\n",
|
257 |
-
"onnx_input = {\"
|
258 |
"\n",
|
259 |
-
"
|
260 |
"\n",
|
261 |
-
"
|
262 |
-
"
|
263 |
"\n",
|
264 |
-
"assert_allclose(
|
265 |
"\n",
|
266 |
-
"print(\"
|
267 |
]
|
268 |
}
|
269 |
],
|
|
|
9 |
},
|
10 |
{
|
11 |
"cell_type": "code",
|
12 |
+
"execution_count": 18,
|
13 |
"metadata": {},
|
14 |
"outputs": [],
|
15 |
"source": [
|
|
|
28 |
},
|
29 |
{
|
30 |
"cell_type": "code",
|
31 |
+
"execution_count": 19,
|
32 |
"metadata": {},
|
33 |
"outputs": [
|
34 |
{
|
35 |
+
"name": "stdout",
|
36 |
+
"output_type": "stream",
|
37 |
+
"text": [
|
38 |
+
"Base checkpoint loaded successfully\n"
|
|
|
|
|
|
|
|
|
39 |
]
|
40 |
}
|
41 |
],
|
42 |
"source": [
|
43 |
"import torch\n",
|
44 |
"\n",
|
45 |
+
"from model import GPT\n",
|
46 |
"\n",
|
47 |
"checkpoint = torch.load(checkpoint_path, map_location=\"cpu\", weights_only=True)\n",
|
48 |
"\n",
|
|
|
64 |
},
|
65 |
{
|
66 |
"cell_type": "code",
|
67 |
+
"execution_count": 20,
|
68 |
"metadata": {},
|
69 |
"outputs": [],
|
70 |
"source": [
|
71 |
+
"from model import GPTWithLoRA\n",
|
72 |
+
"\n",
|
73 |
"if lora_path != None:\n",
|
74 |
" checkpoint = torch.load(lora_path, map_location=\"cpu\", weights_only=True)\n",
|
75 |
"\n",
|
|
|
93 |
},
|
94 |
{
|
95 |
"cell_type": "code",
|
96 |
+
"execution_count": 21,
|
97 |
"metadata": {},
|
98 |
"outputs": [
|
99 |
{
|
100 |
"name": "stdout",
|
101 |
"output_type": "stream",
|
102 |
"text": [
|
103 |
+
"Model saved to ./exports/lightgpt-small.safetensors\n"
|
104 |
]
|
105 |
}
|
106 |
],
|
|
|
125 |
},
|
126 |
{
|
127 |
"cell_type": "code",
|
128 |
+
"execution_count": 22,
|
129 |
"metadata": {},
|
130 |
"outputs": [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
131 |
{
|
132 |
"name": "stderr",
|
133 |
"output_type": "stream",
|
134 |
"text": [
|
135 |
+
"/home/andrew/Workspace/LightGPT/.venv/lib/python3.12/site-packages/torch/onnx/_internal/_exporter_legacy.py:116: UserWarning: torch.onnx.dynamo_export only implements opset version 18 for now. If you need to use a different opset version, please register them with register_custom_op.\n",
|
136 |
+
" warnings.warn(\n"
|
137 |
]
|
138 |
},
|
139 |
{
|
140 |
"name": "stdout",
|
141 |
"output_type": "stream",
|
142 |
"text": [
|
143 |
+
"Applied 72 of general pattern rewrite rules.\n",
|
144 |
+
"Model saved to ./exports/lightgpt-small.onnx\n"
|
145 |
]
|
146 |
}
|
147 |
],
|
148 |
"source": [
|
149 |
+
"from model import ONNXModel\n",
|
150 |
+
"\n",
|
151 |
+
"from torch.onnx import dynamo_export, ExportOptions\n",
|
152 |
"\n",
|
153 |
"example_input = torch.randint(0, model.vocabulary_size - 1, (1, model.block_size))\n",
|
154 |
"\n",
|
155 |
+
"model = ONNXModel(model) # Nicer inferencing API\n",
|
156 |
+
"\n",
|
157 |
"model.eval() # Turn off dropout and other train-time operations\n",
|
158 |
"\n",
|
159 |
+
"export_options = ExportOptions(\n",
|
160 |
+
" dynamic_shapes=True\n",
|
161 |
+
") # Necessary for variable batch and sequence lengths\n",
|
162 |
+
"\n",
|
163 |
+
"onnx_model = dynamo_export(model, example_input, export_options=export_options)\n",
|
164 |
"\n",
|
165 |
"onnx_path = path.join(exports_path, f\"{model_name}.onnx\")\n",
|
166 |
"\n",
|
167 |
+
"onnx_model.save(onnx_path)\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
168 |
"\n",
|
169 |
"print(f\"Model saved to {onnx_path}\")"
|
170 |
]
|
|
|
173 |
"cell_type": "markdown",
|
174 |
"metadata": {},
|
175 |
"source": [
|
176 |
+
"Lastly, let's compare the output of PyTorch with the ONNX runtime to see if they are the same."
|
177 |
]
|
178 |
},
|
179 |
{
|
180 |
"cell_type": "code",
|
181 |
+
"execution_count": null,
|
182 |
"metadata": {},
|
183 |
"outputs": [
|
184 |
{
|
185 |
"name": "stdout",
|
186 |
"output_type": "stream",
|
187 |
"text": [
|
188 |
+
"Looking good!\n"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
189 |
]
|
190 |
}
|
191 |
],
|
192 |
"source": [
|
193 |
"import onnxruntime\n",
|
194 |
"\n",
|
|
|
|
|
195 |
"from numpy.testing import assert_allclose\n",
|
196 |
"\n",
|
197 |
+
"pytorch_logits = model(example_input)\n",
|
198 |
+
"\n",
|
199 |
"session = onnxruntime.InferenceSession(onnx_path, providers=[\"CPUExecutionProvider\"])\n",
|
200 |
"\n",
|
201 |
+
"onnx_input = {\"l_x_\": example_input.numpy()}\n",
|
202 |
"\n",
|
203 |
+
"onnx_logits = session.run(None, onnx_input)\n",
|
204 |
"\n",
|
205 |
+
"onnx_logits = onnx_logits[0]\n",
|
206 |
+
"pytorch_logits = pytorch_logits.detach().numpy()\n",
|
207 |
"\n",
|
208 |
+
"assert_allclose(pytorch_logits, onnx_logits, rtol=1e-2, atol=1e-03)\n",
|
209 |
"\n",
|
210 |
+
"print(\"Looks good!\")"
|
211 |
]
|
212 |
}
|
213 |
],
|
generate.py
CHANGED
@@ -23,6 +23,7 @@ def main():
|
|
23 |
)
|
24 |
parser.add_argument("--lora_path", default=None, type=str)
|
25 |
parser.add_argument("--max_tokens", default=1000, type=int)
|
|
|
26 |
parser.add_argument("--temperature", default=1.0, type=float)
|
27 |
parser.add_argument("--top_k", default=500, type=int)
|
28 |
parser.add_argument("--top_p", default=0.9, type=float)
|
@@ -91,7 +92,12 @@ def main():
|
|
91 |
prompt = torch.tensor(prompt, dtype=torch.int64, device=args.device)
|
92 |
|
93 |
for token in model.generate(
|
94 |
-
prompt,
|
|
|
|
|
|
|
|
|
|
|
95 |
):
|
96 |
out = tokenizer.decode_single_token_bytes(token).decode(
|
97 |
"utf-8", errors="replace"
|
|
|
23 |
)
|
24 |
parser.add_argument("--lora_path", default=None, type=str)
|
25 |
parser.add_argument("--max_tokens", default=1000, type=int)
|
26 |
+
parser.add_argument("--context_length", default=1024, type=int)
|
27 |
parser.add_argument("--temperature", default=1.0, type=float)
|
28 |
parser.add_argument("--top_k", default=500, type=int)
|
29 |
parser.add_argument("--top_p", default=0.9, type=float)
|
|
|
92 |
prompt = torch.tensor(prompt, dtype=torch.int64, device=args.device)
|
93 |
|
94 |
for token in model.generate(
|
95 |
+
prompt,
|
96 |
+
args.max_tokens,
|
97 |
+
args.context_length,
|
98 |
+
args.temperature,
|
99 |
+
args.top_k,
|
100 |
+
args.top_p,
|
101 |
):
|
102 |
out = tokenizer.decode_single_token_bytes(token).decode(
|
103 |
"utf-8", errors="replace"
|
instruction-tune.py
CHANGED
@@ -9,6 +9,7 @@ from torch.optim import Adafactor
|
|
9 |
from torch.amp import autocast
|
10 |
from torch.cuda import is_available as cuda_is_available, is_bf16_supported
|
11 |
from torch.utils.data import random_split
|
|
|
12 |
|
13 |
from torchmetrics.text import Perplexity
|
14 |
|
@@ -26,12 +27,13 @@ def main():
|
|
26 |
parser.add_argument(
|
27 |
"--base_model_path", default="./checkpoints/checkpoint.pt", type=str
|
28 |
)
|
|
|
|
|
29 |
parser.add_argument("--batch_size", default=1, type=int)
|
30 |
parser.add_argument("--gradient_accumulation_steps", default=64, type=int)
|
31 |
parser.add_argument("--learning_rate", default=5e-4, type=float)
|
32 |
parser.add_argument("--rms_decay", default=-0.8, type=float)
|
33 |
parser.add_argument("--optimizer_low_memory", default=True, type=bool)
|
34 |
-
parser.add_argument("--mask_input", default=False, type=bool)
|
35 |
parser.add_argument("--num_epochs", default=4, type=int)
|
36 |
parser.add_argument("--rank", default=8, type=int)
|
37 |
parser.add_argument("--alpha", default=1.0, type=float)
|
@@ -43,6 +45,7 @@ def main():
|
|
43 |
"--checkpoint_path", default="./checkpoints/lora_instruction.pt", type=str
|
44 |
)
|
45 |
parser.add_argument("--resume", action="store_true")
|
|
|
46 |
parser.add_argument("--device", default="cuda", type=str)
|
47 |
parser.add_argument("--seed", default=None, type=int)
|
48 |
|
@@ -65,6 +68,8 @@ def main():
|
|
65 |
torch.manual_seed(args.seed)
|
66 |
random.seed(args.seed)
|
67 |
|
|
|
|
|
68 |
checkpoint = torch.load(
|
69 |
args.base_model_path, map_location=args.device, weights_only=True
|
70 |
)
|
@@ -75,7 +80,7 @@ def main():
|
|
75 |
|
76 |
dataset = Alpaca(
|
77 |
tokenizer,
|
78 |
-
max_tokens_per_sample=
|
79 |
mask_input=args.mask_input,
|
80 |
)
|
81 |
|
@@ -173,6 +178,8 @@ def main():
|
|
173 |
|
174 |
average_cross_entropy = total_cross_entropy / total_batches
|
175 |
|
|
|
|
|
176 |
print(
|
177 |
f"Epoch {epoch}: Cross Entropy: {average_cross_entropy:.5f}",
|
178 |
)
|
@@ -191,6 +198,8 @@ def main():
|
|
191 |
|
192 |
perplexity = perplexity_metric.compute()
|
193 |
|
|
|
|
|
194 |
print(f"Perplexity: {perplexity:.3f}")
|
195 |
|
196 |
perplexity_metric.reset()
|
|
|
9 |
from torch.amp import autocast
|
10 |
from torch.cuda import is_available as cuda_is_available, is_bf16_supported
|
11 |
from torch.utils.data import random_split
|
12 |
+
from torch.utils.tensorboard import SummaryWriter
|
13 |
|
14 |
from torchmetrics.text import Perplexity
|
15 |
|
|
|
27 |
parser.add_argument(
|
28 |
"--base_model_path", default="./checkpoints/checkpoint.pt", type=str
|
29 |
)
|
30 |
+
parser.add_argument("--max_tokens_per_sample", default=4096, type=int)
|
31 |
+
parser.add_argument("--mask_input", action="store_true")
|
32 |
parser.add_argument("--batch_size", default=1, type=int)
|
33 |
parser.add_argument("--gradient_accumulation_steps", default=64, type=int)
|
34 |
parser.add_argument("--learning_rate", default=5e-4, type=float)
|
35 |
parser.add_argument("--rms_decay", default=-0.8, type=float)
|
36 |
parser.add_argument("--optimizer_low_memory", default=True, type=bool)
|
|
|
37 |
parser.add_argument("--num_epochs", default=4, type=int)
|
38 |
parser.add_argument("--rank", default=8, type=int)
|
39 |
parser.add_argument("--alpha", default=1.0, type=float)
|
|
|
45 |
"--checkpoint_path", default="./checkpoints/lora_instruction.pt", type=str
|
46 |
)
|
47 |
parser.add_argument("--resume", action="store_true")
|
48 |
+
parser.add_argument("--run_dir_path", default="./runs/instruction-tune", type=str)
|
49 |
parser.add_argument("--device", default="cuda", type=str)
|
50 |
parser.add_argument("--seed", default=None, type=int)
|
51 |
|
|
|
68 |
torch.manual_seed(args.seed)
|
69 |
random.seed(args.seed)
|
70 |
|
71 |
+
logger = SummaryWriter(args.run_dir_path)
|
72 |
+
|
73 |
checkpoint = torch.load(
|
74 |
args.base_model_path, map_location=args.device, weights_only=True
|
75 |
)
|
|
|
80 |
|
81 |
dataset = Alpaca(
|
82 |
tokenizer,
|
83 |
+
max_tokens_per_sample=args.max_tokens_per_sample,
|
84 |
mask_input=args.mask_input,
|
85 |
)
|
86 |
|
|
|
178 |
|
179 |
average_cross_entropy = total_cross_entropy / total_batches
|
180 |
|
181 |
+
logger.add_scalar("cross entropy", average_cross_entropy, epoch)
|
182 |
+
|
183 |
print(
|
184 |
f"Epoch {epoch}: Cross Entropy: {average_cross_entropy:.5f}",
|
185 |
)
|
|
|
198 |
|
199 |
perplexity = perplexity_metric.compute()
|
200 |
|
201 |
+
logger.add_scalar("perplexity", perplexity, epoch)
|
202 |
+
|
203 |
print(f"Perplexity: {perplexity:.3f}")
|
204 |
|
205 |
perplexity_metric.reset()
|
model.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
from math import sqrt
|
2 |
from dataclasses import dataclass
|
3 |
from functools import partial, cached_property
|
4 |
from typing import Iterator, Self
|
@@ -13,8 +13,8 @@ from torch.nn import (
|
|
13 |
Embedding,
|
14 |
MultiheadAttention,
|
15 |
Linear,
|
|
|
16 |
RMSNorm,
|
17 |
-
GELU,
|
18 |
Dropout1d,
|
19 |
CrossEntropyLoss,
|
20 |
Parameter,
|
@@ -27,25 +27,21 @@ from torch.utils.checkpoint import checkpoint as torch_checkpoint
|
|
27 |
|
28 |
|
29 |
class GPT(Module):
|
30 |
-
"""A generative
|
31 |
|
32 |
def __init__(
|
33 |
self,
|
34 |
-
|
35 |
embedding_dimensions: int,
|
36 |
num_heads: int,
|
37 |
num_layers: int,
|
38 |
feed_forward_ratio: int,
|
39 |
dropout: float,
|
40 |
-
vocabulary_size: int,
|
41 |
padding_index: int,
|
42 |
eos_index: int,
|
43 |
):
|
44 |
super().__init__()
|
45 |
|
46 |
-
if block_size < 1:
|
47 |
-
raise ValueError(f"Block size must be greater than 0, {block_size} given.")
|
48 |
-
|
49 |
if num_layers <= 0:
|
50 |
raise ValueError(f"Num layers must be greater than 0, {num_layers} given.")
|
51 |
|
@@ -67,16 +63,10 @@ class GPT(Module):
|
|
67 |
|
68 |
self.token_embeddings = token_embeddings
|
69 |
|
70 |
-
causal_mask = torch.full((block_size, block_size), float("-inf"))
|
71 |
-
causal_mask = torch.triu(causal_mask, diagonal=1)
|
72 |
-
|
73 |
-
self.causal_mask = Buffer(causal_mask, persistent=False)
|
74 |
-
|
75 |
self.body = ModuleList(
|
76 |
[
|
77 |
CausalSelfAttentionBlock(
|
78 |
embedding_dimensions,
|
79 |
-
block_size,
|
80 |
num_heads,
|
81 |
feed_forward_ratio,
|
82 |
dropout,
|
@@ -93,7 +83,6 @@ class GPT(Module):
|
|
93 |
self.loss_function = CrossEntropyLoss(ignore_index=padding_index)
|
94 |
|
95 |
self.vocabulary_size = vocabulary_size
|
96 |
-
self.block_size = block_size
|
97 |
self.eos_index = eos_index
|
98 |
|
99 |
@cached_property
|
@@ -108,9 +97,10 @@ class GPT(Module):
|
|
108 |
) -> tuple[Tensor, Tensor | None]:
|
109 |
z = self.token_embeddings(x)
|
110 |
|
111 |
-
b, t =
|
112 |
|
113 |
-
causal_mask =
|
|
|
114 |
|
115 |
for layer in self.body:
|
116 |
z = self.checkpoint(layer, z, causal_mask)
|
@@ -132,14 +122,15 @@ class GPT(Module):
|
|
132 |
def generate(
|
133 |
self,
|
134 |
prompt: Tensor,
|
135 |
-
max_tokens: int =
|
|
|
136 |
temperature: float = 1.0,
|
137 |
top_k: int = 500,
|
138 |
top_p: float = 0.9,
|
139 |
) -> Iterator:
|
140 |
"""
|
141 |
Given a prompt, sample the next {max_tokens} tokens from the model weighted
|
142 |
-
by their predicted probabilities.
|
143 |
"""
|
144 |
|
145 |
if max_tokens <= 0:
|
@@ -161,7 +152,7 @@ class GPT(Module):
|
|
161 |
context_window = prompt
|
162 |
|
163 |
for _ in range(max_tokens):
|
164 |
-
context_window = context_window[-
|
165 |
|
166 |
y_pred, _ = self.forward(context_window.unsqueeze(0))
|
167 |
|
@@ -201,12 +192,15 @@ class GPT(Module):
|
|
201 |
def beam_search(
|
202 |
self,
|
203 |
prompt: Tensor,
|
204 |
-
max_tokens: int =
|
|
|
205 |
num_candidates: int = 3,
|
206 |
beam_width: int = 16,
|
207 |
) -> list:
|
208 |
"""
|
209 |
-
Given a prompt, return the {num_candidates} highest probability sequences.
|
|
|
|
|
210 |
"""
|
211 |
|
212 |
if max_tokens <= 0:
|
@@ -267,7 +261,7 @@ class GPT(Module):
|
|
267 |
|
268 |
context_window = torch.cat((prompt, candidate.tokens))
|
269 |
|
270 |
-
context_window = context_window[-
|
271 |
|
272 |
y_pred, _ = self.forward(context_window.unsqueeze(0))
|
273 |
|
@@ -293,7 +287,7 @@ class GPT(Module):
|
|
293 |
|
294 |
class GPTWithLoRA(Module):
|
295 |
"""
|
296 |
-
A wrapper for
|
297 |
to the intermediate layers of the network.
|
298 |
"""
|
299 |
|
@@ -382,13 +376,26 @@ class GPTWithLoRA(Module):
|
|
382 |
return self.model.beam_search(prompt, max_tokens, num_candidates, beam_width)
|
383 |
|
384 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
385 |
class CausalSelfAttentionBlock(Module):
|
386 |
"""Causal self-attention block with residual connections."""
|
387 |
|
388 |
def __init__(
|
389 |
self,
|
390 |
embedding_dimensions: int,
|
391 |
-
block_size: int,
|
392 |
num_heads: int,
|
393 |
feed_forward_ratio: int,
|
394 |
dropout: float,
|
@@ -400,9 +407,6 @@ class CausalSelfAttentionBlock(Module):
|
|
400 |
f"Embedding dimensions must be greater than 0, {embedding_dimensions} given."
|
401 |
)
|
402 |
|
403 |
-
if block_size <= 0:
|
404 |
-
raise ValueError(f"Block size must be greater than 0, {block_size} given.")
|
405 |
-
|
406 |
if num_heads <= 0:
|
407 |
raise ValueError(f"Num heads must be greater than 0, {num_heads} given.")
|
408 |
|
@@ -459,7 +463,7 @@ class MLP(Module):
|
|
459 |
|
460 |
self.layers = Sequential(
|
461 |
Linear(embedding_dimensions, hidden_dimensions, bias=False),
|
462 |
-
|
463 |
Linear(hidden_dimensions, embedding_dimensions, bias=False),
|
464 |
)
|
465 |
|
|
|
1 |
+
from math import sqrt
|
2 |
from dataclasses import dataclass
|
3 |
from functools import partial, cached_property
|
4 |
from typing import Iterator, Self
|
|
|
13 |
Embedding,
|
14 |
MultiheadAttention,
|
15 |
Linear,
|
16 |
+
SiLU,
|
17 |
RMSNorm,
|
|
|
18 |
Dropout1d,
|
19 |
CrossEntropyLoss,
|
20 |
Parameter,
|
|
|
27 |
|
28 |
|
29 |
class GPT(Module):
|
30 |
+
"""A generative pretrained transformer."""
|
31 |
|
32 |
def __init__(
|
33 |
self,
|
34 |
+
vocabulary_size: int,
|
35 |
embedding_dimensions: int,
|
36 |
num_heads: int,
|
37 |
num_layers: int,
|
38 |
feed_forward_ratio: int,
|
39 |
dropout: float,
|
|
|
40 |
padding_index: int,
|
41 |
eos_index: int,
|
42 |
):
|
43 |
super().__init__()
|
44 |
|
|
|
|
|
|
|
45 |
if num_layers <= 0:
|
46 |
raise ValueError(f"Num layers must be greater than 0, {num_layers} given.")
|
47 |
|
|
|
63 |
|
64 |
self.token_embeddings = token_embeddings
|
65 |
|
|
|
|
|
|
|
|
|
|
|
66 |
self.body = ModuleList(
|
67 |
[
|
68 |
CausalSelfAttentionBlock(
|
69 |
embedding_dimensions,
|
|
|
70 |
num_heads,
|
71 |
feed_forward_ratio,
|
72 |
dropout,
|
|
|
83 |
self.loss_function = CrossEntropyLoss(ignore_index=padding_index)
|
84 |
|
85 |
self.vocabulary_size = vocabulary_size
|
|
|
86 |
self.eos_index = eos_index
|
87 |
|
88 |
@cached_property
|
|
|
97 |
) -> tuple[Tensor, Tensor | None]:
|
98 |
z = self.token_embeddings(x)
|
99 |
|
100 |
+
b, t, d = z.size()
|
101 |
|
102 |
+
causal_mask = torch.full((t, t), float("-inf"), dtype=z.dtype, device=z.device)
|
103 |
+
causal_mask = torch.triu(causal_mask, diagonal=1)
|
104 |
|
105 |
for layer in self.body:
|
106 |
z = self.checkpoint(layer, z, causal_mask)
|
|
|
122 |
def generate(
|
123 |
self,
|
124 |
prompt: Tensor,
|
125 |
+
max_tokens: int = 1000,
|
126 |
+
context_length: int = 1024,
|
127 |
temperature: float = 1.0,
|
128 |
top_k: int = 500,
|
129 |
top_p: float = 0.9,
|
130 |
) -> Iterator:
|
131 |
"""
|
132 |
Given a prompt, sample the next {max_tokens} tokens from the model weighted
|
133 |
+
by their predicted probabilities and filtered by the {top_k} and {top_p}.
|
134 |
"""
|
135 |
|
136 |
if max_tokens <= 0:
|
|
|
152 |
context_window = prompt
|
153 |
|
154 |
for _ in range(max_tokens):
|
155 |
+
context_window = context_window[-context_length:]
|
156 |
|
157 |
y_pred, _ = self.forward(context_window.unsqueeze(0))
|
158 |
|
|
|
192 |
def beam_search(
|
193 |
self,
|
194 |
prompt: Tensor,
|
195 |
+
max_tokens: int = 100,
|
196 |
+
context_length: int = 1024,
|
197 |
num_candidates: int = 3,
|
198 |
beam_width: int = 16,
|
199 |
) -> list:
|
200 |
"""
|
201 |
+
Given a prompt, return the {num_candidates} highest probability sequences. Note that
|
202 |
+
this method is often best for generating shorter sequences and is typically less
|
203 |
+
natural sounding than sequences that are more random in nature.
|
204 |
"""
|
205 |
|
206 |
if max_tokens <= 0:
|
|
|
261 |
|
262 |
context_window = torch.cat((prompt, candidate.tokens))
|
263 |
|
264 |
+
context_window = context_window[-context_length:]
|
265 |
|
266 |
y_pred, _ = self.forward(context_window.unsqueeze(0))
|
267 |
|
|
|
287 |
|
288 |
class GPTWithLoRA(Module):
|
289 |
"""
|
290 |
+
A wrapper for pretrained GPT models that applies a LoRA reparameterization
|
291 |
to the intermediate layers of the network.
|
292 |
"""
|
293 |
|
|
|
376 |
return self.model.beam_search(prompt, max_tokens, num_candidates, beam_width)
|
377 |
|
378 |
|
379 |
+
class ONNXModel(Module):
|
380 |
+
"""This wrapper provides a cleaner inferencing API for production models."""
|
381 |
+
|
382 |
+
def __init__(self, model: GPT | GPTWithLoRA):
|
383 |
+
super().__init__()
|
384 |
+
|
385 |
+
self.model = model
|
386 |
+
|
387 |
+
def forward(self, x: Tensor) -> Tensor:
|
388 |
+
logits, _ = self.model.forward(x, None)
|
389 |
+
|
390 |
+
return logits
|
391 |
+
|
392 |
+
|
393 |
class CausalSelfAttentionBlock(Module):
|
394 |
"""Causal self-attention block with residual connections."""
|
395 |
|
396 |
def __init__(
|
397 |
self,
|
398 |
embedding_dimensions: int,
|
|
|
399 |
num_heads: int,
|
400 |
feed_forward_ratio: int,
|
401 |
dropout: float,
|
|
|
407 |
f"Embedding dimensions must be greater than 0, {embedding_dimensions} given."
|
408 |
)
|
409 |
|
|
|
|
|
|
|
410 |
if num_heads <= 0:
|
411 |
raise ValueError(f"Num heads must be greater than 0, {num_heads} given.")
|
412 |
|
|
|
463 |
|
464 |
self.layers = Sequential(
|
465 |
Linear(embedding_dimensions, hidden_dimensions, bias=False),
|
466 |
+
SiLU(),
|
467 |
Linear(hidden_dimensions, embedding_dimensions, bias=False),
|
468 |
)
|
469 |
|
model_sizing.ipynb
CHANGED
@@ -9,15 +9,19 @@
|
|
9 |
},
|
10 |
{
|
11 |
"cell_type": "code",
|
12 |
-
"execution_count":
|
13 |
"metadata": {},
|
14 |
"outputs": [],
|
15 |
"source": [
|
16 |
-
"
|
17 |
"vocabulary_size = 50257\n",
|
18 |
"embedding_dimensions = 1024\n",
|
19 |
"num_attention_heads = 16\n",
|
20 |
"num_hidden_layers = 24\n",
|
|
|
|
|
|
|
|
|
21 |
"samples_per_epoch = 4096"
|
22 |
]
|
23 |
},
|
@@ -30,7 +34,7 @@
|
|
30 |
},
|
31 |
{
|
32 |
"cell_type": "code",
|
33 |
-
"execution_count":
|
34 |
"metadata": {},
|
35 |
"outputs": [
|
36 |
{
|
@@ -49,12 +53,12 @@
|
|
49 |
"text": [
|
50 |
"Token Embeddings 51,463,168 14.56%\n",
|
51 |
"Attention 100,663,296 28.48%\n",
|
52 |
-
"MLP 201,326,
|
53 |
"RMS Norm 50,176 0.01%\n",
|
54 |
"Output Layer 0 0.00%\n",
|
55 |
"\n",
|
56 |
"\n",
|
57 |
-
"Total parameters: 353,503,
|
58 |
]
|
59 |
}
|
60 |
],
|
@@ -67,7 +71,11 @@
|
|
67 |
" embedding_dimensions**2 + embedding_dimensions * 3 * embedding_dimensions\n",
|
68 |
" )\n",
|
69 |
" * num_hidden_layers,\n",
|
70 |
-
" \"MLP\": embedding_dimensions
|
|
|
|
|
|
|
|
|
71 |
" \"RMS Norm\": embedding_dimensions * num_hidden_layers * 2 + embedding_dimensions,\n",
|
72 |
" \"Output Layer\": 0, # Tied to token embeddings\n",
|
73 |
"}\n",
|
@@ -99,7 +107,7 @@
|
|
99 |
},
|
100 |
{
|
101 |
"cell_type": "code",
|
102 |
-
"execution_count":
|
103 |
"metadata": {},
|
104 |
"outputs": [
|
105 |
{
|
@@ -125,7 +133,7 @@
|
|
125 |
},
|
126 |
{
|
127 |
"cell_type": "code",
|
128 |
-
"execution_count":
|
129 |
"metadata": {},
|
130 |
"outputs": [
|
131 |
{
|
@@ -151,7 +159,7 @@
|
|
151 |
},
|
152 |
{
|
153 |
"cell_type": "code",
|
154 |
-
"execution_count":
|
155 |
"metadata": {},
|
156 |
"outputs": [
|
157 |
{
|
@@ -181,14 +189,14 @@
|
|
181 |
},
|
182 |
{
|
183 |
"cell_type": "code",
|
184 |
-
"execution_count":
|
185 |
"metadata": {},
|
186 |
"outputs": [
|
187 |
{
|
188 |
"name": "stdout",
|
189 |
"output_type": "stream",
|
190 |
"text": [
|
191 |
-
"Optimal training tokens: 7,070,
|
192 |
"Epochs required: 1,686\n",
|
193 |
"\n"
|
194 |
]
|
@@ -197,7 +205,9 @@
|
|
197 |
"source": [
|
198 |
"num_training_tokens = 20 * total_parameter_count\n",
|
199 |
"\n",
|
200 |
-
"num_epochs_required = round(
|
|
|
|
|
201 |
"\n",
|
202 |
"print(f\"Optimal training tokens: {num_training_tokens:,}\")\n",
|
203 |
"\n",
|
@@ -213,7 +223,7 @@
|
|
213 |
},
|
214 |
{
|
215 |
"cell_type": "code",
|
216 |
-
"execution_count":
|
217 |
"metadata": {},
|
218 |
"outputs": [
|
219 |
{
|
@@ -231,43 +241,57 @@
|
|
231 |
"output_type": "stream",
|
232 |
"text": [
|
233 |
"Attention 309,237,645,312 37.39%\n",
|
234 |
-
"MLP 412,317,
|
235 |
"RMS Norm 179,200 0.00%\n",
|
236 |
"Output Layer 105,396,568,064 12.75%\n",
|
237 |
"\n",
|
238 |
"\n",
|
239 |
-
"Total forward FLOPs: 826,
|
240 |
]
|
241 |
}
|
242 |
],
|
243 |
"source": [
|
244 |
"ops_per_matmul = 2 # Multiply + accumulate (MAC)\n",
|
245 |
-
"ops_per_activation =
|
246 |
"ops_per_rms_norm = 7 # y = (x / sqrt(rms[x] + epsilon)) * gamma\n",
|
247 |
"\n",
|
248 |
"head_dimensions = embedding_dimensions // num_attention_heads\n",
|
249 |
"\n",
|
250 |
"# K, Q, V projections\n",
|
251 |
"attention = (\n",
|
252 |
-
" ops_per_matmul
|
|
|
|
|
253 |
")\n",
|
254 |
"\n",
|
255 |
"# Attention logits\n",
|
256 |
-
"attention +=
|
|
|
|
|
257 |
"\n",
|
258 |
"# Reductions\n",
|
259 |
"attention += (\n",
|
260 |
-
" ops_per_matmul
|
|
|
|
|
261 |
")\n",
|
262 |
"\n",
|
263 |
"# Output projection\n",
|
264 |
-
"attention += ops_per_matmul *
|
265 |
"\n",
|
266 |
"attention *= num_hidden_layers\n",
|
267 |
"\n",
|
268 |
"# Linear transformations\n",
|
269 |
-
"mlp =
|
270 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
271 |
"\n",
|
272 |
"# Non-linear activations\n",
|
273 |
"mlp += ops_per_activation * (4 * embedding_dimensions)\n",
|
@@ -276,7 +300,9 @@
|
|
276 |
"\n",
|
277 |
"rms_norm = ops_per_rms_norm * embedding_dimensions * (num_hidden_layers + 1)\n",
|
278 |
"\n",
|
279 |
-
"output_layer =
|
|
|
|
|
280 |
"\n",
|
281 |
"flops = {\n",
|
282 |
" \"Attention\": attention,\n",
|
@@ -312,14 +338,14 @@
|
|
312 |
},
|
313 |
{
|
314 |
"cell_type": "code",
|
315 |
-
"execution_count":
|
316 |
"metadata": {},
|
317 |
"outputs": [
|
318 |
{
|
319 |
"name": "stdout",
|
320 |
"output_type": "stream",
|
321 |
"text": [
|
322 |
-
"Total backward FLOPs: 1,653,
|
323 |
]
|
324 |
}
|
325 |
],
|
@@ -338,14 +364,14 @@
|
|
338 |
},
|
339 |
{
|
340 |
"cell_type": "code",
|
341 |
-
"execution_count":
|
342 |
"metadata": {},
|
343 |
"outputs": [
|
344 |
{
|
345 |
"name": "stdout",
|
346 |
"output_type": "stream",
|
347 |
"text": [
|
348 |
-
"Total roundtrip FLOPs: 2,480,
|
349 |
]
|
350 |
}
|
351 |
],
|
@@ -364,24 +390,24 @@
|
|
364 |
},
|
365 |
{
|
366 |
"cell_type": "code",
|
367 |
-
"execution_count":
|
368 |
"metadata": {},
|
369 |
"outputs": [
|
370 |
{
|
371 |
"name": "stdout",
|
372 |
"output_type": "stream",
|
373 |
"text": [
|
374 |
-
"Total PaLM FLOPs: 2,481,161,
|
375 |
]
|
376 |
}
|
377 |
],
|
378 |
"source": [
|
379 |
"palm_flops_per_token = (\n",
|
380 |
" 6 * total_parameter_count\n",
|
381 |
-
" + 12 * num_hidden_layers * num_attention_heads * head_dimensions *
|
382 |
")\n",
|
383 |
"\n",
|
384 |
-
"total_palm_flops = palm_flops_per_token *
|
385 |
"\n",
|
386 |
"print(f\"Total PaLM FLOPs: {total_palm_flops:,}\")"
|
387 |
]
|
@@ -390,23 +416,30 @@
|
|
390 |
"cell_type": "markdown",
|
391 |
"metadata": {},
|
392 |
"source": [
|
393 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
394 |
]
|
395 |
},
|
396 |
{
|
397 |
"cell_type": "code",
|
398 |
-
"execution_count":
|
399 |
"metadata": {},
|
400 |
"outputs": [
|
401 |
{
|
402 |
"name": "stdout",
|
403 |
"output_type": "stream",
|
404 |
"text": [
|
405 |
-
"RTX A2000: 935.43 seconds/epoch, 18.25 days required
|
406 |
-
"RTX A4000: 348.64 seconds/epoch, 6.80 days required
|
407 |
-
"RTX 3090: 154.75 seconds/epoch, 3.02 days required
|
408 |
-
"A100 SXM: 44.01 seconds/epoch, 0.86 days required
|
409 |
-
"HGX A100: 6.79 seconds/epoch, 0.13 days required
|
410 |
]
|
411 |
}
|
412 |
],
|
@@ -421,10 +454,6 @@
|
|
421 |
" mfu: float\n",
|
422 |
"\n",
|
423 |
" @property\n",
|
424 |
-
" def percentage_utilization(self) -> float:\n",
|
425 |
-
" return self.mfu * 100\n",
|
426 |
-
"\n",
|
427 |
-
" @property\n",
|
428 |
" def actual_flops(self) -> float:\n",
|
429 |
" return self.mfu * self.advertised_flops\n",
|
430 |
"\n",
|
@@ -443,7 +472,7 @@
|
|
443 |
" days_required = num_epochs_required * seconds_per_epoch / 60 / 60 / 24\n",
|
444 |
"\n",
|
445 |
" print(\n",
|
446 |
-
" f\"{device.name}: {seconds_per_epoch:.2f} seconds/epoch, {days_required:,.2f} days required
|
447 |
" )"
|
448 |
]
|
449 |
}
|
|
|
9 |
},
|
10 |
{
|
11 |
"cell_type": "code",
|
12 |
+
"execution_count": 188,
|
13 |
"metadata": {},
|
14 |
"outputs": [],
|
15 |
"source": [
|
16 |
+
"# Model\n",
|
17 |
"vocabulary_size = 50257\n",
|
18 |
"embedding_dimensions = 1024\n",
|
19 |
"num_attention_heads = 16\n",
|
20 |
"num_hidden_layers = 24\n",
|
21 |
+
"feed_forward_ratio = 4\n",
|
22 |
+
"\n",
|
23 |
+
"# Training set\n",
|
24 |
+
"tokens_per_sample = 1024\n",
|
25 |
"samples_per_epoch = 4096"
|
26 |
]
|
27 |
},
|
|
|
34 |
},
|
35 |
{
|
36 |
"cell_type": "code",
|
37 |
+
"execution_count": 189,
|
38 |
"metadata": {},
|
39 |
"outputs": [
|
40 |
{
|
|
|
53 |
"text": [
|
54 |
"Token Embeddings 51,463,168 14.56%\n",
|
55 |
"Attention 100,663,296 28.48%\n",
|
56 |
+
"MLP 201,326,616 56.95%\n",
|
57 |
"RMS Norm 50,176 0.01%\n",
|
58 |
"Output Layer 0 0.00%\n",
|
59 |
"\n",
|
60 |
"\n",
|
61 |
+
"Total parameters: 353,503,256\n"
|
62 |
]
|
63 |
}
|
64 |
],
|
|
|
71 |
" embedding_dimensions**2 + embedding_dimensions * 3 * embedding_dimensions\n",
|
72 |
" )\n",
|
73 |
" * num_hidden_layers,\n",
|
74 |
+
" \"MLP\": embedding_dimensions\n",
|
75 |
+
" * feed_forward_ratio\n",
|
76 |
+
" * embedding_dimensions\n",
|
77 |
+
" * 2\n",
|
78 |
+
" * num_hidden_layers,\n",
|
79 |
" \"RMS Norm\": embedding_dimensions * num_hidden_layers * 2 + embedding_dimensions,\n",
|
80 |
" \"Output Layer\": 0, # Tied to token embeddings\n",
|
81 |
"}\n",
|
|
|
107 |
},
|
108 |
{
|
109 |
"cell_type": "code",
|
110 |
+
"execution_count": 190,
|
111 |
"metadata": {},
|
112 |
"outputs": [
|
113 |
{
|
|
|
133 |
},
|
134 |
{
|
135 |
"cell_type": "code",
|
136 |
+
"execution_count": 191,
|
137 |
"metadata": {},
|
138 |
"outputs": [
|
139 |
{
|
|
|
159 |
},
|
160 |
{
|
161 |
"cell_type": "code",
|
162 |
+
"execution_count": 192,
|
163 |
"metadata": {},
|
164 |
"outputs": [
|
165 |
{
|
|
|
189 |
},
|
190 |
{
|
191 |
"cell_type": "code",
|
192 |
+
"execution_count": 193,
|
193 |
"metadata": {},
|
194 |
"outputs": [
|
195 |
{
|
196 |
"name": "stdout",
|
197 |
"output_type": "stream",
|
198 |
"text": [
|
199 |
+
"Optimal training tokens: 7,070,065,120\n",
|
200 |
"Epochs required: 1,686\n",
|
201 |
"\n"
|
202 |
]
|
|
|
205 |
"source": [
|
206 |
"num_training_tokens = 20 * total_parameter_count\n",
|
207 |
"\n",
|
208 |
+
"num_epochs_required = round(\n",
|
209 |
+
" num_training_tokens / (samples_per_epoch * tokens_per_sample)\n",
|
210 |
+
")\n",
|
211 |
"\n",
|
212 |
"print(f\"Optimal training tokens: {num_training_tokens:,}\")\n",
|
213 |
"\n",
|
|
|
223 |
},
|
224 |
{
|
225 |
"cell_type": "code",
|
226 |
+
"execution_count": 194,
|
227 |
"metadata": {},
|
228 |
"outputs": [
|
229 |
{
|
|
|
241 |
"output_type": "stream",
|
242 |
"text": [
|
243 |
"Attention 309,237,645,312 37.39%\n",
|
244 |
+
"MLP 412,317,450,240 49.86%\n",
|
245 |
"RMS Norm 179,200 0.00%\n",
|
246 |
"Output Layer 105,396,568,064 12.75%\n",
|
247 |
"\n",
|
248 |
"\n",
|
249 |
+
"Total forward FLOPs: 826,951,842,816\n"
|
250 |
]
|
251 |
}
|
252 |
],
|
253 |
"source": [
|
254 |
"ops_per_matmul = 2 # Multiply + accumulate (MAC)\n",
|
255 |
+
"ops_per_activation = 5 # Assuming SiLU\n",
|
256 |
"ops_per_rms_norm = 7 # y = (x / sqrt(rms[x] + epsilon)) * gamma\n",
|
257 |
"\n",
|
258 |
"head_dimensions = embedding_dimensions // num_attention_heads\n",
|
259 |
"\n",
|
260 |
"# K, Q, V projections\n",
|
261 |
"attention = (\n",
|
262 |
+
" ops_per_matmul\n",
|
263 |
+
" * tokens_per_sample\n",
|
264 |
+
" * (embedding_dimensions * 3 * embedding_dimensions)\n",
|
265 |
")\n",
|
266 |
"\n",
|
267 |
"# Attention logits\n",
|
268 |
+
"attention += (\n",
|
269 |
+
" ops_per_matmul * tokens_per_sample * tokens_per_sample * embedding_dimensions\n",
|
270 |
+
")\n",
|
271 |
"\n",
|
272 |
"# Reductions\n",
|
273 |
"attention += (\n",
|
274 |
+
" ops_per_matmul\n",
|
275 |
+
" * num_attention_heads\n",
|
276 |
+
" * (tokens_per_sample * tokens_per_sample * head_dimensions)\n",
|
277 |
")\n",
|
278 |
"\n",
|
279 |
"# Output projection\n",
|
280 |
+
"attention += ops_per_matmul * tokens_per_sample * embedding_dimensions**2\n",
|
281 |
"\n",
|
282 |
"attention *= num_hidden_layers\n",
|
283 |
"\n",
|
284 |
"# Linear transformations\n",
|
285 |
+
"mlp = (\n",
|
286 |
+
" ops_per_matmul\n",
|
287 |
+
" * tokens_per_sample\n",
|
288 |
+
" * (embedding_dimensions * (4 * embedding_dimensions))\n",
|
289 |
+
")\n",
|
290 |
+
"mlp += (\n",
|
291 |
+
" ops_per_matmul\n",
|
292 |
+
" * tokens_per_sample\n",
|
293 |
+
" * ((4 * embedding_dimensions) * embedding_dimensions)\n",
|
294 |
+
")\n",
|
295 |
"\n",
|
296 |
"# Non-linear activations\n",
|
297 |
"mlp += ops_per_activation * (4 * embedding_dimensions)\n",
|
|
|
300 |
"\n",
|
301 |
"rms_norm = ops_per_rms_norm * embedding_dimensions * (num_hidden_layers + 1)\n",
|
302 |
"\n",
|
303 |
+
"output_layer = (\n",
|
304 |
+
" ops_per_matmul * tokens_per_sample * embedding_dimensions * vocabulary_size\n",
|
305 |
+
")\n",
|
306 |
"\n",
|
307 |
"flops = {\n",
|
308 |
" \"Attention\": attention,\n",
|
|
|
338 |
},
|
339 |
{
|
340 |
"cell_type": "code",
|
341 |
+
"execution_count": 195,
|
342 |
"metadata": {},
|
343 |
"outputs": [
|
344 |
{
|
345 |
"name": "stdout",
|
346 |
"output_type": "stream",
|
347 |
"text": [
|
348 |
+
"Total backward FLOPs: 1,653,903,685,632\n"
|
349 |
]
|
350 |
}
|
351 |
],
|
|
|
364 |
},
|
365 |
{
|
366 |
"cell_type": "code",
|
367 |
+
"execution_count": 196,
|
368 |
"metadata": {},
|
369 |
"outputs": [
|
370 |
{
|
371 |
"name": "stdout",
|
372 |
"output_type": "stream",
|
373 |
"text": [
|
374 |
+
"Total roundtrip FLOPs: 2,480,855,528,448\n"
|
375 |
]
|
376 |
}
|
377 |
],
|
|
|
390 |
},
|
391 |
{
|
392 |
"cell_type": "code",
|
393 |
+
"execution_count": 197,
|
394 |
"metadata": {},
|
395 |
"outputs": [
|
396 |
{
|
397 |
"name": "stdout",
|
398 |
"output_type": "stream",
|
399 |
"text": [
|
400 |
+
"Total PaLM FLOPs: 2,481,161,650,176\n"
|
401 |
]
|
402 |
}
|
403 |
],
|
404 |
"source": [
|
405 |
"palm_flops_per_token = (\n",
|
406 |
" 6 * total_parameter_count\n",
|
407 |
+
" + 12 * num_hidden_layers * num_attention_heads * head_dimensions * tokens_per_sample\n",
|
408 |
")\n",
|
409 |
"\n",
|
410 |
+
"total_palm_flops = palm_flops_per_token * tokens_per_sample\n",
|
411 |
"\n",
|
412 |
"print(f\"Total PaLM FLOPs: {total_palm_flops:,}\")"
|
413 |
]
|
|
|
416 |
"cell_type": "markdown",
|
417 |
"metadata": {},
|
418 |
"source": [
|
419 |
+
"The two estimates are pretty close so let's proceed."
|
420 |
+
]
|
421 |
+
},
|
422 |
+
{
|
423 |
+
"cell_type": "markdown",
|
424 |
+
"metadata": {},
|
425 |
+
"source": [
|
426 |
+
"Finally, let's estimate how long it would take to train over the optimal number of tokens given some common Nvidia Ampere generation GPU hardware configurations. Note that these results shown here are a theoretical scenario and do not factor in additional overhead such as activation checkpointing or network latency."
|
427 |
]
|
428 |
},
|
429 |
{
|
430 |
"cell_type": "code",
|
431 |
+
"execution_count": 198,
|
432 |
"metadata": {},
|
433 |
"outputs": [
|
434 |
{
|
435 |
"name": "stdout",
|
436 |
"output_type": "stream",
|
437 |
"text": [
|
438 |
+
"RTX A2000: 935.43 seconds/epoch, 18.25 days required\n",
|
439 |
+
"RTX A4000: 348.64 seconds/epoch, 6.80 days required\n",
|
440 |
+
"RTX 3090: 154.75 seconds/epoch, 3.02 days required\n",
|
441 |
+
"A100 SXM: 44.01 seconds/epoch, 0.86 days required\n",
|
442 |
+
"HGX A100: 6.79 seconds/epoch, 0.13 days required\n"
|
443 |
]
|
444 |
}
|
445 |
],
|
|
|
454 |
" mfu: float\n",
|
455 |
"\n",
|
456 |
" @property\n",
|
|
|
|
|
|
|
|
|
457 |
" def actual_flops(self) -> float:\n",
|
458 |
" return self.mfu * self.advertised_flops\n",
|
459 |
"\n",
|
|
|
472 |
" days_required = num_epochs_required * seconds_per_epoch / 60 / 60 / 24\n",
|
473 |
"\n",
|
474 |
" print(\n",
|
475 |
+
" f\"{device.name}: {seconds_per_epoch:.2f} seconds/epoch, {days_required:,.2f} days required\"\n",
|
476 |
" )"
|
477 |
]
|
478 |
}
|
pre-train.py → pretrain.py
RENAMED
@@ -16,6 +16,7 @@ from torch.cuda import set_device, is_available as cuda_is_available, is_bf16_su
|
|
16 |
from torch.nn.utils import clip_grad_norm_
|
17 |
from torch.distributed import init_process_group, destroy_process_group
|
18 |
from torch.distributed.fsdp import FullyShardedDataParallel, ShardingStrategy
|
|
|
19 |
|
20 |
from torchmetrics.text import Perplexity
|
21 |
|
@@ -38,7 +39,7 @@ DDP_BACKEND = "nccl"
|
|
38 |
|
39 |
|
40 |
def main():
|
41 |
-
parser = ArgumentParser(description="
|
42 |
|
43 |
parser.add_argument(
|
44 |
"--dataset_subset",
|
@@ -54,6 +55,7 @@ def main():
|
|
54 |
parser.add_argument("--num_dataset_processes", default=8, type=int)
|
55 |
parser.add_argument("--batch_size", default=1, type=int)
|
56 |
parser.add_argument("--gradient_accumulation_steps", default=128, type=int)
|
|
|
57 |
parser.add_argument("--samples_per_epoch", default=4096, type=int)
|
58 |
parser.add_argument("--num_epochs", default=1686, type=int)
|
59 |
parser.add_argument("--learning_rate", default=1e-2, type=float)
|
@@ -61,7 +63,6 @@ def main():
|
|
61 |
parser.add_argument("--low_memory_optimizer", action="store_true")
|
62 |
parser.add_argument("--max_gradient_norm", default=1.0, type=float)
|
63 |
parser.add_argument("--dropout", default=0.1, type=float)
|
64 |
-
parser.add_argument("--block_size", default=1024, type=int)
|
65 |
parser.add_argument("--embedding_dimensions", default=1024, type=int)
|
66 |
parser.add_argument("--num_attention_heads", default=16, type=int)
|
67 |
parser.add_argument("--num_hidden_layers", default=24, type=int)
|
@@ -74,6 +75,7 @@ def main():
|
|
74 |
"--checkpoint_path", default="./checkpoints/checkpoint.pt", type=str
|
75 |
)
|
76 |
parser.add_argument("--resume", action="store_true")
|
|
|
77 |
parser.add_argument("--device", default="cuda", type=str)
|
78 |
parser.add_argument("--seed", default=None, type=int)
|
79 |
|
@@ -156,23 +158,25 @@ def main():
|
|
156 |
torch.manual_seed(args.seed)
|
157 |
random.seed(args.seed)
|
158 |
|
|
|
|
|
159 |
tokenizer = tiktoken.get_encoding(args.token_encoding)
|
160 |
|
161 |
training = Fineweb(
|
162 |
-
tokenizer,
|
163 |
root_path=args.dataset_path,
|
164 |
subset=args.dataset_subset,
|
165 |
split="train",
|
166 |
-
tokens_per_sample=args.
|
167 |
samples_per_epoch=args.samples_per_epoch,
|
168 |
num_processes=args.num_dataset_processes,
|
169 |
)
|
170 |
testing = Fineweb(
|
171 |
-
tokenizer,
|
172 |
root_path=args.dataset_path,
|
173 |
subset=args.dataset_subset,
|
174 |
split="test",
|
175 |
-
tokens_per_sample=args.
|
176 |
samples_per_epoch=args.samples_per_epoch,
|
177 |
num_processes=args.num_dataset_processes,
|
178 |
)
|
@@ -185,13 +189,12 @@ def main():
|
|
185 |
)
|
186 |
|
187 |
model_args = {
|
188 |
-
"
|
189 |
"embedding_dimensions": args.embedding_dimensions,
|
190 |
"num_heads": args.num_attention_heads,
|
191 |
"num_layers": args.num_hidden_layers,
|
192 |
"feed_forward_ratio": args.feed_forward_ratio,
|
193 |
"dropout": args.dropout,
|
194 |
-
"vocabulary_size": tokenizer.n_vocab,
|
195 |
"padding_index": training.PADDING_INDEX,
|
196 |
"eos_index": tokenizer.eot_token,
|
197 |
}
|
@@ -252,7 +255,7 @@ def main():
|
|
252 |
|
253 |
register_signal_handlers()
|
254 |
|
255 |
-
print("
|
256 |
|
257 |
for epoch in range(starting_epoch, args.num_epochs + 1):
|
258 |
total_cross_entropy, total_gradient_norm = 0.0, 0.0
|
@@ -292,14 +295,18 @@ def main():
|
|
292 |
|
293 |
total_batches += 1
|
294 |
|
295 |
-
|
296 |
-
|
|
|
297 |
|
298 |
-
|
299 |
-
|
300 |
-
|
301 |
-
|
302 |
-
|
|
|
|
|
|
|
303 |
|
304 |
if epoch % args.eval_interval == 0 and IS_MASTER:
|
305 |
model.eval()
|
@@ -315,6 +322,8 @@ def main():
|
|
315 |
|
316 |
perplexity = perplexity_metric.compute()
|
317 |
|
|
|
|
|
318 |
print(f"Perplexity: {perplexity:.3f}")
|
319 |
|
320 |
perplexity_metric.reset()
|
|
|
16 |
from torch.nn.utils import clip_grad_norm_
|
17 |
from torch.distributed import init_process_group, destroy_process_group
|
18 |
from torch.distributed.fsdp import FullyShardedDataParallel, ShardingStrategy
|
19 |
+
from torch.utils.tensorboard import SummaryWriter
|
20 |
|
21 |
from torchmetrics.text import Perplexity
|
22 |
|
|
|
39 |
|
40 |
|
41 |
def main():
|
42 |
+
parser = ArgumentParser(description="Pretrain the GPT.")
|
43 |
|
44 |
parser.add_argument(
|
45 |
"--dataset_subset",
|
|
|
55 |
parser.add_argument("--num_dataset_processes", default=8, type=int)
|
56 |
parser.add_argument("--batch_size", default=1, type=int)
|
57 |
parser.add_argument("--gradient_accumulation_steps", default=128, type=int)
|
58 |
+
parser.add_argument("--tokens_per_sample", default=1024, type=int)
|
59 |
parser.add_argument("--samples_per_epoch", default=4096, type=int)
|
60 |
parser.add_argument("--num_epochs", default=1686, type=int)
|
61 |
parser.add_argument("--learning_rate", default=1e-2, type=float)
|
|
|
63 |
parser.add_argument("--low_memory_optimizer", action="store_true")
|
64 |
parser.add_argument("--max_gradient_norm", default=1.0, type=float)
|
65 |
parser.add_argument("--dropout", default=0.1, type=float)
|
|
|
66 |
parser.add_argument("--embedding_dimensions", default=1024, type=int)
|
67 |
parser.add_argument("--num_attention_heads", default=16, type=int)
|
68 |
parser.add_argument("--num_hidden_layers", default=24, type=int)
|
|
|
75 |
"--checkpoint_path", default="./checkpoints/checkpoint.pt", type=str
|
76 |
)
|
77 |
parser.add_argument("--resume", action="store_true")
|
78 |
+
parser.add_argument("--run_dir_path", default="./runs/pretrain", type=str)
|
79 |
parser.add_argument("--device", default="cuda", type=str)
|
80 |
parser.add_argument("--seed", default=None, type=int)
|
81 |
|
|
|
158 |
torch.manual_seed(args.seed)
|
159 |
random.seed(args.seed)
|
160 |
|
161 |
+
logger = SummaryWriter(args.run_dir_path)
|
162 |
+
|
163 |
tokenizer = tiktoken.get_encoding(args.token_encoding)
|
164 |
|
165 |
training = Fineweb(
|
166 |
+
tokenizer=tokenizer,
|
167 |
root_path=args.dataset_path,
|
168 |
subset=args.dataset_subset,
|
169 |
split="train",
|
170 |
+
tokens_per_sample=args.tokens_per_sample,
|
171 |
samples_per_epoch=args.samples_per_epoch,
|
172 |
num_processes=args.num_dataset_processes,
|
173 |
)
|
174 |
testing = Fineweb(
|
175 |
+
tokenizer=tokenizer,
|
176 |
root_path=args.dataset_path,
|
177 |
subset=args.dataset_subset,
|
178 |
split="test",
|
179 |
+
tokens_per_sample=args.tokens_per_sample,
|
180 |
samples_per_epoch=args.samples_per_epoch,
|
181 |
num_processes=args.num_dataset_processes,
|
182 |
)
|
|
|
189 |
)
|
190 |
|
191 |
model_args = {
|
192 |
+
"vocabulary_size": tokenizer.n_vocab,
|
193 |
"embedding_dimensions": args.embedding_dimensions,
|
194 |
"num_heads": args.num_attention_heads,
|
195 |
"num_layers": args.num_hidden_layers,
|
196 |
"feed_forward_ratio": args.feed_forward_ratio,
|
197 |
"dropout": args.dropout,
|
|
|
198 |
"padding_index": training.PADDING_INDEX,
|
199 |
"eos_index": tokenizer.eot_token,
|
200 |
}
|
|
|
255 |
|
256 |
register_signal_handlers()
|
257 |
|
258 |
+
print("Pretraining ...")
|
259 |
|
260 |
for epoch in range(starting_epoch, args.num_epochs + 1):
|
261 |
total_cross_entropy, total_gradient_norm = 0.0, 0.0
|
|
|
295 |
|
296 |
total_batches += 1
|
297 |
|
298 |
+
if IS_MASTER:
|
299 |
+
average_cross_entropy = total_cross_entropy / total_batches
|
300 |
+
average_gradient_norm = total_gradient_norm / total_steps
|
301 |
|
302 |
+
logger.add_scalar("cross entropy", average_cross_entropy, epoch)
|
303 |
+
logger.add_scalar("gradient norm", average_gradient_norm, epoch)
|
304 |
+
|
305 |
+
print(
|
306 |
+
f"Epoch {epoch}:",
|
307 |
+
f"Cross Entropy: {average_cross_entropy:.5f},",
|
308 |
+
f"Gradient Norm: {average_gradient_norm:.4f}",
|
309 |
+
)
|
310 |
|
311 |
if epoch % args.eval_interval == 0 and IS_MASTER:
|
312 |
model.eval()
|
|
|
322 |
|
323 |
perplexity = perplexity_metric.compute()
|
324 |
|
325 |
+
logger.add_scalar("perplexity", perplexity, epoch)
|
326 |
+
|
327 |
print(f"Perplexity: {perplexity:.3f}")
|
328 |
|
329 |
perplexity_metric.reset()
|
requirements.txt
CHANGED
@@ -6,5 +6,7 @@ tiktoken==0.8.0
|
|
6 |
tqdm==4.66.6
|
7 |
matplotlib==3.9.2
|
8 |
safetensors==0.5.2
|
9 |
-
|
|
|
10 |
onnxruntime==1.20.1
|
|
|
|
6 |
tqdm==4.66.6
|
7 |
matplotlib==3.9.2
|
8 |
safetensors==0.5.2
|
9 |
+
onnx==1.17.0
|
10 |
+
onnxscript==0.1.0.dev20250108
|
11 |
onnxruntime==1.20.1
|
12 |
+
tensorboard==2.18.0
|
runs/.gitignore
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
*
|
2 |
+
!.gitignore
|