khoicrtp commited on
Commit
12001a9
1 Parent(s): fdaa255
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +2 -0
  2. .gitignore +16 -0
  3. LICENSE +201 -0
  4. README.md +187 -3
  5. checkpoints/lit-llama-bak/7B/lit-llama.pth +3 -0
  6. checkpoints/lit-llama-bak/tokenizer.model +3 -0
  7. checkpoints/lit-llama/7B/lit-llama.pth +3 -0
  8. checkpoints/lit-llama/tokenizer.model +3 -0
  9. checkpoints/open-llama/7B/.gitattributes +35 -0
  10. checkpoints/open-llama/7B/LICENSE.txt +201 -0
  11. checkpoints/open-llama/7B/README.md +126 -0
  12. checkpoints/open-llama/7B/open_llama_7b_preview_300bt_easylm +3 -0
  13. checkpoints/open-llama/7B/open_llama_7b_preview_300bt_transformers_weights/config.json +22 -0
  14. checkpoints/open-llama/7B/open_llama_7b_preview_300bt_transformers_weights/generation_config.json +7 -0
  15. checkpoints/open-llama/7B/open_llama_7b_preview_300bt_transformers_weights/pytorch_model-00001-of-00002.bin +3 -0
  16. checkpoints/open-llama/7B/open_llama_7b_preview_300bt_transformers_weights/pytorch_model-00002-of-00002.bin +3 -0
  17. checkpoints/open-llama/7B/open_llama_7b_preview_300bt_transformers_weights/pytorch_model.bin.index.json +330 -0
  18. checkpoints/open-llama/7B/open_llama_7b_preview_300bt_transformers_weights/special_tokens_map.json +1 -0
  19. checkpoints/open-llama/7B/open_llama_7b_preview_300bt_transformers_weights/tokenizer.model +3 -0
  20. checkpoints/open-llama/7B/open_llama_7b_preview_300bt_transformers_weights/tokenizer_config.json +1 -0
  21. checkpoints/open-llama/7B/tokenizer.model +3 -0
  22. checkpoints/open-llama/7B/tokenizer.vocab +0 -0
  23. data/alpaca/alpaca_data_cleaned_archive.json +37 -0
  24. data/alpaca/alpaca_data_cleaned_archive_origin.json +3 -0
  25. data/alpaca/cloud_cpu_benchmark_report_nipacloud_c035f97b62.pdf +0 -0
  26. data/alpaca/test.pt +3 -0
  27. data/alpaca/train.pt +3 -0
  28. evaluate.py +145 -0
  29. evaluate_adapter.py +164 -0
  30. evaluate_full.py +145 -0
  31. evaluate_lora.py +173 -0
  32. finetune_adapter.py +253 -0
  33. finetune_full.py +214 -0
  34. finetune_lora.py +211 -0
  35. generate.py +162 -0
  36. generate_adapter.py +117 -0
  37. generate_full.py +160 -0
  38. generate_lora.py +131 -0
  39. howto/customize_paths.md +33 -0
  40. howto/download_weights.md +131 -0
  41. howto/finetune_adapter.md +102 -0
  42. howto/finetune_full.md +104 -0
  43. howto/finetune_lora.md +88 -0
  44. howto/inference.md +37 -0
  45. howto/tpus.md +51 -0
  46. howto/train_redpajama.md +133 -0
  47. lit_llama/__init__.py +2 -0
  48. lit_llama/__pycache__/__init__.cpython-311.pyc +0 -0
  49. lit_llama/__pycache__/lora.cpython-311.pyc +0 -0
  50. lit_llama/__pycache__/model.cpython-311.pyc +0 -0
