Upload folder using huggingface_hub
Browse files- .env.example +2 -0
- .gitignore +174 -0
- .pre-commit-config.yaml +10 -0
- LICENSE +201 -0
- README.md +170 -0
- configs/examples/demo.yaml +48 -0
- configs/examples/pico-decoder-large.yaml +35 -0
- configs/examples/pico-decoder-medium.yaml +35 -0
- configs/examples/pico-decoder-small.yaml +35 -0
- configs/examples/pico-decoder-tiny.yaml +35 -0
- configs/pico-decoder-tiny-dolma10M-v1.yaml +78 -0
- configs/pico-decoder-tiny-dolma20M-v1.yaml +78 -0
- configs/pico-decoder-tiny-dolma5M-v1.yaml +78 -0
- plots/.gitignore +74 -0
- plots/404.html +33 -0
- plots/README.md +90 -0
- plots/code.js +550 -0
- plots/data.json +0 -0
- plots/index.html +72 -0
- plots/style.css +258 -0
- pyproject.toml +33 -0
- scripts/README.md +109 -0
- scripts/generate_data.py +198 -0
- scripts/train.py +30 -0
- setup.sh +200 -0
- src/checkpointing/__init__.py +23 -0
- src/checkpointing/evaluation.py +68 -0
- src/checkpointing/learning_dynamics.py +424 -0
- src/checkpointing/training.py +287 -0
- src/config/__init__.py +31 -0
- src/config/_constants.py +18 -0
- src/config/checkpointing_config.py +97 -0
- src/config/data_config.py +36 -0
- src/config/evaluation_config.py +28 -0
- src/config/model_config.py +33 -0
- src/config/monitoring_config.py +29 -0
- src/config/training_config.py +40 -0
- src/evaluation/__init__.py +103 -0
- src/evaluation/tasks/paloma.py +52 -0
- src/model/__init__.py +12 -0
- src/model/pico_decoder.py +911 -0
- src/training/trainer.py +753 -0
- src/training/utils/__init__.py +34 -0
- src/training/utils/data.py +35 -0
- src/training/utils/initialization.py +702 -0
- src/training/utils/io.py +52 -0
- src/training/utils/logging.py +48 -0
.env.example
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
WANDB_API_KEY=your_wandb_key
|
2 |
+
HF_TOKEN=your_huggingface_token
|
.gitignore
ADDED
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
wheels/
|
23 |
+
share/python-wheels/
|
24 |
+
*.egg-info/
|
25 |
+
.installed.cfg
|
26 |
+
*.egg
|
27 |
+
MANIFEST
|
28 |
+
|
29 |
+
poetry.lock
|
30 |
+
|
31 |
+
# PyInstaller
|
32 |
+
# Usually these files are written by a python script from a template
|
33 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
34 |
+
*.manifest
|
35 |
+
*.spec
|
36 |
+
|
37 |
+
# Installer logs
|
38 |
+
pip-log.txt
|
39 |
+
pip-delete-this-directory.txt
|
40 |
+
|
41 |
+
# Unit test / coverage reports
|
42 |
+
htmlcov/
|
43 |
+
.tox/
|
44 |
+
.nox/
|
45 |
+
.coverage
|
46 |
+
.coverage.*
|
47 |
+
.cache
|
48 |
+
nosetests.xml
|
49 |
+
coverage.xml
|
50 |
+
*.cover
|
51 |
+
*.py,cover
|
52 |
+
.hypothesis/
|
53 |
+
.pytest_cache/
|
54 |
+
cover/
|
55 |
+
|
56 |
+
# Translations
|
57 |
+
*.mo
|
58 |
+
*.pot
|
59 |
+
|
60 |
+
# Django stuff:
|
61 |
+
*.log
|
62 |
+
local_settings.py
|
63 |
+
db.sqlite3
|
64 |
+
db.sqlite3-journal
|
65 |
+
|
66 |
+
# Flask stuff:
|
67 |
+
instance/
|
68 |
+
.webassets-cache
|
69 |
+
|
70 |
+
# Scrapy stuff:
|
71 |
+
.scrapy
|
72 |
+
|
73 |
+
# Sphinx documentation
|
74 |
+
docs/_build/
|
75 |
+
|
76 |
+
# PyBuilder
|
77 |
+
.pybuilder/
|
78 |
+
target/
|
79 |
+
|
80 |
+
# Jupyter Notebook
|
81 |
+
.ipynb_checkpoints
|
82 |
+
|
83 |
+
# IPython
|
84 |
+
profile_default/
|
85 |
+
ipython_config.py
|
86 |
+
|
87 |
+
# pyenv
|
88 |
+
# For a library or package, you might want to ignore these files since the code is
|
89 |
+
# intended to run in multiple environments; otherwise, check them in:
|
90 |
+
# .python-version
|
91 |
+
|
92 |
+
# pipenv
|
93 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
94 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
95 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
96 |
+
# install all needed dependencies.
|
97 |
+
#Pipfile.lock
|
98 |
+
|
99 |
+
# poetry
|
100 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
101 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
102 |
+
# commonly ignored for libraries.
|
103 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
104 |
+
#poetry.lock
|
105 |
+
|
106 |
+
# pdm
|
107 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
108 |
+
#pdm.lock
|
109 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
110 |
+
# in version control.
|
111 |
+
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
|
112 |
+
.pdm.toml
|
113 |
+
.pdm-python
|
114 |
+
.pdm-build/
|
115 |
+
|
116 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
117 |
+
__pypackages__/
|
118 |
+
|
119 |
+
# Celery stuff
|
120 |
+
celerybeat-schedule
|
121 |
+
celerybeat.pid
|
122 |
+
|
123 |
+
# SageMath parsed files
|
124 |
+
*.sage.py
|
125 |
+
|
126 |
+
# Environments
|
127 |
+
.env
|
128 |
+
.venv
|
129 |
+
env/
|
130 |
+
venv/
|
131 |
+
ENV/
|
132 |
+
env.bak/
|
133 |
+
venv.bak/
|
134 |
+
|
135 |
+
# Spyder project settings
|
136 |
+
.spyderproject
|
137 |
+
.spyproject
|
138 |
+
|
139 |
+
# Rope project settings
|
140 |
+
.ropeproject
|
141 |
+
|
142 |
+
# mkdocs documentation
|
143 |
+
/site
|
144 |
+
|
145 |
+
# mypy
|
146 |
+
.mypy_cache/
|
147 |
+
.dmypy.json
|
148 |
+
dmypy.json
|
149 |
+
|
150 |
+
# Pyre type checker
|
151 |
+
.pyre/
|
152 |
+
|
153 |
+
# pytype static type analyzer
|
154 |
+
.pytype/
|
155 |
+
|
156 |
+
# Cython debug symbols
|
157 |
+
cython_debug/
|
158 |
+
|
159 |
+
# PyCharm
|
160 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
161 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
162 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
163 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
164 |
+
#.idea/
|
165 |
+
|
166 |
+
# Data
|
167 |
+
data/
|
168 |
+
|
169 |
+
# Checkpoint and Logging Directorries
|
170 |
+
runs/
|
171 |
+
wandb/
|
172 |
+
# configs/
|
173 |
+
|
174 |
+
.vscode/
|
.pre-commit-config.yaml
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
repos:
|
2 |
+
- repo: https://github.com/astral-sh/ruff-pre-commit
|
3 |
+
# Ruff version.
|
4 |
+
rev: v0.7.1
|
5 |
+
hooks:
|
6 |
+
# Run the linter.
|
7 |
+
- id: ruff
|
8 |
+
args: [ --fix, --extend-select, I ]
|
9 |
+
# Run the formatter.
|
10 |
+
- id: ruff-format
|
LICENSE
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apache License
|
2 |
+
Version 2.0, January 2004
|
3 |
+
http://www.apache.org/licenses/
|
4 |
+
|
5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
+
|
7 |
+
1. Definitions.
|
8 |
+
|
9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
+
|
12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
+
the copyright owner that is granting the License.
|
14 |
+
|
15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
+
other entities that control, are controlled by, or are under common
|
17 |
+
control with that entity. For the purposes of this definition,
|
18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
19 |
+
direction or management of such entity, whether by contract or
|
20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
+
|
23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
+
exercising permissions granted by this License.
|
25 |
+
|
26 |
+
"Source" form shall mean the preferred form for making modifications,
|
27 |
+
including but not limited to software source code, documentation
|
28 |
+
source, and configuration files.
|
29 |
+
|
30 |
+
"Object" form shall mean any form resulting from mechanical
|
31 |
+
transformation or translation of a Source form, including but
|
32 |
+
not limited to compiled object code, generated documentation,
|
33 |
+
and conversions to other media types.
|
34 |
+
|
35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
36 |
+
Object form, made available under the License, as indicated by a
|
37 |
+
copyright notice that is included in or attached to the work
|
38 |
+
(an example is provided in the Appendix below).
|
39 |
+
|
40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
+
form, that is based on (or derived from) the Work and for which the
|
42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
44 |
+
of this License, Derivative Works shall not include works that remain
|
45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
+
the Work and Derivative Works thereof.
|
47 |
+
|
48 |
+
"Contribution" shall mean any work of authorship, including
|
49 |
+
the original version of the Work and any modifications or additions
|
50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
+
means any form of electronic, verbal, or written communication sent
|
55 |
+
to the Licensor or its representatives, including but not limited to
|
56 |
+
communication on electronic mailing lists, source code control systems,
|
57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
59 |
+
excluding communication that is conspicuously marked or otherwise
|
60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
+
|
62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
64 |
+
subsequently incorporated within the Work.
|
65 |
+
|
66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
71 |
+
Work and such Derivative Works in Source or Object form.
|
72 |
+
|
73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
+
(except as stated in this section) patent license to make, have made,
|
77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
+
where such license applies only to those patent claims licensable
|
79 |
+
by such Contributor that are necessarily infringed by their
|
80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
82 |
+
institute patent litigation against any entity (including a
|
83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
+
or a Contribution incorporated within the Work constitutes direct
|
85 |
+
or contributory patent infringement, then any patent licenses
|
86 |
+
granted to You under this License for that Work shall terminate
|
87 |
+
as of the date such litigation is filed.
|
88 |
+
|
89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
+
Work or Derivative Works thereof in any medium, with or without
|
91 |
+
modifications, and in Source or Object form, provided that You
|
92 |
+
meet the following conditions:
|
93 |
+
|
94 |
+
(a) You must give any other recipients of the Work or
|
95 |
+
Derivative Works a copy of this License; and
|
96 |
+
|
97 |
+
(b) You must cause any modified files to carry prominent notices
|
98 |
+
stating that You changed the files; and
|
99 |
+
|
100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
101 |
+
that You distribute, all copyright, patent, trademark, and
|
102 |
+
attribution notices from the Source form of the Work,
|
103 |
+
excluding those notices that do not pertain to any part of
|
104 |
+
the Derivative Works; and
|
105 |
+
|
106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
+
distribution, then any Derivative Works that You distribute must
|
108 |
+
include a readable copy of the attribution notices contained
|
109 |
+
within such NOTICE file, excluding those notices that do not
|
110 |
+
pertain to any part of the Derivative Works, in at least one
|
111 |
+
of the following places: within a NOTICE text file distributed
|
112 |
+
as part of the Derivative Works; within the Source form or
|
113 |
+
documentation, if provided along with the Derivative Works; or,
|
114 |
+
within a display generated by the Derivative Works, if and
|
115 |
+
wherever such third-party notices normally appear. The contents
|
116 |
+
of the NOTICE file are for informational purposes only and
|
117 |
+
do not modify the License. You may add Your own attribution
|
118 |
+
notices within Derivative Works that You distribute, alongside
|
119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
120 |
+
that such additional attribution notices cannot be construed
|
121 |
+
as modifying the License.
|
122 |
+
|
123 |
+
You may add Your own copyright statement to Your modifications and
|
124 |
+
may provide additional or different license terms and conditions
|
125 |
+
for use, reproduction, or distribution of Your modifications, or
|
126 |
+
for any such Derivative Works as a whole, provided Your use,
|
127 |
+
reproduction, and distribution of the Work otherwise complies with
|
128 |
+
the conditions stated in this License.
|
129 |
+
|
130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
132 |
+
by You to the Licensor shall be under the terms and conditions of
|
133 |
+
this License, without any additional terms or conditions.
|
134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
+
the terms of any separate license agreement you may have executed
|
136 |
+
with Licensor regarding such Contributions.
|
137 |
+
|
138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
140 |
+
except as required for reasonable and customary use in describing the
|
141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
+
|
143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
+
agreed to in writing, Licensor provides the Work (and each
|
145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
+
implied, including, without limitation, any warranties or conditions
|
148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
+
appropriateness of using or redistributing the Work and assume any
|
151 |
+
risks associated with Your exercise of permissions under this License.
|
152 |
+
|
153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
+
whether in tort (including negligence), contract, or otherwise,
|
155 |
+
unless required by applicable law (such as deliberate and grossly
|
156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
+
liable to You for damages, including any direct, indirect, special,
|
158 |
+
incidental, or consequential damages of any character arising as a
|
159 |
+
result of this License or out of the use or inability to use the
|
160 |
+
Work (including but not limited to damages for loss of goodwill,
|
161 |
+
work stoppage, computer failure or malfunction, or any and all
|
162 |
+
other commercial damages or losses), even if such Contributor
|
163 |
+
has been advised of the possibility of such damages.
|
164 |
+
|
165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
+
or other liability obligations and/or rights consistent with this
|
169 |
+
License. However, in accepting such obligations, You may act only
|
170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
+
of any other Contributor, and only if You agree to indemnify,
|
172 |
+
defend, and hold each Contributor harmless for any liability
|
173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
174 |
+
of your accepting any such warranty or additional liability.
|
175 |
+
|
176 |
+
END OF TERMS AND CONDITIONS
|
177 |
+
|
178 |
+
APPENDIX: How to apply the Apache License to your work.
|
179 |
+
|
180 |
+
To apply the Apache License to your work, attach the following
|
181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
182 |
+
replaced with your own identifying information. (Don't include
|
183 |
+
the brackets!) The text should be enclosed in the appropriate
|
184 |
+
comment syntax for the file format. We also recommend that a
|
185 |
+
file or class name and description of purpose be included on the
|
186 |
+
same "printed page" as the copyright notice for easier
|
187 |
+
identification within third-party archives.
|
188 |
+
|
189 |
+
Copyright [yyyy] [name of copyright owner]
|
190 |
+
|
191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
+
you may not use this file except in compliance with the License.
|
193 |
+
You may obtain a copy of the License at
|
194 |
+
|
195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
+
|
197 |
+
Unless required by applicable law or agreed to in writing, software
|
198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
+
See the License for the specific language governing permissions and
|
201 |
+
limitations under the License.
|
README.md
ADDED
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# 🚀 **Pico Train**
|
2 |
+
|
3 |
+
Pico Train is a lightweight framework for training language models—from tiny-scale (~1M parameters) to mid-scale (~1B parameters)—with built-in rich checkpointing that captures activations, gradients, and model states, enabling detailed learning dynamics research.
|
4 |
+
|
5 |
+
Our **suite of pre-trained models** is already publicly available on our [Hugging Face organization](https://huggingface.co/pico-lm), and a dedicated companion library for advanced analysis—[**pico-analyze**](https://github.com/pico-lm/pico-analyze)—is fully released for deeper checkpoint studies.
|
6 |
+
|
7 |
+
> For a **detailed run-through**, check out the **full tutorial** on our website at [picolm.io](https://picolm.io).
|
8 |
+
|
9 |
+
---
|
10 |
+
|
11 |
+
## **Key Features**
|
12 |
+
|
13 |
+
1. **Pico Decoder: LLAMA-style Transformer Architecture**
|
14 |
+
- RMSNorm, RoPE, multi-head self-attention with KV-cache, and SwiGLU activations
|
15 |
+
- Currently supports the **pico-decoder** model, with future expansions planned (pico-diffusion, pico-statespace, etc.)
|
16 |
+
|
17 |
+
2. **Comprehensive Checkpoints**
|
18 |
+
- Saves model states, optimizer states, and training metadata
|
19 |
+
- Enriched with **activation and gradient** snapshots for interpretability
|
20 |
+
|
21 |
+
3. **Focused Scale Range**
|
22 |
+
- Optimized to train models from **1M to 1B parameters**, where learning dynamics research is most viable
|
23 |
+
|
24 |
+
4. **Clean, Pre-tokenized Data**
|
25 |
+
- Uses a pre-tokenized, pre-shuffled version of [Dolma](https://allenai.org/dolma) that we make available on [Hugging Face](https://huggingface.co/datasets/pico-lm/pretokenized-dolma)
|
26 |
+
- Facilitates training models using identical data for **consistency** and **comparability**
|
27 |
+
|
28 |
+
6. **Research Ready**
|
29 |
+
- Minimal, well-documented code suitable for **forking and tailoring**
|
30 |
+
- Logs essential metrics (e.g. perplexity) throughout training
|
31 |
+
- Works seamlessly with [pico-analyze](https://github.com/pico-lm/pico-analyze) for advanced post-training interpretation
|
32 |
+
|
33 |
+
---
|
34 |
+
|
35 |
+
## **Training Philosophy**
|
36 |
+
|
37 |
+
All models in the Pico suite (both pre-trained and user-trained):
|
38 |
+
|
39 |
+
- Employ **identical architectures** and **optimizer settings**
|
40 |
+
- **Share** the same data order and tokens
|
41 |
+
- Automatically log **rich checkpoint data** (including activations, gradients)
|
42 |
+
- Facilitate **direct cross-scale comparisons**
|
43 |
+
|
44 |
+
This uniformity means you can isolate model size as the primary variable, giving you clearer insights into **how model capacity affects learning**.
|
45 |
+
|
46 |
+
---
|
47 |
+
|
48 |
+
## **Resources**
|
49 |
+
|
50 |
+
- **Pre-trained Models** (1M–1B parameters), publicly hosted on [Hugging Face](https://huggingface.co/pico-lm)
|
51 |
+
- **Pre-tokenized Datasets** for straightforward streaming-based training
|
52 |
+
- **Extensive Checkpoints** logging activation and gradient snapshots
|
53 |
+
- **Evaluation Metrics** (perplexity and more) tracked at each checkpoint
|
54 |
+
|
55 |
+
---
|
56 |
+
|
57 |
+
## **Core Components**
|
58 |
+
|
59 |
+
- **Pico-Decoder Model**
|
60 |
+
- LLAMA-style auto-regressive transformer
|
61 |
+
- RMSNorm
|
62 |
+
- RoPE (Rotary Positional Embeddings)
|
63 |
+
- Multi-head attention with KV-cache
|
64 |
+
- SwiGLU activation
|
65 |
+
|
66 |
+
*Future plans include additional architectures like pico-diffusion and pico-statespace.*
|
67 |
+
|
68 |
+
- **Training & Checkpointing**
|
69 |
+
- Automatic storage of model and optimizer states
|
70 |
+
- Periodic hooks for saving **learning dynamics** (activations, gradients)
|
71 |
+
- Optional logging to Weights & Biases
|
72 |
+
|
73 |
+
- **Config-Driven Setup**
|
74 |
+
- Specify architecture, optimizer, dataset, and logging settings in YAML
|
75 |
+
- Straightforward to extend or modify
|
76 |
+
|
77 |
+
---
|
78 |
+
|
79 |
+
## **Quick Start**
|
80 |
+
|
81 |
+
1. **Clone the Repository**
|
82 |
+
|
83 |
+
```bash
|
84 |
+
git clone https://github.com/pico-lm/pico-train
|
85 |
+
cd pico-train
|
86 |
+
```
|
87 |
+
|
88 |
+
2. **Configure Environment**
|
89 |
+
|
90 |
+
Create a `.env` file at the root with your Hugging Face and Weights & Biases tokens:
|
91 |
+
```bash
|
92 |
+
export HF_TOKEN=your_huggingface_token
|
93 |
+
export WANDB_API_KEY=your_wandb_key
|
94 |
+
```
|
95 |
+
|
96 |
+
3. **Install Dependencies**
|
97 |
+
|
98 |
+
```bash
|
99 |
+
source setup.sh
|
100 |
+
```
|
101 |
+
This script checks your environment, installs necessary tools, and sets up a Poetry virtual environment.
|
102 |
+
|
103 |
+
4. **Train Your Model Suite**
|
104 |
+
|
105 |
+
- Edit (or create) a config file (e.g., `configs/demo.yaml`) to specify your architecture and training preferences.
|
106 |
+
- Then run:
|
107 |
+
```bash
|
108 |
+
poetry run train --config_path configs/demo.yaml
|
109 |
+
```
|
110 |
+
- This launches training, automatically checkpointing states and saving learning dynamics data.
|
111 |
+
|
112 |
+
5. **Explore Checkpoints**
|
113 |
+
- By default, checkpoints are stored under `runs/YOUR_RUN_NAME/checkpoints/`.
|
114 |
+
- Each checkpoint contains:
|
115 |
+
- **Model state** (PyTorch + Hugging Face formats)
|
116 |
+
- **Optimizer state**
|
117 |
+
- **Gradients and activations** for interpretability
|
118 |
+
- **Evaluation logs** (e.g. perplexity) and metrics
|
119 |
+
|
120 |
+
---
|
121 |
+
|
122 |
+
## **Repository Structure**
|
123 |
+
|
124 |
+
- **`src/model/pico_decoder.py`**
|
125 |
+
- Core LLAMA-style decoder implementation (attention, RMSNorm, RoPE, etc.)
|
126 |
+
|
127 |
+
- **`src/training/trainer.py`**
|
128 |
+
- Main training loop
|
129 |
+
- Manages distributed and multi-node settings
|
130 |
+
- Collects/logs metrics
|
131 |
+
- Orchestrates checkpoint saving
|
132 |
+
|
133 |
+
- **`src/checkpointing`**
|
134 |
+
- Logic for saving model states, gradients, activations
|
135 |
+
- Tools for uploading checkpoints to Hugging Face
|
136 |
+
|
137 |
+
- **`src/config`**
|
138 |
+
- Flexible Dataclass-based config system (model and training hyperparameters, checkpointing, logging)
|
139 |
+
|
140 |
+
- **`configs/demo.yaml`**
|
141 |
+
- Example config with default values for quick experimentation
|
142 |
+
|
143 |
+
---
|
144 |
+
|
145 |
+
## **Advanced Analysis with Pico Analyze**
|
146 |
+
|
147 |
+
For deeper checkpoint analysis—comparing gradients, tracking representation shifts, measuring sparsity—use our companion repository [**pico-analyze**](https://github.com/pico-lm/pico-analyze). It automatically processes **pico-train** checkpoints and applies advanced metrics like **CKA**, **PWCCA**, **Gini**, **Hoyer**, and more to reveal **how** your models learn over time.
|
148 |
+
|
149 |
+
---
|
150 |
+
|
151 |
+
## **License**
|
152 |
+
|
153 |
+
Pico is open-source under the [Apache License 2.0](LICENSE).
|
154 |
+
|
155 |
+
---
|
156 |
+
|
157 |
+
## **Citation**
|
158 |
+
|
159 |
+
If you use **Pico** in your research, please cite:
|
160 |
+
|
161 |
+
```bibtex
|
162 |
+
@software{pico2025,
|
163 |
+
author = {Diehl Martinez, Richard},
|
164 |
+
title = {Pico: A Lightweight Framework for Studying Language Model Learning Dynamics},
|
165 |
+
year = {2025},
|
166 |
+
url = {https://github.com/pico-lm}
|
167 |
+
}
|
168 |
+
```
|
169 |
+
|
170 |
+
**Happy Training!** For more information and tutorials, visit our website at [picolm.io](https://picolm.io).
|
configs/examples/demo.yaml
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Demo config file
|
2 |
+
# You can follow this template to create your own config file
|
3 |
+
# Refer to the config files in the configs/ directory to see all the available options
|
4 |
+
|
5 |
+
data:
|
6 |
+
dataloader:
|
7 |
+
batch_size: 32
|
8 |
+
|
9 |
+
checkpointing:
|
10 |
+
run_name: "pico-decoder-demo-1"
|
11 |
+
save_every_n_steps: 50
|
12 |
+
|
13 |
+
save_to_hf: true
|
14 |
+
hf_checkpoint:
|
15 |
+
repo_id: "pico-lm/demo"
|
16 |
+
|
17 |
+
learning_dynamics:
|
18 |
+
batch_size: 16
|
19 |
+
|
20 |
+
model:
|
21 |
+
d_model: 96
|
22 |
+
activation_hidden_dim: 384
|
23 |
+
|
24 |
+
evaluation:
|
25 |
+
paloma:
|
26 |
+
batch_size: 32
|
27 |
+
|
28 |
+
monitoring:
|
29 |
+
|
30 |
+
save_to_wandb: true
|
31 |
+
wandb:
|
32 |
+
project: "pico-demo"
|
33 |
+
entity: "pico-lm"
|
34 |
+
|
35 |
+
logging:
|
36 |
+
log_every_n_steps: 10
|
37 |
+
|
38 |
+
training:
|
39 |
+
max_steps: 100
|
40 |
+
|
41 |
+
optimization:
|
42 |
+
lr: 0.001
|
43 |
+
lr_warmup_steps: 30
|
44 |
+
|
45 |
+
gradient_accumulation_steps: 2
|
46 |
+
|
47 |
+
fabric:
|
48 |
+
num_devices: 1
|
configs/examples/pico-decoder-large.yaml
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Demo config file
|
2 |
+
# You can follow this template to create your own config file
|
3 |
+
# Refer to the config files in the configs/ directory to see all the available options
|
4 |
+
|
5 |
+
checkpointing:
|
6 |
+
run_name: "pico-decoder-large-1"
|
7 |
+
save_to_hf: true
|
8 |
+
hf_checkpoint:
|
9 |
+
repo_id: "pico-lm/pico-decoder-large"
|
10 |
+
|
11 |
+
learning_dynamics:
|
12 |
+
batch_size: 128
|
13 |
+
|
14 |
+
model:
|
15 |
+
d_model: 1536
|
16 |
+
activation_hidden_dim: 6144
|
17 |
+
|
18 |
+
monitoring:
|
19 |
+
save_to_wandb: true
|
20 |
+
wandb:
|
21 |
+
project: "pico-decoder"
|
22 |
+
entity: "pico-lm"
|
23 |
+
|
24 |
+
training:
|
25 |
+
optimization:
|
26 |
+
gradient_accumulation_steps: 8
|
27 |
+
|
28 |
+
fabric:
|
29 |
+
num_nodes: 4
|
30 |
+
num_devices: 4
|
31 |
+
|
32 |
+
evaluation:
|
33 |
+
paloma:
|
34 |
+
batch_size: 16
|
35 |
+
|
configs/examples/pico-decoder-medium.yaml
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Demo config file
|
2 |
+
# You can follow this template to create your own config file
|
3 |
+
# Refer to the config files in the configs/ directory to see all the available options
|
4 |
+
|
5 |
+
checkpointing:
|
6 |
+
run_name: "pico-decoder-medium-1"
|
7 |
+
save_to_hf: true
|
8 |
+
hf_checkpoint:
|
9 |
+
repo_id: "pico-lm/pico-decoder-medium"
|
10 |
+
|
11 |
+
learning_dynamics:
|
12 |
+
batch_size: 128
|
13 |
+
|
14 |
+
model:
|
15 |
+
d_model: 768
|
16 |
+
activation_hidden_dim: 3072
|
17 |
+
|
18 |
+
monitoring:
|
19 |
+
save_to_wandb: true
|
20 |
+
wandb:
|
21 |
+
project: "pico-decoder"
|
22 |
+
entity: "pico-lm"
|
23 |
+
|
24 |
+
training:
|
25 |
+
optimization:
|
26 |
+
gradient_accumulation_steps: 8
|
27 |
+
|
28 |
+
fabric:
|
29 |
+
num_nodes: 4
|
30 |
+
num_devices: 4
|
31 |
+
|
32 |
+
evaluation:
|
33 |
+
paloma:
|
34 |
+
batch_size: 16
|
35 |
+
|
configs/examples/pico-decoder-small.yaml
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Demo config file
|
2 |
+
# You can follow this template to create your own config file
|
3 |
+
# Refer to the config files in the configs/ directory to see all the available options
|
4 |
+
|
5 |
+
checkpointing:
|
6 |
+
run_name: "pico-decoder-small-1"
|
7 |
+
save_to_hf: true
|
8 |
+
hf_checkpoint:
|
9 |
+
repo_id: "pico-lm/pico-decoder-small"
|
10 |
+
|
11 |
+
learning_dynamics:
|
12 |
+
batch_size: 128
|
13 |
+
|
14 |
+
model:
|
15 |
+
d_model: 384
|
16 |
+
activation_hidden_dim: 1536
|
17 |
+
|
18 |
+
monitoring:
|
19 |
+
save_to_wandb: true
|
20 |
+
wandb:
|
21 |
+
project: "pico-decoder"
|
22 |
+
entity: "pico-lm"
|
23 |
+
|
24 |
+
training:
|
25 |
+
optimization:
|
26 |
+
gradient_accumulation_steps: 8
|
27 |
+
|
28 |
+
fabric:
|
29 |
+
num_nodes: 4
|
30 |
+
num_devices: 4
|
31 |
+
|
32 |
+
evaluation:
|
33 |
+
paloma:
|
34 |
+
batch_size: 16
|
35 |
+
|
configs/examples/pico-decoder-tiny.yaml
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Demo config file
|
2 |
+
# You can follow this template to create your own config file
|
3 |
+
# Refer to the config files in the configs/ directory to see all the available options
|
4 |
+
|
5 |
+
checkpointing:
|
6 |
+
run_name: "pico-decoder-tiny-1"
|
7 |
+
save_to_hf: true
|
8 |
+
hf_checkpoint:
|
9 |
+
repo_id: "pico-lm/pico-decoder-tiny"
|
10 |
+
|
11 |
+
learning_dynamics:
|
12 |
+
batch_size: 256
|
13 |
+
|
14 |
+
model:
|
15 |
+
d_model: 96
|
16 |
+
activation_hidden_dim: 384
|
17 |
+
|
18 |
+
monitoring:
|
19 |
+
save_to_wandb: true
|
20 |
+
wandb:
|
21 |
+
project: "pico-decoder"
|
22 |
+
entity: "pico-lm"
|
23 |
+
|
24 |
+
training:
|
25 |
+
optimization:
|
26 |
+
gradient_accumulation_steps: 4
|
27 |
+
|
28 |
+
fabric:
|
29 |
+
num_nodes: 4
|
30 |
+
num_devices: 4
|
31 |
+
|
32 |
+
evaluation:
|
33 |
+
paloma:
|
34 |
+
batch_size: 32
|
35 |
+
|
configs/pico-decoder-tiny-dolma10M-v1.yaml
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# High Quality Training Config - Optimized for H100 80GB Performance
|
2 |
+
# Fast training configuration maintaining identical model quality
|
3 |
+
# Optimized for H100 80GB with maximum throughput while preserving stability
|
4 |
+
# Updated for efficient training on Dolma 10M tokens with H100-optimized hyperparameters
|
5 |
+
|
6 |
+
checkpointing:
|
7 |
+
run_name: "pico-decoder-tiny-dolma10M-v1"
|
8 |
+
save_to_hf: true
|
9 |
+
hf_checkpoint:
|
10 |
+
repo_id: "ThomasTheMaker/pico-decoder-tiny"
|
11 |
+
save_every_n_steps: 2000 # Reduced checkpoint frequency for faster training
|
12 |
+
|
13 |
+
learning_dynamics:
|
14 |
+
batch_size: 1 # Minimal batch size for learning dynamics
|
15 |
+
eval_data: null # Disable learning dynamics to save memory
|
16 |
+
|
17 |
+
model:
|
18 |
+
d_model: 96
|
19 |
+
activation_hidden_dim: 384
|
20 |
+
dropout: 0.15 # Increased dropout for stronger regularization
|
21 |
+
attention_dropout: 0.15 # Increased attention dropout
|
22 |
+
layer_norm_eps: 1e-5 # Tighter normalization for stability
|
23 |
+
weight_init_type: "truncated_normal" # Truncated normal for stability
|
24 |
+
layer_norm_type: "rms_norm" # RMSNorm for better stability
|
25 |
+
use_qk_norm: true # Query-Key normalization for attention stability
|
26 |
+
|
27 |
+
monitoring:
|
28 |
+
save_to_wandb: false
|
29 |
+
wandb:
|
30 |
+
project: "pico-decoder-tiny"
|
31 |
+
entity: "boymyc"
|
32 |
+
logging:
|
33 |
+
log_every_n_steps: 100 # Reduced logging frequency for faster training
|
34 |
+
|
35 |
+
training:
|
36 |
+
max_steps: 100000 # Longer training for better convergence
|
37 |
+
optimization:
|
38 |
+
lr: 0.0002 # Scaled learning rate for larger batch size (4x increase)
|
39 |
+
lr_warmup_steps: 2000 # Reduced warmup for faster convergence
|
40 |
+
lr_scheduler: "cosine" # Cosine decay over full dataset for sustained learning
|
41 |
+
weight_decay: 0.02 # Increased weight decay for stronger regularization
|
42 |
+
max_grad_norm: 0.5 # Tighter gradient clipping for stability
|
43 |
+
gradient_accumulation_steps: 1 # Reduced for faster training with larger batches
|
44 |
+
optimizer: "adamw"
|
45 |
+
adam_beta1: 0.9 # Standard AdamW beta1
|
46 |
+
adam_beta2: 0.999 # Standard AdamW beta2
|
47 |
+
adam_epsilon: 1e-8 # Tighter epsilon for numerical stability and convergence
|
48 |
+
|
49 |
+
fabric:
|
50 |
+
num_nodes: 1
|
51 |
+
num_devices: 1
|
52 |
+
precision: "bf16-mixed" # BF16 for Tensor Core optimization
|
53 |
+
|
54 |
+
evaluation:
|
55 |
+
paloma:
|
56 |
+
batch_size: 1 # Minimal evaluation batch size
|
57 |
+
eval_every_n_steps: 1000 # Reduced evaluation frequency for faster training
|
58 |
+
|
59 |
+
data:
|
60 |
+
dataset:
|
61 |
+
name: "ThomasTheMaker/pretokenized-dolma-10M" # Updated to 5M token dataset
|
62 |
+
dataloader:
|
63 |
+
batch_size: 16 # Conservative H100 optimization - 4x larger for stable fast training
|
64 |
+
tokenizer:
|
65 |
+
name: "allenai/OLMo-7B-0724-hf"
|
66 |
+
vocab_size: 50304
|
67 |
+
|
68 |
+
# H100-optimized training strategy for fast, memory-safe training:
|
69 |
+
# 1. Conservative batch size (16) with scaled learning rate (0.0002) for stable H100 utilization
|
70 |
+
# 2. Reduced gradient accumulation (1 step) for faster optimization cycles
|
71 |
+
# 3. Shorter warmup (2000 steps) for quicker convergence with larger batches
|
72 |
+
# 4. Reduced evaluation frequency (1000 steps) to minimize training interruptions
|
73 |
+
# 5. Reduced checkpoint/logging frequency to minimize I/O overhead
|
74 |
+
# 6. Same model architecture and regularization for identical final performance
|
75 |
+
# 7. Expected 4-6x training speedup while maintaining model quality and memory safety
|
76 |
+
# 8. Memory usage: ~15-25GB of 80GB H100 VRAM (safe utilization avoiding OOM)
|
77 |
+
# 9. Maintains all stability features: RMSNorm, QK-Norm, dropout, weight decay
|
78 |
+
# 10. Same convergence quality with significant speedup and no memory issues
|
configs/pico-decoder-tiny-dolma20M-v1.yaml
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# High Quality Training Config - Optimized for H100 80GB Performance
|
2 |
+
# Fast training configuration maintaining identical model quality
|
3 |
+
# Optimized for H100 80GB with maximum throughput while preserving stability
|
4 |
+
# Updated for efficient training on Dolma 10M tokens with H100-optimized hyperparameters
|
5 |
+
|
6 |
+
checkpointing:
|
7 |
+
run_name: "pico-decoder-tiny-dolma20M-v1"
|
8 |
+
save_to_hf: false
|
9 |
+
hf_checkpoint:
|
10 |
+
repo_id: "ThomasTheMaker/pico-decoder-tiny"
|
11 |
+
save_every_n_steps: 1000 # Reduced checkpoint frequency for faster training
|
12 |
+
|
13 |
+
learning_dynamics:
|
14 |
+
batch_size: 1 # Minimal batch size for learning dynamics
|
15 |
+
eval_data: null # Disable learning dynamics to save memory
|
16 |
+
|
17 |
+
model:
|
18 |
+
d_model: 96
|
19 |
+
activation_hidden_dim: 384
|
20 |
+
dropout: 0.15 # Increased dropout for stronger regularization
|
21 |
+
attention_dropout: 0.15 # Increased attention dropout
|
22 |
+
layer_norm_eps: 1e-5 # Tighter normalization for stability
|
23 |
+
weight_init_type: "truncated_normal" # Truncated normal for stability
|
24 |
+
layer_norm_type: "rms_norm" # RMSNorm for better stability
|
25 |
+
use_qk_norm: true # Query-Key normalization for attention stability
|
26 |
+
|
27 |
+
monitoring:
|
28 |
+
save_to_wandb: false
|
29 |
+
wandb:
|
30 |
+
project: "pico-decoder-tiny"
|
31 |
+
entity: "boymyc"
|
32 |
+
logging:
|
33 |
+
log_every_n_steps: 100 # Reduced logging frequency for faster training
|
34 |
+
|
35 |
+
training:
|
36 |
+
max_steps: 100000 # Longer training for better convergence
|
37 |
+
optimization:
|
38 |
+
lr: 0.0002 # Scaled learning rate for larger batch size (4x increase)
|
39 |
+
lr_warmup_steps: 2000 # Reduced warmup for faster convergence
|
40 |
+
lr_scheduler: "cosine" # Cosine decay over full dataset for sustained learning
|
41 |
+
weight_decay: 0.02 # Increased weight decay for stronger regularization
|
42 |
+
max_grad_norm: 0.5 # Tighter gradient clipping for stability
|
43 |
+
gradient_accumulation_steps: 1 # Reduced for faster training with larger batches
|
44 |
+
optimizer: "adamw"
|
45 |
+
adam_beta1: 0.9 # Standard AdamW beta1
|
46 |
+
adam_beta2: 0.999 # Standard AdamW beta2
|
47 |
+
adam_epsilon: 1e-8 # Tighter epsilon for numerical stability and convergence
|
48 |
+
|
49 |
+
fabric:
|
50 |
+
num_nodes: 1
|
51 |
+
num_devices: 1
|
52 |
+
precision: "bf16-mixed" # BF16 for Tensor Core optimization
|
53 |
+
|
54 |
+
evaluation:
|
55 |
+
paloma:
|
56 |
+
batch_size: 1 # Minimal evaluation batch size
|
57 |
+
eval_every_n_steps: 1000 # Reduced evaluation frequency for faster training
|
58 |
+
|
59 |
+
data:
|
60 |
+
dataset:
|
61 |
+
name: "ThomasTheMaker/pretokenized-dolma-20M" # Updated to 5M token dataset
|
62 |
+
dataloader:
|
63 |
+
batch_size: 16 # Conservative H100 optimization - 4x larger for stable fast training
|
64 |
+
tokenizer:
|
65 |
+
name: "allenai/OLMo-7B-0724-hf"
|
66 |
+
vocab_size: 50304
|
67 |
+
|
68 |
+
# H100-optimized training strategy for fast, memory-safe training:
|
69 |
+
# 1. Conservative batch size (16) with scaled learning rate (0.0002) for stable H100 utilization
|
70 |
+
# 2. Reduced gradient accumulation (1 step) for faster optimization cycles
|
71 |
+
# 3. Shorter warmup (2000 steps) for quicker convergence with larger batches
|
72 |
+
# 4. Reduced evaluation frequency (1000 steps) to minimize training interruptions
|
73 |
+
# 5. Reduced checkpoint/logging frequency to minimize I/O overhead
|
74 |
+
# 6. Same model architecture and regularization for identical final performance
|
75 |
+
# 7. Expected 4-6x training speedup while maintaining model quality and memory safety
|
76 |
+
# 8. Memory usage: ~15-25GB of 80GB H100 VRAM (safe utilization avoiding OOM)
|
77 |
+
# 9. Maintains all stability features: RMSNorm, QK-Norm, dropout, weight decay
|
78 |
+
# 10. Same convergence quality with significant speedup and no memory issues
|
configs/pico-decoder-tiny-dolma5M-v1.yaml
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# High Quality Training Config - Optimized for superior model performance
|
2 |
+
# This configuration prioritizes model quality over training speed
|
3 |
+
# Designed for RTX 5090 with focus on preventing overfitting and maximizing generalization
|
4 |
+
# Updated for scaling training on Dolma 5M tokens with stability-focused hyperparameters
|
5 |
+
|
6 |
+
checkpointing:
|
7 |
+
run_name: "pico-decoder-tiny-dolma5M-v1"
|
8 |
+
save_to_hf: true
|
9 |
+
hf_checkpoint:
|
10 |
+
repo_id: "ThomasTheMaker/pico-decoder-tiny"
|
11 |
+
save_every_n_steps: 500 # Frequent checkpoints for quality monitoring
|
12 |
+
|
13 |
+
learning_dynamics:
|
14 |
+
batch_size: 1 # Minimal batch size for learning dynamics
|
15 |
+
eval_data: null # Disable learning dynamics to save memory
|
16 |
+
|
17 |
+
model:
|
18 |
+
d_model: 96
|
19 |
+
activation_hidden_dim: 384
|
20 |
+
dropout: 0.15 # Increased dropout for stronger regularization
|
21 |
+
attention_dropout: 0.15 # Increased attention dropout
|
22 |
+
layer_norm_eps: 1e-5 # Tighter normalization for stability
|
23 |
+
weight_init_type: "truncated_normal" # Truncated normal for stability
|
24 |
+
layer_norm_type: "rms_norm" # RMSNorm for better stability
|
25 |
+
use_qk_norm: true # Query-Key normalization for attention stability
|
26 |
+
|
27 |
+
monitoring:
|
28 |
+
save_to_wandb: false
|
29 |
+
wandb:
|
30 |
+
project: "pico-decoder-tiny"
|
31 |
+
entity: "boymyc"
|
32 |
+
logging:
|
33 |
+
log_every_n_steps: 25 # Very frequent logging for quality monitoring
|
34 |
+
|
35 |
+
training:
|
36 |
+
max_steps: 100000 # Longer training for better convergence
|
37 |
+
optimization:
|
38 |
+
lr: 0.00005 # Even lower learning rate for precision training
|
39 |
+
lr_warmup_steps: 8000 # Extended warmup for stability
|
40 |
+
lr_scheduler: "cosine" # Cosine decay over full dataset for sustained learning
|
41 |
+
weight_decay: 0.02 # Increased weight decay for stronger regularization
|
42 |
+
max_grad_norm: 0.5 # Tighter gradient clipping for stability
|
43 |
+
gradient_accumulation_steps: 4 # Increased for better gradient estimates
|
44 |
+
optimizer: "adamw"
|
45 |
+
adam_beta1: 0.9 # Standard AdamW beta1
|
46 |
+
adam_beta2: 0.999 # Standard AdamW beta2
|
47 |
+
adam_epsilon: 1e-8 # Tighter epsilon for numerical stability and convergence
|
48 |
+
|
49 |
+
fabric:
|
50 |
+
num_nodes: 1
|
51 |
+
num_devices: 1
|
52 |
+
precision: "bf16-mixed" # BF16 for Tensor Core optimization
|
53 |
+
|
54 |
+
evaluation:
|
55 |
+
paloma:
|
56 |
+
batch_size: 1 # Minimal evaluation batch size
|
57 |
+
eval_every_n_steps: 250 # Very frequent evaluation for quality monitoring
|
58 |
+
|
59 |
+
data:
|
60 |
+
dataset:
|
61 |
+
name: "ThomasTheMaker/pretokenized-dolma-5M" # Updated to 5M token dataset
|
62 |
+
dataloader:
|
63 |
+
batch_size: 4 # Reduced for more stable training
|
64 |
+
tokenizer:
|
65 |
+
name: "allenai/OLMo-7B-0724-hf"
|
66 |
+
vocab_size: 50304
|
67 |
+
|
68 |
+
# Stability-focused training strategy for large-scale Dolma training:
|
69 |
+
# 1. Cosine learning rate schedule for sustained learning over full dataset
|
70 |
+
# 2. Truncated normal weight initialization to prevent extreme outliers
|
71 |
+
# 3. RMSNorm for better gradient stability during long training runs
|
72 |
+
# 4. Query-Key normalization (QK-Norm) to prevent attention logit overflow
|
73 |
+
# 5. AdamW epsilon 1e-8 for improved training stability and convergence
|
74 |
+
# 6. Extended warmup (8000 steps) for stable foundation
|
75 |
+
# 7. Stronger regularization (dropout 0.15, weight decay 0.02)
|
76 |
+
# 8. Tighter gradient clipping (0.5) for stability
|
77 |
+
# 9. More frequent evaluation (every 250 steps) for quality monitoring
|
78 |
+
# 10. Longer training (40000 steps) for full convergence on 5M tokens
|
plots/.gitignore
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Logs
|
2 |
+
logs
|
3 |
+
*.log
|
4 |
+
npm-debug.log*
|
5 |
+
yarn-debug.log*
|
6 |
+
yarn-error.log*
|
7 |
+
firebase-debug.log*
|
8 |
+
firebase-debug.*.log*
|
9 |
+
|
10 |
+
# Firebase cache
|
11 |
+
.firebase/
|
12 |
+
|
13 |
+
# Firebase config
|
14 |
+
|
15 |
+
# Uncomment this if you'd like others to create their own Firebase project.
|
16 |
+
# For a team working on the same Firebase project(s), it is recommended to leave
|
17 |
+
# it commented so all members can deploy to the same project(s) in .firebaserc.
|
18 |
+
# .firebaserc
|
19 |
+
|
20 |
+
# Runtime data
|
21 |
+
pids
|
22 |
+
*.pid
|
23 |
+
*.seed
|
24 |
+
*.pid.lock
|
25 |
+
|
26 |
+
# Directory for instrumented libs generated by jscoverage/JSCover
|
27 |
+
lib-cov
|
28 |
+
|
29 |
+
# Coverage directory used by tools like istanbul
|
30 |
+
coverage
|
31 |
+
|
32 |
+
# nyc test coverage
|
33 |
+
.nyc_output
|
34 |
+
|
35 |
+
# Grunt intermediate storage (http://gruntjs.com/creating-plugins#storing-task-files)
|
36 |
+
.grunt
|
37 |
+
|
38 |
+
# Bower dependency directory (https://bower.io/)
|
39 |
+
bower_components
|
40 |
+
|
41 |
+
# node-waf configuration
|
42 |
+
.lock-wscript
|
43 |
+
|
44 |
+
# Compiled binary addons (http://nodejs.org/api/addons.html)
|
45 |
+
build/Release
|
46 |
+
|
47 |
+
# Dependency directories
|
48 |
+
node_modules/
|
49 |
+
|
50 |
+
# Optional npm cache directory
|
51 |
+
.npm
|
52 |
+
|
53 |
+
# Optional eslint cache
|
54 |
+
.eslintcache
|
55 |
+
|
56 |
+
# Optional REPL history
|
57 |
+
.node_repl_history
|
58 |
+
|
59 |
+
# Output of 'npm pack'
|
60 |
+
*.tgz
|
61 |
+
|
62 |
+
# Yarn Integrity file
|
63 |
+
.yarn-integrity
|
64 |
+
|
65 |
+
# dotenv environment variables file
|
66 |
+
.env
|
67 |
+
|
68 |
+
# dataconnect generated files
|
69 |
+
.dataconnect
|
70 |
+
|
71 |
+
# firebase files
|
72 |
+
|
73 |
+
.firebaserc
|
74 |
+
firebase.json
|
plots/404.html
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<!DOCTYPE html>
|
2 |
+
<html>
|
3 |
+
<head>
|
4 |
+
<meta charset="utf-8">
|
5 |
+
<meta name="viewport" content="width=device-width, initial-scale=1">
|
6 |
+
<title>Page Not Found</title>
|
7 |
+
|
8 |
+
<style media="screen">
|
9 |
+
body { background: #ECEFF1; color: rgba(0,0,0,0.87); font-family: Roboto, Helvetica, Arial, sans-serif; margin: 0; padding: 0; }
|
10 |
+
#message { background: white; max-width: 360px; margin: 100px auto 16px; padding: 32px 24px 16px; border-radius: 3px; }
|
11 |
+
#message h3 { color: #888; font-weight: normal; font-size: 16px; margin: 16px 0 12px; }
|
12 |
+
#message h2 { color: #ffa100; font-weight: bold; font-size: 16px; margin: 0 0 8px; }
|
13 |
+
#message h1 { font-size: 22px; font-weight: 300; color: rgba(0,0,0,0.6); margin: 0 0 16px;}
|
14 |
+
#message p { line-height: 140%; margin: 16px 0 24px; font-size: 14px; }
|
15 |
+
#message a { display: block; text-align: center; background: #039be5; text-transform: uppercase; text-decoration: none; color: white; padding: 16px; border-radius: 4px; }
|
16 |
+
#message, #message a { box-shadow: 0 1px 3px rgba(0,0,0,0.12), 0 1px 2px rgba(0,0,0,0.24); }
|
17 |
+
#load { color: rgba(0,0,0,0.4); text-align: center; font-size: 13px; }
|
18 |
+
@media (max-width: 600px) {
|
19 |
+
body, #message { margin-top: 0; background: white; box-shadow: none; }
|
20 |
+
body { border-top: 16px solid #ffa100; }
|
21 |
+
}
|
22 |
+
</style>
|
23 |
+
</head>
|
24 |
+
<body>
|
25 |
+
<div id="message">
|
26 |
+
<h2>404</h2>
|
27 |
+
<h1>Page Not Found</h1>
|
28 |
+
<p>The specified file was not found on this website. Please check the URL for mistakes and try again.</p>
|
29 |
+
<h3>Why am I seeing this?</h3>
|
30 |
+
<p>This page was generated by the Firebase Command-Line Interface. To modify it, edit the <code>404.html</code> file in your project's configured <code>public</code> directory.</p>
|
31 |
+
</div>
|
32 |
+
</body>
|
33 |
+
</html>
|
plots/README.md
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# 🚀 Pico Training Metrics Dashboard
|
2 |
+
|
3 |
+
A beautiful, interactive web dashboard for visualizing training progress across all your Pico model runs.
|
4 |
+
|
5 |
+
## ✨ Features
|
6 |
+
|
7 |
+
- **📈 Training Loss Visualization**: Track loss curves over time for all runs
|
8 |
+
- **🎯 Learning Rate Schedules**: Monitor LR progression and warmup patterns
|
9 |
+
- **📊 Paloma Evaluation**: View perplexity metrics during training
|
10 |
+
- **🔄 Combined View**: See all metrics together for easy comparison
|
11 |
+
- **🎨 Interactive Charts**: Built with Chart.js for smooth interactions
|
12 |
+
- **📱 Responsive Design**: Works on desktop and mobile devices
|
13 |
+
- **⚙️ Run Comparison**: Compare different model configurations side-by-side
|
14 |
+
|
15 |
+
## �� Quick Start
|
16 |
+
|
17 |
+
1. **Generate Data**: First, run the data generation script to parse your training logs:
|
18 |
+
```bash
|
19 |
+
python scripts/generate_data.py
|
20 |
+
```
|
21 |
+
|
22 |
+
2. **View the Dashboard**: Open `index.html` in your web browser
|
23 |
+
3. **Select Runs**: Use the dropdown to view specific runs or all runs together
|
24 |
+
4. **Toggle Metrics**: Check/uncheck boxes to show/hide different metric types
|
25 |
+
5. **Explore Charts**: Hover over data points for detailed information
|
26 |
+
|
27 |
+
## 📁 Files
|
28 |
+
|
29 |
+
- `index.html` - Main dashboard interface
|
30 |
+
- `style.css` - Modern, responsive styling
|
31 |
+
- `code.js` - Interactive chart functionality
|
32 |
+
- `data.json` - Training metrics data (auto-generated from logs)
|
33 |
+
|
34 |
+
## 🔧 Data Source
|
35 |
+
|
36 |
+
The dashboard automatically extracts training metrics from:
|
37 |
+
- Training loss at each step
|
38 |
+
- Learning rate progression
|
39 |
+
- Paloma evaluation results
|
40 |
+
- Model configuration parameters
|
41 |
+
|
42 |
+
## 🔄 Updating Data
|
43 |
+
|
44 |
+
To refresh the dashboard with new training data:
|
45 |
+
1. **Run new training sessions** - logs will be saved to `runs/*/logs/`
|
46 |
+
2. **Generate updated data.json**:
|
47 |
+
```bash
|
48 |
+
python scripts/generate_data.py
|
49 |
+
```
|
50 |
+
3. **Refresh the dashboard** - new runs will appear automatically
|
51 |
+
|
52 |
+
## 🎨 Chart Types
|
53 |
+
|
54 |
+
1. **Training Loss**: Line charts showing loss reduction over time
|
55 |
+
2. **Learning Rate**: Logarithmic scale for LR schedule visualization
|
56 |
+
3. **Evaluation**: Paloma perplexity metrics during training
|
57 |
+
4. **Combined**: All metrics on one chart for easy comparison
|
58 |
+
|
59 |
+
## 💡 Usage Tips
|
60 |
+
|
61 |
+
- **Compare Runs**: Select "All Runs" to see how different configurations perform
|
62 |
+
- **Zoom In**: Use the chart zoom features to focus on specific training phases
|
63 |
+
- **Export**: Right-click charts to save as images
|
64 |
+
- **Mobile**: Dashboard is fully responsive for mobile devices
|
65 |
+
|
66 |
+
## 🎯 Key Metrics Tracked
|
67 |
+
|
68 |
+
- **Training Loss**: Primary performance indicator
|
69 |
+
- **Learning Rate**: Schedule adherence and warmup progress
|
70 |
+
- **Paloma Perplexity**: Model evaluation quality
|
71 |
+
- **Inf/NaN Counts**: Training stability monitoring
|
72 |
+
- **Model Config**: Architecture and hyperparameter details
|
73 |
+
|
74 |
+
## 🌟 Design Features
|
75 |
+
|
76 |
+
- **Modern UI**: Clean, professional interface
|
77 |
+
- **Color Coding**: Distinct colors for each model run
|
78 |
+
- **Responsive Layout**: Adapts to different screen sizes
|
79 |
+
- **Interactive Elements**: Hover effects and smooth animations
|
80 |
+
- **Professional Typography**: Easy-to-read fonts and spacing
|
81 |
+
|
82 |
+
## 📚 Documentation
|
83 |
+
|
84 |
+
For more details on generating the data.json file, see:
|
85 |
+
- `scripts/README.md` - Complete script documentation
|
86 |
+
- `scripts/generate_data.py` - The data generation script
|
87 |
+
|
88 |
+
---
|
89 |
+
|
90 |
+
Built with ❤️ for the Pico Language Model training community
|
plots/code.js
ADDED
@@ -0,0 +1,550 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Global variables
|
2 |
+
let trainingData = null;
|
3 |
+
let charts = {};
|
4 |
+
|
5 |
+
// Color palette for different runs
|
6 |
+
const colors = [
|
7 |
+
'#667eea', '#764ba2', '#f093fb', '#f5576c', '#4facfe', '#00f2fe',
|
8 |
+
'#43e97b', '#38f9d7', '#fa7093', '#fee140', '#a8edea', '#fed6e3'
|
9 |
+
];
|
10 |
+
|
11 |
+
// Initialize the dashboard
|
12 |
+
document.addEventListener('DOMContentLoaded', function() {
|
13 |
+
loadData();
|
14 |
+
setupEventListeners();
|
15 |
+
});
|
16 |
+
|
17 |
+
// Load training data from JSON file
|
18 |
+
async function loadData() {
|
19 |
+
try {
|
20 |
+
const response = await fetch('data.json');
|
21 |
+
trainingData = await response.json();
|
22 |
+
|
23 |
+
// Merge continuation logs from the same model run
|
24 |
+
mergeContinuationLogs();
|
25 |
+
|
26 |
+
populateRunSelector();
|
27 |
+
createCharts();
|
28 |
+
updateRunSummary();
|
29 |
+
updateConfigDetails();
|
30 |
+
|
31 |
+
console.log('Data loaded and merged successfully:', trainingData);
|
32 |
+
} catch (error) {
|
33 |
+
console.error('Error loading data:', error);
|
34 |
+
document.body.innerHTML = '<div class="loading">Error loading training data. Please check the console for details.</div>';
|
35 |
+
}
|
36 |
+
}
|
37 |
+
|
38 |
+
// Merge continuation logs from the same model run
|
39 |
+
function mergeContinuationLogs() {
|
40 |
+
const runGroups = {};
|
41 |
+
|
42 |
+
// Group runs by base model name
|
43 |
+
trainingData.runs.forEach(run => {
|
44 |
+
const baseName = run.run_name;
|
45 |
+
if (!runGroups[baseName]) {
|
46 |
+
runGroups[baseName] = [];
|
47 |
+
}
|
48 |
+
runGroups[baseName].push(run);
|
49 |
+
});
|
50 |
+
|
51 |
+
// Merge runs with the same base name
|
52 |
+
const mergedRuns = [];
|
53 |
+
|
54 |
+
Object.entries(runGroups).forEach(([baseName, runs]) => {
|
55 |
+
if (runs.length === 1) {
|
56 |
+
// Single run, no merging needed
|
57 |
+
mergedRuns.push(runs[0]);
|
58 |
+
} else {
|
59 |
+
// Multiple runs to merge
|
60 |
+
console.log(`Merging ${runs.length} continuation logs for ${baseName}`);
|
61 |
+
|
62 |
+
const mergedRun = {
|
63 |
+
run_name: baseName,
|
64 |
+
log_files: runs.map(r => r.log_file),
|
65 |
+
training_metrics: [],
|
66 |
+
evaluation_results: [],
|
67 |
+
config: runs[0].config || {}
|
68 |
+
};
|
69 |
+
|
70 |
+
// Merge training metrics (they should be continuous)
|
71 |
+
runs.forEach(run => {
|
72 |
+
if (run.training_metrics) {
|
73 |
+
mergedRun.training_metrics.push(...run.training_metrics);
|
74 |
+
}
|
75 |
+
});
|
76 |
+
|
77 |
+
// Merge evaluation results (they should be continuous)
|
78 |
+
runs.forEach(run => {
|
79 |
+
if (run.evaluation_results) {
|
80 |
+
mergedRun.evaluation_results.push(...run.evaluation_results);
|
81 |
+
}
|
82 |
+
});
|
83 |
+
|
84 |
+
// Sort by step number to ensure proper ordering
|
85 |
+
mergedRun.training_metrics.sort((a, b) => a.step - b.step);
|
86 |
+
mergedRun.evaluation_results.sort((a, b) => a.step - b.step);
|
87 |
+
|
88 |
+
// Remove duplicates based on step number
|
89 |
+
mergedRun.training_metrics = mergedRun.training_metrics.filter((metric, index, self) =>
|
90 |
+
index === 0 || metric.step !== self[index - 1].step
|
91 |
+
);
|
92 |
+
mergedRun.evaluation_results = mergedRun.evaluation_results.filter((result, index, self) =>
|
93 |
+
index === 0 || result.step !== self[index - 1].step
|
94 |
+
);
|
95 |
+
|
96 |
+
console.log(`Merged ${baseName}: ${mergedRun.training_metrics.length} training points, ${mergedRun.evaluation_results.length} eval points`);
|
97 |
+
mergedRuns.push(mergedRun);
|
98 |
+
}
|
99 |
+
});
|
100 |
+
|
101 |
+
trainingData.runs = mergedRuns;
|
102 |
+
}
|
103 |
+
|
104 |
+
// Setup event listeners for controls
|
105 |
+
function setupEventListeners() {
|
106 |
+
document.getElementById('runSelect').addEventListener('change', function() {
|
107 |
+
updateCharts();
|
108 |
+
updateRunSummary();
|
109 |
+
updateConfigDetails();
|
110 |
+
});
|
111 |
+
document.getElementById('showTraining').addEventListener('change', updateCharts);
|
112 |
+
document.getElementById('showLearningRate').addEventListener('change', updateCharts);
|
113 |
+
document.getElementById('showEvaluation').addEventListener('change', updateCharts);
|
114 |
+
}
|
115 |
+
|
116 |
+
// Populate run selector dropdown
|
117 |
+
function populateRunSelector() {
|
118 |
+
const select = document.getElementById('runSelect');
|
119 |
+
const runs = trainingData.runs;
|
120 |
+
|
121 |
+
// Clear existing options
|
122 |
+
select.innerHTML = '<option value="all">All Runs</option>';
|
123 |
+
|
124 |
+
runs.forEach((run, index) => {
|
125 |
+
const option = document.createElement('option');
|
126 |
+
option.value = index;
|
127 |
+
option.textContent = run.run_name;
|
128 |
+
select.appendChild(option);
|
129 |
+
});
|
130 |
+
}
|
131 |
+
|
132 |
+
// Create all charts
|
133 |
+
function createCharts() {
|
134 |
+
createLossChart();
|
135 |
+
createLRChart();
|
136 |
+
createEvalChart();
|
137 |
+
createCombinedChart();
|
138 |
+
}
|
139 |
+
|
140 |
+
// Create training loss chart
|
141 |
+
function createLossChart() {
|
142 |
+
const ctx = document.getElementById('lossChart').getContext('2d');
|
143 |
+
|
144 |
+
charts.loss = new Chart(ctx, {
|
145 |
+
type: 'line',
|
146 |
+
data: getChartData('loss'),
|
147 |
+
options: {
|
148 |
+
responsive: true,
|
149 |
+
maintainAspectRatio: false,
|
150 |
+
plugins: {
|
151 |
+
title: {
|
152 |
+
display: true,
|
153 |
+
text: 'Training Loss Over Time'
|
154 |
+
},
|
155 |
+
legend: {
|
156 |
+
position: 'top'
|
157 |
+
}
|
158 |
+
},
|
159 |
+
scales: {
|
160 |
+
x: {
|
161 |
+
type: 'linear',
|
162 |
+
title: {
|
163 |
+
display: true,
|
164 |
+
text: 'Training Step'
|
165 |
+
}
|
166 |
+
},
|
167 |
+
y: {
|
168 |
+
title: {
|
169 |
+
display: true,
|
170 |
+
text: 'Loss'
|
171 |
+
},
|
172 |
+
beginAtZero: false
|
173 |
+
}
|
174 |
+
},
|
175 |
+
interaction: {
|
176 |
+
intersect: false,
|
177 |
+
mode: 'index'
|
178 |
+
}
|
179 |
+
}
|
180 |
+
});
|
181 |
+
}
|
182 |
+
|
183 |
+
// Create learning rate chart
|
184 |
+
function createLRChart() {
|
185 |
+
const ctx = document.getElementById('lrChart').getContext('2d');
|
186 |
+
|
187 |
+
charts.lr = new Chart(ctx, {
|
188 |
+
type: 'line',
|
189 |
+
data: getChartData('lr'),
|
190 |
+
options: {
|
191 |
+
responsive: true,
|
192 |
+
maintainAspectRatio: false,
|
193 |
+
plugins: {
|
194 |
+
title: {
|
195 |
+
display: true,
|
196 |
+
text: 'Learning Rate Schedule'
|
197 |
+
},
|
198 |
+
legend: {
|
199 |
+
position: 'top'
|
200 |
+
}
|
201 |
+
},
|
202 |
+
scales: {
|
203 |
+
x: {
|
204 |
+
type: 'linear',
|
205 |
+
title: {
|
206 |
+
display: true,
|
207 |
+
text: 'Training Step'
|
208 |
+
}
|
209 |
+
},
|
210 |
+
y: {
|
211 |
+
title: {
|
212 |
+
display: true,
|
213 |
+
text: 'Learning Rate'
|
214 |
+
},
|
215 |
+
type: 'logarithmic'
|
216 |
+
}
|
217 |
+
},
|
218 |
+
interaction: {
|
219 |
+
intersect: false,
|
220 |
+
mode: 'index'
|
221 |
+
}
|
222 |
+
}
|
223 |
+
});
|
224 |
+
}
|
225 |
+
|
226 |
+
// Create evaluation chart
|
227 |
+
function createEvalChart() {
|
228 |
+
const ctx = document.getElementById('evalChart').getContext('2d');
|
229 |
+
|
230 |
+
charts.eval = new Chart(ctx, {
|
231 |
+
type: 'line',
|
232 |
+
data: getChartData('eval'),
|
233 |
+
options: {
|
234 |
+
responsive: true,
|
235 |
+
maintainAspectRatio: false,
|
236 |
+
plugins: {
|
237 |
+
title: {
|
238 |
+
display: true,
|
239 |
+
text: 'Paloma Evaluation Metrics'
|
240 |
+
},
|
241 |
+
legend: {
|
242 |
+
position: 'top'
|
243 |
+
}
|
244 |
+
},
|
245 |
+
scales: {
|
246 |
+
x: {
|
247 |
+
type: 'linear',
|
248 |
+
title: {
|
249 |
+
display: true,
|
250 |
+
text: 'Training Step'
|
251 |
+
}
|
252 |
+
},
|
253 |
+
y: {
|
254 |
+
title: {
|
255 |
+
display: true,
|
256 |
+
text: 'Perplexity'
|
257 |
+
},
|
258 |
+
type: 'logarithmic'
|
259 |
+
}
|
260 |
+
},
|
261 |
+
interaction: {
|
262 |
+
intersect: false,
|
263 |
+
mode: 'index'
|
264 |
+
}
|
265 |
+
}
|
266 |
+
});
|
267 |
+
}
|
268 |
+
|
269 |
+
// Create combined chart
|
270 |
+
function createCombinedChart() {
|
271 |
+
const ctx = document.getElementById('combinedChart').getContext('2d');
|
272 |
+
|
273 |
+
charts.combined = new Chart(ctx, {
|
274 |
+
type: 'line',
|
275 |
+
data: getCombinedChartData(),
|
276 |
+
options: {
|
277 |
+
responsive: true,
|
278 |
+
maintainAspectRatio: false,
|
279 |
+
plugins: {
|
280 |
+
title: {
|
281 |
+
display: true,
|
282 |
+
text: 'Combined Training Metrics'
|
283 |
+
},
|
284 |
+
legend: {
|
285 |
+
position: 'top'
|
286 |
+
}
|
287 |
+
},
|
288 |
+
scales: {
|
289 |
+
x: {
|
290 |
+
type: 'linear',
|
291 |
+
title: {
|
292 |
+
display: true,
|
293 |
+
text: 'Training Step'
|
294 |
+
}
|
295 |
+
},
|
296 |
+
y: {
|
297 |
+
title: {
|
298 |
+
display: true,
|
299 |
+
text: 'Value'
|
300 |
+
}
|
301 |
+
}
|
302 |
+
},
|
303 |
+
interaction: {
|
304 |
+
intersect: false,
|
305 |
+
mode: 'index'
|
306 |
+
}
|
307 |
+
}
|
308 |
+
});
|
309 |
+
}
|
310 |
+
|
311 |
+
// Get chart data for specific metric type
|
312 |
+
function getChartData(metricType) {
|
313 |
+
const selectedRun = document.getElementById('runSelect').value;
|
314 |
+
const runs = selectedRun === 'all' ? trainingData.runs : [trainingData.runs[selectedRun]];
|
315 |
+
|
316 |
+
const datasets = [];
|
317 |
+
|
318 |
+
console.log(`Getting ${metricType} data for ${runs.length} runs:`, runs.map(r => r.run_name));
|
319 |
+
|
320 |
+
runs.forEach((run, runIndex) => {
|
321 |
+
const color = colors[runIndex % colors.length];
|
322 |
+
|
323 |
+
if (metricType === 'loss') {
|
324 |
+
if (run.training_metrics && run.training_metrics.length > 0) {
|
325 |
+
const data = run.training_metrics.map(m => ({ x: m.step, y: m.loss }));
|
326 |
+
console.log(`Loss data for ${run.run_name}:`, data.slice(0, 5), '...', data.slice(-5));
|
327 |
+
datasets.push({
|
328 |
+
label: run.run_name,
|
329 |
+
data: data,
|
330 |
+
borderColor: color,
|
331 |
+
backgroundColor: color + '20',
|
332 |
+
borderWidth: 2,
|
333 |
+
fill: false,
|
334 |
+
tension: 0.1
|
335 |
+
});
|
336 |
+
}
|
337 |
+
} else if (metricType === 'lr') {
|
338 |
+
if (run.training_metrics && run.training_metrics.length > 0) {
|
339 |
+
const data = run.training_metrics.map(m => ({ x: m.step, y: m.learning_rate }));
|
340 |
+
console.log(`LR data for ${run.run_name}:`, data.slice(0, 5), '...', data.slice(-5));
|
341 |
+
datasets.push({
|
342 |
+
label: run.run_name,
|
343 |
+
data: data,
|
344 |
+
borderColor: color,
|
345 |
+
backgroundColor: color + '20',
|
346 |
+
borderWidth: 2,
|
347 |
+
fill: false,
|
348 |
+
tension: 0.1
|
349 |
+
});
|
350 |
+
}
|
351 |
+
} else if (metricType === 'eval') {
|
352 |
+
if (run.evaluation_results && run.evaluation_results.length > 0) {
|
353 |
+
const data = run.evaluation_results.map(m => ({ x: m.step, y: m.paloma }));
|
354 |
+
console.log(`Eval data for ${run.run_name}:`, data.slice(0, 5), '...', data.slice(-5));
|
355 |
+
datasets.push({
|
356 |
+
label: run.run_name,
|
357 |
+
data: data,
|
358 |
+
borderColor: color,
|
359 |
+
backgroundColor: color + '20',
|
360 |
+
borderWidth: 2,
|
361 |
+
fill: false,
|
362 |
+
tension: 0.1
|
363 |
+
});
|
364 |
+
}
|
365 |
+
}
|
366 |
+
});
|
367 |
+
|
368 |
+
console.log(`Final ${metricType} datasets:`, datasets);
|
369 |
+
return { datasets };
|
370 |
+
}
|
371 |
+
|
372 |
+
// Get combined chart data
|
373 |
+
function getCombinedChartData() {
|
374 |
+
const selectedRun = document.getElementById('runSelect').value;
|
375 |
+
const runs = selectedRun === 'all' ? trainingData.runs : [trainingData.runs[selectedRun]];
|
376 |
+
|
377 |
+
const datasets = [];
|
378 |
+
|
379 |
+
runs.forEach((run, runIndex) => {
|
380 |
+
const color = colors[runIndex % colors.length];
|
381 |
+
|
382 |
+
// Training loss
|
383 |
+
if (run.training_metrics && run.training_metrics.length > 0) {
|
384 |
+
datasets.push({
|
385 |
+
label: `${run.run_name} - Loss`,
|
386 |
+
data: run.training_metrics.map(m => ({ x: m.step, y: m.loss })),
|
387 |
+
borderColor: color,
|
388 |
+
backgroundColor: color + '20',
|
389 |
+
borderWidth: 2,
|
390 |
+
fill: false,
|
391 |
+
tension: 0.1
|
392 |
+
});
|
393 |
+
}
|
394 |
+
|
395 |
+
// Learning rate (scaled)
|
396 |
+
if (run.training_metrics && run.training_metrics.length > 0) {
|
397 |
+
const maxLR = Math.max(...run.training_metrics.map(m => m.learning_rate));
|
398 |
+
const maxLoss = Math.max(...run.training_metrics.map(m => m.loss));
|
399 |
+
const scaleFactor = maxLoss / maxLR;
|
400 |
+
|
401 |
+
datasets.push({
|
402 |
+
label: `${run.run_name} - LR (scaled)`,
|
403 |
+
data: run.training_metrics.map(m => ({ x: m.step, y: m.learning_rate * scaleFactor })),
|
404 |
+
borderColor: color + '80',
|
405 |
+
backgroundColor: color + '10',
|
406 |
+
borderWidth: 1,
|
407 |
+
fill: false,
|
408 |
+
tension: 0.1
|
409 |
+
});
|
410 |
+
}
|
411 |
+
});
|
412 |
+
|
413 |
+
return { datasets };
|
414 |
+
}
|
415 |
+
|
416 |
+
// Update all charts based on current selection
|
417 |
+
function updateCharts() {
|
418 |
+
if (charts.loss) {
|
419 |
+
charts.loss.data = getChartData('loss');
|
420 |
+
charts.loss.update();
|
421 |
+
}
|
422 |
+
|
423 |
+
if (charts.lr) {
|
424 |
+
charts.lr.data = getChartData('lr');
|
425 |
+
charts.lr.update();
|
426 |
+
}
|
427 |
+
|
428 |
+
if (charts.eval) {
|
429 |
+
charts.eval.data = getChartData('eval');
|
430 |
+
charts.eval.update();
|
431 |
+
}
|
432 |
+
|
433 |
+
if (charts.combined) {
|
434 |
+
charts.combined.data = getCombinedChartData();
|
435 |
+
charts.combined.update();
|
436 |
+
}
|
437 |
+
}
|
438 |
+
|
439 |
+
// Update run summary section
|
440 |
+
function updateRunSummary() {
|
441 |
+
const container = document.getElementById('runSummary');
|
442 |
+
const selectedRun = document.getElementById('runSelect').value;
|
443 |
+
const runs = selectedRun === 'all' ? trainingData.runs : [trainingData.runs[selectedRun]];
|
444 |
+
|
445 |
+
let html = '<div class="run-grid">';
|
446 |
+
|
447 |
+
runs.forEach(run => {
|
448 |
+
const trainingPoints = run.training_metrics ? run.training_metrics.length : 0;
|
449 |
+
const evalPoints = run.evaluation_results ? run.evaluation_results.length : 0;
|
450 |
+
|
451 |
+
let finalLoss = 'N/A';
|
452 |
+
let finalLR = 'N/A';
|
453 |
+
let finalPaloma = 'N/A';
|
454 |
+
let stepRange = 'N/A';
|
455 |
+
|
456 |
+
if (run.training_metrics && run.training_metrics.length > 0) {
|
457 |
+
const first = run.training_metrics[0];
|
458 |
+
const last = run.training_metrics[run.training_metrics.length - 1];
|
459 |
+
finalLoss = last.loss.toFixed(4);
|
460 |
+
finalLR = last.learning_rate.toExponential(2);
|
461 |
+
stepRange = `${first.step} → ${last.step}`;
|
462 |
+
}
|
463 |
+
|
464 |
+
if (run.evaluation_results && run.evaluation_results.length > 0) {
|
465 |
+
const last = run.evaluation_results[run.evaluation_results.length - 1];
|
466 |
+
if (isFinite(last.paloma)) {
|
467 |
+
finalPaloma = last.paloma.toExponential(2);
|
468 |
+
} else {
|
469 |
+
finalPaloma = '∞';
|
470 |
+
}
|
471 |
+
}
|
472 |
+
|
473 |
+
const logFiles = run.log_files ? run.log_files.join(', ') : run.log_file;
|
474 |
+
|
475 |
+
html += `
|
476 |
+
<div class="run-card">
|
477 |
+
<h4>${run.run_name}</h4>
|
478 |
+
<p><strong>Logs:</strong> ${logFiles}</p>
|
479 |
+
<div class="metric">
|
480 |
+
<span>Step Range:</span>
|
481 |
+
<span class="value">${stepRange}</span>
|
482 |
+
</div>
|
483 |
+
<div class="metric">
|
484 |
+
<span>Training Points:</span>
|
485 |
+
<span class="value">${trainingPoints}</span>
|
486 |
+
</div>
|
487 |
+
<div class="metric">
|
488 |
+
<span>Evaluation Points:</span>
|
489 |
+
<span class="value">${evalPoints}</span>
|
490 |
+
</div>
|
491 |
+
<div class="metric">
|
492 |
+
<span>Final Loss:</span>
|
493 |
+
<span class="value">${finalLoss}</span>
|
494 |
+
</div>
|
495 |
+
<div class="metric">
|
496 |
+
<span>Final LR:</span>
|
497 |
+
<span class="value">${finalLR}</span>
|
498 |
+
</div>
|
499 |
+
<div class="metric">
|
500 |
+
<span>Final Paloma:</span>
|
501 |
+
<span class="value">${finalPaloma}</span>
|
502 |
+
</div>
|
503 |
+
</div>
|
504 |
+
`;
|
505 |
+
});
|
506 |
+
|
507 |
+
html += '</div>';
|
508 |
+
container.innerHTML = html;
|
509 |
+
}
|
510 |
+
|
511 |
+
// Update configuration details section
|
512 |
+
function updateConfigDetails() {
|
513 |
+
const container = document.getElementById('configDetails');
|
514 |
+
const selectedRun = document.getElementById('runSelect').value;
|
515 |
+
const runs = selectedRun === 'all' ? trainingData.runs : [trainingData.runs[selectedRun]];
|
516 |
+
|
517 |
+
let html = '<div class="config-grid">';
|
518 |
+
|
519 |
+
// Get unique config keys
|
520 |
+
const allKeys = new Set();
|
521 |
+
runs.forEach(run => {
|
522 |
+
if (run.config) {
|
523 |
+
Object.keys(run.config).forEach(key => allKeys.add(key));
|
524 |
+
}
|
525 |
+
});
|
526 |
+
|
527 |
+
allKeys.forEach(key => {
|
528 |
+
const values = runs.map(run => run.config && run.config[key] !== undefined ? run.config[key] : 'N/A');
|
529 |
+
const uniqueValues = [...new Set(values)];
|
530 |
+
const displayValue = uniqueValues.length === 1 ? uniqueValues[0] : `${uniqueValues.join(' / ')}`;
|
531 |
+
|
532 |
+
html += `
|
533 |
+
<div class="config-item">
|
534 |
+
<div class="label">${key.replace(/_/g, ' ').toUpperCase()}</div>
|
535 |
+
<div class="value">${displayValue}</div>
|
536 |
+
</div>
|
537 |
+
`;
|
538 |
+
});
|
539 |
+
|
540 |
+
html += '</div>';
|
541 |
+
container.innerHTML = html;
|
542 |
+
}
|
543 |
+
|
544 |
+
// Utility function to format large numbers
|
545 |
+
function formatNumber(num) {
|
546 |
+
if (num >= 1e9) return (num / 1e9).toFixed(2) + 'B';
|
547 |
+
if (num >= 1e6) return (num / 1e6).toFixed(2) + 'M';
|
548 |
+
if (num >= 1e3) return (num / 1e3).toFixed(2) + 'K';
|
549 |
+
return num.toString();
|
550 |
+
}
|
plots/data.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
plots/index.html
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<!DOCTYPE html>
|
2 |
+
<html lang="en">
|
3 |
+
<head>
|
4 |
+
<meta charset="UTF-8">
|
5 |
+
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
6 |
+
<title>Pico Training Metrics Dashboard</title>
|
7 |
+
<link rel="stylesheet" href="style.css">
|
8 |
+
<script src="https://cdn.jsdelivr.net/npm/chart.js"></script>
|
9 |
+
</head>
|
10 |
+
<body>
|
11 |
+
<div class="container">
|
12 |
+
<header>
|
13 |
+
<h1>🚀 Pico Training Metrics Dashboard</h1>
|
14 |
+
<p>Real-time visualization of training progress across all model runs</p>
|
15 |
+
</header>
|
16 |
+
|
17 |
+
<div class="controls">
|
18 |
+
<div class="run-selector">
|
19 |
+
<label for="runSelect">Select Run:</label>
|
20 |
+
<select id="runSelect">
|
21 |
+
<option value="all">All Runs</option>
|
22 |
+
</select>
|
23 |
+
</div>
|
24 |
+
<div class="metric-toggle">
|
25 |
+
<label>
|
26 |
+
<input type="checkbox" id="showTraining" checked> Training Loss
|
27 |
+
</label>
|
28 |
+
<label>
|
29 |
+
<input type="checkbox" id="showLearningRate" checked> Learning Rate
|
30 |
+
</label>
|
31 |
+
<label>
|
32 |
+
<input type="checkbox" id="showEvaluation" checked> Paloma Evaluation
|
33 |
+
</label>
|
34 |
+
</div>
|
35 |
+
</div>
|
36 |
+
|
37 |
+
<div class="charts-container">
|
38 |
+
<div class="chart-card">
|
39 |
+
<h3>📈 Training Loss Over Time</h3>
|
40 |
+
<canvas id="lossChart"></canvas>
|
41 |
+
</div>
|
42 |
+
|
43 |
+
<div class="chart-card">
|
44 |
+
<h3>🎯 Learning Rate Schedule</h3>
|
45 |
+
<canvas id="lrChart"></canvas>
|
46 |
+
</div>
|
47 |
+
|
48 |
+
<div class="chart-card">
|
49 |
+
<h3>📊 Paloma Evaluation Metrics</h3>
|
50 |
+
<canvas id="evalChart"></canvas>
|
51 |
+
</div>
|
52 |
+
|
53 |
+
<div class="chart-card">
|
54 |
+
<h3>🔄 Combined View</h3>
|
55 |
+
<canvas id="combinedChart"></canvas>
|
56 |
+
</div>
|
57 |
+
</div>
|
58 |
+
|
59 |
+
<div class="run-summary">
|
60 |
+
<h3>📋 Run Summary</h3>
|
61 |
+
<div id="runSummary"></div>
|
62 |
+
</div>
|
63 |
+
|
64 |
+
<div class="config-details">
|
65 |
+
<h3>⚙️ Model Configuration</h3>
|
66 |
+
<div id="configDetails"></div>
|
67 |
+
</div>
|
68 |
+
</div>
|
69 |
+
|
70 |
+
<script src="code.js"></script>
|
71 |
+
</body>
|
72 |
+
</html>
|
plots/style.css
ADDED
@@ -0,0 +1,258 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
* {
|
2 |
+
margin: 0;
|
3 |
+
padding: 0;
|
4 |
+
box-sizing: border-box;
|
5 |
+
}
|
6 |
+
|
7 |
+
body {
|
8 |
+
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, Cantarell, sans-serif;
|
9 |
+
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
10 |
+
min-height: 100vh;
|
11 |
+
color: #333;
|
12 |
+
}
|
13 |
+
|
14 |
+
.container {
|
15 |
+
max-width: 1400px;
|
16 |
+
margin: 0 auto;
|
17 |
+
padding: 20px;
|
18 |
+
}
|
19 |
+
|
20 |
+
header {
|
21 |
+
text-align: center;
|
22 |
+
margin-bottom: 30px;
|
23 |
+
color: white;
|
24 |
+
}
|
25 |
+
|
26 |
+
header h1 {
|
27 |
+
font-size: 2.5rem;
|
28 |
+
margin-bottom: 10px;
|
29 |
+
text-shadow: 2px 2px 4px rgba(0,0,0,0.3);
|
30 |
+
}
|
31 |
+
|
32 |
+
header p {
|
33 |
+
font-size: 1.1rem;
|
34 |
+
opacity: 0.9;
|
35 |
+
}
|
36 |
+
|
37 |
+
.controls {
|
38 |
+
background: white;
|
39 |
+
padding: 20px;
|
40 |
+
border-radius: 12px;
|
41 |
+
box-shadow: 0 8px 32px rgba(0,0,0,0.1);
|
42 |
+
margin-bottom: 30px;
|
43 |
+
display: flex;
|
44 |
+
justify-content: space-between;
|
45 |
+
align-items: center;
|
46 |
+
flex-wrap: wrap;
|
47 |
+
gap: 20px;
|
48 |
+
}
|
49 |
+
|
50 |
+
.run-selector select {
|
51 |
+
padding: 8px 16px;
|
52 |
+
border: 2px solid #e1e5e9;
|
53 |
+
border-radius: 8px;
|
54 |
+
font-size: 14px;
|
55 |
+
background: white;
|
56 |
+
cursor: pointer;
|
57 |
+
transition: border-color 0.3s ease;
|
58 |
+
}
|
59 |
+
|
60 |
+
.run-selector select:focus {
|
61 |
+
outline: none;
|
62 |
+
border-color: #667eea;
|
63 |
+
}
|
64 |
+
|
65 |
+
.metric-toggle {
|
66 |
+
display: flex;
|
67 |
+
gap: 20px;
|
68 |
+
flex-wrap: wrap;
|
69 |
+
}
|
70 |
+
|
71 |
+
.metric-toggle label {
|
72 |
+
display: flex;
|
73 |
+
align-items: center;
|
74 |
+
gap: 8px;
|
75 |
+
cursor: pointer;
|
76 |
+
font-weight: 500;
|
77 |
+
color: #555;
|
78 |
+
}
|
79 |
+
|
80 |
+
.metric-toggle input[type="checkbox"] {
|
81 |
+
width: 18px;
|
82 |
+
height: 18px;
|
83 |
+
accent-color: #667eea;
|
84 |
+
}
|
85 |
+
|
86 |
+
.charts-container {
|
87 |
+
display: grid;
|
88 |
+
grid-template-columns: repeat(auto-fit, minmax(600px, 1fr));
|
89 |
+
gap: 30px;
|
90 |
+
margin-bottom: 30px;
|
91 |
+
}
|
92 |
+
|
93 |
+
.chart-card {
|
94 |
+
background: white;
|
95 |
+
padding: 25px;
|
96 |
+
border-radius: 12px;
|
97 |
+
box-shadow: 0 8px 32px rgba(0,0,0,0.1);
|
98 |
+
transition: transform 0.3s ease, box-shadow 0.3s ease;
|
99 |
+
}
|
100 |
+
|
101 |
+
.chart-card:hover {
|
102 |
+
transform: translateY(-5px);
|
103 |
+
box-shadow: 0 12px 40px rgba(0,0,0,0.15);
|
104 |
+
}
|
105 |
+
|
106 |
+
.chart-card h3 {
|
107 |
+
margin-bottom: 20px;
|
108 |
+
color: #333;
|
109 |
+
font-size: 1.2rem;
|
110 |
+
display: flex;
|
111 |
+
align-items: center;
|
112 |
+
gap: 8px;
|
113 |
+
}
|
114 |
+
|
115 |
+
.chart-card canvas {
|
116 |
+
max-height: 400px;
|
117 |
+
width: 100% !important;
|
118 |
+
}
|
119 |
+
|
120 |
+
.run-summary, .config-details {
|
121 |
+
background: white;
|
122 |
+
padding: 25px;
|
123 |
+
border-radius: 12px;
|
124 |
+
box-shadow: 0 8px 32px rgba(0,0,0,0.1);
|
125 |
+
margin-bottom: 30px;
|
126 |
+
}
|
127 |
+
|
128 |
+
.run-summary h3, .config-details h3 {
|
129 |
+
margin-bottom: 20px;
|
130 |
+
color: #333;
|
131 |
+
font-size: 1.2rem;
|
132 |
+
display: flex;
|
133 |
+
align-items: center;
|
134 |
+
gap: 8px;
|
135 |
+
}
|
136 |
+
|
137 |
+
.run-grid {
|
138 |
+
display: grid;
|
139 |
+
grid-template-columns: repeat(auto-fit, minmax(300px, 1fr));
|
140 |
+
gap: 20px;
|
141 |
+
}
|
142 |
+
|
143 |
+
.run-card {
|
144 |
+
background: #f8f9fa;
|
145 |
+
padding: 20px;
|
146 |
+
border-radius: 8px;
|
147 |
+
border-left: 4px solid #667eea;
|
148 |
+
}
|
149 |
+
|
150 |
+
.run-card h4 {
|
151 |
+
color: #667eea;
|
152 |
+
margin-bottom: 10px;
|
153 |
+
font-size: 1.1rem;
|
154 |
+
}
|
155 |
+
|
156 |
+
.run-card p {
|
157 |
+
margin-bottom: 8px;
|
158 |
+
color: #666;
|
159 |
+
font-size: 0.9rem;
|
160 |
+
}
|
161 |
+
|
162 |
+
.run-card .metric {
|
163 |
+
display: flex;
|
164 |
+
justify-content: space-between;
|
165 |
+
margin-bottom: 5px;
|
166 |
+
}
|
167 |
+
|
168 |
+
.run-card .metric .value {
|
169 |
+
font-weight: 600;
|
170 |
+
color: #333;
|
171 |
+
}
|
172 |
+
|
173 |
+
.config-grid {
|
174 |
+
display: grid;
|
175 |
+
grid-template-columns: repeat(auto-fit, minmax(200px, 1fr));
|
176 |
+
gap: 15px;
|
177 |
+
}
|
178 |
+
|
179 |
+
.config-item {
|
180 |
+
background: #f8f9fa;
|
181 |
+
padding: 15px;
|
182 |
+
border-radius: 8px;
|
183 |
+
text-align: center;
|
184 |
+
}
|
185 |
+
|
186 |
+
.config-item .label {
|
187 |
+
font-size: 0.8rem;
|
188 |
+
color: #666;
|
189 |
+
text-transform: uppercase;
|
190 |
+
letter-spacing: 0.5px;
|
191 |
+
margin-bottom: 5px;
|
192 |
+
}
|
193 |
+
|
194 |
+
.config-item .value {
|
195 |
+
font-size: 1.2rem;
|
196 |
+
font-weight: 600;
|
197 |
+
color: #333;
|
198 |
+
}
|
199 |
+
|
200 |
+
@media (max-width: 768px) {
|
201 |
+
.container {
|
202 |
+
padding: 15px;
|
203 |
+
}
|
204 |
+
|
205 |
+
header h1 {
|
206 |
+
font-size: 2rem;
|
207 |
+
}
|
208 |
+
|
209 |
+
.controls {
|
210 |
+
flex-direction: column;
|
211 |
+
align-items: stretch;
|
212 |
+
}
|
213 |
+
|
214 |
+
.charts-container {
|
215 |
+
grid-template-columns: 1fr;
|
216 |
+
}
|
217 |
+
|
218 |
+
.chart-card {
|
219 |
+
padding: 20px;
|
220 |
+
}
|
221 |
+
|
222 |
+
.run-grid, .config-grid {
|
223 |
+
grid-template-columns: 1fr;
|
224 |
+
}
|
225 |
+
}
|
226 |
+
|
227 |
+
/* Chart.js customizations */
|
228 |
+
.chartjs-tooltip {
|
229 |
+
background: rgba(0,0,0,0.8) !important;
|
230 |
+
color: white !important;
|
231 |
+
border-radius: 8px !important;
|
232 |
+
padding: 10px !important;
|
233 |
+
font-size: 12px !important;
|
234 |
+
}
|
235 |
+
|
236 |
+
/* Loading state */
|
237 |
+
.loading {
|
238 |
+
text-align: center;
|
239 |
+
padding: 40px;
|
240 |
+
color: #666;
|
241 |
+
}
|
242 |
+
|
243 |
+
.loading::after {
|
244 |
+
content: '';
|
245 |
+
display: inline-block;
|
246 |
+
width: 20px;
|
247 |
+
height: 20px;
|
248 |
+
border: 3px solid #f3f3f3;
|
249 |
+
border-top: 3px solid #667eea;
|
250 |
+
border-radius: 50%;
|
251 |
+
animation: spin 1s linear infinite;
|
252 |
+
margin-left: 10px;
|
253 |
+
}
|
254 |
+
|
255 |
+
@keyframes spin {
|
256 |
+
0% { transform: rotate(0deg); }
|
257 |
+
100% { transform: rotate(360deg); }
|
258 |
+
}
|
pyproject.toml
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[tool.poetry]
|
2 |
+
name = "pico-train"
|
3 |
+
version = "1.0.0"
|
4 |
+
description = "A minimalistic framework for transparently training language models and storing comprehensive checkpoints for in-depth learning dynamics research"
|
5 |
+
authors = ["Richard Diehl Martinez <richard@picolm.io>"]
|
6 |
+
license = "Apache 2.0"
|
7 |
+
readme = "README.md"
|
8 |
+
packages = [{include = "src"}]
|
9 |
+
|
10 |
+
[tool.poetry.scripts]
|
11 |
+
train = "scripts.train:main"
|
12 |
+
|
13 |
+
[tool.poetry.dependencies]
|
14 |
+
python = "^3.10,<3.13"
|
15 |
+
lightning = "^2.4.0"
|
16 |
+
click = "^8.1.7"
|
17 |
+
wandb = "^0.18.1"
|
18 |
+
huggingface-hub = {extras = ["cli"], version = "^0.25.1"}
|
19 |
+
datasets = "^3.0.1,<3.2.0"
|
20 |
+
transformers = "^4.45.2"
|
21 |
+
pre-commit = "^4.0.1"
|
22 |
+
torch = "^2.5.1"
|
23 |
+
evaluate = "^0.4.3"
|
24 |
+
deepspeed = "^0.16.2"
|
25 |
+
rich = "^13.9.4"
|
26 |
+
|
27 |
+
[tool.poetry.group.dev.dependencies]
|
28 |
+
ipykernel = "^6.29.5"
|
29 |
+
jupyter = "^1.1.1"
|
30 |
+
|
31 |
+
[build-system]
|
32 |
+
requires = ["poetry-core"]
|
33 |
+
build-backend = "poetry.core.masonry.api"
|
scripts/README.md
ADDED
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Scripts Directory
|
2 |
+
|
3 |
+
This directory contains utility scripts for the Pico training framework.
|
4 |
+
|
5 |
+
## generate_data.py
|
6 |
+
|
7 |
+
A script to automatically generate `data.json` from training log files for the dashboard.
|
8 |
+
|
9 |
+
### What it does
|
10 |
+
|
11 |
+
This script parses log files from the `runs/` directory and extracts:
|
12 |
+
- **Training metrics**: Loss, learning rate, and inf/NaN counts at each step
|
13 |
+
- **Evaluation results**: Paloma evaluation metrics
|
14 |
+
- **Model configuration**: Architecture parameters (d_model, n_layers, etc.)
|
15 |
+
|
16 |
+
### Usage
|
17 |
+
|
18 |
+
```bash
|
19 |
+
# Generate data.json from the default runs directory
|
20 |
+
python scripts/generate_data.py
|
21 |
+
|
22 |
+
# Specify custom runs directory
|
23 |
+
python scripts/generate_data.py --runs-dir /path/to/runs
|
24 |
+
|
25 |
+
# Specify custom output file
|
26 |
+
python scripts/generate_data.py --output /path/to/output.json
|
27 |
+
```
|
28 |
+
|
29 |
+
### How it works
|
30 |
+
|
31 |
+
1. **Scans runs directory**: Looks for subdirectories containing training runs
|
32 |
+
2. **Finds log files**: Locates `.log` files in each run's `logs/` subdirectory
|
33 |
+
3. **Parses log content**: Uses regex patterns to extract structured data
|
34 |
+
4. **Generates JSON**: Creates a structured JSON file for the dashboard
|
35 |
+
|
36 |
+
### Log Format Requirements
|
37 |
+
|
38 |
+
The script expects log files with the following format:
|
39 |
+
|
40 |
+
```
|
41 |
+
2025-08-29 02:09:12 - pico-train - INFO - Step 500 -- 🔄 Training Metrics
|
42 |
+
2025-08-29 02:09:12 - pico-train - INFO - ├── Loss: 10.8854
|
43 |
+
2025-08-29 02:09:12 - pico-train - INFO - ├── Learning Rate: 3.13e-06
|
44 |
+
2025-08-29 02:09:12 - pico-train - INFO - └── Inf/NaN count: 0
|
45 |
+
```
|
46 |
+
|
47 |
+
And evaluation results:
|
48 |
+
|
49 |
+
```
|
50 |
+
2025-08-29 02:15:26 - pico-train - INFO - Step 1000 -- 📊 Evaluation Results
|
51 |
+
2025-08-29 02:15:26 - pico-train - INFO - └── paloma: 7.125172406420199e+27
|
52 |
+
```
|
53 |
+
|
54 |
+
### Output Format
|
55 |
+
|
56 |
+
The generated `data.json` has this structure:
|
57 |
+
|
58 |
+
```json
|
59 |
+
{
|
60 |
+
"runs": [
|
61 |
+
{
|
62 |
+
"run_name": "model-name",
|
63 |
+
"log_file": "log_filename.log",
|
64 |
+
"training_metrics": [
|
65 |
+
{
|
66 |
+
"step": 0,
|
67 |
+
"loss": 10.9914,
|
68 |
+
"learning_rate": 0.0,
|
69 |
+
"inf_nan_count": 0
|
70 |
+
}
|
71 |
+
],
|
72 |
+
"evaluation_results": [
|
73 |
+
{
|
74 |
+
"step": 1000,
|
75 |
+
"paloma": 59434.76600609756
|
76 |
+
}
|
77 |
+
],
|
78 |
+
"config": {
|
79 |
+
"d_model": 96,
|
80 |
+
"n_layers": 12,
|
81 |
+
"max_seq_len": 2048,
|
82 |
+
"vocab_size": 50304,
|
83 |
+
"lr": 0.0003,
|
84 |
+
"max_steps": 200000,
|
85 |
+
"batch_size": 8
|
86 |
+
}
|
87 |
+
}
|
88 |
+
],
|
89 |
+
"summary": {
|
90 |
+
"total_runs": 1,
|
91 |
+
"run_names": ["model-name"]
|
92 |
+
}
|
93 |
+
}
|
94 |
+
```
|
95 |
+
|
96 |
+
### When to use
|
97 |
+
|
98 |
+
- **After training**: Generate updated dashboard data
|
99 |
+
- **Adding new runs**: Include new training sessions in the dashboard
|
100 |
+
- **Debugging**: Verify log parsing is working correctly
|
101 |
+
- **Dashboard setup**: Initial setup of the training metrics dashboard
|
102 |
+
|
103 |
+
### Troubleshooting
|
104 |
+
|
105 |
+
If the script doesn't find any data:
|
106 |
+
1. Check that log files exist in `runs/*/logs/`
|
107 |
+
2. Verify log format matches the expected pattern
|
108 |
+
3. Ensure log files contain training metrics entries
|
109 |
+
4. Check file permissions and encoding
|
scripts/generate_data.py
ADDED
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
"""
|
3 |
+
Script to generate data.json from training log files.
|
4 |
+
|
5 |
+
This script parses log files from the runs directory and extracts:
|
6 |
+
- Training metrics (loss, learning rate, inf/nan count)
|
7 |
+
- Evaluation results (paloma metrics)
|
8 |
+
- Model configuration parameters
|
9 |
+
|
10 |
+
The output is saved to plots/data.json for the dashboard.
|
11 |
+
"""
|
12 |
+
|
13 |
+
import json
|
14 |
+
import re
|
15 |
+
from pathlib import Path
|
16 |
+
from typing import Any, Dict, List, Optional
|
17 |
+
|
18 |
+
|
19 |
+
def parse_training_metrics(log_content: str) -> List[Dict[str, Any]]:
|
20 |
+
"""Parse training metrics from log content."""
|
21 |
+
metrics = []
|
22 |
+
|
23 |
+
# Pattern to match training metrics entries with timestamp and log level
|
24 |
+
pattern = r"\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2} - pico-train - INFO - Step (\d+) -- 🔄 Training Metrics\n\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2} - pico-train - INFO - ├── Loss: ([\d.]+)\n\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2} - pico-train - INFO - ├── Learning Rate: ([\d.e+-]+)\n\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2} - pico-train - INFO - └── Inf/NaN count: (\d+)"
|
25 |
+
|
26 |
+
matches = re.findall(pattern, log_content)
|
27 |
+
|
28 |
+
for step, loss, lr, inf_nan in matches:
|
29 |
+
metrics.append(
|
30 |
+
{
|
31 |
+
"step": int(step),
|
32 |
+
"loss": float(loss),
|
33 |
+
"learning_rate": float(lr),
|
34 |
+
"inf_nan_count": int(inf_nan),
|
35 |
+
}
|
36 |
+
)
|
37 |
+
|
38 |
+
return sorted(metrics, key=lambda x: x["step"])
|
39 |
+
|
40 |
+
|
41 |
+
def parse_evaluation_results(log_content: str) -> List[Dict[str, Any]]:
|
42 |
+
"""Parse evaluation results from log content."""
|
43 |
+
results = []
|
44 |
+
|
45 |
+
# Pattern to match evaluation results with timestamp and log level
|
46 |
+
pattern = r"\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2} - pico-train - INFO - Step (\d+) -- 📊 Evaluation Results\n\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2} - pico-train - INFO - └── paloma: ([\d.e+-]+)"
|
47 |
+
|
48 |
+
matches = re.findall(pattern, log_content)
|
49 |
+
|
50 |
+
for step, paloma in matches:
|
51 |
+
try:
|
52 |
+
paloma_value = float(paloma)
|
53 |
+
results.append({"step": int(step), "paloma": paloma_value})
|
54 |
+
except ValueError:
|
55 |
+
# Skip if paloma value is not a valid number (e.g., "inf")
|
56 |
+
continue
|
57 |
+
|
58 |
+
return sorted(results, key=lambda x: x["step"])
|
59 |
+
|
60 |
+
|
61 |
+
def extract_config_from_log(log_content: str) -> Dict[str, Any]:
|
62 |
+
"""Extract model configuration from log content."""
|
63 |
+
config = {}
|
64 |
+
|
65 |
+
# Extract key model parameters
|
66 |
+
patterns = {
|
67 |
+
"d_model": r"d_model: (\d+)",
|
68 |
+
"n_layers": r"n_layers: (\d+)",
|
69 |
+
"max_seq_len": r"max_seq_len: (\d+)",
|
70 |
+
"vocab_size": r"vocab_size: (\d+)",
|
71 |
+
"lr": r"lr: ([\d.e+-]+)",
|
72 |
+
"max_steps": r"max_steps: (\d+)",
|
73 |
+
"batch_size": r"batch_size: (\d+)",
|
74 |
+
}
|
75 |
+
|
76 |
+
for key, pattern in patterns.items():
|
77 |
+
match = re.search(pattern, log_content)
|
78 |
+
if match:
|
79 |
+
try:
|
80 |
+
if key in [
|
81 |
+
"d_model",
|
82 |
+
"n_layers",
|
83 |
+
"max_seq_len",
|
84 |
+
"vocab_size",
|
85 |
+
"max_steps",
|
86 |
+
"batch_size",
|
87 |
+
]:
|
88 |
+
config[key] = int(match.group(1))
|
89 |
+
else:
|
90 |
+
config[key] = float(match.group(1))
|
91 |
+
except ValueError:
|
92 |
+
continue
|
93 |
+
|
94 |
+
return config
|
95 |
+
|
96 |
+
|
97 |
+
def process_run_directory(run_path: Path) -> Optional[Dict[str, Any]]:
|
98 |
+
"""Process a single run directory and extract all data."""
|
99 |
+
run_name = run_path.name
|
100 |
+
|
101 |
+
# Find log files
|
102 |
+
logs_dir = run_path / "logs"
|
103 |
+
if not logs_dir.exists():
|
104 |
+
return None
|
105 |
+
|
106 |
+
log_files = list(logs_dir.glob("*.log"))
|
107 |
+
if not log_files:
|
108 |
+
return None
|
109 |
+
|
110 |
+
# Use the most recent log file for configuration
|
111 |
+
latest_log = max(log_files, key=lambda x: x.stat().st_mtime)
|
112 |
+
|
113 |
+
# Read log content
|
114 |
+
log_content = latest_log.read_text(encoding="utf-8")
|
115 |
+
|
116 |
+
# Extract data
|
117 |
+
training_metrics = parse_training_metrics(log_content)
|
118 |
+
evaluation_results = parse_evaluation_results(log_content)
|
119 |
+
config = extract_config_from_log(log_content)
|
120 |
+
|
121 |
+
# If no training metrics found, skip this run
|
122 |
+
if not training_metrics:
|
123 |
+
return None
|
124 |
+
|
125 |
+
return {
|
126 |
+
"run_name": run_name,
|
127 |
+
"log_file": latest_log.name,
|
128 |
+
"training_metrics": training_metrics,
|
129 |
+
"evaluation_results": evaluation_results,
|
130 |
+
"config": config,
|
131 |
+
}
|
132 |
+
|
133 |
+
|
134 |
+
def generate_data_json(runs_dir: str = "runs", output_file: str = "plots/data.json"):
|
135 |
+
"""Generate data.json from all run directories."""
|
136 |
+
runs_path = Path(runs_dir)
|
137 |
+
if not runs_path.exists():
|
138 |
+
print(f"Runs directory {runs_dir} not found!")
|
139 |
+
return
|
140 |
+
|
141 |
+
runs_data = []
|
142 |
+
|
143 |
+
# Process each run directory
|
144 |
+
for run_dir in runs_path.iterdir():
|
145 |
+
if run_dir.is_dir():
|
146 |
+
print(f"Processing run: {run_dir.name}")
|
147 |
+
run_data = process_run_directory(run_dir)
|
148 |
+
if run_data:
|
149 |
+
runs_data.append(run_data)
|
150 |
+
print(f" ✓ Found {len(run_data['training_metrics'])} training metrics")
|
151 |
+
print(
|
152 |
+
f" ✓ Found {len(run_data['evaluation_results'])} evaluation results"
|
153 |
+
)
|
154 |
+
else:
|
155 |
+
print(" ✗ No valid data found")
|
156 |
+
|
157 |
+
if not runs_data:
|
158 |
+
print("No valid runs found!")
|
159 |
+
return
|
160 |
+
|
161 |
+
# Create output data structure
|
162 |
+
output_data = {
|
163 |
+
"runs": runs_data,
|
164 |
+
"summary": {
|
165 |
+
"total_runs": len(runs_data),
|
166 |
+
"run_names": [run["run_name"] for run in runs_data],
|
167 |
+
},
|
168 |
+
}
|
169 |
+
|
170 |
+
# Ensure output directory exists
|
171 |
+
output_path = Path(output_file)
|
172 |
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
173 |
+
|
174 |
+
# Write to file
|
175 |
+
with open(output_path, "w", encoding="utf-8") as f:
|
176 |
+
json.dump(output_data, f, indent=2, ensure_ascii=False)
|
177 |
+
|
178 |
+
print(f"\n✓ Generated {output_file} with {len(runs_data)} runs")
|
179 |
+
print(
|
180 |
+
f"✓ Total training metrics: {sum(len(run['training_metrics']) for run in runs_data)}"
|
181 |
+
)
|
182 |
+
print(
|
183 |
+
f"✓ Total evaluation results: {sum(len(run['evaluation_results']) for run in runs_data)}"
|
184 |
+
)
|
185 |
+
|
186 |
+
|
187 |
+
if __name__ == "__main__":
|
188 |
+
import argparse
|
189 |
+
|
190 |
+
parser = argparse.ArgumentParser(
|
191 |
+
description="Generate data.json from training logs"
|
192 |
+
)
|
193 |
+
parser.add_argument("--runs-dir", default="runs", help="Path to runs directory")
|
194 |
+
parser.add_argument("--output", default="plots/data.json", help="Output file path")
|
195 |
+
|
196 |
+
args = parser.parse_args()
|
197 |
+
|
198 |
+
generate_data_json(args.runs_dir, args.output)
|
scripts/train.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
"""
|
3 |
+
A minimal script to train the Pico language model. In practice, you should just use the
|
4 |
+
`poetry run train` command to run the training pipeline. Doing so will invoke this script.
|
5 |
+
Training logic is located in `src/training/trainer.py`.
|
6 |
+
"""
|
7 |
+
|
8 |
+
from pathlib import Path
|
9 |
+
|
10 |
+
import click
|
11 |
+
|
12 |
+
from src.training.trainer import Trainer
|
13 |
+
|
14 |
+
|
15 |
+
@click.command()
|
16 |
+
@click.option(
|
17 |
+
"--config_path",
|
18 |
+
"config_path",
|
19 |
+
type=click.Path(exists=True, path_type=Path),
|
20 |
+
help="Path to the training configuration file",
|
21 |
+
)
|
22 |
+
def main(config_path: Path) -> None:
|
23 |
+
"""Train the Pico language model using the specified configuration."""
|
24 |
+
|
25 |
+
trainer = Trainer(config_path=str(config_path))
|
26 |
+
trainer.train()
|
27 |
+
|
28 |
+
|
29 |
+
if __name__ == "__main__":
|
30 |
+
main()
|
setup.sh
ADDED
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
# This script sets up the project by installing dependencies, checking for a poetry environment,
|
3 |
+
# and installing pre-commit hooks.
|
4 |
+
|
5 |
+
# Add color and formatting variables at the top
|
6 |
+
GREEN='\033[0;32m'
|
7 |
+
BLUE='\033[0;34m'
|
8 |
+
YELLOW='\033[1;33m'
|
9 |
+
RED='\033[0;31m'
|
10 |
+
NC='\033[0m' # No Color
|
11 |
+
BOLD='\033[1m'
|
12 |
+
|
13 |
+
# Initialize error tracking
|
14 |
+
ERRORS_FOUND=0
|
15 |
+
|
16 |
+
# Function for section headers
|
17 |
+
print_section() {
|
18 |
+
echo -e "\n${BOLD}${BLUE}=== $1 ===${NC}\n"
|
19 |
+
}
|
20 |
+
|
21 |
+
# Function for success messages
|
22 |
+
print_success() {
|
23 |
+
echo -e "${GREEN}✓ $1${NC}"
|
24 |
+
}
|
25 |
+
|
26 |
+
# Function for warnings
|
27 |
+
print_warning() {
|
28 |
+
echo -e "${YELLOW}⚠ $1${NC}"
|
29 |
+
}
|
30 |
+
|
31 |
+
# --- GIT LFS SETUP --- #
|
32 |
+
print_section "Git LFS Setup"
|
33 |
+
if ! command -v git-lfs &> /dev/null; then
|
34 |
+
print_warning "git-lfs is not installed. Some model checkpointing functionality may not work correctly."
|
35 |
+
ERRORS_FOUND=$((ERRORS_FOUND + 1))
|
36 |
+
|
37 |
+
# Check the operating system
|
38 |
+
if [[ "$OSTYPE" == "darwin"* ]]; then
|
39 |
+
# macOS
|
40 |
+
echo -e "${YELLOW} You can install it using Homebrew:${NC}"
|
41 |
+
echo " brew install git-lfs"
|
42 |
+
elif [[ "$OSTYPE" == "linux-gnu"* ]]; then
|
43 |
+
# Linux
|
44 |
+
echo -e "${YELLOW} You can install it using your package manager:${NC}"
|
45 |
+
if command -v apt-get &> /dev/null; then
|
46 |
+
# Ubuntu/Debian
|
47 |
+
echo " curl -s https://packagecloud.io/install/repositories/github/git-lfs/script.deb.sh | sudo bash"
|
48 |
+
echo " sudo apt-get install git-lfs"
|
49 |
+
elif command -v yum &> /dev/null; then
|
50 |
+
# CentOS/RHEL
|
51 |
+
echo " curl -s https://packagecloud.io/install/repositories/github/git-lfs/script.rpm.sh | sudo bash"
|
52 |
+
echo " sudo yum install git-lfs"
|
53 |
+
else
|
54 |
+
print_warning "Could not detect package manager. Please install git-lfs manually."
|
55 |
+
fi
|
56 |
+
else
|
57 |
+
print_warning "Unsupported operating system. Please install git-lfs manually."
|
58 |
+
fi
|
59 |
+
else
|
60 |
+
git-lfs install
|
61 |
+
print_success "git-lfs installed and initialized"
|
62 |
+
fi
|
63 |
+
|
64 |
+
# --- CUDA VERSION CHECK --- #
|
65 |
+
print_section "CUDA Version Check"
|
66 |
+
if command -v nvidia-smi &> /dev/null; then
|
67 |
+
CUDA_VERSION=$(nvidia-smi | sed -n 's/.*CUDA Version: \([0-9.]*\).*/\1/p')
|
68 |
+
|
69 |
+
if [[ -z "$CUDA_VERSION" ]]; then
|
70 |
+
ERRORS_FOUND=$((ERRORS_FOUND + 1))
|
71 |
+
print_warning "nvidia-smi failed to communicate with the NVIDIA driver."
|
72 |
+
echo -e "${YELLOW} Ensure that the latest NVIDIA driver is installed and running.${NC}"
|
73 |
+
else
|
74 |
+
MAJOR_VERSION=${CUDA_VERSION%.*}
|
75 |
+
MINOR_VERSION=${CUDA_VERSION#*.}
|
76 |
+
|
77 |
+
if [ "$MAJOR_VERSION" -lt 12 ] || ([ "$MAJOR_VERSION" -eq 12 ] && [ "$MINOR_VERSION" -lt 1 ]); then
|
78 |
+
ERRORS_FOUND=$((ERRORS_FOUND + 1))
|
79 |
+
print_warning "CUDA version ${MAJOR_VERSION}.${MINOR_VERSION} detected."
|
80 |
+
echo -e "${YELLOW} Some multi-node communication GPU features may not work properly.${NC}"
|
81 |
+
echo -e "${YELLOW} CUDA version 12.1 or newer is recommended.${NC}"
|
82 |
+
else
|
83 |
+
print_success "CUDA version ${MAJOR_VERSION}.${MINOR_VERSION} detected"
|
84 |
+
fi
|
85 |
+
fi
|
86 |
+
else
|
87 |
+
ERRORS_FOUND=$((ERRORS_FOUND + 1))
|
88 |
+
print_warning "nvidia-smi not found. Unable to check CUDA version."
|
89 |
+
echo -e "${YELLOW} Ensure that NVIDIA drivers and CUDA version at 12.1 or newer are installed for GPU support.${NC}"
|
90 |
+
fi
|
91 |
+
|
92 |
+
|
93 |
+
# ---- ENVIRONMENT VARIABLES ---- #
|
94 |
+
print_section "Environment Variables"
|
95 |
+
if [ -f .env ]; then
|
96 |
+
print_success "Loading environment variables from .env..."
|
97 |
+
source .env
|
98 |
+
if [[ -n "$HF_TOKEN" && -n "$WANDB_API_KEY" ]]; then
|
99 |
+
print_success "Both HF_TOKEN and WANDB_API_KEY are set and loaded!"
|
100 |
+
else
|
101 |
+
print_warning "One or both of HF_TOKEN and WANDB_API_KEY are not set."
|
102 |
+
fi
|
103 |
+
else
|
104 |
+
print_warning "No .env file found."
|
105 |
+
echo -e "${YELLOW} You might need to create one with HF_TOKEN and WANDB_API_KEY${NC}"
|
106 |
+
echo -e "${YELLOW} Example .env contents:${NC}"
|
107 |
+
echo " export HF_TOKEN=your_huggingface_token"
|
108 |
+
echo " export WANDB_API_KEY=your_wandb_key"
|
109 |
+
ERRORS_FOUND=$((ERRORS_FOUND + 1))
|
110 |
+
fi
|
111 |
+
|
112 |
+
# ---- POETRY SETUP ---- #
|
113 |
+
print_section "Poetry Setup"
|
114 |
+
|
115 |
+
# First check if Poetry is installed
|
116 |
+
if ! command -v poetry &> /dev/null; then
|
117 |
+
echo "Poetry not found. Installing..."
|
118 |
+
|
119 |
+
# Run the installation command
|
120 |
+
curl -sSL https://install.python-poetry.org | python3 -
|
121 |
+
POETRY_INSTALL_STATUS=$?
|
122 |
+
|
123 |
+
if [ $POETRY_INSTALL_STATUS -ne 0 ]; then
|
124 |
+
print_warning "Poetry installation failed!"
|
125 |
+
ERRORS_FOUND=$((ERRORS_FOUND + 1))
|
126 |
+
else
|
127 |
+
export PATH="$HOME/.local/bin:$PATH"
|
128 |
+
|
129 |
+
# Verify installation succeeded
|
130 |
+
if ! command -v poetry &> /dev/null; then
|
131 |
+
print_warning "Poetry was installed but cannot be found in PATH!"
|
132 |
+
echo -e "${YELLOW} Try adding this to your shell profile:${NC}"
|
133 |
+
echo " export PATH=\"\$HOME/.local/bin:\$PATH\""
|
134 |
+
ERRORS_FOUND=$((ERRORS_FOUND + 1))
|
135 |
+
else
|
136 |
+
print_success "Poetry installed successfully"
|
137 |
+
fi
|
138 |
+
fi
|
139 |
+
else
|
140 |
+
print_success "Poetry already installed"
|
141 |
+
fi
|
142 |
+
|
143 |
+
# Then check for virtual environment
|
144 |
+
if [ ! -d ".venv" ]; then
|
145 |
+
echo "No virtual environment found. Creating one..."
|
146 |
+
poetry config virtualenvs.in-project true
|
147 |
+
|
148 |
+
# Create virtual environment and install dependencies
|
149 |
+
poetry install --with dev
|
150 |
+
POETRY_VENV_STATUS=$?
|
151 |
+
|
152 |
+
if [ $POETRY_VENV_STATUS -ne 0 ]; then
|
153 |
+
print_warning "Failed to create Poetry virtual environment!"
|
154 |
+
ERRORS_FOUND=$((ERRORS_FOUND + 1))
|
155 |
+
else
|
156 |
+
print_success "Poetry environment created successfully"
|
157 |
+
fi
|
158 |
+
else
|
159 |
+
print_success "Poetry environment already exists"
|
160 |
+
fi
|
161 |
+
|
162 |
+
# ---- PRE-COMMIT SETUP ---- #
|
163 |
+
print_section "Pre-commit Setup"
|
164 |
+
|
165 |
+
# Install pre-commit hooks
|
166 |
+
echo "Installing pre-commit hooks..."
|
167 |
+
poetry run pre-commit install
|
168 |
+
if [ $? -ne 0 ]; then
|
169 |
+
print_warning "Failed to install pre-commit hooks!"
|
170 |
+
ERRORS_FOUND=$((ERRORS_FOUND + 1))
|
171 |
+
else
|
172 |
+
print_success "Pre-commit hooks installed"
|
173 |
+
fi
|
174 |
+
|
175 |
+
# Run pre-commit hooks on all files
|
176 |
+
echo "Running pre-commit hooks on all files..."
|
177 |
+
poetry run pre-commit run --all-files
|
178 |
+
if [ $? -ne 0 ]; then
|
179 |
+
print_warning "Pre-commit encountered issues with some files"
|
180 |
+
ERRORS_FOUND=$((ERRORS_FOUND + 1))
|
181 |
+
else
|
182 |
+
print_success "Pre-commit initial run complete"
|
183 |
+
fi
|
184 |
+
|
185 |
+
# --- Final Status Message --- #
|
186 |
+
|
187 |
+
# Final status message
|
188 |
+
print_section "Setup Status"
|
189 |
+
if [ $ERRORS_FOUND -eq 0 ]; then
|
190 |
+
print_success "Setup Complete! 🎉"
|
191 |
+
print_success "To activate the virtual environment, run: poetry env activate"
|
192 |
+
else
|
193 |
+
print_warning "Setup completed with warnings and errors! Please check the messages above."
|
194 |
+
echo -e "${YELLOW} ${ERRORS_FOUND} issue(s) were detected that may affect functionality.${NC}"
|
195 |
+
if [ -d ".venv" ]; then
|
196 |
+
echo -e "${YELLOW} You can still activate the environment with: poetry env activate${NC}"
|
197 |
+
else
|
198 |
+
echo -e "${RED} The virtual environment setup failed. Fix the issues before proceeding.${NC}"
|
199 |
+
fi
|
200 |
+
fi
|
src/checkpointing/__init__.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Pico Checkpointing Package
|
3 |
+
|
4 |
+
We subdivide the checkpointing into training, evaluation, and learning_dynamics. Training
|
5 |
+
checkpoints store the model, optimizer, and learning rate scheduler. Evaluation checkpoints store
|
6 |
+
the evaluation results on the defined metrics. Learning dynamics checkpoints store activations and gradients used for
|
7 |
+
learning dynamics analysis.
|
8 |
+
"""
|
9 |
+
|
10 |
+
from .evaluation import save_evaluation_results
|
11 |
+
from .learning_dynamics import (
|
12 |
+
compute_learning_dynamics_states,
|
13 |
+
save_learning_dynamics_states,
|
14 |
+
)
|
15 |
+
from .training import load_checkpoint, save_checkpoint
|
16 |
+
|
17 |
+
__all__ = [
|
18 |
+
"compute_learning_dynamics_states",
|
19 |
+
"load_checkpoint",
|
20 |
+
"save_checkpoint",
|
21 |
+
"save_evaluation_results",
|
22 |
+
"save_learning_dynamics_states",
|
23 |
+
]
|
src/checkpointing/evaluation.py
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Utilities for checkpointing evaluation-related states (i.e. evaluation results, etc.)
|
3 |
+
|
4 |
+
We save the evaluation results in a JSON file at the step-specific evaluation results directory.
|
5 |
+
"""
|
6 |
+
|
7 |
+
import json
|
8 |
+
import os
|
9 |
+
from typing import Any, Dict
|
10 |
+
|
11 |
+
from huggingface_hub import upload_folder
|
12 |
+
from lightning.fabric import Fabric
|
13 |
+
from lightning.fabric.utilities.rank_zero import rank_zero_only
|
14 |
+
|
15 |
+
from src.config import CheckpointingConfig
|
16 |
+
from src.training.utils.io import use_backoff
|
17 |
+
|
18 |
+
|
19 |
+
@rank_zero_only
|
20 |
+
@use_backoff()
|
21 |
+
def save_evaluation_results(
|
22 |
+
checkpointing_config: CheckpointingConfig,
|
23 |
+
checkpoint_step: int,
|
24 |
+
fabric: Fabric,
|
25 |
+
evaluation_results: Dict[str, Any],
|
26 |
+
) -> None:
|
27 |
+
"""Save evaluation results to disk and optionally to HuggingFace Hub.
|
28 |
+
|
29 |
+
The evaluation results are saved in the following directory structure:
|
30 |
+
{checkpointing_config.runs_dir}/
|
31 |
+
└── {checkpointing_config.run_name}/
|
32 |
+
└── {checkpointing_config.eval_results_dir}/
|
33 |
+
└── step_{checkpoint_step}.json
|
34 |
+
|
35 |
+
NOTE: this function is only called on rank 0 to avoid conflicts; assumes that the evaluation
|
36 |
+
results are gathered on rank 0.
|
37 |
+
|
38 |
+
Args:
|
39 |
+
checkpointing_config: Configuration object containing checkpoint settings
|
40 |
+
checkpoint_step: Current training checkpoint step (i.e. number of learning steps taken)
|
41 |
+
fabric: Lightning Fabric instance
|
42 |
+
evaluation_results: Dictionary containing evaluation metrics
|
43 |
+
"""
|
44 |
+
|
45 |
+
run_dir = os.path.join(checkpointing_config.runs_dir, checkpointing_config.run_name)
|
46 |
+
eval_results_dir = os.path.join(
|
47 |
+
run_dir, checkpointing_config.evaluation.eval_results_dir
|
48 |
+
)
|
49 |
+
|
50 |
+
os.makedirs(eval_results_dir, exist_ok=True)
|
51 |
+
|
52 |
+
curr_eval_results_path = os.path.join(
|
53 |
+
eval_results_dir, f"step_{checkpoint_step}.json"
|
54 |
+
)
|
55 |
+
|
56 |
+
# save out as json
|
57 |
+
with open(curr_eval_results_path, "w") as f:
|
58 |
+
json.dump(evaluation_results, f)
|
59 |
+
|
60 |
+
if checkpointing_config.save_to_hf:
|
61 |
+
upload_folder(
|
62 |
+
folder_path=eval_results_dir,
|
63 |
+
path_in_repo=checkpointing_config.evaluation.eval_results_dir,
|
64 |
+
repo_id=checkpointing_config.hf_checkpoint.repo_id,
|
65 |
+
commit_message=f"Saving Evaluation Results -- Step {checkpoint_step}",
|
66 |
+
revision=checkpointing_config.run_name,
|
67 |
+
token=os.getenv("HF_TOKEN"),
|
68 |
+
)
|
src/checkpointing/learning_dynamics.py
ADDED
@@ -0,0 +1,424 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Utilities for checkpointing learning dynamics-related states (i.e. activations, weights, grads, etc.)
|
3 |
+
|
4 |
+
We save the learning dynamics states in a subdirectory of the checkpointing directory.
|
5 |
+
"""
|
6 |
+
|
7 |
+
import os
|
8 |
+
import re
|
9 |
+
from typing import Dict, Optional
|
10 |
+
|
11 |
+
import deepspeed
|
12 |
+
import torch
|
13 |
+
import torch.nn as nn
|
14 |
+
import torch.optim as optim
|
15 |
+
from datasets import Dataset
|
16 |
+
from huggingface_hub import upload_folder
|
17 |
+
from lightning.fabric import Fabric
|
18 |
+
from lightning.fabric.strategies import DeepSpeedStrategy
|
19 |
+
from lightning.fabric.utilities.rank_zero import rank_zero_only
|
20 |
+
from torch.nn import functional as F
|
21 |
+
from torch.utils.data import DataLoader
|
22 |
+
from transformers import PreTrainedTokenizerBase
|
23 |
+
|
24 |
+
from src.config import CheckpointingConfig
|
25 |
+
from src.config.checkpointing_config import LearningDynamicsCheckpointingConfig
|
26 |
+
from src.training.utils.initialization import initialize_model
|
27 |
+
from src.training.utils.io import use_backoff
|
28 |
+
|
29 |
+
|
30 |
+
# NOTE: DeepSpeed requires a dummy optimizer to be passed in to the setup function
|
31 |
+
class DummyOptimizer(optim.Optimizer):
|
32 |
+
def __init__(self, params):
|
33 |
+
super().__init__(params, defaults={})
|
34 |
+
|
35 |
+
|
36 |
+
class CheckpointStateExtractor:
|
37 |
+
"""
|
38 |
+
Class to extract and save the states of a model at a given checkpoint step for learning
|
39 |
+
dynamics research.
|
40 |
+
"""
|
41 |
+
|
42 |
+
def __init__(
|
43 |
+
self,
|
44 |
+
learning_dynamics_config: LearningDynamicsCheckpointingConfig,
|
45 |
+
fabric: Fabric,
|
46 |
+
model: nn.Module,
|
47 |
+
):
|
48 |
+
self.learning_dynamics_config = learning_dynamics_config
|
49 |
+
self.fabric = fabric
|
50 |
+
self.model = model
|
51 |
+
|
52 |
+
def extract_states(self, dataloader, compute_gradients: bool = False):
|
53 |
+
"""Extracts model states (activations, weights, and optionally gradients).
|
54 |
+
|
55 |
+
Given a dataloader, this function will perform a forward pass of the model on each batch,
|
56 |
+
and save the activations and weights at each layer. If compute_gradients is True, it will
|
57 |
+
also compute the gradients of the model parameters.
|
58 |
+
|
59 |
+
Args:
|
60 |
+
dataloader: The dataloader containing the dataset to extract states from.
|
61 |
+
compute_gradients: Whether to compute the gradients of the model parameters.
|
62 |
+
|
63 |
+
Returns:
|
64 |
+
A dictionary containing the activations, weights, and optionally gradients of the model.
|
65 |
+
"""
|
66 |
+
checkpoint_activations = {}
|
67 |
+
checkpoint_weights = {}
|
68 |
+
|
69 |
+
# NOTE: to extract activations and weights, we need to setup forward hooks on the layers
|
70 |
+
# of the model that we are interested in. This is a good intro to forward hooks if you
|
71 |
+
# are not familiar: https://web.stanford.edu/~nanbhas/blog/forward-hooks-pytorch/
|
72 |
+
forward_hooks = self._setup_forward_hooks(
|
73 |
+
checkpoint_activations,
|
74 |
+
checkpoint_weights,
|
75 |
+
)
|
76 |
+
|
77 |
+
########################################################
|
78 |
+
#
|
79 |
+
# Forward Pass: Extract activations and weights; and compute gradients
|
80 |
+
#
|
81 |
+
########################################################
|
82 |
+
|
83 |
+
for sub_batch in dataloader:
|
84 |
+
_input_ids = torch.tensor(sub_batch["input_ids"], device=self.fabric.device)
|
85 |
+
|
86 |
+
if compute_gradients:
|
87 |
+
if "labels" in sub_batch:
|
88 |
+
input_ids = _input_ids
|
89 |
+
labels = torch.tensor(
|
90 |
+
sub_batch["labels"], device=self.fabric.device
|
91 |
+
)
|
92 |
+
else:
|
93 |
+
input_ids = _input_ids[:, :-1]
|
94 |
+
labels = _input_ids[:, 1:]
|
95 |
+
else:
|
96 |
+
input_ids = _input_ids
|
97 |
+
labels = None
|
98 |
+
|
99 |
+
if labels is None:
|
100 |
+
# we can throw away the outputs, we are only interested in the hidden states
|
101 |
+
with torch.no_grad():
|
102 |
+
_ = self.model(input_ids)
|
103 |
+
else:
|
104 |
+
# NOTE: if we are computing gradients, calling backwards will compute the gradients
|
105 |
+
# of the model parameters.
|
106 |
+
outputs, _ = self.model(input_ids)
|
107 |
+
outputs = outputs.transpose(1, 2)
|
108 |
+
loss = F.cross_entropy(outputs, labels)
|
109 |
+
self.fabric.backward(loss, model=self.model)
|
110 |
+
|
111 |
+
# cleanup forward hooks
|
112 |
+
# NOTE this is not strictly necessary, since self.model is a deepcopy of the original model
|
113 |
+
# but it is good practice to remove the hooks after the forward pass is complete.
|
114 |
+
for hook in forward_hooks:
|
115 |
+
hook.remove()
|
116 |
+
|
117 |
+
########################################################
|
118 |
+
#
|
119 |
+
# Extract gradients from the target tensors of the model
|
120 |
+
#
|
121 |
+
########################################################
|
122 |
+
|
123 |
+
layer_suffixes = self.learning_dynamics_config.layer_suffixes
|
124 |
+
checkpoint_gradients = {}
|
125 |
+
if compute_gradients:
|
126 |
+
for name, param in self.model.named_parameters():
|
127 |
+
# only do this for the weight matrix of the layer_suffixes
|
128 |
+
if (
|
129 |
+
any(layer_suffix in name for layer_suffix in layer_suffixes)
|
130 |
+
and "weight" in name
|
131 |
+
):
|
132 |
+
if isinstance(self.fabric.strategy, DeepSpeedStrategy):
|
133 |
+
_grad = deepspeed.utils.safe_get_full_grad(param)
|
134 |
+
else:
|
135 |
+
_grad = param.grad
|
136 |
+
|
137 |
+
assert _grad is not None, f"Gradient is None for layer: {name}"
|
138 |
+
name = re.sub(r"\.weight", "", name)
|
139 |
+
checkpoint_gradients[name] = _grad.detach().cpu()
|
140 |
+
|
141 |
+
# zero out the gradients
|
142 |
+
self.model.zero_grad()
|
143 |
+
|
144 |
+
return checkpoint_activations, checkpoint_weights, checkpoint_gradients
|
145 |
+
|
146 |
+
########################################################
|
147 |
+
#
|
148 |
+
# Setup forward hooks to save activations and weights at each layer
|
149 |
+
#
|
150 |
+
########################################################
|
151 |
+
|
152 |
+
def _setup_forward_hooks(self, checkpoint_activations, checkpoint_weights):
|
153 |
+
"""Setup forward hooks for the model to save activations and weights at each layer.
|
154 |
+
|
155 |
+
This function will setup forward hooks on the layers of the model that we are interested in.
|
156 |
+
The forward hooks will save the activations and weights at each layer whenever the forward pass
|
157 |
+
is performed.
|
158 |
+
|
159 |
+
Args:
|
160 |
+
checkpoint_activations: A dictionary to store the activations at each layer.
|
161 |
+
checkpoint_weights: A dictionary to store the weights at each layer.
|
162 |
+
|
163 |
+
Returns:
|
164 |
+
A list of forward hooks. We do this so that we can remove the hooks after the forward pass
|
165 |
+
is complete.
|
166 |
+
"""
|
167 |
+
|
168 |
+
forward_hooks = []
|
169 |
+
layer_suffixes = self.learning_dynamics_config.layer_suffixes
|
170 |
+
|
171 |
+
for name, module in self.model.named_modules():
|
172 |
+
if any(layer_suffix in name for layer_suffix in layer_suffixes):
|
173 |
+
_forward_hook = module.register_forward_hook(
|
174 |
+
self._get_forward_hook(
|
175 |
+
name, checkpoint_activations, checkpoint_weights
|
176 |
+
)
|
177 |
+
)
|
178 |
+
forward_hooks.append(_forward_hook)
|
179 |
+
return forward_hooks
|
180 |
+
|
181 |
+
def _get_forward_hook(
|
182 |
+
self, module_name, checkpoint_activations, checkpoint_weights
|
183 |
+
):
|
184 |
+
"""Get a forward hook for a given module.
|
185 |
+
|
186 |
+
This function is called by the _setup_forward_hooks function to setup a forward hook for a given
|
187 |
+
module. This functions is a closure that captures the module_name, checkpoint_activations, and
|
188 |
+
checkpoint_weights.
|
189 |
+
|
190 |
+
Args:
|
191 |
+
module_name: The name of the module to setup a forward hook for.
|
192 |
+
checkpoint_activations: A dictionary to store the activations at each layer.
|
193 |
+
checkpoint_weights: A dictionary to store the weights at each layer.
|
194 |
+
|
195 |
+
Returns:
|
196 |
+
A forward hook for the given module.
|
197 |
+
"""
|
198 |
+
|
199 |
+
def _forward_hook(module, _, module_out):
|
200 |
+
sequence_idx = self.learning_dynamics_config.sequence_idx
|
201 |
+
|
202 |
+
local_activations = module_out[:, sequence_idx, :].detach()
|
203 |
+
|
204 |
+
# Gather activations from all processes using fabric
|
205 |
+
gathered_activations = self.fabric.all_gather(local_activations)
|
206 |
+
|
207 |
+
# Reshape from [num_processes, batch_size, hidden_dim] to [total_batch_size, hidden_dim]
|
208 |
+
# NOTE: transposing allows us to interleave the activations from each process so that
|
209 |
+
# they are in the correct order. (i.e. activation N is from data sample N)
|
210 |
+
gathered_activations = gathered_activations.transpose(0, 1).reshape(
|
211 |
+
-1, gathered_activations.shape[-1]
|
212 |
+
)
|
213 |
+
|
214 |
+
# check if there is already a key for the module name
|
215 |
+
if module_name not in checkpoint_activations:
|
216 |
+
# if there is no key, then we create a new key and store the hidden states
|
217 |
+
checkpoint_activations[module_name] = (
|
218 |
+
gathered_activations.detach().cpu()
|
219 |
+
)
|
220 |
+
|
221 |
+
# extract the weight matrix just once
|
222 |
+
weight_matrix = module.weight.detach().cpu()
|
223 |
+
checkpoint_weights[module_name] = weight_matrix
|
224 |
+
else:
|
225 |
+
# if there is already a key, then we concatenate the new hidden states to the existing ones
|
226 |
+
checkpoint_activations[module_name] = torch.cat(
|
227 |
+
(
|
228 |
+
checkpoint_activations[module_name],
|
229 |
+
gathered_activations.detach().cpu(),
|
230 |
+
)
|
231 |
+
)
|
232 |
+
|
233 |
+
return _forward_hook
|
234 |
+
|
235 |
+
|
236 |
+
def compute_learning_dynamics_states(
|
237 |
+
checkpointing_config: CheckpointingConfig,
|
238 |
+
fabric: Fabric,
|
239 |
+
model: nn.Module,
|
240 |
+
dataset: Dataset,
|
241 |
+
compute_gradients: bool = False,
|
242 |
+
) -> Dict[str, torch.Tensor]:
|
243 |
+
"""Computes the learning dynamics metrics for a given checkpoint step.
|
244 |
+
|
245 |
+
Uses the CheckpointStateExtractor to extract the activations, weights, and optionally gradients
|
246 |
+
of the model at a given checkpoint step.
|
247 |
+
|
248 |
+
Args:
|
249 |
+
checkpointing_config: The configuration object for checkpointing.
|
250 |
+
fabric: The Fabric instance for distributed training.
|
251 |
+
model: The model to extract states from.
|
252 |
+
dataset: The dataset to extract states from.
|
253 |
+
compute_gradients: Whether to compute the gradients of the model parameters.
|
254 |
+
|
255 |
+
Returns:
|
256 |
+
A dictionary containing the activations, weights, and optionally gradients of the model.
|
257 |
+
"""
|
258 |
+
|
259 |
+
# NOTE: Synchronizing processes for fabric dataloader setup
|
260 |
+
fabric.barrier()
|
261 |
+
model.to("cpu") # Offloading model to CPU
|
262 |
+
|
263 |
+
# Setting up Dataloader for learning dynamics
|
264 |
+
def _collate_fn(batch):
|
265 |
+
return {"input_ids": [entry["input_ids"] for entry in batch]}
|
266 |
+
|
267 |
+
batch_size = checkpointing_config.learning_dynamics.batch_size
|
268 |
+
sub_batch_size = batch_size // fabric.world_size
|
269 |
+
|
270 |
+
# NOTE: Make sure to set drop_last to False, otherwise the last batch will be dropped
|
271 |
+
# and we will not have a complete set of activations for the last sample. Also,
|
272 |
+
# we need to set shuffle to False, otherwise the activations will be shuffled across
|
273 |
+
# processes and we will not be able to interleave them correctly.
|
274 |
+
extractor_dataloader = DataLoader(
|
275 |
+
dataset,
|
276 |
+
batch_size=sub_batch_size,
|
277 |
+
shuffle=False,
|
278 |
+
collate_fn=_collate_fn,
|
279 |
+
drop_last=False,
|
280 |
+
)
|
281 |
+
extractor_dataloader = fabric.setup_dataloaders(
|
282 |
+
extractor_dataloader, use_distributed_sampler=True
|
283 |
+
)
|
284 |
+
|
285 |
+
# Create a new model instance with same parameters but zero gradients
|
286 |
+
_model = initialize_model(model.config)
|
287 |
+
_model.load_state_dict(model.state_dict())
|
288 |
+
|
289 |
+
if isinstance(fabric.strategy, DeepSpeedStrategy):
|
290 |
+
_model, _ = fabric.setup(_model, DummyOptimizer(_model.parameters()))
|
291 |
+
else:
|
292 |
+
_model = fabric.setup(_model)
|
293 |
+
|
294 |
+
_model.zero_grad()
|
295 |
+
|
296 |
+
# setup forward hooks for the model to save activations and weights at each layer
|
297 |
+
state_extractor = CheckpointStateExtractor(
|
298 |
+
checkpointing_config.learning_dynamics, fabric, _model
|
299 |
+
)
|
300 |
+
|
301 |
+
checkpoint_activations, checkpoint_weights, checkpoint_gradients = (
|
302 |
+
state_extractor.extract_states(
|
303 |
+
extractor_dataloader, compute_gradients=compute_gradients
|
304 |
+
)
|
305 |
+
)
|
306 |
+
|
307 |
+
del _model
|
308 |
+
torch.cuda.empty_cache()
|
309 |
+
|
310 |
+
# NOTE: Synchronizing processes for model setup
|
311 |
+
fabric.barrier()
|
312 |
+
|
313 |
+
model.to(fabric.device)
|
314 |
+
|
315 |
+
# NOTE: Trimming down the activations to match the dataset size;
|
316 |
+
# This is because the DataSampler might add extra samples to the dataset to make it evenly divisible
|
317 |
+
# by the number of processes. We need to remove these extra samples.
|
318 |
+
for layer_name, layer_activations in checkpoint_activations.items():
|
319 |
+
if len(layer_activations) > len(dataset):
|
320 |
+
checkpoint_activations[layer_name] = layer_activations[: len(dataset)]
|
321 |
+
elif len(layer_activations) < len(dataset):
|
322 |
+
raise ValueError(
|
323 |
+
f"Number of activations ({len(layer_activations)}) in layer {layer_name} does not match number of samples in dataset ({len(dataset)})"
|
324 |
+
)
|
325 |
+
|
326 |
+
return {
|
327 |
+
"activations": checkpoint_activations,
|
328 |
+
"weights": checkpoint_weights,
|
329 |
+
"gradients": checkpoint_gradients,
|
330 |
+
}
|
331 |
+
|
332 |
+
|
333 |
+
@rank_zero_only
|
334 |
+
@use_backoff()
|
335 |
+
def save_learning_dynamics_states(
|
336 |
+
checkpointing_config: CheckpointingConfig,
|
337 |
+
checkpoint_step: int,
|
338 |
+
prefix: str,
|
339 |
+
fabric: Fabric,
|
340 |
+
learning_dynamics_states: Dict[str, torch.Tensor],
|
341 |
+
learning_dynamics_dataset: Optional[Dataset] = None,
|
342 |
+
tokenizer: Optional[PreTrainedTokenizerBase] = None,
|
343 |
+
) -> None:
|
344 |
+
"""Save the learning dynamics metrics to the checkpointing directory.
|
345 |
+
|
346 |
+
By default only the learning dynamics states are saved. If the learning dynamics dataset
|
347 |
+
is provided, it is also saved; if a tokenizer is provided, the dataset is also detokenized
|
348 |
+
(i.e. a new column with the text is added to the dataset).
|
349 |
+
|
350 |
+
The learning dynamics dataset is saved in the checkpointing directory as a HuggingFace
|
351 |
+
dataset.
|
352 |
+
|
353 |
+
Creates a versioned checkpoint directory with the following structure:
|
354 |
+
|
355 |
+
{checkpointing_config.runs_dir}/
|
356 |
+
└── {checkpointing_config.run_name}/
|
357 |
+
└── {checkpointing_config.checkpoints_dir}/
|
358 |
+
├── step_{checkpoint_step}/
|
359 |
+
│ └── {checkpointing_config.learning_dynamics_dir}/ # Learning Dynamics files
|
360 |
+
│ ├── {prefix}_activations.pt
|
361 |
+
│ ├── {prefix}_weights.pt
|
362 |
+
│ └── {prefix}_gradients.pt
|
363 |
+
│ └── {prefix}_data/ # if learning_dynamics_dataset is provided
|
364 |
+
└── latest -> step_{checkpoint_step}/
|
365 |
+
|
366 |
+
NOTE: this function is only called on rank 0
|
367 |
+
|
368 |
+
Args:
|
369 |
+
checkpointing_config: The configuration object for checkpointing.
|
370 |
+
checkpoint_step: The checkpoint step at which the learning dynamics states were computed.
|
371 |
+
prefix: The prefix for the learning dynamics states.
|
372 |
+
fabric: The Fabric instance for distributed training.
|
373 |
+
learning_dynamics_states: The learning dynamics states to save.
|
374 |
+
learning_dynamics_dataset: The dataset containing learning dynamics data,
|
375 |
+
including input IDs that need to be decoded. (optional)
|
376 |
+
tokenizer: The tokenizer used to decode input IDs into text. (optional)
|
377 |
+
"""
|
378 |
+
|
379 |
+
runs_dir = checkpointing_config.runs_dir
|
380 |
+
run_name = checkpointing_config.run_name
|
381 |
+
checkpoints_dir = checkpointing_config.checkpoints_dir
|
382 |
+
learning_dynamics_dir = checkpointing_config.learning_dynamics_dir
|
383 |
+
|
384 |
+
run_path = os.path.join(runs_dir, run_name)
|
385 |
+
root_checkpoint_path = os.path.join(run_path, checkpoints_dir)
|
386 |
+
checkpoint_path = os.path.join(root_checkpoint_path, f"step_{checkpoint_step}")
|
387 |
+
learning_dynamics_path = os.path.join(checkpoint_path, learning_dynamics_dir)
|
388 |
+
os.makedirs(learning_dynamics_path, exist_ok=True)
|
389 |
+
|
390 |
+
# save the learning dynamics states
|
391 |
+
for key, value in learning_dynamics_states.items():
|
392 |
+
if value is not None and len(value) > 0:
|
393 |
+
torch.save(
|
394 |
+
value, os.path.join(learning_dynamics_path, f"{prefix}_{key}.pt")
|
395 |
+
)
|
396 |
+
|
397 |
+
if learning_dynamics_dataset is not None:
|
398 |
+
if tokenizer is not None:
|
399 |
+
# go through dataset and decode the input ids; and add back into dataset
|
400 |
+
detokenized_dataset = {"input_ids": [], "text": []}
|
401 |
+
|
402 |
+
for entry in learning_dynamics_dataset:
|
403 |
+
input_ids = entry["input_ids"]
|
404 |
+
decoded_text = tokenizer.decode(input_ids, skip_special_tokens=True)
|
405 |
+
detokenized_dataset["input_ids"].append(input_ids)
|
406 |
+
detokenized_dataset["text"].append(decoded_text)
|
407 |
+
|
408 |
+
learning_dynamics_dataset = Dataset.from_dict(detokenized_dataset)
|
409 |
+
|
410 |
+
learning_dynamics_dataset_path = os.path.join(
|
411 |
+
learning_dynamics_path, f"{prefix}_data"
|
412 |
+
)
|
413 |
+
learning_dynamics_dataset.save_to_disk(learning_dynamics_dataset_path)
|
414 |
+
|
415 |
+
if checkpointing_config.save_to_hf:
|
416 |
+
# Upload the HF model
|
417 |
+
upload_folder(
|
418 |
+
folder_path=learning_dynamics_path,
|
419 |
+
path_in_repo=learning_dynamics_dir,
|
420 |
+
repo_id=checkpointing_config.hf_checkpoint.repo_id,
|
421 |
+
commit_message=f"Saving Learning Dynamics Data ({prefix}) -- Step {checkpoint_step}",
|
422 |
+
revision=checkpointing_config.run_name,
|
423 |
+
token=os.getenv("HF_TOKEN"),
|
424 |
+
)
|
src/checkpointing/training.py
ADDED
@@ -0,0 +1,287 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Utilities for checkpointing training-related states (i.e. model, optimizer, lr_scheduler, etc.)
|
3 |
+
|
4 |
+
We save both a HuggingFace model and a Fabric-specific checkpoint. The HuggingFace model is
|
5 |
+
saved at the step-specific checkpoint directory, while the Fabric-specific checkpoint is saved
|
6 |
+
in a subdirectory. This is done to facilitate easier versioning of the HuggingFace model files
|
7 |
+
(which are what gets uploaded to the Hub).
|
8 |
+
"""
|
9 |
+
|
10 |
+
import os
|
11 |
+
from dataclasses import asdict
|
12 |
+
from typing import Any, Dict, Tuple, Union
|
13 |
+
|
14 |
+
import yaml
|
15 |
+
from huggingface_hub import upload_file, upload_folder
|
16 |
+
from lightning.fabric import Fabric
|
17 |
+
from lightning.fabric.strategies import DeepSpeedStrategy
|
18 |
+
from lightning.fabric.utilities.seed import _collect_rng_states, _set_rng_states
|
19 |
+
from torch import nn
|
20 |
+
from torch.optim import Optimizer
|
21 |
+
from torch.optim.lr_scheduler import LRScheduler
|
22 |
+
from transformers import PreTrainedTokenizerBase
|
23 |
+
|
24 |
+
from src.config import CheckpointingConfig
|
25 |
+
from src.training.utils.io import use_backoff
|
26 |
+
|
27 |
+
|
28 |
+
@use_backoff()
|
29 |
+
def load_checkpoint(
|
30 |
+
checkpointing_config: CheckpointingConfig,
|
31 |
+
checkpoint_step: Union[str, int],
|
32 |
+
fabric: Fabric,
|
33 |
+
model: nn.Module,
|
34 |
+
optimizer: Optimizer,
|
35 |
+
lr_scheduler: LRScheduler,
|
36 |
+
) -> Tuple[nn.Module, Optimizer, LRScheduler, int]:
|
37 |
+
"""Load model checkpoint and associated states from a given step.
|
38 |
+
|
39 |
+
Args:
|
40 |
+
checkpointing_config: Configuration object containing checkpoint settings
|
41 |
+
checkpoint_step: The step at which to load the checkpoint
|
42 |
+
fabric: Lightning Fabric instance for distributed training support
|
43 |
+
model: The model instance to load weights into
|
44 |
+
optimizer: The optimizer instance to load states into
|
45 |
+
lr_scheduler: The learning rate scheduler to load states into
|
46 |
+
|
47 |
+
Returns:
|
48 |
+
Tuple containing the model, optimizer, lr_scheduler, and checkpoint step.
|
49 |
+
Returns None if no checkpoint is found.
|
50 |
+
"""
|
51 |
+
|
52 |
+
if isinstance(checkpoint_step, int):
|
53 |
+
checkpoint_step = f"step_{checkpoint_step}"
|
54 |
+
|
55 |
+
checkpoint_path = os.path.join(
|
56 |
+
checkpointing_config.runs_dir,
|
57 |
+
checkpointing_config.run_name,
|
58 |
+
checkpointing_config.checkpoints_dir,
|
59 |
+
checkpoint_step,
|
60 |
+
)
|
61 |
+
|
62 |
+
if not os.path.exists(checkpoint_path):
|
63 |
+
return None
|
64 |
+
|
65 |
+
# Load from specified fabric checkpoint subdirectory
|
66 |
+
fabric_checkpoint_path = os.path.join(
|
67 |
+
checkpoint_path, checkpointing_config.fabric_checkpoint_dir
|
68 |
+
)
|
69 |
+
|
70 |
+
checkpoint_state = {
|
71 |
+
"_model": model,
|
72 |
+
"_optimizer": optimizer,
|
73 |
+
"_lr_scheduler": lr_scheduler,
|
74 |
+
}
|
75 |
+
|
76 |
+
if not isinstance(fabric.strategy, DeepSpeedStrategy):
|
77 |
+
fabric_load_file = os.path.join(
|
78 |
+
fabric_checkpoint_path, checkpointing_config.fabric_checkpoint_filename
|
79 |
+
)
|
80 |
+
else:
|
81 |
+
# Deepspeed checkpoints create sub-directory with distributed checkpoint file
|
82 |
+
fabric_load_file = fabric_checkpoint_path
|
83 |
+
|
84 |
+
extra_state = fabric.load(os.path.join(fabric_load_file), state=checkpoint_state)
|
85 |
+
|
86 |
+
# NOTE: extra_state will contain any additional states that were saved in the checkpoint
|
87 |
+
checkpoint_step = extra_state["_checkpoint_step"]
|
88 |
+
|
89 |
+
if "_rng_states" in extra_state:
|
90 |
+
_rng_states = extra_state["_rng_states"]
|
91 |
+
_set_rng_states(_rng_states)
|
92 |
+
|
93 |
+
return model, optimizer, lr_scheduler, checkpoint_step
|
94 |
+
|
95 |
+
|
96 |
+
@use_backoff()
|
97 |
+
def save_checkpoint(
|
98 |
+
configs: Dict[str, Any],
|
99 |
+
checkpoint_step: int,
|
100 |
+
fabric: Fabric,
|
101 |
+
model: nn.Module,
|
102 |
+
optimizer: Optimizer,
|
103 |
+
lr_scheduler: LRScheduler,
|
104 |
+
tokenizer: PreTrainedTokenizerBase,
|
105 |
+
upload_logs: bool = False,
|
106 |
+
) -> None:
|
107 |
+
"""Save training checkpoint and associated states to disk and optionally to HuggingFace Hub.
|
108 |
+
|
109 |
+
We save the following files:
|
110 |
+
- HuggingFace model files (config.json, pytorch_model.bin)
|
111 |
+
- Tokenizer files (vocab.json, merges.txt)
|
112 |
+
- Fabric-specific files - fabric state of the model, optimizer, and lr_scheduler. If using
|
113 |
+
DeepSpeed, the checkpoint is saved in a subdirectory, otherwise it is saved in a single file.
|
114 |
+
|
115 |
+
Note that the HuggingFace model files are saved at the step-specific checkpoint directory, while the
|
116 |
+
Fabric-specific files are saved in a subdirectory. This is done to facilitate easier
|
117 |
+
versioning of the HuggingFace model files (which are what gets uploaded to the Hub).
|
118 |
+
|
119 |
+
NOTE: Why do we save a HF model at all? We do this because it makes it easier to load the model
|
120 |
+
in a separate script for evaluation and to play nicely with the HuggingFace Hub.
|
121 |
+
|
122 |
+
Creates a versioned checkpoint directory with the following structure:
|
123 |
+
|
124 |
+
{checkpointing_config.runs_dir}/
|
125 |
+
└── {checkpointing_config.run_name}/
|
126 |
+
└── training_config.yaml # Training config
|
127 |
+
└── {checkpointing_config.checkpoints_dir}/
|
128 |
+
├── step_{checkpoint_step}/
|
129 |
+
│ ├── config.json # HuggingFace model config
|
130 |
+
│ ├── model.safetensors # HuggingFace model weights
|
131 |
+
│ ├── pico_{model_type}.py # HuggingFace custom model class
|
132 |
+
│ ├── tokenizer.json # Tokenizer vocab
|
133 |
+
│ ├── tokenizer_config.json # Tokenizer config
|
134 |
+
│ └── {checkpointing_config.fabric_checkpoint_dir}/ # Fabric-specific files
|
135 |
+
│ └── checkpoint/ # Distributed model checkpoint files (if using DeepSpeed)
|
136 |
+
│ OR
|
137 |
+
│ └── checkpoint.pt # Single checkpoint file (if using other strategies)
|
138 |
+
└── latest -> step_{checkpoint_step}/
|
139 |
+
|
140 |
+
Args:
|
141 |
+
configs: A dictionary containing the initialized configuration objects.
|
142 |
+
checkpoint_step: The current training checkpoint step (i.e. number of learning steps taken)
|
143 |
+
fabric: Lightning Fabric instance for distributed training support
|
144 |
+
model: The model instance to save
|
145 |
+
optimizer: The optimizer instance to save
|
146 |
+
lr_scheduler: The learning rate scheduler to save
|
147 |
+
tokenizer: The tokenizer to save
|
148 |
+
upload_logs: Whether to upload training logs to HF Hub (default: False)
|
149 |
+
|
150 |
+
"""
|
151 |
+
|
152 |
+
checkpointing_config = configs["checkpointing"]
|
153 |
+
|
154 |
+
# Get the directories from the training config
|
155 |
+
runs_dir = checkpointing_config.runs_dir
|
156 |
+
checkpoints_dir = checkpointing_config.checkpoints_dir
|
157 |
+
fabric_checkpoint_dir = checkpointing_config.fabric_checkpoint_dir
|
158 |
+
logs_dir = checkpointing_config.logs_dir
|
159 |
+
|
160 |
+
run_path = os.path.join(runs_dir, checkpointing_config.run_name)
|
161 |
+
root_checkpoint_path = os.path.join(run_path, checkpoints_dir)
|
162 |
+
checkpoint_path = os.path.join(root_checkpoint_path, f"step_{checkpoint_step}")
|
163 |
+
|
164 |
+
# Create directories
|
165 |
+
os.makedirs(checkpoint_path, exist_ok=True)
|
166 |
+
|
167 |
+
########################################################
|
168 |
+
#
|
169 |
+
# Save HuggingFace files
|
170 |
+
#
|
171 |
+
########################################################
|
172 |
+
|
173 |
+
# NOTE: we convert the Pico model to a HuggingFace model before saving it. See `model.py`
|
174 |
+
# for more details.
|
175 |
+
if fabric.global_rank == 0:
|
176 |
+
hf_model = model.convert_to_hf_model()
|
177 |
+
hf_model.save_pretrained(checkpoint_path)
|
178 |
+
tokenizer.save_pretrained(checkpoint_path)
|
179 |
+
|
180 |
+
########################################################
|
181 |
+
#
|
182 |
+
# Save Fabric-specific files
|
183 |
+
#
|
184 |
+
########################################################
|
185 |
+
|
186 |
+
# Create fabric-specific subdirectory
|
187 |
+
fabric_checkpoint_path = os.path.join(checkpoint_path, fabric_checkpoint_dir)
|
188 |
+
os.makedirs(fabric_checkpoint_path, exist_ok=True)
|
189 |
+
|
190 |
+
# Save model states (use underscore to avoid conflicts with third-party libraries)
|
191 |
+
checkpoint_state = {
|
192 |
+
"_model": model,
|
193 |
+
"_optimizer": optimizer,
|
194 |
+
"_lr_scheduler": lr_scheduler,
|
195 |
+
"_checkpoint_step": checkpoint_step,
|
196 |
+
}
|
197 |
+
|
198 |
+
if not isinstance(fabric.strategy, DeepSpeedStrategy):
|
199 |
+
checkpoint_state["_rng_states"] = _collect_rng_states()
|
200 |
+
fabric_save_file = os.path.join(
|
201 |
+
fabric_checkpoint_path, checkpointing_config.fabric_checkpoint_filename
|
202 |
+
)
|
203 |
+
else:
|
204 |
+
# Deepspeed checkpoints create sub-directory with distributed checkpoint file
|
205 |
+
fabric_save_file = fabric_checkpoint_path
|
206 |
+
|
207 |
+
fabric.save(fabric_save_file, checkpoint_state)
|
208 |
+
|
209 |
+
if fabric.global_rank == 0:
|
210 |
+
# Save config in fabric directory
|
211 |
+
config_path = os.path.join(run_path, "training_config.yaml")
|
212 |
+
if not os.path.exists(config_path):
|
213 |
+
# Converting dataclasses to joined dicts and saving to file
|
214 |
+
_training_config = {}
|
215 |
+
for config_name, config in configs.items():
|
216 |
+
_training_config[config_name] = asdict(config)
|
217 |
+
with open(config_path, "w") as f:
|
218 |
+
yaml.dump(_training_config, f)
|
219 |
+
|
220 |
+
# Update latest symlink
|
221 |
+
latest_symlink_path = os.path.join(root_checkpoint_path, "latest")
|
222 |
+
if os.path.lexists(latest_symlink_path):
|
223 |
+
os.remove(latest_symlink_path)
|
224 |
+
os.symlink(
|
225 |
+
f"step_{checkpoint_step}", latest_symlink_path, target_is_directory=True
|
226 |
+
)
|
227 |
+
|
228 |
+
########################################################
|
229 |
+
#
|
230 |
+
# Push to HuggingFace Hub (if configured)
|
231 |
+
#
|
232 |
+
########################################################
|
233 |
+
|
234 |
+
if fabric.global_rank == 0:
|
235 |
+
# Push only on rank zero thread
|
236 |
+
|
237 |
+
if checkpointing_config.save_to_hf:
|
238 |
+
repo_id = checkpointing_config.hf_checkpoint.repo_id
|
239 |
+
|
240 |
+
# Upload the HF model
|
241 |
+
hf_model.push_to_hub(
|
242 |
+
repo_id=repo_id,
|
243 |
+
commit_message=f"Saving HF Model -- Step {checkpoint_step}",
|
244 |
+
revision=checkpointing_config.run_name,
|
245 |
+
token=os.getenv("HF_TOKEN"),
|
246 |
+
)
|
247 |
+
|
248 |
+
if checkpoint_step == 0:
|
249 |
+
# Uploading Tokenizer during first step since it never changes
|
250 |
+
tokenizer.push_to_hub(
|
251 |
+
repo_id=repo_id,
|
252 |
+
commit_message=f"Saving Tokenizer -- Step {checkpoint_step}",
|
253 |
+
revision=checkpointing_config.run_name,
|
254 |
+
token=os.getenv("HF_TOKEN"),
|
255 |
+
)
|
256 |
+
|
257 |
+
# Upload training config, also only in first step
|
258 |
+
upload_file(
|
259 |
+
path_or_fileobj=config_path,
|
260 |
+
path_in_repo="training_config.yaml",
|
261 |
+
repo_id=repo_id,
|
262 |
+
commit_message=f"Saving Training Config -- Step {checkpoint_step}",
|
263 |
+
revision=checkpointing_config.run_name,
|
264 |
+
token=os.getenv("HF_TOKEN"),
|
265 |
+
)
|
266 |
+
|
267 |
+
# Upload the fabric checkpoint directory
|
268 |
+
upload_folder(
|
269 |
+
folder_path=fabric_checkpoint_path,
|
270 |
+
path_in_repo=fabric_checkpoint_dir,
|
271 |
+
repo_id=repo_id,
|
272 |
+
commit_message=f"Saving Fabric Checkpoint -- Step {checkpoint_step}",
|
273 |
+
revision=checkpointing_config.run_name,
|
274 |
+
token=os.getenv("HF_TOKEN"),
|
275 |
+
)
|
276 |
+
|
277 |
+
# Upload logs if requested
|
278 |
+
if upload_logs:
|
279 |
+
logs_path = os.path.join(run_path, logs_dir)
|
280 |
+
upload_folder(
|
281 |
+
folder_path=logs_path,
|
282 |
+
path_in_repo=logs_dir,
|
283 |
+
repo_id=repo_id,
|
284 |
+
commit_message=f"Saving Logs -- Step {checkpoint_step}",
|
285 |
+
revision=checkpointing_config.run_name,
|
286 |
+
token=os.getenv("HF_TOKEN"),
|
287 |
+
)
|
src/config/__init__.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Pico Config Package
|
3 |
+
|
4 |
+
The modules of this package are where you can specify the hyperparameters for the Pico model,
|
5 |
+
the dataset, the training process, evaluation, etc.
|
6 |
+
|
7 |
+
As with anything else in Pico, we've designed for the configuration setup to be as flexible
|
8 |
+
as possible. By default the configs are implemented as vanilla dataclasses -- this makes it easy to
|
9 |
+
switch to different config management systems if you want, like hydra.
|
10 |
+
|
11 |
+
Some things to NOTE:
|
12 |
+
- All hyperparameters are initialized with default values, which can be overridden.
|
13 |
+
- The default vocab size is set to the size of the OLMo tokenizer.
|
14 |
+
"""
|
15 |
+
|
16 |
+
# For convenience, we export the config classes here
|
17 |
+
from .checkpointing_config import CheckpointingConfig
|
18 |
+
from .data_config import DataConfig
|
19 |
+
from .evaluation_config import EvaluationConfig
|
20 |
+
from .model_config import ModelConfig
|
21 |
+
from .monitoring_config import MonitoringConfig
|
22 |
+
from .training_config import TrainingConfig
|
23 |
+
|
24 |
+
__all__ = [
|
25 |
+
"CheckpointingConfig",
|
26 |
+
"DataConfig",
|
27 |
+
"EvaluationConfig",
|
28 |
+
"ModelConfig",
|
29 |
+
"MonitoringConfig",
|
30 |
+
"TrainingConfig",
|
31 |
+
]
|
src/config/_constants.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Constants used throughout the codebase
|
3 |
+
"""
|
4 |
+
|
5 |
+
# Basic Training Constants used throughout the codebase
|
6 |
+
VOCAB_SIZE = 50304
|
7 |
+
MAX_SEQ_LEN = 2048
|
8 |
+
BATCH_SIZE = 1024
|
9 |
+
GRADIENT_ACCUMULATION_STEPS = 128
|
10 |
+
|
11 |
+
# Directories used to store training runs, checkpoints, logs, and evaluation results
|
12 |
+
RUNS_DIR = "runs"
|
13 |
+
CHECKPOINTS_DIR = "checkpoints"
|
14 |
+
LOGS_DIR = "logs"
|
15 |
+
FABRIC_CHECKPOINT_DIR = "fabric_state"
|
16 |
+
FABRIC_CHECKPOINT_FILENAME = "checkpoint.pt"
|
17 |
+
LEARNING_DYNAMICS_DIR = "learning_dynamics"
|
18 |
+
EVAL_RESULTS_DIR = "eval_results"
|
src/config/checkpointing_config.py
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Checkpointing Config
|
3 |
+
|
4 |
+
Specifies the hyperparameters for the checkpointing process; checkpointing is used to save
|
5 |
+
the model and optimizer states, as well as the learning dynamics metrics.
|
6 |
+
"""
|
7 |
+
|
8 |
+
from dataclasses import dataclass, field
|
9 |
+
from typing import List, Optional
|
10 |
+
|
11 |
+
from ._constants import (
|
12 |
+
CHECKPOINTS_DIR,
|
13 |
+
EVAL_RESULTS_DIR,
|
14 |
+
FABRIC_CHECKPOINT_DIR,
|
15 |
+
FABRIC_CHECKPOINT_FILENAME,
|
16 |
+
LEARNING_DYNAMICS_DIR,
|
17 |
+
LOGS_DIR,
|
18 |
+
RUNS_DIR,
|
19 |
+
)
|
20 |
+
|
21 |
+
|
22 |
+
@dataclass
|
23 |
+
class TrainingCheckpointingConfig:
|
24 |
+
# Automatically resume training from the most recent checkpoint
|
25 |
+
auto_resume: bool = True
|
26 |
+
|
27 |
+
|
28 |
+
@dataclass
|
29 |
+
class EvaluationCheckpointingConfig:
|
30 |
+
# Directory in which evaluation results are saved
|
31 |
+
eval_results_dir: str = EVAL_RESULTS_DIR
|
32 |
+
|
33 |
+
|
34 |
+
@dataclass
|
35 |
+
class LearningDynamicsCheckpointingConfig:
|
36 |
+
# Suffixes of the layers to compute learning dynamics for
|
37 |
+
layer_suffixes: List[str] = field(
|
38 |
+
default_factory=lambda: [
|
39 |
+
"attention.v_proj",
|
40 |
+
"attention.o_proj",
|
41 |
+
"swiglu.w_2",
|
42 |
+
]
|
43 |
+
)
|
44 |
+
|
45 |
+
# Sequence index at which to extract hidden states; by default, we extract the hidden states
|
46 |
+
# at the last token of the sequence (-1)
|
47 |
+
sequence_idx: int = -1
|
48 |
+
|
49 |
+
# size of the sub-batch used for extracting learning dynamics states
|
50 |
+
batch_size: int = 8
|
51 |
+
|
52 |
+
# Path to evaluation dataset - used across learning dynamics checkpointing for consistency
|
53 |
+
# NOTE: set to None to disable extracting learning dynamics states for an eval_batch
|
54 |
+
# NOTE: this dataset should be small, ideally just a batch of additional data
|
55 |
+
eval_data: Optional[str] = "pico-lm/pretokenized-paloma-tinsy"
|
56 |
+
|
57 |
+
|
58 |
+
@dataclass
|
59 |
+
class HuggingFaceCheckpointingConfig:
|
60 |
+
# Should be in the format of <(username or organization name)>/<repo_name>, e.g. pico-lm/demo
|
61 |
+
repo_id: str = ""
|
62 |
+
|
63 |
+
# HuggingFace Collection Slug (specifies a tag for the run)
|
64 |
+
collection_slug: Optional[str] = None
|
65 |
+
|
66 |
+
|
67 |
+
@dataclass
|
68 |
+
class CheckpointingConfig:
|
69 |
+
# Assign a name to the run
|
70 |
+
run_name: Optional[str] = None
|
71 |
+
|
72 |
+
# Defining checkpointing directories
|
73 |
+
runs_dir: str = RUNS_DIR
|
74 |
+
checkpoints_dir: str = CHECKPOINTS_DIR
|
75 |
+
logs_dir: str = LOGS_DIR
|
76 |
+
fabric_checkpoint_dir: str = FABRIC_CHECKPOINT_DIR
|
77 |
+
fabric_checkpoint_filename: str = FABRIC_CHECKPOINT_FILENAME
|
78 |
+
learning_dynamics_dir: str = LEARNING_DYNAMICS_DIR
|
79 |
+
|
80 |
+
# How often to save checkpoints
|
81 |
+
save_every_n_steps: int = 1000
|
82 |
+
|
83 |
+
# Whether to save checkpoints to HuggingFace
|
84 |
+
save_to_hf: Optional[bool] = False
|
85 |
+
hf_checkpoint: HuggingFaceCheckpointingConfig = field(
|
86 |
+
default_factory=HuggingFaceCheckpointingConfig
|
87 |
+
)
|
88 |
+
|
89 |
+
training: TrainingCheckpointingConfig = field(
|
90 |
+
default_factory=TrainingCheckpointingConfig
|
91 |
+
)
|
92 |
+
evaluation: EvaluationCheckpointingConfig = field(
|
93 |
+
default_factory=EvaluationCheckpointingConfig
|
94 |
+
)
|
95 |
+
learning_dynamics: LearningDynamicsCheckpointingConfig = field(
|
96 |
+
default_factory=LearningDynamicsCheckpointingConfig
|
97 |
+
)
|
src/config/data_config.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Data Config
|
3 |
+
|
4 |
+
Specifies the hyperparameters for the dataset, dataloader, and tokenizer.
|
5 |
+
"""
|
6 |
+
|
7 |
+
from dataclasses import dataclass, field
|
8 |
+
|
9 |
+
from ._constants import BATCH_SIZE, VOCAB_SIZE
|
10 |
+
|
11 |
+
|
12 |
+
@dataclass
|
13 |
+
class DatasetConfig:
|
14 |
+
# Defines the HuggingFace name of a dataset
|
15 |
+
name: str = "pico-lm/pretokenized-dolma"
|
16 |
+
|
17 |
+
|
18 |
+
@dataclass
|
19 |
+
class DataLoaderConfig:
|
20 |
+
# NOTE: You should only change these values jointly with the training config; so that the
|
21 |
+
# sub-batch size is consistent with the gradient accumulation steps
|
22 |
+
batch_size: int = BATCH_SIZE
|
23 |
+
|
24 |
+
|
25 |
+
@dataclass
|
26 |
+
class TokenizerConfig:
|
27 |
+
# Specify a tokenizer to use
|
28 |
+
name: str = "allenai/OLMo-7B-0724-hf"
|
29 |
+
vocab_size: int = VOCAB_SIZE
|
30 |
+
|
31 |
+
|
32 |
+
@dataclass
|
33 |
+
class DataConfig:
|
34 |
+
dataset: DatasetConfig = field(default_factory=DatasetConfig)
|
35 |
+
dataloader: DataLoaderConfig = field(default_factory=DataLoaderConfig)
|
36 |
+
tokenizer: TokenizerConfig = field(default_factory=TokenizerConfig)
|
src/config/evaluation_config.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Evaluation Config
|
3 |
+
|
4 |
+
Specifies the hyperparameters for the evaluation process, i.e. what metrics to compute, etc.
|
5 |
+
"""
|
6 |
+
|
7 |
+
from dataclasses import dataclass, field
|
8 |
+
from typing import List, Optional
|
9 |
+
|
10 |
+
from src.config._constants import MAX_SEQ_LEN
|
11 |
+
|
12 |
+
|
13 |
+
@dataclass
|
14 |
+
class PalomaEvaluationConfig:
|
15 |
+
dataset_name: str = "pico-lm/pretokenized-paloma-tinsy"
|
16 |
+
dataset_split: str = "val"
|
17 |
+
max_length: int = MAX_SEQ_LEN
|
18 |
+
batch_size: int = 16
|
19 |
+
|
20 |
+
|
21 |
+
@dataclass
|
22 |
+
class EvaluationConfig:
|
23 |
+
# Evaluation metrics to compute: by default, we compute the perplexity of the model on the paloma dataset
|
24 |
+
metrics: Optional[List[str]] = field(default_factory=lambda: ["paloma"])
|
25 |
+
|
26 |
+
# NOTE: Add other evaluation configs here
|
27 |
+
# Each evaluation metric should have its own config
|
28 |
+
paloma: PalomaEvaluationConfig = field(default_factory=PalomaEvaluationConfig)
|
src/config/model_config.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Model Config
|
3 |
+
|
4 |
+
Specifies the hyperparameters for the Pico model/model architecture.
|
5 |
+
"""
|
6 |
+
|
7 |
+
from dataclasses import dataclass
|
8 |
+
from typing import Optional
|
9 |
+
|
10 |
+
from ._constants import BATCH_SIZE, MAX_SEQ_LEN, VOCAB_SIZE
|
11 |
+
|
12 |
+
|
13 |
+
@dataclass
|
14 |
+
class ModelConfig:
|
15 |
+
model_type: str = "pico_decoder"
|
16 |
+
|
17 |
+
# Pico Decoder default hyperparameters
|
18 |
+
|
19 |
+
d_model: int = 768
|
20 |
+
n_layers: int = 12
|
21 |
+
|
22 |
+
vocab_size: int = VOCAB_SIZE
|
23 |
+
batch_size: int = BATCH_SIZE
|
24 |
+
max_seq_len: int = MAX_SEQ_LEN
|
25 |
+
|
26 |
+
attention_n_heads: int = 12
|
27 |
+
attention_n_kv_heads: Optional[int] = 4
|
28 |
+
|
29 |
+
activation_hidden_dim: int = 3072
|
30 |
+
|
31 |
+
norm_eps: float = 1e-6
|
32 |
+
|
33 |
+
position_emb_theta: float = 10000.0
|
src/config/monitoring_config.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Monitoring Config
|
3 |
+
|
4 |
+
Specifies the monitoring process, e.g. how to log metrics and keep track of training progress.
|
5 |
+
"""
|
6 |
+
|
7 |
+
from dataclasses import dataclass, field
|
8 |
+
|
9 |
+
|
10 |
+
@dataclass
|
11 |
+
class LoggingConfig:
|
12 |
+
log_level: str = "INFO"
|
13 |
+
log_every_n_steps: int = 100
|
14 |
+
|
15 |
+
|
16 |
+
@dataclass
|
17 |
+
class WandbConfig:
|
18 |
+
# configure logging to Weights and Biases
|
19 |
+
project: str = ""
|
20 |
+
entity: str = ""
|
21 |
+
|
22 |
+
|
23 |
+
@dataclass
|
24 |
+
class MonitoringConfig:
|
25 |
+
logging: LoggingConfig = field(default_factory=LoggingConfig)
|
26 |
+
|
27 |
+
# Weights and Biases
|
28 |
+
save_to_wandb: bool = False
|
29 |
+
wandb: WandbConfig = field(default_factory=WandbConfig)
|
src/config/training_config.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Training Config
|
3 |
+
|
4 |
+
Specifies the hyperparameters for the training process, i.e. the optimizer, learning rate, etc.
|
5 |
+
"""
|
6 |
+
|
7 |
+
from dataclasses import dataclass, field
|
8 |
+
|
9 |
+
from ._constants import GRADIENT_ACCUMULATION_STEPS
|
10 |
+
|
11 |
+
|
12 |
+
@dataclass
|
13 |
+
class FabricConfig:
|
14 |
+
# Configure nodes/devices for parallelised training
|
15 |
+
num_nodes: int = 1
|
16 |
+
num_devices: int = 1
|
17 |
+
precision: str = "bf16-mixed"
|
18 |
+
# Hardware accelerator to use, can be cpu/cuda/mps etc.
|
19 |
+
accelerator: str = "cuda"
|
20 |
+
|
21 |
+
|
22 |
+
@dataclass
|
23 |
+
class OptimizationConfig:
|
24 |
+
# Optimizer
|
25 |
+
optimizer: str = "adamw"
|
26 |
+
lr: float = 3e-4
|
27 |
+
|
28 |
+
# Learning Rate Scheduler
|
29 |
+
lr_scheduler: str = "linear_with_warmup"
|
30 |
+
lr_warmup_steps: int = 2500
|
31 |
+
|
32 |
+
# Define number of gradient accumulation steps
|
33 |
+
gradient_accumulation_steps: int = GRADIENT_ACCUMULATION_STEPS
|
34 |
+
|
35 |
+
|
36 |
+
@dataclass
|
37 |
+
class TrainingConfig:
|
38 |
+
fabric: FabricConfig = field(default_factory=FabricConfig)
|
39 |
+
optimization: OptimizationConfig = field(default_factory=OptimizationConfig)
|
40 |
+
max_steps: int = 200_000
|
src/evaluation/__init__.py
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Pico Evaluation Package
|
3 |
+
|
4 |
+
This package implements the evaluation pipeline for the Pico language model. It provides
|
5 |
+
functionality to evaluate model performance using various metrics and handles the complete
|
6 |
+
evaluation workflow.
|
7 |
+
|
8 |
+
We recommend that each evaluation metric should have its own config, and should be
|
9 |
+
implemented as a module in the `evaluation/tasks` directory that exposes a `run_<metric_name>` function.
|
10 |
+
|
11 |
+
NOTE: Out of the box we only support Paloma, but the structure is designed to be flexible and
|
12 |
+
you are meant to add whatever metrics you want. One of the main reasons we store out
|
13 |
+
the model in the HuggingFace format is so that its easy to use third-party evaluation
|
14 |
+
libraries/frameworks.
|
15 |
+
"""
|
16 |
+
|
17 |
+
import os
|
18 |
+
|
19 |
+
import torch
|
20 |
+
from lightning.fabric import Fabric
|
21 |
+
from torch import nn
|
22 |
+
|
23 |
+
from src.config import CheckpointingConfig, EvaluationConfig
|
24 |
+
|
25 |
+
from .tasks.paloma import run_paloma_evaluation
|
26 |
+
|
27 |
+
|
28 |
+
def run_evaluation(
|
29 |
+
evaluation_config: EvaluationConfig,
|
30 |
+
checkpointing_config: CheckpointingConfig,
|
31 |
+
fabric: Fabric,
|
32 |
+
model: nn.Module,
|
33 |
+
) -> None:
|
34 |
+
"""Run model evaluation using specified metrics in `evaluation_config`.
|
35 |
+
|
36 |
+
This function orchestrates the complete evaluation pipeline by:
|
37 |
+
1. Resolving the model checkpoint path (either specified or latest) to load the model from;
|
38 |
+
during training, this is the path to the latest checkpoint in the run directory.
|
39 |
+
2. Iterating over each evaluation metric, and running the corresponding evaluation function.
|
40 |
+
NOTE: we suggest you follow the pattern of the Paloma evaluation function, and implement
|
41 |
+
your own evaluation function for each metric in the `evaluation/tasks` directory.
|
42 |
+
3. Aggregating results across all metrics in a dictionary, and returning it.
|
43 |
+
|
44 |
+
Args:
|
45 |
+
evaluation_config (EvaluationConfig): Configuration object containing:
|
46 |
+
- metrics (List[str]): Metrics to evaluate; each metric should have its
|
47 |
+
own config. Currently supported: ["paloma"];
|
48 |
+
- paloma (PalomaConfig): Configuration for Paloma evaluation
|
49 |
+
- max_length (int): Maximum sequence length
|
50 |
+
- limit_eval_examples (Optional[int]): Number of examples to evaluate
|
51 |
+
checkpointing_config (CheckpointingConfig): Configuration object containing:
|
52 |
+
fabric (Fabric): Lightning Fabric instance
|
53 |
+
model (nn.Module): Original model instance
|
54 |
+
|
55 |
+
Returns:
|
56 |
+
Dict[str, float]: Dictionary mapping metric names to their values
|
57 |
+
Example: {"paloma": 3.45}
|
58 |
+
|
59 |
+
Raises:
|
60 |
+
ValueError: If an unsupported evaluation metric is requested
|
61 |
+
|
62 |
+
Example:
|
63 |
+
results = run_evaluation(
|
64 |
+
EvaluationConfig(
|
65 |
+
run_name="experiment_1",
|
66 |
+
metrics=["paloma"],
|
67 |
+
paloma=PalomaConfig(max_length=2048, batch_size=16)
|
68 |
+
)
|
69 |
+
)
|
70 |
+
|
71 |
+
"""
|
72 |
+
|
73 |
+
fabric.barrier()
|
74 |
+
|
75 |
+
model.to("cpu") # Offloading model to CPU
|
76 |
+
|
77 |
+
evaluation_results = {}
|
78 |
+
|
79 |
+
# NOTE: Evaluation is only run on first processes to enable third-party evaluation libraries
|
80 |
+
# to determine how to handle distributed evaluation.
|
81 |
+
if fabric.global_rank == 0:
|
82 |
+
run_name = checkpointing_config.run_name
|
83 |
+
model_path = f"{os.getcwd()}/{checkpointing_config.runs_dir}/{run_name}/{checkpointing_config.checkpoints_dir}/latest"
|
84 |
+
os.makedirs(model_path, exist_ok=True)
|
85 |
+
|
86 |
+
for metric in evaluation_config.metrics:
|
87 |
+
# NOTE: add your own metrics here
|
88 |
+
if metric == "paloma":
|
89 |
+
evaluation_result = run_paloma_evaluation(
|
90 |
+
model_path, evaluation_config.paloma
|
91 |
+
)
|
92 |
+
else:
|
93 |
+
raise ValueError(f"Metric {metric} not supported")
|
94 |
+
|
95 |
+
evaluation_results[metric] = evaluation_result
|
96 |
+
|
97 |
+
torch.cuda.empty_cache()
|
98 |
+
|
99 |
+
fabric.barrier()
|
100 |
+
|
101 |
+
model.to(fabric.device)
|
102 |
+
|
103 |
+
return evaluation_results
|
src/evaluation/tasks/paloma.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Paloma is a comprehensive evaluation benchmark for large language models (LLMs) that focuses
|
3 |
+
on measuring perplexity across diverse text domains.
|
4 |
+
|
5 |
+
To evaluate on Paloma, we use the huggingface evaluation framework.
|
6 |
+
|
7 |
+
For more details, see: https://huggingface.co/datasets/allenai/paloma
|
8 |
+
"""
|
9 |
+
|
10 |
+
import evaluate
|
11 |
+
from datasets import load_dataset
|
12 |
+
from datasets.utils.logging import disable_progress_bar, enable_progress_bar
|
13 |
+
|
14 |
+
from src.config.evaluation_config import PalomaEvaluationConfig
|
15 |
+
|
16 |
+
|
17 |
+
def run_paloma_evaluation(
|
18 |
+
model_path: str,
|
19 |
+
paloma_config: PalomaEvaluationConfig,
|
20 |
+
) -> None:
|
21 |
+
"""Run Perplexity evaluation on the Paloma evaluation dataset.
|
22 |
+
|
23 |
+
We use the HuggingFace evaluate library to load in and compute the perplexity metric.
|
24 |
+
|
25 |
+
Args:
|
26 |
+
model_path (str): Path to the model checkpoint to be evaluated
|
27 |
+
paloma_config (PalomaEvaluationConfig): Configuration for Paloma evaluation
|
28 |
+
"""
|
29 |
+
|
30 |
+
disable_progress_bar()
|
31 |
+
|
32 |
+
# load custom evaluation space, see https://huggingface.co/spaces/pico-lm/perplexity
|
33 |
+
perplexity = evaluate.load("pico-lm/perplexity")
|
34 |
+
|
35 |
+
dataset = load_dataset(
|
36 |
+
paloma_config.dataset_name, split=paloma_config.dataset_split
|
37 |
+
)["text"]
|
38 |
+
|
39 |
+
# compute perplexity score on Paloma dataset
|
40 |
+
perplexity_result = perplexity.compute(
|
41 |
+
model_id=model_path,
|
42 |
+
predictions=dataset,
|
43 |
+
add_start_token=False,
|
44 |
+
max_length=paloma_config.max_length,
|
45 |
+
batch_size=paloma_config.batch_size,
|
46 |
+
trust_remote_code=True,
|
47 |
+
)
|
48 |
+
|
49 |
+
mean_perplexity = perplexity_result["mean_perplexity"]
|
50 |
+
|
51 |
+
enable_progress_bar()
|
52 |
+
return mean_perplexity
|
src/model/__init__.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Model Package
|
3 |
+
|
4 |
+
This Package contains Pico models (currently only the Pico Decoder). We plan to implement other
|
5 |
+
architectures in the future.
|
6 |
+
|
7 |
+
If you have other models you'd like to implement, we recommend you add modules to this package.
|
8 |
+
"""
|
9 |
+
|
10 |
+
from .pico_decoder import PicoDecoder
|
11 |
+
|
12 |
+
__all__ = ["PicoDecoder"]
|
src/model/pico_decoder.py
ADDED
@@ -0,0 +1,911 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Pico Decoder: A Lightweight Causal Transformer Language Model
|
3 |
+
|
4 |
+
Pico Decoder uses a simple LLAMA-style transformer architecture, written for clarity and educational purposes.
|
5 |
+
|
6 |
+
Everything is written with a modular design for easy modification and experimentation.
|
7 |
+
|
8 |
+
Key features:
|
9 |
+
- RMSNorm for layer normalization
|
10 |
+
- Rotary Positional Embeddings (RoPE)
|
11 |
+
- Multi-head attention with KV-cache support
|
12 |
+
- SwiGLU activation function
|
13 |
+
- Residual connections throughout
|
14 |
+
|
15 |
+
- KV-cache for faster autoregressive generation
|
16 |
+
|
17 |
+
References:
|
18 |
+
- RoPE: https://arxiv.org/abs/2104.09864
|
19 |
+
- SwiGLU: https://arxiv.org/abs/2002.05202
|
20 |
+
- LLAMA: https://arxiv.org/abs/2302.13971
|
21 |
+
|
22 |
+
Adapted from:
|
23 |
+
- OLMO: https://github.com/allenai/OLMo
|
24 |
+
- LLAMA: https://github.com/meta/llama
|
25 |
+
"""
|
26 |
+
|
27 |
+
from dataclasses import asdict
|
28 |
+
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
|
29 |
+
|
30 |
+
import torch
|
31 |
+
import torch.nn as nn
|
32 |
+
import torch.nn.functional as F
|
33 |
+
|
34 |
+
# Handle PyTorch version compatibility for attention backend
|
35 |
+
try:
|
36 |
+
from torch.nn.attention import SDPBackend, sdpa_kernel
|
37 |
+
|
38 |
+
HAS_TORCH_ATTENTION = True
|
39 |
+
except ImportError:
|
40 |
+
# Fallback for older PyTorch versions
|
41 |
+
HAS_TORCH_ATTENTION = False
|
42 |
+
SDPBackend = None
|
43 |
+
sdpa_kernel = None
|
44 |
+
|
45 |
+
from transformers import GenerationMixin, PretrainedConfig, PreTrainedModel
|
46 |
+
from transformers.generation import GenerationConfig
|
47 |
+
from transformers.modeling_outputs import CausalLMOutput, CausalLMOutputWithPast
|
48 |
+
|
49 |
+
try:
|
50 |
+
if TYPE_CHECKING:
|
51 |
+
# We need to do this to avoid importing these when creating the HF-compatible models
|
52 |
+
from src.config import ModelConfig
|
53 |
+
except ImportError:
|
54 |
+
pass
|
55 |
+
|
56 |
+
########################################################
|
57 |
+
#
|
58 |
+
# Layer Normalization
|
59 |
+
#
|
60 |
+
########################################################
|
61 |
+
|
62 |
+
|
63 |
+
class RMSNorm(torch.nn.Module):
|
64 |
+
"""Root Mean Square Layer Normalization.
|
65 |
+
|
66 |
+
A variant of Layer Normalization that uses RMS statistics instead of mean/variance,
|
67 |
+
resulting in improved stability and performance.
|
68 |
+
|
69 |
+
Args:
|
70 |
+
config (Union[ModelConfig, PicoHFConfig]): Configuration object containing normalization parameters
|
71 |
+
- config.norm_eps: Small constant for numerical stability
|
72 |
+
- config.d_model: Model dimension for the weight parameter
|
73 |
+
|
74 |
+
References:
|
75 |
+
https://arxiv.org/abs/1910.07467
|
76 |
+
"""
|
77 |
+
|
78 |
+
def __init__(self, config: Union["ModelConfig", "PicoDecoderHFConfig"]):
|
79 |
+
super().__init__()
|
80 |
+
self.eps = config.norm_eps
|
81 |
+
self.weight = nn.Parameter(torch.ones(config.d_model))
|
82 |
+
|
83 |
+
def _norm(self, x: torch.Tensor) -> torch.Tensor:
|
84 |
+
"""
|
85 |
+
Normalizes the input tensor by its RMS value.
|
86 |
+
"""
|
87 |
+
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
88 |
+
|
89 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
90 |
+
"""
|
91 |
+
Applies RMS normalization to the input tensor and scales it by the weight parameter.
|
92 |
+
"""
|
93 |
+
output = self._norm(x.float()).type_as(x)
|
94 |
+
return output * self.weight
|
95 |
+
|
96 |
+
|
97 |
+
########################################################
|
98 |
+
#
|
99 |
+
# Positional Embedding
|
100 |
+
#
|
101 |
+
########################################################
|
102 |
+
|
103 |
+
|
104 |
+
class RoPE(nn.Module):
|
105 |
+
"""Rotary Positional Embeddings (RoPE).
|
106 |
+
|
107 |
+
Implements position-dependent rotation of keys and queries in attention mechanism,
|
108 |
+
allowing better modeling of relative positions in sequences. Uses complex number
|
109 |
+
operations for efficient rotation.
|
110 |
+
|
111 |
+
Args:
|
112 |
+
config (Union[ModelConfig, PicoHFConfig]): Model configuration containing:
|
113 |
+
- config.position_emb_theta: Base for frequency computation
|
114 |
+
- config.d_model: Model dimension
|
115 |
+
- config.attention_n_heads: Number of attention heads
|
116 |
+
- config.max_seq_len: Maximum sequence length
|
117 |
+
|
118 |
+
References:
|
119 |
+
https://arxiv.org/abs/2104.09864
|
120 |
+
"""
|
121 |
+
|
122 |
+
_freqs_cis_tensor: torch.Tensor | None = None
|
123 |
+
|
124 |
+
def __init__(self, config: Union["ModelConfig", "PicoDecoderHFConfig"]):
|
125 |
+
super().__init__()
|
126 |
+
|
127 |
+
self.theta = config.position_emb_theta
|
128 |
+
self.dim = config.d_model // config.attention_n_heads
|
129 |
+
|
130 |
+
max_seq_len = config.max_seq_len
|
131 |
+
|
132 |
+
# only gets set once, and then reused for all RoPE instances
|
133 |
+
if RoPE._freqs_cis_tensor is None:
|
134 |
+
RoPE._freqs_cis_tensor = self._setup_freqs_cis(
|
135 |
+
max_seq_len, self.theta, self.dim
|
136 |
+
)
|
137 |
+
|
138 |
+
# register _freqs_cis buffer
|
139 |
+
# can be easily recomputed so persistent=False
|
140 |
+
self.register_buffer("_freqs_cis", self._freqs_cis_tensor, persistent=False)
|
141 |
+
|
142 |
+
@classmethod
|
143 |
+
def _setup_freqs_cis(cls, seq_len: int, theta: float, dim: int) -> torch.Tensor:
|
144 |
+
"""Setup Frequency Tensor for RoPE Embeddings
|
145 |
+
|
146 |
+
Initializes the complex frequency tensor that is used to compute the RoPE embeddings.
|
147 |
+
|
148 |
+
Note other implementations will use cos and sin directly, but using the complex
|
149 |
+
number representation is (probably) more efficient:
|
150 |
+
|
151 |
+
e^(theta * i * t) = cos(theta * t) + i * sin(theta * t) [Euler's formula]
|
152 |
+
"""
|
153 |
+
_freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
|
154 |
+
positions = torch.arange(seq_len)
|
155 |
+
freqs = torch.outer(positions, _freqs)
|
156 |
+
return torch.polar(torch.ones_like(freqs), freqs) # complex64
|
157 |
+
|
158 |
+
def get_freqs_cis(
|
159 |
+
self, input_shape: torch.Size, start_pos: int, end_pos: int
|
160 |
+
) -> torch.Tensor:
|
161 |
+
"""Reshape Frequency Tensor for RoPE Embeddings
|
162 |
+
|
163 |
+
Makes the frequency tensor broadcastable with the input tensor.
|
164 |
+
"""
|
165 |
+
_freqs_cis = self._freqs_cis[start_pos:end_pos]
|
166 |
+
ndim = len(input_shape)
|
167 |
+
assert 0 <= 1 < ndim
|
168 |
+
assert _freqs_cis.shape == (input_shape[1], input_shape[-1])
|
169 |
+
|
170 |
+
# TODO: Check whether this is correct (might be able to remove this)
|
171 |
+
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(input_shape)]
|
172 |
+
return _freqs_cis.view(*shape)
|
173 |
+
|
174 |
+
def forward(
|
175 |
+
self,
|
176 |
+
queries: torch.Tensor,
|
177 |
+
keys: torch.Tensor,
|
178 |
+
start_pos: int = 0,
|
179 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
180 |
+
"""Apply RoPE Embeddings to Queries and Keys
|
181 |
+
|
182 |
+
Applies the rotary positional embeddings to the input tensors via complex num multiplication
|
183 |
+
|
184 |
+
NOTE: The start_pos is used if we want to use the kv_cache in the attention mechanism.
|
185 |
+
"""
|
186 |
+
queries_ = torch.view_as_complex(
|
187 |
+
queries.float().reshape(*queries.shape[:-1], -1, 2)
|
188 |
+
)
|
189 |
+
keys_ = torch.view_as_complex(keys.float().reshape(*keys.shape[:-1], -1, 2))
|
190 |
+
|
191 |
+
input_shape = (
|
192 |
+
queries_.shape
|
193 |
+
) # same as keys: (batch_size, seq_len, n_heads, head_dim/2)
|
194 |
+
freqs_start_pos = start_pos
|
195 |
+
freqs_end_pos = freqs_start_pos + queries_.shape[1]
|
196 |
+
|
197 |
+
freqs_cis = self.get_freqs_cis(input_shape, freqs_start_pos, freqs_end_pos)
|
198 |
+
|
199 |
+
queries_rotated = torch.view_as_real(queries_ * freqs_cis).flatten(3)
|
200 |
+
keys_rotated = torch.view_as_real(keys_ * freqs_cis).flatten(3)
|
201 |
+
return queries_rotated.type_as(queries), keys_rotated.type_as(keys)
|
202 |
+
|
203 |
+
|
204 |
+
########################################################
|
205 |
+
#
|
206 |
+
# Attention
|
207 |
+
#
|
208 |
+
########################################################
|
209 |
+
|
210 |
+
|
211 |
+
class Attention(nn.Module):
|
212 |
+
"""Multi-head Attention with Group Query Attention support.
|
213 |
+
|
214 |
+
Implements scaled dot-product attention and supports:
|
215 |
+
- Grouped Query Attention (GQA)
|
216 |
+
- Key-Value caching for efficient inference
|
217 |
+
- RoPE integration
|
218 |
+
|
219 |
+
Args:
|
220 |
+
config (Union[ModelConfig, PretrainedConfig]): Configuration containing:
|
221 |
+
- config.attention_n_heads: Number of attention heads
|
222 |
+
- config.attention_n_kv_heads: Number of key/value heads
|
223 |
+
- config.d_model: Model dimension
|
224 |
+
- config.batch_size: Maximum batch size
|
225 |
+
- config.max_seq_len: Maximum sequence length
|
226 |
+
|
227 |
+
Shape:
|
228 |
+
- Input: (batch_size, seq_len, d_model)
|
229 |
+
- Output: (batch_size, seq_len, d_model)
|
230 |
+
"""
|
231 |
+
|
232 |
+
def __init__(
|
233 |
+
self,
|
234 |
+
config: Union["ModelConfig", "PicoDecoderHFConfig"],
|
235 |
+
):
|
236 |
+
super().__init__()
|
237 |
+
|
238 |
+
self.n_heads = config.attention_n_heads
|
239 |
+
self.n_kv_heads = config.attention_n_kv_heads
|
240 |
+
|
241 |
+
self.batch_size = config.batch_size
|
242 |
+
self.max_seq_len = config.max_seq_len
|
243 |
+
|
244 |
+
d_model = config.d_model
|
245 |
+
self.head_dim = d_model // self.n_heads
|
246 |
+
|
247 |
+
self.n_rep = self.n_heads // self.n_kv_heads
|
248 |
+
|
249 |
+
self.q_proj = nn.Linear(d_model, self.n_heads * self.head_dim, bias=False)
|
250 |
+
self.k_proj = nn.Linear(d_model, self.n_kv_heads * self.head_dim, bias=False)
|
251 |
+
self.v_proj = nn.Linear(d_model, self.n_kv_heads * self.head_dim, bias=False)
|
252 |
+
self.o_proj = nn.Linear(self.n_heads * self.head_dim, d_model, bias=False)
|
253 |
+
|
254 |
+
self.rope = RoPE(config)
|
255 |
+
|
256 |
+
def forward(
|
257 |
+
self,
|
258 |
+
input: torch.Tensor,
|
259 |
+
mask: Optional[torch.Tensor] = None,
|
260 |
+
past_key_values: Optional[Tuple[torch.Tensor, ...]] = None,
|
261 |
+
use_cache: bool = False,
|
262 |
+
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
|
263 |
+
"""Forward pass for the attention mechanism.
|
264 |
+
|
265 |
+
Computes queries, keys, and values for the attention mechanism. Applies rotary positional
|
266 |
+
embeddings to the queries and keys, and then computes attention scores and outputs.
|
267 |
+
|
268 |
+
For an introduction to the attention mechanism, see:
|
269 |
+
https://arxiv.org/abs/1706.03762
|
270 |
+
|
271 |
+
A few things to note:
|
272 |
+
- The past_key_values is used to implement the KV cache, which is used to speed up
|
273 |
+
generation by caching the KV pairs from previous forward passes. This is useful when doing
|
274 |
+
tasks that require generating multiple tokens conditioned on previous tokens (e.g. language
|
275 |
+
modeling, text generation, etc.). The way the KV cache is implemented is that each layer has
|
276 |
+
its own KV cache - this KV cache is implemented as a tuple.
|
277 |
+
"""
|
278 |
+
bsz, seq_len, _ = input.shape
|
279 |
+
_queries, _keys, _values = (
|
280 |
+
self.q_proj(input),
|
281 |
+
self.k_proj(input),
|
282 |
+
self.v_proj(input),
|
283 |
+
)
|
284 |
+
|
285 |
+
# Reshaping for multi-head attention
|
286 |
+
queries = _queries.view(bsz, seq_len, self.n_heads, self.head_dim)
|
287 |
+
keys = _keys.view(bsz, seq_len, self.n_kv_heads, self.head_dim)
|
288 |
+
values = _values.view(bsz, seq_len, self.n_kv_heads, self.head_dim)
|
289 |
+
|
290 |
+
# The start position is used to apply the RoPE embeddings to only the new tokens
|
291 |
+
# when using the kv_cache in the attention mechanism.
|
292 |
+
# We want to start from the last position in the cache.
|
293 |
+
start_pos = 0
|
294 |
+
if past_key_values is not None and past_key_values[0] is not None:
|
295 |
+
start_pos = past_key_values[0].shape[1]
|
296 |
+
|
297 |
+
# apply rotary positional embeddings
|
298 |
+
queries, keys = self.rope(queries, keys, start_pos)
|
299 |
+
|
300 |
+
if (
|
301 |
+
past_key_values is not None
|
302 |
+
and past_key_values[0] is not None
|
303 |
+
and past_key_values[1] is not None
|
304 |
+
):
|
305 |
+
keys = torch.cat([past_key_values[0], keys], dim=1)
|
306 |
+
values = torch.cat([past_key_values[1], values], dim=1)
|
307 |
+
|
308 |
+
if use_cache:
|
309 |
+
cached_keys = keys
|
310 |
+
cached_values = values
|
311 |
+
else:
|
312 |
+
cached_keys = None
|
313 |
+
cached_values = None
|
314 |
+
|
315 |
+
queries = queries.transpose(1, 2)
|
316 |
+
keys = keys.transpose(1, 2)
|
317 |
+
values = values.transpose(1, 2)
|
318 |
+
|
319 |
+
apply_gqa = self.n_rep > 1
|
320 |
+
if apply_gqa and queries.device.type == "mps":
|
321 |
+
# NOTE: MPS does not support GQA in the SDPA kernel, but we can repeat the keys and values
|
322 |
+
# outside of the kernel to get the same effect.
|
323 |
+
# See: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
|
324 |
+
keys = keys.repeat_interleave(self.n_rep, dim=-3)
|
325 |
+
values = values.repeat_interleave(self.n_rep, dim=-3)
|
326 |
+
apply_gqa = False
|
327 |
+
|
328 |
+
if HAS_TORCH_ATTENTION:
|
329 |
+
backends = [SDPBackend.CUDNN_ATTENTION, SDPBackend.MATH]
|
330 |
+
with sdpa_kernel(backends=backends):
|
331 |
+
attn_output = F.scaled_dot_product_attention(
|
332 |
+
queries.contiguous(),
|
333 |
+
keys.contiguous(),
|
334 |
+
values.contiguous(),
|
335 |
+
attn_mask=mask.to(queries.dtype) if mask is not None else None,
|
336 |
+
enable_gqa=apply_gqa,
|
337 |
+
)
|
338 |
+
else:
|
339 |
+
# Fallback for older PyTorch versions - use default backend
|
340 |
+
attn_output = F.scaled_dot_product_attention(
|
341 |
+
queries.contiguous(),
|
342 |
+
keys.contiguous(),
|
343 |
+
values.contiguous(),
|
344 |
+
attn_mask=mask.to(queries.dtype) if mask is not None else None,
|
345 |
+
enable_gqa=apply_gqa,
|
346 |
+
)
|
347 |
+
|
348 |
+
attn_output = attn_output.transpose(1, 2).contiguous().view(bsz, seq_len, -1)
|
349 |
+
output = self.o_proj(attn_output)
|
350 |
+
|
351 |
+
return output, (cached_keys, cached_values)
|
352 |
+
|
353 |
+
|
354 |
+
########################################################
|
355 |
+
#
|
356 |
+
# SwiGLU (Combines MLP and Activation)
|
357 |
+
#
|
358 |
+
########################################################
|
359 |
+
|
360 |
+
|
361 |
+
class SwiGLU(nn.Module):
|
362 |
+
"""SwiGLU Activation Function with Linear Projections.
|
363 |
+
|
364 |
+
Implements the SwiGLU activation function combined with linear transformations,
|
365 |
+
serving as the feed-forward network in transformer blocks.
|
366 |
+
|
367 |
+
Args:
|
368 |
+
config (Union[ModelConfig, PicoDecoderHFConfig]): Configuration containing:
|
369 |
+
- config.d_model: Model dimension
|
370 |
+
- config.activation_hidden_dim: Hidden dimension (typically 4 * d_model)
|
371 |
+
|
372 |
+
References:
|
373 |
+
https://arxiv.org/abs/2002.05202
|
374 |
+
"""
|
375 |
+
|
376 |
+
def __init__(self, config: Union["ModelConfig", "PicoDecoderHFConfig"]):
|
377 |
+
super().__init__()
|
378 |
+
|
379 |
+
model_dim = config.d_model
|
380 |
+
act_hidden_dim = config.activation_hidden_dim # usually 4 * d_model
|
381 |
+
|
382 |
+
self.w_0 = nn.Linear(model_dim, act_hidden_dim, bias=False)
|
383 |
+
self.w_1 = nn.Linear(model_dim, act_hidden_dim, bias=False)
|
384 |
+
self.w_2 = nn.Linear(act_hidden_dim, model_dim, bias=False)
|
385 |
+
|
386 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
387 |
+
return self.w_2(F.silu(self.w_0(x)) * self.w_1(x))
|
388 |
+
|
389 |
+
|
390 |
+
########################################################
|
391 |
+
#
|
392 |
+
# PicoDecoderBlock
|
393 |
+
#
|
394 |
+
########################################################
|
395 |
+
|
396 |
+
|
397 |
+
class PicoDecoderBlock(nn.Module):
|
398 |
+
"""Single Transformer Block with Attention and Feed-forward layers.
|
399 |
+
|
400 |
+
Implements a standard transformer block with:
|
401 |
+
- Multi-head attention with normalization and residual connection
|
402 |
+
- SwiGLU feed-forward network with normalization and residual connection
|
403 |
+
|
404 |
+
Args:
|
405 |
+
config (Union[ModelConfig, PicoDecoderHFConfig]): Model configuration; either a dataclass or
|
406 |
+
a HuggingFace PicoDecoderHFConfig
|
407 |
+
"""
|
408 |
+
|
409 |
+
def __init__(
|
410 |
+
self,
|
411 |
+
config: Union["ModelConfig", "PicoDecoderHFConfig"],
|
412 |
+
):
|
413 |
+
super().__init__()
|
414 |
+
|
415 |
+
self.attention = Attention(config)
|
416 |
+
self.swiglu = SwiGLU(config)
|
417 |
+
self.attention_norm = RMSNorm(config)
|
418 |
+
self.swiglu_norm = RMSNorm(config)
|
419 |
+
|
420 |
+
def forward(
|
421 |
+
self,
|
422 |
+
input: torch.Tensor,
|
423 |
+
mask: Optional[torch.Tensor] = None,
|
424 |
+
past_key_values: Optional[Tuple[torch.Tensor]] = None,
|
425 |
+
use_cache: bool = False,
|
426 |
+
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
|
427 |
+
attention_output, cached_key_values = self.attention(
|
428 |
+
self.attention_norm(input),
|
429 |
+
mask=mask,
|
430 |
+
past_key_values=past_key_values,
|
431 |
+
use_cache=use_cache,
|
432 |
+
)
|
433 |
+
# NOTE: cached_key_values is None if use_cache is False
|
434 |
+
|
435 |
+
h = input + attention_output
|
436 |
+
out = h + self.swiglu(self.swiglu_norm(h))
|
437 |
+
return out, cached_key_values
|
438 |
+
|
439 |
+
|
440 |
+
########################################################
|
441 |
+
#
|
442 |
+
# Pico Decoder (Causal Transformer Model)
|
443 |
+
#
|
444 |
+
########################################################
|
445 |
+
|
446 |
+
|
447 |
+
class PicoDecoder(nn.Module):
|
448 |
+
"""
|
449 |
+
Pico Decoder: combines the embedding, causal decoder blocks, and output projection into a
|
450 |
+
single autoregressive model.
|
451 |
+
|
452 |
+
For more information on the model, see the classes for the modules that make up the model.
|
453 |
+
"""
|
454 |
+
|
455 |
+
def __init__(
|
456 |
+
self,
|
457 |
+
model_config: Union["ModelConfig", "PicoDecoderHFConfig"],
|
458 |
+
):
|
459 |
+
super().__init__()
|
460 |
+
self.config = model_config
|
461 |
+
|
462 |
+
self.embedding_proj = nn.Embedding(self.config.vocab_size, self.config.d_model)
|
463 |
+
self.layers = nn.ModuleList(
|
464 |
+
[PicoDecoderBlock(self.config) for _ in range(self.config.n_layers)]
|
465 |
+
)
|
466 |
+
self.output_norm = RMSNorm(self.config)
|
467 |
+
self.de_embedding_proj = nn.Linear(
|
468 |
+
self.config.d_model, self.config.vocab_size, bias=False
|
469 |
+
)
|
470 |
+
|
471 |
+
def convert_to_hf_model(self) -> "PicoDecoderHF":
|
472 |
+
"""Convert the Lightning model to a HuggingFace model."""
|
473 |
+
# Create HF config without fabric-specific settings
|
474 |
+
hf_config = PicoDecoderHFConfig.from_dataclass(self.config)
|
475 |
+
|
476 |
+
# Create new HF model
|
477 |
+
hf_model = PicoDecoderHF(hf_config)
|
478 |
+
|
479 |
+
# Copy state dict, excluding fabric-specific keys
|
480 |
+
hf_model.load_state_dict(self.state_dict(prefix="pico_decoder."))
|
481 |
+
|
482 |
+
return hf_model
|
483 |
+
|
484 |
+
def forward(
|
485 |
+
self,
|
486 |
+
input_ids: torch.Tensor,
|
487 |
+
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
488 |
+
use_cache: bool = False,
|
489 |
+
) -> Tuple[torch.Tensor, Optional[Tuple[Tuple[torch.Tensor, torch.Tensor]]]]:
|
490 |
+
"""
|
491 |
+
This is the forward pass for the entire Pico model. It boils down to:
|
492 |
+
- Embedding the input ids
|
493 |
+
- Creating a causal mask
|
494 |
+
- Processing through the pico layers
|
495 |
+
- Projecting the output to logits
|
496 |
+
|
497 |
+
NOTE: One feature that might be confusing is the KV cache. The KV cache is used to speed up
|
498 |
+
generation by caching the KV pairs from previous forward passes. This is useful when doing
|
499 |
+
tasks that require generating multiple tokens conditioned on previous tokens (e.g. language
|
500 |
+
modeling, text generation, etc.). The way the KV cache is implemented is that each layer has
|
501 |
+
its own KV cache which is stored as a tuple. The whole model then stores a tuple of these
|
502 |
+
KV caches (so a tuple of tuples).
|
503 |
+
"""
|
504 |
+
|
505 |
+
seq_len = input_ids.shape[-1]
|
506 |
+
h = self.embedding_proj(input_ids)
|
507 |
+
|
508 |
+
# Calculate start position from past cached KV pairs. Remember that each layer has its
|
509 |
+
# own KV Cache. So when we index past_key_values, we need to index into the KV pairs for the
|
510 |
+
# correct layer and then for either the keys or values.
|
511 |
+
start_pos = 0
|
512 |
+
if (
|
513 |
+
past_key_values is not None
|
514 |
+
and past_key_values[0] is not None
|
515 |
+
and past_key_values[0][0] is not None
|
516 |
+
):
|
517 |
+
start_pos = past_key_values[0][0].shape[1]
|
518 |
+
|
519 |
+
# Create causal mask for current sequence
|
520 |
+
mask = None
|
521 |
+
if seq_len > 1:
|
522 |
+
mask = torch.full((seq_len, seq_len), float("-inf"))
|
523 |
+
mask = torch.triu(mask, diagonal=1)
|
524 |
+
|
525 |
+
# If using KV cache, extend mask to cover cached sequence length
|
526 |
+
if past_key_values is not None:
|
527 |
+
# Add zeros for cached tokens (we can attend to all of them)
|
528 |
+
mask = torch.hstack([torch.zeros((seq_len, start_pos)), mask])
|
529 |
+
|
530 |
+
mask = mask.to(h.device)
|
531 |
+
|
532 |
+
# NOTE: If we are using the cache, we need to store the cached KV pairs for each layer
|
533 |
+
# in a tuple. Each layer will have its own cached KV pair which we aggregate in a tuple.
|
534 |
+
cached_key_values = () if use_cache else None
|
535 |
+
|
536 |
+
# Process through transformer blocks
|
537 |
+
for idx, layer in enumerate(self.layers):
|
538 |
+
layer_past_key_values = None
|
539 |
+
if past_key_values is not None:
|
540 |
+
try:
|
541 |
+
# Handle both tuple-based cache and HuggingFace cache objects
|
542 |
+
if hasattr(past_key_values, "__getitem__") and idx < len(
|
543 |
+
past_key_values
|
544 |
+
):
|
545 |
+
layer_past_key_values = past_key_values[idx]
|
546 |
+
except (KeyError, IndexError, TypeError):
|
547 |
+
# If we can't access the cache properly, just skip it
|
548 |
+
layer_past_key_values = None
|
549 |
+
|
550 |
+
h, layer_cached_key_values = layer(
|
551 |
+
h, mask=mask, past_key_values=layer_past_key_values, use_cache=use_cache
|
552 |
+
)
|
553 |
+
|
554 |
+
if use_cache:
|
555 |
+
cached_key_values += (layer_cached_key_values,)
|
556 |
+
|
557 |
+
# Final norm and projection
|
558 |
+
h = self.output_norm(h)
|
559 |
+
logits = self.de_embedding_proj(h).float()
|
560 |
+
|
561 |
+
return logits, cached_key_values
|
562 |
+
|
563 |
+
|
564 |
+
########################################################
|
565 |
+
#
|
566 |
+
# HuggingFace Wrapper for the Pico Decoder model.
|
567 |
+
#
|
568 |
+
########################################################
|
569 |
+
|
570 |
+
|
571 |
+
class PicoDecoderHFConfig(PretrainedConfig):
|
572 |
+
"""Config class for the Pico Decoder HuggingFace wrapper."""
|
573 |
+
|
574 |
+
model_type = "pico_decoder"
|
575 |
+
|
576 |
+
@classmethod
|
577 |
+
def from_dict(cls, config_dict: Dict[str, Any], **kwargs) -> "PicoDecoderHFConfig":
|
578 |
+
"""
|
579 |
+
Initialize config from a dictionary. Note that no kwargs are passed to the constructor --
|
580 |
+
this is because with some kwargs special handling is required and can make this class
|
581 |
+
brittle.
|
582 |
+
"""
|
583 |
+
pico_config = cls(**config_dict)
|
584 |
+
|
585 |
+
return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)
|
586 |
+
unused_kwargs = {
|
587 |
+
key: value for key, value in kwargs.items() if not hasattr(pico_config, key)
|
588 |
+
}
|
589 |
+
|
590 |
+
if return_unused_kwargs:
|
591 |
+
return pico_config, unused_kwargs
|
592 |
+
return pico_config
|
593 |
+
|
594 |
+
@classmethod
|
595 |
+
def from_dataclass(cls, model_config: "ModelConfig"):
|
596 |
+
"""Initialise from our custom config dataclass."""
|
597 |
+
return cls.from_dict(asdict(model_config))
|
598 |
+
|
599 |
+
|
600 |
+
class PicoDecoderHF(PreTrainedModel, GenerationMixin):
|
601 |
+
"""
|
602 |
+
HuggingFace wrapper for the Pico model with generation support.
|
603 |
+
|
604 |
+
Many evaluation frameworks require a model be setup as a HuggingFace model, so we provide a simple
|
605 |
+
wrapper that does just that. When we save checkpoints of the Pico model, we save both the normal
|
606 |
+
Pico model as well as the model wrapped in this HuggingFace class.
|
607 |
+
|
608 |
+
This also lets you do cool things like:
|
609 |
+
|
610 |
+
`model = AutoModelForCausalLM.from_pretrained("path/to/checkpoint")`
|
611 |
+
"""
|
612 |
+
|
613 |
+
config_class = PicoDecoderHFConfig
|
614 |
+
_no_split_modules = ["PicoBlock", "Attention", "SwiGLU", "RMSNorm"]
|
615 |
+
main_input_name = "input_ids"
|
616 |
+
|
617 |
+
def __init__(self, config: PicoDecoderHFConfig):
|
618 |
+
super().__init__(config)
|
619 |
+
self.pico_decoder = PicoDecoder(config)
|
620 |
+
# Initialize generation config with defaults
|
621 |
+
self.generation_config = GenerationConfig()
|
622 |
+
# Set some reasonable defaults for the model
|
623 |
+
if hasattr(config, "max_position_embeddings"):
|
624 |
+
self.generation_config.max_length = config.max_position_embeddings
|
625 |
+
if hasattr(config, "vocab_size"):
|
626 |
+
self.generation_config.vocab_size = config.vocab_size
|
627 |
+
|
628 |
+
def forward(
|
629 |
+
self,
|
630 |
+
input_ids: torch.Tensor,
|
631 |
+
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
632 |
+
use_cache: bool = False,
|
633 |
+
**kwargs,
|
634 |
+
) -> Union[CausalLMOutput, CausalLMOutputWithPast]:
|
635 |
+
"""HuggingFace forward pass wrapper.
|
636 |
+
|
637 |
+
Forwards pass for the HuggingFace version of the Pico Model. Basic wrapper around the
|
638 |
+
Pico model's forward pass, and returns the output as a HuggingFace CausalLMOutput.
|
639 |
+
"""
|
640 |
+
logits, past_key_values = self.pico_decoder(
|
641 |
+
input_ids, past_key_values, use_cache
|
642 |
+
)
|
643 |
+
if use_cache:
|
644 |
+
return CausalLMOutputWithPast(
|
645 |
+
logits=logits,
|
646 |
+
past_key_values=past_key_values,
|
647 |
+
)
|
648 |
+
else:
|
649 |
+
return CausalLMOutput(
|
650 |
+
logits=logits,
|
651 |
+
)
|
652 |
+
|
653 |
+
def prepare_inputs_for_generation(
|
654 |
+
self,
|
655 |
+
input_ids: torch.LongTensor,
|
656 |
+
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
657 |
+
attention_mask: Optional[torch.LongTensor] = None,
|
658 |
+
**kwargs,
|
659 |
+
) -> Dict[str, Any]:
|
660 |
+
"""
|
661 |
+
Prepare inputs for generation.
|
662 |
+
|
663 |
+
Args:
|
664 |
+
input_ids: Input token IDs
|
665 |
+
past_key_values: Cached key-value pairs from previous forward passes
|
666 |
+
attention_mask: Attention mask for the input
|
667 |
+
**kwargs: Additional arguments
|
668 |
+
|
669 |
+
Returns:
|
670 |
+
Dictionary containing prepared inputs
|
671 |
+
"""
|
672 |
+
# If we have past_key_values, we only need the last token
|
673 |
+
if past_key_values is not None:
|
674 |
+
input_ids = input_ids[:, -1:]
|
675 |
+
|
676 |
+
return {
|
677 |
+
"input_ids": input_ids,
|
678 |
+
"past_key_values": past_key_values,
|
679 |
+
"use_cache": True,
|
680 |
+
}
|
681 |
+
|
682 |
+
def get_input_embeddings(self):
|
683 |
+
"""Get the input embeddings layer."""
|
684 |
+
return self.pico_decoder.embedding_proj
|
685 |
+
|
686 |
+
def set_input_embeddings(self, value):
|
687 |
+
"""Set the input embeddings layer."""
|
688 |
+
self.pico_decoder.embedding_proj = value
|
689 |
+
|
690 |
+
def get_output_embeddings(self):
|
691 |
+
"""Get the output embeddings layer."""
|
692 |
+
return self.pico_decoder.de_embedding_proj
|
693 |
+
|
694 |
+
def set_output_embeddings(self, value):
|
695 |
+
"""Set the output embeddings layer."""
|
696 |
+
self.pico_decoder.de_embedding_proj = value
|
697 |
+
|
698 |
+
def get_lm_head(self):
|
699 |
+
"""Get the language model head."""
|
700 |
+
return self.pico_decoder.de_embedding_proj
|
701 |
+
|
702 |
+
def can_generate(self) -> bool:
|
703 |
+
"""Check if the model can generate text."""
|
704 |
+
return True
|
705 |
+
|
706 |
+
@property
|
707 |
+
def is_encoder_decoder(self) -> bool:
|
708 |
+
"""Check if the model is an encoder-decoder model."""
|
709 |
+
return False
|
710 |
+
|
711 |
+
@property
|
712 |
+
def can_use_cache(self) -> bool:
|
713 |
+
"""Check if the model can use KV cache."""
|
714 |
+
return True
|
715 |
+
|
716 |
+
def resize_token_embeddings(
|
717 |
+
self, new_num_tokens: Optional[int] = None
|
718 |
+
) -> torch.nn.Embedding:
|
719 |
+
"""Resize token embeddings."""
|
720 |
+
old_embeddings = self.get_input_embeddings()
|
721 |
+
if new_num_tokens is None:
|
722 |
+
new_num_tokens = old_embeddings.num_embeddings
|
723 |
+
|
724 |
+
new_embeddings = torch.nn.Embedding(
|
725 |
+
new_num_tokens, old_embeddings.embedding_dim
|
726 |
+
)
|
727 |
+
new_embeddings.weight.data[: old_embeddings.num_embeddings] = (
|
728 |
+
old_embeddings.weight.data
|
729 |
+
)
|
730 |
+
|
731 |
+
self.pico_decoder.embedding_proj = new_embeddings
|
732 |
+
self.pico_decoder.de_embedding_proj = torch.nn.Linear(
|
733 |
+
old_embeddings.embedding_dim, new_num_tokens, bias=False
|
734 |
+
)
|
735 |
+
|
736 |
+
return new_embeddings
|
737 |
+
|
738 |
+
|
739 |
+
# Register for auto classes
|
740 |
+
PicoDecoderHFConfig.register_for_auto_class()
|
741 |
+
PicoDecoderHF.register_for_auto_class("AutoModel")
|
742 |
+
PicoDecoderHF.register_for_auto_class("AutoModelForCausalLM")
|
743 |
+
|
744 |
+
|
745 |
+
########################################################
|
746 |
+
#
|
747 |
+
# New PicoDecoderForCausalLM class for generation support
|
748 |
+
#
|
749 |
+
########################################################
|
750 |
+
|
751 |
+
|
752 |
+
class PicoDecoderForCausalLM(PreTrainedModel, GenerationMixin):
|
753 |
+
"""
|
754 |
+
PicoDecoderForCausalLM: A HuggingFace-compatible model that properly supports generation.
|
755 |
+
|
756 |
+
This class is designed to work with existing checkpoints and provides full generation support.
|
757 |
+
It inherits from the right base classes that HuggingFace expects for text generation.
|
758 |
+
"""
|
759 |
+
|
760 |
+
config_class = PicoDecoderHFConfig
|
761 |
+
_no_split_modules = ["PicoBlock", "Attention", "SwiGLU", "RMSNorm"]
|
762 |
+
main_input_name = "input_ids"
|
763 |
+
|
764 |
+
def __init__(self, config: PicoDecoderHFConfig):
|
765 |
+
super().__init__(config)
|
766 |
+
self.pico_decoder = PicoDecoder(config)
|
767 |
+
# Initialize generation config with defaults
|
768 |
+
self.generation_config = GenerationConfig()
|
769 |
+
# Set some reasonable defaults for the model
|
770 |
+
if hasattr(config, "max_position_embeddings"):
|
771 |
+
self.generation_config.max_length = config.max_position_embeddings
|
772 |
+
if hasattr(config, "vocab_size"):
|
773 |
+
self.generation_config.vocab_size = config.vocab_size
|
774 |
+
|
775 |
+
def forward(
|
776 |
+
self,
|
777 |
+
input_ids: torch.Tensor,
|
778 |
+
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
779 |
+
use_cache: bool = False,
|
780 |
+
**kwargs,
|
781 |
+
) -> Union[CausalLMOutput, CausalLMOutputWithPast]:
|
782 |
+
"""Forward pass for text generation."""
|
783 |
+
logits, past_key_values = self.pico_decoder(
|
784 |
+
input_ids, past_key_values, use_cache
|
785 |
+
)
|
786 |
+
if use_cache:
|
787 |
+
return CausalLMOutputWithPast(
|
788 |
+
logits=logits,
|
789 |
+
past_key_values=past_key_values,
|
790 |
+
)
|
791 |
+
else:
|
792 |
+
return CausalLMOutput(
|
793 |
+
logits=logits,
|
794 |
+
)
|
795 |
+
|
796 |
+
def prepare_inputs_for_generation(
|
797 |
+
self,
|
798 |
+
input_ids: torch.LongTensor,
|
799 |
+
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
800 |
+
attention_mask: Optional[torch.LongTensor] = None,
|
801 |
+
**kwargs,
|
802 |
+
) -> Dict[str, Any]:
|
803 |
+
"""Prepare inputs for generation."""
|
804 |
+
# If we have past_key_values, we only need the last token
|
805 |
+
if past_key_values is not None:
|
806 |
+
input_ids = input_ids[:, -1:]
|
807 |
+
|
808 |
+
return {
|
809 |
+
"input_ids": input_ids,
|
810 |
+
"past_key_values": past_key_values,
|
811 |
+
"use_cache": True,
|
812 |
+
}
|
813 |
+
|
814 |
+
def get_input_embeddings(self):
|
815 |
+
"""Get the input embeddings layer."""
|
816 |
+
return self.pico_decoder.embedding_proj
|
817 |
+
|
818 |
+
def set_input_embeddings(self, value):
|
819 |
+
"""Set the input embeddings layer."""
|
820 |
+
self.pico_decoder.embedding_proj = value
|
821 |
+
|
822 |
+
def get_output_embeddings(self):
|
823 |
+
"""Get the output embeddings layer."""
|
824 |
+
return self.pico_decoder.de_embedding_proj
|
825 |
+
|
826 |
+
def set_output_embeddings(self, value):
|
827 |
+
"""Set the output embeddings layer."""
|
828 |
+
self.pico_decoder.de_embedding_proj = value
|
829 |
+
|
830 |
+
def get_lm_head(self):
|
831 |
+
"""Get the language model head."""
|
832 |
+
return self.pico_decoder.de_embedding_proj
|
833 |
+
|
834 |
+
def can_generate(self) -> bool:
|
835 |
+
"""Check if the model can generate text."""
|
836 |
+
return True
|
837 |
+
|
838 |
+
@property
|
839 |
+
def is_encoder_decoder(self) -> bool:
|
840 |
+
"""Check if the model is an encoder-decoder model."""
|
841 |
+
return False
|
842 |
+
|
843 |
+
@property
|
844 |
+
def can_use_cache(self) -> bool:
|
845 |
+
"""Check if the model can use KV cache."""
|
846 |
+
return True
|
847 |
+
|
848 |
+
def resize_token_embeddings(
|
849 |
+
self, new_num_tokens: Optional[int] = None
|
850 |
+
) -> torch.nn.Embedding:
|
851 |
+
"""Resize token embeddings."""
|
852 |
+
old_embeddings = self.get_input_embeddings()
|
853 |
+
if new_num_tokens is None:
|
854 |
+
new_num_tokens = old_embeddings.num_embeddings
|
855 |
+
|
856 |
+
new_embeddings = torch.nn.Embedding(
|
857 |
+
new_num_tokens, old_embeddings.embedding_dim
|
858 |
+
)
|
859 |
+
new_embeddings.weight.data[: old_embeddings.num_embeddings] = (
|
860 |
+
old_embeddings.weight.data
|
861 |
+
)
|
862 |
+
|
863 |
+
self.pico_decoder.embedding_proj = new_embeddings
|
864 |
+
self.pico_decoder.de_embedding_proj = torch.nn.Linear(
|
865 |
+
old_embeddings.embedding_dim, new_num_tokens, bias=False
|
866 |
+
)
|
867 |
+
|
868 |
+
return new_embeddings
|
869 |
+
|
870 |
+
@classmethod
|
871 |
+
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
872 |
+
"""
|
873 |
+
Load a pretrained model from a checkpoint.
|
874 |
+
|
875 |
+
This method handles loading from both the old PicoDecoderHF format and the new format.
|
876 |
+
"""
|
877 |
+
# First try to load with the new class
|
878 |
+
try:
|
879 |
+
return super().from_pretrained(
|
880 |
+
pretrained_model_name_or_path, *model_args, **kwargs
|
881 |
+
)
|
882 |
+
except Exception as e:
|
883 |
+
print(f"Failed to load with new class: {e}")
|
884 |
+
print("Attempting to load with legacy class and convert...")
|
885 |
+
|
886 |
+
# Try to load with the old class and convert
|
887 |
+
try:
|
888 |
+
from transformers import AutoModel
|
889 |
+
|
890 |
+
old_model = AutoModel.from_pretrained(
|
891 |
+
pretrained_model_name_or_path,
|
892 |
+
trust_remote_code=True,
|
893 |
+
*model_args,
|
894 |
+
**kwargs,
|
895 |
+
)
|
896 |
+
|
897 |
+
# Create new model instance
|
898 |
+
new_model = cls(old_model.config)
|
899 |
+
|
900 |
+
# Copy state dict
|
901 |
+
new_model.load_state_dict(old_model.state_dict(), strict=False)
|
902 |
+
|
903 |
+
return new_model
|
904 |
+
|
905 |
+
except Exception as e2:
|
906 |
+
print(f"Failed to convert from legacy format: {e2}")
|
907 |
+
raise e
|
908 |
+
|
909 |
+
|
910 |
+
# Register the new class
|
911 |
+
PicoDecoderForCausalLM.register_for_auto_class("AutoModelForCausalLM")
|
src/training/trainer.py
ADDED
@@ -0,0 +1,753 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Pico Language Model Trainer
|
3 |
+
|
4 |
+
This Trainer implements a minimalistic end-to-end training pipeline of the Pico language model with
|
5 |
+
distributed training support via Lightning Fabric. It provides a modular and configurable training
|
6 |
+
pipeline with the features:
|
7 |
+
|
8 |
+
- Configuration Management: YAML-based configuration for all aspects of training
|
9 |
+
- Distributed Training: Multi-GPU support via Lightning Fabric
|
10 |
+
- Checkpointing: Regular model saving and training state recovery
|
11 |
+
- Evaluation: Periodic model evaluation on validation datasets
|
12 |
+
- Logging: Comprehensive metric tracking and experiment monitoring
|
13 |
+
- Optimization: Support for gradient accumulation, clipping, and LR scheduling
|
14 |
+
"""
|
15 |
+
|
16 |
+
import logging
|
17 |
+
import os
|
18 |
+
import platform
|
19 |
+
from typing import Any, Dict
|
20 |
+
|
21 |
+
import lightning as L
|
22 |
+
import psutil
|
23 |
+
import torch
|
24 |
+
import torch.nn.functional as F
|
25 |
+
import yaml
|
26 |
+
from datasets import Dataset, load_dataset
|
27 |
+
from lightning.fabric.utilities.rank_zero import rank_zero_only
|
28 |
+
|
29 |
+
from src.checkpointing import (
|
30 |
+
compute_learning_dynamics_states,
|
31 |
+
load_checkpoint,
|
32 |
+
save_checkpoint,
|
33 |
+
save_evaluation_results,
|
34 |
+
save_learning_dynamics_states,
|
35 |
+
)
|
36 |
+
from src.evaluation import run_evaluation
|
37 |
+
from src.training.utils import (
|
38 |
+
initialize_configuration,
|
39 |
+
initialize_dataloader,
|
40 |
+
initialize_dataset,
|
41 |
+
initialize_fabric,
|
42 |
+
initialize_hf_checkpointing,
|
43 |
+
initialize_logging,
|
44 |
+
initialize_lr_scheduler,
|
45 |
+
initialize_model,
|
46 |
+
initialize_optimizer,
|
47 |
+
initialize_run_dir,
|
48 |
+
initialize_tokenizer,
|
49 |
+
initialize_wandb,
|
50 |
+
)
|
51 |
+
from src.training.utils.logging import pretty_print_yaml_config
|
52 |
+
|
53 |
+
|
54 |
+
class Trainer:
|
55 |
+
def __init__(self, config_path: str):
|
56 |
+
"""
|
57 |
+
Initializes the Trainer class. This Trainer class implements a `train` method, which is the
|
58 |
+
main entry point for training the Pico model. Before calling `train`, the Trainer class
|
59 |
+
initializes the following:
|
60 |
+
|
61 |
+
- Configuration loading and validation
|
62 |
+
- Model, optimizer, and dataset setup
|
63 |
+
- Logging and experiment tracking setup
|
64 |
+
- Checkpoint management
|
65 |
+
|
66 |
+
Args:
|
67 |
+
config_path (str): Path to the YAML configuration file containing any overrides.
|
68 |
+
"""
|
69 |
+
|
70 |
+
########################################################
|
71 |
+
#
|
72 |
+
# Basic Initialization of Configs, Fabric, Model, Optimizer, etc.
|
73 |
+
#
|
74 |
+
########################################################
|
75 |
+
|
76 |
+
# Setup Config
|
77 |
+
self.configs = initialize_configuration(config_path)
|
78 |
+
|
79 |
+
# Setup Run Directory (i.e. where we store checkpoints, logs, etc.)
|
80 |
+
initialize_run_dir(checkpointing_config=self.configs["checkpointing"])
|
81 |
+
|
82 |
+
# Setup Logger
|
83 |
+
if self.configs["monitoring"].save_to_wandb:
|
84 |
+
wandb_logger = initialize_wandb(
|
85 |
+
monitoring_config=self.configs["monitoring"],
|
86 |
+
checkpointing_config=self.configs["checkpointing"],
|
87 |
+
)
|
88 |
+
else:
|
89 |
+
wandb_logger = None
|
90 |
+
|
91 |
+
# Setup Fabric
|
92 |
+
self.fabric = initialize_fabric(
|
93 |
+
training_config=self.configs["training"],
|
94 |
+
wandb_logger=wandb_logger,
|
95 |
+
)
|
96 |
+
L.seed_everything(42, verbose=False)
|
97 |
+
|
98 |
+
# Optimize for Tensor Cores on RTX 5090
|
99 |
+
if self.fabric.device.type == "cuda":
|
100 |
+
torch.set_float32_matmul_precision(
|
101 |
+
"high"
|
102 |
+
) # Best performance for Tensor Cores
|
103 |
+
print(
|
104 |
+
"Enabled Tensor Core optimization: torch.set_float32_matmul_precision('high')"
|
105 |
+
)
|
106 |
+
|
107 |
+
# Set up logging
|
108 |
+
self.logger = initialize_logging(
|
109 |
+
monitoring_config=self.configs["monitoring"],
|
110 |
+
checkpointing_config=self.configs["checkpointing"],
|
111 |
+
fabric=self.fabric,
|
112 |
+
)
|
113 |
+
|
114 |
+
# Setup Model, Optimizer, and Dataloaders
|
115 |
+
self.model = initialize_model(model_config=self.configs["model"])
|
116 |
+
self.optimizer = initialize_optimizer(
|
117 |
+
training_config=self.configs["training"], model=self.model
|
118 |
+
)
|
119 |
+
self.lr_scheduler = initialize_lr_scheduler(
|
120 |
+
training_config=self.configs["training"], optimizer=self.optimizer
|
121 |
+
)
|
122 |
+
|
123 |
+
# Wrap model and optimizer with Fabric
|
124 |
+
self.model, self.optimizer = self.fabric.setup(self.model, self.optimizer)
|
125 |
+
|
126 |
+
# Setup HuggingFace Checkpointing
|
127 |
+
if self.configs["checkpointing"].save_to_hf:
|
128 |
+
initialize_hf_checkpointing(
|
129 |
+
checkpointing_config=self.configs["checkpointing"], fabric=self.fabric
|
130 |
+
)
|
131 |
+
|
132 |
+
########################################################
|
133 |
+
#
|
134 |
+
# Boilerplate to deal with loading/resuming from checkpoints
|
135 |
+
#
|
136 |
+
########################################################
|
137 |
+
|
138 |
+
self.should_load_checkpoint = self.configs["checkpointing"].training.auto_resume
|
139 |
+
|
140 |
+
# Possibly load a checkpoint
|
141 |
+
if self.should_load_checkpoint:
|
142 |
+
resume_checkpoint = load_checkpoint(
|
143 |
+
checkpointing_config=self.configs["checkpointing"],
|
144 |
+
checkpoint_step="latest",
|
145 |
+
fabric=self.fabric,
|
146 |
+
model=self.model,
|
147 |
+
optimizer=self.optimizer,
|
148 |
+
lr_scheduler=self.lr_scheduler,
|
149 |
+
)
|
150 |
+
|
151 |
+
if resume_checkpoint:
|
152 |
+
(
|
153 |
+
self.model,
|
154 |
+
self.optimizer,
|
155 |
+
self.lr_scheduler,
|
156 |
+
self.initial_batch_step,
|
157 |
+
) = resume_checkpoint
|
158 |
+
else:
|
159 |
+
self.initial_batch_step = 0
|
160 |
+
else:
|
161 |
+
self.initial_batch_step = 0
|
162 |
+
|
163 |
+
########################################################
|
164 |
+
#
|
165 |
+
# Initialization of Dataset & DataLoader (possibly fast-forwarding to correct batch)
|
166 |
+
#
|
167 |
+
########################################################
|
168 |
+
|
169 |
+
self.train_dataset, fast_forward_steps = initialize_dataset(
|
170 |
+
data_config=self.configs["data"],
|
171 |
+
fabric=self.fabric,
|
172 |
+
initial_batch_step=self.initial_batch_step,
|
173 |
+
return_fast_forward_steps=True,
|
174 |
+
)
|
175 |
+
|
176 |
+
self.train_dataloader = initialize_dataloader(
|
177 |
+
data_config=self.configs["data"],
|
178 |
+
training_config=self.configs["training"],
|
179 |
+
fabric=self.fabric,
|
180 |
+
dataset=self.train_dataset,
|
181 |
+
)
|
182 |
+
self.train_dataloader = self.fabric.setup_dataloaders(
|
183 |
+
self.train_dataloader, use_distributed_sampler=False
|
184 |
+
)
|
185 |
+
|
186 |
+
self.tokenizer = initialize_tokenizer(data_config=self.configs["data"])
|
187 |
+
|
188 |
+
# NOTE: We may need to fast-forward the iterator to the correct step so that we can
|
189 |
+
# continue from the correct batch of data we would have seen had training not
|
190 |
+
# previously stopped.
|
191 |
+
train_iterator = iter(self.train_dataloader)
|
192 |
+
if fast_forward_steps > 0:
|
193 |
+
fast_forward_sub_steps = (
|
194 |
+
fast_forward_steps
|
195 |
+
* self.configs["training"].optimization.gradient_accumulation_steps
|
196 |
+
)
|
197 |
+
for _ in range(fast_forward_sub_steps):
|
198 |
+
next(train_iterator)
|
199 |
+
|
200 |
+
self.train_iterator = train_iterator
|
201 |
+
|
202 |
+
# NOTE: Sychronizing processes after fast-forwarding iterator
|
203 |
+
self.fabric.barrier()
|
204 |
+
|
205 |
+
########################################################
|
206 |
+
#
|
207 |
+
# Helper flags used during training for checkpointing and evaluation
|
208 |
+
#
|
209 |
+
########################################################
|
210 |
+
|
211 |
+
# Helper flag to determine if we should evaluate the model
|
212 |
+
self.should_evaluate = (
|
213 |
+
self.configs["evaluation"].metrics is not None
|
214 |
+
and len(self.configs["evaluation"].metrics) > 0
|
215 |
+
)
|
216 |
+
|
217 |
+
self.should_compute_learning_dynamics = (
|
218 |
+
self.configs["checkpointing"].learning_dynamics.layer_suffixes is not None
|
219 |
+
and len(self.configs["checkpointing"].learning_dynamics.layer_suffixes) > 0
|
220 |
+
)
|
221 |
+
|
222 |
+
if self.should_compute_learning_dynamics:
|
223 |
+
if self.configs["checkpointing"].learning_dynamics.eval_data is not None:
|
224 |
+
self.learning_dynamics_eval_dataset = load_dataset(
|
225 |
+
self.configs["checkpointing"].learning_dynamics.eval_data,
|
226 |
+
split="val",
|
227 |
+
)
|
228 |
+
else:
|
229 |
+
self.learning_dynamics_eval_dataset = None
|
230 |
+
|
231 |
+
def train(self) -> None:
|
232 |
+
"""Execute the main training pipeline.
|
233 |
+
|
234 |
+
This method orchestrates the complete training process by:
|
235 |
+
1. Creating an initial checkpoint to save the starting state and evaluate the model as a
|
236 |
+
baseline
|
237 |
+
2. Running the main training loop via `_training_loop`
|
238 |
+
3. Handling final checkpointing and evaluation
|
239 |
+
|
240 |
+
The training progress is tracked through checkpoints and evaluations
|
241 |
+
at intervals specified in the configuration.
|
242 |
+
"""
|
243 |
+
|
244 |
+
########################################################
|
245 |
+
#
|
246 |
+
# Initial Checkpointing and Evaluation
|
247 |
+
#
|
248 |
+
########################################################
|
249 |
+
|
250 |
+
# Save Initial Checkpoint -- If the checkpoint already exists, this performs a no-op
|
251 |
+
save_checkpoint(
|
252 |
+
configs=self.configs,
|
253 |
+
checkpoint_step=self.initial_batch_step,
|
254 |
+
fabric=self.fabric,
|
255 |
+
model=self.model,
|
256 |
+
optimizer=self.optimizer,
|
257 |
+
lr_scheduler=self.lr_scheduler,
|
258 |
+
tokenizer=self.tokenizer,
|
259 |
+
)
|
260 |
+
|
261 |
+
# Save Initial Evaluation Results
|
262 |
+
if self.should_evaluate:
|
263 |
+
if self.initial_batch_step == 0:
|
264 |
+
evaluation_results = run_evaluation(
|
265 |
+
evaluation_config=self.configs["evaluation"],
|
266 |
+
checkpointing_config=self.configs["checkpointing"],
|
267 |
+
fabric=self.fabric,
|
268 |
+
model=self.model,
|
269 |
+
)
|
270 |
+
self._log_evaluation_results(
|
271 |
+
evaluation_results, self.initial_batch_step
|
272 |
+
)
|
273 |
+
save_evaluation_results(
|
274 |
+
checkpointing_config=self.configs["checkpointing"],
|
275 |
+
fabric=self.fabric,
|
276 |
+
evaluation_results=evaluation_results,
|
277 |
+
checkpoint_step=self.initial_batch_step,
|
278 |
+
)
|
279 |
+
else:
|
280 |
+
# NOTE: If the run crashed while evaluating, we need to restart the evaluation
|
281 |
+
eval_results_path = os.path.join(
|
282 |
+
self.configs["checkpointing"].evaluation.eval_results_dir,
|
283 |
+
f"step_{self.initial_batch_step}.json",
|
284 |
+
)
|
285 |
+
if not os.path.exists(eval_results_path):
|
286 |
+
evaluation_results = run_evaluation(
|
287 |
+
evaluation_config=self.configs["evaluation"],
|
288 |
+
checkpointing_config=self.configs["checkpointing"],
|
289 |
+
fabric=self.fabric,
|
290 |
+
model=self.model,
|
291 |
+
)
|
292 |
+
self._log_evaluation_results(
|
293 |
+
evaluation_results, self.initial_batch_step
|
294 |
+
)
|
295 |
+
save_evaluation_results(
|
296 |
+
checkpointing_config=self.configs["checkpointing"],
|
297 |
+
fabric=self.fabric,
|
298 |
+
evaluation_results=evaluation_results,
|
299 |
+
checkpoint_step=self.initial_batch_step,
|
300 |
+
)
|
301 |
+
|
302 |
+
########################################################
|
303 |
+
#
|
304 |
+
# Main Training Loop (see `_training_loop` for details)
|
305 |
+
#
|
306 |
+
########################################################
|
307 |
+
|
308 |
+
if self.initial_batch_step < self.configs["training"].max_steps:
|
309 |
+
self._log_training_configuration()
|
310 |
+
final_step = self._training_loop()
|
311 |
+
else:
|
312 |
+
final_step = self.initial_batch_step
|
313 |
+
|
314 |
+
########################################################
|
315 |
+
#
|
316 |
+
# Final Checkpointing and Evaluation
|
317 |
+
#
|
318 |
+
########################################################
|
319 |
+
|
320 |
+
# Save Learning Dynamics States
|
321 |
+
if self.should_compute_learning_dynamics:
|
322 |
+
if self.learning_dynamics_eval_dataset is not None:
|
323 |
+
self.log(f"Step {final_step} -- 📈 Saving Learning Dynamics")
|
324 |
+
learning_dynamics_val_states = compute_learning_dynamics_states(
|
325 |
+
checkpointing_config=self.configs["checkpointing"],
|
326 |
+
fabric=self.fabric,
|
327 |
+
model=self.model,
|
328 |
+
dataset=self.learning_dynamics_eval_dataset,
|
329 |
+
compute_gradients=True,
|
330 |
+
)
|
331 |
+
save_learning_dynamics_states(
|
332 |
+
checkpointing_config=self.configs["checkpointing"],
|
333 |
+
fabric=self.fabric,
|
334 |
+
learning_dynamics_states=learning_dynamics_val_states,
|
335 |
+
checkpoint_step=final_step,
|
336 |
+
prefix="val",
|
337 |
+
)
|
338 |
+
|
339 |
+
# Handle checkpointing and final evaluation
|
340 |
+
if final_step % self.configs["checkpointing"].save_every_n_steps != 0:
|
341 |
+
self.log(f"Step {final_step} -- 💾 Saving Final Checkpoint")
|
342 |
+
save_checkpoint(
|
343 |
+
configs=self.configs,
|
344 |
+
checkpoint_step=final_step,
|
345 |
+
fabric=self.fabric,
|
346 |
+
model=self.model,
|
347 |
+
optimizer=self.optimizer,
|
348 |
+
lr_scheduler=self.lr_scheduler,
|
349 |
+
tokenizer=self.tokenizer,
|
350 |
+
)
|
351 |
+
|
352 |
+
# Final evaluation
|
353 |
+
if self.should_evaluate:
|
354 |
+
evaluation_results = run_evaluation(
|
355 |
+
evaluation_config=self.configs["evaluation"],
|
356 |
+
checkpointing_config=self.configs["checkpointing"],
|
357 |
+
fabric=self.fabric,
|
358 |
+
model=self.model,
|
359 |
+
)
|
360 |
+
self._log_evaluation_results(evaluation_results, final_step)
|
361 |
+
save_evaluation_results(
|
362 |
+
checkpointing_config=self.configs["checkpointing"],
|
363 |
+
checkpoint_step=final_step,
|
364 |
+
fabric=self.fabric,
|
365 |
+
evaluation_results=evaluation_results,
|
366 |
+
)
|
367 |
+
|
368 |
+
self.log(f"🎉 Training complete! Final step: {final_step}")
|
369 |
+
|
370 |
+
if final_step < self.configs["training"].max_steps:
|
371 |
+
self.log(
|
372 |
+
f"\t Note: Training stopped before max steps ({self.configs['training'].max_steps})",
|
373 |
+
level=logging.WARNING,
|
374 |
+
)
|
375 |
+
|
376 |
+
# Cleanup distributed training
|
377 |
+
self.fabric.barrier()
|
378 |
+
if torch.cuda.is_available():
|
379 |
+
torch.cuda.empty_cache()
|
380 |
+
if torch.distributed.is_initialized():
|
381 |
+
torch.distributed.destroy_process_group()
|
382 |
+
|
383 |
+
del self.train_dataloader # NOTE: shutting down worker nodes
|
384 |
+
|
385 |
+
self.fabric.barrier()
|
386 |
+
|
387 |
+
def _training_loop(self) -> int:
|
388 |
+
"""Execute the main training loop.
|
389 |
+
|
390 |
+
This method orchestrates the core training loop and includes the following features:
|
391 |
+
- Gradient accumulation
|
392 |
+
- Gradient clipping
|
393 |
+
- Periodic model evaluation and checkpointing
|
394 |
+
- Learning Dynamics Checkpointing
|
395 |
+
- Learning rate scheduling
|
396 |
+
- Logging of training metrics including loss and learning rate
|
397 |
+
- Handling of infinite/NaN losses
|
398 |
+
|
399 |
+
Returns:
|
400 |
+
int: The final step count reached during training.
|
401 |
+
NOTE: A complete training run should match the configured max_steps.
|
402 |
+
"""
|
403 |
+
# Setup training loop variables
|
404 |
+
batch_step = self.initial_batch_step
|
405 |
+
|
406 |
+
# NOTE: these are used to compute the average loss over a training interval.
|
407 |
+
# This is more accurate than using the loss at the end of the interval.
|
408 |
+
interval_loss = torch.tensor(0.0, device=self.fabric.device)
|
409 |
+
interval_steps = torch.tensor(0, device=self.fabric.device)
|
410 |
+
interval_inf_or_nan_count = torch.tensor(0, device=self.fabric.device)
|
411 |
+
|
412 |
+
if self.should_compute_learning_dynamics:
|
413 |
+
# NOTE: we basically re-construct the full batch here so that we can compute learning dynamics
|
414 |
+
training_batch = {"input_ids": []}
|
415 |
+
|
416 |
+
# NOTE: determine what sub-batch we should start from
|
417 |
+
initial_sub_batch_step = (
|
418 |
+
batch_step
|
419 |
+
* self.configs["training"].optimization.gradient_accumulation_steps
|
420 |
+
)
|
421 |
+
|
422 |
+
###############################################################
|
423 |
+
#
|
424 |
+
# Core loop starts here
|
425 |
+
# NOTE: the ratio between sub_batch_step and batch_step
|
426 |
+
# is the configured number of gradient_accumulation_steps
|
427 |
+
# i.e. with 32 configured gradient accumulation steps,
|
428 |
+
# there are 32 sub_batch_steps for each batch_step
|
429 |
+
#
|
430 |
+
###############################################################
|
431 |
+
|
432 |
+
for sub_batch_step, sub_batch in enumerate(
|
433 |
+
self.train_iterator, start=initial_sub_batch_step
|
434 |
+
):
|
435 |
+
# NOTE: We want to store the entire training batch whenever we are computing learning dynamics
|
436 |
+
# and we are at a checkpointing step.
|
437 |
+
should_store_training_batch = self.should_compute_learning_dynamics and (
|
438 |
+
batch_step % self.configs["checkpointing"].save_every_n_steps == 0
|
439 |
+
)
|
440 |
+
|
441 |
+
########################################################
|
442 |
+
#
|
443 |
+
# Forward Pass
|
444 |
+
#
|
445 |
+
########################################################
|
446 |
+
|
447 |
+
_input_ids = torch.tensor(sub_batch["input_ids"], device=self.fabric.device)
|
448 |
+
input_ids = _input_ids[:, :-1]
|
449 |
+
labels = _input_ids[:, 1:]
|
450 |
+
|
451 |
+
if should_store_training_batch:
|
452 |
+
gathered_input_ids = self.fabric.all_gather(_input_ids)
|
453 |
+
|
454 |
+
# NOTE: On multi-GPU, we need to reshape the input_ids to be a 2D tensor; on
|
455 |
+
# a single GPU, the input_ids are already a 2D tensor.
|
456 |
+
if self.fabric.world_size > 1:
|
457 |
+
gathered_input_ids = gathered_input_ids.reshape(
|
458 |
+
-1, *gathered_input_ids.shape[2:]
|
459 |
+
)
|
460 |
+
|
461 |
+
training_batch["input_ids"].extend(gathered_input_ids.tolist())
|
462 |
+
|
463 |
+
# Forward pass
|
464 |
+
model_output, _ = self.model(input_ids)
|
465 |
+
model_output = model_output.transpose(1, 2)
|
466 |
+
|
467 |
+
########################################################
|
468 |
+
#
|
469 |
+
# Gradient accumulation
|
470 |
+
#
|
471 |
+
########################################################
|
472 |
+
|
473 |
+
should_accumulate_gradients = (sub_batch_step + 1) % self.configs[
|
474 |
+
"training"
|
475 |
+
].optimization.gradient_accumulation_steps != 0
|
476 |
+
|
477 |
+
with self.fabric.no_backward_sync(
|
478 |
+
self.model, enabled=should_accumulate_gradients
|
479 |
+
):
|
480 |
+
loss = F.cross_entropy(model_output, labels)
|
481 |
+
self.fabric.backward(
|
482 |
+
loss
|
483 |
+
/ self.configs["training"].optimization.gradient_accumulation_steps,
|
484 |
+
model=self.model,
|
485 |
+
)
|
486 |
+
|
487 |
+
if torch.isnan(loss) or torch.isinf(loss):
|
488 |
+
interval_inf_or_nan_count += 1
|
489 |
+
else:
|
490 |
+
interval_loss += loss.item()
|
491 |
+
interval_steps += 1
|
492 |
+
|
493 |
+
# NOTE: if we are not accumulating gradients, we should skip the logging and optimization steps
|
494 |
+
if should_accumulate_gradients:
|
495 |
+
continue
|
496 |
+
|
497 |
+
########################################################
|
498 |
+
#
|
499 |
+
# Logging
|
500 |
+
#
|
501 |
+
########################################################
|
502 |
+
|
503 |
+
if batch_step % self.configs["monitoring"].logging.log_every_n_steps == 0:
|
504 |
+
self._log_training_metrics(
|
505 |
+
interval_loss=interval_loss,
|
506 |
+
interval_steps=interval_steps,
|
507 |
+
interval_inf_or_nan_count=interval_inf_or_nan_count,
|
508 |
+
batch_step=batch_step,
|
509 |
+
)
|
510 |
+
interval_loss = torch.tensor(0.0, device=self.fabric.device)
|
511 |
+
interval_steps = torch.tensor(0, device=self.fabric.device)
|
512 |
+
interval_inf_or_nan_count = torch.tensor(0, device=self.fabric.device)
|
513 |
+
|
514 |
+
########################################################
|
515 |
+
#
|
516 |
+
# Learning Dynamics Checkpointing
|
517 |
+
#
|
518 |
+
########################################################
|
519 |
+
|
520 |
+
if batch_step % self.configs["checkpointing"].save_every_n_steps == 0:
|
521 |
+
if self.should_compute_learning_dynamics:
|
522 |
+
self.log(f"Step {batch_step} -- 📈 Saving Learning Dynamics")
|
523 |
+
|
524 |
+
# Training Batch Learning Dynamics
|
525 |
+
training_batch_dataset = Dataset.from_dict(training_batch)
|
526 |
+
|
527 |
+
learning_dynamics_train_states = compute_learning_dynamics_states(
|
528 |
+
checkpointing_config=self.configs["checkpointing"],
|
529 |
+
fabric=self.fabric,
|
530 |
+
model=self.model,
|
531 |
+
dataset=training_batch_dataset,
|
532 |
+
compute_gradients=True,
|
533 |
+
)
|
534 |
+
|
535 |
+
save_learning_dynamics_states(
|
536 |
+
checkpointing_config=self.configs["checkpointing"],
|
537 |
+
checkpoint_step=batch_step,
|
538 |
+
prefix="train",
|
539 |
+
fabric=self.fabric,
|
540 |
+
learning_dynamics_states=learning_dynamics_train_states,
|
541 |
+
learning_dynamics_dataset=training_batch_dataset,
|
542 |
+
tokenizer=self.tokenizer,
|
543 |
+
)
|
544 |
+
training_batch = {
|
545 |
+
"input_ids": []
|
546 |
+
} # Resetting training_batch for next training batch
|
547 |
+
|
548 |
+
# Validation Data Learning Dynamics
|
549 |
+
if self.learning_dynamics_eval_dataset is not None:
|
550 |
+
learning_dynamics_val_states = compute_learning_dynamics_states(
|
551 |
+
checkpointing_config=self.configs["checkpointing"],
|
552 |
+
fabric=self.fabric,
|
553 |
+
model=self.model,
|
554 |
+
dataset=self.learning_dynamics_eval_dataset,
|
555 |
+
compute_gradients=True,
|
556 |
+
)
|
557 |
+
save_learning_dynamics_states(
|
558 |
+
checkpointing_config=self.configs["checkpointing"],
|
559 |
+
checkpoint_step=batch_step,
|
560 |
+
prefix="val",
|
561 |
+
fabric=self.fabric,
|
562 |
+
learning_dynamics_states=learning_dynamics_val_states,
|
563 |
+
)
|
564 |
+
|
565 |
+
########################################################
|
566 |
+
#
|
567 |
+
# Optimization step
|
568 |
+
#
|
569 |
+
########################################################
|
570 |
+
|
571 |
+
self.optimizer.step()
|
572 |
+
self.optimizer.zero_grad()
|
573 |
+
self.lr_scheduler.step()
|
574 |
+
|
575 |
+
batch_step += 1
|
576 |
+
|
577 |
+
########################################################
|
578 |
+
#
|
579 |
+
# Training Checkpointing and evaluation
|
580 |
+
#
|
581 |
+
########################################################
|
582 |
+
|
583 |
+
if batch_step % self.configs["checkpointing"].save_every_n_steps == 0:
|
584 |
+
self.log(f"Step {batch_step} -- 💾 Saving Checkpoint")
|
585 |
+
save_checkpoint(
|
586 |
+
configs=self.configs,
|
587 |
+
checkpoint_step=batch_step,
|
588 |
+
fabric=self.fabric,
|
589 |
+
model=self.model,
|
590 |
+
optimizer=self.optimizer,
|
591 |
+
lr_scheduler=self.lr_scheduler,
|
592 |
+
tokenizer=self.tokenizer,
|
593 |
+
)
|
594 |
+
|
595 |
+
if self.should_evaluate:
|
596 |
+
evaluation_results = run_evaluation(
|
597 |
+
evaluation_config=self.configs["evaluation"],
|
598 |
+
checkpointing_config=self.configs["checkpointing"],
|
599 |
+
fabric=self.fabric,
|
600 |
+
model=self.model,
|
601 |
+
)
|
602 |
+
if evaluation_results is not None:
|
603 |
+
self._log_evaluation_results(evaluation_results, batch_step)
|
604 |
+
save_evaluation_results(
|
605 |
+
checkpointing_config=self.configs["checkpointing"],
|
606 |
+
fabric=self.fabric,
|
607 |
+
evaluation_results=evaluation_results,
|
608 |
+
checkpoint_step=batch_step,
|
609 |
+
)
|
610 |
+
|
611 |
+
# Break if we've reached training steps
|
612 |
+
if batch_step >= self.configs["training"].max_steps:
|
613 |
+
break
|
614 |
+
|
615 |
+
return batch_step
|
616 |
+
|
617 |
+
########################################################
|
618 |
+
#
|
619 |
+
# Trainer Logging Functinalities
|
620 |
+
#
|
621 |
+
########################################################
|
622 |
+
|
623 |
+
def _log_training_metrics(
|
624 |
+
self,
|
625 |
+
interval_loss: torch.Tensor,
|
626 |
+
interval_steps: torch.Tensor,
|
627 |
+
interval_inf_or_nan_count: torch.Tensor,
|
628 |
+
batch_step: int,
|
629 |
+
):
|
630 |
+
"""
|
631 |
+
Gathers together the training metrics computed across all processes in distributed training
|
632 |
+
and logs them in a tree-style format.
|
633 |
+
"""
|
634 |
+
gathered_interval_loss = self.fabric.all_reduce(
|
635 |
+
interval_loss, reduce_op="sum"
|
636 |
+
).item()
|
637 |
+
gathered_interval_inf_or_nan_count = self.fabric.all_reduce(
|
638 |
+
interval_inf_or_nan_count, reduce_op="sum"
|
639 |
+
).item()
|
640 |
+
gathered_interval_steps = self.fabric.all_reduce(
|
641 |
+
interval_steps, reduce_op="sum"
|
642 |
+
).item()
|
643 |
+
|
644 |
+
avg_loss = (
|
645 |
+
gathered_interval_loss / gathered_interval_steps
|
646 |
+
if gathered_interval_steps > 0
|
647 |
+
else float("inf")
|
648 |
+
)
|
649 |
+
|
650 |
+
self.fabric.log("train/loss", avg_loss, step=batch_step)
|
651 |
+
self.fabric.log(
|
652 |
+
"trainer/inf_or_nan_count",
|
653 |
+
gathered_interval_inf_or_nan_count,
|
654 |
+
step=batch_step,
|
655 |
+
)
|
656 |
+
self.fabric.log(
|
657 |
+
"trainer/learning_rate",
|
658 |
+
self.lr_scheduler.get_last_lr()[0],
|
659 |
+
step=batch_step,
|
660 |
+
)
|
661 |
+
|
662 |
+
# Log to console in tree format
|
663 |
+
self.log(f"Step {batch_step} -- 🔄 Training Metrics")
|
664 |
+
self.log(f"├── Loss: {avg_loss:.4f}")
|
665 |
+
self.log(f"├── Learning Rate: {self.lr_scheduler.get_last_lr()[0]:.2e}")
|
666 |
+
self.log(f"└── Inf/NaN count: {gathered_interval_inf_or_nan_count}")
|
667 |
+
|
668 |
+
def _log_evaluation_results(
|
669 |
+
self, evaluation_results: Dict[str, Any], batch_step: int
|
670 |
+
):
|
671 |
+
"""Log model evaluation metrics to experiment tracking system and console."""
|
672 |
+
self.log(f"Step {batch_step} -- 📊 Evaluation Results")
|
673 |
+
for i, (metric, result) in enumerate(evaluation_results.items()):
|
674 |
+
prefix = "└──" if i == len(evaluation_results) - 1 else "├──"
|
675 |
+
self.log(f"{prefix} {metric}: {result}")
|
676 |
+
self.fabric.log(f"eval/{metric}", result, step=batch_step)
|
677 |
+
|
678 |
+
def _log_training_configuration(self):
|
679 |
+
"""
|
680 |
+
Log training configuration details as well as runtime information about the hardware,
|
681 |
+
software, and batch settings.
|
682 |
+
|
683 |
+
This function is called at the beginning of the training loop to provide a summary of the
|
684 |
+
training configuration.
|
685 |
+
"""
|
686 |
+
|
687 |
+
total_params = sum(p.numel() for p in self.model.parameters())
|
688 |
+
trainable_params = sum(
|
689 |
+
p.numel() for p in self.model.parameters() if p.requires_grad
|
690 |
+
)
|
691 |
+
global_batch_size = self.configs["data"].dataloader.batch_size
|
692 |
+
per_device_batch_size = self.train_dataloader.batch_size
|
693 |
+
gradient_accumulation_steps = self.configs[
|
694 |
+
"training"
|
695 |
+
].optimization.gradient_accumulation_steps
|
696 |
+
|
697 |
+
device_type = ""
|
698 |
+
fabric_device = str(self.fabric.device)
|
699 |
+
if torch.cuda.is_available() and "cuda" in fabric_device:
|
700 |
+
device_type = torch.cuda.get_device_name(self.fabric.device)
|
701 |
+
elif torch.backends.mps.is_available() and "mps" in fabric_device:
|
702 |
+
device_type = "MPS (Apple Silicon)"
|
703 |
+
else:
|
704 |
+
device_type = "CPU"
|
705 |
+
|
706 |
+
training_config_path = os.path.join(
|
707 |
+
self.configs["checkpointing"].runs_dir,
|
708 |
+
self.configs["checkpointing"].run_name,
|
709 |
+
"training_config.yaml",
|
710 |
+
)
|
711 |
+
if os.path.exists(training_config_path):
|
712 |
+
self.log("=" * 50)
|
713 |
+
self.log("✨ Training Configuration")
|
714 |
+
self.log("=" * 50)
|
715 |
+
training_config = yaml.safe_load(open(training_config_path, "r"))
|
716 |
+
pretty_print_yaml_config(self.logger, training_config)
|
717 |
+
|
718 |
+
self.log("=" * 50)
|
719 |
+
self.log("⛭ Runtime Summary:")
|
720 |
+
self.log("=" * 50)
|
721 |
+
self.log(f"Starting from step: {self.initial_batch_step}")
|
722 |
+
|
723 |
+
self.log("Model Setup:")
|
724 |
+
self.log(f"└─ Total Parameters: {total_params:,}")
|
725 |
+
self.log(f"└─ Trainable Parameters: {trainable_params:,}")
|
726 |
+
|
727 |
+
self.log("Distributed Setup:")
|
728 |
+
self.log(f"└─ Number of Devices: {self.fabric.world_size}")
|
729 |
+
self.log(f"└─ Device Type: {device_type}")
|
730 |
+
self.log(
|
731 |
+
f"└─ Available Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB"
|
732 |
+
if torch.cuda.is_available()
|
733 |
+
else f"└─ Available Memory: {psutil.virtual_memory().total / 1e9:.2f} GB"
|
734 |
+
)
|
735 |
+
|
736 |
+
self.log("Software Setup:")
|
737 |
+
self.log(f"└─ Python Version: {platform.python_version()}")
|
738 |
+
self.log(f"└─ PyTorch Version: {torch.__version__}")
|
739 |
+
self.log(
|
740 |
+
f"└─ CUDA Version: {torch.version.cuda if torch.cuda.is_available() else 'N/A'}"
|
741 |
+
)
|
742 |
+
self.log(f"└─ Operating System: {platform.system()} {platform.release()}")
|
743 |
+
|
744 |
+
self.log("Batch Size Configuration:")
|
745 |
+
self.log(f"└─ Global Batch Size: {global_batch_size}")
|
746 |
+
self.log(f"└─ Per Device Batch Size: {per_device_batch_size}")
|
747 |
+
self.log(f"└─ Gradient Accumulation Steps: {gradient_accumulation_steps}")
|
748 |
+
self.log("=" * 50)
|
749 |
+
|
750 |
+
@rank_zero_only
|
751 |
+
def log(self, msg: str, level: int = logging.INFO) -> None:
|
752 |
+
"""NOTE: Log messages only from rank zero process."""
|
753 |
+
self.logger.log(level, msg)
|
src/training/utils/__init__.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Utility package that contains functions for the training process, e.g. initialization, logging, etc.
|
3 |
+
"""
|
4 |
+
|
5 |
+
# For convenience, we export the initialization functions here
|
6 |
+
from .initialization import (
|
7 |
+
initialize_configuration,
|
8 |
+
initialize_dataloader,
|
9 |
+
initialize_dataset,
|
10 |
+
initialize_fabric,
|
11 |
+
initialize_hf_checkpointing,
|
12 |
+
initialize_logging,
|
13 |
+
initialize_lr_scheduler,
|
14 |
+
initialize_model,
|
15 |
+
initialize_optimizer,
|
16 |
+
initialize_run_dir,
|
17 |
+
initialize_tokenizer,
|
18 |
+
initialize_wandb,
|
19 |
+
)
|
20 |
+
|
21 |
+
__all__ = [
|
22 |
+
"initialize_configuration",
|
23 |
+
"initialize_dataloader",
|
24 |
+
"initialize_dataset",
|
25 |
+
"initialize_fabric",
|
26 |
+
"initialize_hf_checkpointing",
|
27 |
+
"initialize_logging",
|
28 |
+
"initialize_lr_scheduler",
|
29 |
+
"initialize_model",
|
30 |
+
"initialize_optimizer",
|
31 |
+
"initialize_run_dir",
|
32 |
+
"initialize_tokenizer",
|
33 |
+
"initialize_wandb",
|
34 |
+
]
|
src/training/utils/data.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Utilities for data loading and processing.
|
3 |
+
"""
|
4 |
+
|
5 |
+
from torch.utils.data import IterableDataset
|
6 |
+
|
7 |
+
|
8 |
+
class ShardedIterableDataset(IterableDataset):
|
9 |
+
"""
|
10 |
+
A super simple implementation of a sharded iterable dataset that enables DataParallelism
|
11 |
+
across multiple workers. Ensures that each worker gets a unique shard of the dataset.
|
12 |
+
|
13 |
+
NOTE: Also works fine if there is only one worker.
|
14 |
+
"""
|
15 |
+
|
16 |
+
def __init__(self, dataset, rank, world_size):
|
17 |
+
self.dataset = dataset
|
18 |
+
self.rank = rank
|
19 |
+
self.world_size = world_size
|
20 |
+
|
21 |
+
def __iter__(self):
|
22 |
+
iterator = iter(self.dataset)
|
23 |
+
# NOTE: Start by skipping to this worker's shard
|
24 |
+
for _ in range(self.rank):
|
25 |
+
next(iterator)
|
26 |
+
|
27 |
+
# NOTE: Yield every world_size-th item
|
28 |
+
while True:
|
29 |
+
try:
|
30 |
+
yield next(iterator)
|
31 |
+
# Skip other workers' samples
|
32 |
+
for _ in range(self.world_size - 1):
|
33 |
+
next(iterator)
|
34 |
+
except StopIteration:
|
35 |
+
break
|
src/training/utils/initialization.py
ADDED
@@ -0,0 +1,702 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Utilities for initializing components of the training process.
|
3 |
+
|
4 |
+
Here, we initialize all of the components that are part of the learning process. From logging,
|
5 |
+
and checkpointing to the optimizer to the dataset and the dataloader, this file contains the
|
6 |
+
logic for setting up the classes and functions that are used in the training loop.
|
7 |
+
|
8 |
+
As always, this code is meant to be basic. We hard-code the obvious defaults, and leave the
|
9 |
+
more experimental stuff to you.
|
10 |
+
"""
|
11 |
+
|
12 |
+
import logging
|
13 |
+
import math
|
14 |
+
import os
|
15 |
+
import warnings
|
16 |
+
from dataclasses import fields, is_dataclass
|
17 |
+
from datetime import datetime
|
18 |
+
from typing import Dict, Optional, Union
|
19 |
+
|
20 |
+
import lightning as L
|
21 |
+
import torch
|
22 |
+
import yaml
|
23 |
+
from datasets import Dataset, DownloadConfig, load_dataset
|
24 |
+
from datasets import config as datasets_config
|
25 |
+
from huggingface_hub import add_collection_item, create_branch, create_repo
|
26 |
+
from lightning.fabric.loggers import Logger as FabricLogger
|
27 |
+
from lightning.fabric.utilities.rank_zero import rank_zero_only
|
28 |
+
from torch.utils.data import DataLoader
|
29 |
+
from transformers import AutoTokenizer
|
30 |
+
|
31 |
+
import wandb
|
32 |
+
from src.config import (
|
33 |
+
CheckpointingConfig,
|
34 |
+
DataConfig,
|
35 |
+
EvaluationConfig,
|
36 |
+
ModelConfig,
|
37 |
+
MonitoringConfig,
|
38 |
+
TrainingConfig,
|
39 |
+
)
|
40 |
+
from src.model import PicoDecoder
|
41 |
+
from src.training.utils.io import use_backoff
|
42 |
+
from wandb.integration.lightning.fabric import WandbLogger
|
43 |
+
|
44 |
+
warnings.filterwarnings(
|
45 |
+
"ignore",
|
46 |
+
message=".*This integration is tested and supported for lightning Fabric.*",
|
47 |
+
)
|
48 |
+
warnings.filterwarnings(
|
49 |
+
"ignore",
|
50 |
+
message=".*Please report any issues to.*",
|
51 |
+
)
|
52 |
+
|
53 |
+
########################################################
|
54 |
+
#
|
55 |
+
# Basic Initialization
|
56 |
+
#
|
57 |
+
########################################################
|
58 |
+
|
59 |
+
|
60 |
+
def _apply_config_overrides(config, overrides: dict):
|
61 |
+
"""Recursively apply configuration overrides to a dataclass config object.
|
62 |
+
|
63 |
+
Args:
|
64 |
+
config: Base configuration object (must be a dataclass)
|
65 |
+
overrides: Dictionary of override values matching config structure
|
66 |
+
|
67 |
+
Returns:
|
68 |
+
Modified config object with overrides to the config.
|
69 |
+
"""
|
70 |
+
for field in fields(config):
|
71 |
+
field_value = getattr(config, field.name)
|
72 |
+
if is_dataclass(field_value):
|
73 |
+
_apply_config_overrides(field_value, overrides.get(field.name, {}))
|
74 |
+
else:
|
75 |
+
if field.name in overrides:
|
76 |
+
setattr(config, field.name, overrides[field.name])
|
77 |
+
return config
|
78 |
+
|
79 |
+
|
80 |
+
def initialize_configuration(
|
81 |
+
config_path: Optional[str] = None,
|
82 |
+
) -> Dict[
|
83 |
+
str,
|
84 |
+
Union[
|
85 |
+
DataConfig,
|
86 |
+
ModelConfig,
|
87 |
+
TrainingConfig,
|
88 |
+
EvaluationConfig,
|
89 |
+
MonitoringConfig,
|
90 |
+
CheckpointingConfig,
|
91 |
+
],
|
92 |
+
]:
|
93 |
+
"""Initialize configuration objects with optional overrides from a YAML file.
|
94 |
+
|
95 |
+
This function initializes all of the configuration objects, and then applies
|
96 |
+
any overrides from the config_path file. If no config_path is provided,
|
97 |
+
the function will use the default configuration objects.
|
98 |
+
|
99 |
+
Args:
|
100 |
+
config_path: Path to a YAML file containing configuration overrides.
|
101 |
+
|
102 |
+
Returns:
|
103 |
+
A dictionary containing the initialized configuration objects.
|
104 |
+
"""
|
105 |
+
data_config = DataConfig()
|
106 |
+
model_config = ModelConfig()
|
107 |
+
training_config = TrainingConfig()
|
108 |
+
evaluation_config = EvaluationConfig()
|
109 |
+
monitoring_config = MonitoringConfig()
|
110 |
+
checkpointing_config = CheckpointingConfig()
|
111 |
+
|
112 |
+
if config_path:
|
113 |
+
overrides = yaml.safe_load(open(config_path, "r"))
|
114 |
+
data_config = _apply_config_overrides(data_config, overrides.get("data", {}))
|
115 |
+
model_config = _apply_config_overrides(model_config, overrides.get("model", {}))
|
116 |
+
training_config = _apply_config_overrides(
|
117 |
+
training_config, overrides.get("training", {})
|
118 |
+
)
|
119 |
+
evaluation_config = _apply_config_overrides(
|
120 |
+
evaluation_config, overrides.get("evaluation", {})
|
121 |
+
)
|
122 |
+
monitoring_config = _apply_config_overrides(
|
123 |
+
monitoring_config, overrides.get("monitoring", {})
|
124 |
+
)
|
125 |
+
checkpointing_config = _apply_config_overrides(
|
126 |
+
checkpointing_config, overrides.get("checkpointing", {})
|
127 |
+
)
|
128 |
+
|
129 |
+
configs = {
|
130 |
+
"data": data_config,
|
131 |
+
"model": model_config,
|
132 |
+
"training": training_config,
|
133 |
+
"evaluation": evaluation_config,
|
134 |
+
"monitoring": monitoring_config,
|
135 |
+
"checkpointing": checkpointing_config,
|
136 |
+
}
|
137 |
+
|
138 |
+
return configs
|
139 |
+
|
140 |
+
|
141 |
+
def initialize_run_dir(checkpointing_config: CheckpointingConfig) -> str:
|
142 |
+
"""Initialize a directory for the current training run.
|
143 |
+
|
144 |
+
Creates a unique directory for storing training, evaluation, and logging artifacts.
|
145 |
+
If no run name is specified in the config, generates a timestamp-based name.
|
146 |
+
|
147 |
+
Args:
|
148 |
+
checkpointing_config: Configuration object containing run settings.
|
149 |
+
NOTE: Must have a 'run_name' attribute that can be None, in which case
|
150 |
+
a timestamp-based name will be generated.
|
151 |
+
|
152 |
+
Returns:
|
153 |
+
str: The path to the run directory.
|
154 |
+
"""
|
155 |
+
run_name = checkpointing_config.run_name
|
156 |
+
if run_name is None:
|
157 |
+
run_name = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
158 |
+
checkpointing_config.run_name = run_name
|
159 |
+
|
160 |
+
run_dir = os.path.join(checkpointing_config.runs_dir, run_name)
|
161 |
+
|
162 |
+
os.makedirs(run_dir, exist_ok=True)
|
163 |
+
return run_dir
|
164 |
+
|
165 |
+
|
166 |
+
def initialize_fabric(
|
167 |
+
training_config: TrainingConfig, wandb_logger: Optional[FabricLogger] = None
|
168 |
+
):
|
169 |
+
"""Initialize Lightning Fabric for distributed training.
|
170 |
+
|
171 |
+
Sets up a Lightning Fabric instance with the specified configuration for
|
172 |
+
handling distributed training, mixed precision, and logging.
|
173 |
+
|
174 |
+
Args:
|
175 |
+
training_config: Configuration object containing fabric settings
|
176 |
+
(accelerator, precision, devices, etc.).
|
177 |
+
wandb_logger: Optional weights and biases logger instance for experiment tracking
|
178 |
+
|
179 |
+
Returns:
|
180 |
+
L.Fabric: Initialized Lightning Fabric instance.
|
181 |
+
|
182 |
+
Example:
|
183 |
+
>>> fabric = initialize_fabric(training_config, wandb_logger)
|
184 |
+
"""
|
185 |
+
|
186 |
+
total_devices = (
|
187 |
+
training_config.fabric.num_devices * training_config.fabric.num_nodes
|
188 |
+
)
|
189 |
+
|
190 |
+
if total_devices > 1:
|
191 |
+
strategy = "deepspeed_stage_2"
|
192 |
+
else:
|
193 |
+
strategy = "auto" # Sets up SingleDevice Strategy by default
|
194 |
+
|
195 |
+
# NOTE: The strategy is set to use either DeepSpeed (Zero Stage 2) on multi-GPU,
|
196 |
+
# or SingleDevice Strategy on single-GPU set ups. If you'd like to use a different strategy,
|
197 |
+
# you can change the strategy flag in the fabric initialization, but be aware that this might
|
198 |
+
# cause issues with checkpointing, evaluation, etc.
|
199 |
+
|
200 |
+
fabric = L.Fabric(
|
201 |
+
accelerator=training_config.fabric.accelerator,
|
202 |
+
precision=training_config.fabric.precision,
|
203 |
+
devices=training_config.fabric.num_devices,
|
204 |
+
num_nodes=training_config.fabric.num_nodes,
|
205 |
+
loggers=[wandb_logger] if wandb_logger is not None else None,
|
206 |
+
strategy=strategy,
|
207 |
+
)
|
208 |
+
|
209 |
+
fabric.launch()
|
210 |
+
|
211 |
+
return fabric
|
212 |
+
|
213 |
+
|
214 |
+
########################################################
|
215 |
+
#
|
216 |
+
# Dataset and Tokenization Initialization
|
217 |
+
#
|
218 |
+
########################################################
|
219 |
+
|
220 |
+
|
221 |
+
@use_backoff(max_retries=20)
|
222 |
+
def initialize_dataset(
|
223 |
+
data_config: DataConfig,
|
224 |
+
fabric: L.Fabric,
|
225 |
+
initial_batch_step: Optional[int] = 0,
|
226 |
+
return_fast_forward_steps: bool = False,
|
227 |
+
):
|
228 |
+
"""Initialize dataset based on the given config.
|
229 |
+
|
230 |
+
This function will return a dataset object, and optionally a fast_forward_steps value.
|
231 |
+
|
232 |
+
The fast_forward_steps value is the number of steps that we need to fast-forward an iterator by,
|
233 |
+
so that we can continue from a ertain batch of data we would have seen had training not previously
|
234 |
+
stopped. Depending on how the dataset is loaded, the amount of steps to fast-forward may be
|
235 |
+
different from the initial_batch_step value.
|
236 |
+
|
237 |
+
NOTE: This functionality is primarily useful for streaming datasets (which for large
|
238 |
+
datasets is most of the time).
|
239 |
+
|
240 |
+
Args:
|
241 |
+
data_config: Configuration object containing dataset settings.
|
242 |
+
fabric: A Lightning Fabric instance.
|
243 |
+
initial_batch_step: The initial batch step to fast-forward to.
|
244 |
+
return_fast_forward_steps: Whether to return the fast-forward steps value.
|
245 |
+
|
246 |
+
Returns:
|
247 |
+
Dataset: Initialized dataset object.
|
248 |
+
Optional[int]: Number of steps to fast-forward the iterator by, if return_fast_forward_steps is True.
|
249 |
+
"""
|
250 |
+
|
251 |
+
datasets_config.STREAMING_READ_MAX_RETRIES = 40 # default is 20
|
252 |
+
datasets_config.STREAMING_READ_RETRY_INTERVAL = 10 # default is 5
|
253 |
+
download_config = DownloadConfig(
|
254 |
+
max_retries=20, # default is 1 and can lead to pre-mature HTTPS errors
|
255 |
+
)
|
256 |
+
|
257 |
+
fast_forward_steps = 0
|
258 |
+
|
259 |
+
if data_config.dataset.name == "pico-lm/pretokenized-dolma":
|
260 |
+
# NOTE: We know that the dataset is sharded into 10,000 shards, so we can easily compute
|
261 |
+
# the data file that we need to load in that contains the batch of data at
|
262 |
+
# initial_batch_step.
|
263 |
+
|
264 |
+
if initial_batch_step is not None:
|
265 |
+
examples_per_shard = 20_480
|
266 |
+
total_shards = 10_000
|
267 |
+
batches_per_shard = examples_per_shard // data_config.dataloader.batch_size
|
268 |
+
shard_idx = initial_batch_step // batches_per_shard
|
269 |
+
|
270 |
+
data_files = [
|
271 |
+
f"data/train-{str(_shard_idx).zfill(5)}-of-{total_shards}.parquet"
|
272 |
+
for _shard_idx in range(shard_idx, total_shards)
|
273 |
+
]
|
274 |
+
|
275 |
+
fast_forward_steps = initial_batch_step % batches_per_shard
|
276 |
+
else:
|
277 |
+
data_files = None
|
278 |
+
|
279 |
+
base_dataset = load_dataset(
|
280 |
+
data_config.dataset.name,
|
281 |
+
split="train",
|
282 |
+
streaming=True,
|
283 |
+
data_files=data_files,
|
284 |
+
download_config=download_config,
|
285 |
+
)
|
286 |
+
else:
|
287 |
+
# NOTE: For other datasets, you might want to add some custom loading logic, especially
|
288 |
+
# to help with loading or fast-forwarding to the correct batch.
|
289 |
+
|
290 |
+
base_dataset = load_dataset(
|
291 |
+
data_config.dataset.name,
|
292 |
+
split="train",
|
293 |
+
streaming=True,
|
294 |
+
download_config=download_config,
|
295 |
+
)
|
296 |
+
|
297 |
+
if data_config.dataset.name == "pico-lm/pretokenized-dolma":
|
298 |
+
from .data import ShardedIterableDataset
|
299 |
+
|
300 |
+
# NOTE: We wrap the dataset in a ShardedIterableDataset, which is a custom class that
|
301 |
+
# allows us to shard an iterable dataset across multiple processes. This is useful for
|
302 |
+
# distributed training, where we want data-parallelism.
|
303 |
+
dataset = ShardedIterableDataset(
|
304 |
+
base_dataset, fabric.global_rank, fabric.world_size
|
305 |
+
)
|
306 |
+
else:
|
307 |
+
dataset = base_dataset
|
308 |
+
|
309 |
+
if return_fast_forward_steps:
|
310 |
+
return dataset, fast_forward_steps
|
311 |
+
else:
|
312 |
+
return dataset
|
313 |
+
|
314 |
+
|
315 |
+
def initialize_tokenizer(data_config: DataConfig):
|
316 |
+
"""Initialize the tokenizer for text processing.
|
317 |
+
|
318 |
+
This function can be extended to include custom tokenization logic.
|
319 |
+
|
320 |
+
Args:
|
321 |
+
data_config: Configuration object containing tokenizer settings.
|
322 |
+
|
323 |
+
Returns:
|
324 |
+
AutoTokenizer: A HuggingFace tokenizer instance.
|
325 |
+
"""
|
326 |
+
|
327 |
+
return AutoTokenizer.from_pretrained(data_config.tokenizer.name)
|
328 |
+
|
329 |
+
|
330 |
+
def initialize_dataloader(
|
331 |
+
data_config: DataConfig,
|
332 |
+
training_config: TrainingConfig,
|
333 |
+
fabric: L.Fabric,
|
334 |
+
dataset: Dataset,
|
335 |
+
):
|
336 |
+
"""Initialize the DataLoader for efficient batch processing.
|
337 |
+
|
338 |
+
Creates a PyTorch DataLoader that handles batching and data loading for training.
|
339 |
+
Configured specifically for streaming tokenized text datasets.
|
340 |
+
|
341 |
+
You might also want to extend this function to add a sampler, or some sort of custom
|
342 |
+
collate function. For the default dataset, we don't need any of this, because the data are
|
343 |
+
pre-shuffled, and pre-tokenized.
|
344 |
+
|
345 |
+
Args:
|
346 |
+
data_config: Configuration object containing dataloader settings.
|
347 |
+
training_config: Configuration object containing training settings.
|
348 |
+
fabric: A Lightning Fabric instance.
|
349 |
+
dataset: A HuggingFace Dataset object containing tokenized text data.
|
350 |
+
Expected to have 'input_ids' field in its items.
|
351 |
+
|
352 |
+
Returns:
|
353 |
+
DataLoader: PyTorch DataLoader instance configured for the dataset.
|
354 |
+
"""
|
355 |
+
|
356 |
+
def _collate_fn(batch):
|
357 |
+
return {"input_ids": [entry["input_ids"] for entry in batch]}
|
358 |
+
|
359 |
+
sub_batch_size = data_config.dataloader.batch_size // (
|
360 |
+
fabric.world_size * training_config.optimization.gradient_accumulation_steps
|
361 |
+
)
|
362 |
+
|
363 |
+
# NOTE: We use the sub-batch size for the dataloader, which is the full batch size
|
364 |
+
# divided by the gradient accumulation steps. This ensures that the effective batch size
|
365 |
+
# is correct.
|
366 |
+
|
367 |
+
return DataLoader(
|
368 |
+
dataset,
|
369 |
+
batch_size=sub_batch_size,
|
370 |
+
shuffle=False, # Keep sequential for streaming datasets
|
371 |
+
pin_memory=True, # Speeds up transfer to GPU
|
372 |
+
collate_fn=_collate_fn,
|
373 |
+
)
|
374 |
+
|
375 |
+
|
376 |
+
########################################################
|
377 |
+
#
|
378 |
+
# Model Initialization
|
379 |
+
#
|
380 |
+
########################################################
|
381 |
+
|
382 |
+
|
383 |
+
def initialize_model(model_config: ModelConfig):
|
384 |
+
"""Initialize the model for training.
|
385 |
+
|
386 |
+
Loads in a given model implemented in the `src.model` package and returns it.
|
387 |
+
|
388 |
+
NOTE: out of the box we currently only support the PicoDecoder model (a causal transformer
|
389 |
+
language model). If you'd like to implement your own model, you can do so by adding a new
|
390 |
+
model class in the `src.model` package, and then adding a new entry here.
|
391 |
+
|
392 |
+
Args:
|
393 |
+
model_config: Configuration object containing model settings.
|
394 |
+
|
395 |
+
Returns:
|
396 |
+
PyTorch model instance.
|
397 |
+
|
398 |
+
"""
|
399 |
+
if model_config.model_type == "pico_decoder":
|
400 |
+
return PicoDecoder(model_config)
|
401 |
+
else:
|
402 |
+
raise ValueError(f"Invalid model type: {model_config.model_type}")
|
403 |
+
|
404 |
+
|
405 |
+
########################################################
|
406 |
+
#
|
407 |
+
# Optimizer and Scheduler
|
408 |
+
#
|
409 |
+
########################################################
|
410 |
+
|
411 |
+
|
412 |
+
def initialize_optimizer(training_config: TrainingConfig, model: torch.nn.Module):
|
413 |
+
"""Initialize the optimizer for model training.
|
414 |
+
|
415 |
+
Creates an optimizer instance based on the configuration settings.
|
416 |
+
|
417 |
+
Add whatever other optimizers you want here.
|
418 |
+
|
419 |
+
Args:
|
420 |
+
training_config: Configuration object containing optimizer settings.
|
421 |
+
Must have:
|
422 |
+
- optimization.optimizer (str): Name of the optimizer ("adamw")
|
423 |
+
- optimization.lr (float): Learning rate for the optimizer
|
424 |
+
model: PyTorch model whose parameters will be optimized.
|
425 |
+
|
426 |
+
Returns:
|
427 |
+
torch.optim.Optimizer: Configured optimizer instance.
|
428 |
+
|
429 |
+
"""
|
430 |
+
|
431 |
+
if training_config.optimization.optimizer == "adamw":
|
432 |
+
optimizer = torch.optim.AdamW(
|
433 |
+
model.parameters(), lr=training_config.optimization.lr
|
434 |
+
)
|
435 |
+
else:
|
436 |
+
raise ValueError(f"Invalid optimizer: {training_config.optimization.optimizer}")
|
437 |
+
|
438 |
+
return optimizer
|
439 |
+
|
440 |
+
|
441 |
+
def initialize_lr_scheduler(
|
442 |
+
training_config: TrainingConfig, optimizer: torch.optim.Optimizer
|
443 |
+
):
|
444 |
+
"""Initialize a learning rate scheduler with warmup and decay.
|
445 |
+
|
446 |
+
The default is a learning rate scheduler that implements a linear warmup followed by
|
447 |
+
linear decay. The learning rate increases linearly from 0 to the initial lr
|
448 |
+
during warmup, then decreases linearly to 0 during the remaining steps.
|
449 |
+
|
450 |
+
Add other types of learning rate schedulers here.
|
451 |
+
|
452 |
+
Args:
|
453 |
+
training_config: Configuration object containing optimizer and scheduler settings.
|
454 |
+
optimizer: PyTorch optimizer whose learning rate will be scheduled.
|
455 |
+
|
456 |
+
Returns:
|
457 |
+
torch.optim.lr_scheduler.LambdaLR: Learning rate scheduler instance.
|
458 |
+
"""
|
459 |
+
|
460 |
+
if training_config.optimization.lr_scheduler == "linear_with_warmup":
|
461 |
+
# Credit where credit is due:
|
462 |
+
# https://github.com/huggingface/transformers/blob/e71a01a104dd663c730e494eb0b6467bb51df357/src/transformers/optimization.py#L102
|
463 |
+
def _lr_lambda(curr_step, num_warmup_steps, max_steps):
|
464 |
+
if curr_step < num_warmup_steps:
|
465 |
+
return float(curr_step) / float(max(1, num_warmup_steps))
|
466 |
+
else:
|
467 |
+
return max(
|
468 |
+
0.0,
|
469 |
+
float(max_steps - curr_step)
|
470 |
+
/ float(max(1, max_steps - num_warmup_steps)),
|
471 |
+
)
|
472 |
+
|
473 |
+
lr_lambda = lambda step: _lr_lambda( # noqa: E731
|
474 |
+
step,
|
475 |
+
training_config.optimization.lr_warmup_steps,
|
476 |
+
training_config.max_steps,
|
477 |
+
)
|
478 |
+
lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
|
479 |
+
optimizer,
|
480 |
+
lr_lambda,
|
481 |
+
)
|
482 |
+
elif training_config.optimization.lr_scheduler == "cosine":
|
483 |
+
# Cosine decay with warmup: linear warmup followed by cosine decay
|
484 |
+
# This provides sustained learning over long training runs
|
485 |
+
def _cosine_lr_lambda(curr_step, num_warmup_steps, max_steps):
|
486 |
+
if curr_step < num_warmup_steps:
|
487 |
+
# Linear warmup
|
488 |
+
return float(curr_step) / float(max(1, num_warmup_steps))
|
489 |
+
else:
|
490 |
+
# Cosine decay to 0.1 * initial_lr (not to 0)
|
491 |
+
progress = float(curr_step - num_warmup_steps) / float(
|
492 |
+
max(1, max_steps - num_warmup_steps)
|
493 |
+
)
|
494 |
+
return max(0.1, 0.5 * (1.0 + math.cos(math.pi * progress)))
|
495 |
+
|
496 |
+
lr_lambda = lambda step: _cosine_lr_lambda( # noqa: E731
|
497 |
+
step,
|
498 |
+
training_config.optimization.lr_warmup_steps,
|
499 |
+
training_config.max_steps,
|
500 |
+
)
|
501 |
+
lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
|
502 |
+
optimizer,
|
503 |
+
lr_lambda,
|
504 |
+
)
|
505 |
+
else:
|
506 |
+
raise ValueError(
|
507 |
+
f"Invalid learning rate scheduler: {training_config.optimization.lr_scheduler}"
|
508 |
+
)
|
509 |
+
|
510 |
+
return lr_scheduler
|
511 |
+
|
512 |
+
|
513 |
+
########################################################
|
514 |
+
#
|
515 |
+
# Experiment Monitoring (Logging, Experiment Tracking, etc.)
|
516 |
+
#
|
517 |
+
########################################################
|
518 |
+
|
519 |
+
|
520 |
+
def _initialize_log_file(checkpointing_config: CheckpointingConfig) -> str:
|
521 |
+
"""Create and initialize a timestamped log file in the run's log directory.
|
522 |
+
|
523 |
+
Sets up a log file with a unique timestamp in the run's logging directory.
|
524 |
+
Creates the necessary directory structure if it doesn't exist.
|
525 |
+
|
526 |
+
Directory Structure:
|
527 |
+
{checkpointing_config.runs_dir}/
|
528 |
+
└── {checkpointing_config.run_name}/
|
529 |
+
└── {checkpointing_config.logs_dir}/
|
530 |
+
└── log_YYYYMMDD_HHMMSS.txt
|
531 |
+
|
532 |
+
Args:
|
533 |
+
checkpointing_config: Configuration object containing checkpointing settings.
|
534 |
+
|
535 |
+
Returns:
|
536 |
+
str: Absolute path to the created log file.
|
537 |
+
|
538 |
+
"""
|
539 |
+
|
540 |
+
run_dir = os.path.join(checkpointing_config.runs_dir, checkpointing_config.run_name)
|
541 |
+
logs_dir = os.path.join(run_dir, checkpointing_config.logs_dir)
|
542 |
+
os.makedirs(logs_dir, exist_ok=True)
|
543 |
+
|
544 |
+
# datetime stamp
|
545 |
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
546 |
+
log_file_name = f"log_{timestamp}.log"
|
547 |
+
log_file_path = os.path.join(logs_dir, log_file_name)
|
548 |
+
|
549 |
+
open(log_file_path, "w").close() # Create an empty log file
|
550 |
+
|
551 |
+
return log_file_path
|
552 |
+
|
553 |
+
|
554 |
+
@use_backoff()
|
555 |
+
def initialize_wandb(
|
556 |
+
monitoring_config: MonitoringConfig, checkpointing_config: CheckpointingConfig
|
557 |
+
):
|
558 |
+
"""Initialize Weights and Biases.
|
559 |
+
|
560 |
+
This function initializes Weights and Biases based on the configuration settings.
|
561 |
+
|
562 |
+
Args:
|
563 |
+
monitoring_config: Configuration object containing monitoring settings.
|
564 |
+
checkpointing_config: Configuration object containing checkpointing settings.
|
565 |
+
|
566 |
+
Returns:
|
567 |
+
Optional[WandbLogger]: An experiment tracker instance.
|
568 |
+
"""
|
569 |
+
|
570 |
+
assert (
|
571 |
+
monitoring_config.wandb.project is not None
|
572 |
+
and monitoring_config.wandb.project != ""
|
573 |
+
), "Wandb project must be provided if wandb is to be used."
|
574 |
+
assert (
|
575 |
+
monitoring_config.wandb.entity is not None
|
576 |
+
and monitoring_config.wandb.entity != ""
|
577 |
+
), "Wandb entity must be provided if wandb is to be used."
|
578 |
+
|
579 |
+
_run_id = None
|
580 |
+
if checkpointing_config.training.auto_resume:
|
581 |
+
# If we are loading a checkpoint, we can try to find the run id of the previous run
|
582 |
+
previous_runs = wandb.Api().runs(
|
583 |
+
path=f"{monitoring_config.wandb.entity}/{monitoring_config.wandb.project}",
|
584 |
+
filters={"display_name": checkpointing_config.run_name},
|
585 |
+
)
|
586 |
+
try:
|
587 |
+
if len(previous_runs) == 1:
|
588 |
+
_run_id = previous_runs[0].id
|
589 |
+
except ValueError:
|
590 |
+
pass
|
591 |
+
|
592 |
+
wandb_logger = WandbLogger(
|
593 |
+
project=monitoring_config.wandb.project,
|
594 |
+
entity=monitoring_config.wandb.entity,
|
595 |
+
id=_run_id,
|
596 |
+
name=checkpointing_config.run_name,
|
597 |
+
)
|
598 |
+
|
599 |
+
return wandb_logger
|
600 |
+
|
601 |
+
|
602 |
+
@rank_zero_only
|
603 |
+
def initialize_logging(
|
604 |
+
monitoring_config: MonitoringConfig,
|
605 |
+
checkpointing_config: CheckpointingConfig,
|
606 |
+
fabric: L.Fabric,
|
607 |
+
):
|
608 |
+
"""Initialize logging system with default logging, to file and console.
|
609 |
+
|
610 |
+
The default logging system uses a file handler and a stream handler.
|
611 |
+
|
612 |
+
NOTE: this function is only called on rank 0.
|
613 |
+
|
614 |
+
Args:
|
615 |
+
monitoring_config: Configuration object containing monitoring settings.
|
616 |
+
checkpointing_config: Configuration object containing checkpointing settings.
|
617 |
+
|
618 |
+
Returns:
|
619 |
+
logger: Standard Python logger configured for file and console output
|
620 |
+
"""
|
621 |
+
|
622 |
+
# ---- Standard Local Logger ---- #
|
623 |
+
logger = logging.getLogger("pico-train")
|
624 |
+
logger.setLevel(logging.INFO)
|
625 |
+
|
626 |
+
# Create file handler
|
627 |
+
log_file_path = _initialize_log_file(checkpointing_config)
|
628 |
+
file_handler = logging.FileHandler(log_file_path, encoding="utf-8")
|
629 |
+
file_handler.setLevel(monitoring_config.logging.log_level)
|
630 |
+
|
631 |
+
# Create formatter and add it to the handler
|
632 |
+
formatter = logging.Formatter(
|
633 |
+
"%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
634 |
+
datefmt="%Y-%m-%d %H:%M:%S",
|
635 |
+
)
|
636 |
+
file_handler.setFormatter(formatter)
|
637 |
+
|
638 |
+
# Add the handler to the logger
|
639 |
+
logger.addHandler(file_handler)
|
640 |
+
|
641 |
+
# Add a stream handler for console output
|
642 |
+
stream_handler = logging.StreamHandler()
|
643 |
+
stream_handler.setLevel(monitoring_config.logging.log_level)
|
644 |
+
stream_handler.setFormatter(formatter)
|
645 |
+
logger.addHandler(stream_handler)
|
646 |
+
|
647 |
+
return logger
|
648 |
+
|
649 |
+
|
650 |
+
########################################################
|
651 |
+
#
|
652 |
+
# HuggingFace/Remote Checkpointing
|
653 |
+
#
|
654 |
+
########################################################
|
655 |
+
|
656 |
+
|
657 |
+
@rank_zero_only
|
658 |
+
@use_backoff()
|
659 |
+
def initialize_hf_checkpointing(
|
660 |
+
checkpointing_config: CheckpointingConfig, fabric: L.Fabric
|
661 |
+
):
|
662 |
+
"""Initialize HuggingFace Checkpointing.
|
663 |
+
|
664 |
+
Creates a HuggingFace repository if it doesn't exist, and creates a branch named after the run.
|
665 |
+
|
666 |
+
NOTE: this function is only called on rank 0.
|
667 |
+
|
668 |
+
Args:
|
669 |
+
checkpointing_config: Configuration object containing checkpointing settings; must have
|
670 |
+
a 'hf_checkpoint' attribute that specifies the HuggingFace repository id and
|
671 |
+
collection slug (if applicable) to save the checkpoint to.
|
672 |
+
|
673 |
+
Raises:
|
674 |
+
RuntimeError: If unable to create HuggingFace repository after multiple attempts.
|
675 |
+
"""
|
676 |
+
|
677 |
+
huggingface_repo_id = checkpointing_config.hf_checkpoint.repo_id
|
678 |
+
assert (
|
679 |
+
huggingface_repo_id is not None and huggingface_repo_id != ""
|
680 |
+
), "hf_checkpoint.repo_id must be provided."
|
681 |
+
|
682 |
+
repo = create_repo(huggingface_repo_id, exist_ok=True)
|
683 |
+
|
684 |
+
# can create a repo without a specified namespace (will default to username)
|
685 |
+
# however the rest of the HF calls need the fully qualified name
|
686 |
+
# this is returned by create repo, so we update the config for later calls
|
687 |
+
checkpointing_config.hf_checkpoint.repo_id = repo.repo_id
|
688 |
+
huggingface_repo_id = repo.repo_id
|
689 |
+
|
690 |
+
if checkpointing_config.hf_checkpoint.collection_slug:
|
691 |
+
add_collection_item(
|
692 |
+
checkpointing_config.hf_checkpoint.collection_slug,
|
693 |
+
huggingface_repo_id,
|
694 |
+
repo.repo_type,
|
695 |
+
exists_ok=True,
|
696 |
+
)
|
697 |
+
|
698 |
+
create_branch(
|
699 |
+
repo_id=huggingface_repo_id,
|
700 |
+
branch=checkpointing_config.run_name,
|
701 |
+
exist_ok=True,
|
702 |
+
)
|
src/training/utils/io.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Defines a retry wrapper for io operations."""
|
2 |
+
|
3 |
+
import time
|
4 |
+
from functools import wraps
|
5 |
+
|
6 |
+
|
7 |
+
def use_backoff(max_retries=2, initial_delay=1, backoff_factor=2):
|
8 |
+
"""
|
9 |
+
Universal retry wrapper with exponential backoff for any function, but primarily for loading
|
10 |
+
and storing HuggingFace datasets and objects.
|
11 |
+
|
12 |
+
Example usage:
|
13 |
+
|
14 |
+
>>> @use_backoff(max_retries=10, delay=1, backoff_factor=2)
|
15 |
+
>>> def important_io_operation(x):
|
16 |
+
>>> return x + 1
|
17 |
+
|
18 |
+
Args:
|
19 |
+
fn: Function to execute
|
20 |
+
max_retries: Maximum number of retry attempts (default: 3)
|
21 |
+
delay: Initial delay between retries in seconds (default: 1)
|
22 |
+
backoff_factor: Multiplier for delay between retries (default: 2)
|
23 |
+
|
24 |
+
Returns:
|
25 |
+
A wrapper function that will retry the function fn up to max_retries times with exponential backoff
|
26 |
+
|
27 |
+
Raises:
|
28 |
+
Exception: If all retries fail
|
29 |
+
"""
|
30 |
+
|
31 |
+
def _decorator(fn):
|
32 |
+
@wraps(fn)
|
33 |
+
def wrapper(*args, **kwargs):
|
34 |
+
current_delay = initial_delay
|
35 |
+
last_exception = None
|
36 |
+
|
37 |
+
for attempt in range(max_retries):
|
38 |
+
try:
|
39 |
+
return fn(*args, **kwargs)
|
40 |
+
except Exception as e:
|
41 |
+
last_exception = e
|
42 |
+
if attempt < max_retries - 1: # Don't sleep on the last attempt
|
43 |
+
time.sleep(current_delay)
|
44 |
+
current_delay *= backoff_factor
|
45 |
+
|
46 |
+
raise Exception(
|
47 |
+
f"IO Operation failed after {max_retries} attempts: {str(last_exception)}"
|
48 |
+
)
|
49 |
+
|
50 |
+
return wrapper
|
51 |
+
|
52 |
+
return _decorator
|
src/training/utils/logging.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Miscellaneous logging utilities.
|
3 |
+
"""
|
4 |
+
|
5 |
+
from io import StringIO
|
6 |
+
|
7 |
+
import yaml
|
8 |
+
from lightning.fabric.utilities.rank_zero import rank_zero_only
|
9 |
+
from rich.console import Console
|
10 |
+
from rich.panel import Panel
|
11 |
+
|
12 |
+
|
13 |
+
@rank_zero_only
|
14 |
+
def pretty_print_yaml_config(logger, config: dict) -> None:
|
15 |
+
"""
|
16 |
+
Pretty print config with rich formatting. Assumes that the config is already saved as a
|
17 |
+
dictionary - this can be done by calling `asdict` on the dataclass or loading in the config
|
18 |
+
from a yaml file.
|
19 |
+
|
20 |
+
NOTE: this function is only called on rank 0.
|
21 |
+
|
22 |
+
Args:
|
23 |
+
logger: Logger object to log the formatted output to.
|
24 |
+
config: Dictionary containing the config to pretty print.
|
25 |
+
"""
|
26 |
+
# Create string buffer
|
27 |
+
output = StringIO()
|
28 |
+
console = Console(file=output, force_terminal=False)
|
29 |
+
|
30 |
+
# Convert to YAML string first
|
31 |
+
yaml_str = yaml.dump(
|
32 |
+
config, default_flow_style=False, sort_keys=False, Dumper=yaml.SafeDumper
|
33 |
+
)
|
34 |
+
|
35 |
+
# Create formatted panel
|
36 |
+
panel = Panel(
|
37 |
+
yaml_str,
|
38 |
+
border_style="blue",
|
39 |
+
padding=(0, 1), # Reduced padding
|
40 |
+
expand=False, # Don't expand to terminal width
|
41 |
+
)
|
42 |
+
|
43 |
+
# Print to buffer
|
44 |
+
console.print(panel)
|
45 |
+
|
46 |
+
# Log the formatted output
|
47 |
+
for line in output.getvalue().splitlines():
|
48 |
+
logger.info(line)
|