ThomasTheMaker commited on
Commit
feba2ad
·
verified ·
1 Parent(s): 6557434

Upload folder using huggingface_hub

Browse files
Files changed (47) hide show
  1. .env.example +2 -0
  2. .gitignore +174 -0
  3. .pre-commit-config.yaml +10 -0
  4. LICENSE +201 -0
  5. README.md +170 -0
  6. configs/examples/demo.yaml +48 -0
  7. configs/examples/pico-decoder-large.yaml +35 -0
  8. configs/examples/pico-decoder-medium.yaml +35 -0
  9. configs/examples/pico-decoder-small.yaml +35 -0
  10. configs/examples/pico-decoder-tiny.yaml +35 -0
  11. configs/pico-decoder-tiny-dolma10M-v1.yaml +78 -0
  12. configs/pico-decoder-tiny-dolma20M-v1.yaml +78 -0
  13. configs/pico-decoder-tiny-dolma5M-v1.yaml +78 -0
  14. plots/.gitignore +74 -0
  15. plots/404.html +33 -0
  16. plots/README.md +90 -0
  17. plots/code.js +550 -0
  18. plots/data.json +0 -0
  19. plots/index.html +72 -0
  20. plots/style.css +258 -0
  21. pyproject.toml +33 -0
  22. scripts/README.md +109 -0
  23. scripts/generate_data.py +198 -0
  24. scripts/train.py +30 -0
  25. setup.sh +200 -0
  26. src/checkpointing/__init__.py +23 -0
  27. src/checkpointing/evaluation.py +68 -0
  28. src/checkpointing/learning_dynamics.py +424 -0
  29. src/checkpointing/training.py +287 -0
  30. src/config/__init__.py +31 -0
  31. src/config/_constants.py +18 -0
  32. src/config/checkpointing_config.py +97 -0
  33. src/config/data_config.py +36 -0
  34. src/config/evaluation_config.py +28 -0
  35. src/config/model_config.py +33 -0
  36. src/config/monitoring_config.py +29 -0
  37. src/config/training_config.py +40 -0
  38. src/evaluation/__init__.py +103 -0
  39. src/evaluation/tasks/paloma.py +52 -0
  40. src/model/__init__.py +12 -0
  41. src/model/pico_decoder.py +911 -0
  42. src/training/trainer.py +753 -0
  43. src/training/utils/__init__.py +34 -0
  44. src/training/utils/data.py +35 -0
  45. src/training/utils/initialization.py +702 -0
  46. src/training/utils/io.py +52 -0
  47. src/training/utils/logging.py +48 -0