.gitattributes CHANGED
@@ -32,3 +32,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
35
+ checkpoints/open-llama/7B/open_llama_7b_preview_300bt_easylm filter=lfs diff=lfs merge=lfs -text
36
+ data/alpaca/alpaca_data_cleaned_archive_origin.json filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __pycache__
2
+ .idea
3
+ .DS_Store
4
+ *.egg-info
5
+ build
6
+
7
+ # data
8
+ data
9
+ checkpoints
10
+ out
11
+ !data/shakespeare/prepare.py
12
+ wandb
13
+
14
+ # downloaded by our tests
15
+ original_model.py
16
+ original_adapter.py
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [2023] Lightning AI
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
README.md CHANGED
@@ -1,3 +1,187 @@
1
- ---
2
- license: openrail
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <div align="center">
2
+ <img src="https://pl-public-data.s3.amazonaws.com/assets_lightning/Lit_LLaMA_Badge3x.png" alt="Lit-LLaMA" width="128"/>
3
+
4
+ # ⚡ Lit-LLaMA ️
5
+
6
+ <!--
7
+ <p align="center">
8
+ <a href="https://www.lightning.ai/">Lightning.ai</a> •
9
+ <a href="https://lightning.ai/docs/pytorch/stable/">PyTorch Lightning</a> •
10
+ <a href="https://lightning.ai/docs/fabric/stable/">Fabric</a>
11
+ </p>
12
+ -->
13
+
14
+ ![cpu-tests](https://github.com/lightning-AI/lit-llama/actions/workflows/cpu-tests.yml/badge.svg) [![Build Status](https://dev.azure.com/Lightning-AI/lit%20Models/_apis/build/status%2FLightning-AI.lit-LLaMA?branchName=main)](https://dev.azure.com/Lightning-AI/lit%20Models/_build/latest?definitionId=49&branchName=main) [![license](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://github.com/Lightning-AI/lit-llama/blob/master/LICENSE) [![Discord](https://img.shields.io/discord/1077906959069626439?style=plastic)](https://discord.gg/VptPCZkGNa)
15
+
16
+ <img src="https://pl-public-data.s3.amazonaws.com/assets_lightning/Llama_pineapple.gif" alt="Lit-LLaMA and pineapple pizza" width="500px"/>
17
+
18
+ </div>
19
+
20
+ # ⚡ Lit-LLaMA ️
21
+ Independent implementation of [LLaMA](<https://github.com/facebookresearch/llama>) that is fully open source under the **Apache 2.0 license.**
22
+
23
+ This implementation builds on [nanoGPT](<https://github.com/karpathy/nanoGPT>).
24
+
25
+ The original LLaMA weights are distributed by Meta under a [research-only license](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md#model-details).
26
+
27
+ New Apache 2.0 licensed weights are being released as part of the [Open LLaMA project](https://github.com/openlm-research/open_llama). Both can be [loaded in Lit-LLaMA](howto/download_weights.md).
28
+
29
+ ## Why?
30
+
31
+ We believe that AI should be fully open source and part of the collective knowledge.
32
+
33
+ The original [LLaMA code](https://github.com/facebookresearch/llama) is [GPL licensed](https://github.com/facebookresearch/llama/blob/main/LICENSE) which means any project using it must also be released under GPL.
34
+
35
+ This "taints" any other code and prevents integration with the rest of the ecosystem.
36
+
37
+ **Lit-LLaMA solves that for good.**
38
+
39
+ &nbsp;
40
+
41
+ ## Design principles
42
+ **Lit-LLaMA** is:
43
+
44
+ - **Simple:** Single-file implementation without boilerplate.
45
+ - **Correct:** Numerically equivalent to the original model.
46
+ - **Optimized:** Runs on consumer hardware or at scale.
47
+ - **Open-source:** No strings attached.
48
+
49
+ ## Get involved!
50
+ [Join our Discord](https://discord.gg/VptPCZkGNa) to build high-performance, truly open-source models for the common benefit of the community.
51
+
52
+ &nbsp;
53
+
54
+ ## Setup
55
+
56
+ Clone the repo
57
+
58
+ ```bash
59
+ git clone https://github.com/Lightning-AI/lit-llama
60
+ cd lit-llama
61
+ ```
62
+
63
+ install dependencies
64
+
65
+ ```bash
66
+ pip install -r requirements.txt
67
+ ```
68
+
69
+ You are all set! 🎉
70
+
71
+ &nbsp;
72
+
73
+ ## Use the model
74
+
75
+ To generate text predictions, you need to download the model weights. **If you don't have them, check out our [guide](howto/download_weights.md).**
76
+
77
+ Run inference:
78
+
79
+ ```bash
80
+ python generate.py --prompt "Hello, my name is"
81
+ ```
82
+
83
+ This will run the 7B model and require ~26 GB of GPU memory (A100 GPU).
84
+
85
+ [Full guide for generating samples from the model](howto/inference.md).
86
+
87
+ ### Run Lit-LLaMA on consumer devices
88
+
89
+ On GPUs with `bfloat16` support, the `generate.py` script will automatically convert the weights and consume about ~14 GB.
90
+ For GPUs with less memory, or ones that don't support `bfloat16`, enable quantization (`--quantize llm.int8`):
91
+
92
+ ```bash
93
+ python generate.py --quantize llm.int8 --prompt "Hello, my name is"
94
+ ```
95
+
96
+ See `python generate.py --help` for more options.
97
+
98
+ You can also use GPTQ-style int4 quantization, but this needs conversions of the weights first:
99
+
100
+ ```bash
101
+ python quantize.py --checkpoint_path lit-llama.pth --tokenizer_path tokenizer.model --output_path llama-7b-gptq.4bit.pth --dtype bfloat16 --quantize gptq.int4
102
+ ```
103
+
104
+ With the generated quantized checkpoint generation works as usual with `--quantize gptq.int4`, bringing GPU usage to about ~5GB. As only the weights of the Linear layers are quantized, it is useful to use `--dtype bfloat16` even with the quantization enabled.
105
+
106
+ [Full guide for generating samples from the model](howto/inference.md).
107
+
108
+ ## Finetune the model
109
+
110
+ We provide a simple training scripts in `finetune_lora.py` and `finetune_adapter.py` that instruction-tunes a pretrained model on the [Alpaca](https://github.com/tatsu-lab/stanford_alpaca) dataset using the techniques of [LoRA](https://arxiv.org/abs/2106.09685) and [Adapter](https://arxiv.org/abs/2303.16199).
111
+
112
+ 1. Download the data and generate a instruction tuning dataset:
113
+
114
+ ```bash
115
+ python scripts/prepare_alpaca.py
116
+ ```
117
+
118
+ 2. Run the finetuning script
119
+
120
+ ```bash
121
+ python finetune_lora.py
122
+ ```
123
+ or
124
+ ```bash
125
+ python finetune_adapter.py
126
+ ```
127
+
128
+ It is expected that you have downloaded the pretrained weights as described above.
129
+ The finetuning requires at least one GPU with ~24 GB memory (GTX 3090). Follow the instructions in the script to efficiently fit your GPU memory.
130
+ Note: For some GPU models you might need to set `torch.backends.cuda.enable_flash_sdp(False)` (see comments at the top of the script).
131
+
132
+ More details about each finetuning method and how you can apply it to your own data can be found in our technical how-to guides.
133
+
134
+ ### Finetuning How-To Guides
135
+
136
+ These technical tutorials illustrate how to run the finetuning code.
137
+
138
+ - [Finetune with LoRA](howto/finetune_lora.md)
139
+ - [Finetune with Adapters](howto/finetune_adapter.md)
140
+
141
+ ### Understanding Finetuning -- Conceptual Tutorials
142
+
143
+ Looking for conceptual tutorials and explanations? We have some additional articles below:
144
+
145
+ - [Understanding Parameter-Efficient Finetuning of Large Language Models: From Prefix Tuning to LLaMA-Adapters](https://lightning.ai/pages/community/article/understanding-llama-adapters/)
146
+
147
+ ## Pre-training
148
+
149
+ We provide a simple training script based on Fabric if you want to venture into pre-training on RedPajama, a reproduction of the original LLaMA dataset.
150
+ Conversion scripts for our optimized streaming `PackedDataset` are included.
151
+
152
+ Follow this guide to start pre-training on the RedPajama dataset:
153
+
154
+ - [Pretrain on RedPajama](howto/train_redpajama.md)
155
+
156
+ ## Get involved!
157
+
158
+ We are on a quest towards fully open source AI.
159
+
160
+ <img align="right" src="https://pl-public-data.s3.amazonaws.com/assets_lightning/Lit_LLaMA_Illustration3x.png" alt="Lit-LLaMA" width="128"/>
161
+
162
+ Join us and start contributing, especially on the following areas:
163
+
164
+ - [ ] [Pre-training](https://github.com/Lightning-AI/lit-llama/labels/pre-training)
165
+ - [ ] [Fine-tuning (full and LoRA)](https://github.com/Lightning-AI/lit-llama/labels/fine-tuning)
166
+ - [ ] [Quantization](https://github.com/Lightning-AI/lit-llama/labels/quantization)
167
+ - [ ] [Sparsification](https://github.com/Lightning-AI/lit-llama/labels/sparsification)
168
+
169
+ Look at `train.py` for a starting point towards pre-training / fine-tuning using [Lightning Fabric](https://lightning.ai/docs/fabric/stable/).
170
+
171
+ We welcome all individual contributors, regardless of their level of experience or hardware. Your contributions are valuable, and we are excited to see what you can accomplish in this collaborative and supportive environment.
172
+
173
+ Unsure about contributing? Check out our [Contributing to Lit-LLaMA: A Hitchhiker’s Guide to the Quest for Fully Open-Source AI](https://lightning.ai/pages/community/tutorial/contributing-to-lit-llama-a-hitchhikers-guide-to-the-quest-for-fully-open-source-ai/) guide.
174
+
175
+ Don't forget to [join our Discord](https://discord.gg/VptPCZkGNa)!
176
+
177
+ ## Acknowledgements
178
+
179
+ - [@karpathy](https://github.com/karpathy) for [nanoGPT](https://github.com/karpathy/nanoGPT)
180
+ - [@FacebookResearch](https://github.com/facebookresearch) for the original [LLaMA implementation](https://github.com/facebookresearch/llama)
181
+ - [@TimDettmers](https://github.com/TimDettmers) for [bitsandbytes](https://github.com/TimDettmers/bitsandbytes)
182
+ - [@Microsoft](https://github.com/microsoft) for [LoRA](https://github.com/microsoft/LoRA)
183
+ - [@IST-DASLab](https://github.com/IST-DASLab) for [GPTQ](https://github.com/IST-DASLab/gptq)
184
+
185
+ ## License
186
+
187
+ Lit-LLaMA is released under the [Apache 2.0](https://github.com/Lightning-AI/lightning-llama/blob/main/LICENSE) license.
checkpoints/lit-llama-bak/7B/lit-llama.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:09b6a8994e8bc8ed517e600e355a19cbe41eaf8338532ff4f88d43df6b95e3cd
3
+ size 26953750909
checkpoints/lit-llama-bak/tokenizer.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bc820fc43f4173d6362c16658c409ed423929a807e55a984af96cce1277d39a4
3
+ size 772031
checkpoints/lit-llama/7B/lit-llama.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:09b6a8994e8bc8ed517e600e355a19cbe41eaf8338532ff4f88d43df6b95e3cd
3
+ size 26953750909
checkpoints/lit-llama/tokenizer.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bc820fc43f4173d6362c16658c409ed423929a807e55a984af96cce1277d39a4
3
+ size 772031
checkpoints/open-llama/7B/.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tflite filter=lfs diff=lfs merge=lfs -text
29
+ *.tgz filter=lfs diff=lfs merge=lfs -text
30
+ *.wasm filter=lfs diff=lfs merge=lfs -text
31
+ *.xz filter=lfs diff=lfs merge=lfs -text
32
+ *.zip filter=lfs diff=lfs merge=lfs -text
33
+ *.zst filter=lfs diff=lfs merge=lfs -text
34
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
35
+ open_llama_7b_preview_300bt_easylm filter=lfs diff=lfs merge=lfs -text
checkpoints/open-llama/7B/LICENSE.txt ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
checkpoints/open-llama/7B/README.md ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ datasets:
4
+ - togethercomputer/RedPajama-Data-1T
5
+ ---
6
+
7
+
8
+ # OpenLLaMA: An Open Reproduction of LLaMA
9
+
10
+ In this repo, we release a permissively licensed open source reproduction of Meta AI's [LLaMA](https://ai.facebook.com/blog/large-language-model-llama-meta-ai/) large language model. In this release, we're releasing a public preview of the 7B OpenLLaMA model that has been trained with 200 billion tokens. We provide PyTorch and Jax weights of pre-trained OpenLLaMA models, as well as evaluation results and comparison against the original LLaMA models. Stay tuned for our updates.
11
+
12
+ **JAX and PyTorch Weights on Huggingface Hub**
13
+ - [200B Checkpoint](https://huggingface.co/openlm-research/open_llama_7b_preview_200bt)
14
+ - [300B Checkpoint](https://huggingface.co/openlm-research/open_llama_7b_preview_300bt)
15
+
16
+
17
+ ## Update 5/3/2023
18
+ We have released a new checkpoint of OpenLLaMA 7B trained on 300B tokens. In communicating
19
+ with our users, we have realized that many existing implementations of LLaMA does not
20
+ prepend the BOS token (id=1) at generation time. Our 200B checkpoint is sensitive
21
+ to this and may produce degraded results without BOS token at the beginning. Hence,
22
+ we recommend always prepending the BOS token when using our 200B checkpoint.
23
+
24
+ In an effort to make our model boradly compatible with existing implementations, we have now
25
+ released a new 300B checkpoint, which is less sensitive to BOS token and can be used
26
+ either way.
27
+
28
+
29
+ ## Dataset and Training
30
+
31
+ We train our models on the [RedPajama](https://www.together.xyz/blog/redpajama) dataset released by [Together](https://www.together.xyz/), which is a reproduction of the LLaMA training dataset containing over 1.2 trillion tokens. We follow the exactly same preprocessing steps and training hyperparameters as the original LLaMA paper, including model architecture, context length, training steps, learning rate schedule, and optimizer. The only difference between our setting and the original one is the dataset used: OpenLLaMA employs the RedPajama dataset rather than the one utilized by the original LLaMA.
32
+
33
+ We train the models on cloud TPU-v4s using [EasyLM](https://github.com/young-geng/EasyLM), a JAX based training pipeline we developed for training and fine-tuning language model. We employ a combination of normal data parallelism and [fully sharded data parallelism (also know as ZeRO stage 3)](https://engineering.fb.com/2021/07/15/open-source/fsdp/) to balance the training throughput and memory usage. Overall we reach a throughput of over 1900 tokens / second / TPU-v4 chip in our training run.
34
+
35
+
36
+ ## Evaluation
37
+
38
+ We evaluated OpenLLaMA on a wide range of tasks using [lm-evaluation-harness](https://github.com/EleutherAI/lm-evaluation-harness). The LLaMA results are generated by running the original LLaMA model on the same evaluation metrics. We note that our results for the LLaMA model differ slightly from the original LLaMA paper, which we believe is a result of different evaluation protocols. Similar differences have been reported in [this issue of lm-evaluation-harness](https://github.com/EleutherAI/lm-evaluation-harness/issues/443). Additionally, we present the results of GPT-J, a 6B parameter model trained on the [Pile](https://pile.eleuther.ai/) dataset by [EleutherAI](https://www.eleuther.ai/).
39
+
40
+ The original LLaMA model was trained for 1 trillion tokens and GPT-J was trained for 500 billion tokens, whereas OpenLLaMA was trained on 200 billion tokens. We present the results in the table below. OpenLLaMA exhibits comparable performance to the original LLaMA and GPT-J across a majority of tasks, and outperforms them in some tasks. We expect that the performance of OpenLLaMA, after completing its training on 1 trillion tokens, will be enhanced even further.
41
+
42
+
43
+ | **Task/Metric** | **GPT-J 6B** | **LLaMA 7B** | **Open LLaMA 7B Preview 200B Tokens** |
44
+ | ---------------------- | ------------ | ------------ | ------------------------------------- |
45
+ | anli_r1/acc | 0.32 | 0.35 | 0.34 |
46
+ | anli_r2/acc | 0.34 | 0.34 | 0.35 |
47
+ | anli_r3/acc | 0.35 | 0.37 | 0.34 |
48
+ | arc_challenge/acc | 0.34 | 0.39 | 0.31 |
49
+ | arc_challenge/acc_norm | 0.37 | 0.41 | 0.34 |
50
+ | arc_easy/acc | 0.67 | 0.68 | 0.66 |
51
+ | arc_easy/acc_norm | 0.62 | 0.52 | 0.59 |
52
+ | boolq/acc | 0.66 | 0.75 | 0.67 |
53
+ | cb/acc | 0.36 | 0.36 | 0.38 |
54
+ | cb/f1 | 0.26 | 0.24 | 0.29 |
55
+ | hellaswag/acc | 0.50 | 0.56 | 0.47 |
56
+ | hellaswag/acc_norm | 0.66 | 0.73 | 0.63 |
57
+ | openbookqa/acc | 0.29 | 0.29 | 0.26 |
58
+ | openbookqa/acc_norm | 0.38 | 0.41 | 0.37 |
59
+ | piqa/acc | 0.75 | 0.78 | 0.74 |
60
+ | piqa/acc_norm | 0.76 | 0.78 | 0.74 |
61
+ | record/em | 0.88 | 0.91 | 0.87 |
62
+ | record/f1 | 0.89 | 0.91 | 0.88 |
63
+ | rte/acc | 0.54 | 0.56 | 0.53 |
64
+ | truthfulqa_mc/mc1 | 0.20 | 0.21 | 0.21 |
65
+ | truthfulqa_mc/mc2 | 0.36 | 0.34 | 0.34 |
66
+ | wic/acc | 0.50 | 0.50 | 0.50 |
67
+ | winogrande/acc | 0.64 | 0.68 | 0.62 |
68
+ | wsc/acc | 0.37 | 0.35 | 0.57 |
69
+ | Average | 0.50 | 0.52 | 0.50 |
70
+
71
+
72
+
73
+
74
+ ## Preview Weights Release and Usage
75
+
76
+ To encourage the feedback from the community, we release a preview checkpoint of our weights. The checkpoint can be downloaded from [HuggingFace Hub](https://huggingface.co/openlm-research/open_llama_7b_preview_200bt). We release the weights in two formats: an EasyLM format to be use with our [EasyLM framework](https://github.com/young-geng/EasyLM), and a PyTorch format to be used with the [Huggingface Transformers](https://huggingface.co/docs/transformers/index) library.
77
+
78
+ For using the weights in our EasyLM framework, please refer to the [LLaMA documentation of EasyLM](https://github.com/young-geng/EasyLM/blob/main/docs/llama.md). Note that unlike the original LLaMA model, our OpenLLaMA tokenizer and weights are trained completely from scratch so it is no longer needed to obtain the original LLaMA tokenizer and weights. For using the weights in the transformers library, please follow the [transformers LLaMA documentation](https://huggingface.co/docs/transformers/main/model_doc/llama). Note that we use BOS (beginning of sentence) token (id=1) during training, so it is important to prepend this token for best performance during few-shot evaluation.
79
+
80
+ Both our training framework EasyLM and the preview checkpoint weights are licensed permissively under the Apache 2.0 license.
81
+
82
+
83
+ ## Future Plans
84
+
85
+ The current release is only a preview of what the complete OpenLLaMA release will offer. We are currently focused on completing the training process on the entire RedPajama dataset. This can gives us a good apple-to-apple comparison between the original LLaMA and our OpenLLaMA. Other than the 7B model, we are also training a smaller 3B model in hope of facilitating language model usage in low resource use cases. Please stay tuned for our upcoming releases.
86
+
87
+
88
+
89
+ ## Contact
90
+
91
+ We would love to get feedback from the community. If you have any questions, please open an issue or contact us.
92
+
93
+ OpenLLaMA is developed by:
94
+ [Xinyang Geng](https://young-geng.xyz/)* and [Hao Liu](https://www.haoliu.site/)* from Berkeley AI Research.
95
+ *Equal Contribution
96
+
97
+
98
+ ## Reference
99
+
100
+ If you found OpenLLaMA useful in your research or applications, please cite using the following BibTeX:
101
+ ```
102
+ @software{openlm2023openllama,
103
+ author = {Geng, Xinyang and Liu, Hao},
104
+ title = {OpenLLaMA: An Open Reproduction of LLaMA},
105
+ month = May,
106
+ year = 2023,
107
+ url = {https://github.com/openlm-research/open_llama}
108
+ }
109
+ ```
110
+ ```
111
+ @software{together2023redpajama,
112
+ author = {Together Computer},
113
+ title = {RedPajama-Data: An Open Source Recipe to Reproduce LLaMA training dataset},
114
+ month = April,
115
+ year = 2023,
116
+ url = {https://github.com/togethercomputer/RedPajama-Data}
117
+ }
118
+ ```
119
+ ```
120
+ @article{touvron2023llama,
121
+ title={Llama: Open and efficient foundation language models},
122
+ author={Touvron, Hugo and Lavril, Thibaut and Izacard, Gautier and Martinet, Xavier and Lachaux, Marie-Anne and Lacroix, Timoth{\'e}e and Rozi{\`e}re, Baptiste and Goyal, Naman and Hambro, Eric and Azhar, Faisal and others},
123
+ journal={arXiv preprint arXiv:2302.13971},
124
+ year={2023}
125
+ }
126
+ ```
checkpoints/open-llama/7B/open_llama_7b_preview_300bt_easylm ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:63ab9d652aaf4e0e47f1d9a0321ef565b62c02921ce0b18a781ba0daac2ebb98
3
+ size 13476851687
checkpoints/open-llama/7B/open_llama_7b_preview_300bt_transformers_weights/config.json ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "LlamaForCausalLM"
4
+ ],
5
+ "bos_token_id": 1,
6
+ "eos_token_id": 2,
7
+ "hidden_act": "silu",
8
+ "hidden_size": 4096,
9
+ "initializer_range": 0.02,
10
+ "intermediate_size": 11008,
11
+ "max_position_embeddings": 2048,
12
+ "model_type": "llama",
13
+ "num_attention_heads": 32,
14
+ "num_hidden_layers": 32,
15
+ "pad_token_id": 0,
16
+ "rms_norm_eps": 1e-06,
17
+ "tie_word_embeddings": false,
18
+ "torch_dtype": "float16",
19
+ "transformers_version": "4.28.0.dev0",
20
+ "use_cache": true,
21
+ "vocab_size": 32000
22
+ }
checkpoints/open-llama/7B/open_llama_7b_preview_300bt_transformers_weights/generation_config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "pad_token_id": 0,
6
+ "transformers_version": "4.28.0.dev0"
7
+ }
checkpoints/open-llama/7B/open_llama_7b_preview_300bt_transformers_weights/pytorch_model-00001-of-00002.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:adce75e45dbad20967c0e96a83b318720a767e4b8f77aabcd01cd2b38e8f0b2e
3
+ size 9976634558
checkpoints/open-llama/7B/open_llama_7b_preview_300bt_transformers_weights/pytorch_model-00002-of-00002.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a80d748a6ab528f0db2249013a5d3fea17e039ad9fa1bf3e170e9070ec30f938
3
+ size 3500315539
checkpoints/open-llama/7B/open_llama_7b_preview_300bt_transformers_weights/pytorch_model.bin.index.json ADDED
@@ -0,0 +1,330 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "metadata": {
3
+ "total_size": 13476839424
4
+ },
5
+ "weight_map": {
6
+ "lm_head.weight": "pytorch_model-00002-of-00002.bin",
7
+ "model.embed_tokens.weight": "pytorch_model-00001-of-00002.bin",
8
+ "model.layers.0.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
9
+ "model.layers.0.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
10
+ "model.layers.0.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
11
+ "model.layers.0.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
12
+ "model.layers.0.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
13
+ "model.layers.0.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
14
+ "model.layers.0.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
15
+ "model.layers.0.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
16
+ "model.layers.0.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
17
+ "model.layers.0.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
18
+ "model.layers.1.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
19
+ "model.layers.1.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
20
+ "model.layers.1.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
21
+ "model.layers.1.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
22
+ "model.layers.1.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
23
+ "model.layers.1.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
24
+ "model.layers.1.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
25
+ "model.layers.1.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
26
+ "model.layers.1.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
27
+ "model.layers.1.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
28
+ "model.layers.10.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
29
+ "model.layers.10.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
30
+ "model.layers.10.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
31
+ "model.layers.10.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
32
+ "model.layers.10.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
33
+ "model.layers.10.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
34
+ "model.layers.10.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
35
+ "model.layers.10.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
36
+ "model.layers.10.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
37
+ "model.layers.10.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
38
+ "model.layers.11.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
39
+ "model.layers.11.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
40
+ "model.layers.11.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
41
+ "model.layers.11.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
42
+ "model.layers.11.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
43
+ "model.layers.11.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
44
+ "model.layers.11.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
45
+ "model.layers.11.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
46
+ "model.layers.11.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
47
+ "model.layers.11.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
48
+ "model.layers.12.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
49
+ "model.layers.12.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
50
+ "model.layers.12.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
51
+ "model.layers.12.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
52
+ "model.layers.12.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
53
+ "model.layers.12.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
54
+ "model.layers.12.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
55
+ "model.layers.12.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
56
+ "model.layers.12.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
57
+ "model.layers.12.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
58
+ "model.layers.13.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
59
+ "model.layers.13.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
60
+ "model.layers.13.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
61
+ "model.layers.13.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
62
+ "model.layers.13.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
63
+ "model.layers.13.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
64
+ "model.layers.13.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
65
+ "model.layers.13.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
66
+ "model.layers.13.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
67
+ "model.layers.13.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
68
+ "model.layers.14.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
69
+ "model.layers.14.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
70
+ "model.layers.14.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
71
+ "model.layers.14.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
72
+ "model.layers.14.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
73
+ "model.layers.14.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
74
+ "model.layers.14.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
75
+ "model.layers.14.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
76
+ "model.layers.14.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
77
+ "model.layers.14.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
78
+ "model.layers.15.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
79
+ "model.layers.15.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
80
+ "model.layers.15.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
81
+ "model.layers.15.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
82
+ "model.layers.15.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
83
+ "model.layers.15.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
84
+ "model.layers.15.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
85
+ "model.layers.15.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
86
+ "model.layers.15.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
87
+ "model.layers.15.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
88
+ "model.layers.16.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
89
+ "model.layers.16.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
90
+ "model.layers.16.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
91
+ "model.layers.16.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
92
+ "model.layers.16.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
93
+ "model.layers.16.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
94
+ "model.layers.16.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
95
+ "model.layers.16.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
96
+ "model.layers.16.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
97
+ "model.layers.16.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
98
+ "model.layers.17.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
99
+ "model.layers.17.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
100
+ "model.layers.17.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
101
+ "model.layers.17.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
102
+ "model.layers.17.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
103
+ "model.layers.17.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
104
+ "model.layers.17.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
105
+ "model.layers.17.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
106
+ "model.layers.17.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
107
+ "model.layers.17.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
108
+ "model.layers.18.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
109
+ "model.layers.18.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
110
+ "model.layers.18.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
111
+ "model.layers.18.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
112
+ "model.layers.18.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
113
+ "model.layers.18.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
114
+ "model.layers.18.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
115
+ "model.layers.18.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
116
+ "model.layers.18.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
117
+ "model.layers.18.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
118
+ "model.layers.19.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
119
+ "model.layers.19.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
120
+ "model.layers.19.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
121
+ "model.layers.19.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
122
+ "model.layers.19.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
123
+ "model.layers.19.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
124
+ "model.layers.19.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
125
+ "model.layers.19.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
126
+ "model.layers.19.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
127
+ "model.layers.19.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
128
+ "model.layers.2.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
129
+ "model.layers.2.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
130
+ "model.layers.2.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
131
+ "model.layers.2.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
132
+ "model.layers.2.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
133
+ "model.layers.2.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
134
+ "model.layers.2.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
135
+ "model.layers.2.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
136
+ "model.layers.2.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
137
+ "model.layers.2.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
138
+ "model.layers.20.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
139
+ "model.layers.20.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
140
+ "model.layers.20.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
141
+ "model.layers.20.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
142
+ "model.layers.20.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
143
+ "model.layers.20.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
144
+ "model.layers.20.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
145
+ "model.layers.20.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
146
+ "model.layers.20.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
147
+ "model.layers.20.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
148
+ "model.layers.21.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
149
+ "model.layers.21.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
150
+ "model.layers.21.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
151
+ "model.layers.21.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
152
+ "model.layers.21.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
153
+ "model.layers.21.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
154
+ "model.layers.21.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
155
+ "model.layers.21.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
156
+ "model.layers.21.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
157
+ "model.layers.21.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
158
+ "model.layers.22.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
159
+ "model.layers.22.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
160
+ "model.layers.22.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
161
+ "model.layers.22.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
162
+ "model.layers.22.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
163
+ "model.layers.22.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
164
+ "model.layers.22.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
165
+ "model.layers.22.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
166
+ "model.layers.22.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
167
+ "model.layers.22.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
168
+ "model.layers.23.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
169
+ "model.layers.23.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
170
+ "model.layers.23.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
171
+ "model.layers.23.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
172
+ "model.layers.23.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
173
+ "model.layers.23.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
174
+ "model.layers.23.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
175
+ "model.layers.23.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
176
+ "model.layers.23.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
177
+ "model.layers.23.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
178
+ "model.layers.24.input_layernorm.weight": "pytorch_model-00002-of-00002.bin",
179
+ "model.layers.24.mlp.down_proj.weight": "pytorch_model-00002-of-00002.bin",
180
+ "model.layers.24.mlp.gate_proj.weight": "pytorch_model-00002-of-00002.bin",
181
+ "model.layers.24.mlp.up_proj.weight": "pytorch_model-00002-of-00002.bin",
182
+ "model.layers.24.post_attention_layernorm.weight": "pytorch_model-00002-of-00002.bin",
183
+ "model.layers.24.self_attn.k_proj.weight": "pytorch_model-00002-of-00002.bin",
184
+ "model.layers.24.self_attn.o_proj.weight": "pytorch_model-00002-of-00002.bin",
185
+ "model.layers.24.self_attn.q_proj.weight": "pytorch_model-00002-of-00002.bin",
186
+ "model.layers.24.self_attn.rotary_emb.inv_freq": "pytorch_model-00002-of-00002.bin",
187
+ "model.layers.24.self_attn.v_proj.weight": "pytorch_model-00002-of-00002.bin",
188
+ "model.layers.25.input_layernorm.weight": "pytorch_model-00002-of-00002.bin",
189
+ "model.layers.25.mlp.down_proj.weight": "pytorch_model-00002-of-00002.bin",
190
+ "model.layers.25.mlp.gate_proj.weight": "pytorch_model-00002-of-00002.bin",
191
+ "model.layers.25.mlp.up_proj.weight": "pytorch_model-00002-of-00002.bin",
192
+ "model.layers.25.post_attention_layernorm.weight": "pytorch_model-00002-of-00002.bin",
193
+ "model.layers.25.self_attn.k_proj.weight": "pytorch_model-00002-of-00002.bin",
194
+ "model.layers.25.self_attn.o_proj.weight": "pytorch_model-00002-of-00002.bin",
195
+ "model.layers.25.self_attn.q_proj.weight": "pytorch_model-00002-of-00002.bin",
196
+ "model.layers.25.self_attn.rotary_emb.inv_freq": "pytorch_model-00002-of-00002.bin",
197
+ "model.layers.25.self_attn.v_proj.weight": "pytorch_model-00002-of-00002.bin",
198
+ "model.layers.26.input_layernorm.weight": "pytorch_model-00002-of-00002.bin",
199
+ "model.layers.26.mlp.down_proj.weight": "pytorch_model-00002-of-00002.bin",
200
+ "model.layers.26.mlp.gate_proj.weight": "pytorch_model-00002-of-00002.bin",
201
+ "model.layers.26.mlp.up_proj.weight": "pytorch_model-00002-of-00002.bin",
202
+ "model.layers.26.post_attention_layernorm.weight": "pytorch_model-00002-of-00002.bin",
203
+ "model.layers.26.self_attn.k_proj.weight": "pytorch_model-00002-of-00002.bin",
204
+ "model.layers.26.self_attn.o_proj.weight": "pytorch_model-00002-of-00002.bin",
205
+ "model.layers.26.self_attn.q_proj.weight": "pytorch_model-00002-of-00002.bin",
206
+ "model.layers.26.self_attn.rotary_emb.inv_freq": "pytorch_model-00002-of-00002.bin",
207
+ "model.layers.26.self_attn.v_proj.weight": "pytorch_model-00002-of-00002.bin",
208
+ "model.layers.27.input_layernorm.weight": "pytorch_model-00002-of-00002.bin",
209
+ "model.layers.27.mlp.down_proj.weight": "pytorch_model-00002-of-00002.bin",
210
+ "model.layers.27.mlp.gate_proj.weight": "pytorch_model-00002-of-00002.bin",
211
+ "model.layers.27.mlp.up_proj.weight": "pytorch_model-00002-of-00002.bin",
212
+ "model.layers.27.post_attention_layernorm.weight": "pytorch_model-00002-of-00002.bin",
213
+ "model.layers.27.self_attn.k_proj.weight": "pytorch_model-00002-of-00002.bin",
214
+ "model.layers.27.self_attn.o_proj.weight": "pytorch_model-00002-of-00002.bin",
215
+ "model.layers.27.self_attn.q_proj.weight": "pytorch_model-00002-of-00002.bin",
216
+ "model.layers.27.self_attn.rotary_emb.inv_freq": "pytorch_model-00002-of-00002.bin",
217
+ "model.layers.27.self_attn.v_proj.weight": "pytorch_model-00002-of-00002.bin",
218
+ "model.layers.28.input_layernorm.weight": "pytorch_model-00002-of-00002.bin",
219
+ "model.layers.28.mlp.down_proj.weight": "pytorch_model-00002-of-00002.bin",
220
+ "model.layers.28.mlp.gate_proj.weight": "pytorch_model-00002-of-00002.bin",
221
+ "model.layers.28.mlp.up_proj.weight": "pytorch_model-00002-of-00002.bin",
222
+ "model.layers.28.post_attention_layernorm.weight": "pytorch_model-00002-of-00002.bin",
223
+ "model.layers.28.self_attn.k_proj.weight": "pytorch_model-00002-of-00002.bin",
224
+ "model.layers.28.self_attn.o_proj.weight": "pytorch_model-00002-of-00002.bin",
225
+ "model.layers.28.self_attn.q_proj.weight": "pytorch_model-00002-of-00002.bin",
226
+ "model.layers.28.self_attn.rotary_emb.inv_freq": "pytorch_model-00002-of-00002.bin",
227
+ "model.layers.28.self_attn.v_proj.weight": "pytorch_model-00002-of-00002.bin",
228
+ "model.layers.29.input_layernorm.weight": "pytorch_model-00002-of-00002.bin",
229
+ "model.layers.29.mlp.down_proj.weight": "pytorch_model-00002-of-00002.bin",
230
+ "model.layers.29.mlp.gate_proj.weight": "pytorch_model-00002-of-00002.bin",
231
+ "model.layers.29.mlp.up_proj.weight": "pytorch_model-00002-of-00002.bin",
232
+ "model.layers.29.post_attention_layernorm.weight": "pytorch_model-00002-of-00002.bin",
233
+ "model.layers.29.self_attn.k_proj.weight": "pytorch_model-00002-of-00002.bin",
234
+ "model.layers.29.self_attn.o_proj.weight": "pytorch_model-00002-of-00002.bin",
235
+ "model.layers.29.self_attn.q_proj.weight": "pytorch_model-00002-of-00002.bin",
236
+ "model.layers.29.self_attn.rotary_emb.inv_freq": "pytorch_model-00002-of-00002.bin",
237
+ "model.layers.29.self_attn.v_proj.weight": "pytorch_model-00002-of-00002.bin",
238
+ "model.layers.3.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
239
+ "model.layers.3.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
240
+ "model.layers.3.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
241
+ "model.layers.3.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
242
+ "model.layers.3.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
243
+ "model.layers.3.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
244
+ "model.layers.3.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
245
+ "model.layers.3.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
246
+ "model.layers.3.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
247
+ "model.layers.3.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
248
+ "model.layers.30.input_layernorm.weight": "pytorch_model-00002-of-00002.bin",
249
+ "model.layers.30.mlp.down_proj.weight": "pytorch_model-00002-of-00002.bin",
250
+ "model.layers.30.mlp.gate_proj.weight": "pytorch_model-00002-of-00002.bin",
251
+ "model.layers.30.mlp.up_proj.weight": "pytorch_model-00002-of-00002.bin",
252
+ "model.layers.30.post_attention_layernorm.weight": "pytorch_model-00002-of-00002.bin",
253
+ "model.layers.30.self_attn.k_proj.weight": "pytorch_model-00002-of-00002.bin",
254
+ "model.layers.30.self_attn.o_proj.weight": "pytorch_model-00002-of-00002.bin",
255
+ "model.layers.30.self_attn.q_proj.weight": "pytorch_model-00002-of-00002.bin",
256
+ "model.layers.30.self_attn.rotary_emb.inv_freq": "pytorch_model-00002-of-00002.bin",
257
+ "model.layers.30.self_attn.v_proj.weight": "pytorch_model-00002-of-00002.bin",
258
+ "model.layers.31.input_layernorm.weight": "pytorch_model-00002-of-00002.bin",
259
+ "model.layers.31.mlp.down_proj.weight": "pytorch_model-00002-of-00002.bin",
260
+ "model.layers.31.mlp.gate_proj.weight": "pytorch_model-00002-of-00002.bin",
261
+ "model.layers.31.mlp.up_proj.weight": "pytorch_model-00002-of-00002.bin",
262
+ "model.layers.31.post_attention_layernorm.weight": "pytorch_model-00002-of-00002.bin",
263
+ "model.layers.31.self_attn.k_proj.weight": "pytorch_model-00002-of-00002.bin",
264
+ "model.layers.31.self_attn.o_proj.weight": "pytorch_model-00002-of-00002.bin",
265
+ "model.layers.31.self_attn.q_proj.weight": "pytorch_model-00002-of-00002.bin",
266
+ "model.layers.31.self_attn.rotary_emb.inv_freq": "pytorch_model-00002-of-00002.bin",
267
+ "model.layers.31.self_attn.v_proj.weight": "pytorch_model-00002-of-00002.bin",
268
+ "model.layers.4.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
269
+ "model.layers.4.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
270
+ "model.layers.4.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
271
+ "model.layers.4.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
272
+ "model.layers.4.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
273
+ "model.layers.4.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
274
+ "model.layers.4.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
275
+ "model.layers.4.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
276
+ "model.layers.4.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
277
+ "model.layers.4.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
278
+ "model.layers.5.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
279
+ "model.layers.5.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
280
+ "model.layers.5.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
281
+ "model.layers.5.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
282
+ "model.layers.5.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
283
+ "model.layers.5.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
284
+ "model.layers.5.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
285
+ "model.layers.5.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
286
+ "model.layers.5.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
287
+ "model.layers.5.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
288
+ "model.layers.6.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
289
+ "model.layers.6.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
290
+ "model.layers.6.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
291
+ "model.layers.6.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
292
+ "model.layers.6.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
293
+ "model.layers.6.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
294
+ "model.layers.6.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
295
+ "model.layers.6.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
296
+ "model.layers.6.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
297
+ "model.layers.6.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
298
+ "model.layers.7.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
299
+ "model.layers.7.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
300
+ "model.layers.7.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
301
+ "model.layers.7.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
302
+ "model.layers.7.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
303
+ "model.layers.7.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
304
+ "model.layers.7.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
305
+ "model.layers.7.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
306
+ "model.layers.7.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
307
+ "model.layers.7.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
308
+ "model.layers.8.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
309
+ "model.layers.8.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
310
+ "model.layers.8.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
311
+ "model.layers.8.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
312
+ "model.layers.8.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
313
+ "model.layers.8.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
314
+ "model.layers.8.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
315
+ "model.layers.8.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
316
+ "model.layers.8.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
317
+ "model.layers.8.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
318
+ "model.layers.9.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
319
+ "model.layers.9.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
320
+ "model.layers.9.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
321
+ "model.layers.9.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
322
+ "model.layers.9.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
323
+ "model.layers.9.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
324
+ "model.layers.9.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
325
+ "model.layers.9.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
326
+ "model.layers.9.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
327
+ "model.layers.9.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
328
+ "model.norm.weight": "pytorch_model-00002-of-00002.bin"
329
+ }
330
+ }
checkpoints/open-llama/7B/open_llama_7b_preview_300bt_transformers_weights/special_tokens_map.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {}
checkpoints/open-llama/7B/open_llama_7b_preview_300bt_transformers_weights/tokenizer.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bc820fc43f4173d6362c16658c409ed423929a807e55a984af96cce1277d39a4
3
+ size 772031
checkpoints/open-llama/7B/open_llama_7b_preview_300bt_transformers_weights/tokenizer_config.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"bos_token": "", "eos_token": "", "model_max_length": 1000000000000000019884624838656, "tokenizer_class": "LlamaTokenizer", "unk_token": ""}
checkpoints/open-llama/7B/tokenizer.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bc820fc43f4173d6362c16658c409ed423929a807e55a984af96cce1277d39a4
3
+ size 772031
checkpoints/open-llama/7B/tokenizer.vocab ADDED
The diff for this file is too large to render. See raw diff
 
data/alpaca/alpaca_data_cleaned_archive.json ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "instruction": "Describe DDOS attack",
4
+ "input": "",
5
+ "output": "A Distributed Denial of Service (DDoS) attack is a malicious attempt to disrupt the normal functioning of a network, service, or website by overwhelming it with a flood of internet traffic. In a DDoS attack, multiple compromised computers or devices, often referred to as a botnet, are used to generate a massive amount of traffic that exceeds the target's capacity to handle it."
6
+ },
7
+ {
8
+ "instruction": "What are DDOS technique?",
9
+ "input": "",
10
+ "output": "Traffic flooding: The attacker overwhelms the target with a massive volume of network traffic, consuming all available resources. \nApplication layer attacks: The attacker targets specific applications or services, attempting to exhaust their resources by exploiting vulnerabilities or by initiating resource-intensive operations. \nProtocol attacks: The attacker exploits weaknesses in network protocols, such as the TCP/IP stack, to consume server resources or disrupt communication.\nReflective/amplified attacks: The attacker uses a technique where they send requests to a large number of vulnerable systems that respond by sending a much larger response back to the target, amplifying the attack's impact."
11
+ },
12
+ {
13
+ "instruction": "How to avoid DDOS attack",
14
+ "input": "",
15
+ "output": "DDoS mitigation services: Consider using a reputable DDoS mitigation service that specializes in detecting and mitigating attacks. These services can help filter out malicious traffic and ensure your network remains accessible during an attack. \nNetwork monitoring: Implement robust network monitoring systems to detect any unusual traffic patterns or sudden spikes in traffic. Early detection can allow you to respond quickly and mitigate the impact of an attack."
16
+ },
17
+ {
18
+ "instruction": "what to do when getting ddos-ed",
19
+ "input": "",
20
+ "output": "Contact your DDoS mitigation service: If you have a DDoS mitigation service in place, immediately contact them to activate their services and inform them about the ongoing attack. They can help filter out malicious traffic and provide guidance on mitigation strategies.\nMonitor network traffic: Use network monitoring tools or services to analyze and monitor the incoming traffic. Identify the characteristics of the attack, such as the source IP addresses, traffic patterns, and attack vectors. This information can be useful for further analysis and potential legal action."
21
+ },
22
+ {
23
+ "instruction": "What is DDOS attack",
24
+ "input": "",
25
+ "output": "is a malicious with a flood of internet traffic. In a DDoS attack, multiple compromised computers or devices, often referred to as a botnet, are used to generate a massive amount of traffic that exceeds the target's capacity to handle it."
26
+ },
27
+ {
28
+ "instruction": "Which are DDOS technique?",
29
+ "input": "Traffic flooding, Scanning, Viewing",
30
+ "output": "Traffic flooding"
31
+ },
32
+ {
33
+ "instruction": "Avoid DDOS attack",
34
+ "input": "",
35
+ "output": "DDoS mitigation services: Consider using a reputable DDoS mitigation service that specializes in detecting and mitigating attacks. These services can help filter out malicious traffic and ensure your network remains accessible during an attack. \nNetwork monitoring: Implement robust network monitoring systems to detect any unusual traffic patterns or sudden spikes in traffic. Early detection can allow you to respond quickly and mitigate the impact of an attack."
36
+ }
37
+ ]
data/alpaca/alpaca_data_cleaned_archive_origin.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:00c26b8da597c1aaa5a0bac023bdb8f26bbaa37a9ead7837df4aa7e51ad57459
3
+ size 23573609
data/alpaca/cloud_cpu_benchmark_report_nipacloud_c035f97b62.pdf ADDED
Binary file (501 kB). View file
 
data/alpaca/test.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:51be5410edd01a0351eae4535769ff441ae4278d17a6a643bed3cfcad5888c1d
3
+ size 4607
data/alpaca/train.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4026f9d8b9f342b2d6938202e017bd5e8f81716e2c79b78a7b0de92861f15050
3
+ size 10902
evaluate.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This mimics GPTQ's evaluation metrics: https://github.com/IST-DASLab/gptq/
2
+ # Thanks to E. Frantar et al GPTQ: Accurate Post-training Compression for GPT, arXiv:2210.17323
3
+ import math
4
+ import sys
5
+ import time
6
+ from pathlib import Path
7
+ from typing import Optional
8
+
9
+ import lightning as L
10
+ import torch
11
+ import tqdm
12
+
13
+ from lit_llama import LLaMA, Tokenizer
14
+ from lit_llama.utils import EmptyInitOnDevice, llama_model_lookup
15
+
16
+ from datasets import load_dataset
17
+
18
+
19
+ def load_eval_data(dataset_name: str) -> str:
20
+ # this mimics gptq datautils
21
+ if dataset_name == "wikitext":
22
+ # traindata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='train')
23
+ testdata = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
24
+ testdata = "\n\n".join(testdata["text"])
25
+ elif dataset_name == "ptb":
26
+ testdata = load_dataset("ptb_text_only", "penn_treebank", split="test")
27
+ testdata = "\n\n".join(testdata["sentence"])
28
+ elif dataset_name == "c4":
29
+ testdata = load_dataset(
30
+ "allenai/c4",
31
+ "allenai--c4",
32
+ data_files={"validation": "en/c4-validation.00000-of-00008.json.gz"},
33
+ split="validation",
34
+ )
35
+ testdata = " ".join(testdata[:1100]["text"])
36
+
37
+ else:
38
+ raise ValueError("invalid dataset name (wikitext, ptb, c4 are allowed)")
39
+ return testdata
40
+
41
+
42
+ def main(
43
+ datasets: str = "wikitext,ptb,c4",
44
+ *,
45
+ # compilation fails as it does not support torch.complex64 for RoPE
46
+ # compile: bool = False,
47
+ accelerator: str = "auto",
48
+ checkpoint_path: Optional[Path] = None,
49
+ tokenizer_path: Optional[Path] = None,
50
+ dtype: str = "float32",
51
+ quantize: Optional[str] = None,
52
+ ) -> None:
53
+ """Generates text samples based on a pre-trained LLaMA model and tokenizer.
54
+
55
+ Args:
56
+ datasets: The datasets to use as a comma separated string
57
+ # compile: Whether to compile the model.
58
+ accelerator: The hardware to run on. Possible choices are:
59
+ ``"cpu"``, ``"cuda"``, ``"mps"``, ``"gpu"``, ``"tpu"``, ``"auto"``.
60
+ checkpoint_path: The checkpoint path to load.
61
+ tokenizer_path: The tokenizer path to load.
62
+ quantize: Whether to quantize the model and using which method:
63
+ ``"llm.int8"``: LLM.int8() mode,
64
+ ``"gptq.int4"``: GPTQ 4-bit mode.
65
+ """
66
+ if not checkpoint_path:
67
+ checkpoint_path = Path(f"./checkpoints/lit-llama/7B/lit-llama.pth")
68
+ if not tokenizer_path:
69
+ tokenizer_path = Path("./checkpoints/lit-llama/tokenizer.model")
70
+ assert checkpoint_path.is_file()
71
+ assert tokenizer_path.is_file()
72
+
73
+ fabric = L.Fabric(accelerator=accelerator, devices=1)
74
+
75
+ dt = getattr(torch, dtype, None)
76
+ if not isinstance(dt, torch.dtype):
77
+ raise ValueError(f"{dtype} is not a valid dtype.")
78
+ dtype = dt
79
+
80
+ with EmptyInitOnDevice(
81
+ device=fabric.device, dtype=dtype, quantization_mode=quantize
82
+ ):
83
+ print("Loading model ...", file=sys.stderr)
84
+ t0 = time.time()
85
+ checkpoint = torch.load(checkpoint_path)
86
+ name = llama_model_lookup(checkpoint)
87
+ model = LLaMA.from_name(name)
88
+ model.load_state_dict(checkpoint)
89
+ print(f"Time to load model: {time.time() - t0:.02f} seconds.", file=sys.stderr)
90
+
91
+ model.eval()
92
+
93
+ # if compile:
94
+ # model = torch.compile(model)
95
+
96
+ total_toks = 0
97
+ model = fabric.setup_module(model)
98
+
99
+ tokenizer = Tokenizer(tokenizer_path)
100
+
101
+ for dsname in datasets.split(","):
102
+ test_string = load_eval_data(dsname)
103
+ encoded_text = tokenizer.encode(
104
+ test_string, bos=True, eos=False, device=fabric.device
105
+ )
106
+ encoded_text = encoded_text[
107
+ None, : 256 * model.config.block_size
108
+ ] # add batch dimension, trim like gptq implementation
109
+ t0 = time.perf_counter()
110
+
111
+ nlls = 0
112
+ toks = 0
113
+ with torch.inference_mode():
114
+ block_size = 2048 # this is for compat with gptq, and indeed we get much worse beyond this (https://github.com/facebookresearch/llama/blob/57b0eb62de0636e75af471e49e2f1862d908d9d8/llama/model.py#L30)
115
+ for i in tqdm.tqdm(range(0, encoded_text.shape[1], block_size)):
116
+ inp = encoded_text[:, i : i + block_size]
117
+ logits = model(inp)[0]
118
+ nll = torch.nn.functional.cross_entropy(
119
+ logits[:-1], inp[0, 1:].to(dtype=torch.long), reduction="sum"
120
+ )
121
+ toks += inp.size(1) - 1
122
+ nlls += nll.item()
123
+
124
+ print(encoded_text.shape, logits.shape)
125
+ encoded_text = encoded_text[:, : logits.shape[0]]
126
+ ppl = math.exp(nlls / toks)
127
+ print(f"Perplexity on {dsname}: {ppl:.2f}")
128
+ total_toks += toks
129
+
130
+ t = time.perf_counter() - t0
131
+ print(
132
+ f"\n\nTime for inference: {t:.02f} sec total, {total_toks / t:.02f} tokens/sec",
133
+ file=sys.stderr,
134
+ )
135
+ print(
136
+ f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB",
137
+ file=sys.stderr,
138
+ )
139
+
140
+
141
+ if __name__ == "__main__":
142
+ from jsonargparse import CLI
143
+
144
+ torch.set_float32_matmul_precision("high")
145
+ CLI(main)
evaluate_adapter.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This mimics GPTQ's evaluation metrics: https://github.com/IST-DASLab/gptq/
2
+ # Thanks to E. Frantar et al GPTQ: Accurate Post-training Compression for GPT, arXiv:2210.17323
3
+ import math
4
+ import sys
5
+ import time
6
+ from pathlib import Path
7
+ from typing import Optional
8
+
9
+ import lightning as L
10
+ import torch
11
+ import tqdm
12
+
13
+ from lit_llama import Tokenizer
14
+ from lit_llama.adapter import LLaMA
15
+ from lit_llama.utils import EmptyInitOnDevice, lazy_load, llama_model_lookup
16
+ from scripts.prepare_alpaca import generate_prompt
17
+
18
+ from datasets import load_dataset
19
+
20
+
21
+ def load_eval_data(dataset_name: str) -> str:
22
+ # this mimics gptq datautils
23
+ if dataset_name == "wikitext":
24
+ # traindata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='train')
25
+ testdata = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
26
+ testdata = "\n\n".join(testdata["text"])
27
+ elif dataset_name == "ptb":
28
+ testdata = load_dataset("ptb_text_only", "penn_treebank", split="test")
29
+ testdata = "\n\n".join(testdata["sentence"])
30
+ elif dataset_name == "c4":
31
+ testdata = load_dataset(
32
+ "allenai/c4",
33
+ "allenai--c4",
34
+ data_files={"validation": "en/c4-validation.00000-of-00008.json.gz"},
35
+ split="validation",
36
+ )
37
+ testdata = " ".join(testdata[:1100]["text"])
38
+
39
+ else:
40
+ raise ValueError("invalid dataset name (wikitext, ptb, c4 are allowed)")
41
+ return testdata
42
+
43
+
44
+ def main(
45
+ datasets: str = "wikitext,ptb,c4",
46
+ *,
47
+ # compilation fails as it does not support torch.complex64 for RoPE
48
+ # compile: bool = False,
49
+ accelerator: str = "auto",
50
+ adapter_path: Optional[Path] = None,
51
+ checkpoint_path: Optional[Path] = None,
52
+ tokenizer_path: Optional[Path] = None,
53
+ dtype: str = "float32",
54
+ quantize: Optional[str] = None,
55
+ ) -> None:
56
+ """Generates text samples based on a pre-trained LLaMA model and tokenizer.
57
+
58
+ Args:
59
+ datasets: The datasets to use as a comma separated string
60
+ # compile: Whether to compile the model.
61
+ accelerator: The hardware to run on. Possible choices are:
62
+ ``"cpu"``, ``"cuda"``, ``"mps"``, ``"gpu"``, ``"tpu"``, ``"auto"``.
63
+ adapter_path: Path to the checkpoint with trained adapter weights, which are the output of
64
+ `finetune_adapter.py`.
65
+ checkpoint_path: The checkpoint path to load.
66
+ tokenizer_path: The tokenizer path to load.
67
+ quantize: Whether to quantize the model and using which method:
68
+ ``"llm.int8"``: LLM.int8() mode,
69
+ ``"gptq.int4"``: GPTQ 4-bit mode.
70
+ """
71
+ if not adapter_path:
72
+ adapter_path = Path("out/adapter/alpaca/lit-llama-adapter-finetuned.pth")
73
+ if not checkpoint_path:
74
+ checkpoint_path = Path(f"./checkpoints/lit-llama/7B/lit-llama.pth")
75
+ if not tokenizer_path:
76
+ tokenizer_path = Path("./checkpoints/lit-llama/tokenizer.model")
77
+
78
+ assert adapter_path.is_file()
79
+ assert checkpoint_path.is_file()
80
+ assert tokenizer_path.is_file()
81
+
82
+ fabric = L.Fabric(accelerator=accelerator, devices=1)
83
+
84
+ dt = getattr(torch, dtype, None)
85
+ if not isinstance(dt, torch.dtype):
86
+ raise ValueError(f"{dtype} is not a valid dtype.")
87
+ dtype = dt
88
+
89
+ with EmptyInitOnDevice(
90
+ device=fabric.device, dtype=dtype, quantization_mode=quantize
91
+ ):
92
+ print("Loading model ...", file=sys.stderr)
93
+ t0 = time.time()
94
+ pretrained_checkpoint = lazy_load(checkpoint_path)
95
+ adapter_checkpoint = lazy_load(adapter_path)
96
+ name = llama_model_lookup(pretrained_checkpoint)
97
+ model = LLaMA.from_name(name)
98
+
99
+ # 1. Load the pretrained weights
100
+ model.load_state_dict(pretrained_checkpoint, strict=False)
101
+ # 2. Load the fine-tuned adapter weights
102
+ model.load_state_dict(adapter_checkpoint, strict=False)
103
+
104
+ print(f"Time to load model: {time.time() - t0:.02f} seconds.", file=sys.stderr)
105
+
106
+ model.eval()
107
+
108
+ # if compile:
109
+ # model = torch.compile(model)
110
+
111
+ total_toks = 0
112
+ model = fabric.setup_module(model)
113
+
114
+ tokenizer = Tokenizer(tokenizer_path)
115
+
116
+ for dsname in datasets.split(","):
117
+ test_string = load_eval_data(dsname)
118
+
119
+ sample = {"instruction": test_string, "input": input}
120
+ test_string = generate_prompt(sample)
121
+
122
+ encoded_text = tokenizer.encode(
123
+ test_string, bos=True, eos=False, device=fabric.device
124
+ )
125
+ encoded_text = encoded_text[
126
+ None, : 256 * model.config.block_size
127
+ ] # add batch dimension, trim like gptq implementation
128
+ t0 = time.perf_counter()
129
+
130
+ nlls = 0
131
+ toks = 0
132
+ with torch.inference_mode():
133
+ block_size = 2048 # this is for compat with gptq, and indeed we get much worse beyond this (https://github.com/facebookresearch/llama/blob/57b0eb62de0636e75af471e49e2f1862d908d9d8/llama/model.py#L30)
134
+ for i in tqdm.tqdm(range(0, encoded_text.shape[1], block_size)):
135
+ inp = encoded_text[:, i : i + block_size]
136
+ logits = model(inp)[0]
137
+ nll = torch.nn.functional.cross_entropy(
138
+ logits[:-1], inp[0, 1:].to(dtype=torch.long), reduction="sum"
139
+ )
140
+ toks += inp.size(1) - 1
141
+ nlls += nll.item()
142
+
143
+ print(encoded_text.shape, logits.shape)
144
+ encoded_text = encoded_text[:, : logits.shape[0]]
145
+ ppl = math.exp(nlls / toks)
146
+ print(f"Perplexity on {dsname}: {ppl:.2f}")
147
+ total_toks += toks
148
+
149
+ t = time.perf_counter() - t0
150
+ print(
151
+ f"\n\nTime for inference: {t:.02f} sec total, {total_toks / t:.02f} tokens/sec",
152
+ file=sys.stderr,
153
+ )
154
+ print(
155
+ f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB",
156
+ file=sys.stderr,
157
+ )
158
+
159
+
160
+ if __name__ == "__main__":
161
+ from jsonargparse import CLI
162
+
163
+ torch.set_float32_matmul_precision("high")
164
+ CLI(main)
evaluate_full.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This mimics GPTQ's evaluation metrics: https://github.com/IST-DASLab/gptq/
2
+ # Thanks to E. Frantar et al GPTQ: Accurate Post-training Compression for GPT, arXiv:2210.17323
3
+ import math
4
+ import sys
5
+ import time
6
+ from pathlib import Path
7
+ from typing import Optional
8
+
9
+ import lightning as L
10
+ import torch
11
+ import tqdm
12
+
13
+ from lit_llama import LLaMA, Tokenizer
14
+ from lit_llama.utils import EmptyInitOnDevice
15
+
16
+ from datasets import load_dataset
17
+
18
+
19
+ def load_eval_data(dataset_name: str) -> str:
20
+ # this mimics gptq datautils
21
+ if dataset_name == "wikitext":
22
+ # traindata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='train')
23
+ testdata = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
24
+ testdata = "\n\n".join(testdata["text"])
25
+ elif dataset_name == "ptb":
26
+ testdata = load_dataset("ptb_text_only", "penn_treebank", split="test")
27
+ testdata = "\n\n".join(testdata["sentence"])
28
+ elif dataset_name == "c4":
29
+ testdata = load_dataset(
30
+ "allenai/c4",
31
+ "allenai--c4",
32
+ data_files={"validation": "en/c4-validation.00000-of-00008.json.gz"},
33
+ split="validation",
34
+ )
35
+ testdata = " ".join(testdata[:1100]["text"])
36
+
37
+ else:
38
+ raise ValueError("invalid dataset name (wikitext, ptb, c4 are allowed)")
39
+ return testdata
40
+
41
+
42
+ def main(
43
+ datasets: str = "wikitext,ptb,c4",
44
+ *,
45
+ # compilation fails as it does not support torch.complex64 for RoPE
46
+ # compile: bool = False,
47
+ accelerator: str = "auto",
48
+ checkpoint_path: Optional[Path] = None,
49
+ tokenizer_path: Optional[Path] = None,
50
+ model_size: str = "7B",
51
+ dtype: str = "float32",
52
+ quantize: Optional[str] = None,
53
+ ) -> None:
54
+ """Generates text samples based on a pre-trained LLaMA model and tokenizer.
55
+
56
+ Args:
57
+ datasets: The datasets to use as a comma separated string
58
+ # compile: Whether to compile the model.
59
+ accelerator: The hardware to run on. Possible choices are:
60
+ ``"cpu"``, ``"cuda"``, ``"mps"``, ``"gpu"``, ``"tpu"``, ``"auto"``.
61
+ checkpoint_path: The checkpoint path to load.
62
+ tokenizer_path: The tokenizer path to load.
63
+ quantize: Whether to quantize the model and using which method:
64
+ ``"llm.int8"``: LLM.int8() mode,
65
+ ``"gptq.int4"``: GPTQ 4-bit mode.
66
+ """
67
+ if not checkpoint_path:
68
+ checkpoint_path = Path(f"./checkpoints/lit-llama/{model_size}/lit-llama.pth")
69
+ if not tokenizer_path:
70
+ tokenizer_path = Path("./checkpoints/lit-llama/tokenizer.model")
71
+ assert checkpoint_path.is_file()
72
+ assert tokenizer_path.is_file()
73
+
74
+ fabric = L.Fabric(accelerator=accelerator, devices=1)
75
+
76
+ dt = getattr(torch, dtype, None)
77
+ if not isinstance(dt, torch.dtype):
78
+ raise ValueError(f"{dtype} is not a valid dtype.")
79
+ dtype = dt
80
+
81
+ with EmptyInitOnDevice(
82
+ device=fabric.device, dtype=dtype, quantization_mode=quantize
83
+ ):
84
+ print("Loading model ...", file=sys.stderr)
85
+ t0 = time.time()
86
+ model = LLaMA.from_name(model_size)
87
+ checkpoint = torch.load(checkpoint_path)
88
+ model.load_state_dict(checkpoint)
89
+ print(f"Time to load model: {time.time() - t0:.02f} seconds.", file=sys.stderr)
90
+
91
+ model.eval()
92
+
93
+ # if compile:
94
+ # model = torch.compile(model)
95
+
96
+ total_toks = 0
97
+ model = fabric.setup_module(model)
98
+
99
+ tokenizer = Tokenizer(tokenizer_path)
100
+
101
+ for dsname in datasets.split(","):
102
+ test_string = load_eval_data(dsname)
103
+ encoded_text = tokenizer.encode(
104
+ test_string, bos=True, eos=False, device=fabric.device
105
+ )
106
+ encoded_text = encoded_text[
107
+ None, : 256 * model.config.block_size
108
+ ] # add batch dimension, trim like gptq implementation
109
+ t0 = time.perf_counter()
110
+
111
+ nlls = 0
112
+ toks = 0
113
+ with torch.inference_mode():
114
+ block_size = 2048 # this is for compat with gptq, and indeed we get much worse beyond this (https://github.com/facebookresearch/llama/blob/57b0eb62de0636e75af471e49e2f1862d908d9d8/llama/model.py#L30)
115
+ for i in tqdm.tqdm(range(0, encoded_text.shape[1], block_size)):
116
+ inp = encoded_text[:, i : i + block_size]
117
+ logits = model(inp)[0]
118
+ nll = torch.nn.functional.cross_entropy(
119
+ logits[:-1], inp[0, 1:].to(dtype=torch.long), reduction="sum"
120
+ )
121
+ toks += inp.size(1) - 1
122
+ nlls += nll.item()
123
+
124
+ print(encoded_text.shape, logits.shape)
125
+ encoded_text = encoded_text[:, : logits.shape[0]]
126
+ ppl = math.exp(nlls / toks)
127
+ print(f"Perplexity on {dsname}: {ppl:.2f}")
128
+ total_toks += toks
129
+
130
+ t = time.perf_counter() - t0
131
+ print(
132
+ f"\n\nTime for inference: {t:.02f} sec total, {total_toks / t:.02f} tokens/sec",
133
+ file=sys.stderr,
134
+ )
135
+ print(
136
+ f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB",
137
+ file=sys.stderr,
138
+ )
139
+
140
+
141
+ if __name__ == "__main__":
142
+ from jsonargparse import CLI
143
+
144
+ torch.set_float32_matmul_precision("high")
145
+ CLI(main)
evaluate_lora.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This mimics GPTQ's evaluation metrics: https://github.com/IST-DASLab/gptq/
2
+ # Thanks to E. Frantar et al GPTQ: Accurate Post-training Compression for GPT, arXiv:2210.17323
3
+ import math
4
+ import sys
5
+ import time
6
+ from pathlib import Path
7
+ from typing import Optional
8
+
9
+ import lightning as L
10
+ import torch
11
+ import tqdm
12
+
13
+ from lit_llama import LLaMA, Tokenizer
14
+ from lit_llama.utils import EmptyInitOnDevice, lazy_load, llama_model_lookup
15
+ from lit_llama.lora import lora
16
+ from scripts.prepare_alpaca import generate_prompt
17
+
18
+ from datasets import load_dataset
19
+
20
+ lora_r = 8
21
+ lora_alpha = 16
22
+ lora_dropout = 0.05
23
+
24
+
25
+ def load_eval_data(dataset_name: str) -> str:
26
+ # this mimics gptq datautils
27
+ if dataset_name == "wikitext":
28
+ # traindata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='train')
29
+ testdata = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
30
+ testdata = "\n\n".join(testdata["text"])
31
+ elif dataset_name == "ptb":
32
+ testdata = load_dataset("ptb_text_only", "penn_treebank", split="test")
33
+ testdata = "\n\n".join(testdata["sentence"])
34
+ elif dataset_name == "c4":
35
+ testdata = load_dataset(
36
+ "allenai/c4",
37
+ "allenai--c4",
38
+ data_files={"validation": "en/c4-validation.00000-of-00008.json.gz"},
39
+ split="validation",
40
+ )
41
+ testdata = " ".join(testdata[:1100]["text"])
42
+
43
+ else:
44
+ raise ValueError("invalid dataset name (wikitext, ptb, c4 are allowed)")
45
+ return testdata
46
+
47
+
48
+ def main(
49
+ datasets: str = "wikitext,ptb,c4",
50
+ *,
51
+ # compilation fails as it does not support torch.complex64 for RoPE
52
+ # compile: bool = False,
53
+ accelerator: str = "auto",
54
+ lora_path: Optional[Path] = None,
55
+ checkpoint_path: Optional[Path] = None,
56
+ tokenizer_path: Optional[Path] = None,
57
+ dtype: str = "float32",
58
+ quantize: Optional[str] = None,
59
+ ) -> None:
60
+ """Generates text samples based on a pre-trained LLaMA model and tokenizer
61
+ finetuned with LoRA.
62
+
63
+ Args:
64
+ datasets: The datasets to use as a comma separated string
65
+ # compile: Whether to compile the model.
66
+ accelerator: The hardware to run on. Possible choices are:
67
+ ``"cpu"``, ``"cuda"``, ``"mps"``, ``"gpu"``, ``"tpu"``, ``"auto"``.
68
+ lora_path: Path to the checkpoint with trained LoRA weights, which are the output of
69
+ `finetune_lora.py`.
70
+ checkpoint_path: The checkpoint path to load.
71
+ tokenizer_path: The tokenizer path to load.
72
+ quantize: Whether to quantize the model and using which method:
73
+ ``"llm.int8"``: LLM.int8() mode,
74
+ ``"gptq.int4"``: GPTQ 4-bit mode.
75
+ """
76
+ if not lora_path:
77
+ lora_path = Path("out/lora/alpaca/lit-llama-lora-finetuned.pth")
78
+ if not checkpoint_path:
79
+ checkpoint_path = Path(f"./checkpoints/lit-llama/7B/lit-llama.pth")
80
+ if not tokenizer_path:
81
+ tokenizer_path = Path("./checkpoints/lit-llama/tokenizer.model")
82
+ assert lora_path.is_file()
83
+ assert checkpoint_path.is_file()
84
+ assert tokenizer_path.is_file()
85
+
86
+ if quantize is not None:
87
+ raise NotImplementedError("Quantization in LoRA is not supported yet")
88
+
89
+ fabric = L.Fabric(accelerator=accelerator, devices=1)
90
+
91
+ dt = getattr(torch, dtype, None)
92
+ if not isinstance(dt, torch.dtype):
93
+ raise ValueError(f"{dtype} is not a valid dtype.")
94
+ dtype = dt
95
+
96
+ print("Loading model ...", file=sys.stderr)
97
+ t0 = time.time()
98
+
99
+ pretrained_checkpoint = lazy_load(checkpoint_path)
100
+ adapter_checkpoint = lazy_load(lora_path)
101
+ name = llama_model_lookup(pretrained_checkpoint)
102
+
103
+ with EmptyInitOnDevice(
104
+ device=fabric.device, dtype=dtype, quantization_mode=quantize
105
+ ), lora(r=lora_r, alpha=lora_alpha, dropout=lora_dropout, enabled=True):
106
+ model = LLaMA.from_name(name)
107
+
108
+ # 1. Load the pretrained weights
109
+ model.load_state_dict(pretrained_checkpoint, strict=False)
110
+ # 2. Load the fine-tuned adapter weights
111
+ model.load_state_dict(adapter_checkpoint, strict=False)
112
+
113
+ print(f"Time to load model: {time.time() - t0:.02f} seconds.", file=sys.stderr)
114
+
115
+ model.eval()
116
+
117
+ # if compile:
118
+ # model = torch.compile(model)
119
+
120
+ total_toks = 0
121
+ model = fabric.setup_module(model)
122
+
123
+ tokenizer = Tokenizer(tokenizer_path)
124
+
125
+ for dsname in datasets.split(","):
126
+ test_string = load_eval_data(dsname)
127
+
128
+ sample = {"instruction": test_string, "input": input}
129
+ test_string = generate_prompt(sample)
130
+
131
+ encoded_text = tokenizer.encode(
132
+ test_string, bos=True, eos=False, device=fabric.device
133
+ )
134
+ encoded_text = encoded_text[
135
+ None, : 256 * model.config.block_size
136
+ ] # add batch dimension, trim like gptq implementation
137
+ t0 = time.perf_counter()
138
+
139
+ nlls = 0
140
+ toks = 0
141
+ with torch.inference_mode():
142
+ block_size = 2048 # this is for compat with gptq, and indeed we get much worse beyond this (https://github.com/facebookresearch/llama/blob/57b0eb62de0636e75af471e49e2f1862d908d9d8/llama/model.py#L30)
143
+ for i in tqdm.tqdm(range(0, encoded_text.shape[1], block_size)):
144
+ inp = encoded_text[:, i : i + block_size]
145
+ logits = model(inp)[0]
146
+ nll = torch.nn.functional.cross_entropy(
147
+ logits[:-1], inp[0, 1:].to(dtype=torch.long), reduction="sum"
148
+ )
149
+ toks += inp.size(1) - 1
150
+ nlls += nll.item()
151
+
152
+ print(encoded_text.shape, logits.shape)
153
+ encoded_text = encoded_text[:, : logits.shape[0]]
154
+ ppl = math.exp(nlls / toks)
155
+ print(f"Perplexity on {dsname}: {ppl:.2f}")
156
+ total_toks += toks
157
+
158
+ t = time.perf_counter() - t0
159
+ print(
160
+ f"\n\nTime for inference: {t:.02f} sec total, {total_toks / t:.02f} tokens/sec",
161
+ file=sys.stderr,
162
+ )
163
+ print(
164
+ f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB",
165
+ file=sys.stderr,
166
+ )
167
+
168
+
169
+ if __name__ == "__main__":
170
+ from jsonargparse import CLI
171
+
172
+ torch.set_float32_matmul_precision("high")
173
+ CLI(main)
finetune_adapter.py ADDED
@@ -0,0 +1,253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Instruction-tuning with LLaMA-Adapter on the Alpaca dataset following the paper
3
+
4
+ LLaMA-Adapter: Efficient Fine-tuning of Language Models with Zero-init Attention
5
+ https://arxiv.org/abs/2303.16199
6
+
7
+ This script runs on a single GPU by default. You can adjust the `micro_batch_size` to fit your GPU memory.
8
+ You can finetune within 1 hour as done in the original paper using DeepSpeed Zero-2 on 8 A100 GPUs by setting the
9
+ devices variable to `devices = 8` and `micro_batch_size = 8` (or higher).
10
+
11
+ Note: If you run into a CUDA error "Expected is_sm80 to be true, but got false", uncomment the line
12
+ `torch.backends.cuda.enable_flash_sdp(False)` in the script below (see https://github.com/Lightning-AI/lit-llama/issues/101).
13
+ """
14
+ import os
15
+ import time
16
+ from pathlib import Path
17
+ import shutil
18
+
19
+ import lightning as L
20
+ import numpy as np
21
+ import torch
22
+
23
+ from generate import generate
24
+ from lit_llama.adapter import LLaMA, LLaMAConfig, mark_only_adapter_as_trainable, adapter_state_from_state_dict
25
+ from lit_llama.tokenizer import Tokenizer
26
+ from scripts.prepare_alpaca import generate_prompt
27
+ from lightning.fabric.strategies import DeepSpeedStrategy
28
+
29
+
30
+ eval_interval = 600
31
+ save_interval = 1000
32
+ eval_iters = 100
33
+ log_interval = 1
34
+ devices = 1
35
+
36
+ # Hyperparameters
37
+ learning_rate = 9e-3
38
+ batch_size = 64 / devices
39
+ micro_batch_size = 4
40
+ gradient_accumulation_steps = batch_size // micro_batch_size
41
+ epoch_size = 50000 # train dataset size
42
+ num_epochs = 5
43
+ max_iters = num_epochs * epoch_size // devices
44
+ weight_decay = 0.02
45
+ max_seq_length = 256 # see scripts/prepare_alpaca.py
46
+ warmup_steps = epoch_size * 2 // micro_batch_size // devices # 2 epochs
47
+
48
+ ds_config = {
49
+ "train_micro_batch_size_per_gpu": micro_batch_size,
50
+ "gradient_accumulation_steps": gradient_accumulation_steps,
51
+ "zero_optimization": {"stage": 2},
52
+ }
53
+
54
+
55
+ def main(
56
+ data_dir: str = "data/alpaca",
57
+ pretrained_path: str = "checkpoints/lit-llama/7B/lit-llama.pth",
58
+ out_dir: str = "out/adapter/alpaca",
59
+ ):
60
+
61
+ fabric = L.Fabric(
62
+ accelerator="cuda",
63
+ devices=devices,
64
+ strategy=(DeepSpeedStrategy(config=ds_config) if devices > 1 else "auto"),
65
+ precision="bf16-true",
66
+ )
67
+ fabric.launch()
68
+ fabric.seed_everything(1337 + fabric.global_rank)
69
+
70
+ if fabric.global_rank == 0:
71
+ os.makedirs(out_dir, exist_ok=True)
72
+
73
+ train_data, val_data = load_datasets(data_dir=data_dir)
74
+
75
+ config = LLaMAConfig(block_size=max_seq_length)
76
+
77
+ if not os.path.isfile(pretrained_path):
78
+ raise FileNotFoundError(
79
+ f"Can't find the pretrained weights at {pretrained_path}."
80
+ " Please follow the instructions in the README to download them."
81
+ )
82
+ checkpoint = torch.load(pretrained_path)
83
+
84
+ with fabric.init_module():
85
+ model = LLaMA(config)
86
+ # strict=False because missing keys due to adapter weights not containted in state dict
87
+ model.load_state_dict(checkpoint, strict=False)
88
+
89
+ mark_only_adapter_as_trainable(model)
90
+
91
+ num_params = sum([p.numel() for p in model.parameters() if p.requires_grad])
92
+ print(f"Number of trainable parameters: {num_params}")
93
+
94
+ optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
95
+ model, optimizer = fabric.setup(model, optimizer)
96
+ train(fabric, model, optimizer, train_data, val_data, out_dir)
97
+
98
+ # Save the final checkpoint at the end of training
99
+ save_model_checkpoint(fabric, model, os.path.join(out_dir, "lit-llama-adapter-finetuned.pth"))
100
+
101
+
102
+ def train(
103
+ fabric: L.Fabric,
104
+ model: torch.nn.Module,
105
+ optimizer: torch.optim.Optimizer,
106
+ train_data: np.ndarray,
107
+ val_data: np.ndarray,
108
+ out_dir: str,
109
+ ) -> None:
110
+ """The training loop.
111
+
112
+ Loosely based on the nanoGPT implementation: https://github.com/karpathy/nanoGPT.
113
+ """
114
+ step_count = 0
115
+
116
+ for iter_num in range(max_iters):
117
+
118
+ if step_count <= warmup_steps:
119
+ # linear warmup
120
+ lr = learning_rate * step_count / warmup_steps
121
+ for param_group in optimizer.param_groups:
122
+ param_group['lr'] = lr
123
+
124
+ t0 = time.time()
125
+
126
+ input_ids, targets = get_batch(fabric, train_data)
127
+ logits = model(input_ids)
128
+ loss = loss_fn(logits, targets)
129
+ with fabric.no_backward_sync(model, enabled=((iter_num + 1) % gradient_accumulation_steps != 0)):
130
+ fabric.backward(loss / gradient_accumulation_steps)
131
+
132
+ if (iter_num + 1) % gradient_accumulation_steps == 0:
133
+ optimizer.step()
134
+ optimizer.zero_grad()
135
+ step_count += 1
136
+
137
+ if step_count % eval_interval == 0:
138
+ val_loss = validate(fabric, model, val_data)
139
+ fabric.print(f"step {iter_num}: val loss {val_loss:.4f}")
140
+ fabric.barrier()
141
+
142
+ if step_count % save_interval == 0:
143
+ print(f"Saving adapter weights to {out_dir}")
144
+ # TODO: Provide a function/script to merge the adapter weights with pretrained weights
145
+ save_model_checkpoint(fabric, model, os.path.join(out_dir, f"iter-{iter_num:06d}.pth"))
146
+
147
+ dt = time.time() - t0
148
+ if iter_num % log_interval == 0:
149
+ fabric.print(f"iter {iter_num}: loss {loss.item():.4f}, time: {dt*1000:.2f}ms")
150
+
151
+
152
+ def generate_response(model, instruction, input=""):
153
+ tokenizer = Tokenizer("checkpoints/lit-llama/tokenizer.model")
154
+ sample = {"instruction": instruction, "input": input}
155
+ prompt = generate_prompt(sample)
156
+ encoded = tokenizer.encode(prompt, bos=True, eos=False, device=model.device)
157
+
158
+ output = generate(
159
+ model,
160
+ idx=encoded,
161
+ max_seq_length=max_seq_length,
162
+ max_new_tokens=100,
163
+ temperature=0.8,
164
+ )
165
+ output = tokenizer.decode(output)
166
+ return output # output.split("### Response:")[1].strip()
167
+
168
+
169
+ @torch.no_grad()
170
+ def validate(fabric: L.Fabric, model: torch.nn.Module, val_data: np.ndarray) -> torch.Tensor:
171
+ fabric.print("Validating ...")
172
+ model.eval()
173
+ losses = torch.zeros(eval_iters)
174
+ for k in range(eval_iters):
175
+ input_ids, targets = get_batch(fabric, val_data)
176
+ logits = model(input_ids)
177
+ loss = loss_fn(logits, targets)
178
+ losses[k] = loss.item()
179
+ val_loss = losses.mean()
180
+
181
+ # produce an example:
182
+ instruction = "Recommend a movie for me to watch during the weekend and explain the reason."
183
+ output = generate_response(model, instruction)
184
+ fabric.print(instruction)
185
+ fabric.print(output)
186
+
187
+ model.train()
188
+ return val_loss.item()
189
+
190
+ def loss_fn(logits, targets):
191
+ # shift the targets such that output n predicts token n+1
192
+ logits = logits[..., :-1, :].contiguous()
193
+ targets = targets[..., 1:].contiguous()
194
+ loss = torch.nn.functional.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
195
+ return loss
196
+
197
+
198
+ def get_batch(fabric: L.Fabric, data: list):
199
+ ix = torch.randint(len(data), (micro_batch_size,))
200
+
201
+ input_ids = [data[i]["input_ids"].type(torch.int64) for i in ix]
202
+ labels = [data[i]["labels"].type(torch.int64) for i in ix]
203
+
204
+ max_len = max(len(s) for s in input_ids)
205
+
206
+ def pad_right(x, pad_id):
207
+ # pad right based on the longest sequence
208
+ n = max_len - len(x)
209
+ return torch.cat((x, torch.full((n,), pad_id, dtype=x.dtype)))
210
+
211
+ x = torch.stack([pad_right(x, pad_id=0) for x in input_ids])
212
+ y = torch.stack([pad_right(x, pad_id=-1) for x in labels])
213
+ x, y = fabric.to_device((x.pin_memory(), y.pin_memory()))
214
+ return x, y
215
+
216
+
217
+ def load_datasets(data_dir):
218
+ train_data = torch.load(os.path.join(data_dir, "train.pt"))
219
+ val_data = torch.load(os.path.join(data_dir, "test.pt"))
220
+ return train_data, val_data
221
+
222
+
223
+ def save_model_checkpoint(fabric, model, file_path):
224
+ file_path = Path(file_path)
225
+
226
+ if isinstance(fabric.strategy, DeepSpeedStrategy):
227
+ from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint
228
+
229
+ tmp_path = file_path.with_suffix(".tmp")
230
+ fabric.save(tmp_path, {"model": model})
231
+ fabric.barrier()
232
+ if fabric.global_rank == 0:
233
+ # Create a consolidated checkpoint with the same name next to the deepspeed checkpoint
234
+ # and only keep the adapter weights
235
+ state_dict = get_fp32_state_dict_from_zero_checkpoint(tmp_path)
236
+ state_dict = adapter_state_from_state_dict(state_dict)
237
+ torch.save(state_dict, file_path)
238
+ shutil.rmtree(tmp_path)
239
+ else:
240
+ state_dict = adapter_state_from_state_dict(model.state_dict())
241
+ if fabric.global_rank == 0:
242
+ torch.save(state_dict, file_path)
243
+ fabric.barrier()
244
+
245
+
246
+ if __name__ == "__main__":
247
+ # Uncomment this line if you see an error: "Expected is_sm80 to be true, but got false"
248
+ # torch.backends.cuda.enable_flash_sdp(False)
249
+ torch.set_float32_matmul_precision("high")
250
+
251
+ from jsonargparse.cli import CLI
252
+
253
+ CLI(main)
finetune_full.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Instruction-tuning on the Alpaca dataset using a regular finetuning procedure (updating all layers).
3
+
4
+ Note: If you run into a CUDA error "Expected is_sm80 to be true, but got false", uncomment the line
5
+ `torch.backends.cuda.enable_flash_sdp(False)` in the script below (see https://github.com/Lightning-AI/lit-llama/issues/101).
6
+ """
7
+ import os
8
+ import time
9
+ from functools import partial
10
+
11
+ import lightning as L
12
+ from lightning.fabric.strategies import FSDPStrategy
13
+ import numpy as np
14
+ import torch
15
+ from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
16
+
17
+ from generate import generate
18
+ from lit_llama.model import Block, LLaMA, LLaMAConfig
19
+ from lit_llama.tokenizer import Tokenizer
20
+ from lit_llama.utils import save_model_checkpoint
21
+ from scripts.prepare_alpaca import generate_prompt
22
+
23
+
24
+ eval_interval = 1000
25
+ save_interval = 1000
26
+ eval_iters = 100
27
+ log_interval = 100
28
+ devices = 4
29
+
30
+ # Hyperparameters
31
+ learning_rate = 3e-5
32
+ batch_size = 128 / devices
33
+ micro_batch_size = 4
34
+ gradient_accumulation_steps = batch_size // micro_batch_size
35
+ epoch_size = 50000 # train dataset size
36
+ num_epochs = 5
37
+ max_iters = num_epochs * epoch_size // micro_batch_size // devices
38
+ weight_decay = 0.0
39
+ block_size = 512
40
+ warmup_steps = 100
41
+
42
+
43
+ def main(
44
+ data_dir: str = "data/alpaca",
45
+ pretrained_path: str = "checkpoints/lit-llama/7B/lit-llama.pth",
46
+ out_dir: str = "out/full/alpaca",
47
+ ):
48
+
49
+ auto_wrap_policy = partial(transformer_auto_wrap_policy, transformer_layer_cls={Block})
50
+ strategy = FSDPStrategy(auto_wrap_policy=auto_wrap_policy, activation_checkpointing=Block)
51
+
52
+ fabric = L.Fabric(accelerator="cuda", devices=devices, precision="bf16-mixed", strategy=strategy)
53
+ fabric.launch()
54
+ fabric.seed_everything(1337 + fabric.global_rank)
55
+
56
+ if fabric.global_rank == 0:
57
+ os.makedirs(out_dir, exist_ok=True)
58
+
59
+ train_data, val_data = load_datasets(data_dir=data_dir)
60
+
61
+ config = LLaMAConfig.from_name("7B")
62
+ config.block_size = block_size
63
+
64
+ checkpoint = torch.load(pretrained_path)
65
+
66
+ with fabric.device:
67
+ torch.set_default_tensor_type(torch.HalfTensor)
68
+ model = LLaMA(config).bfloat16()
69
+ torch.set_default_tensor_type(torch.FloatTensor)
70
+ model.load_state_dict(checkpoint, strict=False)
71
+
72
+ model = fabric.setup_module(model)
73
+
74
+ optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
75
+ optimizer = fabric.setup_optimizers(optimizer)
76
+
77
+ train(fabric, model, optimizer, train_data, val_data, out_dir)
78
+
79
+ # Save the final checkpoint at the end of training
80
+ save_model_checkpoint(fabric, model, os.path.join(out_dir, "lit-llama-full-finetuned.pth"))
81
+
82
+
83
+ def train(
84
+ fabric: L.Fabric,
85
+ model: torch.nn.Module,
86
+ optimizer: torch.optim.Optimizer,
87
+ train_data: np.ndarray,
88
+ val_data: np.ndarray,
89
+ out_dir: str,
90
+ ) -> None:
91
+ """The training loop.
92
+
93
+ Loosely based on the nanoGPT implementation: https://github.com/karpathy/nanoGPT.
94
+ """
95
+ step_count = 0
96
+ model.train()
97
+
98
+ for iter_num in range(max_iters):
99
+
100
+ is_accumulating = (iter_num + 1) % gradient_accumulation_steps == 0
101
+
102
+ if step_count <= warmup_steps:
103
+ # linear warmup
104
+ lr = learning_rate * step_count / warmup_steps
105
+ for param_group in optimizer.param_groups:
106
+ param_group['lr'] = lr
107
+
108
+ t0 = time.time()
109
+
110
+ with fabric.no_backward_sync(model, enabled=is_accumulating):
111
+ input_ids, targets = get_batch(fabric, train_data)
112
+ logits = model(input_ids)
113
+ loss = loss_fn(logits, targets)
114
+ fabric.backward(loss)
115
+
116
+ if not is_accumulating:
117
+ optimizer.step()
118
+ optimizer.zero_grad()
119
+ step_count += 1
120
+
121
+ if step_count % eval_interval == 0:
122
+ val_loss = validate(fabric, model, val_data)
123
+ fabric.print(f"step {iter_num}: val loss {val_loss:.4f}")
124
+ fabric.barrier()
125
+
126
+ if step_count % save_interval == 0:
127
+ print(f"Saving weights to {out_dir}")
128
+ save_model_checkpoint(fabric, model, os.path.join(out_dir, f"iter-{iter_num:06d}-ckpt.pth"))
129
+
130
+ dt = time.time() - t0
131
+ if iter_num % log_interval == 0:
132
+ fabric.print(f"iter {iter_num}: loss {loss.item():.4f}, time: {dt*1000:.2f}ms")
133
+
134
+
135
+ def generate_response(model, instruction):
136
+ tokenizer = Tokenizer("checkpoints/lit-llama/tokenizer.model")
137
+ sample = {"instruction": instruction, "input": ""}
138
+ prompt = generate_prompt(sample)
139
+ encoded = tokenizer.encode(prompt, bos=True, eos=False, device=model.device)
140
+
141
+ output = generate(
142
+ model,
143
+ idx=encoded,
144
+ max_seq_length=block_size,
145
+ max_new_tokens=100,
146
+ )
147
+ output = tokenizer.decode(output)
148
+ return output # output.split("### Response:")[1].strip()
149
+
150
+
151
+ @torch.no_grad()
152
+ def validate(fabric: L.Fabric, model: torch.nn.Module, val_data: np.ndarray) -> torch.Tensor:
153
+ fabric.print("Validating ...")
154
+ model.eval()
155
+ losses = torch.zeros(eval_iters)
156
+ for k in range(eval_iters):
157
+ input_ids, targets = get_batch(fabric, val_data)
158
+ logits = model(input_ids)
159
+ loss = loss_fn(logits, targets)
160
+ losses[k] = loss.item()
161
+ out = losses.mean()
162
+
163
+ # produce an example:
164
+ instruction = "Recommend a movie for me to watch during the weekend and explain the reason."
165
+
166
+ output = generate_response(model, instruction)
167
+ fabric.print(instruction)
168
+ fabric.print(output)
169
+
170
+ model.train()
171
+ return out.item()
172
+
173
+
174
+ def loss_fn(logits, targets):
175
+ # shift the targets such that output n predicts token n+1
176
+ logits = logits[..., :-1, :].contiguous()
177
+ targets = targets[..., 1:].contiguous()
178
+ loss = torch.nn.functional.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
179
+ return loss
180
+
181
+
182
+ def get_batch(fabric: L.Fabric, data: list):
183
+ ix = torch.randint(len(data), (micro_batch_size,))
184
+
185
+ input_ids = [data[i]["input_ids"].type(torch.int64) for i in ix]
186
+ labels = [data[i]["labels"].type(torch.int64) for i in ix]
187
+
188
+ max_len = max(len(s) for s in input_ids)
189
+
190
+ def pad_right(x, pad_id):
191
+ # pad right based on the longest sequence
192
+ n = max_len - len(x)
193
+ return torch.cat((x, torch.full((n,), pad_id, dtype=x.dtype)))
194
+
195
+ x = torch.stack([pad_right(x, pad_id=0) for x in input_ids])
196
+ y = torch.stack([pad_right(x, pad_id=-1) for x in labels])
197
+ x, y = fabric.to_device((x.pin_memory(), y.pin_memory()))
198
+ return x, y
199
+
200
+
201
+ def load_datasets(data_dir):
202
+ train_data = torch.load(os.path.join(data_dir, "train.pt"))
203
+ val_data = torch.load(os.path.join(data_dir, "test.pt"))
204
+ return train_data, val_data
205
+
206
+
207
+ if __name__ == "__main__":
208
+ # Uncomment this line if you see an error: "Expected is_sm80 to be true, but got false"
209
+ # torch.backends.cuda.enable_flash_sdp(False)
210
+ torch.set_float32_matmul_precision("high")
211
+
212
+ from jsonargparse.cli import CLI
213
+
214
+ CLI(main)
finetune_lora.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Instruction-tuning with LoRA on the Alpaca dataset.
3
+
4
+ Note: If you run into a CUDA error "Expected is_sm80 to be true, but got false", uncomment the line
5
+ `torch.backends.cuda.enable_flash_sdp(False)` in the script below (see https://github.com/Lightning-AI/lit-llama/issues/101).
6
+ """
7
+ import os
8
+ import time
9
+
10
+ import lightning as L
11
+ import numpy as np
12
+ import torch
13
+
14
+ from generate import generate
15
+ from lit_llama.lora import mark_only_lora_as_trainable, lora, lora_state_dict
16
+ from lit_llama.model import LLaMA, LLaMAConfig
17
+ from lit_llama.tokenizer import Tokenizer
18
+ from scripts.prepare_alpaca import generate_prompt
19
+
20
+
21
+ eval_interval = 100
22
+ save_interval = 100
23
+ eval_iters = 100
24
+ log_interval = 1
25
+
26
+ # Hyperparameters
27
+ learning_rate = 3e-4
28
+ batch_size = 128
29
+ micro_batch_size = 4
30
+ gradient_accumulation_steps = batch_size // micro_batch_size
31
+ max_iters = 2 #50000 * 3 // micro_batch_size
32
+ weight_decay = 0.0
33
+ max_seq_length = 256 # see scripts/prepare_alpaca.py
34
+ lora_r = 8
35
+ lora_alpha = 16
36
+ lora_dropout = 0.05
37
+ warmup_steps = 100
38
+
39
+
40
+ def main(
41
+ data_dir: str = "data/alpaca",
42
+ pretrained_path: str = "checkpoints/lit-llama/7B/lit-llama.pth",
43
+ out_dir: str = "out/lora/alpaca",
44
+ ):
45
+
46
+ #fabric = L.Fabric(accelerator="cuda", precision="bf16-true")
47
+ fabric = L.Fabric(accelerator="cpu", devices=2, precision="bf16-true")
48
+ fabric.launch()
49
+ fabric.seed_everything(1337 + fabric.global_rank)
50
+
51
+ if fabric.global_rank == 0:
52
+ os.makedirs(out_dir, exist_ok=True)
53
+ print("loading dataset ", data_dir)
54
+ train_data, val_data = load_datasets(data_dir=data_dir)
55
+ print("train data: ", len(train_data))
56
+ print("val data: ", len(val_data))
57
+ config = LLaMAConfig.from_name("7B")
58
+ config.block_size = max_seq_length
59
+ print("loading pretrained model ", pretrained_path)
60
+ checkpoint = torch.load(pretrained_path)
61
+
62
+ with fabric.init_module(), lora(r=lora_r, alpha=lora_alpha, dropout=lora_dropout, enabled=True):
63
+ model = LLaMA(config)
64
+ # strict=False because missing keys due to LoRA weights not contained in checkpoint state
65
+ model.load_state_dict(checkpoint, strict=False)
66
+
67
+ mark_only_lora_as_trainable(model)
68
+
69
+ optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
70
+ model, optimizer = fabric.setup(model, optimizer)
71
+ print("start training")
72
+ train(fabric, model, optimizer, train_data, val_data, out_dir)
73
+
74
+ # Save the final LoRA checkpoint at the end of training
75
+ print(f"Saving LoRA weights to {out_dir}")
76
+ checkpoint = lora_state_dict(model)
77
+ fabric.save(os.path.join(out_dir, "lit-llama-lora-finetuned.pth"), checkpoint)
78
+
79
+
80
+ def train(
81
+ fabric: L.Fabric,
82
+ model: torch.nn.Module,
83
+ optimizer: torch.optim.Optimizer,
84
+ train_data: np.ndarray,
85
+ val_data: np.ndarray,
86
+ out_dir: str,
87
+ ) -> None:
88
+ """The training loop.
89
+
90
+ Loosely based on the nanoGPT implementation: https://github.com/karpathy/nanoGPT.
91
+ """
92
+ step_count = 0
93
+ print("max iters:", max_iters )
94
+
95
+ for iter_num in range(max_iters):
96
+ print("iter_num", iter_num)
97
+ if step_count <= warmup_steps:
98
+ # linear warmup
99
+ lr = learning_rate * step_count / warmup_steps
100
+ for param_group in optimizer.param_groups:
101
+ param_group['lr'] = lr
102
+
103
+ t0 = time.time()
104
+
105
+ input_ids, targets = get_batch(fabric, train_data)
106
+ logits = model(input_ids)
107
+ print("calculate loss")
108
+ loss = loss_fn(logits, targets)
109
+ print("backward")
110
+ fabric.backward(loss)
111
+
112
+ if (iter_num + 1) % gradient_accumulation_steps == 0:
113
+ print("step optimizer")
114
+ optimizer.step()
115
+ optimizer.zero_grad()
116
+ step_count += 1
117
+ if step_count % eval_interval == 0:
118
+ val_loss = validate(fabric, model, val_data)
119
+ fabric.print(f"step {iter_num}: val loss {val_loss:.4f}")
120
+ fabric.barrier()
121
+
122
+ if step_count % save_interval == 0:
123
+ print(f"Saving LoRA weights to {out_dir}")
124
+ # We are only saving the LoRA weights
125
+ # TODO: Provide a function/script to merge the LoRA weights with pretrained weights
126
+ checkpoint = lora_state_dict(model)
127
+ fabric.save(os.path.join(out_dir, f"iter-{iter_num:06d}-ckpt.pth"), checkpoint)
128
+
129
+ dt = time.time() - t0
130
+ if iter_num % log_interval == 0:
131
+ fabric.print(f"iter {iter_num}: loss {loss.item():.4f}, time: {dt*1000:.2f}ms")
132
+
133
+
134
+ def generate_response(model, instruction):
135
+ tokenizer = Tokenizer("checkpoints/lit-llama/tokenizer.model")
136
+ sample = {"instruction": instruction, "input": ""}
137
+ prompt = generate_prompt(sample)
138
+ encoded = tokenizer.encode(prompt, bos=True, eos=False, device=model.device)
139
+
140
+ output = generate(
141
+ model,
142
+ idx=encoded,
143
+ max_seq_length=max_seq_length,
144
+ max_new_tokens=100,
145
+ )
146
+ output = tokenizer.decode(output)
147
+ return output # output.split("### Response:")[1].strip()
148
+
149
+
150
+ @torch.no_grad()
151
+ def validate(fabric: L.Fabric, model: torch.nn.Module, val_data: np.ndarray) -> torch.Tensor:
152
+ fabric.print("Validating ...")
153
+ model.eval()
154
+ losses = torch.zeros(eval_iters)
155
+ for k in range(eval_iters):
156
+ input_ids, targets = get_batch(fabric, val_data)
157
+ logits = model(input_ids)
158
+ loss = loss_fn(logits, targets)
159
+ losses[k] = loss.item()
160
+ out = losses.mean()
161
+
162
+ # produce an example:
163
+ instruction = "Recommend a movie for me to watch during the weekend and explain the reason."
164
+ output = generate_response(model, instruction)
165
+ fabric.print(instruction)
166
+ fabric.print(output)
167
+
168
+ model.train()
169
+ return out.item()
170
+
171
+ def loss_fn(logits, targets):
172
+ # shift the targets such that output n predicts token n+1
173
+ logits = logits[..., :-1, :].contiguous()
174
+ targets = targets[..., 1:].contiguous()
175
+ loss = torch.nn.functional.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
176
+ return loss
177
+
178
+
179
+ def get_batch(fabric: L.Fabric, data: list):
180
+ ix = torch.randint(len(data), (micro_batch_size,))
181
+
182
+ input_ids = [data[i]["input_ids"].type(torch.int64) for i in ix]
183
+ labels = [data[i]["labels"].type(torch.int64) for i in ix]
184
+
185
+ max_len = max(len(s) for s in input_ids)
186
+
187
+ def pad_right(x, pad_id):
188
+ # pad right based on the longest sequence
189
+ n = max_len - len(x)
190
+ return torch.cat((x, torch.full((n,), pad_id, dtype=x.dtype)))
191
+
192
+ x = torch.stack([pad_right(x, pad_id=0) for x in input_ids])
193
+ y = torch.stack([pad_right(x, pad_id=-1) for x in labels])
194
+ x, y = fabric.to_device((x.pin_memory(), y.pin_memory()))
195
+ return x, y
196
+
197
+
198
+ def load_datasets(data_dir):
199
+ train_data = torch.load(os.path.join(data_dir, "train.pt"))
200
+ val_data = torch.load(os.path.join(data_dir, "test.pt"))
201
+ return train_data, val_data
202
+
203
+
204
+ if __name__ == "__main__":
205
+ # Uncomment this line if you see an error: "Expected is_sm80 to be true, but got false"
206
+ # torch.backends.cuda.enable_flash_sdp(False)
207
+ torch.set_float32_matmul_precision("high")
208
+
209
+ from jsonargparse.cli import CLI
210
+
211
+ CLI(main)
generate.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import time
3
+ import warnings
4
+ from pathlib import Path
5
+ from typing import Optional
6
+
7
+ import lightning as L
8
+ import torch
9
+
10
+ from lit_llama import LLaMA, Tokenizer
11
+ from lit_llama.utils import EmptyInitOnDevice, lazy_load, llama_model_lookup
12
+
13
+ @torch.no_grad()
14
+ def generate(
15
+ model: torch.nn.Module,
16
+ idx: torch.Tensor,
17
+ max_new_tokens: int,
18
+ max_seq_length: int,
19
+ temperature: float = 1.0,
20
+ top_k: Optional[int] = None,
21
+ eos_id: Optional[int] = None,
22
+ ) -> torch.Tensor:
23
+ """Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
24
+
25
+ The implementation of this function is modified from A. Karpathy's nanoGPT.
26
+
27
+ Args:
28
+ model: The model to use.
29
+ idx: Tensor of shape (T) with indices of the prompt sequence.
30
+ max_new_tokens: The number of new tokens to generate.
31
+ max_seq_length: The maximum sequence length allowed.
32
+ temperature: Scales the predicted logits by 1 / temperature
33
+ top_k: If specified, only sample among the tokens with the k highest probabilities
34
+ eos_id: If specified, stop generating any more token once the <eos> token is triggered
35
+ """
36
+ # create an empty tensor of the expected final shape and fill in the current tokens
37
+ T = idx.size(0)
38
+ T_new = T + max_new_tokens
39
+ empty = torch.empty(T_new, dtype=idx.dtype, device=idx.device)
40
+ empty[:T] = idx
41
+ idx = empty
42
+
43
+ # generate max_new_tokens tokens
44
+ for t in range(T, T_new):
45
+ # ignore the not-filled-yet tokens
46
+ idx_cond = idx[:t]
47
+ # if the sequence context is growing too long we must crop it at max_seq_length
48
+ idx_cond = idx_cond if T <= max_seq_length else idx_cond[-max_seq_length:]
49
+
50
+ # forward
51
+ logits = model(idx_cond.view(1, -1))
52
+ logits = logits[0, -1] / temperature
53
+
54
+ # optionally crop the logits to only the top k options
55
+ if top_k is not None:
56
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
57
+ logits[logits < v[[-1]]] = -float("Inf")
58
+
59
+ probs = torch.nn.functional.softmax(logits, dim=-1)
60
+ idx_next = torch.multinomial(probs, num_samples=1)
61
+
62
+ # concatenate the new generation
63
+ idx[t] = idx_next
64
+
65
+ # if <eos> token is triggered, return the output (stop generation)
66
+ if idx_next == eos_id:
67
+ return idx[:t + 1] # include the EOS token
68
+
69
+ return idx
70
+
71
+
72
+ def main(
73
+ prompt: str = "Hello, my name is",
74
+ *,
75
+ num_samples: int = 1,
76
+ max_new_tokens: int = 50,
77
+ top_k: int = 200,
78
+ temperature: float = 0.8,
79
+ checkpoint_path: Optional[Path] = None,
80
+ tokenizer_path: Optional[Path] = None,
81
+ quantize: Optional[str] = None,
82
+ ) -> None:
83
+ """Generates text samples based on a pre-trained LLaMA model and tokenizer.
84
+
85
+ Args:
86
+ prompt: The prompt string to use for generating the samples.
87
+ num_samples: The number of text samples to generate.
88
+ max_new_tokens: The number of generation steps to take.
89
+ top_k: The number of top most probable tokens to consider in the sampling process.
90
+ temperature: A value controlling the randomness of the sampling process. Higher values result in more random
91
+ samples.
92
+ checkpoint_path: The checkpoint path to load.
93
+ tokenizer_path: The tokenizer path to load.
94
+ quantize: Whether to quantize the model and using which method:
95
+ ``"llm.int8"``: LLM.int8() mode,
96
+ ``"gptq.int4"``: GPTQ 4-bit mode.
97
+ """
98
+ if not checkpoint_path:
99
+ checkpoint_path = Path(f"./checkpoints/lit-llama/7B/lit-llama.pth")
100
+ if not tokenizer_path:
101
+ tokenizer_path = Path("./checkpoints/lit-llama/tokenizer.model")
102
+ assert checkpoint_path.is_file(), checkpoint_path
103
+ assert tokenizer_path.is_file(), tokenizer_path
104
+
105
+ fabric = L.Fabric(devices="auto", accelerator="cuda")
106
+ #fabric = L.Fabric(accelerator="cpu")
107
+ fabric.launch()
108
+ dtype = torch.bfloat16 if fabric.device.type == "cuda" and torch.cuda.is_bf16_supported() else torch.float32
109
+
110
+ print("Loading model ...", file=sys.stderr)
111
+ t0 = time.time()
112
+ with lazy_load(checkpoint_path) as checkpoint:
113
+ name = llama_model_lookup(checkpoint)
114
+
115
+ with EmptyInitOnDevice(
116
+ device=fabric.device, dtype=dtype, quantization_mode=quantize
117
+ ):
118
+ model = LLaMA.from_name(name)
119
+
120
+ model.load_state_dict(checkpoint)
121
+ print(f"Time to load model: {time.time() - t0:.02f} seconds.", file=sys.stderr)
122
+
123
+ model.eval()
124
+ model = fabric.setup_module(model)
125
+
126
+ tokenizer = Tokenizer(tokenizer_path)
127
+ encoded_prompt = tokenizer.encode(prompt, bos=True, eos=False, device=fabric.device)
128
+
129
+ L.seed_everything(1234)
130
+ for i in range(num_samples):
131
+ t0 = time.perf_counter()
132
+ y = generate(
133
+ model,
134
+ encoded_prompt,
135
+ max_new_tokens,
136
+ model.config.block_size, # type: ignore[union-attr,arg-type]
137
+ temperature=temperature,
138
+ top_k=top_k,
139
+ )
140
+ t = time.perf_counter() - t0
141
+ print(tokenizer.decode(y))
142
+ print(f"Time for inference {i + 1}: {t:.02f} sec total, {max_new_tokens / t:.02f} tokens/sec", file=sys.stderr)
143
+ if fabric.device.type == "cuda":
144
+ print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB", file=sys.stderr)
145
+
146
+
147
+ if __name__ == "__main__":
148
+ from jsonargparse import CLI
149
+ torch.backends.cuda.max_split_size_mb = 16
150
+ torch.quantization.quantize_dynamic
151
+ torch.set_float32_matmul_precision("high")
152
+ warnings.filterwarnings(
153
+ # Triggered internally at ../aten/src/ATen/EmptyTensor.cpp:31
154
+ "ignore",
155
+ message="ComplexHalf support is experimental and many operators don't support it yet"
156
+ )
157
+ warnings.filterwarnings(
158
+ # Triggered in bitsandbytes/autograd/_functions.py:298
159
+ "ignore",
160
+ message="MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization",
161
+ )
162
+ CLI(main)
generate_adapter.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import time
3
+ import warnings
4
+ from pathlib import Path
5
+ from typing import Optional
6
+
7
+ import lightning as L
8
+ import torch
9
+
10
+ from generate import generate
11
+ from lit_llama import Tokenizer
12
+ from lit_llama.adapter import LLaMA
13
+ from lit_llama.utils import EmptyInitOnDevice, lazy_load, llama_model_lookup
14
+ from scripts.prepare_alpaca import generate_prompt
15
+
16
+
17
+ def main(
18
+ prompt: str = "What food do lamas eat?",
19
+ input: str = "",
20
+ adapter_path: Optional[Path] = None,
21
+ pretrained_path: Optional[Path] = None,
22
+ tokenizer_path: Optional[Path] = None,
23
+ quantize: Optional[str] = None,
24
+ max_new_tokens: int = 100,
25
+ top_k: int = 200,
26
+ temperature: float = 0.8,
27
+ ) -> None:
28
+ """Generates a response based on a given instruction and an optional input.
29
+ This script will only work with checkpoints from the instruction-tuned LLaMA-Adapter model.
30
+ See `finetune_adapter.py`.
31
+
32
+ Args:
33
+ prompt: The prompt/instruction (Alpaca style).
34
+ adapter_path: Path to the checkpoint with trained adapter weights, which are the output of
35
+ `finetune_adapter.py`.
36
+ input: Optional input (Alpaca style).
37
+ pretrained_path: The path to the checkpoint with pretrained LLaMA weights.
38
+ tokenizer_path: The tokenizer path to load.
39
+ quantize: Whether to quantize the model and using which method:
40
+ ``"llm.int8"``: LLM.int8() mode,
41
+ ``"gptq.int4"``: GPTQ 4-bit mode.
42
+ max_new_tokens: The number of generation steps to take.
43
+ top_k: The number of top most probable tokens to consider in the sampling process.
44
+ temperature: A value controlling the randomness of the sampling process. Higher values result in more random
45
+ samples.
46
+ """
47
+ if not adapter_path:
48
+ adapter_path = Path("out/adapter/alpaca/lit-llama-adapter-finetuned.pth")
49
+ if not pretrained_path:
50
+ pretrained_path = Path(f"./checkpoints/lit-llama/7B/lit-llama.pth")
51
+ if not tokenizer_path:
52
+ tokenizer_path = Path("./checkpoints/lit-llama/tokenizer.model")
53
+
54
+ assert adapter_path.is_file()
55
+ assert pretrained_path.is_file()
56
+ assert tokenizer_path.is_file()
57
+
58
+ fabric = L.Fabric(devices=1)
59
+ dtype = torch.bfloat16 if fabric.device.type == "cuda" and torch.cuda.is_bf16_supported() else torch.float32
60
+
61
+ print("Loading model ...", file=sys.stderr)
62
+ t0 = time.time()
63
+ with (lazy_load(pretrained_path) as pretrained_checkpoint,
64
+ lazy_load(adapter_path) as adapter_checkpoint):
65
+ name = llama_model_lookup(pretrained_checkpoint)
66
+
67
+ with EmptyInitOnDevice(
68
+ device=fabric.device, dtype=dtype, quantization_mode=quantize
69
+ ):
70
+ model = LLaMA.from_name(name)
71
+
72
+ # 1. Load the pretrained weights
73
+ model.load_state_dict(pretrained_checkpoint, strict=False)
74
+ # 2. Load the fine-tuned adapter weights
75
+ model.load_state_dict(adapter_checkpoint, strict=False)
76
+
77
+ print(f"Time to load model: {time.time() - t0:.02f} seconds.", file=sys.stderr)
78
+
79
+ model.eval()
80
+ model = fabric.setup_module(model)
81
+
82
+ tokenizer = Tokenizer(tokenizer_path)
83
+ sample = {"instruction": prompt, "input": input}
84
+ prompt = generate_prompt(sample)
85
+ encoded = tokenizer.encode(prompt, bos=True, eos=False, device=model.device)
86
+
87
+ t0 = time.perf_counter()
88
+ output = generate(
89
+ model,
90
+ idx=encoded,
91
+ max_seq_length=max_new_tokens,
92
+ max_new_tokens=max_new_tokens,
93
+ temperature=temperature,
94
+ top_k=top_k,
95
+ eos_id=tokenizer.eos_id
96
+ )
97
+ t = time.perf_counter() - t0
98
+
99
+ output = tokenizer.decode(output)
100
+ output = output.split("### Response:")[1].strip()
101
+ print(output)
102
+
103
+ print(f"\n\nTime for inference: {t:.02f} sec total, {max_new_tokens / t:.02f} tokens/sec", file=sys.stderr)
104
+ if fabric.device.type == "cuda":
105
+ print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB", file=sys.stderr)
106
+
107
+
108
+ if __name__ == "__main__":
109
+ from jsonargparse import CLI
110
+
111
+ torch.set_float32_matmul_precision("high")
112
+ warnings.filterwarnings(
113
+ # Triggered internally at ../aten/src/ATen/EmptyTensor.cpp:31
114
+ "ignore",
115
+ message="ComplexHalf support is experimental and many operators don't support it yet"
116
+ )
117
+ CLI(main)
generate_full.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import time
3
+ import warnings
4
+ from pathlib import Path
5
+ from typing import Optional
6
+
7
+ import lightning as L
8
+ import torch
9
+
10
+ from lit_llama import LLaMA, Tokenizer
11
+ from lit_llama.utils import EmptyInitOnDevice
12
+
13
+
14
+ @torch.no_grad()
15
+ def generate(
16
+ model: torch.nn.Module,
17
+ idx: torch.Tensor,
18
+ max_new_tokens: int,
19
+ max_seq_length: int,
20
+ temperature: float = 1.0,
21
+ top_k: Optional[int] = None,
22
+ eos_id: Optional[int] = None,
23
+ ) -> torch.Tensor:
24
+ """Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
25
+
26
+ The implementation of this function is modified from A. Karpathy's nanoGPT.
27
+
28
+ Args:
29
+ model: The model to use.
30
+ idx: Tensor of shape (T) with indices of the prompt sequence.
31
+ max_new_tokens: The number of new tokens to generate.
32
+ max_seq_length: The maximum sequence length allowed.
33
+ temperature: Scales the predicted logits by 1 / temperature
34
+ top_k: If specified, only sample among the tokens with the k highest probabilities
35
+ eos_id: If specified, stop generating any more token once the <eos> token is triggered
36
+ """
37
+ # create an empty tensor of the expected final shape and fill in the current tokens
38
+ T = idx.size(0)
39
+ T_new = T + max_new_tokens
40
+ empty = torch.empty(T_new, dtype=idx.dtype, device=idx.device)
41
+ empty[:T] = idx
42
+ idx = empty
43
+
44
+ # generate max_new_tokens tokens
45
+ for t in range(T, T_new):
46
+ # ignore the not-filled-yet tokens
47
+ idx_cond = idx[:t]
48
+ # if the sequence context is growing too long we must crop it at max_seq_length
49
+ idx_cond = idx_cond if T <= max_seq_length else idx_cond[-max_seq_length:]
50
+
51
+ # forward
52
+ logits = model(idx_cond.view(1, -1))
53
+ logits = logits[0, -1] / temperature
54
+
55
+ # optionally crop the logits to only the top k options
56
+ if top_k is not None:
57
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
58
+ logits[logits < v[[-1]]] = -float("Inf")
59
+
60
+ probs = torch.nn.functional.softmax(logits, dim=-1)
61
+ idx_next = torch.multinomial(probs, num_samples=1)
62
+
63
+ # concatenate the new generation
64
+ idx[t] = idx_next
65
+
66
+ # if <eos> token is triggered, return the output (stop generation)
67
+ if idx_next == eos_id:
68
+ return idx[:t + 1] # include the EOS token
69
+
70
+ return idx
71
+
72
+
73
+ def main(
74
+ prompt: str = "Hello, my name is",
75
+ *,
76
+ num_samples: int = 1,
77
+ max_new_tokens: int = 50,
78
+ top_k: int = 200,
79
+ temperature: float = 0.8,
80
+ checkpoint_path: Optional[Path] = None,
81
+ tokenizer_path: Optional[Path] = None,
82
+ model_size: str = "7B",
83
+ quantize: Optional[str] = None,
84
+ ) -> None:
85
+ """Generates text samples based on a pre-trained LLaMA model and tokenizer.
86
+
87
+ Args:
88
+ prompt: The prompt string to use for generating the samples.
89
+ num_samples: The number of text samples to generate.
90
+ max_new_tokens: The number of generation steps to take.
91
+ top_k: The number of top most probable tokens to consider in the sampling process.
92
+ temperature: A value controlling the randomness of the sampling process. Higher values result in more random
93
+ samples.
94
+ checkpoint_path: The checkpoint path to load.
95
+ tokenizer_path: The tokenizer path to load.
96
+ model_size: The model size to load.
97
+ quantize: Whether to quantize the model and using which method:
98
+ ``"llm.int8"``: LLM.int8() mode,
99
+ ``"gptq.int4"``: GPTQ 4-bit mode.
100
+ """
101
+ if not checkpoint_path:
102
+ checkpoint_path = Path(f"./checkpoints/lit-llama/{model_size}/lit-llama.pth")
103
+ if not tokenizer_path:
104
+ tokenizer_path = Path("./checkpoints/lit-llama/tokenizer.model")
105
+ assert checkpoint_path.is_file(), checkpoint_path
106
+ assert tokenizer_path.is_file(), tokenizer_path
107
+
108
+ fabric = L.Fabric(devices=1)
109
+ dtype = torch.bfloat16 if fabric.device.type == "cuda" and torch.cuda.is_bf16_supported() else torch.float32
110
+
111
+ print("Loading model ...", file=sys.stderr)
112
+ t0 = time.time()
113
+ with EmptyInitOnDevice(
114
+ device=fabric.device, dtype=dtype, quantization_mode=quantize
115
+ ):
116
+ model = LLaMA.from_name(model_size)
117
+
118
+ checkpoint = torch.load(checkpoint_path)
119
+ model.load_state_dict(checkpoint)
120
+ print(f"Time to load model: {time.time() - t0:.02f} seconds.", file=sys.stderr)
121
+
122
+ model.eval()
123
+ model = fabric.setup_module(model)
124
+
125
+ tokenizer = Tokenizer(tokenizer_path)
126
+ encoded_prompt = tokenizer.encode(prompt, bos=True, eos=False, device=fabric.device)
127
+
128
+ L.seed_everything(1234)
129
+ for i in range(num_samples):
130
+ t0 = time.perf_counter()
131
+ y = generate(
132
+ model,
133
+ encoded_prompt,
134
+ max_new_tokens,
135
+ model.config.block_size, # type: ignore[union-attr,arg-type]
136
+ temperature=temperature,
137
+ top_k=top_k,
138
+ )
139
+ t = time.perf_counter() - t0
140
+ print(tokenizer.decode(y))
141
+ print(f"Time for inference {i + 1}: {t:.02f} sec total, {max_new_tokens / t:.02f} tokens/sec", file=sys.stderr)
142
+ if fabric.device.type == "cuda":
143
+ print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB", file=sys.stderr)
144
+
145
+
146
+ if __name__ == "__main__":
147
+ from jsonargparse import CLI
148
+
149
+ torch.set_float32_matmul_precision("high")
150
+ warnings.filterwarnings(
151
+ # Triggered internally at ../aten/src/ATen/EmptyTensor.cpp:31
152
+ "ignore",
153
+ message="ComplexHalf support is experimental and many operators don't support it yet"
154
+ )
155
+ warnings.filterwarnings(
156
+ # Triggered in bitsandbytes/autograd/_functions.py:298
157
+ "ignore",
158
+ message="MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization",
159
+ )
160
+ CLI(main)
generate_lora.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import time
3
+ import warnings
4
+ from pathlib import Path
5
+ from typing import Optional
6
+
7
+ import lightning as L
8
+ import torch
9
+
10
+ from generate import generate
11
+ from lit_llama import Tokenizer, LLaMA
12
+ from lit_llama.lora import lora
13
+ from lit_llama.utils import EmptyInitOnDevice, lazy_load, llama_model_lookup
14
+ from scripts.prepare_alpaca import generate_prompt
15
+
16
+ lora_r = 8
17
+ lora_alpha = 16
18
+ lora_dropout = 0.05
19
+
20
+
21
+ def main(
22
+ prompt: str = "What food do lamas eat?",
23
+ input: str = "",
24
+ lora_path: Optional[Path] = None,
25
+ pretrained_path: Optional[Path] = None,
26
+ tokenizer_path: Optional[Path] = None,
27
+ quantize: Optional[str] = None,
28
+ dtype: str = "float32",
29
+ max_new_tokens: int = 100,
30
+ top_k: int = 200,
31
+ temperature: float = 0.8,
32
+ ) -> None:
33
+ """Generates a response based on a given instruction and an optional input.
34
+ This script will only work with checkpoints from the instruction-tuned LoRA model.
35
+ See `finetune_lora.py`.
36
+
37
+ Args:
38
+ prompt: The prompt/instruction (Alpaca style).
39
+ lora_path: Path to the checkpoint with trained LoRA weights, which are the output of
40
+ `finetune_lora.py`.
41
+ input: Optional input (Alpaca style).
42
+ pretrained_path: The path to the checkpoint with pretrained LLaMA weights.
43
+ tokenizer_path: The tokenizer path to load.
44
+ quantize: Whether to quantize the model and using which method:
45
+ ``"llm.int8"``: LLM.int8() mode,
46
+ ``"gptq.int4"``: GPTQ 4-bit mode.
47
+ dtype: The dtype to use during generation.
48
+ max_new_tokens: The number of generation steps to take.
49
+ top_k: The number of top most probable tokens to consider in the sampling process.
50
+ temperature: A value controlling the randomness of the sampling process. Higher values result in more random
51
+ samples.
52
+ """
53
+ if not lora_path:
54
+ lora_path = Path("out/lora/alpaca/lit-llama-lora-finetuned.pth")
55
+ if not pretrained_path:
56
+ pretrained_path = Path(f"./checkpoints/lit-llama/7B/lit-llama.pth")
57
+ if not tokenizer_path:
58
+ tokenizer_path = Path("./checkpoints/lit-llama/tokenizer.model")
59
+
60
+ assert lora_path.is_file()
61
+ assert pretrained_path.is_file()
62
+ assert tokenizer_path.is_file()
63
+
64
+ if quantize is not None:
65
+ raise NotImplementedError("Quantization in LoRA is not supported yet")
66
+
67
+ fabric = L.Fabric(devices=1)
68
+
69
+ dt = getattr(torch, dtype, None)
70
+ if not isinstance(dt, torch.dtype):
71
+ raise ValueError(f"{dtype} is not a valid dtype.")
72
+ dtype = dt
73
+
74
+ print("Loading model ...", file=sys.stderr)
75
+ t0 = time.time()
76
+
77
+ with (lazy_load(pretrained_path) as pretrained_checkpoint,
78
+ lazy_load(lora_path) as adapter_checkpoint):
79
+ name = llama_model_lookup(pretrained_checkpoint)
80
+
81
+ with EmptyInitOnDevice(
82
+ device=fabric.device, dtype=dtype, quantization_mode=quantize
83
+ ), lora(r=lora_r, alpha=lora_alpha, dropout=lora_dropout, enabled=True):
84
+ model = LLaMA.from_name(name)
85
+
86
+ # 1. Load the pretrained weights
87
+ model.load_state_dict(pretrained_checkpoint, strict=False)
88
+ # 2. Load the fine-tuned adapter weights
89
+ model.load_state_dict(adapter_checkpoint, strict=False)
90
+
91
+ print(f"Time to load model: {time.time() - t0:.02f} seconds.", file=sys.stderr)
92
+
93
+ model.eval()
94
+ model = fabric.setup_module(model)
95
+
96
+ tokenizer = Tokenizer(tokenizer_path)
97
+ sample = {"instruction": prompt, "input": input}
98
+ prompt = generate_prompt(sample)
99
+ encoded = tokenizer.encode(prompt, bos=True, eos=False, device=model.device)
100
+
101
+ t0 = time.perf_counter()
102
+ output = generate(
103
+ model,
104
+ idx=encoded,
105
+ max_seq_length=max_new_tokens,
106
+ max_new_tokens=max_new_tokens,
107
+ temperature=temperature,
108
+ top_k=top_k,
109
+ eos_id=tokenizer.eos_id
110
+ )
111
+ t = time.perf_counter() - t0
112
+
113
+ output = tokenizer.decode(output)
114
+ output = output.split("### Response:")[1].strip()
115
+ print(output)
116
+
117
+ print(f"\n\nTime for inference: {t:.02f} sec total, {max_new_tokens / t:.02f} tokens/sec", file=sys.stderr)
118
+ if fabric.device.type == "cuda":
119
+ print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB", file=sys.stderr)
120
+
121
+
122
+ if __name__ == "__main__":
123
+ from jsonargparse import CLI
124
+
125
+ torch.set_float32_matmul_precision("high")
126
+ warnings.filterwarnings(
127
+ # Triggered internally at ../aten/src/ATen/EmptyTensor.cpp:31
128
+ "ignore",
129
+ message="ComplexHalf support is experimental and many operators don't support it yet"
130
+ )
131
+ CLI(main)
howto/customize_paths.md ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Customize paths
2
+
3
+ The project is setup to use specific paths to read the original weights and save checkpoints etc.
4
+
5
+ For all scripts, you can run
6
+
7
+ ```shell
8
+ python script.py -h
9
+ ```
10
+
11
+ to get a list of available options. For instance, here's how you would modify the checkpoint dir:
12
+
13
+ ```shell
14
+ python scripts/convert_checkpoint.py --checkpoint_dir "data/checkpoints/foo"
15
+ ```
16
+
17
+ Note that this change will need to be passed along to subsequent steps, for example:
18
+
19
+ ```shell
20
+ python scripts/generate.py \
21
+ --checkpoint_path "data/checkpoints/foo/7B/lit-llama.pth" \
22
+ --tokenizer_path "data/checkpoints/foo/tokenizer.model"
23
+ ```
24
+
25
+ and
26
+
27
+ ```shell
28
+ python scripts/quantize.py \
29
+ --checkpoint_path "data/checkpoints/foo/7B/lit-llama.pth" \
30
+ --tokenizer_path "data/checkpoints/foo/tokenizer.model"
31
+ ```
32
+
33
+ To avoid this, you can use symbolic links to create shortcuts and avoid passing different paths.
howto/download_weights.md ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Downloading pretrained weights
2
+
3
+ Except for when you are training from scratch, you will need the pretrained weights from Meta.
4
+
5
+ ### Original Meta weights
6
+
7
+ Download the model weights following the instructions on the official [LLaMA repository](https://github.com/facebookresearch/llama).
8
+
9
+ Once downloaded, you should have a folder like this:
10
+
11
+ ```text
12
+ checkpoints/llama
13
+ ├── 7B
14
+ │ ├── ...
15
+ │ └── consolidated.00.pth
16
+ ├── 13B
17
+ │ ...
18
+ └── tokenizer.model
19
+ ```
20
+
21
+ Convert the weights to the Lit-LLaMA format:
22
+
23
+ ```bash
24
+ python scripts/convert_checkpoint.py --model_size 7B
25
+ ```
26
+
27
+ > **Note**
28
+ > All scripts support argument [customization](customize_paths.md)
29
+
30
+ ### OpenLLaMA
31
+
32
+ OpenLM Research has released **Apache 2.0 licensed** weights obtained by training LLaMA on the 1.2 trillion token open-source [RedPajama](https://github.com/togethercomputer/RedPajama-Data) dataset.
33
+
34
+ Weights were released in preview on intermediate number of tokens (200B, 300B at the time of writing). In order to get them do:
35
+
36
+ ```bash
37
+ # Make sure you have git-lfs installed (https://git-lfs.com): git lfs install
38
+ git clone https://huggingface.co/openlm-research/open_llama_7b_preview_300bt checkpoints/open-llama/7B
39
+ ```
40
+
41
+ Or if you don't have `git-lfs` installed:
42
+
43
+ ```bash
44
+ python scripts/download.py --repo_id openlm-research/open_llama_7b_preview_300bt --local_dir checkpoints/open-llama/7B
45
+ ```
46
+
47
+ Once downloaded, you should have a folder like this:
48
+
49
+ ```text
50
+ checkpoints/open-llama/
51
+ └── 7B
52
+ └── open_llama_7b_preview_300bt_transformers_weights
53
+ ├── ...
54
+ ├── pytorch_model-00001-of-00002.bin
55
+ ├── pytorch_model-00002-of-00002.bin
56
+ ├── pytorch_model.bin.index.json
57
+ └── tokenizer.model
58
+ ```
59
+
60
+ Convert the weights to the Lit-LLaMA format:
61
+
62
+ ```bash
63
+ python scripts/convert_hf_checkpoint.py --checkpoint_dir checkpoints/open-llama/7B/open_llama_7b_preview_300bt_transformers_weights --model_size 7B
64
+ ```
65
+
66
+ > **Note**
67
+ > All scripts support argument [customization](customize_paths.md)
68
+
69
+ Once converted, you should have a folder like this:
70
+
71
+ ```text
72
+ checkpoints/lit-llama/
73
+ ├── 7B
74
+ │ └── lit-llama.pth
75
+ └── tokenizer.model
76
+ ```
77
+
78
+ You are all set. Now you can continue with inference or finetuning.
79
+
80
+ Try running [`generate.py` to test the imported weights](inference.md).
81
+
82
+
83
+ ### Alternative sources
84
+
85
+ You might find LLaMA weights hosted online in the HuggingFace hub. Beware that this infringes the original weight's license.
86
+ You could try downloading them by running the following command with a specific repo id:
87
+
88
+ ```bash
89
+ # Make sure you have git-lfs installed (https://git-lfs.com): git lfs install
90
+ git clone REPO_ID checkpoints/hf-llama/7B
91
+ ```
92
+
93
+ Or if you don't have `git-lfs` installed:
94
+
95
+ ```bash
96
+ python scripts/download.py --repo_id REPO_ID --local_dir checkpoints/hf-llama/7B
97
+ ```
98
+
99
+ Once downloaded, you should have a folder like this:
100
+
101
+ ```text
102
+ checkpoints/hf-llama/
103
+ └── 7B
104
+ ├── ...
105
+ ├── pytorch_model-00001-of-00002.bin
106
+ ├── pytorch_model-00002-of-00002.bin
107
+ ├── pytorch_model.bin.index.json
108
+ └── tokenizer.model
109
+ ```
110
+
111
+ Convert the weights to the Lit-LLaMA format:
112
+
113
+ ```bash
114
+ python scripts/convert_hf_checkpoint.py --model_size 7B
115
+ ```
116
+
117
+ > **Note**
118
+ > All scripts support argument [customization](customize_paths.md)
119
+
120
+ Once converted, you should have a folder like this:
121
+
122
+ ```text
123
+ checkpoints/lit-llama/
124
+ ├── 7B
125
+ │ └── lit-llama.pth
126
+ └── tokenizer.model
127
+ ```
128
+
129
+ You are all set. Now you can continue with inference or finetuning.
130
+
131
+ Try running [`generate.py` to test the imported weights](inference.md).
howto/finetune_adapter.md ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Finetuning with Adapter
2
+
3
+ [LLaMA-Adapter](https://arxiv.org/abs/2303.16199) is a form of prefix-tuning that prepends a learnable adaption-prompt to the inputs of the attention blocks in LLaMA. In total, there are only 1.2M parameters to update during finetuning, which significantly reduces the memory footprint and speeds up training.
4
+
5
+ We are able to demonstrate instruction-finetuning Lit-LLaMA 7B on the [Alpaca](https://github.com/tatsu-lab/stanford_alpaca) dataset on a **single GTX 3090 (24GB) GPU**. If using 8 GPUs, finetuning can be completed in under 1 hour.
6
+
7
+ If you are new to LLaMA-Adapter and are interest to learn more about how it works before proceeding with the finetuning guide below, you might find our article [Understanding Parameter-Efficient Finetuning of Large Language Models: From Prefix Tuning to LLaMA-Adapters](https://lightning.ai/pages/community/article/understanding-llama-adapters/) helpful.
8
+
9
+ ## Preparation
10
+
11
+ The steps here only need to be done once:
12
+
13
+ 1. Follow the instructions in the [README](README.md) to install the dependencies.
14
+ 2. Download and convert the weights and save them in the `./checkpoints` folder as described [here](download_weights.md).
15
+ 3. If you want to utilize more than one GPU, you should `pip install deepspeed`.
16
+ 4. Download the data and generate the Alpaca instruction tuning dataset:
17
+
18
+ ```bash
19
+ python scripts/prepare_alpaca.py
20
+ ```
21
+
22
+ or [prepare your own dataset](#tune-on-your-own-dataset).
23
+
24
+ ## Running the finetuning
25
+
26
+ ```bash
27
+ python finetune_adapter.py
28
+ ```
29
+
30
+ The finetuning requires at least one GPU with ~24 GB memory (GTX 3090).
31
+ You can speed up training by setting the `devices` variable in the script to utilize more GPUs if available.
32
+ Depending on the available GPU memory, you can also tune the `micro_batch_size` parameter to utilize the GPU efficiently.
33
+
34
+ For example, the following settings will let you finetune the model in under 1 hour using DeepSpeed Zero-2:
35
+ ```python
36
+ devices = 8
37
+ micro_batch_size = 8
38
+ ```
39
+
40
+ This script will save checkpoints periodically to the folder `out/`.
41
+
42
+ > **Note**
43
+ > All scripts support argument [customization](customize_paths.md)
44
+
45
+ ## Test the model
46
+
47
+ You can test the finetuned model with your own instructions by running:
48
+
49
+ ```bash
50
+ python generate_adapter.py \
51
+ --prompt "Recommend a movie to watch on the weekend." \
52
+ --quantize llm.int8
53
+ ```
54
+ Output:
55
+ ```
56
+ A good movie to watch on the weekend would be The Lion King, since it's a classic family film that everyone can enjoy...
57
+ ```
58
+ If your GPU supports `bfloat16`, the script will automatically use it. Together with `--quantize llm.int8`, this brings the memory consumption down to ~8 GB.
59
+
60
+ ## Tune on your dataset
61
+
62
+ With only a few modifications, you can prepare and train on your own instruction dataset.
63
+
64
+ 1. Create a json file in which each row holds one instruction-response pair.
65
+ A row has an entry for 'instruction', 'input', and 'output', where 'input' is optional an can be
66
+ the empty string if the instruction doesn't require a context. Below is an example json file:
67
+
68
+ ```
69
+ [
70
+ {
71
+ "instruction": "Arrange the given numbers in ascending order.",
72
+ "input": "2, 4, 0, 8, 3",
73
+ "output": "0, 2, 3, 4, 8"
74
+ },
75
+ ...
76
+ ]
77
+ ```
78
+
79
+ 2. Make a copy of `scripts/prepare_alpaca.py` and name it what you want:
80
+
81
+ ```bash
82
+ cp scripts/prepare_alpaca.py scripts/prepare_mydata.py
83
+ ```
84
+
85
+ 3. Modify `scripts/prepare_mydata.py` to read the json data file.
86
+ 4. Run the script to generate the preprocessed, tokenized train-val split:
87
+
88
+ ```bash
89
+ python scripts/prepare_mydata.py --destination_path data/mydata/
90
+ ```
91
+
92
+ 5. Run `finetune_adapter.py` by passing in the location of your data (and optionally other parameters):
93
+
94
+ ```bash
95
+ python finetune_adapter.py --data_dir data/mydata/ --out_dir out/myexperiment
96
+ ```
97
+
98
+
99
+ ## Troubleshooting
100
+
101
+ If you run into a CUDA error "Expected is_sm80 to be true, but got false", uncomment the line
102
+ `torch.backends.cuda.enable_flash_sdp(False)` in the script below (see https://github.com/Lightning-AI/lit-llama/issues/101).
howto/finetune_full.md ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Full Finetuning
2
+
3
+ Full finetuning updates all layers in the pretrained LLaMA model. This *regular* finetuning procedure is typically considered as the baseline for parameter-efficient alternatives such as Low-Rank Adaptation (LoRA) or LLaMA-Adapter.
4
+
5
+ The current [finetune_full.py](../scripts/finetune_full.py) we provide uses 4 A100 GPUs with a fully-sharded data parallel strategy to finetune Lit-LLaMA 7B on [Alpaca](https://github.com/tatsu-lab/stanford_alpaca) dataset. The A100 GPUs have 40 GB each, but it may require less memory to finetune this model.
6
+
7
+
8
+
9
+ ## Preparation
10
+
11
+ The steps here only need to be done once:
12
+
13
+ 1. Follow the instructions in the [README](README.md) to install the dependencies.
14
+
15
+ 2. Download and convert the weights and save them in the `./checkpoints` folder as described [here](download_weights.md).
16
+
17
+ 4. Download the data and generate the Alpaca instruction tuning dataset:
18
+
19
+ ```bash
20
+ python scripts/prepare_alpaca.py
21
+ ```
22
+
23
+ or [prepare your own dataset](#tune-on-your-own-dataset).
24
+
25
+ ## Running the finetuning
26
+
27
+ ```bash
28
+ python finetune_full.py
29
+ ```
30
+
31
+
32
+ You can speed up training by setting the `devices` variable in the script to utilize more GPUs if available or increase the `batch_size`.
33
+ Depending on the available GPU memory, you can also tune the `micro_batch_size` parameter to utilize the GPU efficiently.
34
+
35
+ For example, the following settings will let you finetune the model in 32 hours using a fully-sharded data parallel strategy:
36
+ ```python
37
+ devices = 4
38
+ batch_size = 128 // devices
39
+ micro_batch_size = 4
40
+ ```
41
+
42
+ This script will save checkpoints periodically to the folder `out/`.
43
+
44
+ > **Note**
45
+ > All scripts support argument [customization](customize_paths.md)
46
+
47
+ ## Test the model
48
+
49
+ You can test the finetuned model with your own instructions by running:
50
+
51
+ ```bash
52
+ python generate_full.py \
53
+ --prompt "Recommend a movie to watch on the weekend." \
54
+ --quantize llm.int8
55
+ ```
56
+ Output:
57
+ ```
58
+ A good movie to watch on the weekend would be The Lion King, since it's a classic family film that everyone can enjoy...
59
+ ```
60
+ If your GPU supports `bfloat16`, the script will automatically use it. Together with `--quantize llm.int8`, this brings the memory consumption down to ~8 GB.
61
+
62
+ ## Tune on your dataset
63
+
64
+ With only a few modifications, you can prepare and train on your own instruction dataset.
65
+
66
+ 1. Create a json file in which each row holds one instruction-response pair.
67
+ A row has an entry for 'instruction', 'input', and 'output', where 'input' is optional an can be
68
+ the empty string if the instruction doesn't require a context. Below is an example json file:
69
+
70
+ ```
71
+ [
72
+ {
73
+ "instruction": "Arrange the given numbers in ascending order.",
74
+ "input": "2, 4, 0, 8, 3",
75
+ "output": "0, 2, 3, 4, 8"
76
+ },
77
+ ...
78
+ ]
79
+ ```
80
+
81
+ 2. Make a copy of `scripts/prepare_alpaca.py` and name it what you want:
82
+
83
+ ```bash
84
+ cp scripts/prepare_alpaca.py scripts/prepare_mydata.py
85
+ ```
86
+
87
+ 3. Modify `scripts/prepare_mydata.py` to read the json data file.
88
+ 4. Run the script to generate the preprocessed, tokenized train-val split:
89
+
90
+ ```bash
91
+ python scripts/prepare_mydata.py --destination_path data/mydata/
92
+ ```
93
+
94
+ 5. Run `finetune_full.py` by passing in the location of your data (and optionally other parameters):
95
+
96
+ ```bash
97
+ python finetune_full.py --data_dir data/mydata/ --out_dir out/myexperiment
98
+ ```
99
+
100
+
101
+ ## Troubleshooting
102
+
103
+ If you run into a CUDA error "Expected is_sm80 to be true, but got false", uncomment the line
104
+ `torch.backends.cuda.enable_flash_sdp(False)` in the script below (see https://github.com/Lightning-AI/lit-llama/issues/101).
howto/finetune_lora.md ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Finetuning with LoRA
2
+
3
+ [Low-rank adaption (LoRA)](https://arxiv.org/abs/2106.09685) is a technique to approximate the update to the linear layers in a LLM with a low-rank matrix factorization. This significantly reduces the number of trainable parameters and speeds up training with little impact on the final performance of the model.
4
+ We demonstrate this method by instruction-finetuning LLaMA 7B on the [Alpaca](https://github.com/tatsu-lab/stanford_alpaca) dataset on a **single GTX 3090 (24GB) GPU**.
5
+
6
+ ## Preparation
7
+
8
+ The steps here only need to be done once:
9
+
10
+ 1. Follow the instructions in the [README](README.md) to install the dependencies.
11
+ 2. Download and convert the weights and save them in the `./checkpoints` folder as described [here](download_weights.md).
12
+ 3. Download the data and generate the instruction tuning dataset:
13
+
14
+ ```bash
15
+ python scripts/prepare_alpaca.py
16
+ ```
17
+
18
+ ## Running the finetuning
19
+
20
+ ```bash
21
+ python finetune_lora.py
22
+ ```
23
+
24
+ The finetuning requires at least one GPU with ~24 GB memory (GTX 3090).
25
+
26
+ This script will save checkpoints periodically to the folder `out/`.
27
+
28
+ > **Note**
29
+ > All scripts support argument [customization](customize_paths.md)
30
+
31
+
32
+ ## Test the model
33
+
34
+ You can test the finetuned model with your own instructions by running:
35
+
36
+ ```bash
37
+ python generate_lora.py --prompt "Recommend a movie to watch on the weekend."
38
+ ```
39
+ Output:
40
+ ```
41
+ I would recommend the movie The Martian (2015). It is a sci-fi movie starring Matt Damon that follows the story of...
42
+ ```
43
+
44
+ If your GPU supports `bfloat16`, you can additionally pass `--dtype bfloat16` to bring the memory consumption down to ~14 GB.
45
+
46
+ ## Tune on your dataset
47
+
48
+ With only a few modifications, you can prepare and train on your own instruction dataset.
49
+
50
+ 1. Create a json file in which each row holds one instruction-response pair.
51
+ A row has an entry for 'instruction', 'input', and 'output', where 'input' is optional an can be
52
+ the empty string if the instruction doesn't require a context. Below is an example json file:
53
+
54
+ ```
55
+ [
56
+ {
57
+ "instruction": "Arrange the given numbers in ascending order.",
58
+ "input": "2, 4, 0, 8, 3",
59
+ "output": "0, 2, 3, 4, 8"
60
+ },
61
+ ...
62
+ ]
63
+ ```
64
+
65
+ 2. Make a copy of `scripts/prepare_alpaca.py` and name it what you want:
66
+
67
+ ```bash
68
+ cp scripts/prepare_alpaca.py scripts/prepare_mydata.py
69
+ ```
70
+
71
+ 3. Modify `scripts/prepare_mydata.py` to read the json data file.
72
+ 4. Run the script to generate the preprocessed, tokenized train-val split:
73
+
74
+ ```bash
75
+ python scripts/prepare_mydata.py --destination_path data/mydata/
76
+ ```
77
+
78
+ 5. Run `finetune_lora.py` by passing in the location of your data (and optionally other parameters):
79
+
80
+ ```bash
81
+ python finetune_lora.py --data_dir data/mydata/ --out_dir out/myexperiment
82
+ ```
83
+
84
+
85
+ ## Troubleshooting
86
+
87
+ If you run into a CUDA error "Expected is_sm80 to be true, but got false", uncomment the line
88
+ `torch.backends.cuda.enable_flash_sdp(False)` in the script below (see https://github.com/Lightning-AI/lit-llama/issues/101).
howto/inference.md ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Inference
2
+
3
+ We demonstrate how to run inference (next token prediction) with the LLaMA base model in the [`generate.py`](generate.py) script:
4
+
5
+ ```bash
6
+ python generate.py --prompt "Hello, my name is"
7
+ ```
8
+ Output:
9
+ ```
10
+ Hello my name is TJ. I have a passion for the outdoors, love hiking and exploring. I also enjoy traveling and learning new things. I especially enjoy long walks, good conversation and a friendly smile.
11
+ ```
12
+
13
+ The script assumes you have downloaded and converted the weights and saved them in the `./checkpoints` folder as described [here](download_weights.md).
14
+
15
+ > **Note**
16
+ > All scripts support argument [customization](customize_paths.md)
17
+
18
+ With the default settings, this will run the 7B model and require ~26 GB of GPU memory (A100 GPU).
19
+
20
+ ## Run Lit-LLaMA on consumer devices
21
+
22
+ On GPUs with `bfloat16` support, the `generate.py` script will automatically convert the weights and consume about ~14 GB.
23
+ For GPUs with less memory, or ones that don't support `bfloat16`, enable quantization (`--quantize llm.int8`):
24
+
25
+ ```bash
26
+ python generate.py --quantize llm.int8 --prompt "Hello, my name is"
27
+ ```
28
+ This will consume about ~10 GB of GPU memory or ~8 GB if also using `bfloat16`.
29
+ See `python generate.py --help` for more options.
30
+
31
+ You can also use GPTQ-style int4 quantization, but this needs conversions of the weights first:
32
+
33
+ ```bash
34
+ python quantize.py --checkpoint_path lit-llama.pth --tokenizer_path tokenizer.model --output_path llama-7b-gptq.4bit.pt --dtype bfloat16 --quantize gptq.int4
35
+ ```
36
+
37
+ With the generated quantized checkpoint generation works as usual with `--quantize gptq.int4`, bringing GPU usage to about ~5GB. As only the weights of the Linear layers are quantized, it is useful to use `--dtype bfloat16` even with the quantization enabled.
howto/tpus.md ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # TPU support
2
+
3
+ Lit-LLaMA used `lightning.Fabric` under the hood, which itself supports TPUs (via [PyTorch XLA](https://github.com/pytorch/xla)).
4
+
5
+ The following commands will allow you to set up a `Google Cloud` instance with a [TPU v4](https://cloud.google.com/tpu/docs/system-architecture-tpu-vm) VM:
6
+
7
+ ```shell
8
+ gcloud compute tpus tpu-vm create lit-llama --version=tpu-vm-v4-pt-2.0 --accelerator-type=v4-8 --zone=us-central2-b
9
+ gcloud compute tpus tpu-vm ssh lit-llama --zone=us-central2-b
10
+ ```
11
+
12
+ Now that you are in the machine, let's clone the repository and install the dependencies
13
+
14
+ ```shell
15
+ git clone https://github.com/Lightning-AI/lit-llama
16
+ cd lit-llama
17
+ pip install -r requirements.txt
18
+ ```
19
+
20
+ By default, computations will run using the new (and experimental) PjRT runtime. Still, it's recommended that you set the following environment variables
21
+
22
+ ```shell
23
+ export PJRT_DEVICE=TPU
24
+ export ALLOW_MULTIPLE_LIBTPU_LOAD=1
25
+ ```
26
+
27
+ > **Note**
28
+ > You can find an extensive guide on how to get set-up and all the available options [here](https://cloud.google.com/tpu/docs/v4-users-guide).
29
+
30
+ Since you created a new machine, you'll probably need to download the weights. You could scp them into the machine with `gcloud compute tpus tpu-vm scp` or you can follow the steps described in our [downloading guide](download_weights.md).
31
+
32
+ ## Inference
33
+
34
+ Generation works out-of-the-box with TPUs:
35
+
36
+ ```shell
37
+ python3 generate.py --prompt "Hello, my name is" --num_samples 2
38
+ ```
39
+
40
+ This command will take a long time as XLA needs to compile the graph (~13 min) before running the model.
41
+ In fact, you'll notice that the second sample takes considerable less time (~12 sec).
42
+
43
+ ## Finetuning
44
+
45
+ Coming soon.
46
+
47
+ > **Warning**
48
+ > When you are done, remember to delete your instance
49
+ > ```shell
50
+ > gcloud compute tpus tpu-vm delete lit-llama --zone=us-central2-b
51
+ > ```
howto/train_redpajama.md ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Pre-train LLaMA on RedPajama
2
+
3
+ This howto will walk you through setting up the RedPajama dataset and launching the pre-training script.
4
+
5
+ ## What's RedPajama
6
+
7
+ [RedPajama](https://github.com/togethercomputer/RedPajama-Data) is an open-source reproduction of the original LLaMA training dataset.
8
+
9
+ It contains a total of 1.2 trillion tokens, divided into
10
+
11
+ ```text
12
+ Commoncrawl 878B
13
+ C4 175B
14
+ GitHub 59B
15
+ Books 26B
16
+ ArXiv 28B
17
+ Wikipedia 24B
18
+ StackExchange 20B
19
+ ```
20
+
21
+ The [RedPajama repo](https://github.com/togethercomputer/RedPajama-Data) contains the source code for collecting and preparing
22
+ the dataset, and it is Apache 2.0 licensed.
23
+
24
+ The data itself is licensed according to the original licenses with which its invidivdual parts were released.
25
+ The GitHub datasets are limited to MIT, BSD, or Apache 2.0 repositories.
26
+
27
+ Along with the full [RedPajama-1T dataset](https://huggingface.co/datasets/togethercomputer/RedPajama-Data-1T),
28
+ the [RedPajama-1T-Sample](https://huggingface.co/datasets/togethercomputer/RedPajama-Data-1T-Sample) 1B sample dataset
29
+ is also available for development.
30
+
31
+ You can download the data using git lfs:
32
+
33
+ ```bash
34
+ # Make sure you have git-lfs installed (https://git-lfs.com): git lfs install
35
+ git clone https://huggingface.co/datasets/togethercomputer/RedPajama-Data-1T data/RedPajama-Data-1T
36
+ ```
37
+
38
+ ```bash
39
+ # Make sure you have git-lfs installed (https://git-lfs.com): git lfs install
40
+ git clone https://huggingface.co/datasets/togethercomputer/RedPajama-Data-1T-Sample data/RedPajama-Data-1T-Sample
41
+ ```
42
+
43
+ ## Prepare RedPajama for training
44
+
45
+ The dataset consists of 2084 `jsonl` files (the sample dataset contains 11). In order to start pre-training lit-llama
46
+ on it, you need to read, tokenize, and write the data in binary chunks. This will leverage the `PackedDataset`
47
+ streaming dataset that comes with lit-llama.
48
+
49
+ Do to so, run
50
+
51
+ ```bash
52
+ python scripts/prepare_redpajama.py --source_path data/RedPajama-Data-1T --tokenizer_path checkpoints/lit-llama/tokenizer.model --destination_path data/lit-redpajama
53
+ ```
54
+
55
+ or
56
+
57
+ ```bash
58
+ python scripts/prepare_redpajama.py --source_path data/RedPajama-Data-1T-Sample --tokenizer_path checkpoints/lit-llama/tokenizer.model --destination_path data/lit-redpajama-sample --sample True
59
+ ```
60
+
61
+ for the sample dataset.
62
+
63
+ In the above we are assuming that you will be using the same tokenizer as used in LLaMA, but any trained [SentencePiece](https://github.com/google/sentencepiece) tokenizer with a 32000 vocabulary size will do here.
64
+
65
+ The script will take a while to run, so time for :tea:
66
+
67
+ ## Pre-training
68
+
69
+ Running the pre-training script requires at least 4 GPUs with 40GB+ each (A100).
70
+
71
+ ```bash
72
+ python train_redpajama.py --devices 4 --train_data_dir data/lit-redpajama
73
+ ```
74
+
75
+ For running on the sample dataset:
76
+
77
+ ```bash
78
+ python train_redpajama.py --devices 4 --train_data_dir data/lit-redpajama-sample
79
+ ```
80
+
81
+ The script will save checkpoints periodically to the folder `out/`.
82
+
83
+ The `train_redpajama.py` script will pre-train the LLaMA 7B model with FSDP in
84
+ `bfloat16` precision and gradient accumulation.
85
+
86
+ You can easily change the size of the model by passing a different string to
87
+
88
+ ```python
89
+ config = LLaMAConfig.from_name("7B")
90
+ ```
91
+
92
+ in the `main` function.
93
+
94
+ Keep in mind that the original LLaMA training for the 7B model required 83k A100 80GB
95
+ hours, so you'll need access to a cluster.
96
+
97
+ Once you're in a cluster, you can follow [these instructions](https://lightning.ai/docs/fabric/stable/guide/multi_node/other.html)
98
+ to launch the script across machines:
99
+
100
+ - [SLURM cluster](https://lightning.ai/docs/fabric/stable/guide/multi_node/slurm.html)
101
+ - [Barebones cluster](https://lightning.ai/docs/fabric/stable/guide/multi_node/barebones.html)
102
+ - [MPI](https://lightning.ai/docs/fabric/stable/guide/multi_node/other.html)
103
+
104
+ The script contains several configurations and hyperparameters you can tweak:
105
+
106
+ ```python
107
+ out_dir = "out/training"
108
+ save_interval = 1000
109
+ eval_interval = 1000
110
+ eval_iters = 100
111
+ log_interval = 1
112
+
113
+ # Hyperparameters
114
+ learning_rate = 6e-4
115
+ batch_size = 125
116
+ micro_batch_size = 5
117
+ max_iters = 600000 # num_epochs * epoch_size // devices
118
+ weight_decay = 1e-1
119
+ beta1 = 0.9
120
+ beta2 = 0.95
121
+ grad_clip = 1.0
122
+ decay_lr = True
123
+ warmup_iters = 2000
124
+ lr_decay_iters = max_iters
125
+ min_lr = 6e-5
126
+ ```
127
+
128
+ In particular, `micro_batch_size` should be adjusted so the process will use the available
129
+ GPU memory.
130
+
131
+ Last, logging is kept minimal in the script. In order to use a particular logger
132
+ please refer to <https://lightning.ai/docs/fabric/stable/api/loggers.html> or
133
+ call a logging client library like `wandb` directly.
lit_llama/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from lit_llama.model import LLaMAConfig, LLaMA, RMSNorm, build_rope_cache, apply_rope
2
+ from lit_llama.tokenizer import Tokenizer
lit_llama/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (412 Bytes). View file
 
lit_llama/__pycache__/lora.cpython-311.pyc ADDED
Binary file (12.7 kB). View file
 
lit_llama/__pycache__/model.cpython-311.pyc ADDED
Binary file (15.4 kB). View file