.env.example ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ WANDB_API_KEY=your_wandb_key
2
+ HF_TOKEN=your_huggingface_token
.gitignore ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ poetry.lock
30
+
31
+ # PyInstaller
32
+ # Usually these files are written by a python script from a template
33
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
34
+ *.manifest
35
+ *.spec
36
+
37
+ # Installer logs
38
+ pip-log.txt
39
+ pip-delete-this-directory.txt
40
+
41
+ # Unit test / coverage reports
42
+ htmlcov/
43
+ .tox/
44
+ .nox/
45
+ .coverage
46
+ .coverage.*
47
+ .cache
48
+ nosetests.xml
49
+ coverage.xml
50
+ *.cover
51
+ *.py,cover
52
+ .hypothesis/
53
+ .pytest_cache/
54
+ cover/
55
+
56
+ # Translations
57
+ *.mo
58
+ *.pot
59
+
60
+ # Django stuff:
61
+ *.log
62
+ local_settings.py
63
+ db.sqlite3
64
+ db.sqlite3-journal
65
+
66
+ # Flask stuff:
67
+ instance/
68
+ .webassets-cache
69
+
70
+ # Scrapy stuff:
71
+ .scrapy
72
+
73
+ # Sphinx documentation
74
+ docs/_build/
75
+
76
+ # PyBuilder
77
+ .pybuilder/
78
+ target/
79
+
80
+ # Jupyter Notebook
81
+ .ipynb_checkpoints
82
+
83
+ # IPython
84
+ profile_default/
85
+ ipython_config.py
86
+
87
+ # pyenv
88
+ # For a library or package, you might want to ignore these files since the code is
89
+ # intended to run in multiple environments; otherwise, check them in:
90
+ # .python-version
91
+
92
+ # pipenv
93
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
94
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
95
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
96
+ # install all needed dependencies.
97
+ #Pipfile.lock
98
+
99
+ # poetry
100
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
101
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
102
+ # commonly ignored for libraries.
103
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
104
+ #poetry.lock
105
+
106
+ # pdm
107
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
108
+ #pdm.lock
109
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
110
+ # in version control.
111
+ # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
112
+ .pdm.toml
113
+ .pdm-python
114
+ .pdm-build/
115
+
116
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
117
+ __pypackages__/
118
+
119
+ # Celery stuff
120
+ celerybeat-schedule
121
+ celerybeat.pid
122
+
123
+ # SageMath parsed files
124
+ *.sage.py
125
+
126
+ # Environments
127
+ .env
128
+ .venv
129
+ env/
130
+ venv/
131
+ ENV/
132
+ env.bak/
133
+ venv.bak/
134
+
135
+ # Spyder project settings
136
+ .spyderproject
137
+ .spyproject
138
+
139
+ # Rope project settings
140
+ .ropeproject
141
+
142
+ # mkdocs documentation
143
+ /site
144
+
145
+ # mypy
146
+ .mypy_cache/
147
+ .dmypy.json
148
+ dmypy.json
149
+
150
+ # Pyre type checker
151
+ .pyre/
152
+
153
+ # pytype static type analyzer
154
+ .pytype/
155
+
156
+ # Cython debug symbols
157
+ cython_debug/
158
+
159
+ # PyCharm
160
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
161
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
162
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
163
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
164
+ #.idea/
165
+
166
+ # Data
167
+ data/
168
+
169
+ # Checkpoint and Logging Directorries
170
+ runs/
171
+ wandb/
172
+ # configs/
173
+
174
+ .vscode/
.pre-commit-config.yaml ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ repos:
2
+ - repo: https://github.com/astral-sh/ruff-pre-commit
3
+ # Ruff version.
4
+ rev: v0.7.1
5
+ hooks:
6
+ # Run the linter.
7
+ - id: ruff
8
+ args: [ --fix, --extend-select, I ]
9
+ # Run the formatter.
10
+ - id: ruff-format
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
README.md ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 🚀 **Pico Train**
2
+
3
+ Pico Train is a lightweight framework for training language models—from tiny-scale (~1M parameters) to mid-scale (~1B parameters)—with built-in rich checkpointing that captures activations, gradients, and model states, enabling detailed learning dynamics research.
4
+
5
+ Our **suite of pre-trained models** is already publicly available on our [Hugging Face organization](https://huggingface.co/pico-lm), and a dedicated companion library for advanced analysis—[**pico-analyze**](https://github.com/pico-lm/pico-analyze)—is fully released for deeper checkpoint studies.
6
+
7
+ > For a **detailed run-through**, check out the **full tutorial** on our website at [picolm.io](https://picolm.io).
8
+
9
+ ---
10
+
11
+ ## **Key Features**
12
+
13
+ 1. **Pico Decoder: LLAMA-style Transformer Architecture**
14
+ - RMSNorm, RoPE, multi-head self-attention with KV-cache, and SwiGLU activations
15
+ - Currently supports the **pico-decoder** model, with future expansions planned (pico-diffusion, pico-statespace, etc.)
16
+
17
+ 2. **Comprehensive Checkpoints**
18
+ - Saves model states, optimizer states, and training metadata
19
+ - Enriched with **activation and gradient** snapshots for interpretability
20
+
21
+ 3. **Focused Scale Range**
22
+ - Optimized to train models from **1M to 1B parameters**, where learning dynamics research is most viable
23
+
24
+ 4. **Clean, Pre-tokenized Data**
25
+ - Uses a pre-tokenized, pre-shuffled version of [Dolma](https://allenai.org/dolma) that we make available on [Hugging Face](https://huggingface.co/datasets/pico-lm/pretokenized-dolma)
26
+ - Facilitates training models using identical data for **consistency** and **comparability**
27
+
28
+ 6. **Research Ready**
29
+ - Minimal, well-documented code suitable for **forking and tailoring**
30
+ - Logs essential metrics (e.g. perplexity) throughout training
31
+ - Works seamlessly with [pico-analyze](https://github.com/pico-lm/pico-analyze) for advanced post-training interpretation
32
+
33
+ ---
34
+
35
+ ## **Training Philosophy**
36
+
37
+ All models in the Pico suite (both pre-trained and user-trained):
38
+
39
+ - Employ **identical architectures** and **optimizer settings**
40
+ - **Share** the same data order and tokens
41
+ - Automatically log **rich checkpoint data** (including activations, gradients)
42
+ - Facilitate **direct cross-scale comparisons**
43
+
44
+ This uniformity means you can isolate model size as the primary variable, giving you clearer insights into **how model capacity affects learning**.
45
+
46
+ ---
47
+
48
+ ## **Resources**
49
+
50
+ - **Pre-trained Models** (1M–1B parameters), publicly hosted on [Hugging Face](https://huggingface.co/pico-lm)
51
+ - **Pre-tokenized Datasets** for straightforward streaming-based training
52
+ - **Extensive Checkpoints** logging activation and gradient snapshots
53
+ - **Evaluation Metrics** (perplexity and more) tracked at each checkpoint
54
+
55
+ ---
56
+
57
+ ## **Core Components**
58
+
59
+ - **Pico-Decoder Model**
60
+ - LLAMA-style auto-regressive transformer
61
+ - RMSNorm
62
+ - RoPE (Rotary Positional Embeddings)
63
+ - Multi-head attention with KV-cache
64
+ - SwiGLU activation
65
+
66
+ *Future plans include additional architectures like pico-diffusion and pico-statespace.*
67
+
68
+ - **Training & Checkpointing**
69
+ - Automatic storage of model and optimizer states
70
+ - Periodic hooks for saving **learning dynamics** (activations, gradients)
71
+ - Optional logging to Weights & Biases
72
+
73
+ - **Config-Driven Setup**
74
+ - Specify architecture, optimizer, dataset, and logging settings in YAML
75
+ - Straightforward to extend or modify
76
+
77
+ ---
78
+
79
+ ## **Quick Start**
80
+
81
+ 1. **Clone the Repository**
82
+
83
+ ```bash
84
+ git clone https://github.com/pico-lm/pico-train
85
+ cd pico-train
86
+ ```
87
+
88
+ 2. **Configure Environment**
89
+
90
+ Create a `.env` file at the root with your Hugging Face and Weights & Biases tokens:
91
+ ```bash
92
+ export HF_TOKEN=your_huggingface_token
93
+ export WANDB_API_KEY=your_wandb_key
94
+ ```
95
+
96
+ 3. **Install Dependencies**
97
+
98
+ ```bash
99
+ source setup.sh
100
+ ```
101
+ This script checks your environment, installs necessary tools, and sets up a Poetry virtual environment.
102
+
103
+ 4. **Train Your Model Suite**
104
+
105
+ - Edit (or create) a config file (e.g., `configs/demo.yaml`) to specify your architecture and training preferences.
106
+ - Then run:
107
+ ```bash
108
+ poetry run train --config_path configs/demo.yaml
109
+ ```
110
+ - This launches training, automatically checkpointing states and saving learning dynamics data.
111
+
112
+ 5. **Explore Checkpoints**
113
+ - By default, checkpoints are stored under `runs/YOUR_RUN_NAME/checkpoints/`.
114
+ - Each checkpoint contains:
115
+ - **Model state** (PyTorch + Hugging Face formats)
116
+ - **Optimizer state**
117
+ - **Gradients and activations** for interpretability
118
+ - **Evaluation logs** (e.g. perplexity) and metrics
119
+
120
+ ---
121
+
122
+ ## **Repository Structure**
123
+
124
+ - **`src/model/pico_decoder.py`**
125
+ - Core LLAMA-style decoder implementation (attention, RMSNorm, RoPE, etc.)
126
+
127
+ - **`src/training/trainer.py`**
128
+ - Main training loop
129
+ - Manages distributed and multi-node settings
130
+ - Collects/logs metrics
131
+ - Orchestrates checkpoint saving
132
+
133
+ - **`src/checkpointing`**
134
+ - Logic for saving model states, gradients, activations
135
+ - Tools for uploading checkpoints to Hugging Face
136
+
137
+ - **`src/config`**
138
+ - Flexible Dataclass-based config system (model and training hyperparameters, checkpointing, logging)
139
+
140
+ - **`configs/demo.yaml`**
141
+ - Example config with default values for quick experimentation
142
+
143
+ ---
144
+
145
+ ## **Advanced Analysis with Pico Analyze**
146
+
147
+ For deeper checkpoint analysis—comparing gradients, tracking representation shifts, measuring sparsity—use our companion repository [**pico-analyze**](https://github.com/pico-lm/pico-analyze). It automatically processes **pico-train** checkpoints and applies advanced metrics like **CKA**, **PWCCA**, **Gini**, **Hoyer**, and more to reveal **how** your models learn over time.
148
+
149
+ ---
150
+
151
+ ## **License**
152
+
153
+ Pico is open-source under the [Apache License 2.0](LICENSE).
154
+
155
+ ---
156
+
157
+ ## **Citation**
158
+
159
+ If you use **Pico** in your research, please cite:
160
+
161
+ ```bibtex
162
+ @software{pico2025,
163
+ author = {Diehl Martinez, Richard},
164
+ title = {Pico: A Lightweight Framework for Studying Language Model Learning Dynamics},
165
+ year = {2025},
166
+ url = {https://github.com/pico-lm}
167
+ }
168
+ ```
169
+
170
+ **Happy Training!** For more information and tutorials, visit our website at [picolm.io](https://picolm.io).
configs/examples/demo.yaml ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Demo config file
2
+ # You can follow this template to create your own config file
3
+ # Refer to the config files in the configs/ directory to see all the available options
4
+
5
+ data:
6
+ dataloader:
7
+ batch_size: 32
8
+
9
+ checkpointing:
10
+ run_name: "pico-decoder-demo-1"
11
+ save_every_n_steps: 50
12
+
13
+ save_to_hf: true
14
+ hf_checkpoint:
15
+ repo_id: "pico-lm/demo"
16
+
17
+ learning_dynamics:
18
+ batch_size: 16
19
+
20
+ model:
21
+ d_model: 96
22
+ activation_hidden_dim: 384
23
+
24
+ evaluation:
25
+ paloma:
26
+ batch_size: 32
27
+
28
+ monitoring:
29
+
30
+ save_to_wandb: true
31
+ wandb:
32
+ project: "pico-demo"
33
+ entity: "pico-lm"
34
+
35
+ logging:
36
+ log_every_n_steps: 10
37
+
38
+ training:
39
+ max_steps: 100
40
+
41
+ optimization:
42
+ lr: 0.001
43
+ lr_warmup_steps: 30
44
+
45
+ gradient_accumulation_steps: 2
46
+
47
+ fabric:
48
+ num_devices: 1
configs/examples/pico-decoder-large.yaml ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Demo config file
2
+ # You can follow this template to create your own config file
3
+ # Refer to the config files in the configs/ directory to see all the available options
4
+
5
+ checkpointing:
6
+ run_name: "pico-decoder-large-1"
7
+ save_to_hf: true
8
+ hf_checkpoint:
9
+ repo_id: "pico-lm/pico-decoder-large"
10
+
11
+ learning_dynamics:
12
+ batch_size: 128
13
+
14
+ model:
15
+ d_model: 1536
16
+ activation_hidden_dim: 6144
17
+
18
+ monitoring:
19
+ save_to_wandb: true
20
+ wandb:
21
+ project: "pico-decoder"
22
+ entity: "pico-lm"
23
+
24
+ training:
25
+ optimization:
26
+ gradient_accumulation_steps: 8
27
+
28
+ fabric:
29
+ num_nodes: 4
30
+ num_devices: 4
31
+
32
+ evaluation:
33
+ paloma:
34
+ batch_size: 16
35
+
configs/examples/pico-decoder-medium.yaml ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Demo config file
2
+ # You can follow this template to create your own config file
3
+ # Refer to the config files in the configs/ directory to see all the available options
4
+
5
+ checkpointing:
6
+ run_name: "pico-decoder-medium-1"
7
+ save_to_hf: true
8
+ hf_checkpoint:
9
+ repo_id: "pico-lm/pico-decoder-medium"
10
+
11
+ learning_dynamics:
12
+ batch_size: 128
13
+
14
+ model:
15
+ d_model: 768
16
+ activation_hidden_dim: 3072
17
+
18
+ monitoring:
19
+ save_to_wandb: true
20
+ wandb:
21
+ project: "pico-decoder"
22
+ entity: "pico-lm"
23
+
24
+ training:
25
+ optimization:
26
+ gradient_accumulation_steps: 8
27
+
28
+ fabric:
29
+ num_nodes: 4
30
+ num_devices: 4
31
+
32
+ evaluation:
33
+ paloma:
34
+ batch_size: 16
35
+
configs/examples/pico-decoder-small.yaml ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Demo config file
2
+ # You can follow this template to create your own config file
3
+ # Refer to the config files in the configs/ directory to see all the available options
4
+
5
+ checkpointing:
6
+ run_name: "pico-decoder-small-1"
7
+ save_to_hf: true
8
+ hf_checkpoint:
9
+ repo_id: "pico-lm/pico-decoder-small"
10
+
11
+ learning_dynamics:
12
+ batch_size: 128
13
+
14
+ model:
15
+ d_model: 384
16
+ activation_hidden_dim: 1536
17
+
18
+ monitoring:
19
+ save_to_wandb: true
20
+ wandb:
21
+ project: "pico-decoder"
22
+ entity: "pico-lm"
23
+
24
+ training:
25
+ optimization:
26
+ gradient_accumulation_steps: 8
27
+
28
+ fabric:
29
+ num_nodes: 4
30
+ num_devices: 4
31
+
32
+ evaluation:
33
+ paloma:
34
+ batch_size: 16
35
+
configs/examples/pico-decoder-tiny.yaml ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Demo config file
2
+ # You can follow this template to create your own config file
3
+ # Refer to the config files in the configs/ directory to see all the available options
4
+
5
+ checkpointing:
6
+ run_name: "pico-decoder-tiny-1"
7
+ save_to_hf: true
8
+ hf_checkpoint:
9
+ repo_id: "pico-lm/pico-decoder-tiny"
10
+
11
+ learning_dynamics:
12
+ batch_size: 256
13
+
14
+ model:
15
+ d_model: 96
16
+ activation_hidden_dim: 384
17
+
18
+ monitoring:
19
+ save_to_wandb: true
20
+ wandb:
21
+ project: "pico-decoder"
22
+ entity: "pico-lm"
23
+
24
+ training:
25
+ optimization:
26
+ gradient_accumulation_steps: 4
27
+
28
+ fabric:
29
+ num_nodes: 4
30
+ num_devices: 4
31
+
32
+ evaluation:
33
+ paloma:
34
+ batch_size: 32
35
+
configs/pico-decoder-tiny-dolma10M-v1.yaml ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # High Quality Training Config - Optimized for H100 80GB Performance
2
+ # Fast training configuration maintaining identical model quality
3
+ # Optimized for H100 80GB with maximum throughput while preserving stability
4
+ # Updated for efficient training on Dolma 10M tokens with H100-optimized hyperparameters
5
+
6
+ checkpointing:
7
+ run_name: "pico-decoder-tiny-dolma10M-v1"
8
+ save_to_hf: true
9
+ hf_checkpoint:
10
+ repo_id: "ThomasTheMaker/pico-decoder-tiny"
11
+ save_every_n_steps: 2000 # Reduced checkpoint frequency for faster training
12
+
13
+ learning_dynamics:
14
+ batch_size: 1 # Minimal batch size for learning dynamics
15
+ eval_data: null # Disable learning dynamics to save memory
16
+
17
+ model:
18
+ d_model: 96
19
+ activation_hidden_dim: 384
20
+ dropout: 0.15 # Increased dropout for stronger regularization
21
+ attention_dropout: 0.15 # Increased attention dropout
22
+ layer_norm_eps: 1e-5 # Tighter normalization for stability
23
+ weight_init_type: "truncated_normal" # Truncated normal for stability
24
+ layer_norm_type: "rms_norm" # RMSNorm for better stability
25
+ use_qk_norm: true # Query-Key normalization for attention stability
26
+
27
+ monitoring:
28
+ save_to_wandb: false
29
+ wandb:
30
+ project: "pico-decoder-tiny"
31
+ entity: "boymyc"
32
+ logging:
33
+ log_every_n_steps: 100 # Reduced logging frequency for faster training
34
+
35
+ training:
36
+ max_steps: 100000 # Longer training for better convergence
37
+ optimization:
38
+ lr: 0.0002 # Scaled learning rate for larger batch size (4x increase)
39
+ lr_warmup_steps: 2000 # Reduced warmup for faster convergence
40
+ lr_scheduler: "cosine" # Cosine decay over full dataset for sustained learning
41
+ weight_decay: 0.02 # Increased weight decay for stronger regularization
42
+ max_grad_norm: 0.5 # Tighter gradient clipping for stability
43
+ gradient_accumulation_steps: 1 # Reduced for faster training with larger batches
44
+ optimizer: "adamw"
45
+ adam_beta1: 0.9 # Standard AdamW beta1
46
+ adam_beta2: 0.999 # Standard AdamW beta2
47
+ adam_epsilon: 1e-8 # Tighter epsilon for numerical stability and convergence
48
+
49
+ fabric:
50
+ num_nodes: 1
51
+ num_devices: 1
52
+ precision: "bf16-mixed" # BF16 for Tensor Core optimization
53
+
54
+ evaluation:
55
+ paloma:
56
+ batch_size: 1 # Minimal evaluation batch size
57
+ eval_every_n_steps: 1000 # Reduced evaluation frequency for faster training
58
+
59
+ data:
60
+ dataset:
61
+ name: "ThomasTheMaker/pretokenized-dolma-10M" # Updated to 5M token dataset
62
+ dataloader:
63
+ batch_size: 16 # Conservative H100 optimization - 4x larger for stable fast training
64
+ tokenizer:
65
+ name: "allenai/OLMo-7B-0724-hf"
66
+ vocab_size: 50304
67
+
68
+ # H100-optimized training strategy for fast, memory-safe training:
69
+ # 1. Conservative batch size (16) with scaled learning rate (0.0002) for stable H100 utilization
70
+ # 2. Reduced gradient accumulation (1 step) for faster optimization cycles
71
+ # 3. Shorter warmup (2000 steps) for quicker convergence with larger batches
72
+ # 4. Reduced evaluation frequency (1000 steps) to minimize training interruptions
73
+ # 5. Reduced checkpoint/logging frequency to minimize I/O overhead
74
+ # 6. Same model architecture and regularization for identical final performance
75
+ # 7. Expected 4-6x training speedup while maintaining model quality and memory safety
76
+ # 8. Memory usage: ~15-25GB of 80GB H100 VRAM (safe utilization avoiding OOM)
77
+ # 9. Maintains all stability features: RMSNorm, QK-Norm, dropout, weight decay
78
+ # 10. Same convergence quality with significant speedup and no memory issues
configs/pico-decoder-tiny-dolma20M-v1.yaml ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # High Quality Training Config - Optimized for H100 80GB Performance
2
+ # Fast training configuration maintaining identical model quality
3
+ # Optimized for H100 80GB with maximum throughput while preserving stability
4
+ # Updated for efficient training on Dolma 10M tokens with H100-optimized hyperparameters
5
+
6
+ checkpointing:
7
+ run_name: "pico-decoder-tiny-dolma20M-v1"
8
+ save_to_hf: false
9
+ hf_checkpoint:
10
+ repo_id: "ThomasTheMaker/pico-decoder-tiny"
11
+ save_every_n_steps: 1000 # Reduced checkpoint frequency for faster training
12
+
13
+ learning_dynamics:
14
+ batch_size: 1 # Minimal batch size for learning dynamics
15
+ eval_data: null # Disable learning dynamics to save memory
16
+
17
+ model:
18
+ d_model: 96
19
+ activation_hidden_dim: 384
20
+ dropout: 0.15 # Increased dropout for stronger regularization
21
+ attention_dropout: 0.15 # Increased attention dropout
22
+ layer_norm_eps: 1e-5 # Tighter normalization for stability
23
+ weight_init_type: "truncated_normal" # Truncated normal for stability
24
+ layer_norm_type: "rms_norm" # RMSNorm for better stability
25
+ use_qk_norm: true # Query-Key normalization for attention stability
26
+
27
+ monitoring:
28
+ save_to_wandb: false
29
+ wandb:
30
+ project: "pico-decoder-tiny"
31
+ entity: "boymyc"
32
+ logging:
33
+ log_every_n_steps: 100 # Reduced logging frequency for faster training
34
+
35
+ training:
36
+ max_steps: 100000 # Longer training for better convergence
37
+ optimization:
38
+ lr: 0.0002 # Scaled learning rate for larger batch size (4x increase)
39
+ lr_warmup_steps: 2000 # Reduced warmup for faster convergence
40
+ lr_scheduler: "cosine" # Cosine decay over full dataset for sustained learning
41
+ weight_decay: 0.02 # Increased weight decay for stronger regularization
42
+ max_grad_norm: 0.5 # Tighter gradient clipping for stability
43
+ gradient_accumulation_steps: 1 # Reduced for faster training with larger batches
44
+ optimizer: "adamw"
45
+ adam_beta1: 0.9 # Standard AdamW beta1
46
+ adam_beta2: 0.999 # Standard AdamW beta2
47
+ adam_epsilon: 1e-8 # Tighter epsilon for numerical stability and convergence
48
+
49
+ fabric:
50
+ num_nodes: 1
51
+ num_devices: 1
52
+ precision: "bf16-mixed" # BF16 for Tensor Core optimization
53
+
54
+ evaluation:
55
+ paloma:
56
+ batch_size: 1 # Minimal evaluation batch size
57
+ eval_every_n_steps: 1000 # Reduced evaluation frequency for faster training
58
+
59
+ data:
60
+ dataset:
61
+ name: "ThomasTheMaker/pretokenized-dolma-20M" # Updated to 5M token dataset
62
+ dataloader:
63
+ batch_size: 16 # Conservative H100 optimization - 4x larger for stable fast training
64
+ tokenizer:
65
+ name: "allenai/OLMo-7B-0724-hf"
66
+ vocab_size: 50304
67
+
68
+ # H100-optimized training strategy for fast, memory-safe training:
69
+ # 1. Conservative batch size (16) with scaled learning rate (0.0002) for stable H100 utilization
70
+ # 2. Reduced gradient accumulation (1 step) for faster optimization cycles
71
+ # 3. Shorter warmup (2000 steps) for quicker convergence with larger batches
72
+ # 4. Reduced evaluation frequency (1000 steps) to minimize training interruptions
73
+ # 5. Reduced checkpoint/logging frequency to minimize I/O overhead
74
+ # 6. Same model architecture and regularization for identical final performance
75
+ # 7. Expected 4-6x training speedup while maintaining model quality and memory safety
76
+ # 8. Memory usage: ~15-25GB of 80GB H100 VRAM (safe utilization avoiding OOM)
77
+ # 9. Maintains all stability features: RMSNorm, QK-Norm, dropout, weight decay
78
+ # 10. Same convergence quality with significant speedup and no memory issues
configs/pico-decoder-tiny-dolma5M-v1.yaml ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # High Quality Training Config - Optimized for superior model performance
2
+ # This configuration prioritizes model quality over training speed
3
+ # Designed for RTX 5090 with focus on preventing overfitting and maximizing generalization
4
+ # Updated for scaling training on Dolma 5M tokens with stability-focused hyperparameters
5
+
6
+ checkpointing:
7
+ run_name: "pico-decoder-tiny-dolma5M-v1"
8
+ save_to_hf: true
9
+ hf_checkpoint:
10
+ repo_id: "ThomasTheMaker/pico-decoder-tiny"
11
+ save_every_n_steps: 500 # Frequent checkpoints for quality monitoring
12
+
13
+ learning_dynamics:
14
+ batch_size: 1 # Minimal batch size for learning dynamics
15
+ eval_data: null # Disable learning dynamics to save memory
16
+
17
+ model:
18
+ d_model: 96
19
+ activation_hidden_dim: 384
20
+ dropout: 0.15 # Increased dropout for stronger regularization
21
+ attention_dropout: 0.15 # Increased attention dropout
22
+ layer_norm_eps: 1e-5 # Tighter normalization for stability
23
+ weight_init_type: "truncated_normal" # Truncated normal for stability
24
+ layer_norm_type: "rms_norm" # RMSNorm for better stability
25
+ use_qk_norm: true # Query-Key normalization for attention stability
26
+
27
+ monitoring:
28
+ save_to_wandb: false
29
+ wandb:
30
+ project: "pico-decoder-tiny"
31
+ entity: "boymyc"
32
+ logging:
33
+ log_every_n_steps: 25 # Very frequent logging for quality monitoring
34
+
35
+ training:
36
+ max_steps: 100000 # Longer training for better convergence
37
+ optimization:
38
+ lr: 0.00005 # Even lower learning rate for precision training
39
+ lr_warmup_steps: 8000 # Extended warmup for stability
40
+ lr_scheduler: "cosine" # Cosine decay over full dataset for sustained learning
41
+ weight_decay: 0.02 # Increased weight decay for stronger regularization
42
+ max_grad_norm: 0.5 # Tighter gradient clipping for stability
43
+ gradient_accumulation_steps: 4 # Increased for better gradient estimates
44
+ optimizer: "adamw"
45
+ adam_beta1: 0.9 # Standard AdamW beta1
46
+ adam_beta2: 0.999 # Standard AdamW beta2
47
+ adam_epsilon: 1e-8 # Tighter epsilon for numerical stability and convergence
48
+
49
+ fabric:
50
+ num_nodes: 1
51
+ num_devices: 1
52
+ precision: "bf16-mixed" # BF16 for Tensor Core optimization
53
+
54
+ evaluation:
55
+ paloma:
56
+ batch_size: 1 # Minimal evaluation batch size
57
+ eval_every_n_steps: 250 # Very frequent evaluation for quality monitoring
58
+
59
+ data:
60
+ dataset:
61
+ name: "ThomasTheMaker/pretokenized-dolma-5M" # Updated to 5M token dataset
62
+ dataloader:
63
+ batch_size: 4 # Reduced for more stable training
64
+ tokenizer:
65
+ name: "allenai/OLMo-7B-0724-hf"
66
+ vocab_size: 50304
67
+
68
+ # Stability-focused training strategy for large-scale Dolma training:
69
+ # 1. Cosine learning rate schedule for sustained learning over full dataset
70
+ # 2. Truncated normal weight initialization to prevent extreme outliers
71
+ # 3. RMSNorm for better gradient stability during long training runs
72
+ # 4. Query-Key normalization (QK-Norm) to prevent attention logit overflow
73
+ # 5. AdamW epsilon 1e-8 for improved training stability and convergence
74
+ # 6. Extended warmup (8000 steps) for stable foundation
75
+ # 7. Stronger regularization (dropout 0.15, weight decay 0.02)
76
+ # 8. Tighter gradient clipping (0.5) for stability
77
+ # 9. More frequent evaluation (every 250 steps) for quality monitoring
78
+ # 10. Longer training (40000 steps) for full convergence on 5M tokens
plots/.gitignore ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Logs
2
+ logs
3
+ *.log
4
+ npm-debug.log*
5
+ yarn-debug.log*
6
+ yarn-error.log*
7
+ firebase-debug.log*
8
+ firebase-debug.*.log*
9
+
10
+ # Firebase cache
11
+ .firebase/
12
+
13
+ # Firebase config
14
+
15
+ # Uncomment this if you'd like others to create their own Firebase project.
16
+ # For a team working on the same Firebase project(s), it is recommended to leave
17
+ # it commented so all members can deploy to the same project(s) in .firebaserc.
18
+ # .firebaserc
19
+
20
+ # Runtime data
21
+ pids
22
+ *.pid
23
+ *.seed
24
+ *.pid.lock
25
+
26
+ # Directory for instrumented libs generated by jscoverage/JSCover
27
+ lib-cov
28
+
29
+ # Coverage directory used by tools like istanbul
30
+ coverage
31
+
32
+ # nyc test coverage
33
+ .nyc_output
34
+
35
+ # Grunt intermediate storage (http://gruntjs.com/creating-plugins#storing-task-files)
36
+ .grunt
37
+
38
+ # Bower dependency directory (https://bower.io/)
39
+ bower_components
40
+
41
+ # node-waf configuration
42
+ .lock-wscript
43
+
44
+ # Compiled binary addons (http://nodejs.org/api/addons.html)
45
+ build/Release
46
+
47
+ # Dependency directories
48
+ node_modules/
49
+
50
+ # Optional npm cache directory
51
+ .npm
52
+
53
+ # Optional eslint cache
54
+ .eslintcache
55
+
56
+ # Optional REPL history
57
+ .node_repl_history
58
+
59
+ # Output of 'npm pack'
60
+ *.tgz
61
+
62
+ # Yarn Integrity file
63
+ .yarn-integrity
64
+
65
+ # dotenv environment variables file
66
+ .env
67
+
68
+ # dataconnect generated files
69
+ .dataconnect
70
+
71
+ # firebase files
72
+
73
+ .firebaserc
74
+ firebase.json
plots/404.html ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html>
3
+ <head>
4
+ <meta charset="utf-8">
5
+ <meta name="viewport" content="width=device-width, initial-scale=1">
6
+ <title>Page Not Found</title>
7
+
8
+ <style media="screen">
9
+ body { background: #ECEFF1; color: rgba(0,0,0,0.87); font-family: Roboto, Helvetica, Arial, sans-serif; margin: 0; padding: 0; }
10
+ #message { background: white; max-width: 360px; margin: 100px auto 16px; padding: 32px 24px 16px; border-radius: 3px; }
11
+ #message h3 { color: #888; font-weight: normal; font-size: 16px; margin: 16px 0 12px; }
12
+ #message h2 { color: #ffa100; font-weight: bold; font-size: 16px; margin: 0 0 8px; }
13
+ #message h1 { font-size: 22px; font-weight: 300; color: rgba(0,0,0,0.6); margin: 0 0 16px;}
14
+ #message p { line-height: 140%; margin: 16px 0 24px; font-size: 14px; }
15
+ #message a { display: block; text-align: center; background: #039be5; text-transform: uppercase; text-decoration: none; color: white; padding: 16px; border-radius: 4px; }
16
+ #message, #message a { box-shadow: 0 1px 3px rgba(0,0,0,0.12), 0 1px 2px rgba(0,0,0,0.24); }
17
+ #load { color: rgba(0,0,0,0.4); text-align: center; font-size: 13px; }
18
+ @media (max-width: 600px) {
19
+ body, #message { margin-top: 0; background: white; box-shadow: none; }
20
+ body { border-top: 16px solid #ffa100; }
21
+ }
22
+ </style>
23
+ </head>
24
+ <body>
25
+ <div id="message">
26
+ <h2>404</h2>
27
+ <h1>Page Not Found</h1>
28
+ <p>The specified file was not found on this website. Please check the URL for mistakes and try again.</p>
29
+ <h3>Why am I seeing this?</h3>
30
+ <p>This page was generated by the Firebase Command-Line Interface. To modify it, edit the <code>404.html</code> file in your project's configured <code>public</code> directory.</p>
31
+ </div>
32
+ </body>
33
+ </html>
plots/README.md ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 🚀 Pico Training Metrics Dashboard
2
+
3
+ A beautiful, interactive web dashboard for visualizing training progress across all your Pico model runs.
4
+
5
+ ## ✨ Features
6
+
7
+ - **📈 Training Loss Visualization**: Track loss curves over time for all runs
8
+ - **🎯 Learning Rate Schedules**: Monitor LR progression and warmup patterns
9
+ - **📊 Paloma Evaluation**: View perplexity metrics during training
10
+ - **🔄 Combined View**: See all metrics together for easy comparison
11
+ - **🎨 Interactive Charts**: Built with Chart.js for smooth interactions
12
+ - **📱 Responsive Design**: Works on desktop and mobile devices
13
+ - **⚙️ Run Comparison**: Compare different model configurations side-by-side
14
+
15
+ ## �� Quick Start
16
+
17
+ 1. **Generate Data**: First, run the data generation script to parse your training logs:
18
+ ```bash
19
+ python scripts/generate_data.py
20
+ ```
21
+
22
+ 2. **View the Dashboard**: Open `index.html` in your web browser
23
+ 3. **Select Runs**: Use the dropdown to view specific runs or all runs together
24
+ 4. **Toggle Metrics**: Check/uncheck boxes to show/hide different metric types
25
+ 5. **Explore Charts**: Hover over data points for detailed information
26
+
27
+ ## 📁 Files
28
+
29
+ - `index.html` - Main dashboard interface
30
+ - `style.css` - Modern, responsive styling
31
+ - `code.js` - Interactive chart functionality
32
+ - `data.json` - Training metrics data (auto-generated from logs)
33
+
34
+ ## 🔧 Data Source
35
+
36
+ The dashboard automatically extracts training metrics from:
37
+ - Training loss at each step
38
+ - Learning rate progression
39
+ - Paloma evaluation results
40
+ - Model configuration parameters
41
+
42
+ ## 🔄 Updating Data
43
+
44
+ To refresh the dashboard with new training data:
45
+ 1. **Run new training sessions** - logs will be saved to `runs/*/logs/`
46
+ 2. **Generate updated data.json**:
47
+ ```bash
48
+ python scripts/generate_data.py
49
+ ```
50
+ 3. **Refresh the dashboard** - new runs will appear automatically
51
+
52
+ ## 🎨 Chart Types
53
+
54
+ 1. **Training Loss**: Line charts showing loss reduction over time
55
+ 2. **Learning Rate**: Logarithmic scale for LR schedule visualization
56
+ 3. **Evaluation**: Paloma perplexity metrics during training
57
+ 4. **Combined**: All metrics on one chart for easy comparison
58
+
59
+ ## 💡 Usage Tips
60
+
61
+ - **Compare Runs**: Select "All Runs" to see how different configurations perform
62
+ - **Zoom In**: Use the chart zoom features to focus on specific training phases
63
+ - **Export**: Right-click charts to save as images
64
+ - **Mobile**: Dashboard is fully responsive for mobile devices
65
+
66
+ ## 🎯 Key Metrics Tracked
67
+
68
+ - **Training Loss**: Primary performance indicator
69
+ - **Learning Rate**: Schedule adherence and warmup progress
70
+ - **Paloma Perplexity**: Model evaluation quality
71
+ - **Inf/NaN Counts**: Training stability monitoring
72
+ - **Model Config**: Architecture and hyperparameter details
73
+
74
+ ## 🌟 Design Features
75
+
76
+ - **Modern UI**: Clean, professional interface
77
+ - **Color Coding**: Distinct colors for each model run
78
+ - **Responsive Layout**: Adapts to different screen sizes
79
+ - **Interactive Elements**: Hover effects and smooth animations
80
+ - **Professional Typography**: Easy-to-read fonts and spacing
81
+
82
+ ## 📚 Documentation
83
+
84
+ For more details on generating the data.json file, see:
85
+ - `scripts/README.md` - Complete script documentation
86
+ - `scripts/generate_data.py` - The data generation script
87
+
88
+ ---
89
+
90
+ Built with ❤️ for the Pico Language Model training community
plots/code.js ADDED
@@ -0,0 +1,550 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Global variables
2
+ let trainingData = null;
3
+ let charts = {};
4
+
5
+ // Color palette for different runs
6
+ const colors = [
7
+ '#667eea', '#764ba2', '#f093fb', '#f5576c', '#4facfe', '#00f2fe',
8
+ '#43e97b', '#38f9d7', '#fa7093', '#fee140', '#a8edea', '#fed6e3'
9
+ ];
10
+
11
+ // Initialize the dashboard
12
+ document.addEventListener('DOMContentLoaded', function() {
13
+ loadData();
14
+ setupEventListeners();
15
+ });
16
+
17
+ // Load training data from JSON file
18
+ async function loadData() {
19
+ try {
20
+ const response = await fetch('data.json');
21
+ trainingData = await response.json();
22
+
23
+ // Merge continuation logs from the same model run
24
+ mergeContinuationLogs();
25
+
26
+ populateRunSelector();
27
+ createCharts();
28
+ updateRunSummary();
29
+ updateConfigDetails();
30
+
31
+ console.log('Data loaded and merged successfully:', trainingData);
32
+ } catch (error) {
33
+ console.error('Error loading data:', error);
34
+ document.body.innerHTML = '<div class="loading">Error loading training data. Please check the console for details.</div>';
35
+ }
36
+ }
37
+
38
+ // Merge continuation logs from the same model run
39
+ function mergeContinuationLogs() {
40
+ const runGroups = {};
41
+
42
+ // Group runs by base model name
43
+ trainingData.runs.forEach(run => {
44
+ const baseName = run.run_name;
45
+ if (!runGroups[baseName]) {
46
+ runGroups[baseName] = [];
47
+ }
48
+ runGroups[baseName].push(run);
49
+ });
50
+
51
+ // Merge runs with the same base name
52
+ const mergedRuns = [];
53
+
54
+ Object.entries(runGroups).forEach(([baseName, runs]) => {
55
+ if (runs.length === 1) {
56
+ // Single run, no merging needed
57
+ mergedRuns.push(runs[0]);
58
+ } else {
59
+ // Multiple runs to merge
60
+ console.log(`Merging ${runs.length} continuation logs for ${baseName}`);
61
+
62
+ const mergedRun = {
63
+ run_name: baseName,
64
+ log_files: runs.map(r => r.log_file),
65
+ training_metrics: [],
66
+ evaluation_results: [],
67
+ config: runs[0].config || {}
68
+ };
69
+
70
+ // Merge training metrics (they should be continuous)
71
+ runs.forEach(run => {
72
+ if (run.training_metrics) {
73
+ mergedRun.training_metrics.push(...run.training_metrics);
74
+ }
75
+ });
76
+
77
+ // Merge evaluation results (they should be continuous)
78
+ runs.forEach(run => {
79
+ if (run.evaluation_results) {
80
+ mergedRun.evaluation_results.push(...run.evaluation_results);
81
+ }
82
+ });
83
+
84
+ // Sort by step number to ensure proper ordering
85
+ mergedRun.training_metrics.sort((a, b) => a.step - b.step);
86
+ mergedRun.evaluation_results.sort((a, b) => a.step - b.step);
87
+
88
+ // Remove duplicates based on step number
89
+ mergedRun.training_metrics = mergedRun.training_metrics.filter((metric, index, self) =>
90
+ index === 0 || metric.step !== self[index - 1].step
91
+ );
92
+ mergedRun.evaluation_results = mergedRun.evaluation_results.filter((result, index, self) =>
93
+ index === 0 || result.step !== self[index - 1].step
94
+ );
95
+
96
+ console.log(`Merged ${baseName}: ${mergedRun.training_metrics.length} training points, ${mergedRun.evaluation_results.length} eval points`);
97
+ mergedRuns.push(mergedRun);
98
+ }
99
+ });
100
+
101
+ trainingData.runs = mergedRuns;
102
+ }
103
+
104
+ // Setup event listeners for controls
105
+ function setupEventListeners() {
106
+ document.getElementById('runSelect').addEventListener('change', function() {
107
+ updateCharts();
108
+ updateRunSummary();
109
+ updateConfigDetails();
110
+ });
111
+ document.getElementById('showTraining').addEventListener('change', updateCharts);
112
+ document.getElementById('showLearningRate').addEventListener('change', updateCharts);
113
+ document.getElementById('showEvaluation').addEventListener('change', updateCharts);
114
+ }
115
+
116
+ // Populate run selector dropdown
117
+ function populateRunSelector() {
118
+ const select = document.getElementById('runSelect');
119
+ const runs = trainingData.runs;
120
+
121
+ // Clear existing options
122
+ select.innerHTML = '<option value="all">All Runs</option>';
123
+
124
+ runs.forEach((run, index) => {
125
+ const option = document.createElement('option');
126
+ option.value = index;
127
+ option.textContent = run.run_name;
128
+ select.appendChild(option);
129
+ });
130
+ }
131
+
132
+ // Create all charts
133
+ function createCharts() {
134
+ createLossChart();
135
+ createLRChart();
136
+ createEvalChart();
137
+ createCombinedChart();
138
+ }
139
+
140
+ // Create training loss chart
141
+ function createLossChart() {
142
+ const ctx = document.getElementById('lossChart').getContext('2d');
143
+
144
+ charts.loss = new Chart(ctx, {
145
+ type: 'line',
146
+ data: getChartData('loss'),
147
+ options: {
148
+ responsive: true,
149
+ maintainAspectRatio: false,
150
+ plugins: {
151
+ title: {
152
+ display: true,
153
+ text: 'Training Loss Over Time'
154
+ },
155
+ legend: {
156
+ position: 'top'
157
+ }
158
+ },
159
+ scales: {
160
+ x: {
161
+ type: 'linear',
162
+ title: {
163
+ display: true,
164
+ text: 'Training Step'
165
+ }
166
+ },
167
+ y: {
168
+ title: {
169
+ display: true,
170
+ text: 'Loss'
171
+ },
172
+ beginAtZero: false
173
+ }
174
+ },
175
+ interaction: {
176
+ intersect: false,
177
+ mode: 'index'
178
+ }
179
+ }
180
+ });
181
+ }
182
+
183
+ // Create learning rate chart
184
+ function createLRChart() {
185
+ const ctx = document.getElementById('lrChart').getContext('2d');
186
+
187
+ charts.lr = new Chart(ctx, {
188
+ type: 'line',
189
+ data: getChartData('lr'),
190
+ options: {
191
+ responsive: true,
192
+ maintainAspectRatio: false,
193
+ plugins: {
194
+ title: {
195
+ display: true,
196
+ text: 'Learning Rate Schedule'
197
+ },
198
+ legend: {
199
+ position: 'top'
200
+ }
201
+ },
202
+ scales: {
203
+ x: {
204
+ type: 'linear',
205
+ title: {
206
+ display: true,
207
+ text: 'Training Step'
208
+ }
209
+ },
210
+ y: {
211
+ title: {
212
+ display: true,
213
+ text: 'Learning Rate'
214
+ },
215
+ type: 'logarithmic'
216
+ }
217
+ },
218
+ interaction: {
219
+ intersect: false,
220
+ mode: 'index'
221
+ }
222
+ }
223
+ });
224
+ }
225
+
226
+ // Create evaluation chart
227
+ function createEvalChart() {
228
+ const ctx = document.getElementById('evalChart').getContext('2d');
229
+
230
+ charts.eval = new Chart(ctx, {
231
+ type: 'line',
232
+ data: getChartData('eval'),
233
+ options: {
234
+ responsive: true,
235
+ maintainAspectRatio: false,
236
+ plugins: {
237
+ title: {
238
+ display: true,
239
+ text: 'Paloma Evaluation Metrics'
240
+ },
241
+ legend: {
242
+ position: 'top'
243
+ }
244
+ },
245
+ scales: {
246
+ x: {
247
+ type: 'linear',
248
+ title: {
249
+ display: true,
250
+ text: 'Training Step'
251
+ }
252
+ },
253
+ y: {
254
+ title: {
255
+ display: true,
256
+ text: 'Perplexity'
257
+ },
258
+ type: 'logarithmic'
259
+ }
260
+ },
261
+ interaction: {
262
+ intersect: false,
263
+ mode: 'index'
264
+ }
265
+ }
266
+ });
267
+ }
268
+
269
+ // Create combined chart
270
+ function createCombinedChart() {
271
+ const ctx = document.getElementById('combinedChart').getContext('2d');
272
+
273
+ charts.combined = new Chart(ctx, {
274
+ type: 'line',
275
+ data: getCombinedChartData(),
276
+ options: {
277
+ responsive: true,
278
+ maintainAspectRatio: false,
279
+ plugins: {
280
+ title: {
281
+ display: true,
282
+ text: 'Combined Training Metrics'
283
+ },
284
+ legend: {
285
+ position: 'top'
286
+ }
287
+ },
288
+ scales: {
289
+ x: {
290
+ type: 'linear',
291
+ title: {
292
+ display: true,
293
+ text: 'Training Step'
294
+ }
295
+ },
296
+ y: {
297
+ title: {
298
+ display: true,
299
+ text: 'Value'
300
+ }
301
+ }
302
+ },
303
+ interaction: {
304
+ intersect: false,
305
+ mode: 'index'
306
+ }
307
+ }
308
+ });
309
+ }
310
+
311
+ // Get chart data for specific metric type
312
+ function getChartData(metricType) {
313
+ const selectedRun = document.getElementById('runSelect').value;
314
+ const runs = selectedRun === 'all' ? trainingData.runs : [trainingData.runs[selectedRun]];
315
+
316
+ const datasets = [];
317
+
318
+ console.log(`Getting ${metricType} data for ${runs.length} runs:`, runs.map(r => r.run_name));
319
+
320
+ runs.forEach((run, runIndex) => {
321
+ const color = colors[runIndex % colors.length];
322
+
323
+ if (metricType === 'loss') {
324
+ if (run.training_metrics && run.training_metrics.length > 0) {
325
+ const data = run.training_metrics.map(m => ({ x: m.step, y: m.loss }));
326
+ console.log(`Loss data for ${run.run_name}:`, data.slice(0, 5), '...', data.slice(-5));
327
+ datasets.push({
328
+ label: run.run_name,
329
+ data: data,
330
+ borderColor: color,
331
+ backgroundColor: color + '20',
332
+ borderWidth: 2,
333
+ fill: false,
334
+ tension: 0.1
335
+ });
336
+ }
337
+ } else if (metricType === 'lr') {
338
+ if (run.training_metrics && run.training_metrics.length > 0) {
339
+ const data = run.training_metrics.map(m => ({ x: m.step, y: m.learning_rate }));
340
+ console.log(`LR data for ${run.run_name}:`, data.slice(0, 5), '...', data.slice(-5));
341
+ datasets.push({
342
+ label: run.run_name,
343
+ data: data,
344
+ borderColor: color,
345
+ backgroundColor: color + '20',
346
+ borderWidth: 2,
347
+ fill: false,
348
+ tension: 0.1
349
+ });
350
+ }
351
+ } else if (metricType === 'eval') {
352
+ if (run.evaluation_results && run.evaluation_results.length > 0) {
353
+ const data = run.evaluation_results.map(m => ({ x: m.step, y: m.paloma }));
354
+ console.log(`Eval data for ${run.run_name}:`, data.slice(0, 5), '...', data.slice(-5));
355
+ datasets.push({
356
+ label: run.run_name,
357
+ data: data,
358
+ borderColor: color,
359
+ backgroundColor: color + '20',
360
+ borderWidth: 2,
361
+ fill: false,
362
+ tension: 0.1
363
+ });
364
+ }
365
+ }
366
+ });
367
+
368
+ console.log(`Final ${metricType} datasets:`, datasets);
369
+ return { datasets };
370
+ }
371
+
372
+ // Get combined chart data
373
+ function getCombinedChartData() {
374
+ const selectedRun = document.getElementById('runSelect').value;
375
+ const runs = selectedRun === 'all' ? trainingData.runs : [trainingData.runs[selectedRun]];
376
+
377
+ const datasets = [];
378
+
379
+ runs.forEach((run, runIndex) => {
380
+ const color = colors[runIndex % colors.length];
381
+
382
+ // Training loss
383
+ if (run.training_metrics && run.training_metrics.length > 0) {
384
+ datasets.push({
385
+ label: `${run.run_name} - Loss`,
386
+ data: run.training_metrics.map(m => ({ x: m.step, y: m.loss })),
387
+ borderColor: color,
388
+ backgroundColor: color + '20',
389
+ borderWidth: 2,
390
+ fill: false,
391
+ tension: 0.1
392
+ });
393
+ }
394
+
395
+ // Learning rate (scaled)
396
+ if (run.training_metrics && run.training_metrics.length > 0) {
397
+ const maxLR = Math.max(...run.training_metrics.map(m => m.learning_rate));
398
+ const maxLoss = Math.max(...run.training_metrics.map(m => m.loss));
399
+ const scaleFactor = maxLoss / maxLR;
400
+
401
+ datasets.push({
402
+ label: `${run.run_name} - LR (scaled)`,
403
+ data: run.training_metrics.map(m => ({ x: m.step, y: m.learning_rate * scaleFactor })),
404
+ borderColor: color + '80',
405
+ backgroundColor: color + '10',
406
+ borderWidth: 1,
407
+ fill: false,
408
+ tension: 0.1
409
+ });
410
+ }
411
+ });
412
+
413
+ return { datasets };
414
+ }
415
+
416
+ // Update all charts based on current selection
417
+ function updateCharts() {
418
+ if (charts.loss) {
419
+ charts.loss.data = getChartData('loss');
420
+ charts.loss.update();
421
+ }
422
+
423
+ if (charts.lr) {
424
+ charts.lr.data = getChartData('lr');
425
+ charts.lr.update();
426
+ }
427
+
428
+ if (charts.eval) {
429
+ charts.eval.data = getChartData('eval');
430
+ charts.eval.update();
431
+ }
432
+
433
+ if (charts.combined) {
434
+ charts.combined.data = getCombinedChartData();
435
+ charts.combined.update();
436
+ }
437
+ }
438
+
439
+ // Update run summary section
440
+ function updateRunSummary() {
441
+ const container = document.getElementById('runSummary');
442
+ const selectedRun = document.getElementById('runSelect').value;
443
+ const runs = selectedRun === 'all' ? trainingData.runs : [trainingData.runs[selectedRun]];
444
+
445
+ let html = '<div class="run-grid">';
446
+
447
+ runs.forEach(run => {
448
+ const trainingPoints = run.training_metrics ? run.training_metrics.length : 0;
449
+ const evalPoints = run.evaluation_results ? run.evaluation_results.length : 0;
450
+
451
+ let finalLoss = 'N/A';
452
+ let finalLR = 'N/A';
453
+ let finalPaloma = 'N/A';
454
+ let stepRange = 'N/A';
455
+
456
+ if (run.training_metrics && run.training_metrics.length > 0) {
457
+ const first = run.training_metrics[0];
458
+ const last = run.training_metrics[run.training_metrics.length - 1];
459
+ finalLoss = last.loss.toFixed(4);
460
+ finalLR = last.learning_rate.toExponential(2);
461
+ stepRange = `${first.step} → ${last.step}`;
462
+ }
463
+
464
+ if (run.evaluation_results && run.evaluation_results.length > 0) {
465
+ const last = run.evaluation_results[run.evaluation_results.length - 1];
466
+ if (isFinite(last.paloma)) {
467
+ finalPaloma = last.paloma.toExponential(2);
468
+ } else {
469
+ finalPaloma = '∞';
470
+ }
471
+ }
472
+
473
+ const logFiles = run.log_files ? run.log_files.join(', ') : run.log_file;
474
+
475
+ html += `
476
+ <div class="run-card">
477
+ <h4>${run.run_name}</h4>
478
+ <p><strong>Logs:</strong> ${logFiles}</p>
479
+ <div class="metric">
480
+ <span>Step Range:</span>
481
+ <span class="value">${stepRange}</span>
482
+ </div>
483
+ <div class="metric">
484
+ <span>Training Points:</span>
485
+ <span class="value">${trainingPoints}</span>
486
+ </div>
487
+ <div class="metric">
488
+ <span>Evaluation Points:</span>
489
+ <span class="value">${evalPoints}</span>
490
+ </div>
491
+ <div class="metric">
492
+ <span>Final Loss:</span>
493
+ <span class="value">${finalLoss}</span>
494
+ </div>
495
+ <div class="metric">
496
+ <span>Final LR:</span>
497
+ <span class="value">${finalLR}</span>
498
+ </div>
499
+ <div class="metric">
500
+ <span>Final Paloma:</span>
501
+ <span class="value">${finalPaloma}</span>
502
+ </div>
503
+ </div>
504
+ `;
505
+ });
506
+
507
+ html += '</div>';
508
+ container.innerHTML = html;
509
+ }
510
+
511
+ // Update configuration details section
512
+ function updateConfigDetails() {
513
+ const container = document.getElementById('configDetails');
514
+ const selectedRun = document.getElementById('runSelect').value;
515
+ const runs = selectedRun === 'all' ? trainingData.runs : [trainingData.runs[selectedRun]];
516
+
517
+ let html = '<div class="config-grid">';
518
+
519
+ // Get unique config keys
520
+ const allKeys = new Set();
521
+ runs.forEach(run => {
522
+ if (run.config) {
523
+ Object.keys(run.config).forEach(key => allKeys.add(key));
524
+ }
525
+ });
526
+
527
+ allKeys.forEach(key => {
528
+ const values = runs.map(run => run.config && run.config[key] !== undefined ? run.config[key] : 'N/A');
529
+ const uniqueValues = [...new Set(values)];
530
+ const displayValue = uniqueValues.length === 1 ? uniqueValues[0] : `${uniqueValues.join(' / ')}`;
531
+
532
+ html += `
533
+ <div class="config-item">
534
+ <div class="label">${key.replace(/_/g, ' ').toUpperCase()}</div>
535
+ <div class="value">${displayValue}</div>
536
+ </div>
537
+ `;
538
+ });
539
+
540
+ html += '</div>';
541
+ container.innerHTML = html;
542
+ }
543
+
544
+ // Utility function to format large numbers
545
+ function formatNumber(num) {
546
+ if (num >= 1e9) return (num / 1e9).toFixed(2) + 'B';
547
+ if (num >= 1e6) return (num / 1e6).toFixed(2) + 'M';
548
+ if (num >= 1e3) return (num / 1e3).toFixed(2) + 'K';
549
+ return num.toString();
550
+ }
plots/data.json ADDED
The diff for this file is too large to render. See raw diff
 
plots/index.html ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+ <head>
4
+ <meta charset="UTF-8">
5
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
6
+ <title>Pico Training Metrics Dashboard</title>
7
+ <link rel="stylesheet" href="style.css">
8
+ <script src="https://cdn.jsdelivr.net/npm/chart.js"></script>
9
+ </head>
10
+ <body>
11
+ <div class="container">
12
+ <header>
13
+ <h1>🚀 Pico Training Metrics Dashboard</h1>
14
+ <p>Real-time visualization of training progress across all model runs</p>
15
+ </header>
16
+
17
+ <div class="controls">
18
+ <div class="run-selector">
19
+ <label for="runSelect">Select Run:</label>
20
+ <select id="runSelect">
21
+ <option value="all">All Runs</option>
22
+ </select>
23
+ </div>
24
+ <div class="metric-toggle">
25
+ <label>
26
+ <input type="checkbox" id="showTraining" checked> Training Loss
27
+ </label>
28
+ <label>
29
+ <input type="checkbox" id="showLearningRate" checked> Learning Rate
30
+ </label>
31
+ <label>
32
+ <input type="checkbox" id="showEvaluation" checked> Paloma Evaluation
33
+ </label>
34
+ </div>
35
+ </div>
36
+
37
+ <div class="charts-container">
38
+ <div class="chart-card">
39
+ <h3>📈 Training Loss Over Time</h3>
40
+ <canvas id="lossChart"></canvas>
41
+ </div>
42
+
43
+ <div class="chart-card">
44
+ <h3>🎯 Learning Rate Schedule</h3>
45
+ <canvas id="lrChart"></canvas>
46
+ </div>
47
+
48
+ <div class="chart-card">
49
+ <h3>📊 Paloma Evaluation Metrics</h3>
50
+ <canvas id="evalChart"></canvas>
51
+ </div>
52
+
53
+ <div class="chart-card">
54
+ <h3>🔄 Combined View</h3>
55
+ <canvas id="combinedChart"></canvas>
56
+ </div>
57
+ </div>
58
+
59
+ <div class="run-summary">
60
+ <h3>📋 Run Summary</h3>
61
+ <div id="runSummary"></div>
62
+ </div>
63
+
64
+ <div class="config-details">
65
+ <h3>⚙️ Model Configuration</h3>
66
+ <div id="configDetails"></div>
67
+ </div>
68
+ </div>
69
+
70
+ <script src="code.js"></script>
71
+ </body>
72
+ </html>
plots/style.css ADDED
@@ -0,0 +1,258 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ * {
2
+ margin: 0;
3
+ padding: 0;
4
+ box-sizing: border-box;
5
+ }
6
+
7
+ body {
8
+ font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, Cantarell, sans-serif;
9
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
10
+ min-height: 100vh;
11
+ color: #333;
12
+ }
13
+
14
+ .container {
15
+ max-width: 1400px;
16
+ margin: 0 auto;
17
+ padding: 20px;
18
+ }
19
+
20
+ header {
21
+ text-align: center;
22
+ margin-bottom: 30px;
23
+ color: white;
24
+ }
25
+
26
+ header h1 {
27
+ font-size: 2.5rem;
28
+ margin-bottom: 10px;
29
+ text-shadow: 2px 2px 4px rgba(0,0,0,0.3);
30
+ }
31
+
32
+ header p {
33
+ font-size: 1.1rem;
34
+ opacity: 0.9;
35
+ }
36
+
37
+ .controls {
38
+ background: white;
39
+ padding: 20px;
40
+ border-radius: 12px;
41
+ box-shadow: 0 8px 32px rgba(0,0,0,0.1);
42
+ margin-bottom: 30px;
43
+ display: flex;
44
+ justify-content: space-between;
45
+ align-items: center;
46
+ flex-wrap: wrap;
47
+ gap: 20px;
48
+ }
49
+
50
+ .run-selector select {
51
+ padding: 8px 16px;
52
+ border: 2px solid #e1e5e9;
53
+ border-radius: 8px;
54
+ font-size: 14px;
55
+ background: white;
56
+ cursor: pointer;
57
+ transition: border-color 0.3s ease;
58
+ }
59
+
60
+ .run-selector select:focus {
61
+ outline: none;
62
+ border-color: #667eea;
63
+ }
64
+
65
+ .metric-toggle {
66
+ display: flex;
67
+ gap: 20px;
68
+ flex-wrap: wrap;
69
+ }
70
+
71
+ .metric-toggle label {
72
+ display: flex;
73
+ align-items: center;
74
+ gap: 8px;
75
+ cursor: pointer;
76
+ font-weight: 500;
77
+ color: #555;
78
+ }
79
+
80
+ .metric-toggle input[type="checkbox"] {
81
+ width: 18px;
82
+ height: 18px;
83
+ accent-color: #667eea;
84
+ }
85
+
86
+ .charts-container {
87
+ display: grid;
88
+ grid-template-columns: repeat(auto-fit, minmax(600px, 1fr));
89
+ gap: 30px;
90
+ margin-bottom: 30px;
91
+ }
92
+
93
+ .chart-card {
94
+ background: white;
95
+ padding: 25px;
96
+ border-radius: 12px;
97
+ box-shadow: 0 8px 32px rgba(0,0,0,0.1);
98
+ transition: transform 0.3s ease, box-shadow 0.3s ease;
99
+ }
100
+
101
+ .chart-card:hover {
102
+ transform: translateY(-5px);
103
+ box-shadow: 0 12px 40px rgba(0,0,0,0.15);
104
+ }
105
+
106
+ .chart-card h3 {
107
+ margin-bottom: 20px;
108
+ color: #333;
109
+ font-size: 1.2rem;
110
+ display: flex;
111
+ align-items: center;
112
+ gap: 8px;
113
+ }
114
+
115
+ .chart-card canvas {
116
+ max-height: 400px;
117
+ width: 100% !important;
118
+ }
119
+
120
+ .run-summary, .config-details {
121
+ background: white;
122
+ padding: 25px;
123
+ border-radius: 12px;
124
+ box-shadow: 0 8px 32px rgba(0,0,0,0.1);
125
+ margin-bottom: 30px;
126
+ }
127
+
128
+ .run-summary h3, .config-details h3 {
129
+ margin-bottom: 20px;
130
+ color: #333;
131
+ font-size: 1.2rem;
132
+ display: flex;
133
+ align-items: center;
134
+ gap: 8px;
135
+ }
136
+
137
+ .run-grid {
138
+ display: grid;
139
+ grid-template-columns: repeat(auto-fit, minmax(300px, 1fr));
140
+ gap: 20px;
141
+ }
142
+
143
+ .run-card {
144
+ background: #f8f9fa;
145
+ padding: 20px;
146
+ border-radius: 8px;
147
+ border-left: 4px solid #667eea;
148
+ }
149
+
150
+ .run-card h4 {
151
+ color: #667eea;
152
+ margin-bottom: 10px;
153
+ font-size: 1.1rem;
154
+ }
155
+
156
+ .run-card p {
157
+ margin-bottom: 8px;
158
+ color: #666;
159
+ font-size: 0.9rem;
160
+ }
161
+
162
+ .run-card .metric {
163
+ display: flex;
164
+ justify-content: space-between;
165
+ margin-bottom: 5px;
166
+ }
167
+
168
+ .run-card .metric .value {
169
+ font-weight: 600;
170
+ color: #333;
171
+ }
172
+
173
+ .config-grid {
174
+ display: grid;
175
+ grid-template-columns: repeat(auto-fit, minmax(200px, 1fr));
176
+ gap: 15px;
177
+ }
178
+
179
+ .config-item {
180
+ background: #f8f9fa;
181
+ padding: 15px;
182
+ border-radius: 8px;
183
+ text-align: center;
184
+ }
185
+
186
+ .config-item .label {
187
+ font-size: 0.8rem;
188
+ color: #666;
189
+ text-transform: uppercase;
190
+ letter-spacing: 0.5px;
191
+ margin-bottom: 5px;
192
+ }
193
+
194
+ .config-item .value {
195
+ font-size: 1.2rem;
196
+ font-weight: 600;
197
+ color: #333;
198
+ }
199
+
200
+ @media (max-width: 768px) {
201
+ .container {
202
+ padding: 15px;
203
+ }
204
+
205
+ header h1 {
206
+ font-size: 2rem;
207
+ }
208
+
209
+ .controls {
210
+ flex-direction: column;
211
+ align-items: stretch;
212
+ }
213
+
214
+ .charts-container {
215
+ grid-template-columns: 1fr;
216
+ }
217
+
218
+ .chart-card {
219
+ padding: 20px;
220
+ }
221
+
222
+ .run-grid, .config-grid {
223
+ grid-template-columns: 1fr;
224
+ }
225
+ }
226
+
227
+ /* Chart.js customizations */
228
+ .chartjs-tooltip {
229
+ background: rgba(0,0,0,0.8) !important;
230
+ color: white !important;
231
+ border-radius: 8px !important;
232
+ padding: 10px !important;
233
+ font-size: 12px !important;
234
+ }
235
+
236
+ /* Loading state */
237
+ .loading {
238
+ text-align: center;
239
+ padding: 40px;
240
+ color: #666;
241
+ }
242
+
243
+ .loading::after {
244
+ content: '';
245
+ display: inline-block;
246
+ width: 20px;
247
+ height: 20px;
248
+ border: 3px solid #f3f3f3;
249
+ border-top: 3px solid #667eea;
250
+ border-radius: 50%;
251
+ animation: spin 1s linear infinite;
252
+ margin-left: 10px;
253
+ }
254
+
255
+ @keyframes spin {
256
+ 0% { transform: rotate(0deg); }
257
+ 100% { transform: rotate(360deg); }
258
+ }
pyproject.toml ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [tool.poetry]
2
+ name = "pico-train"
3
+ version = "1.0.0"
4
+ description = "A minimalistic framework for transparently training language models and storing comprehensive checkpoints for in-depth learning dynamics research"
5
+ authors = ["Richard Diehl Martinez <richard@picolm.io>"]
6
+ license = "Apache 2.0"
7
+ readme = "README.md"
8
+ packages = [{include = "src"}]
9
+
10
+ [tool.poetry.scripts]
11
+ train = "scripts.train:main"
12
+
13
+ [tool.poetry.dependencies]
14
+ python = "^3.10,<3.13"
15
+ lightning = "^2.4.0"
16
+ click = "^8.1.7"
17
+ wandb = "^0.18.1"
18
+ huggingface-hub = {extras = ["cli"], version = "^0.25.1"}
19
+ datasets = "^3.0.1,<3.2.0"
20
+ transformers = "^4.45.2"
21
+ pre-commit = "^4.0.1"
22
+ torch = "^2.5.1"
23
+ evaluate = "^0.4.3"
24
+ deepspeed = "^0.16.2"
25
+ rich = "^13.9.4"
26
+
27
+ [tool.poetry.group.dev.dependencies]
28
+ ipykernel = "^6.29.5"
29
+ jupyter = "^1.1.1"
30
+
31
+ [build-system]
32
+ requires = ["poetry-core"]
33
+ build-backend = "poetry.core.masonry.api"
scripts/README.md ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Scripts Directory
2
+
3
+ This directory contains utility scripts for the Pico training framework.
4
+
5
+ ## generate_data.py
6
+
7
+ A script to automatically generate `data.json` from training log files for the dashboard.
8
+
9
+ ### What it does
10
+
11
+ This script parses log files from the `runs/` directory and extracts:
12
+ - **Training metrics**: Loss, learning rate, and inf/NaN counts at each step
13
+ - **Evaluation results**: Paloma evaluation metrics
14
+ - **Model configuration**: Architecture parameters (d_model, n_layers, etc.)
15
+
16
+ ### Usage
17
+
18
+ ```bash
19
+ # Generate data.json from the default runs directory
20
+ python scripts/generate_data.py
21
+
22
+ # Specify custom runs directory
23
+ python scripts/generate_data.py --runs-dir /path/to/runs
24
+
25
+ # Specify custom output file
26
+ python scripts/generate_data.py --output /path/to/output.json
27
+ ```
28
+
29
+ ### How it works
30
+
31
+ 1. **Scans runs directory**: Looks for subdirectories containing training runs
32
+ 2. **Finds log files**: Locates `.log` files in each run's `logs/` subdirectory
33
+ 3. **Parses log content**: Uses regex patterns to extract structured data
34
+ 4. **Generates JSON**: Creates a structured JSON file for the dashboard
35
+
36
+ ### Log Format Requirements
37
+
38
+ The script expects log files with the following format:
39
+
40
+ ```
41
+ 2025-08-29 02:09:12 - pico-train - INFO - Step 500 -- 🔄 Training Metrics
42
+ 2025-08-29 02:09:12 - pico-train - INFO - ├── Loss: 10.8854
43
+ 2025-08-29 02:09:12 - pico-train - INFO - ├── Learning Rate: 3.13e-06
44
+ 2025-08-29 02:09:12 - pico-train - INFO - └── Inf/NaN count: 0
45
+ ```
46
+
47
+ And evaluation results:
48
+
49
+ ```
50
+ 2025-08-29 02:15:26 - pico-train - INFO - Step 1000 -- 📊 Evaluation Results
51
+ 2025-08-29 02:15:26 - pico-train - INFO - └── paloma: 7.125172406420199e+27
52
+ ```
53
+
54
+ ### Output Format
55
+
56
+ The generated `data.json` has this structure:
57
+
58
+ ```json
59
+ {
60
+ "runs": [
61
+ {
62
+ "run_name": "model-name",
63
+ "log_file": "log_filename.log",
64
+ "training_metrics": [
65
+ {
66
+ "step": 0,
67
+ "loss": 10.9914,
68
+ "learning_rate": 0.0,
69
+ "inf_nan_count": 0
70
+ }
71
+ ],
72
+ "evaluation_results": [
73
+ {
74
+ "step": 1000,
75
+ "paloma": 59434.76600609756
76
+ }
77
+ ],
78
+ "config": {
79
+ "d_model": 96,
80
+ "n_layers": 12,
81
+ "max_seq_len": 2048,
82
+ "vocab_size": 50304,
83
+ "lr": 0.0003,
84
+ "max_steps": 200000,
85
+ "batch_size": 8
86
+ }
87
+ }
88
+ ],
89
+ "summary": {
90
+ "total_runs": 1,
91
+ "run_names": ["model-name"]
92
+ }
93
+ }
94
+ ```
95
+
96
+ ### When to use
97
+
98
+ - **After training**: Generate updated dashboard data
99
+ - **Adding new runs**: Include new training sessions in the dashboard
100
+ - **Debugging**: Verify log parsing is working correctly
101
+ - **Dashboard setup**: Initial setup of the training metrics dashboard
102
+
103
+ ### Troubleshooting
104
+
105
+ If the script doesn't find any data:
106
+ 1. Check that log files exist in `runs/*/logs/`
107
+ 2. Verify log format matches the expected pattern
108
+ 3. Ensure log files contain training metrics entries
109
+ 4. Check file permissions and encoding
scripts/generate_data.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Script to generate data.json from training log files.
4
+
5
+ This script parses log files from the runs directory and extracts:
6
+ - Training metrics (loss, learning rate, inf/nan count)
7
+ - Evaluation results (paloma metrics)
8
+ - Model configuration parameters
9
+
10
+ The output is saved to plots/data.json for the dashboard.
11
+ """
12
+
13
+ import json
14
+ import re
15
+ from pathlib import Path
16
+ from typing import Any, Dict, List, Optional
17
+
18
+
19
+ def parse_training_metrics(log_content: str) -> List[Dict[str, Any]]:
20
+ """Parse training metrics from log content."""
21
+ metrics = []
22
+
23
+ # Pattern to match training metrics entries with timestamp and log level
24
+ pattern = r"\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2} - pico-train - INFO - Step (\d+) -- 🔄 Training Metrics\n\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2} - pico-train - INFO - ├── Loss: ([\d.]+)\n\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2} - pico-train - INFO - ├── Learning Rate: ([\d.e+-]+)\n\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2} - pico-train - INFO - └── Inf/NaN count: (\d+)"
25
+
26
+ matches = re.findall(pattern, log_content)
27
+
28
+ for step, loss, lr, inf_nan in matches:
29
+ metrics.append(
30
+ {
31
+ "step": int(step),
32
+ "loss": float(loss),
33
+ "learning_rate": float(lr),
34
+ "inf_nan_count": int(inf_nan),
35
+ }
36
+ )
37
+
38
+ return sorted(metrics, key=lambda x: x["step"])
39
+
40
+
41
+ def parse_evaluation_results(log_content: str) -> List[Dict[str, Any]]:
42
+ """Parse evaluation results from log content."""
43
+ results = []
44
+
45
+ # Pattern to match evaluation results with timestamp and log level
46
+ pattern = r"\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2} - pico-train - INFO - Step (\d+) -- 📊 Evaluation Results\n\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2} - pico-train - INFO - └── paloma: ([\d.e+-]+)"
47
+
48
+ matches = re.findall(pattern, log_content)
49
+
50
+ for step, paloma in matches:
51
+ try:
52
+ paloma_value = float(paloma)
53
+ results.append({"step": int(step), "paloma": paloma_value})
54
+ except ValueError:
55
+ # Skip if paloma value is not a valid number (e.g., "inf")
56
+ continue
57
+
58
+ return sorted(results, key=lambda x: x["step"])
59
+
60
+
61
+ def extract_config_from_log(log_content: str) -> Dict[str, Any]:
62
+ """Extract model configuration from log content."""
63
+ config = {}
64
+
65
+ # Extract key model parameters
66
+ patterns = {
67
+ "d_model": r"d_model: (\d+)",
68
+ "n_layers": r"n_layers: (\d+)",
69
+ "max_seq_len": r"max_seq_len: (\d+)",
70
+ "vocab_size": r"vocab_size: (\d+)",
71
+ "lr": r"lr: ([\d.e+-]+)",
72
+ "max_steps": r"max_steps: (\d+)",
73
+ "batch_size": r"batch_size: (\d+)",
74
+ }
75
+
76
+ for key, pattern in patterns.items():
77
+ match = re.search(pattern, log_content)
78
+ if match:
79
+ try:
80
+ if key in [
81
+ "d_model",
82
+ "n_layers",
83
+ "max_seq_len",
84
+ "vocab_size",
85
+ "max_steps",
86
+ "batch_size",
87
+ ]:
88
+ config[key] = int(match.group(1))
89
+ else:
90
+ config[key] = float(match.group(1))
91
+ except ValueError:
92
+ continue
93
+
94
+ return config
95
+
96
+
97
+ def process_run_directory(run_path: Path) -> Optional[Dict[str, Any]]:
98
+ """Process a single run directory and extract all data."""
99
+ run_name = run_path.name
100
+
101
+ # Find log files
102
+ logs_dir = run_path / "logs"
103
+ if not logs_dir.exists():
104
+ return None
105
+
106
+ log_files = list(logs_dir.glob("*.log"))
107
+ if not log_files:
108
+ return None
109
+
110
+ # Use the most recent log file for configuration
111
+ latest_log = max(log_files, key=lambda x: x.stat().st_mtime)
112
+
113
+ # Read log content
114
+ log_content = latest_log.read_text(encoding="utf-8")
115
+
116
+ # Extract data
117
+ training_metrics = parse_training_metrics(log_content)
118
+ evaluation_results = parse_evaluation_results(log_content)
119
+ config = extract_config_from_log(log_content)
120
+
121
+ # If no training metrics found, skip this run
122
+ if not training_metrics:
123
+ return None
124
+
125
+ return {
126
+ "run_name": run_name,
127
+ "log_file": latest_log.name,
128
+ "training_metrics": training_metrics,
129
+ "evaluation_results": evaluation_results,
130
+ "config": config,
131
+ }
132
+
133
+
134
+ def generate_data_json(runs_dir: str = "runs", output_file: str = "plots/data.json"):
135
+ """Generate data.json from all run directories."""
136
+ runs_path = Path(runs_dir)
137
+ if not runs_path.exists():
138
+ print(f"Runs directory {runs_dir} not found!")
139
+ return
140
+
141
+ runs_data = []
142
+
143
+ # Process each run directory
144
+ for run_dir in runs_path.iterdir():
145
+ if run_dir.is_dir():
146
+ print(f"Processing run: {run_dir.name}")
147
+ run_data = process_run_directory(run_dir)
148
+ if run_data:
149
+ runs_data.append(run_data)
150
+ print(f" ✓ Found {len(run_data['training_metrics'])} training metrics")
151
+ print(
152
+ f" ✓ Found {len(run_data['evaluation_results'])} evaluation results"
153
+ )
154
+ else:
155
+ print(" ✗ No valid data found")
156
+
157
+ if not runs_data:
158
+ print("No valid runs found!")
159
+ return
160
+
161
+ # Create output data structure
162
+ output_data = {
163
+ "runs": runs_data,
164
+ "summary": {
165
+ "total_runs": len(runs_data),
166
+ "run_names": [run["run_name"] for run in runs_data],
167
+ },
168
+ }
169
+
170
+ # Ensure output directory exists
171
+ output_path = Path(output_file)
172
+ output_path.parent.mkdir(parents=True, exist_ok=True)
173
+
174
+ # Write to file
175
+ with open(output_path, "w", encoding="utf-8") as f:
176
+ json.dump(output_data, f, indent=2, ensure_ascii=False)
177
+
178
+ print(f"\n✓ Generated {output_file} with {len(runs_data)} runs")
179
+ print(
180
+ f"✓ Total training metrics: {sum(len(run['training_metrics']) for run in runs_data)}"
181
+ )
182
+ print(
183
+ f"✓ Total evaluation results: {sum(len(run['evaluation_results']) for run in runs_data)}"
184
+ )
185
+
186
+
187
+ if __name__ == "__main__":
188
+ import argparse
189
+
190
+ parser = argparse.ArgumentParser(
191
+ description="Generate data.json from training logs"
192
+ )
193
+ parser.add_argument("--runs-dir", default="runs", help="Path to runs directory")
194
+ parser.add_argument("--output", default="plots/data.json", help="Output file path")
195
+
196
+ args = parser.parse_args()
197
+
198
+ generate_data_json(args.runs_dir, args.output)
scripts/train.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ A minimal script to train the Pico language model. In practice, you should just use the
4
+ `poetry run train` command to run the training pipeline. Doing so will invoke this script.
5
+ Training logic is located in `src/training/trainer.py`.
6
+ """
7
+
8
+ from pathlib import Path
9
+
10
+ import click
11
+
12
+ from src.training.trainer import Trainer
13
+
14
+
15
+ @click.command()
16
+ @click.option(
17
+ "--config_path",
18
+ "config_path",
19
+ type=click.Path(exists=True, path_type=Path),
20
+ help="Path to the training configuration file",
21
+ )
22
+ def main(config_path: Path) -> None:
23
+ """Train the Pico language model using the specified configuration."""
24
+
25
+ trainer = Trainer(config_path=str(config_path))
26
+ trainer.train()
27
+
28
+
29
+ if __name__ == "__main__":
30
+ main()
setup.sh ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # This script sets up the project by installing dependencies, checking for a poetry environment,
3
+ # and installing pre-commit hooks.
4
+
5
+ # Add color and formatting variables at the top
6
+ GREEN='\033[0;32m'
7
+ BLUE='\033[0;34m'
8
+ YELLOW='\033[1;33m'
9
+ RED='\033[0;31m'
10
+ NC='\033[0m' # No Color
11
+ BOLD='\033[1m'
12
+
13
+ # Initialize error tracking
14
+ ERRORS_FOUND=0
15
+
16
+ # Function for section headers
17
+ print_section() {
18
+ echo -e "\n${BOLD}${BLUE}=== $1 ===${NC}\n"
19
+ }
20
+
21
+ # Function for success messages
22
+ print_success() {
23
+ echo -e "${GREEN}✓ $1${NC}"
24
+ }
25
+
26
+ # Function for warnings
27
+ print_warning() {
28
+ echo -e "${YELLOW}⚠ $1${NC}"
29
+ }
30
+
31
+ # --- GIT LFS SETUP --- #
32
+ print_section "Git LFS Setup"
33
+ if ! command -v git-lfs &> /dev/null; then
34
+ print_warning "git-lfs is not installed. Some model checkpointing functionality may not work correctly."
35
+ ERRORS_FOUND=$((ERRORS_FOUND + 1))
36
+
37
+ # Check the operating system
38
+ if [[ "$OSTYPE" == "darwin"* ]]; then
39
+ # macOS
40
+ echo -e "${YELLOW} You can install it using Homebrew:${NC}"
41
+ echo " brew install git-lfs"
42
+ elif [[ "$OSTYPE" == "linux-gnu"* ]]; then
43
+ # Linux
44
+ echo -e "${YELLOW} You can install it using your package manager:${NC}"
45
+ if command -v apt-get &> /dev/null; then
46
+ # Ubuntu/Debian
47
+ echo " curl -s https://packagecloud.io/install/repositories/github/git-lfs/script.deb.sh | sudo bash"
48
+ echo " sudo apt-get install git-lfs"
49
+ elif command -v yum &> /dev/null; then
50
+ # CentOS/RHEL
51
+ echo " curl -s https://packagecloud.io/install/repositories/github/git-lfs/script.rpm.sh | sudo bash"
52
+ echo " sudo yum install git-lfs"
53
+ else
54
+ print_warning "Could not detect package manager. Please install git-lfs manually."
55
+ fi
56
+ else
57
+ print_warning "Unsupported operating system. Please install git-lfs manually."
58
+ fi
59
+ else
60
+ git-lfs install
61
+ print_success "git-lfs installed and initialized"
62
+ fi
63
+
64
+ # --- CUDA VERSION CHECK --- #
65
+ print_section "CUDA Version Check"
66
+ if command -v nvidia-smi &> /dev/null; then
67
+ CUDA_VERSION=$(nvidia-smi | sed -n 's/.*CUDA Version: \([0-9.]*\).*/\1/p')
68
+
69
+ if [[ -z "$CUDA_VERSION" ]]; then
70
+ ERRORS_FOUND=$((ERRORS_FOUND + 1))
71
+ print_warning "nvidia-smi failed to communicate with the NVIDIA driver."
72
+ echo -e "${YELLOW} Ensure that the latest NVIDIA driver is installed and running.${NC}"
73
+ else
74
+ MAJOR_VERSION=${CUDA_VERSION%.*}
75
+ MINOR_VERSION=${CUDA_VERSION#*.}
76
+
77
+ if [ "$MAJOR_VERSION" -lt 12 ] || ([ "$MAJOR_VERSION" -eq 12 ] && [ "$MINOR_VERSION" -lt 1 ]); then
78
+ ERRORS_FOUND=$((ERRORS_FOUND + 1))
79
+ print_warning "CUDA version ${MAJOR_VERSION}.${MINOR_VERSION} detected."
80
+ echo -e "${YELLOW} Some multi-node communication GPU features may not work properly.${NC}"
81
+ echo -e "${YELLOW} CUDA version 12.1 or newer is recommended.${NC}"
82
+ else
83
+ print_success "CUDA version ${MAJOR_VERSION}.${MINOR_VERSION} detected"
84
+ fi
85
+ fi
86
+ else
87
+ ERRORS_FOUND=$((ERRORS_FOUND + 1))
88
+ print_warning "nvidia-smi not found. Unable to check CUDA version."
89
+ echo -e "${YELLOW} Ensure that NVIDIA drivers and CUDA version at 12.1 or newer are installed for GPU support.${NC}"
90
+ fi
91
+
92
+
93
+ # ---- ENVIRONMENT VARIABLES ---- #
94
+ print_section "Environment Variables"
95
+ if [ -f .env ]; then
96
+ print_success "Loading environment variables from .env..."
97
+ source .env
98
+ if [[ -n "$HF_TOKEN" && -n "$WANDB_API_KEY" ]]; then
99
+ print_success "Both HF_TOKEN and WANDB_API_KEY are set and loaded!"
100
+ else
101
+ print_warning "One or both of HF_TOKEN and WANDB_API_KEY are not set."
102
+ fi
103
+ else
104
+ print_warning "No .env file found."
105
+ echo -e "${YELLOW} You might need to create one with HF_TOKEN and WANDB_API_KEY${NC}"
106
+ echo -e "${YELLOW} Example .env contents:${NC}"
107
+ echo " export HF_TOKEN=your_huggingface_token"
108
+ echo " export WANDB_API_KEY=your_wandb_key"
109
+ ERRORS_FOUND=$((ERRORS_FOUND + 1))
110
+ fi
111
+
112
+ # ---- POETRY SETUP ---- #
113
+ print_section "Poetry Setup"
114
+
115
+ # First check if Poetry is installed
116
+ if ! command -v poetry &> /dev/null; then
117
+ echo "Poetry not found. Installing..."
118
+
119
+ # Run the installation command
120
+ curl -sSL https://install.python-poetry.org | python3 -
121
+ POETRY_INSTALL_STATUS=$?
122
+
123
+ if [ $POETRY_INSTALL_STATUS -ne 0 ]; then
124
+ print_warning "Poetry installation failed!"
125
+ ERRORS_FOUND=$((ERRORS_FOUND + 1))
126
+ else
127
+ export PATH="$HOME/.local/bin:$PATH"
128
+
129
+ # Verify installation succeeded
130
+ if ! command -v poetry &> /dev/null; then
131
+ print_warning "Poetry was installed but cannot be found in PATH!"
132
+ echo -e "${YELLOW} Try adding this to your shell profile:${NC}"
133
+ echo " export PATH=\"\$HOME/.local/bin:\$PATH\""
134
+ ERRORS_FOUND=$((ERRORS_FOUND + 1))
135
+ else
136
+ print_success "Poetry installed successfully"
137
+ fi
138
+ fi
139
+ else
140
+ print_success "Poetry already installed"
141
+ fi
142
+
143
+ # Then check for virtual environment
144
+ if [ ! -d ".venv" ]; then
145
+ echo "No virtual environment found. Creating one..."
146
+ poetry config virtualenvs.in-project true
147
+
148
+ # Create virtual environment and install dependencies
149
+ poetry install --with dev
150
+ POETRY_VENV_STATUS=$?
151
+
152
+ if [ $POETRY_VENV_STATUS -ne 0 ]; then
153
+ print_warning "Failed to create Poetry virtual environment!"
154
+ ERRORS_FOUND=$((ERRORS_FOUND + 1))
155
+ else
156
+ print_success "Poetry environment created successfully"
157
+ fi
158
+ else
159
+ print_success "Poetry environment already exists"
160
+ fi
161
+
162
+ # ---- PRE-COMMIT SETUP ---- #
163
+ print_section "Pre-commit Setup"
164
+
165
+ # Install pre-commit hooks
166
+ echo "Installing pre-commit hooks..."
167
+ poetry run pre-commit install
168
+ if [ $? -ne 0 ]; then
169
+ print_warning "Failed to install pre-commit hooks!"
170
+ ERRORS_FOUND=$((ERRORS_FOUND + 1))
171
+ else
172
+ print_success "Pre-commit hooks installed"
173
+ fi
174
+
175
+ # Run pre-commit hooks on all files
176
+ echo "Running pre-commit hooks on all files..."
177
+ poetry run pre-commit run --all-files
178
+ if [ $? -ne 0 ]; then
179
+ print_warning "Pre-commit encountered issues with some files"
180
+ ERRORS_FOUND=$((ERRORS_FOUND + 1))
181
+ else
182
+ print_success "Pre-commit initial run complete"
183
+ fi
184
+
185
+ # --- Final Status Message --- #
186
+
187
+ # Final status message
188
+ print_section "Setup Status"
189
+ if [ $ERRORS_FOUND -eq 0 ]; then
190
+ print_success "Setup Complete! 🎉"
191
+ print_success "To activate the virtual environment, run: poetry env activate"
192
+ else
193
+ print_warning "Setup completed with warnings and errors! Please check the messages above."
194
+ echo -e "${YELLOW} ${ERRORS_FOUND} issue(s) were detected that may affect functionality.${NC}"
195
+ if [ -d ".venv" ]; then
196
+ echo -e "${YELLOW} You can still activate the environment with: poetry env activate${NC}"
197
+ else
198
+ echo -e "${RED} The virtual environment setup failed. Fix the issues before proceeding.${NC}"
199
+ fi
200
+ fi
src/checkpointing/__init__.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Pico Checkpointing Package
3
+
4
+ We subdivide the checkpointing into training, evaluation, and learning_dynamics. Training
5
+ checkpoints store the model, optimizer, and learning rate scheduler. Evaluation checkpoints store
6
+ the evaluation results on the defined metrics. Learning dynamics checkpoints store activations and gradients used for
7
+ learning dynamics analysis.
8
+ """
9
+
10
+ from .evaluation import save_evaluation_results
11
+ from .learning_dynamics import (
12
+ compute_learning_dynamics_states,
13
+ save_learning_dynamics_states,
14
+ )
15
+ from .training import load_checkpoint, save_checkpoint
16
+
17
+ __all__ = [
18
+ "compute_learning_dynamics_states",
19
+ "load_checkpoint",
20
+ "save_checkpoint",
21
+ "save_evaluation_results",
22
+ "save_learning_dynamics_states",
23
+ ]
src/checkpointing/evaluation.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Utilities for checkpointing evaluation-related states (i.e. evaluation results, etc.)
3
+
4
+ We save the evaluation results in a JSON file at the step-specific evaluation results directory.
5
+ """
6
+
7
+ import json
8
+ import os
9
+ from typing import Any, Dict
10
+
11
+ from huggingface_hub import upload_folder
12
+ from lightning.fabric import Fabric
13
+ from lightning.fabric.utilities.rank_zero import rank_zero_only
14
+
15
+ from src.config import CheckpointingConfig
16
+ from src.training.utils.io import use_backoff
17
+
18
+
19
+ @rank_zero_only
20
+ @use_backoff()
21
+ def save_evaluation_results(
22
+ checkpointing_config: CheckpointingConfig,
23
+ checkpoint_step: int,
24
+ fabric: Fabric,
25
+ evaluation_results: Dict[str, Any],
26
+ ) -> None:
27
+ """Save evaluation results to disk and optionally to HuggingFace Hub.
28
+
29
+ The evaluation results are saved in the following directory structure:
30
+ {checkpointing_config.runs_dir}/
31
+ └── {checkpointing_config.run_name}/
32
+ └── {checkpointing_config.eval_results_dir}/
33
+ └── step_{checkpoint_step}.json
34
+
35
+ NOTE: this function is only called on rank 0 to avoid conflicts; assumes that the evaluation
36
+ results are gathered on rank 0.
37
+
38
+ Args:
39
+ checkpointing_config: Configuration object containing checkpoint settings
40
+ checkpoint_step: Current training checkpoint step (i.e. number of learning steps taken)
41
+ fabric: Lightning Fabric instance
42
+ evaluation_results: Dictionary containing evaluation metrics
43
+ """
44
+
45
+ run_dir = os.path.join(checkpointing_config.runs_dir, checkpointing_config.run_name)
46
+ eval_results_dir = os.path.join(
47
+ run_dir, checkpointing_config.evaluation.eval_results_dir
48
+ )
49
+
50
+ os.makedirs(eval_results_dir, exist_ok=True)
51
+
52
+ curr_eval_results_path = os.path.join(
53
+ eval_results_dir, f"step_{checkpoint_step}.json"
54
+ )
55
+
56
+ # save out as json
57
+ with open(curr_eval_results_path, "w") as f:
58
+ json.dump(evaluation_results, f)
59
+
60
+ if checkpointing_config.save_to_hf:
61
+ upload_folder(
62
+ folder_path=eval_results_dir,
63
+ path_in_repo=checkpointing_config.evaluation.eval_results_dir,
64
+ repo_id=checkpointing_config.hf_checkpoint.repo_id,
65
+ commit_message=f"Saving Evaluation Results -- Step {checkpoint_step}",
66
+ revision=checkpointing_config.run_name,
67
+ token=os.getenv("HF_TOKEN"),
68
+ )
src/checkpointing/learning_dynamics.py ADDED
@@ -0,0 +1,424 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Utilities for checkpointing learning dynamics-related states (i.e. activations, weights, grads, etc.)
3
+
4
+ We save the learning dynamics states in a subdirectory of the checkpointing directory.
5
+ """
6
+
7
+ import os
8
+ import re
9
+ from typing import Dict, Optional
10
+
11
+ import deepspeed
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.optim as optim
15
+ from datasets import Dataset
16
+ from huggingface_hub import upload_folder
17
+ from lightning.fabric import Fabric
18
+ from lightning.fabric.strategies import DeepSpeedStrategy
19
+ from lightning.fabric.utilities.rank_zero import rank_zero_only
20
+ from torch.nn import functional as F
21
+ from torch.utils.data import DataLoader
22
+ from transformers import PreTrainedTokenizerBase
23
+
24
+ from src.config import CheckpointingConfig
25
+ from src.config.checkpointing_config import LearningDynamicsCheckpointingConfig
26
+ from src.training.utils.initialization import initialize_model
27
+ from src.training.utils.io import use_backoff
28
+
29
+
30
+ # NOTE: DeepSpeed requires a dummy optimizer to be passed in to the setup function
31
+ class DummyOptimizer(optim.Optimizer):
32
+ def __init__(self, params):
33
+ super().__init__(params, defaults={})
34
+
35
+
36
+ class CheckpointStateExtractor:
37
+ """
38
+ Class to extract and save the states of a model at a given checkpoint step for learning
39
+ dynamics research.
40
+ """
41
+
42
+ def __init__(
43
+ self,
44
+ learning_dynamics_config: LearningDynamicsCheckpointingConfig,
45
+ fabric: Fabric,
46
+ model: nn.Module,
47
+ ):
48
+ self.learning_dynamics_config = learning_dynamics_config
49
+ self.fabric = fabric
50
+ self.model = model
51
+
52
+ def extract_states(self, dataloader, compute_gradients: bool = False):
53
+ """Extracts model states (activations, weights, and optionally gradients).
54
+
55
+ Given a dataloader, this function will perform a forward pass of the model on each batch,
56
+ and save the activations and weights at each layer. If compute_gradients is True, it will
57
+ also compute the gradients of the model parameters.
58
+
59
+ Args:
60
+ dataloader: The dataloader containing the dataset to extract states from.
61
+ compute_gradients: Whether to compute the gradients of the model parameters.
62
+
63
+ Returns:
64
+ A dictionary containing the activations, weights, and optionally gradients of the model.
65
+ """
66
+ checkpoint_activations = {}
67
+ checkpoint_weights = {}
68
+
69
+ # NOTE: to extract activations and weights, we need to setup forward hooks on the layers
70
+ # of the model that we are interested in. This is a good intro to forward hooks if you
71
+ # are not familiar: https://web.stanford.edu/~nanbhas/blog/forward-hooks-pytorch/
72
+ forward_hooks = self._setup_forward_hooks(
73
+ checkpoint_activations,
74
+ checkpoint_weights,
75
+ )
76
+
77
+ ########################################################
78
+ #
79
+ # Forward Pass: Extract activations and weights; and compute gradients
80
+ #
81
+ ########################################################
82
+
83
+ for sub_batch in dataloader:
84
+ _input_ids = torch.tensor(sub_batch["input_ids"], device=self.fabric.device)
85
+
86
+ if compute_gradients:
87
+ if "labels" in sub_batch:
88
+ input_ids = _input_ids
89
+ labels = torch.tensor(
90
+ sub_batch["labels"], device=self.fabric.device
91
+ )
92
+ else:
93
+ input_ids = _input_ids[:, :-1]
94
+ labels = _input_ids[:, 1:]
95
+ else:
96
+ input_ids = _input_ids
97
+ labels = None
98
+
99
+ if labels is None:
100
+ # we can throw away the outputs, we are only interested in the hidden states
101
+ with torch.no_grad():
102
+ _ = self.model(input_ids)
103
+ else:
104
+ # NOTE: if we are computing gradients, calling backwards will compute the gradients
105
+ # of the model parameters.
106
+ outputs, _ = self.model(input_ids)
107
+ outputs = outputs.transpose(1, 2)
108
+ loss = F.cross_entropy(outputs, labels)
109
+ self.fabric.backward(loss, model=self.model)
110
+
111
+ # cleanup forward hooks
112
+ # NOTE this is not strictly necessary, since self.model is a deepcopy of the original model
113
+ # but it is good practice to remove the hooks after the forward pass is complete.
114
+ for hook in forward_hooks:
115
+ hook.remove()
116
+
117
+ ########################################################
118
+ #
119
+ # Extract gradients from the target tensors of the model
120
+ #
121
+ ########################################################
122
+
123
+ layer_suffixes = self.learning_dynamics_config.layer_suffixes
124
+ checkpoint_gradients = {}
125
+ if compute_gradients:
126
+ for name, param in self.model.named_parameters():
127
+ # only do this for the weight matrix of the layer_suffixes
128
+ if (
129
+ any(layer_suffix in name for layer_suffix in layer_suffixes)
130
+ and "weight" in name
131
+ ):
132
+ if isinstance(self.fabric.strategy, DeepSpeedStrategy):
133
+ _grad = deepspeed.utils.safe_get_full_grad(param)
134
+ else:
135
+ _grad = param.grad
136
+
137
+ assert _grad is not None, f"Gradient is None for layer: {name}"
138
+ name = re.sub(r"\.weight", "", name)
139
+ checkpoint_gradients[name] = _grad.detach().cpu()
140
+
141
+ # zero out the gradients
142
+ self.model.zero_grad()
143
+
144
+ return checkpoint_activations, checkpoint_weights, checkpoint_gradients
145
+
146
+ ########################################################
147
+ #
148
+ # Setup forward hooks to save activations and weights at each layer
149
+ #
150
+ ########################################################
151
+
152
+ def _setup_forward_hooks(self, checkpoint_activations, checkpoint_weights):
153
+ """Setup forward hooks for the model to save activations and weights at each layer.
154
+
155
+ This function will setup forward hooks on the layers of the model that we are interested in.
156
+ The forward hooks will save the activations and weights at each layer whenever the forward pass
157
+ is performed.
158
+
159
+ Args:
160
+ checkpoint_activations: A dictionary to store the activations at each layer.
161
+ checkpoint_weights: A dictionary to store the weights at each layer.
162
+
163
+ Returns:
164
+ A list of forward hooks. We do this so that we can remove the hooks after the forward pass
165
+ is complete.
166
+ """
167
+
168
+ forward_hooks = []
169
+ layer_suffixes = self.learning_dynamics_config.layer_suffixes
170
+
171
+ for name, module in self.model.named_modules():
172
+ if any(layer_suffix in name for layer_suffix in layer_suffixes):
173
+ _forward_hook = module.register_forward_hook(
174
+ self._get_forward_hook(
175
+ name, checkpoint_activations, checkpoint_weights
176
+ )
177
+ )
178
+ forward_hooks.append(_forward_hook)
179
+ return forward_hooks
180
+
181
+ def _get_forward_hook(
182
+ self, module_name, checkpoint_activations, checkpoint_weights
183
+ ):
184
+ """Get a forward hook for a given module.
185
+
186
+ This function is called by the _setup_forward_hooks function to setup a forward hook for a given
187
+ module. This functions is a closure that captures the module_name, checkpoint_activations, and
188
+ checkpoint_weights.
189
+
190
+ Args:
191
+ module_name: The name of the module to setup a forward hook for.
192
+ checkpoint_activations: A dictionary to store the activations at each layer.
193
+ checkpoint_weights: A dictionary to store the weights at each layer.
194
+
195
+ Returns:
196
+ A forward hook for the given module.
197
+ """
198
+
199
+ def _forward_hook(module, _, module_out):
200
+ sequence_idx = self.learning_dynamics_config.sequence_idx
201
+
202
+ local_activations = module_out[:, sequence_idx, :].detach()
203
+
204
+ # Gather activations from all processes using fabric
205
+ gathered_activations = self.fabric.all_gather(local_activations)
206
+
207
+ # Reshape from [num_processes, batch_size, hidden_dim] to [total_batch_size, hidden_dim]
208
+ # NOTE: transposing allows us to interleave the activations from each process so that
209
+ # they are in the correct order. (i.e. activation N is from data sample N)
210
+ gathered_activations = gathered_activations.transpose(0, 1).reshape(
211
+ -1, gathered_activations.shape[-1]
212
+ )
213
+
214
+ # check if there is already a key for the module name
215
+ if module_name not in checkpoint_activations:
216
+ # if there is no key, then we create a new key and store the hidden states
217
+ checkpoint_activations[module_name] = (
218
+ gathered_activations.detach().cpu()
219
+ )
220
+
221
+ # extract the weight matrix just once
222
+ weight_matrix = module.weight.detach().cpu()
223
+ checkpoint_weights[module_name] = weight_matrix
224
+ else:
225
+ # if there is already a key, then we concatenate the new hidden states to the existing ones
226
+ checkpoint_activations[module_name] = torch.cat(
227
+ (
228
+ checkpoint_activations[module_name],
229
+ gathered_activations.detach().cpu(),
230
+ )
231
+ )
232
+
233
+ return _forward_hook
234
+
235
+
236
+ def compute_learning_dynamics_states(
237
+ checkpointing_config: CheckpointingConfig,
238
+ fabric: Fabric,
239
+ model: nn.Module,
240
+ dataset: Dataset,
241
+ compute_gradients: bool = False,
242
+ ) -> Dict[str, torch.Tensor]:
243
+ """Computes the learning dynamics metrics for a given checkpoint step.
244
+
245
+ Uses the CheckpointStateExtractor to extract the activations, weights, and optionally gradients
246
+ of the model at a given checkpoint step.
247
+
248
+ Args:
249
+ checkpointing_config: The configuration object for checkpointing.
250
+ fabric: The Fabric instance for distributed training.
251
+ model: The model to extract states from.
252
+ dataset: The dataset to extract states from.
253
+ compute_gradients: Whether to compute the gradients of the model parameters.
254
+
255
+ Returns:
256
+ A dictionary containing the activations, weights, and optionally gradients of the model.
257
+ """
258
+
259
+ # NOTE: Synchronizing processes for fabric dataloader setup
260
+ fabric.barrier()
261
+ model.to("cpu") # Offloading model to CPU
262
+
263
+ # Setting up Dataloader for learning dynamics
264
+ def _collate_fn(batch):
265
+ return {"input_ids": [entry["input_ids"] for entry in batch]}
266
+
267
+ batch_size = checkpointing_config.learning_dynamics.batch_size
268
+ sub_batch_size = batch_size // fabric.world_size
269
+
270
+ # NOTE: Make sure to set drop_last to False, otherwise the last batch will be dropped
271
+ # and we will not have a complete set of activations for the last sample. Also,
272
+ # we need to set shuffle to False, otherwise the activations will be shuffled across
273
+ # processes and we will not be able to interleave them correctly.
274
+ extractor_dataloader = DataLoader(
275
+ dataset,
276
+ batch_size=sub_batch_size,
277
+ shuffle=False,
278
+ collate_fn=_collate_fn,
279
+ drop_last=False,
280
+ )
281
+ extractor_dataloader = fabric.setup_dataloaders(
282
+ extractor_dataloader, use_distributed_sampler=True
283
+ )
284
+
285
+ # Create a new model instance with same parameters but zero gradients
286
+ _model = initialize_model(model.config)
287
+ _model.load_state_dict(model.state_dict())
288
+
289
+ if isinstance(fabric.strategy, DeepSpeedStrategy):
290
+ _model, _ = fabric.setup(_model, DummyOptimizer(_model.parameters()))
291
+ else:
292
+ _model = fabric.setup(_model)
293
+
294
+ _model.zero_grad()
295
+
296
+ # setup forward hooks for the model to save activations and weights at each layer
297
+ state_extractor = CheckpointStateExtractor(
298
+ checkpointing_config.learning_dynamics, fabric, _model
299
+ )
300
+
301
+ checkpoint_activations, checkpoint_weights, checkpoint_gradients = (
302
+ state_extractor.extract_states(
303
+ extractor_dataloader, compute_gradients=compute_gradients
304
+ )
305
+ )
306
+
307
+ del _model
308
+ torch.cuda.empty_cache()
309
+
310
+ # NOTE: Synchronizing processes for model setup
311
+ fabric.barrier()
312
+
313
+ model.to(fabric.device)
314
+
315
+ # NOTE: Trimming down the activations to match the dataset size;
316
+ # This is because the DataSampler might add extra samples to the dataset to make it evenly divisible
317
+ # by the number of processes. We need to remove these extra samples.
318
+ for layer_name, layer_activations in checkpoint_activations.items():
319
+ if len(layer_activations) > len(dataset):
320
+ checkpoint_activations[layer_name] = layer_activations[: len(dataset)]
321
+ elif len(layer_activations) < len(dataset):
322
+ raise ValueError(
323
+ f"Number of activations ({len(layer_activations)}) in layer {layer_name} does not match number of samples in dataset ({len(dataset)})"
324
+ )
325
+
326
+ return {
327
+ "activations": checkpoint_activations,
328
+ "weights": checkpoint_weights,
329
+ "gradients": checkpoint_gradients,
330
+ }
331
+
332
+
333
+ @rank_zero_only
334
+ @use_backoff()
335
+ def save_learning_dynamics_states(
336
+ checkpointing_config: CheckpointingConfig,
337
+ checkpoint_step: int,
338
+ prefix: str,
339
+ fabric: Fabric,
340
+ learning_dynamics_states: Dict[str, torch.Tensor],
341
+ learning_dynamics_dataset: Optional[Dataset] = None,
342
+ tokenizer: Optional[PreTrainedTokenizerBase] = None,
343
+ ) -> None:
344
+ """Save the learning dynamics metrics to the checkpointing directory.
345
+
346
+ By default only the learning dynamics states are saved. If the learning dynamics dataset
347
+ is provided, it is also saved; if a tokenizer is provided, the dataset is also detokenized
348
+ (i.e. a new column with the text is added to the dataset).
349
+
350
+ The learning dynamics dataset is saved in the checkpointing directory as a HuggingFace
351
+ dataset.
352
+
353
+ Creates a versioned checkpoint directory with the following structure:
354
+
355
+ {checkpointing_config.runs_dir}/
356
+ └── {checkpointing_config.run_name}/
357
+ └── {checkpointing_config.checkpoints_dir}/
358
+ ├── step_{checkpoint_step}/
359
+ │ └── {checkpointing_config.learning_dynamics_dir}/ # Learning Dynamics files
360
+ │ ├── {prefix}_activations.pt
361
+ │ ├── {prefix}_weights.pt
362
+ │ └── {prefix}_gradients.pt
363
+ │ └── {prefix}_data/ # if learning_dynamics_dataset is provided
364
+ └── latest -> step_{checkpoint_step}/
365
+
366
+ NOTE: this function is only called on rank 0
367
+
368
+ Args:
369
+ checkpointing_config: The configuration object for checkpointing.
370
+ checkpoint_step: The checkpoint step at which the learning dynamics states were computed.
371
+ prefix: The prefix for the learning dynamics states.
372
+ fabric: The Fabric instance for distributed training.
373
+ learning_dynamics_states: The learning dynamics states to save.
374
+ learning_dynamics_dataset: The dataset containing learning dynamics data,
375
+ including input IDs that need to be decoded. (optional)
376
+ tokenizer: The tokenizer used to decode input IDs into text. (optional)
377
+ """
378
+
379
+ runs_dir = checkpointing_config.runs_dir
380
+ run_name = checkpointing_config.run_name
381
+ checkpoints_dir = checkpointing_config.checkpoints_dir
382
+ learning_dynamics_dir = checkpointing_config.learning_dynamics_dir
383
+
384
+ run_path = os.path.join(runs_dir, run_name)
385
+ root_checkpoint_path = os.path.join(run_path, checkpoints_dir)
386
+ checkpoint_path = os.path.join(root_checkpoint_path, f"step_{checkpoint_step}")
387
+ learning_dynamics_path = os.path.join(checkpoint_path, learning_dynamics_dir)
388
+ os.makedirs(learning_dynamics_path, exist_ok=True)
389
+
390
+ # save the learning dynamics states
391
+ for key, value in learning_dynamics_states.items():
392
+ if value is not None and len(value) > 0:
393
+ torch.save(
394
+ value, os.path.join(learning_dynamics_path, f"{prefix}_{key}.pt")
395
+ )
396
+
397
+ if learning_dynamics_dataset is not None:
398
+ if tokenizer is not None:
399
+ # go through dataset and decode the input ids; and add back into dataset
400
+ detokenized_dataset = {"input_ids": [], "text": []}
401
+
402
+ for entry in learning_dynamics_dataset:
403
+ input_ids = entry["input_ids"]
404
+ decoded_text = tokenizer.decode(input_ids, skip_special_tokens=True)
405
+ detokenized_dataset["input_ids"].append(input_ids)
406
+ detokenized_dataset["text"].append(decoded_text)
407
+
408
+ learning_dynamics_dataset = Dataset.from_dict(detokenized_dataset)
409
+
410
+ learning_dynamics_dataset_path = os.path.join(
411
+ learning_dynamics_path, f"{prefix}_data"
412
+ )
413
+ learning_dynamics_dataset.save_to_disk(learning_dynamics_dataset_path)
414
+
415
+ if checkpointing_config.save_to_hf:
416
+ # Upload the HF model
417
+ upload_folder(
418
+ folder_path=learning_dynamics_path,
419
+ path_in_repo=learning_dynamics_dir,
420
+ repo_id=checkpointing_config.hf_checkpoint.repo_id,
421
+ commit_message=f"Saving Learning Dynamics Data ({prefix}) -- Step {checkpoint_step}",
422
+ revision=checkpointing_config.run_name,
423
+ token=os.getenv("HF_TOKEN"),
424
+ )
src/checkpointing/training.py ADDED
@@ -0,0 +1,287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Utilities for checkpointing training-related states (i.e. model, optimizer, lr_scheduler, etc.)
3
+
4
+ We save both a HuggingFace model and a Fabric-specific checkpoint. The HuggingFace model is
5
+ saved at the step-specific checkpoint directory, while the Fabric-specific checkpoint is saved
6
+ in a subdirectory. This is done to facilitate easier versioning of the HuggingFace model files
7
+ (which are what gets uploaded to the Hub).
8
+ """
9
+
10
+ import os
11
+ from dataclasses import asdict
12
+ from typing import Any, Dict, Tuple, Union
13
+
14
+ import yaml
15
+ from huggingface_hub import upload_file, upload_folder
16
+ from lightning.fabric import Fabric
17
+ from lightning.fabric.strategies import DeepSpeedStrategy
18
+ from lightning.fabric.utilities.seed import _collect_rng_states, _set_rng_states
19
+ from torch import nn
20
+ from torch.optim import Optimizer
21
+ from torch.optim.lr_scheduler import LRScheduler
22
+ from transformers import PreTrainedTokenizerBase
23
+
24
+ from src.config import CheckpointingConfig
25
+ from src.training.utils.io import use_backoff
26
+
27
+
28
+ @use_backoff()
29
+ def load_checkpoint(
30
+ checkpointing_config: CheckpointingConfig,
31
+ checkpoint_step: Union[str, int],
32
+ fabric: Fabric,
33
+ model: nn.Module,
34
+ optimizer: Optimizer,
35
+ lr_scheduler: LRScheduler,
36
+ ) -> Tuple[nn.Module, Optimizer, LRScheduler, int]:
37
+ """Load model checkpoint and associated states from a given step.
38
+
39
+ Args:
40
+ checkpointing_config: Configuration object containing checkpoint settings
41
+ checkpoint_step: The step at which to load the checkpoint
42
+ fabric: Lightning Fabric instance for distributed training support
43
+ model: The model instance to load weights into
44
+ optimizer: The optimizer instance to load states into
45
+ lr_scheduler: The learning rate scheduler to load states into
46
+
47
+ Returns:
48
+ Tuple containing the model, optimizer, lr_scheduler, and checkpoint step.
49
+ Returns None if no checkpoint is found.
50
+ """
51
+
52
+ if isinstance(checkpoint_step, int):
53
+ checkpoint_step = f"step_{checkpoint_step}"
54
+
55
+ checkpoint_path = os.path.join(
56
+ checkpointing_config.runs_dir,
57
+ checkpointing_config.run_name,
58
+ checkpointing_config.checkpoints_dir,
59
+ checkpoint_step,
60
+ )
61
+
62
+ if not os.path.exists(checkpoint_path):
63
+ return None
64
+
65
+ # Load from specified fabric checkpoint subdirectory
66
+ fabric_checkpoint_path = os.path.join(
67
+ checkpoint_path, checkpointing_config.fabric_checkpoint_dir
68
+ )
69
+
70
+ checkpoint_state = {
71
+ "_model": model,
72
+ "_optimizer": optimizer,
73
+ "_lr_scheduler": lr_scheduler,
74
+ }
75
+
76
+ if not isinstance(fabric.strategy, DeepSpeedStrategy):
77
+ fabric_load_file = os.path.join(
78
+ fabric_checkpoint_path, checkpointing_config.fabric_checkpoint_filename
79
+ )
80
+ else:
81
+ # Deepspeed checkpoints create sub-directory with distributed checkpoint file
82
+ fabric_load_file = fabric_checkpoint_path
83
+
84
+ extra_state = fabric.load(os.path.join(fabric_load_file), state=checkpoint_state)
85
+
86
+ # NOTE: extra_state will contain any additional states that were saved in the checkpoint
87
+ checkpoint_step = extra_state["_checkpoint_step"]
88
+
89
+ if "_rng_states" in extra_state:
90
+ _rng_states = extra_state["_rng_states"]
91
+ _set_rng_states(_rng_states)
92
+
93
+ return model, optimizer, lr_scheduler, checkpoint_step
94
+
95
+
96
+ @use_backoff()
97
+ def save_checkpoint(
98
+ configs: Dict[str, Any],
99
+ checkpoint_step: int,
100
+ fabric: Fabric,
101
+ model: nn.Module,
102
+ optimizer: Optimizer,
103
+ lr_scheduler: LRScheduler,
104
+ tokenizer: PreTrainedTokenizerBase,
105
+ upload_logs: bool = False,
106
+ ) -> None:
107
+ """Save training checkpoint and associated states to disk and optionally to HuggingFace Hub.
108
+
109
+ We save the following files:
110
+ - HuggingFace model files (config.json, pytorch_model.bin)
111
+ - Tokenizer files (vocab.json, merges.txt)
112
+ - Fabric-specific files - fabric state of the model, optimizer, and lr_scheduler. If using
113
+ DeepSpeed, the checkpoint is saved in a subdirectory, otherwise it is saved in a single file.
114
+
115
+ Note that the HuggingFace model files are saved at the step-specific checkpoint directory, while the
116
+ Fabric-specific files are saved in a subdirectory. This is done to facilitate easier
117
+ versioning of the HuggingFace model files (which are what gets uploaded to the Hub).
118
+
119
+ NOTE: Why do we save a HF model at all? We do this because it makes it easier to load the model
120
+ in a separate script for evaluation and to play nicely with the HuggingFace Hub.
121
+
122
+ Creates a versioned checkpoint directory with the following structure:
123
+
124
+ {checkpointing_config.runs_dir}/
125
+ └── {checkpointing_config.run_name}/
126
+ └── training_config.yaml # Training config
127
+ └── {checkpointing_config.checkpoints_dir}/
128
+ ├── step_{checkpoint_step}/
129
+ │ ├── config.json # HuggingFace model config
130
+ │ ├── model.safetensors # HuggingFace model weights
131
+ │ ├── pico_{model_type}.py # HuggingFace custom model class
132
+ │ ├── tokenizer.json # Tokenizer vocab
133
+ │ ├── tokenizer_config.json # Tokenizer config
134
+ │ └── {checkpointing_config.fabric_checkpoint_dir}/ # Fabric-specific files
135
+ │ └── checkpoint/ # Distributed model checkpoint files (if using DeepSpeed)
136
+ │ OR
137
+ │ └── checkpoint.pt # Single checkpoint file (if using other strategies)
138
+ └── latest -> step_{checkpoint_step}/
139
+
140
+ Args:
141
+ configs: A dictionary containing the initialized configuration objects.
142
+ checkpoint_step: The current training checkpoint step (i.e. number of learning steps taken)
143
+ fabric: Lightning Fabric instance for distributed training support
144
+ model: The model instance to save
145
+ optimizer: The optimizer instance to save
146
+ lr_scheduler: The learning rate scheduler to save
147
+ tokenizer: The tokenizer to save
148
+ upload_logs: Whether to upload training logs to HF Hub (default: False)
149
+
150
+ """
151
+
152
+ checkpointing_config = configs["checkpointing"]
153
+
154
+ # Get the directories from the training config
155
+ runs_dir = checkpointing_config.runs_dir
156
+ checkpoints_dir = checkpointing_config.checkpoints_dir
157
+ fabric_checkpoint_dir = checkpointing_config.fabric_checkpoint_dir
158
+ logs_dir = checkpointing_config.logs_dir
159
+
160
+ run_path = os.path.join(runs_dir, checkpointing_config.run_name)
161
+ root_checkpoint_path = os.path.join(run_path, checkpoints_dir)
162
+ checkpoint_path = os.path.join(root_checkpoint_path, f"step_{checkpoint_step}")
163
+
164
+ # Create directories
165
+ os.makedirs(checkpoint_path, exist_ok=True)
166
+
167
+ ########################################################
168
+ #
169
+ # Save HuggingFace files
170
+ #
171
+ ########################################################
172
+
173
+ # NOTE: we convert the Pico model to a HuggingFace model before saving it. See `model.py`
174
+ # for more details.
175
+ if fabric.global_rank == 0:
176
+ hf_model = model.convert_to_hf_model()
177
+ hf_model.save_pretrained(checkpoint_path)
178
+ tokenizer.save_pretrained(checkpoint_path)
179
+
180
+ ########################################################
181
+ #
182
+ # Save Fabric-specific files
183
+ #
184
+ ########################################################
185
+
186
+ # Create fabric-specific subdirectory
187
+ fabric_checkpoint_path = os.path.join(checkpoint_path, fabric_checkpoint_dir)
188
+ os.makedirs(fabric_checkpoint_path, exist_ok=True)
189
+
190
+ # Save model states (use underscore to avoid conflicts with third-party libraries)
191
+ checkpoint_state = {
192
+ "_model": model,
193
+ "_optimizer": optimizer,
194
+ "_lr_scheduler": lr_scheduler,
195
+ "_checkpoint_step": checkpoint_step,
196
+ }
197
+
198
+ if not isinstance(fabric.strategy, DeepSpeedStrategy):
199
+ checkpoint_state["_rng_states"] = _collect_rng_states()
200
+ fabric_save_file = os.path.join(
201
+ fabric_checkpoint_path, checkpointing_config.fabric_checkpoint_filename
202
+ )
203
+ else:
204
+ # Deepspeed checkpoints create sub-directory with distributed checkpoint file
205
+ fabric_save_file = fabric_checkpoint_path
206
+
207
+ fabric.save(fabric_save_file, checkpoint_state)
208
+
209
+ if fabric.global_rank == 0:
210
+ # Save config in fabric directory
211
+ config_path = os.path.join(run_path, "training_config.yaml")
212
+ if not os.path.exists(config_path):
213
+ # Converting dataclasses to joined dicts and saving to file
214
+ _training_config = {}
215
+ for config_name, config in configs.items():
216
+ _training_config[config_name] = asdict(config)
217
+ with open(config_path, "w") as f:
218
+ yaml.dump(_training_config, f)
219
+
220
+ # Update latest symlink
221
+ latest_symlink_path = os.path.join(root_checkpoint_path, "latest")
222
+ if os.path.lexists(latest_symlink_path):
223
+ os.remove(latest_symlink_path)
224
+ os.symlink(
225
+ f"step_{checkpoint_step}", latest_symlink_path, target_is_directory=True
226
+ )
227
+
228
+ ########################################################
229
+ #
230
+ # Push to HuggingFace Hub (if configured)
231
+ #
232
+ ########################################################
233
+
234
+ if fabric.global_rank == 0:
235
+ # Push only on rank zero thread
236
+
237
+ if checkpointing_config.save_to_hf:
238
+ repo_id = checkpointing_config.hf_checkpoint.repo_id
239
+
240
+ # Upload the HF model
241
+ hf_model.push_to_hub(
242
+ repo_id=repo_id,
243
+ commit_message=f"Saving HF Model -- Step {checkpoint_step}",
244
+ revision=checkpointing_config.run_name,
245
+ token=os.getenv("HF_TOKEN"),
246
+ )
247
+
248
+ if checkpoint_step == 0:
249
+ # Uploading Tokenizer during first step since it never changes
250
+ tokenizer.push_to_hub(
251
+ repo_id=repo_id,
252
+ commit_message=f"Saving Tokenizer -- Step {checkpoint_step}",
253
+ revision=checkpointing_config.run_name,
254
+ token=os.getenv("HF_TOKEN"),
255
+ )
256
+
257
+ # Upload training config, also only in first step
258
+ upload_file(
259
+ path_or_fileobj=config_path,
260
+ path_in_repo="training_config.yaml",
261
+ repo_id=repo_id,
262
+ commit_message=f"Saving Training Config -- Step {checkpoint_step}",
263
+ revision=checkpointing_config.run_name,
264
+ token=os.getenv("HF_TOKEN"),
265
+ )
266
+
267
+ # Upload the fabric checkpoint directory
268
+ upload_folder(
269
+ folder_path=fabric_checkpoint_path,
270
+ path_in_repo=fabric_checkpoint_dir,
271
+ repo_id=repo_id,
272
+ commit_message=f"Saving Fabric Checkpoint -- Step {checkpoint_step}",
273
+ revision=checkpointing_config.run_name,
274
+ token=os.getenv("HF_TOKEN"),
275
+ )
276
+
277
+ # Upload logs if requested
278
+ if upload_logs:
279
+ logs_path = os.path.join(run_path, logs_dir)
280
+ upload_folder(
281
+ folder_path=logs_path,
282
+ path_in_repo=logs_dir,
283
+ repo_id=repo_id,
284
+ commit_message=f"Saving Logs -- Step {checkpoint_step}",
285
+ revision=checkpointing_config.run_name,
286
+ token=os.getenv("HF_TOKEN"),
287
+ )
src/config/__init__.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Pico Config Package
3
+
4
+ The modules of this package are where you can specify the hyperparameters for the Pico model,
5
+ the dataset, the training process, evaluation, etc.
6
+
7
+ As with anything else in Pico, we've designed for the configuration setup to be as flexible
8
+ as possible. By default the configs are implemented as vanilla dataclasses -- this makes it easy to
9
+ switch to different config management systems if you want, like hydra.
10
+
11
+ Some things to NOTE:
12
+ - All hyperparameters are initialized with default values, which can be overridden.
13
+ - The default vocab size is set to the size of the OLMo tokenizer.
14
+ """
15
+
16
+ # For convenience, we export the config classes here
17
+ from .checkpointing_config import CheckpointingConfig
18
+ from .data_config import DataConfig
19
+ from .evaluation_config import EvaluationConfig
20
+ from .model_config import ModelConfig
21
+ from .monitoring_config import MonitoringConfig
22
+ from .training_config import TrainingConfig
23
+
24
+ __all__ = [
25
+ "CheckpointingConfig",
26
+ "DataConfig",
27
+ "EvaluationConfig",
28
+ "ModelConfig",
29
+ "MonitoringConfig",
30
+ "TrainingConfig",
31
+ ]
src/config/_constants.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Constants used throughout the codebase
3
+ """
4
+
5
+ # Basic Training Constants used throughout the codebase
6
+ VOCAB_SIZE = 50304
7
+ MAX_SEQ_LEN = 2048
8
+ BATCH_SIZE = 1024
9
+ GRADIENT_ACCUMULATION_STEPS = 128
10
+
11
+ # Directories used to store training runs, checkpoints, logs, and evaluation results
12
+ RUNS_DIR = "runs"
13
+ CHECKPOINTS_DIR = "checkpoints"
14
+ LOGS_DIR = "logs"
15
+ FABRIC_CHECKPOINT_DIR = "fabric_state"
16
+ FABRIC_CHECKPOINT_FILENAME = "checkpoint.pt"
17
+ LEARNING_DYNAMICS_DIR = "learning_dynamics"
18
+ EVAL_RESULTS_DIR = "eval_results"
src/config/checkpointing_config.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Checkpointing Config
3
+
4
+ Specifies the hyperparameters for the checkpointing process; checkpointing is used to save
5
+ the model and optimizer states, as well as the learning dynamics metrics.
6
+ """
7
+
8
+ from dataclasses import dataclass, field
9
+ from typing import List, Optional
10
+
11
+ from ._constants import (
12
+ CHECKPOINTS_DIR,
13
+ EVAL_RESULTS_DIR,
14
+ FABRIC_CHECKPOINT_DIR,
15
+ FABRIC_CHECKPOINT_FILENAME,
16
+ LEARNING_DYNAMICS_DIR,
17
+ LOGS_DIR,
18
+ RUNS_DIR,
19
+ )
20
+
21
+
22
+ @dataclass
23
+ class TrainingCheckpointingConfig:
24
+ # Automatically resume training from the most recent checkpoint
25
+ auto_resume: bool = True
26
+
27
+
28
+ @dataclass
29
+ class EvaluationCheckpointingConfig:
30
+ # Directory in which evaluation results are saved
31
+ eval_results_dir: str = EVAL_RESULTS_DIR
32
+
33
+
34
+ @dataclass
35
+ class LearningDynamicsCheckpointingConfig:
36
+ # Suffixes of the layers to compute learning dynamics for
37
+ layer_suffixes: List[str] = field(
38
+ default_factory=lambda: [
39
+ "attention.v_proj",
40
+ "attention.o_proj",
41
+ "swiglu.w_2",
42
+ ]
43
+ )
44
+
45
+ # Sequence index at which to extract hidden states; by default, we extract the hidden states
46
+ # at the last token of the sequence (-1)
47
+ sequence_idx: int = -1
48
+
49
+ # size of the sub-batch used for extracting learning dynamics states
50
+ batch_size: int = 8
51
+
52
+ # Path to evaluation dataset - used across learning dynamics checkpointing for consistency
53
+ # NOTE: set to None to disable extracting learning dynamics states for an eval_batch
54
+ # NOTE: this dataset should be small, ideally just a batch of additional data
55
+ eval_data: Optional[str] = "pico-lm/pretokenized-paloma-tinsy"
56
+
57
+
58
+ @dataclass
59
+ class HuggingFaceCheckpointingConfig:
60
+ # Should be in the format of <(username or organization name)>/<repo_name>, e.g. pico-lm/demo
61
+ repo_id: str = ""
62
+
63
+ # HuggingFace Collection Slug (specifies a tag for the run)
64
+ collection_slug: Optional[str] = None
65
+
66
+
67
+ @dataclass
68
+ class CheckpointingConfig:
69
+ # Assign a name to the run
70
+ run_name: Optional[str] = None
71
+
72
+ # Defining checkpointing directories
73
+ runs_dir: str = RUNS_DIR
74
+ checkpoints_dir: str = CHECKPOINTS_DIR
75
+ logs_dir: str = LOGS_DIR
76
+ fabric_checkpoint_dir: str = FABRIC_CHECKPOINT_DIR
77
+ fabric_checkpoint_filename: str = FABRIC_CHECKPOINT_FILENAME
78
+ learning_dynamics_dir: str = LEARNING_DYNAMICS_DIR
79
+
80
+ # How often to save checkpoints
81
+ save_every_n_steps: int = 1000
82
+
83
+ # Whether to save checkpoints to HuggingFace
84
+ save_to_hf: Optional[bool] = False
85
+ hf_checkpoint: HuggingFaceCheckpointingConfig = field(
86
+ default_factory=HuggingFaceCheckpointingConfig
87
+ )
88
+
89
+ training: TrainingCheckpointingConfig = field(
90
+ default_factory=TrainingCheckpointingConfig
91
+ )
92
+ evaluation: EvaluationCheckpointingConfig = field(
93
+ default_factory=EvaluationCheckpointingConfig
94
+ )
95
+ learning_dynamics: LearningDynamicsCheckpointingConfig = field(
96
+ default_factory=LearningDynamicsCheckpointingConfig
97
+ )
src/config/data_config.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Data Config
3
+
4
+ Specifies the hyperparameters for the dataset, dataloader, and tokenizer.
5
+ """
6
+
7
+ from dataclasses import dataclass, field
8
+
9
+ from ._constants import BATCH_SIZE, VOCAB_SIZE
10
+
11
+
12
+ @dataclass
13
+ class DatasetConfig:
14
+ # Defines the HuggingFace name of a dataset
15
+ name: str = "pico-lm/pretokenized-dolma"
16
+
17
+
18
+ @dataclass
19
+ class DataLoaderConfig:
20
+ # NOTE: You should only change these values jointly with the training config; so that the
21
+ # sub-batch size is consistent with the gradient accumulation steps
22
+ batch_size: int = BATCH_SIZE
23
+
24
+
25
+ @dataclass
26
+ class TokenizerConfig:
27
+ # Specify a tokenizer to use
28
+ name: str = "allenai/OLMo-7B-0724-hf"
29
+ vocab_size: int = VOCAB_SIZE
30
+
31
+
32
+ @dataclass
33
+ class DataConfig:
34
+ dataset: DatasetConfig = field(default_factory=DatasetConfig)
35
+ dataloader: DataLoaderConfig = field(default_factory=DataLoaderConfig)
36
+ tokenizer: TokenizerConfig = field(default_factory=TokenizerConfig)
src/config/evaluation_config.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Evaluation Config
3
+
4
+ Specifies the hyperparameters for the evaluation process, i.e. what metrics to compute, etc.
5
+ """
6
+
7
+ from dataclasses import dataclass, field
8
+ from typing import List, Optional
9
+
10
+ from src.config._constants import MAX_SEQ_LEN
11
+
12
+
13
+ @dataclass
14
+ class PalomaEvaluationConfig:
15
+ dataset_name: str = "pico-lm/pretokenized-paloma-tinsy"
16
+ dataset_split: str = "val"
17
+ max_length: int = MAX_SEQ_LEN
18
+ batch_size: int = 16
19
+
20
+
21
+ @dataclass
22
+ class EvaluationConfig:
23
+ # Evaluation metrics to compute: by default, we compute the perplexity of the model on the paloma dataset
24
+ metrics: Optional[List[str]] = field(default_factory=lambda: ["paloma"])
25
+
26
+ # NOTE: Add other evaluation configs here
27
+ # Each evaluation metric should have its own config
28
+ paloma: PalomaEvaluationConfig = field(default_factory=PalomaEvaluationConfig)
src/config/model_config.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Model Config
3
+
4
+ Specifies the hyperparameters for the Pico model/model architecture.
5
+ """
6
+
7
+ from dataclasses import dataclass
8
+ from typing import Optional
9
+
10
+ from ._constants import BATCH_SIZE, MAX_SEQ_LEN, VOCAB_SIZE
11
+
12
+
13
+ @dataclass
14
+ class ModelConfig:
15
+ model_type: str = "pico_decoder"
16
+
17
+ # Pico Decoder default hyperparameters
18
+
19
+ d_model: int = 768
20
+ n_layers: int = 12
21
+
22
+ vocab_size: int = VOCAB_SIZE
23
+ batch_size: int = BATCH_SIZE
24
+ max_seq_len: int = MAX_SEQ_LEN
25
+
26
+ attention_n_heads: int = 12
27
+ attention_n_kv_heads: Optional[int] = 4
28
+
29
+ activation_hidden_dim: int = 3072
30
+
31
+ norm_eps: float = 1e-6
32
+
33
+ position_emb_theta: float = 10000.0
src/config/monitoring_config.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Monitoring Config
3
+
4
+ Specifies the monitoring process, e.g. how to log metrics and keep track of training progress.
5
+ """
6
+
7
+ from dataclasses import dataclass, field
8
+
9
+
10
+ @dataclass
11
+ class LoggingConfig:
12
+ log_level: str = "INFO"
13
+ log_every_n_steps: int = 100
14
+
15
+
16
+ @dataclass
17
+ class WandbConfig:
18
+ # configure logging to Weights and Biases
19
+ project: str = ""
20
+ entity: str = ""
21
+
22
+
23
+ @dataclass
24
+ class MonitoringConfig:
25
+ logging: LoggingConfig = field(default_factory=LoggingConfig)
26
+
27
+ # Weights and Biases
28
+ save_to_wandb: bool = False
29
+ wandb: WandbConfig = field(default_factory=WandbConfig)
src/config/training_config.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Training Config
3
+
4
+ Specifies the hyperparameters for the training process, i.e. the optimizer, learning rate, etc.
5
+ """
6
+
7
+ from dataclasses import dataclass, field
8
+
9
+ from ._constants import GRADIENT_ACCUMULATION_STEPS
10
+
11
+
12
+ @dataclass
13
+ class FabricConfig:
14
+ # Configure nodes/devices for parallelised training
15
+ num_nodes: int = 1
16
+ num_devices: int = 1
17
+ precision: str = "bf16-mixed"
18
+ # Hardware accelerator to use, can be cpu/cuda/mps etc.
19
+ accelerator: str = "cuda"
20
+
21
+
22
+ @dataclass
23
+ class OptimizationConfig:
24
+ # Optimizer
25
+ optimizer: str = "adamw"
26
+ lr: float = 3e-4
27
+
28
+ # Learning Rate Scheduler
29
+ lr_scheduler: str = "linear_with_warmup"
30
+ lr_warmup_steps: int = 2500
31
+
32
+ # Define number of gradient accumulation steps
33
+ gradient_accumulation_steps: int = GRADIENT_ACCUMULATION_STEPS
34
+
35
+
36
+ @dataclass
37
+ class TrainingConfig:
38
+ fabric: FabricConfig = field(default_factory=FabricConfig)
39
+ optimization: OptimizationConfig = field(default_factory=OptimizationConfig)
40
+ max_steps: int = 200_000
src/evaluation/__init__.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Pico Evaluation Package
3
+
4
+ This package implements the evaluation pipeline for the Pico language model. It provides
5
+ functionality to evaluate model performance using various metrics and handles the complete
6
+ evaluation workflow.
7
+
8
+ We recommend that each evaluation metric should have its own config, and should be
9
+ implemented as a module in the `evaluation/tasks` directory that exposes a `run_<metric_name>` function.
10
+
11
+ NOTE: Out of the box we only support Paloma, but the structure is designed to be flexible and
12
+ you are meant to add whatever metrics you want. One of the main reasons we store out
13
+ the model in the HuggingFace format is so that its easy to use third-party evaluation
14
+ libraries/frameworks.
15
+ """
16
+
17
+ import os
18
+
19
+ import torch
20
+ from lightning.fabric import Fabric
21
+ from torch import nn
22
+
23
+ from src.config import CheckpointingConfig, EvaluationConfig
24
+
25
+ from .tasks.paloma import run_paloma_evaluation
26
+
27
+
28
+ def run_evaluation(
29
+ evaluation_config: EvaluationConfig,
30
+ checkpointing_config: CheckpointingConfig,
31
+ fabric: Fabric,
32
+ model: nn.Module,
33
+ ) -> None:
34
+ """Run model evaluation using specified metrics in `evaluation_config`.
35
+
36
+ This function orchestrates the complete evaluation pipeline by:
37
+ 1. Resolving the model checkpoint path (either specified or latest) to load the model from;
38
+ during training, this is the path to the latest checkpoint in the run directory.
39
+ 2. Iterating over each evaluation metric, and running the corresponding evaluation function.
40
+ NOTE: we suggest you follow the pattern of the Paloma evaluation function, and implement
41
+ your own evaluation function for each metric in the `evaluation/tasks` directory.
42
+ 3. Aggregating results across all metrics in a dictionary, and returning it.
43
+
44
+ Args:
45
+ evaluation_config (EvaluationConfig): Configuration object containing:
46
+ - metrics (List[str]): Metrics to evaluate; each metric should have its
47
+ own config. Currently supported: ["paloma"];
48
+ - paloma (PalomaConfig): Configuration for Paloma evaluation
49
+ - max_length (int): Maximum sequence length
50
+ - limit_eval_examples (Optional[int]): Number of examples to evaluate
51
+ checkpointing_config (CheckpointingConfig): Configuration object containing:
52
+ fabric (Fabric): Lightning Fabric instance
53
+ model (nn.Module): Original model instance
54
+
55
+ Returns:
56
+ Dict[str, float]: Dictionary mapping metric names to their values
57
+ Example: {"paloma": 3.45}
58
+
59
+ Raises:
60
+ ValueError: If an unsupported evaluation metric is requested
61
+
62
+ Example:
63
+ results = run_evaluation(
64
+ EvaluationConfig(
65
+ run_name="experiment_1",
66
+ metrics=["paloma"],
67
+ paloma=PalomaConfig(max_length=2048, batch_size=16)
68
+ )
69
+ )
70
+
71
+ """
72
+
73
+ fabric.barrier()
74
+
75
+ model.to("cpu") # Offloading model to CPU
76
+
77
+ evaluation_results = {}
78
+
79
+ # NOTE: Evaluation is only run on first processes to enable third-party evaluation libraries
80
+ # to determine how to handle distributed evaluation.
81
+ if fabric.global_rank == 0:
82
+ run_name = checkpointing_config.run_name
83
+ model_path = f"{os.getcwd()}/{checkpointing_config.runs_dir}/{run_name}/{checkpointing_config.checkpoints_dir}/latest"
84
+ os.makedirs(model_path, exist_ok=True)
85
+
86
+ for metric in evaluation_config.metrics:
87
+ # NOTE: add your own metrics here
88
+ if metric == "paloma":
89
+ evaluation_result = run_paloma_evaluation(
90
+ model_path, evaluation_config.paloma
91
+ )
92
+ else:
93
+ raise ValueError(f"Metric {metric} not supported")
94
+
95
+ evaluation_results[metric] = evaluation_result
96
+
97
+ torch.cuda.empty_cache()
98
+
99
+ fabric.barrier()
100
+
101
+ model.to(fabric.device)
102
+
103
+ return evaluation_results
src/evaluation/tasks/paloma.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Paloma is a comprehensive evaluation benchmark for large language models (LLMs) that focuses
3
+ on measuring perplexity across diverse text domains.
4
+
5
+ To evaluate on Paloma, we use the huggingface evaluation framework.
6
+
7
+ For more details, see: https://huggingface.co/datasets/allenai/paloma
8
+ """
9
+
10
+ import evaluate
11
+ from datasets import load_dataset
12
+ from datasets.utils.logging import disable_progress_bar, enable_progress_bar
13
+
14
+ from src.config.evaluation_config import PalomaEvaluationConfig
15
+
16
+
17
+ def run_paloma_evaluation(
18
+ model_path: str,
19
+ paloma_config: PalomaEvaluationConfig,
20
+ ) -> None:
21
+ """Run Perplexity evaluation on the Paloma evaluation dataset.
22
+
23
+ We use the HuggingFace evaluate library to load in and compute the perplexity metric.
24
+
25
+ Args:
26
+ model_path (str): Path to the model checkpoint to be evaluated
27
+ paloma_config (PalomaEvaluationConfig): Configuration for Paloma evaluation
28
+ """
29
+
30
+ disable_progress_bar()
31
+
32
+ # load custom evaluation space, see https://huggingface.co/spaces/pico-lm/perplexity
33
+ perplexity = evaluate.load("pico-lm/perplexity")
34
+
35
+ dataset = load_dataset(
36
+ paloma_config.dataset_name, split=paloma_config.dataset_split
37
+ )["text"]
38
+
39
+ # compute perplexity score on Paloma dataset
40
+ perplexity_result = perplexity.compute(
41
+ model_id=model_path,
42
+ predictions=dataset,
43
+ add_start_token=False,
44
+ max_length=paloma_config.max_length,
45
+ batch_size=paloma_config.batch_size,
46
+ trust_remote_code=True,
47
+ )
48
+
49
+ mean_perplexity = perplexity_result["mean_perplexity"]
50
+
51
+ enable_progress_bar()
52
+ return mean_perplexity
src/model/__init__.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Model Package
3
+
4
+ This Package contains Pico models (currently only the Pico Decoder). We plan to implement other
5
+ architectures in the future.
6
+
7
+ If you have other models you'd like to implement, we recommend you add modules to this package.
8
+ """
9
+
10
+ from .pico_decoder import PicoDecoder
11
+
12
+ __all__ = ["PicoDecoder"]
src/model/pico_decoder.py ADDED
@@ -0,0 +1,911 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Pico Decoder: A Lightweight Causal Transformer Language Model
3
+
4
+ Pico Decoder uses a simple LLAMA-style transformer architecture, written for clarity and educational purposes.
5
+
6
+ Everything is written with a modular design for easy modification and experimentation.
7
+
8
+ Key features:
9
+ - RMSNorm for layer normalization
10
+ - Rotary Positional Embeddings (RoPE)
11
+ - Multi-head attention with KV-cache support
12
+ - SwiGLU activation function
13
+ - Residual connections throughout
14
+
15
+ - KV-cache for faster autoregressive generation
16
+
17
+ References:
18
+ - RoPE: https://arxiv.org/abs/2104.09864
19
+ - SwiGLU: https://arxiv.org/abs/2002.05202
20
+ - LLAMA: https://arxiv.org/abs/2302.13971
21
+
22
+ Adapted from:
23
+ - OLMO: https://github.com/allenai/OLMo
24
+ - LLAMA: https://github.com/meta/llama
25
+ """
26
+
27
+ from dataclasses import asdict
28
+ from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
29
+
30
+ import torch
31
+ import torch.nn as nn
32
+ import torch.nn.functional as F
33
+
34
+ # Handle PyTorch version compatibility for attention backend
35
+ try:
36
+ from torch.nn.attention import SDPBackend, sdpa_kernel
37
+
38
+ HAS_TORCH_ATTENTION = True
39
+ except ImportError:
40
+ # Fallback for older PyTorch versions
41
+ HAS_TORCH_ATTENTION = False
42
+ SDPBackend = None
43
+ sdpa_kernel = None
44
+
45
+ from transformers import GenerationMixin, PretrainedConfig, PreTrainedModel
46
+ from transformers.generation import GenerationConfig
47
+ from transformers.modeling_outputs import CausalLMOutput, CausalLMOutputWithPast
48
+
49
+ try:
50
+ if TYPE_CHECKING:
51
+ # We need to do this to avoid importing these when creating the HF-compatible models
52
+ from src.config import ModelConfig
53
+ except ImportError:
54
+ pass
55
+
56
+ ########################################################
57
+ #
58
+ # Layer Normalization
59
+ #
60
+ ########################################################
61
+
62
+
63
+ class RMSNorm(torch.nn.Module):
64
+ """Root Mean Square Layer Normalization.
65
+
66
+ A variant of Layer Normalization that uses RMS statistics instead of mean/variance,
67
+ resulting in improved stability and performance.
68
+
69
+ Args:
70
+ config (Union[ModelConfig, PicoHFConfig]): Configuration object containing normalization parameters
71
+ - config.norm_eps: Small constant for numerical stability
72
+ - config.d_model: Model dimension for the weight parameter
73
+
74
+ References:
75
+ https://arxiv.org/abs/1910.07467
76
+ """
77
+
78
+ def __init__(self, config: Union["ModelConfig", "PicoDecoderHFConfig"]):
79
+ super().__init__()
80
+ self.eps = config.norm_eps
81
+ self.weight = nn.Parameter(torch.ones(config.d_model))
82
+
83
+ def _norm(self, x: torch.Tensor) -> torch.Tensor:
84
+ """
85
+ Normalizes the input tensor by its RMS value.
86
+ """
87
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
88
+
89
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
90
+ """
91
+ Applies RMS normalization to the input tensor and scales it by the weight parameter.
92
+ """
93
+ output = self._norm(x.float()).type_as(x)
94
+ return output * self.weight
95
+
96
+
97
+ ########################################################
98
+ #
99
+ # Positional Embedding
100
+ #
101
+ ########################################################
102
+
103
+
104
+ class RoPE(nn.Module):
105
+ """Rotary Positional Embeddings (RoPE).
106
+
107
+ Implements position-dependent rotation of keys and queries in attention mechanism,
108
+ allowing better modeling of relative positions in sequences. Uses complex number
109
+ operations for efficient rotation.
110
+
111
+ Args:
112
+ config (Union[ModelConfig, PicoHFConfig]): Model configuration containing:
113
+ - config.position_emb_theta: Base for frequency computation
114
+ - config.d_model: Model dimension
115
+ - config.attention_n_heads: Number of attention heads
116
+ - config.max_seq_len: Maximum sequence length
117
+
118
+ References:
119
+ https://arxiv.org/abs/2104.09864
120
+ """
121
+
122
+ _freqs_cis_tensor: torch.Tensor | None = None
123
+
124
+ def __init__(self, config: Union["ModelConfig", "PicoDecoderHFConfig"]):
125
+ super().__init__()
126
+
127
+ self.theta = config.position_emb_theta
128
+ self.dim = config.d_model // config.attention_n_heads
129
+
130
+ max_seq_len = config.max_seq_len
131
+
132
+ # only gets set once, and then reused for all RoPE instances
133
+ if RoPE._freqs_cis_tensor is None:
134
+ RoPE._freqs_cis_tensor = self._setup_freqs_cis(
135
+ max_seq_len, self.theta, self.dim
136
+ )
137
+
138
+ # register _freqs_cis buffer
139
+ # can be easily recomputed so persistent=False
140
+ self.register_buffer("_freqs_cis", self._freqs_cis_tensor, persistent=False)
141
+
142
+ @classmethod
143
+ def _setup_freqs_cis(cls, seq_len: int, theta: float, dim: int) -> torch.Tensor:
144
+ """Setup Frequency Tensor for RoPE Embeddings
145
+
146
+ Initializes the complex frequency tensor that is used to compute the RoPE embeddings.
147
+
148
+ Note other implementations will use cos and sin directly, but using the complex
149
+ number representation is (probably) more efficient:
150
+
151
+ e^(theta * i * t) = cos(theta * t) + i * sin(theta * t) [Euler's formula]
152
+ """
153
+ _freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
154
+ positions = torch.arange(seq_len)
155
+ freqs = torch.outer(positions, _freqs)
156
+ return torch.polar(torch.ones_like(freqs), freqs) # complex64
157
+
158
+ def get_freqs_cis(
159
+ self, input_shape: torch.Size, start_pos: int, end_pos: int
160
+ ) -> torch.Tensor:
161
+ """Reshape Frequency Tensor for RoPE Embeddings
162
+
163
+ Makes the frequency tensor broadcastable with the input tensor.
164
+ """
165
+ _freqs_cis = self._freqs_cis[start_pos:end_pos]
166
+ ndim = len(input_shape)
167
+ assert 0 <= 1 < ndim
168
+ assert _freqs_cis.shape == (input_shape[1], input_shape[-1])
169
+
170
+ # TODO: Check whether this is correct (might be able to remove this)
171
+ shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(input_shape)]
172
+ return _freqs_cis.view(*shape)
173
+
174
+ def forward(
175
+ self,
176
+ queries: torch.Tensor,
177
+ keys: torch.Tensor,
178
+ start_pos: int = 0,
179
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
180
+ """Apply RoPE Embeddings to Queries and Keys
181
+
182
+ Applies the rotary positional embeddings to the input tensors via complex num multiplication
183
+
184
+ NOTE: The start_pos is used if we want to use the kv_cache in the attention mechanism.
185
+ """
186
+ queries_ = torch.view_as_complex(
187
+ queries.float().reshape(*queries.shape[:-1], -1, 2)
188
+ )
189
+ keys_ = torch.view_as_complex(keys.float().reshape(*keys.shape[:-1], -1, 2))
190
+
191
+ input_shape = (
192
+ queries_.shape
193
+ ) # same as keys: (batch_size, seq_len, n_heads, head_dim/2)
194
+ freqs_start_pos = start_pos
195
+ freqs_end_pos = freqs_start_pos + queries_.shape[1]
196
+
197
+ freqs_cis = self.get_freqs_cis(input_shape, freqs_start_pos, freqs_end_pos)
198
+
199
+ queries_rotated = torch.view_as_real(queries_ * freqs_cis).flatten(3)
200
+ keys_rotated = torch.view_as_real(keys_ * freqs_cis).flatten(3)
201
+ return queries_rotated.type_as(queries), keys_rotated.type_as(keys)
202
+
203
+
204
+ ########################################################
205
+ #
206
+ # Attention
207
+ #
208
+ ########################################################
209
+
210
+
211
+ class Attention(nn.Module):
212
+ """Multi-head Attention with Group Query Attention support.
213
+
214
+ Implements scaled dot-product attention and supports:
215
+ - Grouped Query Attention (GQA)
216
+ - Key-Value caching for efficient inference
217
+ - RoPE integration
218
+
219
+ Args:
220
+ config (Union[ModelConfig, PretrainedConfig]): Configuration containing:
221
+ - config.attention_n_heads: Number of attention heads
222
+ - config.attention_n_kv_heads: Number of key/value heads
223
+ - config.d_model: Model dimension
224
+ - config.batch_size: Maximum batch size
225
+ - config.max_seq_len: Maximum sequence length
226
+
227
+ Shape:
228
+ - Input: (batch_size, seq_len, d_model)
229
+ - Output: (batch_size, seq_len, d_model)
230
+ """
231
+
232
+ def __init__(
233
+ self,
234
+ config: Union["ModelConfig", "PicoDecoderHFConfig"],
235
+ ):
236
+ super().__init__()
237
+
238
+ self.n_heads = config.attention_n_heads
239
+ self.n_kv_heads = config.attention_n_kv_heads
240
+
241
+ self.batch_size = config.batch_size
242
+ self.max_seq_len = config.max_seq_len
243
+
244
+ d_model = config.d_model
245
+ self.head_dim = d_model // self.n_heads
246
+
247
+ self.n_rep = self.n_heads // self.n_kv_heads
248
+
249
+ self.q_proj = nn.Linear(d_model, self.n_heads * self.head_dim, bias=False)
250
+ self.k_proj = nn.Linear(d_model, self.n_kv_heads * self.head_dim, bias=False)
251
+ self.v_proj = nn.Linear(d_model, self.n_kv_heads * self.head_dim, bias=False)
252
+ self.o_proj = nn.Linear(self.n_heads * self.head_dim, d_model, bias=False)
253
+
254
+ self.rope = RoPE(config)
255
+
256
+ def forward(
257
+ self,
258
+ input: torch.Tensor,
259
+ mask: Optional[torch.Tensor] = None,
260
+ past_key_values: Optional[Tuple[torch.Tensor, ...]] = None,
261
+ use_cache: bool = False,
262
+ ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
263
+ """Forward pass for the attention mechanism.
264
+
265
+ Computes queries, keys, and values for the attention mechanism. Applies rotary positional
266
+ embeddings to the queries and keys, and then computes attention scores and outputs.
267
+
268
+ For an introduction to the attention mechanism, see:
269
+ https://arxiv.org/abs/1706.03762
270
+
271
+ A few things to note:
272
+ - The past_key_values is used to implement the KV cache, which is used to speed up
273
+ generation by caching the KV pairs from previous forward passes. This is useful when doing
274
+ tasks that require generating multiple tokens conditioned on previous tokens (e.g. language
275
+ modeling, text generation, etc.). The way the KV cache is implemented is that each layer has
276
+ its own KV cache - this KV cache is implemented as a tuple.
277
+ """
278
+ bsz, seq_len, _ = input.shape
279
+ _queries, _keys, _values = (
280
+ self.q_proj(input),
281
+ self.k_proj(input),
282
+ self.v_proj(input),
283
+ )
284
+
285
+ # Reshaping for multi-head attention
286
+ queries = _queries.view(bsz, seq_len, self.n_heads, self.head_dim)
287
+ keys = _keys.view(bsz, seq_len, self.n_kv_heads, self.head_dim)
288
+ values = _values.view(bsz, seq_len, self.n_kv_heads, self.head_dim)
289
+
290
+ # The start position is used to apply the RoPE embeddings to only the new tokens
291
+ # when using the kv_cache in the attention mechanism.
292
+ # We want to start from the last position in the cache.
293
+ start_pos = 0
294
+ if past_key_values is not None and past_key_values[0] is not None:
295
+ start_pos = past_key_values[0].shape[1]
296
+
297
+ # apply rotary positional embeddings
298
+ queries, keys = self.rope(queries, keys, start_pos)
299
+
300
+ if (
301
+ past_key_values is not None
302
+ and past_key_values[0] is not None
303
+ and past_key_values[1] is not None
304
+ ):
305
+ keys = torch.cat([past_key_values[0], keys], dim=1)
306
+ values = torch.cat([past_key_values[1], values], dim=1)
307
+
308
+ if use_cache:
309
+ cached_keys = keys
310
+ cached_values = values
311
+ else:
312
+ cached_keys = None
313
+ cached_values = None
314
+
315
+ queries = queries.transpose(1, 2)
316
+ keys = keys.transpose(1, 2)
317
+ values = values.transpose(1, 2)
318
+
319
+ apply_gqa = self.n_rep > 1
320
+ if apply_gqa and queries.device.type == "mps":
321
+ # NOTE: MPS does not support GQA in the SDPA kernel, but we can repeat the keys and values
322
+ # outside of the kernel to get the same effect.
323
+ # See: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
324
+ keys = keys.repeat_interleave(self.n_rep, dim=-3)
325
+ values = values.repeat_interleave(self.n_rep, dim=-3)
326
+ apply_gqa = False
327
+
328
+ if HAS_TORCH_ATTENTION:
329
+ backends = [SDPBackend.CUDNN_ATTENTION, SDPBackend.MATH]
330
+ with sdpa_kernel(backends=backends):
331
+ attn_output = F.scaled_dot_product_attention(
332
+ queries.contiguous(),
333
+ keys.contiguous(),
334
+ values.contiguous(),
335
+ attn_mask=mask.to(queries.dtype) if mask is not None else None,
336
+ enable_gqa=apply_gqa,
337
+ )
338
+ else:
339
+ # Fallback for older PyTorch versions - use default backend
340
+ attn_output = F.scaled_dot_product_attention(
341
+ queries.contiguous(),
342
+ keys.contiguous(),
343
+ values.contiguous(),
344
+ attn_mask=mask.to(queries.dtype) if mask is not None else None,
345
+ enable_gqa=apply_gqa,
346
+ )
347
+
348
+ attn_output = attn_output.transpose(1, 2).contiguous().view(bsz, seq_len, -1)
349
+ output = self.o_proj(attn_output)
350
+
351
+ return output, (cached_keys, cached_values)
352
+
353
+
354
+ ########################################################
355
+ #
356
+ # SwiGLU (Combines MLP and Activation)
357
+ #
358
+ ########################################################
359
+
360
+
361
+ class SwiGLU(nn.Module):
362
+ """SwiGLU Activation Function with Linear Projections.
363
+
364
+ Implements the SwiGLU activation function combined with linear transformations,
365
+ serving as the feed-forward network in transformer blocks.
366
+
367
+ Args:
368
+ config (Union[ModelConfig, PicoDecoderHFConfig]): Configuration containing:
369
+ - config.d_model: Model dimension
370
+ - config.activation_hidden_dim: Hidden dimension (typically 4 * d_model)
371
+
372
+ References:
373
+ https://arxiv.org/abs/2002.05202
374
+ """
375
+
376
+ def __init__(self, config: Union["ModelConfig", "PicoDecoderHFConfig"]):
377
+ super().__init__()
378
+
379
+ model_dim = config.d_model
380
+ act_hidden_dim = config.activation_hidden_dim # usually 4 * d_model
381
+
382
+ self.w_0 = nn.Linear(model_dim, act_hidden_dim, bias=False)
383
+ self.w_1 = nn.Linear(model_dim, act_hidden_dim, bias=False)
384
+ self.w_2 = nn.Linear(act_hidden_dim, model_dim, bias=False)
385
+
386
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
387
+ return self.w_2(F.silu(self.w_0(x)) * self.w_1(x))
388
+
389
+
390
+ ########################################################
391
+ #
392
+ # PicoDecoderBlock
393
+ #
394
+ ########################################################
395
+
396
+
397
+ class PicoDecoderBlock(nn.Module):
398
+ """Single Transformer Block with Attention and Feed-forward layers.
399
+
400
+ Implements a standard transformer block with:
401
+ - Multi-head attention with normalization and residual connection
402
+ - SwiGLU feed-forward network with normalization and residual connection
403
+
404
+ Args:
405
+ config (Union[ModelConfig, PicoDecoderHFConfig]): Model configuration; either a dataclass or
406
+ a HuggingFace PicoDecoderHFConfig
407
+ """
408
+
409
+ def __init__(
410
+ self,
411
+ config: Union["ModelConfig", "PicoDecoderHFConfig"],
412
+ ):
413
+ super().__init__()
414
+
415
+ self.attention = Attention(config)
416
+ self.swiglu = SwiGLU(config)
417
+ self.attention_norm = RMSNorm(config)
418
+ self.swiglu_norm = RMSNorm(config)
419
+
420
+ def forward(
421
+ self,
422
+ input: torch.Tensor,
423
+ mask: Optional[torch.Tensor] = None,
424
+ past_key_values: Optional[Tuple[torch.Tensor]] = None,
425
+ use_cache: bool = False,
426
+ ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
427
+ attention_output, cached_key_values = self.attention(
428
+ self.attention_norm(input),
429
+ mask=mask,
430
+ past_key_values=past_key_values,
431
+ use_cache=use_cache,
432
+ )
433
+ # NOTE: cached_key_values is None if use_cache is False
434
+
435
+ h = input + attention_output
436
+ out = h + self.swiglu(self.swiglu_norm(h))
437
+ return out, cached_key_values
438
+
439
+
440
+ ########################################################
441
+ #
442
+ # Pico Decoder (Causal Transformer Model)
443
+ #
444
+ ########################################################
445
+
446
+
447
+ class PicoDecoder(nn.Module):
448
+ """
449
+ Pico Decoder: combines the embedding, causal decoder blocks, and output projection into a
450
+ single autoregressive model.
451
+
452
+ For more information on the model, see the classes for the modules that make up the model.
453
+ """
454
+
455
+ def __init__(
456
+ self,
457
+ model_config: Union["ModelConfig", "PicoDecoderHFConfig"],
458
+ ):
459
+ super().__init__()
460
+ self.config = model_config
461
+
462
+ self.embedding_proj = nn.Embedding(self.config.vocab_size, self.config.d_model)
463
+ self.layers = nn.ModuleList(
464
+ [PicoDecoderBlock(self.config) for _ in range(self.config.n_layers)]
465
+ )
466
+ self.output_norm = RMSNorm(self.config)
467
+ self.de_embedding_proj = nn.Linear(
468
+ self.config.d_model, self.config.vocab_size, bias=False
469
+ )
470
+
471
+ def convert_to_hf_model(self) -> "PicoDecoderHF":
472
+ """Convert the Lightning model to a HuggingFace model."""
473
+ # Create HF config without fabric-specific settings
474
+ hf_config = PicoDecoderHFConfig.from_dataclass(self.config)
475
+
476
+ # Create new HF model
477
+ hf_model = PicoDecoderHF(hf_config)
478
+
479
+ # Copy state dict, excluding fabric-specific keys
480
+ hf_model.load_state_dict(self.state_dict(prefix="pico_decoder."))
481
+
482
+ return hf_model
483
+
484
+ def forward(
485
+ self,
486
+ input_ids: torch.Tensor,
487
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
488
+ use_cache: bool = False,
489
+ ) -> Tuple[torch.Tensor, Optional[Tuple[Tuple[torch.Tensor, torch.Tensor]]]]:
490
+ """
491
+ This is the forward pass for the entire Pico model. It boils down to:
492
+ - Embedding the input ids
493
+ - Creating a causal mask
494
+ - Processing through the pico layers
495
+ - Projecting the output to logits
496
+
497
+ NOTE: One feature that might be confusing is the KV cache. The KV cache is used to speed up
498
+ generation by caching the KV pairs from previous forward passes. This is useful when doing
499
+ tasks that require generating multiple tokens conditioned on previous tokens (e.g. language
500
+ modeling, text generation, etc.). The way the KV cache is implemented is that each layer has
501
+ its own KV cache which is stored as a tuple. The whole model then stores a tuple of these
502
+ KV caches (so a tuple of tuples).
503
+ """
504
+
505
+ seq_len = input_ids.shape[-1]
506
+ h = self.embedding_proj(input_ids)
507
+
508
+ # Calculate start position from past cached KV pairs. Remember that each layer has its
509
+ # own KV Cache. So when we index past_key_values, we need to index into the KV pairs for the
510
+ # correct layer and then for either the keys or values.
511
+ start_pos = 0
512
+ if (
513
+ past_key_values is not None
514
+ and past_key_values[0] is not None
515
+ and past_key_values[0][0] is not None
516
+ ):
517
+ start_pos = past_key_values[0][0].shape[1]
518
+
519
+ # Create causal mask for current sequence
520
+ mask = None
521
+ if seq_len > 1:
522
+ mask = torch.full((seq_len, seq_len), float("-inf"))
523
+ mask = torch.triu(mask, diagonal=1)
524
+
525
+ # If using KV cache, extend mask to cover cached sequence length
526
+ if past_key_values is not None:
527
+ # Add zeros for cached tokens (we can attend to all of them)
528
+ mask = torch.hstack([torch.zeros((seq_len, start_pos)), mask])
529
+
530
+ mask = mask.to(h.device)
531
+
532
+ # NOTE: If we are using the cache, we need to store the cached KV pairs for each layer
533
+ # in a tuple. Each layer will have its own cached KV pair which we aggregate in a tuple.
534
+ cached_key_values = () if use_cache else None
535
+
536
+ # Process through transformer blocks
537
+ for idx, layer in enumerate(self.layers):
538
+ layer_past_key_values = None
539
+ if past_key_values is not None:
540
+ try:
541
+ # Handle both tuple-based cache and HuggingFace cache objects
542
+ if hasattr(past_key_values, "__getitem__") and idx < len(
543
+ past_key_values
544
+ ):
545
+ layer_past_key_values = past_key_values[idx]
546
+ except (KeyError, IndexError, TypeError):
547
+ # If we can't access the cache properly, just skip it
548
+ layer_past_key_values = None
549
+
550
+ h, layer_cached_key_values = layer(
551
+ h, mask=mask, past_key_values=layer_past_key_values, use_cache=use_cache
552
+ )
553
+
554
+ if use_cache:
555
+ cached_key_values += (layer_cached_key_values,)
556
+
557
+ # Final norm and projection
558
+ h = self.output_norm(h)
559
+ logits = self.de_embedding_proj(h).float()
560
+
561
+ return logits, cached_key_values
562
+
563
+
564
+ ########################################################
565
+ #
566
+ # HuggingFace Wrapper for the Pico Decoder model.
567
+ #
568
+ ########################################################
569
+
570
+
571
+ class PicoDecoderHFConfig(PretrainedConfig):
572
+ """Config class for the Pico Decoder HuggingFace wrapper."""
573
+
574
+ model_type = "pico_decoder"
575
+
576
+ @classmethod
577
+ def from_dict(cls, config_dict: Dict[str, Any], **kwargs) -> "PicoDecoderHFConfig":
578
+ """
579
+ Initialize config from a dictionary. Note that no kwargs are passed to the constructor --
580
+ this is because with some kwargs special handling is required and can make this class
581
+ brittle.
582
+ """
583
+ pico_config = cls(**config_dict)
584
+
585
+ return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)
586
+ unused_kwargs = {
587
+ key: value for key, value in kwargs.items() if not hasattr(pico_config, key)
588
+ }
589
+
590
+ if return_unused_kwargs:
591
+ return pico_config, unused_kwargs
592
+ return pico_config
593
+
594
+ @classmethod
595
+ def from_dataclass(cls, model_config: "ModelConfig"):
596
+ """Initialise from our custom config dataclass."""
597
+ return cls.from_dict(asdict(model_config))
598
+
599
+
600
+ class PicoDecoderHF(PreTrainedModel, GenerationMixin):
601
+ """
602
+ HuggingFace wrapper for the Pico model with generation support.
603
+
604
+ Many evaluation frameworks require a model be setup as a HuggingFace model, so we provide a simple
605
+ wrapper that does just that. When we save checkpoints of the Pico model, we save both the normal
606
+ Pico model as well as the model wrapped in this HuggingFace class.
607
+
608
+ This also lets you do cool things like:
609
+
610
+ `model = AutoModelForCausalLM.from_pretrained("path/to/checkpoint")`
611
+ """
612
+
613
+ config_class = PicoDecoderHFConfig
614
+ _no_split_modules = ["PicoBlock", "Attention", "SwiGLU", "RMSNorm"]
615
+ main_input_name = "input_ids"
616
+
617
+ def __init__(self, config: PicoDecoderHFConfig):
618
+ super().__init__(config)
619
+ self.pico_decoder = PicoDecoder(config)
620
+ # Initialize generation config with defaults
621
+ self.generation_config = GenerationConfig()
622
+ # Set some reasonable defaults for the model
623
+ if hasattr(config, "max_position_embeddings"):
624
+ self.generation_config.max_length = config.max_position_embeddings
625
+ if hasattr(config, "vocab_size"):
626
+ self.generation_config.vocab_size = config.vocab_size
627
+
628
+ def forward(
629
+ self,
630
+ input_ids: torch.Tensor,
631
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
632
+ use_cache: bool = False,
633
+ **kwargs,
634
+ ) -> Union[CausalLMOutput, CausalLMOutputWithPast]:
635
+ """HuggingFace forward pass wrapper.
636
+
637
+ Forwards pass for the HuggingFace version of the Pico Model. Basic wrapper around the
638
+ Pico model's forward pass, and returns the output as a HuggingFace CausalLMOutput.
639
+ """
640
+ logits, past_key_values = self.pico_decoder(
641
+ input_ids, past_key_values, use_cache
642
+ )
643
+ if use_cache:
644
+ return CausalLMOutputWithPast(
645
+ logits=logits,
646
+ past_key_values=past_key_values,
647
+ )
648
+ else:
649
+ return CausalLMOutput(
650
+ logits=logits,
651
+ )
652
+
653
+ def prepare_inputs_for_generation(
654
+ self,
655
+ input_ids: torch.LongTensor,
656
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
657
+ attention_mask: Optional[torch.LongTensor] = None,
658
+ **kwargs,
659
+ ) -> Dict[str, Any]:
660
+ """
661
+ Prepare inputs for generation.
662
+
663
+ Args:
664
+ input_ids: Input token IDs
665
+ past_key_values: Cached key-value pairs from previous forward passes
666
+ attention_mask: Attention mask for the input
667
+ **kwargs: Additional arguments
668
+
669
+ Returns:
670
+ Dictionary containing prepared inputs
671
+ """
672
+ # If we have past_key_values, we only need the last token
673
+ if past_key_values is not None:
674
+ input_ids = input_ids[:, -1:]
675
+
676
+ return {
677
+ "input_ids": input_ids,
678
+ "past_key_values": past_key_values,
679
+ "use_cache": True,
680
+ }
681
+
682
+ def get_input_embeddings(self):
683
+ """Get the input embeddings layer."""
684
+ return self.pico_decoder.embedding_proj
685
+
686
+ def set_input_embeddings(self, value):
687
+ """Set the input embeddings layer."""
688
+ self.pico_decoder.embedding_proj = value
689
+
690
+ def get_output_embeddings(self):
691
+ """Get the output embeddings layer."""
692
+ return self.pico_decoder.de_embedding_proj
693
+
694
+ def set_output_embeddings(self, value):
695
+ """Set the output embeddings layer."""
696
+ self.pico_decoder.de_embedding_proj = value
697
+
698
+ def get_lm_head(self):
699
+ """Get the language model head."""
700
+ return self.pico_decoder.de_embedding_proj
701
+
702
+ def can_generate(self) -> bool:
703
+ """Check if the model can generate text."""
704
+ return True
705
+
706
+ @property
707
+ def is_encoder_decoder(self) -> bool:
708
+ """Check if the model is an encoder-decoder model."""
709
+ return False
710
+
711
+ @property
712
+ def can_use_cache(self) -> bool:
713
+ """Check if the model can use KV cache."""
714
+ return True
715
+
716
+ def resize_token_embeddings(
717
+ self, new_num_tokens: Optional[int] = None
718
+ ) -> torch.nn.Embedding:
719
+ """Resize token embeddings."""
720
+ old_embeddings = self.get_input_embeddings()
721
+ if new_num_tokens is None:
722
+ new_num_tokens = old_embeddings.num_embeddings
723
+
724
+ new_embeddings = torch.nn.Embedding(
725
+ new_num_tokens, old_embeddings.embedding_dim
726
+ )
727
+ new_embeddings.weight.data[: old_embeddings.num_embeddings] = (
728
+ old_embeddings.weight.data
729
+ )
730
+
731
+ self.pico_decoder.embedding_proj = new_embeddings
732
+ self.pico_decoder.de_embedding_proj = torch.nn.Linear(
733
+ old_embeddings.embedding_dim, new_num_tokens, bias=False
734
+ )
735
+
736
+ return new_embeddings
737
+
738
+
739
+ # Register for auto classes
740
+ PicoDecoderHFConfig.register_for_auto_class()
741
+ PicoDecoderHF.register_for_auto_class("AutoModel")
742
+ PicoDecoderHF.register_for_auto_class("AutoModelForCausalLM")
743
+
744
+
745
+ ########################################################
746
+ #
747
+ # New PicoDecoderForCausalLM class for generation support
748
+ #
749
+ ########################################################
750
+
751
+
752
+ class PicoDecoderForCausalLM(PreTrainedModel, GenerationMixin):
753
+ """
754
+ PicoDecoderForCausalLM: A HuggingFace-compatible model that properly supports generation.
755
+
756
+ This class is designed to work with existing checkpoints and provides full generation support.
757
+ It inherits from the right base classes that HuggingFace expects for text generation.
758
+ """
759
+
760
+ config_class = PicoDecoderHFConfig
761
+ _no_split_modules = ["PicoBlock", "Attention", "SwiGLU", "RMSNorm"]
762
+ main_input_name = "input_ids"
763
+
764
+ def __init__(self, config: PicoDecoderHFConfig):
765
+ super().__init__(config)
766
+ self.pico_decoder = PicoDecoder(config)
767
+ # Initialize generation config with defaults
768
+ self.generation_config = GenerationConfig()
769
+ # Set some reasonable defaults for the model
770
+ if hasattr(config, "max_position_embeddings"):
771
+ self.generation_config.max_length = config.max_position_embeddings
772
+ if hasattr(config, "vocab_size"):
773
+ self.generation_config.vocab_size = config.vocab_size
774
+
775
+ def forward(
776
+ self,
777
+ input_ids: torch.Tensor,
778
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
779
+ use_cache: bool = False,
780
+ **kwargs,
781
+ ) -> Union[CausalLMOutput, CausalLMOutputWithPast]:
782
+ """Forward pass for text generation."""
783
+ logits, past_key_values = self.pico_decoder(
784
+ input_ids, past_key_values, use_cache
785
+ )
786
+ if use_cache:
787
+ return CausalLMOutputWithPast(
788
+ logits=logits,
789
+ past_key_values=past_key_values,
790
+ )
791
+ else:
792
+ return CausalLMOutput(
793
+ logits=logits,
794
+ )
795
+
796
+ def prepare_inputs_for_generation(
797
+ self,
798
+ input_ids: torch.LongTensor,
799
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
800
+ attention_mask: Optional[torch.LongTensor] = None,
801
+ **kwargs,
802
+ ) -> Dict[str, Any]:
803
+ """Prepare inputs for generation."""
804
+ # If we have past_key_values, we only need the last token
805
+ if past_key_values is not None:
806
+ input_ids = input_ids[:, -1:]
807
+
808
+ return {
809
+ "input_ids": input_ids,
810
+ "past_key_values": past_key_values,
811
+ "use_cache": True,
812
+ }
813
+
814
+ def get_input_embeddings(self):
815
+ """Get the input embeddings layer."""
816
+ return self.pico_decoder.embedding_proj
817
+
818
+ def set_input_embeddings(self, value):
819
+ """Set the input embeddings layer."""
820
+ self.pico_decoder.embedding_proj = value
821
+
822
+ def get_output_embeddings(self):
823
+ """Get the output embeddings layer."""
824
+ return self.pico_decoder.de_embedding_proj
825
+
826
+ def set_output_embeddings(self, value):
827
+ """Set the output embeddings layer."""
828
+ self.pico_decoder.de_embedding_proj = value
829
+
830
+ def get_lm_head(self):
831
+ """Get the language model head."""
832
+ return self.pico_decoder.de_embedding_proj
833
+
834
+ def can_generate(self) -> bool:
835
+ """Check if the model can generate text."""
836
+ return True
837
+
838
+ @property
839
+ def is_encoder_decoder(self) -> bool:
840
+ """Check if the model is an encoder-decoder model."""
841
+ return False
842
+
843
+ @property
844
+ def can_use_cache(self) -> bool:
845
+ """Check if the model can use KV cache."""
846
+ return True
847
+
848
+ def resize_token_embeddings(
849
+ self, new_num_tokens: Optional[int] = None
850
+ ) -> torch.nn.Embedding:
851
+ """Resize token embeddings."""
852
+ old_embeddings = self.get_input_embeddings()
853
+ if new_num_tokens is None:
854
+ new_num_tokens = old_embeddings.num_embeddings
855
+
856
+ new_embeddings = torch.nn.Embedding(
857
+ new_num_tokens, old_embeddings.embedding_dim
858
+ )
859
+ new_embeddings.weight.data[: old_embeddings.num_embeddings] = (
860
+ old_embeddings.weight.data
861
+ )
862
+
863
+ self.pico_decoder.embedding_proj = new_embeddings
864
+ self.pico_decoder.de_embedding_proj = torch.nn.Linear(
865
+ old_embeddings.embedding_dim, new_num_tokens, bias=False
866
+ )
867
+
868
+ return new_embeddings
869
+
870
+ @classmethod
871
+ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
872
+ """
873
+ Load a pretrained model from a checkpoint.
874
+
875
+ This method handles loading from both the old PicoDecoderHF format and the new format.
876
+ """
877
+ # First try to load with the new class
878
+ try:
879
+ return super().from_pretrained(
880
+ pretrained_model_name_or_path, *model_args, **kwargs
881
+ )
882
+ except Exception as e:
883
+ print(f"Failed to load with new class: {e}")
884
+ print("Attempting to load with legacy class and convert...")
885
+
886
+ # Try to load with the old class and convert
887
+ try:
888
+ from transformers import AutoModel
889
+
890
+ old_model = AutoModel.from_pretrained(
891
+ pretrained_model_name_or_path,
892
+ trust_remote_code=True,
893
+ *model_args,
894
+ **kwargs,
895
+ )
896
+
897
+ # Create new model instance
898
+ new_model = cls(old_model.config)
899
+
900
+ # Copy state dict
901
+ new_model.load_state_dict(old_model.state_dict(), strict=False)
902
+
903
+ return new_model
904
+
905
+ except Exception as e2:
906
+ print(f"Failed to convert from legacy format: {e2}")
907
+ raise e
908
+
909
+
910
+ # Register the new class
911
+ PicoDecoderForCausalLM.register_for_auto_class("AutoModelForCausalLM")
src/training/trainer.py ADDED
@@ -0,0 +1,753 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Pico Language Model Trainer
3
+
4
+ This Trainer implements a minimalistic end-to-end training pipeline of the Pico language model with
5
+ distributed training support via Lightning Fabric. It provides a modular and configurable training
6
+ pipeline with the features:
7
+
8
+ - Configuration Management: YAML-based configuration for all aspects of training
9
+ - Distributed Training: Multi-GPU support via Lightning Fabric
10
+ - Checkpointing: Regular model saving and training state recovery
11
+ - Evaluation: Periodic model evaluation on validation datasets
12
+ - Logging: Comprehensive metric tracking and experiment monitoring
13
+ - Optimization: Support for gradient accumulation, clipping, and LR scheduling
14
+ """
15
+
16
+ import logging
17
+ import os
18
+ import platform
19
+ from typing import Any, Dict
20
+
21
+ import lightning as L
22
+ import psutil
23
+ import torch
24
+ import torch.nn.functional as F
25
+ import yaml
26
+ from datasets import Dataset, load_dataset
27
+ from lightning.fabric.utilities.rank_zero import rank_zero_only
28
+
29
+ from src.checkpointing import (
30
+ compute_learning_dynamics_states,
31
+ load_checkpoint,
32
+ save_checkpoint,
33
+ save_evaluation_results,
34
+ save_learning_dynamics_states,
35
+ )
36
+ from src.evaluation import run_evaluation
37
+ from src.training.utils import (
38
+ initialize_configuration,
39
+ initialize_dataloader,
40
+ initialize_dataset,
41
+ initialize_fabric,
42
+ initialize_hf_checkpointing,
43
+ initialize_logging,
44
+ initialize_lr_scheduler,
45
+ initialize_model,
46
+ initialize_optimizer,
47
+ initialize_run_dir,
48
+ initialize_tokenizer,
49
+ initialize_wandb,
50
+ )
51
+ from src.training.utils.logging import pretty_print_yaml_config
52
+
53
+
54
+ class Trainer:
55
+ def __init__(self, config_path: str):
56
+ """
57
+ Initializes the Trainer class. This Trainer class implements a `train` method, which is the
58
+ main entry point for training the Pico model. Before calling `train`, the Trainer class
59
+ initializes the following:
60
+
61
+ - Configuration loading and validation
62
+ - Model, optimizer, and dataset setup
63
+ - Logging and experiment tracking setup
64
+ - Checkpoint management
65
+
66
+ Args:
67
+ config_path (str): Path to the YAML configuration file containing any overrides.
68
+ """
69
+
70
+ ########################################################
71
+ #
72
+ # Basic Initialization of Configs, Fabric, Model, Optimizer, etc.
73
+ #
74
+ ########################################################
75
+
76
+ # Setup Config
77
+ self.configs = initialize_configuration(config_path)
78
+
79
+ # Setup Run Directory (i.e. where we store checkpoints, logs, etc.)
80
+ initialize_run_dir(checkpointing_config=self.configs["checkpointing"])
81
+
82
+ # Setup Logger
83
+ if self.configs["monitoring"].save_to_wandb:
84
+ wandb_logger = initialize_wandb(
85
+ monitoring_config=self.configs["monitoring"],
86
+ checkpointing_config=self.configs["checkpointing"],
87
+ )
88
+ else:
89
+ wandb_logger = None
90
+
91
+ # Setup Fabric
92
+ self.fabric = initialize_fabric(
93
+ training_config=self.configs["training"],
94
+ wandb_logger=wandb_logger,
95
+ )
96
+ L.seed_everything(42, verbose=False)
97
+
98
+ # Optimize for Tensor Cores on RTX 5090
99
+ if self.fabric.device.type == "cuda":
100
+ torch.set_float32_matmul_precision(
101
+ "high"
102
+ ) # Best performance for Tensor Cores
103
+ print(
104
+ "Enabled Tensor Core optimization: torch.set_float32_matmul_precision('high')"
105
+ )
106
+
107
+ # Set up logging
108
+ self.logger = initialize_logging(
109
+ monitoring_config=self.configs["monitoring"],
110
+ checkpointing_config=self.configs["checkpointing"],
111
+ fabric=self.fabric,
112
+ )
113
+
114
+ # Setup Model, Optimizer, and Dataloaders
115
+ self.model = initialize_model(model_config=self.configs["model"])
116
+ self.optimizer = initialize_optimizer(
117
+ training_config=self.configs["training"], model=self.model
118
+ )
119
+ self.lr_scheduler = initialize_lr_scheduler(
120
+ training_config=self.configs["training"], optimizer=self.optimizer
121
+ )
122
+
123
+ # Wrap model and optimizer with Fabric
124
+ self.model, self.optimizer = self.fabric.setup(self.model, self.optimizer)
125
+
126
+ # Setup HuggingFace Checkpointing
127
+ if self.configs["checkpointing"].save_to_hf:
128
+ initialize_hf_checkpointing(
129
+ checkpointing_config=self.configs["checkpointing"], fabric=self.fabric
130
+ )
131
+
132
+ ########################################################
133
+ #
134
+ # Boilerplate to deal with loading/resuming from checkpoints
135
+ #
136
+ ########################################################
137
+
138
+ self.should_load_checkpoint = self.configs["checkpointing"].training.auto_resume
139
+
140
+ # Possibly load a checkpoint
141
+ if self.should_load_checkpoint:
142
+ resume_checkpoint = load_checkpoint(
143
+ checkpointing_config=self.configs["checkpointing"],
144
+ checkpoint_step="latest",
145
+ fabric=self.fabric,
146
+ model=self.model,
147
+ optimizer=self.optimizer,
148
+ lr_scheduler=self.lr_scheduler,
149
+ )
150
+
151
+ if resume_checkpoint:
152
+ (
153
+ self.model,
154
+ self.optimizer,
155
+ self.lr_scheduler,
156
+ self.initial_batch_step,
157
+ ) = resume_checkpoint
158
+ else:
159
+ self.initial_batch_step = 0
160
+ else:
161
+ self.initial_batch_step = 0
162
+
163
+ ########################################################
164
+ #
165
+ # Initialization of Dataset & DataLoader (possibly fast-forwarding to correct batch)
166
+ #
167
+ ########################################################
168
+
169
+ self.train_dataset, fast_forward_steps = initialize_dataset(
170
+ data_config=self.configs["data"],
171
+ fabric=self.fabric,
172
+ initial_batch_step=self.initial_batch_step,
173
+ return_fast_forward_steps=True,
174
+ )
175
+
176
+ self.train_dataloader = initialize_dataloader(
177
+ data_config=self.configs["data"],
178
+ training_config=self.configs["training"],
179
+ fabric=self.fabric,
180
+ dataset=self.train_dataset,
181
+ )
182
+ self.train_dataloader = self.fabric.setup_dataloaders(
183
+ self.train_dataloader, use_distributed_sampler=False
184
+ )
185
+
186
+ self.tokenizer = initialize_tokenizer(data_config=self.configs["data"])
187
+
188
+ # NOTE: We may need to fast-forward the iterator to the correct step so that we can
189
+ # continue from the correct batch of data we would have seen had training not
190
+ # previously stopped.
191
+ train_iterator = iter(self.train_dataloader)
192
+ if fast_forward_steps > 0:
193
+ fast_forward_sub_steps = (
194
+ fast_forward_steps
195
+ * self.configs["training"].optimization.gradient_accumulation_steps
196
+ )
197
+ for _ in range(fast_forward_sub_steps):
198
+ next(train_iterator)
199
+
200
+ self.train_iterator = train_iterator
201
+
202
+ # NOTE: Sychronizing processes after fast-forwarding iterator
203
+ self.fabric.barrier()
204
+
205
+ ########################################################
206
+ #
207
+ # Helper flags used during training for checkpointing and evaluation
208
+ #
209
+ ########################################################
210
+
211
+ # Helper flag to determine if we should evaluate the model
212
+ self.should_evaluate = (
213
+ self.configs["evaluation"].metrics is not None
214
+ and len(self.configs["evaluation"].metrics) > 0
215
+ )
216
+
217
+ self.should_compute_learning_dynamics = (
218
+ self.configs["checkpointing"].learning_dynamics.layer_suffixes is not None
219
+ and len(self.configs["checkpointing"].learning_dynamics.layer_suffixes) > 0
220
+ )
221
+
222
+ if self.should_compute_learning_dynamics:
223
+ if self.configs["checkpointing"].learning_dynamics.eval_data is not None:
224
+ self.learning_dynamics_eval_dataset = load_dataset(
225
+ self.configs["checkpointing"].learning_dynamics.eval_data,
226
+ split="val",
227
+ )
228
+ else:
229
+ self.learning_dynamics_eval_dataset = None
230
+
231
+ def train(self) -> None:
232
+ """Execute the main training pipeline.
233
+
234
+ This method orchestrates the complete training process by:
235
+ 1. Creating an initial checkpoint to save the starting state and evaluate the model as a
236
+ baseline
237
+ 2. Running the main training loop via `_training_loop`
238
+ 3. Handling final checkpointing and evaluation
239
+
240
+ The training progress is tracked through checkpoints and evaluations
241
+ at intervals specified in the configuration.
242
+ """
243
+
244
+ ########################################################
245
+ #
246
+ # Initial Checkpointing and Evaluation
247
+ #
248
+ ########################################################
249
+
250
+ # Save Initial Checkpoint -- If the checkpoint already exists, this performs a no-op
251
+ save_checkpoint(
252
+ configs=self.configs,
253
+ checkpoint_step=self.initial_batch_step,
254
+ fabric=self.fabric,
255
+ model=self.model,
256
+ optimizer=self.optimizer,
257
+ lr_scheduler=self.lr_scheduler,
258
+ tokenizer=self.tokenizer,
259
+ )
260
+
261
+ # Save Initial Evaluation Results
262
+ if self.should_evaluate:
263
+ if self.initial_batch_step == 0:
264
+ evaluation_results = run_evaluation(
265
+ evaluation_config=self.configs["evaluation"],
266
+ checkpointing_config=self.configs["checkpointing"],
267
+ fabric=self.fabric,
268
+ model=self.model,
269
+ )
270
+ self._log_evaluation_results(
271
+ evaluation_results, self.initial_batch_step
272
+ )
273
+ save_evaluation_results(
274
+ checkpointing_config=self.configs["checkpointing"],
275
+ fabric=self.fabric,
276
+ evaluation_results=evaluation_results,
277
+ checkpoint_step=self.initial_batch_step,
278
+ )
279
+ else:
280
+ # NOTE: If the run crashed while evaluating, we need to restart the evaluation
281
+ eval_results_path = os.path.join(
282
+ self.configs["checkpointing"].evaluation.eval_results_dir,
283
+ f"step_{self.initial_batch_step}.json",
284
+ )
285
+ if not os.path.exists(eval_results_path):
286
+ evaluation_results = run_evaluation(
287
+ evaluation_config=self.configs["evaluation"],
288
+ checkpointing_config=self.configs["checkpointing"],
289
+ fabric=self.fabric,
290
+ model=self.model,
291
+ )
292
+ self._log_evaluation_results(
293
+ evaluation_results, self.initial_batch_step
294
+ )
295
+ save_evaluation_results(
296
+ checkpointing_config=self.configs["checkpointing"],
297
+ fabric=self.fabric,
298
+ evaluation_results=evaluation_results,
299
+ checkpoint_step=self.initial_batch_step,
300
+ )
301
+
302
+ ########################################################
303
+ #
304
+ # Main Training Loop (see `_training_loop` for details)
305
+ #
306
+ ########################################################
307
+
308
+ if self.initial_batch_step < self.configs["training"].max_steps:
309
+ self._log_training_configuration()
310
+ final_step = self._training_loop()
311
+ else:
312
+ final_step = self.initial_batch_step
313
+
314
+ ########################################################
315
+ #
316
+ # Final Checkpointing and Evaluation
317
+ #
318
+ ########################################################
319
+
320
+ # Save Learning Dynamics States
321
+ if self.should_compute_learning_dynamics:
322
+ if self.learning_dynamics_eval_dataset is not None:
323
+ self.log(f"Step {final_step} -- 📈 Saving Learning Dynamics")
324
+ learning_dynamics_val_states = compute_learning_dynamics_states(
325
+ checkpointing_config=self.configs["checkpointing"],
326
+ fabric=self.fabric,
327
+ model=self.model,
328
+ dataset=self.learning_dynamics_eval_dataset,
329
+ compute_gradients=True,
330
+ )
331
+ save_learning_dynamics_states(
332
+ checkpointing_config=self.configs["checkpointing"],
333
+ fabric=self.fabric,
334
+ learning_dynamics_states=learning_dynamics_val_states,
335
+ checkpoint_step=final_step,
336
+ prefix="val",
337
+ )
338
+
339
+ # Handle checkpointing and final evaluation
340
+ if final_step % self.configs["checkpointing"].save_every_n_steps != 0:
341
+ self.log(f"Step {final_step} -- 💾 Saving Final Checkpoint")
342
+ save_checkpoint(
343
+ configs=self.configs,
344
+ checkpoint_step=final_step,
345
+ fabric=self.fabric,
346
+ model=self.model,
347
+ optimizer=self.optimizer,
348
+ lr_scheduler=self.lr_scheduler,
349
+ tokenizer=self.tokenizer,
350
+ )
351
+
352
+ # Final evaluation
353
+ if self.should_evaluate:
354
+ evaluation_results = run_evaluation(
355
+ evaluation_config=self.configs["evaluation"],
356
+ checkpointing_config=self.configs["checkpointing"],
357
+ fabric=self.fabric,
358
+ model=self.model,
359
+ )
360
+ self._log_evaluation_results(evaluation_results, final_step)
361
+ save_evaluation_results(
362
+ checkpointing_config=self.configs["checkpointing"],
363
+ checkpoint_step=final_step,
364
+ fabric=self.fabric,
365
+ evaluation_results=evaluation_results,
366
+ )
367
+
368
+ self.log(f"🎉 Training complete! Final step: {final_step}")
369
+
370
+ if final_step < self.configs["training"].max_steps:
371
+ self.log(
372
+ f"\t Note: Training stopped before max steps ({self.configs['training'].max_steps})",
373
+ level=logging.WARNING,
374
+ )
375
+
376
+ # Cleanup distributed training
377
+ self.fabric.barrier()
378
+ if torch.cuda.is_available():
379
+ torch.cuda.empty_cache()
380
+ if torch.distributed.is_initialized():
381
+ torch.distributed.destroy_process_group()
382
+
383
+ del self.train_dataloader # NOTE: shutting down worker nodes
384
+
385
+ self.fabric.barrier()
386
+
387
+ def _training_loop(self) -> int:
388
+ """Execute the main training loop.
389
+
390
+ This method orchestrates the core training loop and includes the following features:
391
+ - Gradient accumulation
392
+ - Gradient clipping
393
+ - Periodic model evaluation and checkpointing
394
+ - Learning Dynamics Checkpointing
395
+ - Learning rate scheduling
396
+ - Logging of training metrics including loss and learning rate
397
+ - Handling of infinite/NaN losses
398
+
399
+ Returns:
400
+ int: The final step count reached during training.
401
+ NOTE: A complete training run should match the configured max_steps.
402
+ """
403
+ # Setup training loop variables
404
+ batch_step = self.initial_batch_step
405
+
406
+ # NOTE: these are used to compute the average loss over a training interval.
407
+ # This is more accurate than using the loss at the end of the interval.
408
+ interval_loss = torch.tensor(0.0, device=self.fabric.device)
409
+ interval_steps = torch.tensor(0, device=self.fabric.device)
410
+ interval_inf_or_nan_count = torch.tensor(0, device=self.fabric.device)
411
+
412
+ if self.should_compute_learning_dynamics:
413
+ # NOTE: we basically re-construct the full batch here so that we can compute learning dynamics
414
+ training_batch = {"input_ids": []}
415
+
416
+ # NOTE: determine what sub-batch we should start from
417
+ initial_sub_batch_step = (
418
+ batch_step
419
+ * self.configs["training"].optimization.gradient_accumulation_steps
420
+ )
421
+
422
+ ###############################################################
423
+ #
424
+ # Core loop starts here
425
+ # NOTE: the ratio between sub_batch_step and batch_step
426
+ # is the configured number of gradient_accumulation_steps
427
+ # i.e. with 32 configured gradient accumulation steps,
428
+ # there are 32 sub_batch_steps for each batch_step
429
+ #
430
+ ###############################################################
431
+
432
+ for sub_batch_step, sub_batch in enumerate(
433
+ self.train_iterator, start=initial_sub_batch_step
434
+ ):
435
+ # NOTE: We want to store the entire training batch whenever we are computing learning dynamics
436
+ # and we are at a checkpointing step.
437
+ should_store_training_batch = self.should_compute_learning_dynamics and (
438
+ batch_step % self.configs["checkpointing"].save_every_n_steps == 0
439
+ )
440
+
441
+ ########################################################
442
+ #
443
+ # Forward Pass
444
+ #
445
+ ########################################################
446
+
447
+ _input_ids = torch.tensor(sub_batch["input_ids"], device=self.fabric.device)
448
+ input_ids = _input_ids[:, :-1]
449
+ labels = _input_ids[:, 1:]
450
+
451
+ if should_store_training_batch:
452
+ gathered_input_ids = self.fabric.all_gather(_input_ids)
453
+
454
+ # NOTE: On multi-GPU, we need to reshape the input_ids to be a 2D tensor; on
455
+ # a single GPU, the input_ids are already a 2D tensor.
456
+ if self.fabric.world_size > 1:
457
+ gathered_input_ids = gathered_input_ids.reshape(
458
+ -1, *gathered_input_ids.shape[2:]
459
+ )
460
+
461
+ training_batch["input_ids"].extend(gathered_input_ids.tolist())
462
+
463
+ # Forward pass
464
+ model_output, _ = self.model(input_ids)
465
+ model_output = model_output.transpose(1, 2)
466
+
467
+ ########################################################
468
+ #
469
+ # Gradient accumulation
470
+ #
471
+ ########################################################
472
+
473
+ should_accumulate_gradients = (sub_batch_step + 1) % self.configs[
474
+ "training"
475
+ ].optimization.gradient_accumulation_steps != 0
476
+
477
+ with self.fabric.no_backward_sync(
478
+ self.model, enabled=should_accumulate_gradients
479
+ ):
480
+ loss = F.cross_entropy(model_output, labels)
481
+ self.fabric.backward(
482
+ loss
483
+ / self.configs["training"].optimization.gradient_accumulation_steps,
484
+ model=self.model,
485
+ )
486
+
487
+ if torch.isnan(loss) or torch.isinf(loss):
488
+ interval_inf_or_nan_count += 1
489
+ else:
490
+ interval_loss += loss.item()
491
+ interval_steps += 1
492
+
493
+ # NOTE: if we are not accumulating gradients, we should skip the logging and optimization steps
494
+ if should_accumulate_gradients:
495
+ continue
496
+
497
+ ########################################################
498
+ #
499
+ # Logging
500
+ #
501
+ ########################################################
502
+
503
+ if batch_step % self.configs["monitoring"].logging.log_every_n_steps == 0:
504
+ self._log_training_metrics(
505
+ interval_loss=interval_loss,
506
+ interval_steps=interval_steps,
507
+ interval_inf_or_nan_count=interval_inf_or_nan_count,
508
+ batch_step=batch_step,
509
+ )
510
+ interval_loss = torch.tensor(0.0, device=self.fabric.device)
511
+ interval_steps = torch.tensor(0, device=self.fabric.device)
512
+ interval_inf_or_nan_count = torch.tensor(0, device=self.fabric.device)
513
+
514
+ ########################################################
515
+ #
516
+ # Learning Dynamics Checkpointing
517
+ #
518
+ ########################################################
519
+
520
+ if batch_step % self.configs["checkpointing"].save_every_n_steps == 0:
521
+ if self.should_compute_learning_dynamics:
522
+ self.log(f"Step {batch_step} -- 📈 Saving Learning Dynamics")
523
+
524
+ # Training Batch Learning Dynamics
525
+ training_batch_dataset = Dataset.from_dict(training_batch)
526
+
527
+ learning_dynamics_train_states = compute_learning_dynamics_states(
528
+ checkpointing_config=self.configs["checkpointing"],
529
+ fabric=self.fabric,
530
+ model=self.model,
531
+ dataset=training_batch_dataset,
532
+ compute_gradients=True,
533
+ )
534
+
535
+ save_learning_dynamics_states(
536
+ checkpointing_config=self.configs["checkpointing"],
537
+ checkpoint_step=batch_step,
538
+ prefix="train",
539
+ fabric=self.fabric,
540
+ learning_dynamics_states=learning_dynamics_train_states,
541
+ learning_dynamics_dataset=training_batch_dataset,
542
+ tokenizer=self.tokenizer,
543
+ )
544
+ training_batch = {
545
+ "input_ids": []
546
+ } # Resetting training_batch for next training batch
547
+
548
+ # Validation Data Learning Dynamics
549
+ if self.learning_dynamics_eval_dataset is not None:
550
+ learning_dynamics_val_states = compute_learning_dynamics_states(
551
+ checkpointing_config=self.configs["checkpointing"],
552
+ fabric=self.fabric,
553
+ model=self.model,
554
+ dataset=self.learning_dynamics_eval_dataset,
555
+ compute_gradients=True,
556
+ )
557
+ save_learning_dynamics_states(
558
+ checkpointing_config=self.configs["checkpointing"],
559
+ checkpoint_step=batch_step,
560
+ prefix="val",
561
+ fabric=self.fabric,
562
+ learning_dynamics_states=learning_dynamics_val_states,
563
+ )
564
+
565
+ ########################################################
566
+ #
567
+ # Optimization step
568
+ #
569
+ ########################################################
570
+
571
+ self.optimizer.step()
572
+ self.optimizer.zero_grad()
573
+ self.lr_scheduler.step()
574
+
575
+ batch_step += 1
576
+
577
+ ########################################################
578
+ #
579
+ # Training Checkpointing and evaluation
580
+ #
581
+ ########################################################
582
+
583
+ if batch_step % self.configs["checkpointing"].save_every_n_steps == 0:
584
+ self.log(f"Step {batch_step} -- 💾 Saving Checkpoint")
585
+ save_checkpoint(
586
+ configs=self.configs,
587
+ checkpoint_step=batch_step,
588
+ fabric=self.fabric,
589
+ model=self.model,
590
+ optimizer=self.optimizer,
591
+ lr_scheduler=self.lr_scheduler,
592
+ tokenizer=self.tokenizer,
593
+ )
594
+
595
+ if self.should_evaluate:
596
+ evaluation_results = run_evaluation(
597
+ evaluation_config=self.configs["evaluation"],
598
+ checkpointing_config=self.configs["checkpointing"],
599
+ fabric=self.fabric,
600
+ model=self.model,
601
+ )
602
+ if evaluation_results is not None:
603
+ self._log_evaluation_results(evaluation_results, batch_step)
604
+ save_evaluation_results(
605
+ checkpointing_config=self.configs["checkpointing"],
606
+ fabric=self.fabric,
607
+ evaluation_results=evaluation_results,
608
+ checkpoint_step=batch_step,
609
+ )
610
+
611
+ # Break if we've reached training steps
612
+ if batch_step >= self.configs["training"].max_steps:
613
+ break
614
+
615
+ return batch_step
616
+
617
+ ########################################################
618
+ #
619
+ # Trainer Logging Functinalities
620
+ #
621
+ ########################################################
622
+
623
+ def _log_training_metrics(
624
+ self,
625
+ interval_loss: torch.Tensor,
626
+ interval_steps: torch.Tensor,
627
+ interval_inf_or_nan_count: torch.Tensor,
628
+ batch_step: int,
629
+ ):
630
+ """
631
+ Gathers together the training metrics computed across all processes in distributed training
632
+ and logs them in a tree-style format.
633
+ """
634
+ gathered_interval_loss = self.fabric.all_reduce(
635
+ interval_loss, reduce_op="sum"
636
+ ).item()
637
+ gathered_interval_inf_or_nan_count = self.fabric.all_reduce(
638
+ interval_inf_or_nan_count, reduce_op="sum"
639
+ ).item()
640
+ gathered_interval_steps = self.fabric.all_reduce(
641
+ interval_steps, reduce_op="sum"
642
+ ).item()
643
+
644
+ avg_loss = (
645
+ gathered_interval_loss / gathered_interval_steps
646
+ if gathered_interval_steps > 0
647
+ else float("inf")
648
+ )
649
+
650
+ self.fabric.log("train/loss", avg_loss, step=batch_step)
651
+ self.fabric.log(
652
+ "trainer/inf_or_nan_count",
653
+ gathered_interval_inf_or_nan_count,
654
+ step=batch_step,
655
+ )
656
+ self.fabric.log(
657
+ "trainer/learning_rate",
658
+ self.lr_scheduler.get_last_lr()[0],
659
+ step=batch_step,
660
+ )
661
+
662
+ # Log to console in tree format
663
+ self.log(f"Step {batch_step} -- 🔄 Training Metrics")
664
+ self.log(f"├── Loss: {avg_loss:.4f}")
665
+ self.log(f"├── Learning Rate: {self.lr_scheduler.get_last_lr()[0]:.2e}")
666
+ self.log(f"└── Inf/NaN count: {gathered_interval_inf_or_nan_count}")
667
+
668
+ def _log_evaluation_results(
669
+ self, evaluation_results: Dict[str, Any], batch_step: int
670
+ ):
671
+ """Log model evaluation metrics to experiment tracking system and console."""
672
+ self.log(f"Step {batch_step} -- 📊 Evaluation Results")
673
+ for i, (metric, result) in enumerate(evaluation_results.items()):
674
+ prefix = "└──" if i == len(evaluation_results) - 1 else "├──"
675
+ self.log(f"{prefix} {metric}: {result}")
676
+ self.fabric.log(f"eval/{metric}", result, step=batch_step)
677
+
678
+ def _log_training_configuration(self):
679
+ """
680
+ Log training configuration details as well as runtime information about the hardware,
681
+ software, and batch settings.
682
+
683
+ This function is called at the beginning of the training loop to provide a summary of the
684
+ training configuration.
685
+ """
686
+
687
+ total_params = sum(p.numel() for p in self.model.parameters())
688
+ trainable_params = sum(
689
+ p.numel() for p in self.model.parameters() if p.requires_grad
690
+ )
691
+ global_batch_size = self.configs["data"].dataloader.batch_size
692
+ per_device_batch_size = self.train_dataloader.batch_size
693
+ gradient_accumulation_steps = self.configs[
694
+ "training"
695
+ ].optimization.gradient_accumulation_steps
696
+
697
+ device_type = ""
698
+ fabric_device = str(self.fabric.device)
699
+ if torch.cuda.is_available() and "cuda" in fabric_device:
700
+ device_type = torch.cuda.get_device_name(self.fabric.device)
701
+ elif torch.backends.mps.is_available() and "mps" in fabric_device:
702
+ device_type = "MPS (Apple Silicon)"
703
+ else:
704
+ device_type = "CPU"
705
+
706
+ training_config_path = os.path.join(
707
+ self.configs["checkpointing"].runs_dir,
708
+ self.configs["checkpointing"].run_name,
709
+ "training_config.yaml",
710
+ )
711
+ if os.path.exists(training_config_path):
712
+ self.log("=" * 50)
713
+ self.log("✨ Training Configuration")
714
+ self.log("=" * 50)
715
+ training_config = yaml.safe_load(open(training_config_path, "r"))
716
+ pretty_print_yaml_config(self.logger, training_config)
717
+
718
+ self.log("=" * 50)
719
+ self.log("⛭ Runtime Summary:")
720
+ self.log("=" * 50)
721
+ self.log(f"Starting from step: {self.initial_batch_step}")
722
+
723
+ self.log("Model Setup:")
724
+ self.log(f"└─ Total Parameters: {total_params:,}")
725
+ self.log(f"└─ Trainable Parameters: {trainable_params:,}")
726
+
727
+ self.log("Distributed Setup:")
728
+ self.log(f"└─ Number of Devices: {self.fabric.world_size}")
729
+ self.log(f"└─ Device Type: {device_type}")
730
+ self.log(
731
+ f"└─ Available Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB"
732
+ if torch.cuda.is_available()
733
+ else f"└─ Available Memory: {psutil.virtual_memory().total / 1e9:.2f} GB"
734
+ )
735
+
736
+ self.log("Software Setup:")
737
+ self.log(f"└─ Python Version: {platform.python_version()}")
738
+ self.log(f"└─ PyTorch Version: {torch.__version__}")
739
+ self.log(
740
+ f"└─ CUDA Version: {torch.version.cuda if torch.cuda.is_available() else 'N/A'}"
741
+ )
742
+ self.log(f"└─ Operating System: {platform.system()} {platform.release()}")
743
+
744
+ self.log("Batch Size Configuration:")
745
+ self.log(f"└─ Global Batch Size: {global_batch_size}")
746
+ self.log(f"└─ Per Device Batch Size: {per_device_batch_size}")
747
+ self.log(f"└─ Gradient Accumulation Steps: {gradient_accumulation_steps}")
748
+ self.log("=" * 50)
749
+
750
+ @rank_zero_only
751
+ def log(self, msg: str, level: int = logging.INFO) -> None:
752
+ """NOTE: Log messages only from rank zero process."""
753
+ self.logger.log(level, msg)
src/training/utils/__init__.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Utility package that contains functions for the training process, e.g. initialization, logging, etc.
3
+ """
4
+
5
+ # For convenience, we export the initialization functions here
6
+ from .initialization import (
7
+ initialize_configuration,
8
+ initialize_dataloader,
9
+ initialize_dataset,
10
+ initialize_fabric,
11
+ initialize_hf_checkpointing,
12
+ initialize_logging,
13
+ initialize_lr_scheduler,
14
+ initialize_model,
15
+ initialize_optimizer,
16
+ initialize_run_dir,
17
+ initialize_tokenizer,
18
+ initialize_wandb,
19
+ )
20
+
21
+ __all__ = [
22
+ "initialize_configuration",
23
+ "initialize_dataloader",
24
+ "initialize_dataset",
25
+ "initialize_fabric",
26
+ "initialize_hf_checkpointing",
27
+ "initialize_logging",
28
+ "initialize_lr_scheduler",
29
+ "initialize_model",
30
+ "initialize_optimizer",
31
+ "initialize_run_dir",
32
+ "initialize_tokenizer",
33
+ "initialize_wandb",
34
+ ]
src/training/utils/data.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Utilities for data loading and processing.
3
+ """
4
+
5
+ from torch.utils.data import IterableDataset
6
+
7
+
8
+ class ShardedIterableDataset(IterableDataset):
9
+ """
10
+ A super simple implementation of a sharded iterable dataset that enables DataParallelism
11
+ across multiple workers. Ensures that each worker gets a unique shard of the dataset.
12
+
13
+ NOTE: Also works fine if there is only one worker.
14
+ """
15
+
16
+ def __init__(self, dataset, rank, world_size):
17
+ self.dataset = dataset
18
+ self.rank = rank
19
+ self.world_size = world_size
20
+
21
+ def __iter__(self):
22
+ iterator = iter(self.dataset)
23
+ # NOTE: Start by skipping to this worker's shard
24
+ for _ in range(self.rank):
25
+ next(iterator)
26
+
27
+ # NOTE: Yield every world_size-th item
28
+ while True:
29
+ try:
30
+ yield next(iterator)
31
+ # Skip other workers' samples
32
+ for _ in range(self.world_size - 1):
33
+ next(iterator)
34
+ except StopIteration:
35
+ break
src/training/utils/initialization.py ADDED
@@ -0,0 +1,702 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Utilities for initializing components of the training process.
3
+
4
+ Here, we initialize all of the components that are part of the learning process. From logging,
5
+ and checkpointing to the optimizer to the dataset and the dataloader, this file contains the
6
+ logic for setting up the classes and functions that are used in the training loop.
7
+
8
+ As always, this code is meant to be basic. We hard-code the obvious defaults, and leave the
9
+ more experimental stuff to you.
10
+ """
11
+
12
+ import logging
13
+ import math
14
+ import os
15
+ import warnings
16
+ from dataclasses import fields, is_dataclass
17
+ from datetime import datetime
18
+ from typing import Dict, Optional, Union
19
+
20
+ import lightning as L
21
+ import torch
22
+ import yaml
23
+ from datasets import Dataset, DownloadConfig, load_dataset
24
+ from datasets import config as datasets_config
25
+ from huggingface_hub import add_collection_item, create_branch, create_repo
26
+ from lightning.fabric.loggers import Logger as FabricLogger
27
+ from lightning.fabric.utilities.rank_zero import rank_zero_only
28
+ from torch.utils.data import DataLoader
29
+ from transformers import AutoTokenizer
30
+
31
+ import wandb
32
+ from src.config import (
33
+ CheckpointingConfig,
34
+ DataConfig,
35
+ EvaluationConfig,
36
+ ModelConfig,
37
+ MonitoringConfig,
38
+ TrainingConfig,
39
+ )
40
+ from src.model import PicoDecoder
41
+ from src.training.utils.io import use_backoff
42
+ from wandb.integration.lightning.fabric import WandbLogger
43
+
44
+ warnings.filterwarnings(
45
+ "ignore",
46
+ message=".*This integration is tested and supported for lightning Fabric.*",
47
+ )
48
+ warnings.filterwarnings(
49
+ "ignore",
50
+ message=".*Please report any issues to.*",
51
+ )
52
+
53
+ ########################################################
54
+ #
55
+ # Basic Initialization
56
+ #
57
+ ########################################################
58
+
59
+
60
+ def _apply_config_overrides(config, overrides: dict):
61
+ """Recursively apply configuration overrides to a dataclass config object.
62
+
63
+ Args:
64
+ config: Base configuration object (must be a dataclass)
65
+ overrides: Dictionary of override values matching config structure
66
+
67
+ Returns:
68
+ Modified config object with overrides to the config.
69
+ """
70
+ for field in fields(config):
71
+ field_value = getattr(config, field.name)
72
+ if is_dataclass(field_value):
73
+ _apply_config_overrides(field_value, overrides.get(field.name, {}))
74
+ else:
75
+ if field.name in overrides:
76
+ setattr(config, field.name, overrides[field.name])
77
+ return config
78
+
79
+
80
+ def initialize_configuration(
81
+ config_path: Optional[str] = None,
82
+ ) -> Dict[
83
+ str,
84
+ Union[
85
+ DataConfig,
86
+ ModelConfig,
87
+ TrainingConfig,
88
+ EvaluationConfig,
89
+ MonitoringConfig,
90
+ CheckpointingConfig,
91
+ ],
92
+ ]:
93
+ """Initialize configuration objects with optional overrides from a YAML file.
94
+
95
+ This function initializes all of the configuration objects, and then applies
96
+ any overrides from the config_path file. If no config_path is provided,
97
+ the function will use the default configuration objects.
98
+
99
+ Args:
100
+ config_path: Path to a YAML file containing configuration overrides.
101
+
102
+ Returns:
103
+ A dictionary containing the initialized configuration objects.
104
+ """
105
+ data_config = DataConfig()
106
+ model_config = ModelConfig()
107
+ training_config = TrainingConfig()
108
+ evaluation_config = EvaluationConfig()
109
+ monitoring_config = MonitoringConfig()
110
+ checkpointing_config = CheckpointingConfig()
111
+
112
+ if config_path:
113
+ overrides = yaml.safe_load(open(config_path, "r"))
114
+ data_config = _apply_config_overrides(data_config, overrides.get("data", {}))
115
+ model_config = _apply_config_overrides(model_config, overrides.get("model", {}))
116
+ training_config = _apply_config_overrides(
117
+ training_config, overrides.get("training", {})
118
+ )
119
+ evaluation_config = _apply_config_overrides(
120
+ evaluation_config, overrides.get("evaluation", {})
121
+ )
122
+ monitoring_config = _apply_config_overrides(
123
+ monitoring_config, overrides.get("monitoring", {})
124
+ )
125
+ checkpointing_config = _apply_config_overrides(
126
+ checkpointing_config, overrides.get("checkpointing", {})
127
+ )
128
+
129
+ configs = {
130
+ "data": data_config,
131
+ "model": model_config,
132
+ "training": training_config,
133
+ "evaluation": evaluation_config,
134
+ "monitoring": monitoring_config,
135
+ "checkpointing": checkpointing_config,
136
+ }
137
+
138
+ return configs
139
+
140
+
141
+ def initialize_run_dir(checkpointing_config: CheckpointingConfig) -> str:
142
+ """Initialize a directory for the current training run.
143
+
144
+ Creates a unique directory for storing training, evaluation, and logging artifacts.
145
+ If no run name is specified in the config, generates a timestamp-based name.
146
+
147
+ Args:
148
+ checkpointing_config: Configuration object containing run settings.
149
+ NOTE: Must have a 'run_name' attribute that can be None, in which case
150
+ a timestamp-based name will be generated.
151
+
152
+ Returns:
153
+ str: The path to the run directory.
154
+ """
155
+ run_name = checkpointing_config.run_name
156
+ if run_name is None:
157
+ run_name = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
158
+ checkpointing_config.run_name = run_name
159
+
160
+ run_dir = os.path.join(checkpointing_config.runs_dir, run_name)
161
+
162
+ os.makedirs(run_dir, exist_ok=True)
163
+ return run_dir
164
+
165
+
166
+ def initialize_fabric(
167
+ training_config: TrainingConfig, wandb_logger: Optional[FabricLogger] = None
168
+ ):
169
+ """Initialize Lightning Fabric for distributed training.
170
+
171
+ Sets up a Lightning Fabric instance with the specified configuration for
172
+ handling distributed training, mixed precision, and logging.
173
+
174
+ Args:
175
+ training_config: Configuration object containing fabric settings
176
+ (accelerator, precision, devices, etc.).
177
+ wandb_logger: Optional weights and biases logger instance for experiment tracking
178
+
179
+ Returns:
180
+ L.Fabric: Initialized Lightning Fabric instance.
181
+
182
+ Example:
183
+ >>> fabric = initialize_fabric(training_config, wandb_logger)
184
+ """
185
+
186
+ total_devices = (
187
+ training_config.fabric.num_devices * training_config.fabric.num_nodes
188
+ )
189
+
190
+ if total_devices > 1:
191
+ strategy = "deepspeed_stage_2"
192
+ else:
193
+ strategy = "auto" # Sets up SingleDevice Strategy by default
194
+
195
+ # NOTE: The strategy is set to use either DeepSpeed (Zero Stage 2) on multi-GPU,
196
+ # or SingleDevice Strategy on single-GPU set ups. If you'd like to use a different strategy,
197
+ # you can change the strategy flag in the fabric initialization, but be aware that this might
198
+ # cause issues with checkpointing, evaluation, etc.
199
+
200
+ fabric = L.Fabric(
201
+ accelerator=training_config.fabric.accelerator,
202
+ precision=training_config.fabric.precision,
203
+ devices=training_config.fabric.num_devices,
204
+ num_nodes=training_config.fabric.num_nodes,
205
+ loggers=[wandb_logger] if wandb_logger is not None else None,
206
+ strategy=strategy,
207
+ )
208
+
209
+ fabric.launch()
210
+
211
+ return fabric
212
+
213
+
214
+ ########################################################
215
+ #
216
+ # Dataset and Tokenization Initialization
217
+ #
218
+ ########################################################
219
+
220
+
221
+ @use_backoff(max_retries=20)
222
+ def initialize_dataset(
223
+ data_config: DataConfig,
224
+ fabric: L.Fabric,
225
+ initial_batch_step: Optional[int] = 0,
226
+ return_fast_forward_steps: bool = False,
227
+ ):
228
+ """Initialize dataset based on the given config.
229
+
230
+ This function will return a dataset object, and optionally a fast_forward_steps value.
231
+
232
+ The fast_forward_steps value is the number of steps that we need to fast-forward an iterator by,
233
+ so that we can continue from a ertain batch of data we would have seen had training not previously
234
+ stopped. Depending on how the dataset is loaded, the amount of steps to fast-forward may be
235
+ different from the initial_batch_step value.
236
+
237
+ NOTE: This functionality is primarily useful for streaming datasets (which for large
238
+ datasets is most of the time).
239
+
240
+ Args:
241
+ data_config: Configuration object containing dataset settings.
242
+ fabric: A Lightning Fabric instance.
243
+ initial_batch_step: The initial batch step to fast-forward to.
244
+ return_fast_forward_steps: Whether to return the fast-forward steps value.
245
+
246
+ Returns:
247
+ Dataset: Initialized dataset object.
248
+ Optional[int]: Number of steps to fast-forward the iterator by, if return_fast_forward_steps is True.
249
+ """
250
+
251
+ datasets_config.STREAMING_READ_MAX_RETRIES = 40 # default is 20
252
+ datasets_config.STREAMING_READ_RETRY_INTERVAL = 10 # default is 5
253
+ download_config = DownloadConfig(
254
+ max_retries=20, # default is 1 and can lead to pre-mature HTTPS errors
255
+ )
256
+
257
+ fast_forward_steps = 0
258
+
259
+ if data_config.dataset.name == "pico-lm/pretokenized-dolma":
260
+ # NOTE: We know that the dataset is sharded into 10,000 shards, so we can easily compute
261
+ # the data file that we need to load in that contains the batch of data at
262
+ # initial_batch_step.
263
+
264
+ if initial_batch_step is not None:
265
+ examples_per_shard = 20_480
266
+ total_shards = 10_000
267
+ batches_per_shard = examples_per_shard // data_config.dataloader.batch_size
268
+ shard_idx = initial_batch_step // batches_per_shard
269
+
270
+ data_files = [
271
+ f"data/train-{str(_shard_idx).zfill(5)}-of-{total_shards}.parquet"
272
+ for _shard_idx in range(shard_idx, total_shards)
273
+ ]
274
+
275
+ fast_forward_steps = initial_batch_step % batches_per_shard
276
+ else:
277
+ data_files = None
278
+
279
+ base_dataset = load_dataset(
280
+ data_config.dataset.name,
281
+ split="train",
282
+ streaming=True,
283
+ data_files=data_files,
284
+ download_config=download_config,
285
+ )
286
+ else:
287
+ # NOTE: For other datasets, you might want to add some custom loading logic, especially
288
+ # to help with loading or fast-forwarding to the correct batch.
289
+
290
+ base_dataset = load_dataset(
291
+ data_config.dataset.name,
292
+ split="train",
293
+ streaming=True,
294
+ download_config=download_config,
295
+ )
296
+
297
+ if data_config.dataset.name == "pico-lm/pretokenized-dolma":
298
+ from .data import ShardedIterableDataset
299
+
300
+ # NOTE: We wrap the dataset in a ShardedIterableDataset, which is a custom class that
301
+ # allows us to shard an iterable dataset across multiple processes. This is useful for
302
+ # distributed training, where we want data-parallelism.
303
+ dataset = ShardedIterableDataset(
304
+ base_dataset, fabric.global_rank, fabric.world_size
305
+ )
306
+ else:
307
+ dataset = base_dataset
308
+
309
+ if return_fast_forward_steps:
310
+ return dataset, fast_forward_steps
311
+ else:
312
+ return dataset
313
+
314
+
315
+ def initialize_tokenizer(data_config: DataConfig):
316
+ """Initialize the tokenizer for text processing.
317
+
318
+ This function can be extended to include custom tokenization logic.
319
+
320
+ Args:
321
+ data_config: Configuration object containing tokenizer settings.
322
+
323
+ Returns:
324
+ AutoTokenizer: A HuggingFace tokenizer instance.
325
+ """
326
+
327
+ return AutoTokenizer.from_pretrained(data_config.tokenizer.name)
328
+
329
+
330
+ def initialize_dataloader(
331
+ data_config: DataConfig,
332
+ training_config: TrainingConfig,
333
+ fabric: L.Fabric,
334
+ dataset: Dataset,
335
+ ):
336
+ """Initialize the DataLoader for efficient batch processing.
337
+
338
+ Creates a PyTorch DataLoader that handles batching and data loading for training.
339
+ Configured specifically for streaming tokenized text datasets.
340
+
341
+ You might also want to extend this function to add a sampler, or some sort of custom
342
+ collate function. For the default dataset, we don't need any of this, because the data are
343
+ pre-shuffled, and pre-tokenized.
344
+
345
+ Args:
346
+ data_config: Configuration object containing dataloader settings.
347
+ training_config: Configuration object containing training settings.
348
+ fabric: A Lightning Fabric instance.
349
+ dataset: A HuggingFace Dataset object containing tokenized text data.
350
+ Expected to have 'input_ids' field in its items.
351
+
352
+ Returns:
353
+ DataLoader: PyTorch DataLoader instance configured for the dataset.
354
+ """
355
+
356
+ def _collate_fn(batch):
357
+ return {"input_ids": [entry["input_ids"] for entry in batch]}
358
+
359
+ sub_batch_size = data_config.dataloader.batch_size // (
360
+ fabric.world_size * training_config.optimization.gradient_accumulation_steps
361
+ )
362
+
363
+ # NOTE: We use the sub-batch size for the dataloader, which is the full batch size
364
+ # divided by the gradient accumulation steps. This ensures that the effective batch size
365
+ # is correct.
366
+
367
+ return DataLoader(
368
+ dataset,
369
+ batch_size=sub_batch_size,
370
+ shuffle=False, # Keep sequential for streaming datasets
371
+ pin_memory=True, # Speeds up transfer to GPU
372
+ collate_fn=_collate_fn,
373
+ )
374
+
375
+
376
+ ########################################################
377
+ #
378
+ # Model Initialization
379
+ #
380
+ ########################################################
381
+
382
+
383
+ def initialize_model(model_config: ModelConfig):
384
+ """Initialize the model for training.
385
+
386
+ Loads in a given model implemented in the `src.model` package and returns it.
387
+
388
+ NOTE: out of the box we currently only support the PicoDecoder model (a causal transformer
389
+ language model). If you'd like to implement your own model, you can do so by adding a new
390
+ model class in the `src.model` package, and then adding a new entry here.
391
+
392
+ Args:
393
+ model_config: Configuration object containing model settings.
394
+
395
+ Returns:
396
+ PyTorch model instance.
397
+
398
+ """
399
+ if model_config.model_type == "pico_decoder":
400
+ return PicoDecoder(model_config)
401
+ else:
402
+ raise ValueError(f"Invalid model type: {model_config.model_type}")
403
+
404
+
405
+ ########################################################
406
+ #
407
+ # Optimizer and Scheduler
408
+ #
409
+ ########################################################
410
+
411
+
412
+ def initialize_optimizer(training_config: TrainingConfig, model: torch.nn.Module):
413
+ """Initialize the optimizer for model training.
414
+
415
+ Creates an optimizer instance based on the configuration settings.
416
+
417
+ Add whatever other optimizers you want here.
418
+
419
+ Args:
420
+ training_config: Configuration object containing optimizer settings.
421
+ Must have:
422
+ - optimization.optimizer (str): Name of the optimizer ("adamw")
423
+ - optimization.lr (float): Learning rate for the optimizer
424
+ model: PyTorch model whose parameters will be optimized.
425
+
426
+ Returns:
427
+ torch.optim.Optimizer: Configured optimizer instance.
428
+
429
+ """
430
+
431
+ if training_config.optimization.optimizer == "adamw":
432
+ optimizer = torch.optim.AdamW(
433
+ model.parameters(), lr=training_config.optimization.lr
434
+ )
435
+ else:
436
+ raise ValueError(f"Invalid optimizer: {training_config.optimization.optimizer}")
437
+
438
+ return optimizer
439
+
440
+
441
+ def initialize_lr_scheduler(
442
+ training_config: TrainingConfig, optimizer: torch.optim.Optimizer
443
+ ):
444
+ """Initialize a learning rate scheduler with warmup and decay.
445
+
446
+ The default is a learning rate scheduler that implements a linear warmup followed by
447
+ linear decay. The learning rate increases linearly from 0 to the initial lr
448
+ during warmup, then decreases linearly to 0 during the remaining steps.
449
+
450
+ Add other types of learning rate schedulers here.
451
+
452
+ Args:
453
+ training_config: Configuration object containing optimizer and scheduler settings.
454
+ optimizer: PyTorch optimizer whose learning rate will be scheduled.
455
+
456
+ Returns:
457
+ torch.optim.lr_scheduler.LambdaLR: Learning rate scheduler instance.
458
+ """
459
+
460
+ if training_config.optimization.lr_scheduler == "linear_with_warmup":
461
+ # Credit where credit is due:
462
+ # https://github.com/huggingface/transformers/blob/e71a01a104dd663c730e494eb0b6467bb51df357/src/transformers/optimization.py#L102
463
+ def _lr_lambda(curr_step, num_warmup_steps, max_steps):
464
+ if curr_step < num_warmup_steps:
465
+ return float(curr_step) / float(max(1, num_warmup_steps))
466
+ else:
467
+ return max(
468
+ 0.0,
469
+ float(max_steps - curr_step)
470
+ / float(max(1, max_steps - num_warmup_steps)),
471
+ )
472
+
473
+ lr_lambda = lambda step: _lr_lambda( # noqa: E731
474
+ step,
475
+ training_config.optimization.lr_warmup_steps,
476
+ training_config.max_steps,
477
+ )
478
+ lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
479
+ optimizer,
480
+ lr_lambda,
481
+ )
482
+ elif training_config.optimization.lr_scheduler == "cosine":
483
+ # Cosine decay with warmup: linear warmup followed by cosine decay
484
+ # This provides sustained learning over long training runs
485
+ def _cosine_lr_lambda(curr_step, num_warmup_steps, max_steps):
486
+ if curr_step < num_warmup_steps:
487
+ # Linear warmup
488
+ return float(curr_step) / float(max(1, num_warmup_steps))
489
+ else:
490
+ # Cosine decay to 0.1 * initial_lr (not to 0)
491
+ progress = float(curr_step - num_warmup_steps) / float(
492
+ max(1, max_steps - num_warmup_steps)
493
+ )
494
+ return max(0.1, 0.5 * (1.0 + math.cos(math.pi * progress)))
495
+
496
+ lr_lambda = lambda step: _cosine_lr_lambda( # noqa: E731
497
+ step,
498
+ training_config.optimization.lr_warmup_steps,
499
+ training_config.max_steps,
500
+ )
501
+ lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
502
+ optimizer,
503
+ lr_lambda,
504
+ )
505
+ else:
506
+ raise ValueError(
507
+ f"Invalid learning rate scheduler: {training_config.optimization.lr_scheduler}"
508
+ )
509
+
510
+ return lr_scheduler
511
+
512
+
513
+ ########################################################
514
+ #
515
+ # Experiment Monitoring (Logging, Experiment Tracking, etc.)
516
+ #
517
+ ########################################################
518
+
519
+
520
+ def _initialize_log_file(checkpointing_config: CheckpointingConfig) -> str:
521
+ """Create and initialize a timestamped log file in the run's log directory.
522
+
523
+ Sets up a log file with a unique timestamp in the run's logging directory.
524
+ Creates the necessary directory structure if it doesn't exist.
525
+
526
+ Directory Structure:
527
+ {checkpointing_config.runs_dir}/
528
+ └── {checkpointing_config.run_name}/
529
+ └── {checkpointing_config.logs_dir}/
530
+ └── log_YYYYMMDD_HHMMSS.txt
531
+
532
+ Args:
533
+ checkpointing_config: Configuration object containing checkpointing settings.
534
+
535
+ Returns:
536
+ str: Absolute path to the created log file.
537
+
538
+ """
539
+
540
+ run_dir = os.path.join(checkpointing_config.runs_dir, checkpointing_config.run_name)
541
+ logs_dir = os.path.join(run_dir, checkpointing_config.logs_dir)
542
+ os.makedirs(logs_dir, exist_ok=True)
543
+
544
+ # datetime stamp
545
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
546
+ log_file_name = f"log_{timestamp}.log"
547
+ log_file_path = os.path.join(logs_dir, log_file_name)
548
+
549
+ open(log_file_path, "w").close() # Create an empty log file
550
+
551
+ return log_file_path
552
+
553
+
554
+ @use_backoff()
555
+ def initialize_wandb(
556
+ monitoring_config: MonitoringConfig, checkpointing_config: CheckpointingConfig
557
+ ):
558
+ """Initialize Weights and Biases.
559
+
560
+ This function initializes Weights and Biases based on the configuration settings.
561
+
562
+ Args:
563
+ monitoring_config: Configuration object containing monitoring settings.
564
+ checkpointing_config: Configuration object containing checkpointing settings.
565
+
566
+ Returns:
567
+ Optional[WandbLogger]: An experiment tracker instance.
568
+ """
569
+
570
+ assert (
571
+ monitoring_config.wandb.project is not None
572
+ and monitoring_config.wandb.project != ""
573
+ ), "Wandb project must be provided if wandb is to be used."
574
+ assert (
575
+ monitoring_config.wandb.entity is not None
576
+ and monitoring_config.wandb.entity != ""
577
+ ), "Wandb entity must be provided if wandb is to be used."
578
+
579
+ _run_id = None
580
+ if checkpointing_config.training.auto_resume:
581
+ # If we are loading a checkpoint, we can try to find the run id of the previous run
582
+ previous_runs = wandb.Api().runs(
583
+ path=f"{monitoring_config.wandb.entity}/{monitoring_config.wandb.project}",
584
+ filters={"display_name": checkpointing_config.run_name},
585
+ )
586
+ try:
587
+ if len(previous_runs) == 1:
588
+ _run_id = previous_runs[0].id
589
+ except ValueError:
590
+ pass
591
+
592
+ wandb_logger = WandbLogger(
593
+ project=monitoring_config.wandb.project,
594
+ entity=monitoring_config.wandb.entity,
595
+ id=_run_id,
596
+ name=checkpointing_config.run_name,
597
+ )
598
+
599
+ return wandb_logger
600
+
601
+
602
+ @rank_zero_only
603
+ def initialize_logging(
604
+ monitoring_config: MonitoringConfig,
605
+ checkpointing_config: CheckpointingConfig,
606
+ fabric: L.Fabric,
607
+ ):
608
+ """Initialize logging system with default logging, to file and console.
609
+
610
+ The default logging system uses a file handler and a stream handler.
611
+
612
+ NOTE: this function is only called on rank 0.
613
+
614
+ Args:
615
+ monitoring_config: Configuration object containing monitoring settings.
616
+ checkpointing_config: Configuration object containing checkpointing settings.
617
+
618
+ Returns:
619
+ logger: Standard Python logger configured for file and console output
620
+ """
621
+
622
+ # ---- Standard Local Logger ---- #
623
+ logger = logging.getLogger("pico-train")
624
+ logger.setLevel(logging.INFO)
625
+
626
+ # Create file handler
627
+ log_file_path = _initialize_log_file(checkpointing_config)
628
+ file_handler = logging.FileHandler(log_file_path, encoding="utf-8")
629
+ file_handler.setLevel(monitoring_config.logging.log_level)
630
+
631
+ # Create formatter and add it to the handler
632
+ formatter = logging.Formatter(
633
+ "%(asctime)s - %(name)s - %(levelname)s - %(message)s",
634
+ datefmt="%Y-%m-%d %H:%M:%S",
635
+ )
636
+ file_handler.setFormatter(formatter)
637
+
638
+ # Add the handler to the logger
639
+ logger.addHandler(file_handler)
640
+
641
+ # Add a stream handler for console output
642
+ stream_handler = logging.StreamHandler()
643
+ stream_handler.setLevel(monitoring_config.logging.log_level)
644
+ stream_handler.setFormatter(formatter)
645
+ logger.addHandler(stream_handler)
646
+
647
+ return logger
648
+
649
+
650
+ ########################################################
651
+ #
652
+ # HuggingFace/Remote Checkpointing
653
+ #
654
+ ########################################################
655
+
656
+
657
+ @rank_zero_only
658
+ @use_backoff()
659
+ def initialize_hf_checkpointing(
660
+ checkpointing_config: CheckpointingConfig, fabric: L.Fabric
661
+ ):
662
+ """Initialize HuggingFace Checkpointing.
663
+
664
+ Creates a HuggingFace repository if it doesn't exist, and creates a branch named after the run.
665
+
666
+ NOTE: this function is only called on rank 0.
667
+
668
+ Args:
669
+ checkpointing_config: Configuration object containing checkpointing settings; must have
670
+ a 'hf_checkpoint' attribute that specifies the HuggingFace repository id and
671
+ collection slug (if applicable) to save the checkpoint to.
672
+
673
+ Raises:
674
+ RuntimeError: If unable to create HuggingFace repository after multiple attempts.
675
+ """
676
+
677
+ huggingface_repo_id = checkpointing_config.hf_checkpoint.repo_id
678
+ assert (
679
+ huggingface_repo_id is not None and huggingface_repo_id != ""
680
+ ), "hf_checkpoint.repo_id must be provided."
681
+
682
+ repo = create_repo(huggingface_repo_id, exist_ok=True)
683
+
684
+ # can create a repo without a specified namespace (will default to username)
685
+ # however the rest of the HF calls need the fully qualified name
686
+ # this is returned by create repo, so we update the config for later calls
687
+ checkpointing_config.hf_checkpoint.repo_id = repo.repo_id
688
+ huggingface_repo_id = repo.repo_id
689
+
690
+ if checkpointing_config.hf_checkpoint.collection_slug:
691
+ add_collection_item(
692
+ checkpointing_config.hf_checkpoint.collection_slug,
693
+ huggingface_repo_id,
694
+ repo.repo_type,
695
+ exists_ok=True,
696
+ )
697
+
698
+ create_branch(
699
+ repo_id=huggingface_repo_id,
700
+ branch=checkpointing_config.run_name,
701
+ exist_ok=True,
702
+ )
src/training/utils/io.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Defines a retry wrapper for io operations."""
2
+
3
+ import time
4
+ from functools import wraps
5
+
6
+
7
+ def use_backoff(max_retries=2, initial_delay=1, backoff_factor=2):
8
+ """
9
+ Universal retry wrapper with exponential backoff for any function, but primarily for loading
10
+ and storing HuggingFace datasets and objects.
11
+
12
+ Example usage:
13
+
14
+ >>> @use_backoff(max_retries=10, delay=1, backoff_factor=2)
15
+ >>> def important_io_operation(x):
16
+ >>> return x + 1
17
+
18
+ Args:
19
+ fn: Function to execute
20
+ max_retries: Maximum number of retry attempts (default: 3)
21
+ delay: Initial delay between retries in seconds (default: 1)
22
+ backoff_factor: Multiplier for delay between retries (default: 2)
23
+
24
+ Returns:
25
+ A wrapper function that will retry the function fn up to max_retries times with exponential backoff
26
+
27
+ Raises:
28
+ Exception: If all retries fail
29
+ """
30
+
31
+ def _decorator(fn):
32
+ @wraps(fn)
33
+ def wrapper(*args, **kwargs):
34
+ current_delay = initial_delay
35
+ last_exception = None
36
+
37
+ for attempt in range(max_retries):
38
+ try:
39
+ return fn(*args, **kwargs)
40
+ except Exception as e:
41
+ last_exception = e
42
+ if attempt < max_retries - 1: # Don't sleep on the last attempt
43
+ time.sleep(current_delay)
44
+ current_delay *= backoff_factor
45
+
46
+ raise Exception(
47
+ f"IO Operation failed after {max_retries} attempts: {str(last_exception)}"
48
+ )
49
+
50
+ return wrapper
51
+
52
+ return _decorator
src/training/utils/logging.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Miscellaneous logging utilities.
3
+ """
4
+
5
+ from io import StringIO
6
+
7
+ import yaml
8
+ from lightning.fabric.utilities.rank_zero import rank_zero_only
9
+ from rich.console import Console
10
+ from rich.panel import Panel
11
+
12
+
13
+ @rank_zero_only
14
+ def pretty_print_yaml_config(logger, config: dict) -> None:
15
+ """
16
+ Pretty print config with rich formatting. Assumes that the config is already saved as a
17
+ dictionary - this can be done by calling `asdict` on the dataclass or loading in the config
18
+ from a yaml file.
19
+
20
+ NOTE: this function is only called on rank 0.
21
+
22
+ Args:
23
+ logger: Logger object to log the formatted output to.
24
+ config: Dictionary containing the config to pretty print.
25
+ """
26
+ # Create string buffer
27
+ output = StringIO()
28
+ console = Console(file=output, force_terminal=False)
29
+
30
+ # Convert to YAML string first
31
+ yaml_str = yaml.dump(
32
+ config, default_flow_style=False, sort_keys=False, Dumper=yaml.SafeDumper
33
+ )
34
+
35
+ # Create formatted panel
36
+ panel = Panel(
37
+ yaml_str,
38
+ border_style="blue",
39
+ padding=(0, 1), # Reduced padding
40
+ expand=False, # Don't expand to terminal width
41
+ )
42
+
43
+ # Print to buffer
44
+ console.print(panel)
45
+
46
+ # Log the formatted output
47
+ for line in output.getvalue().splitlines():
48
+ logger.info(line)