flopml commited on
Commit
34dc8ec
·
verified ·
1 Parent(s): 63616c2

Upload 32 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ llama2.c/data/TinyStories_all_data/data00.json filter=lfs diff=lfs merge=lfs -text
llama2.c/.github/workflows/build.yml ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Continuous Integration
2
+
3
+ on:
4
+ push:
5
+ branches:
6
+ - master
7
+ paths: ['.github/workflows/**', '**/Makefile', '**/*.c', '**/*.h', '**/*.py']
8
+ pull_request:
9
+ types: [opened, synchronize, reopened]
10
+ paths: ['**/Makefile', '**/*.c', '**/*.h', '**/*.py']
11
+ # for manual triggering
12
+ workflow_dispatch:
13
+
14
+ env:
15
+ BRANCH_NAME: ${{ github.head_ref || github.ref_name }}
16
+
17
+ jobs:
18
+ # check basic builds to avoid breaking changes
19
+ ubuntu-focal-make:
20
+ runs-on: ubuntu-latest
21
+
22
+ steps:
23
+ - name: Clone
24
+ id: checkout
25
+ uses: actions/checkout@v3
26
+
27
+ - name: Dependencies
28
+ id: depends
29
+ run: |
30
+ sudo apt-get update
31
+ sudo apt-get install build-essential -y
32
+
33
+ - name: Set up Python 3.10
34
+ uses: actions/setup-python@v3
35
+ with:
36
+ python-version: "3.10"
37
+
38
+ - name: Pip setup
39
+ run: |
40
+ python -m pip install --upgrade pip
41
+ if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
42
+
43
+ - name: Build
44
+ id: make_build
45
+ run: |
46
+ make
47
+
48
+ - name: Build runfast
49
+ id: make_build_runfast
50
+ run: |
51
+ make runfast
52
+
53
+ - name: Test with pytest
54
+ run: |
55
+ pytest
56
+
57
+ macOS-latest-make:
58
+ runs-on: macos-latest
59
+
60
+ steps:
61
+ - name: Clone
62
+ id: checkout
63
+ uses: actions/checkout@v3
64
+
65
+ - name: Dependencies
66
+ id: depends
67
+ continue-on-error: true
68
+ run: |
69
+ brew update
70
+
71
+ - name: Set up Python 3.10
72
+ uses: actions/setup-python@v3
73
+ with:
74
+ python-version: "3.10"
75
+
76
+ - name: Pip setup
77
+ run: |
78
+ python -m pip install --upgrade pip
79
+ if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
80
+
81
+ - name: Build clang
82
+ id: make_build_clang
83
+ run: |
84
+ make run CC=clang
85
+
86
+ - name: Build
87
+ id: make_build
88
+ run: |
89
+ make
90
+
91
+ - name: Build runfast
92
+ id: make_build_runfast
93
+ run: |
94
+ make runfast
95
+
96
+ - name: Test with pytest
97
+ run: pytest
98
+
99
+
100
+
101
+
102
+ windows-latest-make:
103
+ runs-on: windows-latest
104
+
105
+ strategy:
106
+ fail-fast: false #necessary, otherwise the matrix breaks
107
+ matrix:
108
+ arch:
109
+ - amd64
110
+ - amd64_x86
111
+ - amd64_arm64
112
+
113
+ steps:
114
+ - name: Clone
115
+ id: checkout
116
+ uses: actions/checkout@v3
117
+
118
+ - name: Setup MSBuild
119
+ uses: microsoft/setup-msbuild@v1
120
+
121
+ - name: Setup MSVC ${{ matrix.arch }}
122
+ uses: ilammy/msvc-dev-cmd@v1
123
+ with:
124
+ arch: ${{ matrix.arch }}
125
+
126
+ - name: Set up Python 3.10
127
+ if: matrix.arch != 'amd64_arm64'
128
+ uses: actions/setup-python@v3
129
+ with:
130
+ python-version: "3.10"
131
+
132
+ - name: Pip setup
133
+ if: matrix.arch != 'amd64_arm64'
134
+ run: |
135
+ python -m pip install --upgrade pip
136
+ if (Test-Path requirements.txt) {
137
+ pip install -r requirements.txt
138
+ }
139
+
140
+ - name: Build ${{ matrix.arch }}
141
+ id: build_msvc
142
+ run: |
143
+ .\build_msvc.bat
144
+
145
+ #cross-comiled, cannot be run on host
146
+ - name: Test with pytest
147
+ if: matrix.arch != 'amd64_arm64'
148
+ run: pytest
149
+
150
+ windows-latest-mingw:
151
+ runs-on: windows-latest
152
+
153
+ defaults:
154
+ run:
155
+ shell: msys2 {0}
156
+
157
+ strategy:
158
+ matrix:
159
+ include:
160
+ - { sys: mingw64, env: x86_64 }
161
+
162
+ steps:
163
+ - name: Checkout
164
+ id: checkout
165
+ uses: actions/checkout@v3
166
+
167
+ - uses: msys2/setup-msys2@v2
168
+ id: setup-msys2
169
+ with:
170
+ msystem: ${{ matrix.sys }}
171
+ install: mingw-w64-${{matrix.env}}-gcc make
172
+
173
+ - name: Build ${{ matrix.sys }} ${{ matrix.env }}
174
+ id: build_mingw
175
+ run: |
176
+ make win64
177
+
178
+ - name: Set up Python 3.10
179
+ uses: actions/setup-python@v3
180
+ with:
181
+ python-version: "3.10"
182
+
183
+ - name: Pip setup
184
+ shell: powershell
185
+ run: |
186
+ python -m pip install --upgrade pip
187
+ if (Test-Path requirements.txt) {
188
+ pip install -r requirements.txt
189
+ }
190
+
191
+ - name: Test with pytest
192
+ shell: powershell
193
+ run: pytest
llama2.c/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2023 Andrej
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
llama2.c/Makefile ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # choose your compiler, e.g. gcc/clang
2
+ # example override to clang: make run CC=clang
3
+ CC = gcc
4
+
5
+ # the most basic way of building that is most likely to work on most systems
6
+ .PHONY: run
7
+ run: run.c
8
+ $(CC) -O3 -o run run.c -lm
9
+ $(CC) -O3 -o runq runq.c -lm
10
+
11
+ # useful for a debug build, can then e.g. analyze with valgrind, example:
12
+ # $ valgrind --leak-check=full ./run out/model.bin -n 3
13
+ rundebug: run.c
14
+ $(CC) -g -o run run.c -lm
15
+ $(CC) -g -o runq runq.c -lm
16
+
17
+ # https://gcc.gnu.org/onlinedocs/gcc/Optimize-Options.html
18
+ # https://simonbyrne.github.io/notes/fastmath/
19
+ # -Ofast enables all -O3 optimizations.
20
+ # Disregards strict standards compliance.
21
+ # It also enables optimizations that are not valid for all standard-compliant programs.
22
+ # It turns on -ffast-math, -fallow-store-data-races and the Fortran-specific
23
+ # -fstack-arrays, unless -fmax-stack-var-size is specified, and -fno-protect-parens.
24
+ # It turns off -fsemantic-interposition.
25
+ # In our specific application this is *probably* okay to use
26
+ .PHONY: runfast
27
+ runfast: run.c
28
+ $(CC) -Ofast -o run run.c -lm
29
+ $(CC) -Ofast -o runq runq.c -lm
30
+
31
+ # additionally compiles with OpenMP, allowing multithreaded runs
32
+ # make sure to also enable multiple threads when running, e.g.:
33
+ # OMP_NUM_THREADS=4 ./run out/model.bin
34
+ .PHONY: runomp
35
+ runomp: run.c
36
+ $(CC) -Ofast -fopenmp -march=native run.c -lm -o run
37
+ $(CC) -Ofast -fopenmp -march=native runq.c -lm -o runq
38
+
39
+ .PHONY: win64
40
+ win64:
41
+ x86_64-w64-mingw32-gcc -Ofast -D_WIN32 -o run.exe -I. run.c win.c
42
+ x86_64-w64-mingw32-gcc -Ofast -D_WIN32 -o runq.exe -I. runq.c win.c
43
+
44
+ # compiles with gnu99 standard flags for amazon linux, coreos, etc. compatibility
45
+ .PHONY: rungnu
46
+ rungnu:
47
+ $(CC) -Ofast -std=gnu11 -o run run.c -lm
48
+ $(CC) -Ofast -std=gnu11 -o runq runq.c -lm
49
+
50
+ .PHONY: runompgnu
51
+ runompgnu:
52
+ $(CC) -Ofast -fopenmp -std=gnu11 run.c -lm -o run
53
+ $(CC) -Ofast -fopenmp -std=gnu11 runq.c -lm -o runq
54
+
55
+ # run all tests
56
+ .PHONY: test
57
+ test:
58
+ pytest
59
+
60
+ # run only tests for run.c C implementation (is a bit faster if only C code changed)
61
+ .PHONY: testc
62
+ testc:
63
+ pytest -k runc
64
+
65
+ # run the C tests, without touching pytest / python
66
+ # to increase verbosity level run e.g. as `make testcc VERBOSITY=1`
67
+ VERBOSITY ?= 0
68
+ .PHONY: testcc
69
+ testcc:
70
+ $(CC) -DVERBOSITY=$(VERBOSITY) -O3 -o testc test.c -lm
71
+ ./testc
72
+
73
+ .PHONY: clean
74
+ clean:
75
+ rm -f run
76
+ rm -f runq
llama2.c/README.md ADDED
@@ -0,0 +1,415 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## llama2.c
2
+
3
+ <p align="center">
4
+ <img src="assets/llama_cute.jpg" width="300" height="300" alt="Cute Llama">
5
+ </p>
6
+
7
+ Have you ever wanted to inference a baby [Llama 2](https://ai.meta.com/llama/) model in pure C? No? Well, now you can!
8
+
9
+ Train the Llama 2 LLM architecture in PyTorch then inference it with one simple 700-line C file ([run.c](run.c)). You might think that you need many billion parameter LLMs to do anything useful, but in fact very small LLMs can have surprisingly strong performance if you make the domain narrow enough (ref: [TinyStories](https://huggingface.co/datasets/roneneldan/TinyStories) paper). This repo is a "fullstack" train + inference solution for Llama 2 LLM, with focus on minimalism and simplicity.
10
+
11
+ As the architecture is identical, you can also load and inference Meta's Llama 2 models. However, the current code only inferences models in fp32, so you will most likely not be able to productively load models larger than 7B. Work on model quantization is currently ongoing.
12
+
13
+ Please note that this repo started recently as a fun weekend project: I took my earlier [nanoGPT](https://github.com/karpathy/nanoGPT), tuned it to implement the Llama-2 architecture instead of GPT-2, and the meat of it was writing the C inference engine in [run.c](run.c). So the project is young and moving quickly. Hat tip to the awesome [llama.cpp](https://github.com/ggerganov/llama.cpp) for inspiring this project. Compared to llama.cpp, I wanted something super simple, minimal, and educational so I chose to hard-code the Llama 2 architecture and just roll one inference file of pure C with no dependencies.
14
+
15
+ ## feel the magic
16
+
17
+ [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/karpathy/llama2.c/blob/master/run.ipynb)
18
+
19
+ First, navigate to the folder where you keep your projects and clone this repository to this folder:
20
+
21
+ ```bash
22
+ git clone https://github.com/karpathy/llama2.c.git
23
+ ```
24
+
25
+ Then, open the repository folder:
26
+
27
+ ```bash
28
+ cd llama2.c
29
+ ```
30
+
31
+ Now, let's just run a baby Llama 2 model in C. You need a model checkpoint. Download this 15M parameter model I trained on the [TinyStories](https://huggingface.co/datasets/roneneldan/TinyStories) dataset (~60MB download):
32
+
33
+ ```bash
34
+ wget https://huggingface.co/karpathy/tinyllamas/resolve/main/stories15M.bin
35
+ ```
36
+
37
+ Compile and run the C code:
38
+
39
+ ```bash
40
+ make run
41
+ ./run stories15M.bin
42
+ ```
43
+
44
+ You'll see the text stream a sample. On my M1 MacBook Air this runs at ~110 tokens/s. See [performance](#performance) or the Makefile for compile flags that can significantly speed this up. We can also try a bit bigger 42M parameter model:
45
+
46
+ ```bash
47
+ wget https://huggingface.co/karpathy/tinyllamas/resolve/main/stories42M.bin
48
+ ./run stories42M.bin
49
+ ```
50
+
51
+ This still runs at interactive rates and samples more coherent and diverse stories:
52
+
53
+ > Once upon a time, there was a little girl named Lily. She loved playing with her toys on top of her bed. One day, she decided to have a tea party with her stuffed animals. She poured some tea into a tiny teapot and put it on top of the teapot. Suddenly, her little brother Max came into the room and wanted to join the tea party too. Lily didn't want to share her tea and she told Max to go away. Max started to cry and Lily felt bad. She decided to yield her tea party to Max and they both shared the teapot. But then, something unexpected happened. The teapot started to shake and wiggle. Lily and Max were scared and didn't know what to do. Suddenly, the teapot started to fly towards the ceiling and landed on the top of the bed. Lily and Max were amazed and they hugged each other. They realized that sharing was much more fun than being selfish. From that day on, they always shared their tea parties and toys.
54
+
55
+ You can also prompt the model with a prefix or a number of additional command line arguments, e.g. to sample at temperature 0.8 for 256 steps and with a prompt:
56
+
57
+ ```bash
58
+ ./run stories42M.bin -t 0.8 -n 256 -i "One day, Lily met a Shoggoth"
59
+ ```
60
+
61
+ > One day, Lily met a Shoggoth. He was very shy, but was also very generous. Lily said “Hello Shoggy! Can I be your friend?” Shoggy was happy to have a friend and said “Yes, let’s explore the universe together!” So they set off on a journey to explore the universe. As they travelled, Shoggy was happy to explain to Lily about all the wonderful things in the universe. At the end of the day, Lily and Shoggy had gathered lots of wonderful things from the universe, and they both felt very proud. They promised to explore the universe as one big pair and to never stop being generous to each other.
62
+
63
+ There is also an even better 110M param model available, see [models](#models).
64
+
65
+ Quick note on sampling, the recommendation for ~best results is to sample with `-t 1.0 -p 0.9`, i.e. temperature 1.0 (default) but also top-p sampling at 0.9 (default). Intuitively, top-p ensures that tokens with tiny probabilities do not get sampled, so we can't get "unlucky" during sampling, and we are less likely to go "off the rails" afterwards. More generally, to control the diversity of samples use either the temperature (i.e. vary `-t` between 0 and 1 and keep top-p off with `-p 0`) or the top-p value (i.e. vary `-p` between 0 and 1 and keep `-t 1`), but not both. Nice explainers on LLM sampling strategies include [this](https://peterchng.com/blog/2023/05/02/token-selection-strategies-top-k-top-p-and-temperature/), [this](https://docs.cohere.com/docs/controlling-generation-with-top-k-top-p) or [this](https://huggingface.co/blog/how-to-generate).
66
+
67
+ ## Meta's Llama 2 models
68
+
69
+ As the neural net architecture is identical, we can also inference the Llama 2 models released by Meta. Sadly there is a bit of friction here due to licensing (I can't directly upload the checkpoints, I think). So Step 1, get the Llama 2 checkpoints by following the [Meta instructions](https://github.com/facebookresearch/llama). Once we have those checkpoints, we have to convert them into the llama2.c format.
70
+ For this we need to install the python dependencies (`pip install -r requirements.txt`) and then use the `export.py` file, e.g. for 7B model:
71
+
72
+ ```bash
73
+ python export.py llama2_7b.bin --meta-llama path/to/llama/model/7B
74
+ ```
75
+
76
+ The export will take ~10 minutes or so and generate a 26GB file (the weights of the 7B model in float32) called `llama2_7b.bin` in the current directory. It has been [reported](https://github.com/karpathy/llama2.c/pull/85) that despite efforts. I would not attempt to run anything above 7B right now for two reasons: first, 13B+ currently doesn't work because of integer flow in pointer arithmetic, which is yet to be fixed, and second, even if it were fixed, this repo is doing float32 inference right now, so it would be fairly unusably slow. Once the export is done, we can run it:
77
+
78
+ ```bash
79
+ ./run llama2_7b.bin
80
+ ```
81
+
82
+ This ran at about 4 tokens/s compiled with [OpenMP](#OpenMP) on 96 threads on my CPU Linux box in the cloud. (On my MacBook Air M1, currently it's closer to 30 seconds per token if you just build with `make runfast`.) Example output:
83
+
84
+ > The purpose of this document is to highlight the state-of-the-art of CoO generation technologies, both recent developments and those in commercial use. The focus is on the technologies with the highest merit to become the dominating processes of the future and therefore to be technologies of interest to S&amp;T ... R&amp;D. As such, CoO generation technologies developed in Russia, Japan and Europe are described in some depth. The document starts with an introduction to cobalt oxides as complex products and a short view on cobalt as an essential material. The document continues with the discussion of the available CoO generation processes with respect to energy and capital consumption as well as to environmental damage.
85
+
86
+ base models... ¯\\_(ツ)_/¯. Since we can inference the base model, it should be possible to also inference the chat model quite easily, and have a conversation with it. And if we can find a way to run 7B more efficiently, we can start adding LoRA to our training script, and going wild with finetunes all within the repo!
87
+
88
+ You can also chat with the Llama Chat models. Export the chat model exactly as above:
89
+
90
+ ```bash
91
+ python export.py llama2_7b_chat.bin --meta-llama /path/to/7B-chat
92
+ ```
93
+
94
+ Then chat with it by specifying the chat mode using the `-m` flag, e.g.:
95
+
96
+ ```bash
97
+ ./run llama2_7b_chat.bin -m chat
98
+ ```
99
+
100
+ You can also try Meta's Code Llama models even if support for them is incomplete. In particular, some hyperparameters changed (e.g. the constant in RoPE layer), so the inference is not exactly correct and a bit buggy right now. Looking into fixes. Make sure to build the tokenizer for the plain and instruct variants and pass it when doing inference.
101
+
102
+ ```bash
103
+ python export.py codellama2_7b.bin --meta-llama /path/to/CodeLlama-7b
104
+ python tokenizer.py --tokenizer-model=/path/to/CodeLlama-7b/tokenizer.model
105
+ ./run codellama2_7b.bin -z /path/to/CodeLlama-7b/tokenizer.bin
106
+ ```
107
+
108
+ Chat with Code Llama Instruct:
109
+
110
+ ```bash
111
+ python export.py codellama2_7b_instruct.bin --meta-llama /path/to/CodeLlama-7b-Instruct
112
+ python tokenizer.py --tokenizer-model=/path/to/CodeLlama-7b-Instruct/tokenizer.model
113
+ ./run codellama2_7b_instruct.bin -m chat -z /path/to/CodeLlama-7b-Instruct/tokenizer.bin
114
+ ```
115
+
116
+ ## int8 quantization
117
+
118
+ The (default) script [run.c](run.c), above, uses a float32 forward pass, where the entire calculation of the forward pass is kept in fp32. This is very easy to understand as far as reference code goes, but it has the following downsides: the model checkpoint files are very large (it takes 4 bytes per every individual weight), and the forward pass is relatively slow. The (very) common inference optimization employed in practice is to quantize the model parameters to lower precision, giving up a little bit of correctness in return for smaller checkpoint sizes and faster forward passes (as most of the inference uses integer arithmetic). Empirically, LLMs can tolerate precisions as low as 4-bit (or even lower), but we use int8 here because it is a "safe" setting that gets us the benefits but doesn't sacrifice too much of the model accuracy. Only the weights that participate in matmuls are quantized. All the other parameters (e.g. especially the scale and bias in RMSNorm) are kept in float32, because these layers are very sensitive. Now, if all you're after is reduction in checkpoint sizes, you could quantize the weights, save the checkpoint, and then dequantize them in run.c, and do float32 inference as normal and call it a day. This is totally fine. But here, we go one step further (as is standard practice) and additionally quantize the activations in the forward pass. This requires us to dynamically quantize and dequantize between float32 and int8 at runtime, which adds overhead. But the benefit is that now the majority of the calculations (the matmuls especially!) are using pure integer arithmetic, where both weights and activations enter as int8. This is where the speedups can fundamentally come from. The version we use is the "Q8_0" quantization (llama.cpp terminology), where the 0 means that the weight quantization is symmetric around 0, quantizing to the range [-127, 127].
119
+
120
+ The quantized forward pass is implemented in [runq.c](runq.c). To use it, we have to export the model in the quantized format. For example, the float32 version of Llama 2 7B was exported as:
121
+
122
+ ```
123
+ python export.py llama2_7b.bin --meta-llama path/to/llama/model/7B
124
+ ```
125
+
126
+ This creates a 26GB file, because each one of 7B parameters is 4 bytes (fp32). To export it quantized, we instead use version 2 export:
127
+
128
+ ```
129
+ python export.py llama2_7b_q80.bin --version 2 --meta-llama path/to/llama/model/7B
130
+ ```
131
+
132
+ This runs for a few minutes, but now creates only a 6.7GB file. For exporting non-meta checkpoints you would use the --checkpoint arg instead of --meta-llama arg (more docs on this later, below). Now let's inference them. I like to use OMP here because these are big models, so e.g. on my Linux box:
133
+
134
+ ```
135
+ make runomp
136
+ OMP_NUM_THREADS=64 ./run llama2_7b.bin -n 40
137
+ OMP_NUM_THREADS=64 ./runq llama2_7b_q80.bin -n 40
138
+ ```
139
+
140
+ This runs 40 steps just to get a timing. The float32 version for me runs at 4.6 tok/s, and the int8 version at 14 tok/s. So we achieved a 3X speedup while reducing the checkpoint size by 4X. However, the forward pass is quantized to int8, and therefore silently very slightly lower quality.
141
+
142
+ ## huggingface models
143
+
144
+ We can load any huggingface models that use the Llama 2 architecture. See the script [export.py](export.py) and the `--hf` flag to export the model .bin file.
145
+
146
+ ## models
147
+
148
+ For the sake of examples of smaller, from-scratch models, I trained a small model series on TinyStories. All of these trained in a few hours on my training setup (4X A100 40GB GPUs). The 110M took around 24 hours. I am hosting them on huggingface hub [tinyllamas](https://huggingface.co/karpathy/tinyllamas), both in the original PyTorch .pt, and also in the llama2.c format .bin:
149
+
150
+ | model | dim | n_layers | n_heads | n_kv_heads | max context length | parameters | val loss | download
151
+ | --- | --- | --- | --- | --- | --- | --- | --- | --- |
152
+ | 260K | 64 | 5 | 8 | 4 | 512 | 260K | 1.297 | [stories260K](https://huggingface.co/karpathy/tinyllamas/tree/main/stories260K)
153
+ | OG | 288 | 6 | 6 | 6 | 256 | 15M | 1.072 | [stories15M.bin](https://huggingface.co/karpathy/tinyllamas/resolve/main/stories15M.bin) |
154
+ | 42M| 512 | 8 | 8 | 8 | 1024 | 42M | 0.847 | [stories42M.bin](https://huggingface.co/karpathy/tinyllamas/resolve/main/stories42M.bin) |
155
+ | 110M| 768 | 12 | 12 | 12 | 1024 | 110M | 0.760 | [stories110M.bin](https://huggingface.co/karpathy/tinyllamas/resolve/main/stories110M.bin) |
156
+
157
+ You'll notice that the 110M model is equivalent to GPT-1 in size. Alternatively, this is also the smallest model in the GPT-2 series (`GPT-2 small`), except the max context length is only 1024 instead of 2048. The only notable changes from GPT-1/2 architecture is that Llama uses RoPE relatively positional embeddings instead of absolute/learned positional embeddings, a bit more fancy SwiGLU non-linearity in the MLP, RMSNorm instead of LayerNorm, bias=False on all Linear layers, and is optionally multiquery.
158
+
159
+ ## training
160
+
161
+ Let's see how we can train a baby Llama 2 from scratch using the code in this repo. First let's download and pretokenize some source dataset, e.g. I like [TinyStories](https://huggingface.co/datasets/roneneldan/TinyStories) so this is the only example currently available in this repo. But it should be very easy to add datasets, see the code.
162
+
163
+ ```bash
164
+ python tinystories.py download
165
+ python tinystories.py pretokenize
166
+ ```
167
+
168
+ Then train our model:
169
+
170
+ ```bash
171
+ python train.py
172
+ ```
173
+
174
+ **brief training guide**. See the train.py script for more exotic launches and hyperparameter overrides. Here is a brief guide to how to set the parameters. Look at the table at the very end of the [Chinchilla paper](https://arxiv.org/abs/2203.15556) to get a sense of how the Transformer parameters (dim, n_layers, n_heads) grow or shrink together. Extrapolate/interpolate this pattern to get bigger or smaller transformers. Set the max context length however you wish, depending on the problem: this should be the max number of tokens that matter to predict the next token. E.g. Llama 2 uses 2048. Next, you want the _total_ batch size per update (printed by the script as "tokens per iteration will be:") to be somewhere around 100K tokens for medium-sized applications. For tiny applications it could be lower, for large training (e.g. GPTs/LLamas) it is usually ~0.5M, or even more. You get there by first maxing out the batch_size to whatever your system allows (e.g. mine was 16 in a recent run because after that my GPU runs out of memory), and then you want to increase gradient_accumulation_steps to be as high as necessary to reach the total batch size of ~100K. Finally, you want to tune your learning_rate (LR). You want this to be as high as your training allows. Very small networks can get away with a large LR (e.g. 1e-3 or even higher). Large networks need lower LRs. 3e-4 is a safe choice in most medium-sized applications, but can be too low for small networks, so try to increase it! Finally, max_iters is the length of training. Play with different settings. I mostly only ever tune these parameters and leave most of the others unchanged. Here is an example of how I trained the 110M model, which I don't think is anywhere near optimal, but looked sensible to me: dim 768, n_layers 12, n_heads 12 (so size of each head is 768 / 12 = 64 channels), seq len of 1024, batch size 16 (this is the most that fit my A100 40GB GPU), gradient_accumulation_steps = 8 was needed to get total tokens batch size to be 16 batch size * 1024 tokens in sequence * 8 grad_accum = 131,072 tokens per update. Good. Learning rate 4e-4 (probably a little too low). max_iters 200K (probably a bit too high). Dropout 0.1, as that usually helps a bit at medium size. That was it. I ran using Distributed Data Parallel (DDP) on 4 GPUs on my cloud machine, training took ~day or so.
175
+
176
+ Totally understand if you want to skip model training, for simple demo just download one of the pretrained models (see [models](#models) section), e.g.:
177
+
178
+ ```bash
179
+ wget https://huggingface.co/karpathy/tinyllamas/resolve/main/stories15M.bin
180
+ ```
181
+
182
+ Once we have the model.bin file, we can inference in C. Compile the C code first:
183
+
184
+ ```bash
185
+ make run
186
+ ```
187
+
188
+ You can now run it simply as
189
+
190
+ ```bash
191
+ ./run stories15M.bin
192
+ ```
193
+
194
+ Watch the tokens stream by, fun! We can also run the PyTorch inference script for a comparison. Download one of the models again from huggingface hub and point the `sample.py` script at it:
195
+
196
+ ```bash
197
+ wget https://huggingface.co/karpathy/tinyllamas/resolve/main/stories15M.pt -P out15M
198
+ python sample.py --checkpoint=out15M/stories15M.pt
199
+ ```
200
+
201
+ Which gives the same results.
202
+
203
+ ## custom tokenizers
204
+
205
+ In everything above, we've assumed the custom Lllama 2 tokenizer with 32,000 tokens. However, in many boutique LLMs, using vocabulary this big might be an overkill. If you have a small application you have in mind, you might be much better off training your own tokenizers. This can make everything nicer - with smaller vocabs your model has fewer parameters (because the token embedding table is a lot smaller), the inference is faster (because there are fewer tokens to predict), and your average sequence length per example could also get smaller (because the compression is a lot more efficient on your data). So let's see how we train a custom tokenizer.
206
+
207
+ By default, to pretokenize the tinystories dataset we had to run, in order:
208
+
209
+ ```
210
+ python tinystories.py download
211
+ python tinystories.py pretokenize
212
+ ```
213
+
214
+ The `pretokenize` stage here loads the Llama 2 tokenizer (vocab size 32,000) and uses it to convert the downloaded text into integers, and saves that to file. We now change this as follows, to train an example 4096-token tokenizer:
215
+
216
+ ```
217
+ python tinystories.py download
218
+ python tinystories.py train_vocab --vocab_size=4096
219
+ python tinystories.py pretokenize --vocab_size=4096
220
+ ```
221
+
222
+ The `train_vocab` stage will call the `sentencepiece` library to train the tokenizer, storing it in a new file `data/tok4096.model`. I tried to reproduce as well as I could the settings that (I think) Meta used to train their vocabulary. This uses the Byte Pair Encoding algorithm that starts out with raw utf8 byte sequences of the text data and then iteratively merges the most common consecutive pairs of tokens to form the vocabulary. Inspect the `tinystories.py` file - the custom tokenizers are stored in a special directory structure indexed by the vocab size.
223
+
224
+ A quick note of interest is that vocab size of 4096 trained specifically on tinystories creates integer sequences with about the same sequence length per example as the default Llama 2 tokenizer of 32000 tokens! This means that our custom, tailored tokenizer is a lot better adapted to our specific text, and can compress it very effectively. So our trained models are smaller and faster.
225
+
226
+ Now that we have pretokenized the dataset with our custom tokenizer, we can train the model. The training script `train.py` doesn't care about the exact tokens, it only cares about the vocabulary size so it can correctly initialize the model. So when training your model, make sure to pass in
227
+
228
+ ```
229
+ python train.py --vocab_source=custom --vocab_size=4096
230
+ ```
231
+
232
+ (The defaults are `llama2` and `32000` respectively, which indicates the default Llama 2 tokenizer). This trains the model. Finally we are ready to run inference with our `run.c` script. For that we need two things. Number one, we have to export our tokenizer in the `.bin` format, do that with:
233
+
234
+ ```
235
+ python tokenizer.py --tokenizer-model=data/tok4096.model
236
+ ```
237
+
238
+ This writes the tokenizer to `data/tok4096.bin`. Now we can run inference, pointing it to this tokenizer using the `-z` flag:
239
+
240
+ ```
241
+ ./run out/model.bin -z data/tok4096.bin
242
+ ```
243
+
244
+ This should print the samples. If you leave out the `-z` flag, it will use the default Llama 2 tokenizer, which would generate a good sequence of integers, but they would get translated using a different vocabulary to text, so it would look like gibberish.
245
+
246
+ ## performance
247
+
248
+ There are many ways to potentially speed up this code depending on your system. Have a look at the [Makefile](Makefile), which contains a lot of notes. The `make run` command currently uses the `-O3` optimization by default, i.e.:
249
+
250
+ ```bash
251
+ gcc -O3 -o run run.c -lm
252
+ ```
253
+
254
+ -O3 includes optimizations that are expensive in terms of compile time and memory usage. Including vectorization, loop unrolling, and predicting branches.
255
+
256
+ To get a much better performance, try to compile with `make runfast`. This turns on the `-Ofast` flag, which includes additional optimizations that may break compliance with the C/IEEE specifications, in addition to `-O3`. See [the GCC docs](https://gcc.gnu.org/onlinedocs/gcc/Optimize-Options.html) for more information.
257
+
258
+ Try `-march=native` to compile the program to use the architecture of the machine you're compiling on rather than a more generic CPU. This may enable additional optimizations and hardware-specific tuning such as improved vector instructions/width.
259
+
260
+ The fastest throughput I saw so far on my MacBook Air (M1) so far is with `make runfast`.
261
+
262
+ You can also experiment with replacing `gcc` with `clang`.
263
+
264
+ If compiling with gcc, try experimenting with `-funroll-all-loops`, see PR [#183](https://github.com/karpathy/llama2.c/pull/183)
265
+
266
+ **OpenMP**. Big improvements can also be achieved by compiling with OpenMP, which "activates" the `#pragma omp parallel for` inside the matmul and attention, allowing the work in the loops to be split up over multiple processors.
267
+ You'll need to install the OpenMP library and the clang compiler first (e.g. `apt install clang libomp-dev` on ubuntu). Then you can compile with `make runomp`, which does:
268
+
269
+ ```bash
270
+ clang -Ofast -fopenmp -march=native run.c -lm -o run
271
+ ```
272
+
273
+ When you run inference make sure to use OpenMP flags to set the number of threads, e.g.:
274
+
275
+ ```bash
276
+ OMP_NUM_THREADS=4 ./run out/model.bin
277
+ ```
278
+
279
+ Depending on your system resources you may want to tweak these hyperparameters and use more threads. But more is not always better, usually this is a bit U shaped. In particular, if your CPU has SMT (multithreading), try setting the number of threads to the number of physical cores rather than logical cores. The performance difference can be large due to cache thrashing and communication overhead. The PyTorch documentation [CPU specific optimizations
280
+ ](https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html#cpu-specific-optimizations) has some good information that applies here too.
281
+
282
+ ## platforms
283
+
284
+ On **Windows**, use `build_msvc.bat` in a Visual Studio Command Prompt to build with msvc, or you can use `make win64` to use mingw compiler toolchain from linux or windows to build the windows target. MSVC build will automatically use openmp and max threads appropriate for your CPU unless you set `OMP_NUM_THREADS` env.
285
+
286
+ On **Centos 7**, **Amazon Linux 2018** use `rungnu` Makefile target: `make rungnu` or `make runompgnu` to use openmp.
287
+
288
+ On **Mac**, use clang from brew for openmp build. Install clang as `brew install llvm` and use the installed clang binary to compile with openmp: `make runomp CC=/opt/homebrew/opt/llvm/bin/clang`
289
+
290
+ ## tests
291
+
292
+ You can run tests simply with pytest:
293
+
294
+ ```bash
295
+ $ pip install pytest
296
+ $ pytest
297
+ ```
298
+
299
+ This will currently invoke two tests inside `test_all.py`, which forward the model in both C and Python for 200 steps and check the output against a known good expected output. The tests currently run in only a few seconds, but will have to download and cache the stories260K models in a temporary `test` directory (only ~2MB download).
300
+
301
+ There are also some tests in C, in the file [test.c](test.c). You can run these with `make testcc`, or to see more stuff printed:
302
+
303
+ ```
304
+ make testcc VERBOSITY=1
305
+ ```
306
+
307
+ Call for help: help add more tests.
308
+
309
+ ## ack
310
+
311
+ I trained the llama2.c storyteller models on a 4X A100 40GB box graciously provided by the excellent [Lambda labs](https://lambdalabs.com/service/gpu-cloud), thank you.
312
+
313
+ ## discord
314
+
315
+ Figured it's possible to reuse my existing discord channel (that I use for my [zero to hero youtube series](https://karpathy.ai/zero-to-hero.html)), see #llama2c channel on [discord](https://discord.gg/3zy8kqD9Cp), for any quick questions, related discussions, etc.
316
+
317
+ ## contributing
318
+
319
+ A few words on this repo and the kinds of PRs that are likely to be accepted. What is the goal of this repo? Basically I think there will be a lot of interest in training or finetuning custom micro-LLMs (think ~100M - ~1B params, but let's say up to ~10B params) across a large diversity of applications, and deploying them in edge-adjacent environments (think MCUs, phones, web browsers, laptops, etc.). I'd like this repo to be the simplest, smallest, most hackable repo to support this workflow, both training and inference. In particular, this repo is not a complex framework with a 1000 knobs controlling inscrutible code across a nested directory structure of hundreds of files. Instead, I expect most applications will wish to create a fork of this repo and hack it to their specific needs and deployment platforms.
320
+
321
+ People who care about deployment efficiency above all else should look at [llama.cpp](https://github.com/ggerganov/llama.cpp). This repo still cares about efficiency, but not at the cost of simplicity, readability or portability. Basically, I expect that a lot of people come to this repo because the training code is 2 readable .py files and the inference code is 500 lines of C. So I'd like this to continue to be a kind of simplest "reference implementation" that can be easily hacked in a separate fork into whatever downstream application people are excited about. It shouldn't be full-featured. It shouldn't take 100 different options or settings. It shouldn't be the most efficient. A few examples:
322
+
323
+ - someone re-ordered two loops to improve data locality for a small efficieny win => instant merge.
324
+ - someone added the one line "pragma omp parallel for", which allows you to compile with OpenMP and dramatically speed up the code, or acts as just a comment if you don't compile it that way => instant merge.
325
+ - bug fixes and touchups etc. => happy to merge
326
+
327
+ A few examples of PRs are that are not an excellent fit:
328
+
329
+ - adding more than several #ifdefs all over the place in code. If they are localized / few, might be okay.
330
+ - adding a lot of code that is very specific to some specific platform (e.g. MCUs, or some special version of linux or processor). These may be a better fit for forks of the project, and I am very happy to maintain a list of these forks in section below.
331
+ - adding hundreds of lines of code to run.c that are only active in specific scenarios or platforms.
332
+
333
+ If your candidate PRs have elements of these it doesn't mean they won't get merged, it just means they will make it into the gray territory. TLDR: I am eager to merge any mostly small, mostly localized, broadly applicable, clean changes that improve the efficiency and portability of the repo, while keep its hackability and readability. I appreciate all PRs seeking to help me improve the project, thank you! <3.
334
+
335
+ ## notable forks
336
+
337
+ - Rust
338
+ - [llama2.rs](https://github.com/gaxler/llama2.rs) by @[gaxler](https://github.com/gaxler): a Rust port of this project
339
+ - [llama2.rs](https://github.com/leo-du/llama2.rs) by @[leo-du](https://github.com/leo-du): A Rust port of this project
340
+ - [llama2-rs](https://github.com/danielgrittner/llama2-rs) by @[danielgrittner](https://github.com/danielgrittner): a Rust port of this project
341
+ - [llama2.rs](https://github.com/lintian06/llama2.rs) by @[lintian06](https://github.com/lintian06): A Rust port of this project
342
+ - [pecca.rs](https://github.com/rahoua/pecca-rs) by @[rahoua](https://github.com/rahoua): A Rust port leveraging [ndarray](https://github.com/rust-ndarray/ndarray), supports BLAS.
343
+ - [llama2.rs](https://github.com/flaneur2020/llama2.rs) by @[flaneur2020](https://github.com/flaneur2020): A Rust port of this project.
344
+ - [llama2-burn](https://github.com/code-cp/llama2-burn): A Rust port of this project leveraging [Burn](https://github.com/tracel-ai/burn)
345
+ - Go
346
+ - [go-llama2](https://github.com/tmc/go-llama2) by @[tmc](https://github.com/tmc): a Go port of this project
347
+ - [llama2.go](https://github.com/nikolaydubina/llama2.go) by @[nikolaydubina](https://github.com/nikolaydubina): a Go port of this project
348
+ - [llama2.go](https://github.com/haormj/llama2.go) by @[haormj](https://github.com/haormj): a Go port of this project
349
+ - [llama2.go](https://github.com/saracen/llama2.go) by @[saracen](https://github.com/saracen): a Go port of this project
350
+ - Android
351
+ - [llama2.c-android](https://github.com/Manuel030/llama2.c-android): by @[Manuel030](https://github.com/Manuel030): adds Android binaries of this project
352
+ - [llama2.c-android-wrapper](https://github.com/celikin/llama2.c-android-wrapper): by @[celikin](https://github.com/celikin): added JNI wrapper, PoC
353
+ - C
354
+ - [llama3.c](https://github.com/jameswdelancey/llama3.c): by @[jameswdelancey](https://github.com/jameswdelancey): a LLaMA 3 8B Base and Instruct port of this project
355
+ - C++
356
+ - [llama2.cpp](https://github.com/leloykun/llama2.cpp) by @[leloykun](https://github.com/leloykun): a C++ port of this project
357
+ - [llama2.cpp](https://github.com/coldlarry/llama2.cpp) by @[coldlarry](https://github.com/coldlarry): a C++ port of this project
358
+ - JavaScript
359
+ - [llama2.js](https://github.com/epicure/llama2.js) by @[epicure](https://github.com/epicure): a JavaScript port of this project
360
+ - [llamajs](https://github.com/agershun/llamajs) by @[agershun](https://github.com/agershun): a JavaScript port of this project
361
+ - [llama2.ts](https://github.com/wizzard0/llama2.ts) by @[oleksandr_now](https://twitter.com/oleksandr_now): a TypeScript port of this project. Full Llama2-7B capable.
362
+ - [llama2.c-emscripten](https://github.com/gohai/llama2.c-emscripten) by @[gohai](https://github.com/gohai): Emscripten (JavaScript) port, based on @ggerganov's initial prototype
363
+ - Zig
364
+ - [llama2.zig](https://github.com/cgbur/llama2.zig) by @[cgbur](https://github.com/cgbur): A Zig port of this project
365
+ - [llama2.zig](https://github.com/vodkaslime/llama2.zig) by @[vodkaslime](https://github.com/vodkaslime): a Zig port of this project
366
+ - [llama2.zig](https://github.com/clebert/llama2.zig) by @[clebert](https://github.com/clebert): a Zig port of this project
367
+ - Julia
368
+ - [llama2.jl](https://github.com/juvi21/llama2.jl) by @[juvi21](https://github.com/juvi21): a Julia port of this project
369
+ - Scala
370
+ - [llama2.scala](https://github.com/jrudolph/llama2.scala) by @[jrudolph](https://github.com/jrudolph): a Scala port of this project
371
+ - Java
372
+ - [llama2.java](https://github.com/mukel/llama2.java) by @[mukel](https://github.com/mukel): a Java port of this project
373
+ - [llama2.java](https://github.com/neoremind/llama2.java) by @[neoremind](https://github.com/neoremind): a Java port of this project
374
+ - [llama2.tornadovm.java](https://github.com/mikepapadim/llama2.tornadovm.java) by @[mikepapadim](https://github.com/mikepapadim): an extension of the llama2.java with GPU-support through [TornadoVM](https://github.com/beehive-lab/TornadoVM).
375
+ - Kotlin
376
+ - [llama2.kt](https://github.com/madroidmaq/llama2.kt) by @[madroidmaq](https://github.com/madroidmaq): a Kotlin port of this project
377
+ - [llama2-kmp](https://github.com/stepango/llama2-kmp) by @[stepango](https://github.com/stepango): a Kotlin multiplatform(KMP) port of this project
378
+ - Python
379
+ - [llama2.py](https://github.com/tairov/llama2.py) by @[tairov](https://github.com/tairov): a simple one file pure Python port of this project with zero dependencies
380
+ - C#
381
+ - [llama2.cs](https://github.com/trrahul/llama2.cs) by @[trrahul](https://github.com/trrahul): a C# port of this project
382
+ - F#
383
+ - [llama2.fs](https://github.com/micsh/llama2.fs) by @[micsh](https://github.com/micsh): a F# port of this project
384
+ - Dart
385
+ - [llama2.dart](https://github.com/yiminghan/llama2.dart) by @[yiminghan](https://github.com/yiminghan/llama2.dart): one-file dart port of this project, works with Flutter!
386
+ - Web
387
+ - [llama2c-web](https://github.com/dmarcos/llama2.c-web) by @[dmarcos](https://github.com/dmarcos): Super simple way to build unmodified llama2.c to WASM and run it in the browser. [Demo](https://diegomarcos.com/llama2.c-web/)
388
+ - [llama2.rs.wasm](https://github.com/mtb0x1/llama2.rs.wasm) by @[mtb0x1](https://github.com/mtb0x1/) : a [Demo](https://mtb0x1.github.io/llama2.rs.wasm/) of all listed rust ports to WASM, all in one web page.
389
+ - WebAssembly
390
+ - [icpp-llm](https://github.com/icppWorld/icpp-llm): LLMs for the Internet Computer
391
+ - Fortran
392
+ - [llama2.f90](https://github.com/rbitr/llama2.f90): a Fortran port of this project
393
+ - Mojo
394
+ - [llama2.🔥](https://github.com/tairov/llama2.mojo) by @[tairov](https://github.com/tairov): pure Mojo port of this project
395
+ - OCaml
396
+ - [llama2.ml](https://github.com/jackpeck/llama2.ml) by @[jackpeck](https://github.com/jackpeck): an OCaml port of this project
397
+ - Hare
398
+ - [llama2.ha](https://sr.ht/~dvshkn/llama2.ha) by @[dvshkn](https://git.sr.ht/~dvshkn): a Hare port of this project
399
+ - [llama2.c - Llama 2 Everywhere](https://github.com/trholding/llama2.c) by @[trholding](https://github.com/trholding): Standalone, Bootable & Portable Binary Llama 2
400
+ - [llama2.c-zh - Bilingual Chinese and English](https://github.com/chenyangMl/llama2.c-zh) by @[chenyangMl](https://github.com/chenyangMl): Expand tokenizer to support training and inference in both Chinese and English
401
+ - Haskell
402
+ - [llama2.hs](https://github.com/chris-ch/llama2.hs) by @[chris-ch](https://github.com/chris-ch): an Haskell port of this project
403
+
404
+ ## unsorted todos
405
+
406
+ - add support in run.c of reading version 1+ files from export, later deprecate "version 0"
407
+ - run.cu (CUDA) investigate and merge
408
+ - add more tests inside [test.c](test.c)
409
+ - add Engine class for use in sample.py that does efficient inference in PyTorch, e.g. KV cache keeping
410
+ - make it easier to add a new dataset with not too much pain
411
+ - (LoRA) finetuning and export of Llama 2 models
412
+
413
+ ## License
414
+
415
+ MIT
llama2.c/__pycache__/export.cpython-310.pyc ADDED
Binary file (16.5 kB). View file
 
llama2.c/__pycache__/model.cpython-310.pyc ADDED
Binary file (12.4 kB). View file
 
llama2.c/__pycache__/tinystories.cpython-310.pyc ADDED
Binary file (8.27 kB). View file
 
llama2.c/__pycache__/tokenizer.cpython-310.pyc ADDED
Binary file (2.41 kB). View file
 
llama2.c/assets/llama_cute.jpg ADDED
llama2.c/build_msvc.bat ADDED
@@ -0,0 +1 @@
 
 
1
+ cl.exe /fp:fast /Ox /openmp /I. run.c win.c
llama2.c/configurator.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Poor Man's Configurator. Probably a terrible idea. Example usage:
3
+ $ python train.py config/override_file.py --batch_size=32
4
+ this will first run config/override_file.py, then override batch_size to 32
5
+
6
+ The code in this file will be run as follows from e.g. train.py:
7
+ >>> exec(open('configurator.py').read())
8
+
9
+ So it's not a Python module, it's just shuttling this code away from train.py
10
+ The code in this script then overrides the globals()
11
+
12
+ I know people are not going to love this, I just really dislike configuration
13
+ complexity and having to prepend config. to every single variable. If someone
14
+ comes up with a better simple Python solution I am all ears.
15
+ """
16
+
17
+ import sys
18
+ from ast import literal_eval
19
+
20
+ for arg in sys.argv[1:]:
21
+ if '=' not in arg:
22
+ # assume it's the name of a config file
23
+ assert not arg.startswith('--')
24
+ config_file = arg
25
+ print(f"Overriding config with {config_file}:")
26
+ with open(config_file) as f:
27
+ print(f.read())
28
+ exec(open(config_file).read())
29
+ else:
30
+ # assume it's a --key=value argument
31
+ assert arg.startswith('--')
32
+ key, val = arg.split('=')
33
+ key = key[2:]
34
+ if key in globals():
35
+ try:
36
+ # attempt to eval it it (e.g. if bool, number, or etc)
37
+ attempt = literal_eval(val)
38
+ except (SyntaxError, ValueError):
39
+ # if that goes wrong, just use the string
40
+ attempt = val
41
+ # ensure the types match ok
42
+ assert type(attempt) == type(globals()[key])
43
+ # cross fingers
44
+ print(f"Overriding: {key} = {attempt}")
45
+ globals()[key] = attempt
46
+ else:
47
+ raise ValueError(f"Unknown config key: {key}")
llama2.c/data/TinyStories_all_data.tar.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:75a94f6a0c4c93898f650fd8becfc2a2051d9e6880e7c749c990bf2a986ee15f
3
+ size 32401692
llama2.c/data/TinyStories_all_data/data00.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8fed3d0b55e1b47bbb0a854cd5988d6a4bcc2a2d48e3bee44359ce4ae84d26f1
3
+ size 41317962
llama2.c/data/TinyStories_all_data/data00.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f70e719a1dc8c4a0108cf88007abaff4665716d5191759a6c93200ba71b10074
3
+ size 140424235
llama2.c/doc/stories260K.md ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # stories260K
2
+
3
+ [Stories260K huggginface link](https://huggingface.co/karpathy/tinyllamas)
4
+
5
+ The 260K model is a tiny model used for testing, and was trained as follows:
6
+
7
+ ```
8
+ python train.py \
9
+ --out_dir="outmini" \
10
+ --batch_size=128 \
11
+ --max_seq_len=512 \
12
+ --gradient_accumulation_steps=1 \
13
+ --vocab_source="custom" \
14
+ --vocab_size=512 \
15
+ --dim=64 \
16
+ --n_layers=5 \
17
+ --n_heads=8 \
18
+ --n_kv_heads=4 \
19
+ --multiple_of=4 \
20
+ --learning_rate=1e-3 \
21
+ --dropout=0.05 \
22
+ --weight_decay=0.01 \
23
+ --max_iters=100000 \
24
+ --beta2=0.99 \
25
+ --warmup_iters=1000 \
26
+ --eval_interval=2000 \
27
+ --eval_iters=100 \
28
+ --compile=True
29
+ ```
30
+
31
+ You'll notice that `n_kv_heads` is 4 while `n_heads` is 8, so two heads at a time share their key,value projections, i.e. this model is 2X multiquery. You'll also notice that we're using a custom tokenizer with 512 tokens. The model trained for ~10 minutes (?) on my A100 and achieves validation loss of 1.2968.
32
+
33
+ Sampling this model at temperature 0.0 (i.e. deterministic greedy argmax sampling) gives:
34
+
35
+ ```
36
+ $ ./run stories260K/stories260K.bin -z stories260K/tok512.bin -t 0.0
37
+ Once upon a time, there was a little girl named Lily. She loved to play outside in the park. One day, she saw a big, red ball. She wanted to play with it, but it was too high.
38
+ Lily's mom said, "Lily, let's go to the park." Lily was sad and didn't know what to do. She said, "I want to play with your ball, but I can't find it."
39
+ Lily was sad and didn't know what to do. She said, "I'm sorry, Lily. I didn't know what to do."
40
+ Lily didn't want to help her mom, so she said, "I'm sorry, mom. I didn't know what to do." Her mom said, "Don't worry, Lily. We can help you.
41
+ ```
42
+
43
+ You can reproduce the same in Python by running `sample.py`:
44
+
45
+ ```
46
+ $ python sample.py --checkpoint=stories260K/stories260K.pt --tokenizer=stories260K/tok512.model --temperature=0.0 --max_new_tokens=257
47
+ ```
48
+
49
+ I hardcoded max tokens to be 257 manually because the `sample.py` script doesn't currently terminate on the special BOS token like the run.c script does. Sampling at 1.0 with topp of 0.9 gives a bit more reasonable samples:
50
+
51
+ ```
52
+ $ ./run stories260K/stories260K.bin -z stories260K/tok512.bin -t 1.0 -p 0.9 -s 133742
53
+ Once upon a time, there was a little boy named Timmy. Timmy loved to play with his toys and eat sandwiches. One day, Timmy's mom told him it was time to rest for a while. Timmy's friend Billy came over and took him a down.
54
+ Timmy's mom saw that Timmy was sad, but Timmy said, "I didn't understand what is it! We need to find some leafs." Timmy thought about it and took a deep breath on a spoon. He hoped it was important to be kind and continued to find its image next time.
55
+ After they finished getting, Timmy's dad came up to his house and promised to help Timmy.
56
+ ```
57
+
58
+ Hey you can't expect too much from a 260K parameter model. I'm even mildly shocked we get this far :D
llama2.c/doc/train_llama_tokenizer.md ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # training llama tokenizer
2
+
3
+ How does Meta train their sentencepiece tokenizer? You can print the config as follows:
4
+
5
+ ```python
6
+ import sentencepiece.sentencepiece_model_pb2
7
+ mp = sentencepiece.sentencepiece_model_pb2.ModelProto()
8
+ mp.ParseFromString(open("tokenizer.model", "rb").read())
9
+ print(mp.trainer_spec)
10
+ print(mp.normalizer_spec)
11
+ ```
12
+
13
+ this gives:
14
+
15
+ ```
16
+ trainer_spec {
17
+ input: "/large_experiments/theorem/datasets/MERGED/all.test1.merged"
18
+ model_prefix: "spm_model_32k_200M_charcov099995_allowWSO__v2"
19
+ model_type: BPE
20
+ vocab_size: 32000
21
+ self_test_sample_size: 0
22
+ input_format: "text"
23
+ character_coverage: 0.9999499917030334
24
+ input_sentence_size: 200000000
25
+ seed_sentencepiece_size: 1000000
26
+ shrinking_factor: 0.75
27
+ num_threads: 80
28
+ num_sub_iterations: 2
29
+ max_sentence_length: 4192
30
+ shuffle_input_sentence: true
31
+ max_sentencepiece_length: 16
32
+ split_by_unicode_script: true
33
+ split_by_whitespace: true
34
+ split_by_number: true
35
+ treat_whitespace_as_suffix: false
36
+ split_digits: true
37
+ allow_whitespace_only_pieces: true
38
+ vocabulary_output_piece_score: true
39
+ hard_vocab_limit: true
40
+ use_all_vocab: false
41
+ byte_fallback: true
42
+ required_chars: ""
43
+ unk_id: 0
44
+ bos_id: 1
45
+ eos_id: 2
46
+ pad_id: -1
47
+ unk_surface: " \342\201\207 "
48
+ unk_piece: "<unk>"
49
+ bos_piece: "<s>"
50
+ eos_piece: "</s>"
51
+ pad_piece: "<pad>"
52
+ train_extremely_large_corpus: false
53
+ enable_differential_privacy: false
54
+ differential_privacy_noise_level: 0.0
55
+ differential_privacy_clipping_threshold: 0
56
+ }
57
+ normalizer_spec {
58
+ name: "identity"
59
+ precompiled_charsmap: ""
60
+ add_dummy_prefix: true
61
+ remove_extra_whitespaces: false
62
+ normalization_rule_tsv: ""
63
+ }
64
+ ```
65
+
66
+ We can use the sentencepiece spm_train to train the same models, but optionally smaller. Here are their [options docs](https://github.com/google/sentencepiece/blob/master/doc/options.md) we can refer to. It's not much but it helps.
67
+
68
+ We'll depart on one setting, I recommend changing `character_coverage` -> 1.0. We also want to make sure to note the following important settings that come up in the paper and are not necessarily the default sentencepiece settings:
69
+
70
+ ```
71
+ --split-digits = true
72
+ --allow_whitespace_only_pieces = true
73
+ --byte_fallback = true
74
+ --normalization_rule_name = identity
75
+ ```
76
+
77
+ With this in mind we can train a sentencepiece vocab in what I believe is probably the same to how Meta trained theirs as:
78
+
79
+ ```
80
+ spm_train --input="$input" \
81
+ --model_prefix="$model_prefix" \
82
+ --model_type=bpe \
83
+ --vocab_size="$vocab_size" \
84
+ --self_test_sample_size=0 \
85
+ --input_format="text" \
86
+ --character_coverage=1.0 \
87
+ --num_threads="$(nproc)" \
88
+ --split_digits=true \
89
+ --allow_whitespace_only_pieces=true \
90
+ --byte_fallback=true \
91
+ --unk_surface=" \342\201\207 " \
92
+ --normalization_rule_name=identity \
93
+ ```
94
+
95
+ Where $input is the input file, $model_prefix is the output path prefix, vocab_size is the desired vocab, and we're by default taking over the CPU resources of the machine.
96
+
97
+ Lastly note that sentencepiece is weird and expects "sentences" delimited by newlines as the input. You can't just put in a massive block of text. And they have a hyperparameter that constols the maximum size of a "sentence". Fwiw I really dislike this design choice around a weird concept of a "sentence". It should just be block of text with no assumptions. But here we are.
98
+
99
+ Look into the file `tinystories.py` where we train the vocab in the same way, but using Python bindings instead.
llama2.c/export.py ADDED
@@ -0,0 +1,567 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This script has functions and utilties for model export.
3
+ Basically, we have a bunch of versions of the model, and we
4
+ want to export them to .bin files to be read from and inferenced in C.
5
+
6
+ Among the "input" versions of PyTorch files/models:
7
+ - Official Llama 2 weights released by Meta
8
+ - Huggingface weights available on the hub
9
+ - llama2.c (this repo) trained models
10
+
11
+ Among the "output" versions of .bin files:
12
+ - v0: Legacy files of the original llama2.c repo (will eventually be DEPRECATED)
13
+ - v1-vN: Improved .bin files with a proper header, cache alignment, etc.
14
+
15
+ This script aspires to provide all of these conversions.
16
+ """
17
+ import os
18
+ import gzip
19
+ import shutil
20
+ import struct
21
+ import argparse
22
+ import json
23
+ from pathlib import Path
24
+
25
+ import numpy as np
26
+ import torch
27
+ from torch import nn
28
+
29
+ from model import ModelArgs, Transformer
30
+
31
+ # -----------------------------------------------------------------------------
32
+ # common utilities
33
+
34
+ def serialize_fp32(file, tensor):
35
+ """ writes one fp32 tensor to file that is open in wb mode """
36
+ d = tensor.detach().cpu().view(-1).to(torch.float32).numpy()
37
+ b = struct.pack(f'{len(d)}f', *d)
38
+ file.write(b)
39
+
40
+ def serialize_int8(file, tensor):
41
+ """ writes one int8 tensor to file that is open in wb mode """
42
+ d = tensor.detach().cpu().view(-1).numpy().astype(np.int8)
43
+ b = struct.pack(f'{len(d)}b', *d)
44
+ file.write(b)
45
+
46
+ def quantize_q80(w, group_size):
47
+ """
48
+ takes a tensor and returns the Q8_0 quantized version
49
+ i.e. symmetric quantization into int8, range [-127,127]
50
+ """
51
+ assert w.numel() % group_size == 0
52
+ ori_shape = w.shape
53
+ w = w.float() # convert to float32
54
+ w = w.reshape(-1, group_size)
55
+ # find the max in each group
56
+ wmax = torch.abs(w).max(dim=1).values
57
+ # calculate the scaling factor such that float = quant * scale
58
+ scale = wmax / 127.0
59
+ # scale into range [-127, 127]
60
+ quant = w / scale[:,None]
61
+ # round to nearest integer
62
+ int8val = torch.round(quant).to(torch.int8)
63
+ # dequantize by rescaling
64
+ fp32val = (int8val.float() * scale[:,None]).view(-1)
65
+ fp32valr = fp32val.reshape(-1, group_size)
66
+ # calculate the max error in each group
67
+ err = torch.abs(fp32valr - w).max(dim=1).values
68
+ # find the max error across all groups
69
+ maxerr = err.max().item()
70
+ return int8val, scale, maxerr
71
+
72
+ # -----------------------------------------------------------------------------
73
+ # legacy
74
+
75
+ def legacy_export(model, filepath):
76
+ """ Original export of llama2.c bin files, i.e. version v0 """
77
+ out_file = open(filepath, 'wb')
78
+
79
+ # first write out the header
80
+ hidden_dim = model.layers[0].feed_forward.w1.weight.shape[0]
81
+ p = model.params
82
+ shared_classifier = torch.equal(model.tok_embeddings.weight, model.output.weight)
83
+ # legacy format uses negative/positive vocab size as a shared classifier flag
84
+ if not shared_classifier:
85
+ p.vocab_size = -p.vocab_size
86
+ n_kv_heads = p.n_heads if p.n_kv_heads is None else p.n_kv_heads
87
+ header = struct.pack('iiiiiii', p.dim, hidden_dim, p.n_layers, p.n_heads,
88
+ n_kv_heads, p.vocab_size, p.max_seq_len)
89
+ out_file.write(header)
90
+
91
+ # next write out the embedding weights
92
+ serialize_fp32(out_file, model.tok_embeddings.weight)
93
+
94
+ # now all the layers
95
+ # attention weights
96
+ for layer in model.layers:
97
+ serialize_fp32(out_file, layer.attention_norm.weight)
98
+ for layer in model.layers:
99
+ serialize_fp32(out_file, layer.attention.wq.weight)
100
+ for layer in model.layers:
101
+ serialize_fp32(out_file, layer.attention.wk.weight)
102
+ for layer in model.layers:
103
+ serialize_fp32(out_file, layer.attention.wv.weight)
104
+ for layer in model.layers:
105
+ serialize_fp32(out_file, layer.attention.wo.weight)
106
+ # ffn weights
107
+ for layer in model.layers:
108
+ serialize_fp32(out_file, layer.ffn_norm.weight)
109
+ for layer in model.layers:
110
+ serialize_fp32(out_file, layer.feed_forward.w1.weight)
111
+ for layer in model.layers:
112
+ serialize_fp32(out_file, layer.feed_forward.w2.weight)
113
+ for layer in model.layers:
114
+ serialize_fp32(out_file, layer.feed_forward.w3.weight)
115
+ # final rmsnorm
116
+ serialize_fp32(out_file, model.norm.weight)
117
+ # freqs_cis
118
+ serialize_fp32(out_file, model.freqs_cos[:p.max_seq_len])
119
+ serialize_fp32(out_file, model.freqs_sin[:p.max_seq_len])
120
+
121
+ # final classifier weights
122
+ if not shared_classifier:
123
+ serialize_fp32(out_file, model.output.weight)
124
+
125
+ # write to binary file
126
+ out_file.close()
127
+ print(f"wrote {filepath}")
128
+
129
+ # -----------------------------------------------------------------------------
130
+ # new version
131
+
132
+ def version1_export(model, filepath):
133
+ """
134
+ Export the model weights in full float32 .bin file to be read from C.
135
+ This is same as legacy_export, but with a proper header.
136
+ """
137
+ version = 1
138
+
139
+ out_file = open(filepath, 'wb')
140
+ # first write out the header. the header will be 256 bytes
141
+ # 1) write magic, which will be uint32 of "ak42" in ASCII
142
+ out_file.write(struct.pack('I', 0x616b3432))
143
+ # 2) write version, which will be int
144
+ out_file.write(struct.pack('i', version))
145
+ # 3) write the params, which will be 7 ints
146
+ p = model.params
147
+ hidden_dim = model.layers[0].feed_forward.w1.weight.shape[0]
148
+ n_kv_heads = p.n_heads if p.n_kv_heads is None else p.n_kv_heads
149
+ header = struct.pack('iiiiiii', p.dim, hidden_dim, p.n_layers, p.n_heads,
150
+ n_kv_heads, p.vocab_size, p.max_seq_len)
151
+ out_file.write(header)
152
+ # 4) write some other flags
153
+ shared_classifier = torch.equal(model.tok_embeddings.weight, model.output.weight)
154
+ out_file.write(struct.pack('B', int(shared_classifier)))
155
+ pad = 256 - out_file.tell() # pad rest with zeros; tell returns current pos
156
+ assert pad >= 0
157
+ out_file.write(b'\0' * pad)
158
+
159
+ # now let's write out all the params
160
+ weights = [
161
+ *[layer.attention_norm.weight for layer in model.layers],
162
+ *[layer.ffn_norm.weight for layer in model.layers],
163
+ model.norm.weight,
164
+ model.tok_embeddings.weight,
165
+ *[layer.attention.wq.weight for layer in model.layers],
166
+ *[layer.attention.wk.weight for layer in model.layers],
167
+ *[layer.attention.wv.weight for layer in model.layers],
168
+ *[layer.attention.wo.weight for layer in model.layers],
169
+ *[layer.feed_forward.w1.weight for layer in model.layers],
170
+ *[layer.feed_forward.w2.weight for layer in model.layers],
171
+ *[layer.feed_forward.w3.weight for layer in model.layers],
172
+ ]
173
+ if not shared_classifier:
174
+ weights.append(model.output.weight)
175
+ for w in weights:
176
+ serialize_fp32(out_file, w)
177
+
178
+ # write to binary file
179
+ out_file.close()
180
+ print(f"wrote {filepath}")
181
+
182
+ def version2_export(model, filepath, group_size=64):
183
+ """
184
+ Export the model weights in Q8_0 into .bin file to be read from C.
185
+ That is:
186
+ - quantize all weights to symmetric int8, in range [-127, 127]
187
+ - all other tensors (the rmsnorm params) are kept and exported in fp32
188
+ - quantization is done in groups of group_size to reduce the effects of any outliers
189
+ """
190
+ version = 2
191
+
192
+ # let's first do some validation for this export type
193
+ while model.params.dim % group_size != 0:
194
+ group_size //= 2
195
+ print(f"BACKOFF: reducing group size to {group_size} to fit hidden_dim")
196
+ weights = [
197
+ model.tok_embeddings.weight,
198
+ *[layer.attention.wq.weight for layer in model.layers],
199
+ *[layer.attention.wk.weight for layer in model.layers],
200
+ *[layer.attention.wv.weight for layer in model.layers],
201
+ *[layer.attention.wo.weight for layer in model.layers],
202
+ *[layer.feed_forward.w1.weight for layer in model.layers],
203
+ *[layer.feed_forward.w2.weight for layer in model.layers],
204
+ *[layer.feed_forward.w3.weight for layer in model.layers],
205
+ ]
206
+ shared_classifier = torch.equal(model.tok_embeddings.weight, model.output.weight)
207
+ if not shared_classifier:
208
+ weights.append(model.output.weight)
209
+ for w in weights:
210
+ assert w.numel() % group_size == 0, f"weight {i} has numel {w.numel()}, not a multiple of group_size {group_size}"
211
+
212
+ # write
213
+ out_file = open(filepath, 'wb')
214
+ # first write out the header. the header will be 256 bytes
215
+ # 1) write magic, which will be uint32 of "ak42" in ASCII
216
+ out_file.write(struct.pack('I', 0x616b3432))
217
+ # 2) write version, which will be int
218
+ out_file.write(struct.pack('i', version))
219
+ # 3) write the params, which will be 7 ints
220
+ p = model.params
221
+ hidden_dim = model.layers[0].feed_forward.w1.weight.shape[0]
222
+ n_kv_heads = p.n_heads if p.n_kv_heads is None else p.n_kv_heads
223
+ header = struct.pack('iiiiiii', p.dim, hidden_dim, p.n_layers, p.n_heads,
224
+ n_kv_heads, p.vocab_size, p.max_seq_len)
225
+ out_file.write(header)
226
+ # 4) write some other flags
227
+ out_file.write(struct.pack('B', int(shared_classifier)))
228
+ out_file.write(struct.pack('i', group_size)) # group size used for quantization
229
+ pad = 256 - out_file.tell() # pad rest with zeros; tell returns current pos
230
+ assert pad >= 0
231
+ out_file.write(b'\0' * pad)
232
+ # now that the header is done, let's write out the model
233
+
234
+ # first let's write out all the params that we are keeping in fp32: the norms
235
+ for layer in model.layers: # attention norms
236
+ serialize_fp32(out_file, layer.attention_norm.weight)
237
+ for layer in model.layers: # MLP norms
238
+ serialize_fp32(out_file, layer.ffn_norm.weight)
239
+ serialize_fp32(out_file, model.norm.weight) # final pre-classifier norm
240
+
241
+ # now let's write out all the params that we are quantizing to Q8_0
242
+ # note we skip classifier weights, which are shared with the embedding
243
+ ew = []
244
+ for i, w in enumerate(weights):
245
+ # quantize this weight
246
+ q, s, err = quantize_q80(w, group_size)
247
+ # save the int8 weights to file
248
+ serialize_int8(out_file, q) # save the tensor in int8
249
+ serialize_fp32(out_file, s) # save scale factors
250
+ # logging
251
+ ew.append((err, w.shape))
252
+ print(f"{i+1}/{len(weights)} quantized {tuple(w.shape)} to Q8_0 with max error {err}")
253
+
254
+ # print the highest error across all weights, should be very small, e.g. O(~0.001)
255
+ ew.sort(reverse=True)
256
+ print(f"max quantization group error across all weights: {ew[0][0]}")
257
+
258
+ # write to binary file
259
+ out_file.close()
260
+ print(f"wrote {filepath}")
261
+
262
+ def hf_export(llama_model, filepath, group_size=64, dtype=torch.float32):
263
+ """ Generate the pytorch_model.bin state_dict and config.json for HuggingFace """
264
+
265
+ try:
266
+ from transformers.models.llama.configuration_llama import LlamaConfig
267
+ except ImportError:
268
+ print("Error: transformers package is required to load huggingface models")
269
+ print("Please run `pip install transformers` to install it")
270
+ return None
271
+
272
+ # Generate LlamaModel state_dict
273
+ hf_state_dict = {}
274
+
275
+ # Sometimes we have repeated key values for the heads
276
+ dim = llama_model.params.dim
277
+ num_key_value_heads = llama_model.params.n_kv_heads
278
+ n_rep = llama_model.params.n_heads // num_key_value_heads
279
+ key_value_dim = dim // n_rep
280
+
281
+ # HuggingFace needs the weights permuted.
282
+ # See: https://github.com/huggingface/transformers/blob/b132c1703eb1c8bd9dfa4ad6a9be2bfd6ef819e9/src/transformers/models/llama/convert_llama_weights_to_hf.py#L122
283
+ def permute_original(w, n_heads=llama_model.params.n_heads, dim1=dim, dim2=dim):
284
+ return w.view(dim1, dim2).reshape(n_heads, dim1 // n_heads // 2, 2, dim2).transpose(1, 2).reshape(dim1, dim2)
285
+
286
+ # Transfer weights from llama model to the HF state dictionary format
287
+ hf_state_dict['model.embed_tokens.weight'] = llama_model.tok_embeddings.weight.clone().to(dtype)
288
+ hf_state_dict['model.norm.weight'] = llama_model.norm.weight.clone().to(dtype)
289
+
290
+ # Add each layer's weights to the HF state dictionary
291
+ for i, layer in enumerate(llama_model.layers):
292
+ layer_id = layer.layer_id
293
+ hf_state_dict[f'model.layers.{i}.input_layernorm.weight'] = llama_model.layers[layer_id].attention_norm.weight.clone().to(dtype)
294
+ hf_state_dict[f'model.layers.{i}.self_attn.q_proj.weight'] = permute_original(llama_model.layers[layer_id].attention.wq.weight.clone()).to(dtype)
295
+ hf_state_dict[f'model.layers.{i}.self_attn.k_proj.weight'] = permute_original(llama_model.layers[layer_id].attention.wk.weight.clone(), num_key_value_heads, key_value_dim, dim).to(dtype)
296
+ hf_state_dict[f'model.layers.{i}.self_attn.v_proj.weight'] = llama_model.layers[layer_id].attention.wv.weight.clone().to(dtype)
297
+ hf_state_dict[f'model.layers.{i}.self_attn.o_proj.weight'] = llama_model.layers[layer_id].attention.wo.weight.clone().to(dtype)
298
+ hf_state_dict[f'model.layers.{i}.post_attention_layernorm.weight'] = llama_model.layers[layer_id].ffn_norm.weight.clone().to(dtype)
299
+ hf_state_dict[f'model.layers.{i}.mlp.gate_proj.weight'] = llama_model.layers[layer_id].feed_forward.w1.weight.clone().to(dtype)
300
+ hf_state_dict[f'model.layers.{i}.mlp.down_proj.weight'] = llama_model.layers[layer_id].feed_forward.w2.weight.clone().to(dtype)
301
+ hf_state_dict[f'model.layers.{i}.mlp.up_proj.weight'] = llama_model.layers[layer_id].feed_forward.w3.weight.clone().to(dtype)
302
+
303
+ # llama2.c usually uses tied weights -> reference the embed_tokens.weights instead
304
+ hf_state_dict['lm_head.weight'] = hf_state_dict['model.embed_tokens.weight']
305
+
306
+ # We check that the embeddings are tied, else use manual output weights
307
+ _embeddings_are_tied: bool = torch.equal(llama_model.tok_embeddings.weight, llama_model.output.weight)
308
+ if not _embeddings_are_tied:
309
+ hf_state_dict['lm_head.weight'] = llama_model.output.weight.clone().to(dtype)
310
+
311
+
312
+ # Generate LlamaConfig (seen in transformers.models.llama.configuration_llama)
313
+
314
+ # Extract necessary attributes from llama.c model
315
+ vocab_size = llama_model.params.vocab_size
316
+ hidden_size = llama_model.params.dim
317
+ intermediate_size = llama_model.layers[0].feed_forward.w1.weight.shape[0]
318
+ num_hidden_layers = llama_model.params.n_layers
319
+ num_attention_heads = llama_model.params.n_heads
320
+ num_key_value_heads = llama_model.params.n_kv_heads
321
+ max_position_embeddings = llama_model.params.max_seq_len
322
+ rms_norm_eps = llama_model.params.norm_eps
323
+
324
+ # TODO check values for:
325
+ # pretraining_tp, initializer_range, use_cache,
326
+ # rope_theta, and rope_scaling.
327
+
328
+ config = LlamaConfig(
329
+ vocab_size=vocab_size,
330
+ hidden_size=hidden_size,
331
+ intermediate_size=intermediate_size,
332
+ num_hidden_layers=num_hidden_layers,
333
+ num_attention_heads=num_attention_heads,
334
+ num_key_value_heads=num_key_value_heads,
335
+ max_position_embeddings=max_position_embeddings,
336
+ rms_norm_eps=rms_norm_eps,
337
+ tie_word_embeddings=_embeddings_are_tied,
338
+ # Manual
339
+ architectures=["LlamaForCausalLM"],
340
+ hidden_act="silu",
341
+ )
342
+
343
+
344
+ # Save files in directory filepath
345
+ # First make the directory if it doesn't exist
346
+ os.makedirs(filepath, exist_ok=True)
347
+
348
+ # Save the state dictionary in .bin format, and config as .json
349
+ torch.save(hf_state_dict, os.path.join(filepath, "pytorch_model.bin"))
350
+ config.save_pretrained(filepath)
351
+
352
+
353
+ # -----------------------------------------------------------------------------
354
+ # Load / import functions
355
+
356
+ def load_checkpoint(checkpoint):
357
+
358
+ # load the provided model checkpoint
359
+ checkpoint_dict = torch.load(checkpoint, map_location='cpu')
360
+ gptconf = ModelArgs(**checkpoint_dict['model_args'])
361
+ model = Transformer(gptconf)
362
+ state_dict = checkpoint_dict['model']
363
+ unwanted_prefix = '_orig_mod.'
364
+ for k,v in list(state_dict.items()):
365
+ if k.startswith(unwanted_prefix):
366
+ state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
367
+ model.load_state_dict(state_dict, strict=False)
368
+ model.eval()
369
+ return model
370
+
371
+ def load_meta_model(model_path):
372
+ params_path = os.path.join(model_path, 'params.json')
373
+ with open(params_path) as f:
374
+ params = json.load(f)
375
+ print(params)
376
+
377
+ model_paths = sorted(list(Path(model_path).glob('consolidated.*.pth')))
378
+ models = [torch.load(p, map_location='cpu') for p in model_paths]
379
+
380
+ def concat_weights(models):
381
+ state_dict = {}
382
+ for name in list(models[0]):
383
+ tensors = [model[name] for model in models]
384
+ if len(tensors) == 1 or len(tensors[0].shape) == 1:
385
+ state_dict[name] = tensors[0]
386
+ continue
387
+ is_axis_1 = (
388
+ name.startswith('tok_embeddings.')
389
+ or name.endswith('.attention.wo.weight')
390
+ or name.endswith('.feed_forward.w2.weight')
391
+ )
392
+ axis = 1 if is_axis_1 else 0
393
+ state_dict[name] = torch.cat(tensors, dim=axis)
394
+ for model in models:
395
+ del model[name]
396
+ return state_dict
397
+
398
+ state_dict = concat_weights(models)
399
+ del models
400
+
401
+ # set ModelArgs
402
+ config = ModelArgs()
403
+ config.dim = params["dim"]
404
+ config.n_layers = params["n_layers"]
405
+ config.n_heads = params["n_heads"]
406
+ config.n_kv_heads = params.get('n_kv_heads') or params['n_heads']
407
+ config.multiple_of = params["multiple_of"]
408
+ config.norm_eps = params["norm_eps"]
409
+
410
+ config.vocab_size = state_dict['tok_embeddings.weight'].shape[0]
411
+ config.max_seq_len = 2048
412
+
413
+
414
+ # create a new Transformer object and set weights
415
+ model = Transformer(config)
416
+
417
+ model.tok_embeddings.weight = nn.Parameter(state_dict['tok_embeddings.weight'])
418
+ model.norm.weight = nn.Parameter(state_dict['norm.weight'])
419
+
420
+ for layer in model.layers:
421
+ i = layer.layer_id
422
+ layer.attention_norm.weight = nn.Parameter(state_dict[f'layers.{i}.attention_norm.weight'])
423
+ layer.attention.wq.weight = nn.Parameter(state_dict[f'layers.{i}.attention.wq.weight'])
424
+ layer.attention.wk.weight = nn.Parameter(state_dict[f'layers.{i}.attention.wk.weight'])
425
+ layer.attention.wv.weight = nn.Parameter(state_dict[f'layers.{i}.attention.wv.weight'])
426
+ layer.attention.wo.weight = nn.Parameter(state_dict[f'layers.{i}.attention.wo.weight'])
427
+ layer.ffn_norm.weight = nn.Parameter(state_dict[f'layers.{i}.ffn_norm.weight'])
428
+ layer.feed_forward.w1.weight = nn.Parameter(state_dict[f'layers.{i}.feed_forward.w1.weight'])
429
+ layer.feed_forward.w2.weight = nn.Parameter(state_dict[f'layers.{i}.feed_forward.w2.weight'])
430
+ layer.feed_forward.w3.weight = nn.Parameter(state_dict[f'layers.{i}.feed_forward.w3.weight'])
431
+
432
+ # final classifier
433
+ model.output.weight = nn.Parameter(state_dict['output.weight'])
434
+ model.eval()
435
+ return model
436
+
437
+ def load_hf_model(model_path):
438
+
439
+ try:
440
+ from transformers import AutoModelForCausalLM
441
+ except ImportError:
442
+ print("Error: transformers package is required to load huggingface models")
443
+ print("Please run `pip install transformers` to install it")
444
+ return None
445
+
446
+ # load HF model
447
+ hf_model = AutoModelForCausalLM.from_pretrained(model_path)
448
+ hf_dict = hf_model.state_dict()
449
+
450
+ # convert LlamaConfig to ModelArgs
451
+ config = ModelArgs()
452
+ config.dim = hf_model.config.hidden_size
453
+ config.n_layers = hf_model.config.num_hidden_layers
454
+ config.n_heads = hf_model.config.num_attention_heads
455
+ config.n_kv_heads = hf_model.config.num_attention_heads
456
+ config.vocab_size = hf_model.config.vocab_size
457
+ config.hidden_dim = hf_model.config.intermediate_size
458
+ config.norm_eps = hf_model.config.rms_norm_eps
459
+ config.max_seq_len = hf_model.config.max_position_embeddings
460
+
461
+ # create a new Transformer object and set weights
462
+ model = Transformer(config)
463
+
464
+ model.tok_embeddings.weight = nn.Parameter(hf_dict['model.embed_tokens.weight'])
465
+ model.norm.weight = nn.Parameter(hf_dict['model.norm.weight'])
466
+
467
+ # huggingface permutes WQ and WK, this function reverses it
468
+ def permute_reverse(w, n_heads=config.n_heads, dim1=config.dim, dim2=config.dim):
469
+ return w.view(n_heads, 2, dim1 // n_heads // 2, dim2).transpose(1, 2).reshape(dim1, dim2)
470
+
471
+ for layer in model.layers:
472
+ i = layer.layer_id
473
+ layer.attention_norm.weight = nn.Parameter(hf_dict[f'model.layers.{i}.input_layernorm.weight'])
474
+ layer.attention.wq.weight = nn.Parameter(permute_reverse(hf_dict[f'model.layers.{i}.self_attn.q_proj.weight']))
475
+ layer.attention.wk.weight = nn.Parameter(permute_reverse(hf_dict[f'model.layers.{i}.self_attn.k_proj.weight']))
476
+ layer.attention.wv.weight = nn.Parameter(hf_dict[f'model.layers.{i}.self_attn.v_proj.weight'])
477
+ layer.attention.wo.weight = nn.Parameter(hf_dict[f'model.layers.{i}.self_attn.o_proj.weight'])
478
+ layer.ffn_norm.weight = nn.Parameter(hf_dict[f'model.layers.{i}.post_attention_layernorm.weight'])
479
+ layer.feed_forward.w1.weight = nn.Parameter(hf_dict[f'model.layers.{i}.mlp.gate_proj.weight'])
480
+ layer.feed_forward.w2.weight = nn.Parameter(hf_dict[f'model.layers.{i}.mlp.down_proj.weight'])
481
+ layer.feed_forward.w3.weight = nn.Parameter(hf_dict[f'model.layers.{i}.mlp.up_proj.weight'])
482
+
483
+ # final classifier
484
+ model.output.weight = nn.Parameter(hf_dict['lm_head.weight'])
485
+ model.eval()
486
+ return model
487
+
488
+
489
+ # -----------------------------------------------------------------------------
490
+ # API entrypoint
491
+
492
+ def model_export(model, filepath, version, dtype=torch.float32):
493
+ """
494
+ Versions docs:
495
+ v-1:huggingface export, i.e. intended for use outside of this repo, in HF
496
+ v0: legacy llama2.c float format, DEPRECATED
497
+ v1: float32 export
498
+ v2: int8 quantized Q8_0 export, similar to llama.cpp, in groups
499
+ # TODO: add dtype export support for other versions (?)
500
+ """
501
+ if version == 0:
502
+ legacy_export(model, filepath)
503
+ elif version == 1:
504
+ version1_export(model, filepath)
505
+ elif version == 2:
506
+ version2_export(model, filepath)
507
+ elif version == -1:
508
+ hf_export(model, filepath, dtype)
509
+ else:
510
+ raise ValueError(f"unknown version {version}")
511
+
512
+ def torchscript_export(model, filepath, zero_params=False, gzip_output=False):
513
+ """
514
+ (This was submitted via a PR earlier. Leaving it here, but "orphaned" for now)
515
+ Saves the model as a TorchScript.
516
+ The resulting file can be loaded in C++ code and then used for training or
517
+ inference with:
518
+ #include <torch/script.h>
519
+ torch::jit::Module module = torch::jit::load("model.pt")
520
+ Note that the serialized model includes the initial parameters and with the default
521
+ ModelArgs the file is 59M and gzips down to 55M. If you want to serialize/distribute
522
+ the model parameters separately you can zero out the parameters before saving it and
523
+ it will gzip down to 780K.
524
+ """
525
+
526
+ # If requested zero params before saving the model. This is useful in
527
+ # conjunction with gzip_output.
528
+ if zero_params:
529
+ for p in model.parameters():
530
+ p.detach().zero_()
531
+
532
+ torch.jit.save(torch.jit.script(model), filepath)
533
+
534
+ if gzip_output:
535
+ with open(filepath, "rb") as f_in:
536
+ with gzip.open(f"{filepath}.gz", "wb") as f_out:
537
+ shutil.copyfileobj(f_in, f_out)
538
+ os.unlink(filepath)
539
+
540
+ # -----------------------------------------------------------------------------
541
+ # CLI entrypoint
542
+
543
+ if __name__ == "__main__":
544
+
545
+ parser = argparse.ArgumentParser()
546
+ parser.add_argument("filepath", type=str, help="the output filepath")
547
+ parser.add_argument("--version", default=0, type=int, help="the version to export with")
548
+ parser.add_argument("--dtype", type=str, help="dtype of the model (fp16, fp32)", default="fp32")
549
+ group = parser.add_mutually_exclusive_group(required=True)
550
+ group.add_argument("--checkpoint", type=str, help="model checkpoint, .pt file")
551
+ group.add_argument("--meta-llama", type=str, help="meta llama model path")
552
+ group.add_argument("--hf", type=str, help="huggingface model path")
553
+ args = parser.parse_args()
554
+ dtype = {"fp16": torch.float16, "fp32": torch.float32}[args.dtype]
555
+
556
+ if args.checkpoint:
557
+ model = load_checkpoint(args.checkpoint)
558
+ elif args.meta_llama:
559
+ model = load_meta_model(args.meta_llama)
560
+ elif args.hf:
561
+ model = load_hf_model(args.hf)
562
+
563
+ if model is None:
564
+ parser.error("Can't load input model!")
565
+
566
+ # export
567
+ model_export(model, args.filepath, args.version, args.dtype)
llama2.c/model.py ADDED
@@ -0,0 +1,343 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import struct
3
+ import inspect
4
+ from dataclasses import dataclass
5
+ from typing import Any, Optional, Tuple
6
+
7
+ import numpy as np
8
+ import torch
9
+ import torch.nn.functional as F
10
+ from torch import nn
11
+
12
+ @dataclass
13
+ class ModelArgs:
14
+ # default hyperparameters for the Llama 7B model
15
+ dim: int = 4096
16
+ n_layers: int = 32
17
+ n_heads: int = 32
18
+ n_kv_heads: Optional[int] = None
19
+ vocab_size: int = 32000
20
+ hidden_dim: Optional[int] = None
21
+ multiple_of: int = 256 # MLP hidden layer size will be multiple of
22
+ norm_eps: float = 1e-5
23
+ max_seq_len: int = 2048
24
+ dropout: float = 0.0
25
+
26
+
27
+ class RMSNorm(torch.nn.Module):
28
+ def __init__(self, dim: int, eps: float):
29
+ super().__init__()
30
+ self.eps = eps
31
+ self.weight = nn.Parameter(torch.ones(dim))
32
+
33
+ def _norm(self, x):
34
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
35
+
36
+ def forward(self, x):
37
+ output = self._norm(x.float()).type_as(x)
38
+ return output * self.weight
39
+
40
+
41
+ def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
42
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
43
+ t = torch.arange(end, device=freqs.device) # type: ignore
44
+ freqs = torch.outer(t, freqs).float() # type: ignore
45
+ freqs_cos = torch.cos(freqs) # real part
46
+ freqs_sin = torch.sin(freqs) # imaginary part
47
+ return freqs_cos, freqs_sin
48
+
49
+ def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
50
+ ndim = x.ndim
51
+ assert 0 <= 1 < ndim
52
+ assert freqs_cis.shape == (x.shape[1], x.shape[-1])
53
+ shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
54
+ return freqs_cis.view(shape)
55
+
56
+ def apply_rotary_emb(
57
+ xq: torch.Tensor,
58
+ xk: torch.Tensor,
59
+ freqs_cos: torch.Tensor,
60
+ freqs_sin: torch.Tensor
61
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
62
+
63
+ # reshape xq and xk to match the complex representation
64
+ xq_r, xq_i = xq.float().reshape(xq.shape[:-1] + (-1, 2)).unbind(-1)
65
+ xk_r, xk_i = xk.float().reshape(xk.shape[:-1] + (-1, 2)).unbind(-1)
66
+
67
+ # reshape freqs_cos and freqs_sin for broadcasting
68
+ freqs_cos = reshape_for_broadcast(freqs_cos, xq_r)
69
+ freqs_sin = reshape_for_broadcast(freqs_sin, xq_r)
70
+
71
+ # apply rotation using real numbers
72
+ xq_out_r = xq_r * freqs_cos - xq_i * freqs_sin
73
+ xq_out_i = xq_r * freqs_sin + xq_i * freqs_cos
74
+ xk_out_r = xk_r * freqs_cos - xk_i * freqs_sin
75
+ xk_out_i = xk_r * freqs_sin + xk_i * freqs_cos
76
+
77
+ # flatten last two dimensions
78
+ xq_out = torch.stack([xq_out_r, xq_out_i], dim=-1).flatten(3)
79
+ xk_out = torch.stack([xk_out_r, xk_out_i], dim=-1).flatten(3)
80
+
81
+ return xq_out.type_as(xq), xk_out.type_as(xk)
82
+
83
+ def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
84
+ """torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
85
+ bs, slen, n_kv_heads, head_dim = x.shape
86
+ if n_rep == 1:
87
+ return x
88
+ return (
89
+ x[:, :, :, None, :]
90
+ .expand(bs, slen, n_kv_heads, n_rep, head_dim)
91
+ .reshape(bs, slen, n_kv_heads * n_rep, head_dim)
92
+ )
93
+
94
+ class Attention(nn.Module):
95
+ def __init__(self, args: ModelArgs):
96
+ super().__init__()
97
+ self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
98
+ assert args.n_heads % self.n_kv_heads == 0
99
+ model_parallel_size = 1
100
+ self.n_local_heads = args.n_heads // model_parallel_size
101
+ self.n_local_kv_heads = self.n_kv_heads // model_parallel_size
102
+ self.n_rep = self.n_local_heads // self.n_local_kv_heads
103
+ self.head_dim = args.dim // args.n_heads
104
+ self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False)
105
+ self.wk = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
106
+ self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
107
+ self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False)
108
+ self.attn_dropout = nn.Dropout(args.dropout)
109
+ self.resid_dropout = nn.Dropout(args.dropout)
110
+ self.dropout = args.dropout
111
+
112
+ # use flash attention or a manual implementation?
113
+ self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
114
+ if not self.flash:
115
+ print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
116
+ mask = torch.full((1, 1, args.max_seq_len, args.max_seq_len), float("-inf"))
117
+ mask = torch.triu(mask, diagonal=1)
118
+ self.register_buffer("mask", mask)
119
+
120
+ def forward(
121
+ self,
122
+ x: torch.Tensor,
123
+ freqs_cos: torch.Tensor,
124
+ freqs_sin: torch.Tensor,
125
+ ):
126
+ bsz, seqlen, _ = x.shape
127
+
128
+ # QKV
129
+ xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
130
+ xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
131
+ xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
132
+ xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
133
+
134
+ # RoPE relative positional embeddings
135
+ xq, xk = apply_rotary_emb(xq, xk, freqs_cos, freqs_sin)
136
+
137
+ # grouped multiquery attention: expand out keys and values
138
+ xk = repeat_kv(xk, self.n_rep) # (bs, seqlen, n_local_heads, head_dim)
139
+ xv = repeat_kv(xv, self.n_rep) # (bs, seqlen, n_local_heads, head_dim)
140
+
141
+ # make heads into a batch dimension
142
+ xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
143
+ xk = xk.transpose(1, 2)
144
+ xv = xv.transpose(1, 2)
145
+
146
+ # flash implementation
147
+ if self.flash:
148
+ output = torch.nn.functional.scaled_dot_product_attention(xq, xk, xv, attn_mask=None, dropout_p=self.dropout if self.training else 0.0, is_causal=True)
149
+ else:
150
+ # manual implementation
151
+ scores = torch.matmul(xq, xk.transpose(2, 3)) / math.sqrt(self.head_dim)
152
+ assert hasattr(self, 'mask')
153
+ scores = scores + self.mask[:, :, :seqlen, :seqlen] # (bs, n_local_heads, seqlen, cache_len + seqlen)
154
+ scores = F.softmax(scores.float(), dim=-1).type_as(xq)
155
+ scores = self.attn_dropout(scores)
156
+ output = torch.matmul(scores, xv) # (bs, n_local_heads, seqlen, head_dim)
157
+
158
+ # restore time as batch dimension and concat heads
159
+ output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
160
+
161
+ # final projection into the residual stream
162
+ output = self.wo(output)
163
+ output = self.resid_dropout(output)
164
+ return output
165
+
166
+
167
+ class FeedForward(nn.Module):
168
+ def __init__(self, dim: int, hidden_dim: int, multiple_of: int, dropout: float):
169
+ super().__init__()
170
+ if hidden_dim is None:
171
+ hidden_dim = 4 * dim
172
+ hidden_dim = int(2 * hidden_dim / 3)
173
+ hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
174
+ self.w1 = nn.Linear(dim, hidden_dim, bias=False)
175
+ self.w2 = nn.Linear(hidden_dim, dim, bias=False)
176
+ self.w3 = nn.Linear(dim, hidden_dim, bias=False)
177
+ self.dropout = nn.Dropout(dropout)
178
+
179
+ def forward(self, x):
180
+ return self.dropout(self.w2(F.silu(self.w1(x)) * self.w3(x)))
181
+
182
+
183
+ class TransformerBlock(nn.Module):
184
+ def __init__(self, layer_id: int, args: ModelArgs):
185
+ super().__init__()
186
+ self.n_heads = args.n_heads
187
+ self.dim = args.dim
188
+ self.head_dim = args.dim // args.n_heads
189
+ self.attention = Attention(args)
190
+ self.feed_forward = FeedForward(
191
+ dim=args.dim,
192
+ hidden_dim=args.hidden_dim,
193
+ multiple_of=args.multiple_of,
194
+ dropout=args.dropout,
195
+ )
196
+ self.layer_id = layer_id
197
+ self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
198
+ self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
199
+
200
+ def forward(self, x, freqs_cos, freqs_sin):
201
+ h = x + self.attention.forward(self.attention_norm(x), freqs_cos, freqs_sin)
202
+ out = h + self.feed_forward.forward(self.ffn_norm(h))
203
+ return out
204
+
205
+
206
+ class Transformer(nn.Module):
207
+ last_loss: Optional[torch.Tensor]
208
+
209
+ def __init__(self, params: ModelArgs):
210
+ super().__init__()
211
+ self.params = params
212
+ self.vocab_size = params.vocab_size
213
+ self.n_layers = params.n_layers
214
+
215
+ self.tok_embeddings = nn.Embedding(params.vocab_size, params.dim)
216
+ self.dropout = nn.Dropout(params.dropout)
217
+ self.layers = torch.nn.ModuleList()
218
+ for layer_id in range(params.n_layers):
219
+ self.layers.append(TransformerBlock(layer_id, params))
220
+ self.norm = RMSNorm(params.dim, eps=params.norm_eps)
221
+ self.output = nn.Linear(params.dim, params.vocab_size, bias=False)
222
+
223
+ # share the unembedding parameters with the embedding parameters
224
+ self.tok_embeddings.weight = self.output.weight # https://paperswithcode.com/method/weight-tying
225
+
226
+ # some useful precompute for the RoPE relative positional embeddings
227
+ freqs_cos, freqs_sin = precompute_freqs_cis(self.params.dim // self.params.n_heads, self.params.max_seq_len)
228
+ self.register_buffer("freqs_cos", freqs_cos, persistent=False)
229
+ self.register_buffer("freqs_sin", freqs_sin, persistent=False)
230
+
231
+ # init all weights
232
+ self.apply(self._init_weights)
233
+ # apply special scaled init to the residual projections, per GPT-2 paper
234
+ for pn, p in self.named_parameters():
235
+ if pn.endswith('w3.weight') or pn.endswith('wo.weight'):
236
+ torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * params.n_layers))
237
+
238
+ # Initialize attribute for the loss of the last forward call. This will be set if the forward is called with a targets tensor.
239
+ self.last_loss = None
240
+
241
+ def _init_weights(self, module):
242
+ if isinstance(module, nn.Linear):
243
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
244
+ if module.bias is not None:
245
+ torch.nn.init.zeros_(module.bias)
246
+ elif isinstance(module, nn.Embedding):
247
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
248
+
249
+ def forward(self, tokens: torch.Tensor, targets: Optional[torch.Tensor] = None) -> torch.Tensor:
250
+ _bsz, seqlen = tokens.shape
251
+ h = self.tok_embeddings(tokens)
252
+ h = self.dropout(h)
253
+ freqs_cos = self.freqs_cos[:seqlen]
254
+ freqs_sin = self.freqs_sin[:seqlen]
255
+
256
+ for layer in self.layers:
257
+ h = layer(h, freqs_cos, freqs_sin)
258
+ h = self.norm(h)
259
+
260
+ if targets is not None:
261
+ # if we are given some desired targets also calculate the loss
262
+ logits = self.output(h)
263
+ self.last_loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
264
+ else:
265
+ # inference-time mini-optimization: only forward the output on the very last position
266
+ logits = self.output(h[:, [-1], :]) # note: using list [-1] to preserve the time dim
267
+ self.last_loss = None
268
+
269
+ return logits
270
+
271
+ def configure_optimizers(self, weight_decay, learning_rate, betas, device_type):
272
+ # start with all of the candidate parameters
273
+ param_dict = {pn: p for pn, p in self.named_parameters()}
274
+ # filter out those that do not require grad
275
+ param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad}
276
+ # create optim groups. Any parameters that is 2D will be weight decayed, otherwise no.
277
+ # i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't.
278
+ decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]
279
+ nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]
280
+ optim_groups = [
281
+ {'params': decay_params, 'weight_decay': weight_decay},
282
+ {'params': nodecay_params, 'weight_decay': 0.0}
283
+ ]
284
+ num_decay_params = sum(p.numel() for p in decay_params)
285
+ num_nodecay_params = sum(p.numel() for p in nodecay_params)
286
+ print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters")
287
+ print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters")
288
+ # Create AdamW optimizer and use the fused version if it is available
289
+ fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters
290
+ use_fused = fused_available and device_type == 'cuda'
291
+ extra_args = dict(fused=True) if use_fused else dict()
292
+ optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, **extra_args)
293
+ print(f"using fused AdamW: {use_fused}")
294
+
295
+ return optimizer
296
+
297
+ def estimate_mfu(self, fwdbwd_per_iter, dt):
298
+ """ estimate model flops utilization (MFU) in units of A100 bfloat16 peak FLOPS """
299
+ # first estimate the number of flops we do per iteration.
300
+ # see PaLM paper Appendix B as ref: https://arxiv.org/abs/2204.02311
301
+ N = sum(p.numel() for p in self.parameters())
302
+ cfg = self.params
303
+ L, H, Q, T = cfg.n_layers, cfg.n_heads, cfg.dim//cfg.n_heads, cfg.max_seq_len
304
+ flops_per_token = 6*N + 12*L*H*Q*T
305
+ flops_per_fwdbwd = flops_per_token * T
306
+ flops_per_iter = flops_per_fwdbwd * fwdbwd_per_iter
307
+ # express our flops throughput as ratio of A100 bfloat16 peak flops
308
+ flops_achieved = flops_per_iter * (1.0/dt) # per second
309
+ flops_promised = 312e12 # A100 GPU bfloat16 peak flops is 312 TFLOPS
310
+ mfu = flops_achieved / flops_promised
311
+ return mfu
312
+
313
+ @torch.inference_mode()
314
+ def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
315
+ """
316
+ Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete
317
+ the sequence max_new_tokens times, feeding the predictions back into the model each time.
318
+ Most likely you'll want to make sure to be in model.eval() mode of operation for this.
319
+ Also note this is a super inefficient version of sampling with no key/value cache.
320
+ """
321
+ for _ in range(max_new_tokens):
322
+ # if the sequence context is growing too long we must crop it at block_size
323
+ idx_cond = idx if idx.size(1) <= self.params.max_seq_len else idx[:, -self.params.max_seq_len:]
324
+ # forward the model to get the logits for the index in the sequence
325
+ logits = self(idx_cond)
326
+ logits = logits[:, -1, :] # crop to just the final time step
327
+ if temperature == 0.0:
328
+ # "sample" the single most likely index
329
+ _, idx_next = torch.topk(logits, k=1, dim=-1)
330
+ else:
331
+ # pluck the logits at the final step and scale by desired temperature
332
+ logits = logits / temperature
333
+ # optionally crop the logits to only the top k options
334
+ if top_k is not None:
335
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
336
+ logits[logits < v[:, [-1]]] = -float('Inf')
337
+ # apply softmax to convert logits to (normalized) probabilities
338
+ probs = F.softmax(logits, dim=-1)
339
+ idx_next = torch.multinomial(probs, num_samples=1)
340
+ # append sampled index to the running sequence and continue
341
+ idx = torch.cat((idx, idx_next), dim=1)
342
+
343
+ return idx
llama2.c/requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ numpy==1.23.5
2
+ pytest==7.4.0
3
+ Requests==2.31.0
4
+ sentencepiece==0.1.99
5
+ torch==2.0.1
6
+ tqdm==4.64.1
7
+ wandb==0.15.5
llama2.c/run.c ADDED
@@ -0,0 +1,973 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /* Inference for Llama-2 Transformer model in pure C */
2
+
3
+ #include <stdio.h>
4
+ #include <stdlib.h>
5
+ #include <ctype.h>
6
+ #include <time.h>
7
+ #include <math.h>
8
+ #include <string.h>
9
+ #include <fcntl.h>
10
+ #if defined _WIN32
11
+ #include "win.h"
12
+ #else
13
+ #include <unistd.h>
14
+ #include <sys/mman.h>
15
+ #endif
16
+ // ----------------------------------------------------------------------------
17
+ // Transformer model
18
+
19
+ typedef struct {
20
+ int dim; // transformer dimension
21
+ int hidden_dim; // for ffn layers
22
+ int n_layers; // number of layers
23
+ int n_heads; // number of query heads
24
+ int n_kv_heads; // number of key/value heads (can be < query heads because of multiquery)
25
+ int vocab_size; // vocabulary size, usually 256 (byte-level)
26
+ int seq_len; // max sequence length
27
+ } Config;
28
+
29
+ typedef struct {
30
+ // token embedding table
31
+ float* token_embedding_table; // (vocab_size, dim)
32
+ // weights for rmsnorms
33
+ float* rms_att_weight; // (layer, dim) rmsnorm weights
34
+ float* rms_ffn_weight; // (layer, dim)
35
+ // weights for matmuls. note dim == n_heads * head_size
36
+ float* wq; // (layer, dim, n_heads * head_size)
37
+ float* wk; // (layer, dim, n_kv_heads * head_size)
38
+ float* wv; // (layer, dim, n_kv_heads * head_size)
39
+ float* wo; // (layer, n_heads * head_size, dim)
40
+ // weights for ffn
41
+ float* w1; // (layer, hidden_dim, dim)
42
+ float* w2; // (layer, dim, hidden_dim)
43
+ float* w3; // (layer, hidden_dim, dim)
44
+ // final rmsnorm
45
+ float* rms_final_weight; // (dim,)
46
+ // (optional) classifier weights for the logits, on the last layer
47
+ float* wcls;
48
+ } TransformerWeights;
49
+
50
+ typedef struct {
51
+ // current wave of activations
52
+ float *x; // activation at current time stamp (dim,)
53
+ float *xb; // same, but inside a residual branch (dim,)
54
+ float *xb2; // an additional buffer just for convenience (dim,)
55
+ float *hb; // buffer for hidden dimension in the ffn (hidden_dim,)
56
+ float *hb2; // buffer for hidden dimension in the ffn (hidden_dim,)
57
+ float *q; // query (dim,)
58
+ float *k; // key (dim,)
59
+ float *v; // value (dim,)
60
+ float *att; // buffer for scores/attention values (n_heads, seq_len)
61
+ float *logits; // output logits
62
+ // kv cache
63
+ float* key_cache; // (layer, seq_len, dim)
64
+ float* value_cache; // (layer, seq_len, dim)
65
+ } RunState;
66
+
67
+ typedef struct {
68
+ Config config; // the hyperparameters of the architecture (the blueprint)
69
+ TransformerWeights weights; // the weights of the model
70
+ RunState state; // buffers for the "wave" of activations in the forward pass
71
+ // some more state needed to properly clean up the memory mapping (sigh)
72
+ int fd; // file descriptor for memory mapping
73
+ float* data; // memory mapped data pointer
74
+ ssize_t file_size; // size of the checkpoint file in bytes
75
+ } Transformer;
76
+
77
+ void malloc_run_state(RunState* s, Config* p) {
78
+ // we calloc instead of malloc to keep valgrind happy
79
+ int kv_dim = (p->dim * p->n_kv_heads) / p->n_heads;
80
+ s->x = calloc(p->dim, sizeof(float));
81
+ s->xb = calloc(p->dim, sizeof(float));
82
+ s->xb2 = calloc(p->dim, sizeof(float));
83
+ s->hb = calloc(p->hidden_dim, sizeof(float));
84
+ s->hb2 = calloc(p->hidden_dim, sizeof(float));
85
+ s->q = calloc(p->dim, sizeof(float));
86
+ s->key_cache = calloc(p->n_layers * p->seq_len * kv_dim, sizeof(float));
87
+ s->value_cache = calloc(p->n_layers * p->seq_len * kv_dim, sizeof(float));
88
+ s->att = calloc(p->n_heads * p->seq_len, sizeof(float));
89
+ s->logits = calloc(p->vocab_size, sizeof(float));
90
+ // ensure all mallocs went fine
91
+ if (!s->x || !s->xb || !s->xb2 || !s->hb || !s->hb2 || !s->q
92
+ || !s->key_cache || !s->value_cache || !s->att || !s->logits) {
93
+ fprintf(stderr, "malloc failed!\n");
94
+ exit(EXIT_FAILURE);
95
+ }
96
+ }
97
+
98
+ void free_run_state(RunState* s) {
99
+ free(s->x);
100
+ free(s->xb);
101
+ free(s->xb2);
102
+ free(s->hb);
103
+ free(s->hb2);
104
+ free(s->q);
105
+ free(s->att);
106
+ free(s->logits);
107
+ free(s->key_cache);
108
+ free(s->value_cache);
109
+ }
110
+
111
+ void memory_map_weights(TransformerWeights *w, Config* p, float* ptr, int shared_weights) {
112
+ int head_size = p->dim / p->n_heads;
113
+ // make sure the multiplications below are done in 64bit to fit the parameter counts of 13B+ models
114
+ unsigned long long n_layers = p->n_layers;
115
+ w->token_embedding_table = ptr;
116
+ ptr += p->vocab_size * p->dim;
117
+ w->rms_att_weight = ptr;
118
+ ptr += n_layers * p->dim;
119
+ w->wq = ptr;
120
+ ptr += n_layers * p->dim * (p->n_heads * head_size);
121
+ w->wk = ptr;
122
+ ptr += n_layers * p->dim * (p->n_kv_heads * head_size);
123
+ w->wv = ptr;
124
+ ptr += n_layers * p->dim * (p->n_kv_heads * head_size);
125
+ w->wo = ptr;
126
+ ptr += n_layers * (p->n_heads * head_size) * p->dim;
127
+ w->rms_ffn_weight = ptr;
128
+ ptr += n_layers * p->dim;
129
+ w->w1 = ptr;
130
+ ptr += n_layers * p->dim * p->hidden_dim;
131
+ w->w2 = ptr;
132
+ ptr += n_layers * p->hidden_dim * p->dim;
133
+ w->w3 = ptr;
134
+ ptr += n_layers * p->dim * p->hidden_dim;
135
+ w->rms_final_weight = ptr;
136
+ ptr += p->dim;
137
+ ptr += p->seq_len * head_size / 2; // skip what used to be freq_cis_real (for RoPE)
138
+ ptr += p->seq_len * head_size / 2; // skip what used to be freq_cis_imag (for RoPE)
139
+ w->wcls = shared_weights ? w->token_embedding_table : ptr;
140
+ }
141
+
142
+ void read_checkpoint(char* checkpoint, Config* config, TransformerWeights* weights,
143
+ int* fd, float** data, ssize_t* file_size) {
144
+ FILE *file = fopen(checkpoint, "rb");
145
+ if (!file) { fprintf(stderr, "Couldn't open file %s\n", checkpoint); exit(EXIT_FAILURE); }
146
+ // read in the config header
147
+ if (fread(config, sizeof(Config), 1, file) != 1) { exit(EXIT_FAILURE); }
148
+ // negative vocab size is hacky way of signaling unshared weights. bit yikes.
149
+ int shared_weights = config->vocab_size > 0 ? 1 : 0;
150
+ config->vocab_size = abs(config->vocab_size);
151
+ // figure out the file size
152
+ fseek(file, 0, SEEK_END); // move file pointer to end of file
153
+ *file_size = ftell(file); // get the file size, in bytes
154
+ fclose(file);
155
+ // memory map the Transformer weights into the data pointer
156
+ *fd = open(checkpoint, O_RDONLY); // open in read only mode
157
+ if (*fd == -1) { fprintf(stderr, "open failed!\n"); exit(EXIT_FAILURE); }
158
+ *data = mmap(NULL, *file_size, PROT_READ, MAP_PRIVATE, *fd, 0);
159
+ if (*data == MAP_FAILED) { fprintf(stderr, "mmap failed!\n"); exit(EXIT_FAILURE); }
160
+ float* weights_ptr = *data + sizeof(Config)/sizeof(float);
161
+ memory_map_weights(weights, config, weights_ptr, shared_weights);
162
+ }
163
+
164
+ void build_transformer(Transformer *t, char* checkpoint_path) {
165
+ // read in the Config and the Weights from the checkpoint
166
+ read_checkpoint(checkpoint_path, &t->config, &t->weights, &t->fd, &t->data, &t->file_size);
167
+ // allocate the RunState buffers
168
+ malloc_run_state(&t->state, &t->config);
169
+ }
170
+
171
+ void free_transformer(Transformer* t) {
172
+ // close the memory mapping
173
+ if (t->data != MAP_FAILED) { munmap(t->data, t->file_size); }
174
+ if (t->fd != -1) { close(t->fd); }
175
+ // free the RunState buffers
176
+ free_run_state(&t->state);
177
+ }
178
+
179
+ // ----------------------------------------------------------------------------
180
+ // neural net blocks; the dynamics of the Transformer
181
+
182
+ void rmsnorm(float* o, float* x, float* weight, int size) {
183
+ // calculate sum of squares
184
+ float ss = 0.0f;
185
+ for (int j = 0; j < size; j++) {
186
+ ss += x[j] * x[j];
187
+ }
188
+ ss /= size;
189
+ ss += 1e-5f;
190
+ ss = 1.0f / sqrtf(ss);
191
+ // normalize and scale
192
+ for (int j = 0; j < size; j++) {
193
+ o[j] = weight[j] * (ss * x[j]);
194
+ }
195
+ }
196
+
197
+ void softmax(float* x, int size) {
198
+ // find max value (for numerical stability)
199
+ float max_val = x[0];
200
+ for (int i = 1; i < size; i++) {
201
+ if (x[i] > max_val) {
202
+ max_val = x[i];
203
+ }
204
+ }
205
+ // exp and sum
206
+ float sum = 0.0f;
207
+ for (int i = 0; i < size; i++) {
208
+ x[i] = expf(x[i] - max_val);
209
+ sum += x[i];
210
+ }
211
+ // normalize
212
+ for (int i = 0; i < size; i++) {
213
+ x[i] /= sum;
214
+ }
215
+ }
216
+
217
+ void matmul(float* xout, float* x, float* w, int n, int d) {
218
+ // W (d,n) @ x (n,) -> xout (d,)
219
+ // by far the most amount of time is spent inside this little function
220
+ int i;
221
+ #pragma omp parallel for private(i)
222
+ for (i = 0; i < d; i++) {
223
+ float val = 0.0f;
224
+ for (int j = 0; j < n; j++) {
225
+ val += w[i * n + j] * x[j];
226
+ }
227
+ xout[i] = val;
228
+ }
229
+ }
230
+
231
+ float* forward(Transformer* transformer, int token, int pos) {
232
+
233
+ // a few convenience variables
234
+ Config* p = &transformer->config;
235
+ TransformerWeights* w = &transformer->weights;
236
+ RunState* s = &transformer->state;
237
+ float *x = s->x;
238
+ int dim = p->dim;
239
+ int kv_dim = (p->dim * p->n_kv_heads) / p->n_heads;
240
+ int kv_mul = p->n_heads / p->n_kv_heads; // integer multiplier of the kv sharing in multiquery
241
+ int hidden_dim = p->hidden_dim;
242
+ int head_size = dim / p->n_heads;
243
+
244
+ // copy the token embedding into x
245
+ float* content_row = w->token_embedding_table + token * dim;
246
+ memcpy(x, content_row, dim*sizeof(*x));
247
+
248
+ // forward all the layers
249
+ for(unsigned long long l = 0; l < p->n_layers; l++) {
250
+
251
+ // attention rmsnorm
252
+ rmsnorm(s->xb, x, w->rms_att_weight + l*dim, dim);
253
+
254
+ // key and value point to the kv cache
255
+ int loff = l * p->seq_len * kv_dim; // kv cache layer offset for convenience
256
+ s->k = s->key_cache + loff + pos * kv_dim;
257
+ s->v = s->value_cache + loff + pos * kv_dim;
258
+
259
+ // qkv matmuls for this position
260
+ matmul(s->q, s->xb, w->wq + l*dim*dim, dim, dim);
261
+ matmul(s->k, s->xb, w->wk + l*dim*kv_dim, dim, kv_dim);
262
+ matmul(s->v, s->xb, w->wv + l*dim*kv_dim, dim, kv_dim);
263
+
264
+ // RoPE relative positional encoding: complex-valued rotate q and k in each head
265
+ for (int i = 0; i < dim; i+=2) {
266
+ int head_dim = i % head_size;
267
+ float freq = 1.0f / powf(10000.0f, head_dim / (float)head_size);
268
+ float val = pos * freq;
269
+ float fcr = cosf(val);
270
+ float fci = sinf(val);
271
+ int rotn = i < kv_dim ? 2 : 1; // how many vectors? 2 = q & k, 1 = q only
272
+ for (int v = 0; v < rotn; v++) {
273
+ float* vec = v == 0 ? s->q : s->k; // the vector to rotate (query or key)
274
+ float v0 = vec[i];
275
+ float v1 = vec[i+1];
276
+ vec[i] = v0 * fcr - v1 * fci;
277
+ vec[i+1] = v0 * fci + v1 * fcr;
278
+ }
279
+ }
280
+
281
+ // multihead attention. iterate over all heads
282
+ int h;
283
+ #pragma omp parallel for private(h)
284
+ for (h = 0; h < p->n_heads; h++) {
285
+ // get the query vector for this head
286
+ float* q = s->q + h * head_size;
287
+ // attention scores for this head
288
+ float* att = s->att + h * p->seq_len;
289
+ // iterate over all timesteps, including the current one
290
+ for (int t = 0; t <= pos; t++) {
291
+ // get the key vector for this head and at this timestep
292
+ float* k = s->key_cache + loff + t * kv_dim + (h / kv_mul) * head_size;
293
+ // calculate the attention score as the dot product of q and k
294
+ float score = 0.0f;
295
+ for (int i = 0; i < head_size; i++) {
296
+ score += q[i] * k[i];
297
+ }
298
+ score /= sqrtf(head_size);
299
+ // save the score to the attention buffer
300
+ att[t] = score;
301
+ }
302
+
303
+ // softmax the scores to get attention weights, from 0..pos inclusively
304
+ softmax(att, pos + 1);
305
+
306
+ // weighted sum of the values, store back into xb
307
+ float* xb = s->xb + h * head_size;
308
+ memset(xb, 0, head_size * sizeof(float));
309
+ for (int t = 0; t <= pos; t++) {
310
+ // get the value vector for this head and at this timestep
311
+ float* v = s->value_cache + loff + t * kv_dim + (h / kv_mul) * head_size;
312
+ // get the attention weight for this timestep
313
+ float a = att[t];
314
+ // accumulate the weighted value into xb
315
+ for (int i = 0; i < head_size; i++) {
316
+ xb[i] += a * v[i];
317
+ }
318
+ }
319
+ }
320
+
321
+ // final matmul to get the output of the attention
322
+ matmul(s->xb2, s->xb, w->wo + l*dim*dim, dim, dim);
323
+
324
+ // residual connection back into x
325
+ for (int i = 0; i < dim; i++) {
326
+ x[i] += s->xb2[i];
327
+ }
328
+
329
+ // ffn rmsnorm
330
+ rmsnorm(s->xb, x, w->rms_ffn_weight + l*dim, dim);
331
+
332
+ // Now for FFN in PyTorch we have: self.w2(F.silu(self.w1(x)) * self.w3(x))
333
+ // first calculate self.w1(x) and self.w3(x)
334
+ matmul(s->hb, s->xb, w->w1 + l*dim*hidden_dim, dim, hidden_dim);
335
+ matmul(s->hb2, s->xb, w->w3 + l*dim*hidden_dim, dim, hidden_dim);
336
+
337
+ // SwiGLU non-linearity
338
+ for (int i = 0; i < hidden_dim; i++) {
339
+ float val = s->hb[i];
340
+ // silu(x)=x*σ(x), where σ(x) is the logistic sigmoid
341
+ val *= (1.0f / (1.0f + expf(-val)));
342
+ // elementwise multiply with w3(x)
343
+ val *= s->hb2[i];
344
+ s->hb[i] = val;
345
+ }
346
+
347
+ // final matmul to get the output of the ffn
348
+ matmul(s->xb, s->hb, w->w2 + l*dim*hidden_dim, hidden_dim, dim);
349
+
350
+ // residual connection
351
+ for (int i = 0; i < dim; i++) {
352
+ x[i] += s->xb[i];
353
+ }
354
+ }
355
+
356
+ // final rmsnorm
357
+ rmsnorm(x, x, w->rms_final_weight, dim);
358
+
359
+ // classifier into logits
360
+ matmul(s->logits, x, w->wcls, p->dim, p->vocab_size);
361
+ return s->logits;
362
+ }
363
+
364
+ // ----------------------------------------------------------------------------
365
+ // The Byte Pair Encoding (BPE) Tokenizer that translates strings <-> tokens
366
+
367
+ typedef struct {
368
+ char *str;
369
+ int id;
370
+ } TokenIndex;
371
+
372
+ typedef struct {
373
+ char** vocab;
374
+ float* vocab_scores;
375
+ TokenIndex *sorted_vocab;
376
+ int vocab_size;
377
+ unsigned int max_token_length;
378
+ unsigned char byte_pieces[512]; // stores all single-byte strings
379
+ } Tokenizer;
380
+
381
+ int compare_tokens(const void *a, const void *b) {
382
+ return strcmp(((TokenIndex*)a)->str, ((TokenIndex*)b)->str);
383
+ }
384
+
385
+ void build_tokenizer(Tokenizer* t, char* tokenizer_path, int vocab_size) {
386
+ // i should have written the vocab_size into the tokenizer file... sigh
387
+ t->vocab_size = vocab_size;
388
+ // malloc space to hold the scores and the strings
389
+ t->vocab = (char**)malloc(vocab_size * sizeof(char*));
390
+ t->vocab_scores = (float*)malloc(vocab_size * sizeof(float));
391
+ t->sorted_vocab = NULL; // initialized lazily
392
+ for (int i = 0; i < 256; i++) {
393
+ t->byte_pieces[i * 2] = (unsigned char)i;
394
+ t->byte_pieces[i * 2 + 1] = '\0';
395
+ }
396
+ // read in the file
397
+ FILE *file = fopen(tokenizer_path, "rb");
398
+ if (!file) { fprintf(stderr, "couldn't load %s\n", tokenizer_path); exit(EXIT_FAILURE); }
399
+ if (fread(&t->max_token_length, sizeof(int), 1, file) != 1) { fprintf(stderr, "failed read\n"); exit(EXIT_FAILURE); }
400
+ int len;
401
+ for (int i = 0; i < vocab_size; i++) {
402
+ if (fread(t->vocab_scores + i, sizeof(float), 1, file) != 1) { fprintf(stderr, "failed read\n"); exit(EXIT_FAILURE);}
403
+ if (fread(&len, sizeof(int), 1, file) != 1) { fprintf(stderr, "failed read\n"); exit(EXIT_FAILURE); }
404
+ t->vocab[i] = (char *)malloc(len + 1);
405
+ if (fread(t->vocab[i], len, 1, file) != 1) { fprintf(stderr, "failed read\n"); exit(EXIT_FAILURE); }
406
+ t->vocab[i][len] = '\0'; // add the string terminating token
407
+ }
408
+ fclose(file);
409
+ }
410
+
411
+ void free_tokenizer(Tokenizer* t) {
412
+ for (int i = 0; i < t->vocab_size; i++) { free(t->vocab[i]); }
413
+ free(t->vocab);
414
+ free(t->vocab_scores);
415
+ free(t->sorted_vocab);
416
+ }
417
+
418
+ char* decode(Tokenizer* t, int prev_token, int token) {
419
+ char *piece = t->vocab[token];
420
+ // following BOS (1) token, sentencepiece decoder strips any leading whitespace (see PR #89)
421
+ if (prev_token == 1 && piece[0] == ' ') { piece++; }
422
+ // careful, some tokens designate raw bytes, and look like e.g. '<0x01>'
423
+ // parse this and convert and return the actual byte
424
+ unsigned char byte_val;
425
+ if (sscanf(piece, "<0x%02hhX>", &byte_val) == 1) {
426
+ piece = (char*)t->byte_pieces + byte_val * 2;
427
+ }
428
+ return piece;
429
+ }
430
+
431
+ void safe_printf(char *piece) {
432
+ // piece might be a raw byte token, and we only want to print printable chars or whitespace
433
+ // because some of the other bytes can be various control codes, backspace, etc.
434
+ if (piece == NULL) { return; }
435
+ if (piece[0] == '\0') { return; }
436
+ if (piece[1] == '\0') {
437
+ unsigned char byte_val = piece[0];
438
+ if (!(isprint(byte_val) || isspace(byte_val))) {
439
+ return; // bad byte, don't print it
440
+ }
441
+ }
442
+ printf("%s", piece);
443
+ }
444
+
445
+ int str_lookup(char *str, TokenIndex *sorted_vocab, int vocab_size) {
446
+ // efficiently find the perfect match for str in vocab, return its index or -1 if not found
447
+ TokenIndex tok = { .str = str }; // acts as the key to search for
448
+ TokenIndex *res = bsearch(&tok, sorted_vocab, vocab_size, sizeof(TokenIndex), compare_tokens);
449
+ return res != NULL ? res->id : -1;
450
+ }
451
+
452
+ void encode(Tokenizer* t, char *text, int8_t bos, int8_t eos, int *tokens, int *n_tokens) {
453
+ // encode the string text (input) into an upper-bound preallocated tokens[] array
454
+ // bos != 0 means prepend the BOS token (=1), eos != 0 means append the EOS token (=2)
455
+ if (text == NULL) { fprintf(stderr, "cannot encode NULL text\n"); exit(EXIT_FAILURE); }
456
+
457
+ if (t->sorted_vocab == NULL) {
458
+ // lazily malloc and sort the vocabulary
459
+ t->sorted_vocab = malloc(t->vocab_size * sizeof(TokenIndex));
460
+ for (int i = 0; i < t->vocab_size; i++) {
461
+ t->sorted_vocab[i].str = t->vocab[i];
462
+ t->sorted_vocab[i].id = i;
463
+ }
464
+ qsort(t->sorted_vocab, t->vocab_size, sizeof(TokenIndex), compare_tokens);
465
+ }
466
+
467
+ // create a temporary buffer that will store merge candidates of always two consecutive tokens
468
+ // *2 for concat, +1 for null terminator +2 for UTF8 (in case max_token_length is 1)
469
+ char* str_buffer = malloc((t->max_token_length*2 +1 +2) * sizeof(char));
470
+ size_t str_len = 0;
471
+
472
+ // start at 0 tokens
473
+ *n_tokens = 0;
474
+
475
+ // add optional BOS (=1) token, if desired
476
+ if (bos) tokens[(*n_tokens)++] = 1;
477
+
478
+ // add_dummy_prefix is true by default
479
+ // so prepend a dummy prefix token to the input string, but only if text != ""
480
+ // TODO: pretty sure this isn't correct in the general case but I don't have the
481
+ // energy to read more of the sentencepiece code to figure out what it's doing
482
+ if (text[0] != '\0') {
483
+ int dummy_prefix = str_lookup(" ", t->sorted_vocab, t->vocab_size);
484
+ tokens[(*n_tokens)++] = dummy_prefix;
485
+ }
486
+
487
+ // Okay UTF-8 time. This will get messy. Here is the reference from Wikipedia:
488
+ // Code point ↔ UTF-8 conversion
489
+ // First code point Last code point Byte 1 Byte 2 Byte 3 Byte 4
490
+ // U+0000 U+007F 0xxxxxxx
491
+ // U+0080 U+07FF 110xxxxx 10xxxxxx
492
+ // U+0800 U+FFFF 1110xxxx 10xxxxxx 10xxxxxx
493
+ // U+10000 U+10FFFF 11110xxx 10xxxxxx 10xxxxxx 10xxxxxx
494
+
495
+ // process the raw (UTF-8) byte sequence of the input string
496
+ for (char *c = text; *c != '\0'; c++) {
497
+
498
+ // reset buffer if the current byte is ASCII or a leading byte
499
+ // 0xC0 is 11000000, so (*c & 0xC0) keeps the first 2 bits and zeros the rest
500
+ // 0x80 is 10000000
501
+ // in UTF-8, all continuation bytes start with "10" in first two bits
502
+ // so in English this is: "if this byte is not a continuation byte"
503
+ if ((*c & 0xC0) != 0x80) {
504
+ // this byte must be either a leading byte (11...) or an ASCII char (0x...)
505
+ // => reset our location, as we're starting a new UTF-8 codepoint
506
+ str_len = 0;
507
+ }
508
+
509
+ // append the current byte to the buffer
510
+ str_buffer[str_len++] = *c; // ++ is post-increment, incremented after this line
511
+ str_buffer[str_len] = '\0';
512
+
513
+ // while the next character is a continuation byte, continue appending
514
+ // but if there are too many of them, just stop to avoid overruning str_buffer size.
515
+ if ((*(c+1) & 0xC0) == 0x80 && str_len < 4) {
516
+ continue;
517
+ }
518
+
519
+ // ok c+1 is not a continuation byte, so we've read in a full codepoint
520
+ int id = str_lookup(str_buffer, t->sorted_vocab, t->vocab_size);
521
+
522
+ if (id != -1) {
523
+ // we found this codepoint in vocab, add it as a token
524
+ tokens[(*n_tokens)++] = id;
525
+ } else {
526
+ // byte_fallback encoding: just encode each byte as a token
527
+ // +3 is here because the first 3 vocab elements are <unk>, <s>, </s>
528
+ // so the individual bytes only start at index 3
529
+ for (int i=0; i < str_len; i++) {
530
+ tokens[(*n_tokens)++] = (unsigned char)str_buffer[i] + 3;
531
+ }
532
+ }
533
+ str_len = 0; // protect against a sequence of stray UTF8 continuation bytes
534
+ }
535
+
536
+ // merge the best consecutive pair each iteration, according the scores in vocab_scores
537
+ while (1) {
538
+ float best_score = -1e10;
539
+ int best_id = -1;
540
+ int best_idx = -1;
541
+
542
+ for (int i=0; i < (*n_tokens-1); i++) {
543
+ // check if we can merge the pair (tokens[i], tokens[i+1])
544
+ sprintf(str_buffer, "%s%s", t->vocab[tokens[i]], t->vocab[tokens[i+1]]);
545
+ int id = str_lookup(str_buffer, t->sorted_vocab, t->vocab_size);
546
+ if (id != -1 && t->vocab_scores[id] > best_score) {
547
+ // this merge pair exists in vocab! record its score and position
548
+ best_score = t->vocab_scores[id];
549
+ best_id = id;
550
+ best_idx = i;
551
+ }
552
+ }
553
+
554
+ if (best_idx == -1) {
555
+ break; // we couldn't find any more pairs to merge, so we're done
556
+ }
557
+
558
+ // merge the consecutive pair (best_idx, best_idx+1) into new token best_id
559
+ tokens[best_idx] = best_id;
560
+ // delete token at position best_idx+1, shift the entire sequence back 1
561
+ for (int i = best_idx+1; i < (*n_tokens-1); i++) {
562
+ tokens[i] = tokens[i+1];
563
+ }
564
+ (*n_tokens)--; // token length decreased
565
+ }
566
+
567
+ // add optional EOS (=2) token, if desired
568
+ if (eos) tokens[(*n_tokens)++] = 2;
569
+
570
+ free(str_buffer);
571
+ }
572
+
573
+ // ----------------------------------------------------------------------------
574
+ // The Sampler, which takes logits and returns a sampled token
575
+ // sampling can be done in a few ways: greedy argmax, sampling, top-p sampling
576
+
577
+ typedef struct {
578
+ float prob;
579
+ int index;
580
+ } ProbIndex; // struct used when sorting probabilities during top-p sampling
581
+
582
+ typedef struct {
583
+ int vocab_size;
584
+ ProbIndex* probindex; // buffer used in top-p sampling
585
+ float temperature;
586
+ float topp;
587
+ unsigned long long rng_state;
588
+ } Sampler;
589
+
590
+ int sample_argmax(float* probabilities, int n) {
591
+ // return the index that has the highest probability
592
+ int max_i = 0;
593
+ float max_p = probabilities[0];
594
+ for (int i = 1; i < n; i++) {
595
+ if (probabilities[i] > max_p) {
596
+ max_i = i;
597
+ max_p = probabilities[i];
598
+ }
599
+ }
600
+ return max_i;
601
+ }
602
+
603
+ int sample_mult(float* probabilities, int n, float coin) {
604
+ // sample index from probabilities (they must sum to 1!)
605
+ // coin is a random number in [0, 1), usually from random_f32()
606
+ float cdf = 0.0f;
607
+ for (int i = 0; i < n; i++) {
608
+ cdf += probabilities[i];
609
+ if (coin < cdf) {
610
+ return i;
611
+ }
612
+ }
613
+ return n - 1; // in case of rounding errors
614
+ }
615
+
616
+ int compare(const void* a, const void* b) {
617
+ ProbIndex* a_ = (ProbIndex*) a;
618
+ ProbIndex* b_ = (ProbIndex*) b;
619
+ if (a_->prob > b_->prob) return -1;
620
+ if (a_->prob < b_->prob) return 1;
621
+ return 0;
622
+ }
623
+
624
+ int sample_topp(float* probabilities, int n, float topp, ProbIndex* probindex, float coin) {
625
+ // top-p sampling (or "nucleus sampling") samples from the smallest set of
626
+ // tokens that exceed probability topp. This way we never sample tokens that
627
+ // have very low probabilities and are less likely to go "off the rails".
628
+ // coin is a random number in [0, 1), usually from random_f32()
629
+
630
+ int n0 = 0;
631
+ // quicksort indices in descending order of probabilities
632
+ // values smaller than (1 - topp) / (n - 1) cannot be part of the result
633
+ // so for efficiency we crop these out as candidates before sorting
634
+ const float cutoff = (1.0f - topp) / (n - 1);
635
+ for (int i = 0; i < n; i++) {
636
+ if (probabilities[i] >= cutoff) {
637
+ probindex[n0].index = i;
638
+ probindex[n0].prob = probabilities[i];
639
+ n0++;
640
+ }
641
+ }
642
+ qsort(probindex, n0, sizeof(ProbIndex), compare);
643
+
644
+ // truncate the list where cumulative probability exceeds topp
645
+ float cumulative_prob = 0.0f;
646
+ int last_idx = n0 - 1; // in case of rounding errors consider all elements
647
+ for (int i = 0; i < n0; i++) {
648
+ cumulative_prob += probindex[i].prob;
649
+ if (cumulative_prob > topp) {
650
+ last_idx = i;
651
+ break; // we've exceeded topp by including last_idx
652
+ }
653
+ }
654
+
655
+ // sample from the truncated list
656
+ float r = coin * cumulative_prob;
657
+ float cdf = 0.0f;
658
+ for (int i = 0; i <= last_idx; i++) {
659
+ cdf += probindex[i].prob;
660
+ if (r < cdf) {
661
+ return probindex[i].index;
662
+ }
663
+ }
664
+ return probindex[last_idx].index; // in case of rounding errors
665
+ }
666
+
667
+ void build_sampler(Sampler* sampler, int vocab_size, float temperature, float topp, unsigned long long rng_seed) {
668
+ sampler->vocab_size = vocab_size;
669
+ sampler->temperature = temperature;
670
+ sampler->topp = topp;
671
+ sampler->rng_state = rng_seed;
672
+ // buffer only used with nucleus sampling; may not need but it's ~small
673
+ sampler->probindex = malloc(sampler->vocab_size * sizeof(ProbIndex));
674
+ }
675
+
676
+ void free_sampler(Sampler* sampler) {
677
+ free(sampler->probindex);
678
+ }
679
+
680
+ unsigned int random_u32(unsigned long long *state) {
681
+ // xorshift rng: https://en.wikipedia.org/wiki/Xorshift#xorshift.2A
682
+ *state ^= *state >> 12;
683
+ *state ^= *state << 25;
684
+ *state ^= *state >> 27;
685
+ return (*state * 0x2545F4914F6CDD1Dull) >> 32;
686
+ }
687
+ float random_f32(unsigned long long *state) { // random float32 in [0,1)
688
+ return (random_u32(state) >> 8) / 16777216.0f;
689
+ }
690
+
691
+ int sample(Sampler* sampler, float* logits) {
692
+ // sample the token given the logits and some hyperparameters
693
+ int next;
694
+ if (sampler->temperature == 0.0f) {
695
+ // greedy argmax sampling: take the token with the highest probability
696
+ next = sample_argmax(logits, sampler->vocab_size);
697
+ } else {
698
+ // apply the temperature to the logits
699
+ for (int q=0; q<sampler->vocab_size; q++) { logits[q] /= sampler->temperature; }
700
+ // apply softmax to the logits to get the probabilities for next token
701
+ softmax(logits, sampler->vocab_size);
702
+ // flip a (float) coin (this is our source of entropy for sampling)
703
+ float coin = random_f32(&sampler->rng_state);
704
+ // we sample from this distribution to get the next token
705
+ if (sampler->topp <= 0 || sampler->topp >= 1) {
706
+ // simply sample from the predicted probability distribution
707
+ next = sample_mult(logits, sampler->vocab_size, coin);
708
+ } else {
709
+ // top-p (nucleus) sampling, clamping the least likely tokens to zero
710
+ next = sample_topp(logits, sampler->vocab_size, sampler->topp, sampler->probindex, coin);
711
+ }
712
+ }
713
+ return next;
714
+ }
715
+
716
+ // ----------------------------------------------------------------------------
717
+ // utilities: time
718
+
719
+ long time_in_ms() {
720
+ // return time in milliseconds, for benchmarking the model speed
721
+ struct timespec time;
722
+ clock_gettime(CLOCK_REALTIME, &time);
723
+ return time.tv_sec * 1000 + time.tv_nsec / 1000000;
724
+ }
725
+
726
+ // ----------------------------------------------------------------------------
727
+ // generation loop
728
+
729
+ void generate(Transformer *transformer, Tokenizer *tokenizer, Sampler *sampler, char *prompt, int steps) {
730
+ char *empty_prompt = "";
731
+ if (prompt == NULL) { prompt = empty_prompt; }
732
+
733
+ // encode the (string) prompt into tokens sequence
734
+ int num_prompt_tokens = 0;
735
+ int* prompt_tokens = (int*)malloc((strlen(prompt)+3) * sizeof(int)); // +3 for '\0', ?BOS, ?EOS
736
+ encode(tokenizer, prompt, 1, 0, prompt_tokens, &num_prompt_tokens);
737
+ if (num_prompt_tokens < 1) {
738
+ fprintf(stderr, "something is wrong, expected at least 1 prompt token\n");
739
+ exit(EXIT_FAILURE);
740
+ }
741
+
742
+ // start the main loop
743
+ long start = 0; // used to time our code, only initialized after first iteration
744
+ int next; // will store the next token in the sequence
745
+ int token = prompt_tokens[0]; // kick off with the first token in the prompt
746
+ int pos = 0; // position in the sequence
747
+ while (pos < steps) {
748
+
749
+ // forward the transformer to get logits for the next token
750
+ float* logits = forward(transformer, token, pos);
751
+
752
+ // advance the state machine
753
+ if (pos < num_prompt_tokens - 1) {
754
+ // if we are still processing the input prompt, force the next prompt token
755
+ next = prompt_tokens[pos + 1];
756
+ } else {
757
+ // otherwise sample the next token from the logits
758
+ next = sample(sampler, logits);
759
+ }
760
+ pos++;
761
+
762
+ // data-dependent terminating condition: the BOS (=1) token delimits sequences
763
+ if (next == 1) { break; }
764
+
765
+ // print the token as string, decode it with the Tokenizer object
766
+ char* piece = decode(tokenizer, token, next);
767
+ safe_printf(piece); // same as printf("%s", piece), but skips "unsafe" bytes
768
+ fflush(stdout);
769
+ token = next;
770
+
771
+ // init the timer here because the first iteration can be slower
772
+ if (start == 0) { start = time_in_ms(); }
773
+ }
774
+ printf("\n");
775
+
776
+ // report achieved tok/s (pos-1 because the timer starts after first iteration)
777
+ if (pos > 1) {
778
+ long end = time_in_ms();
779
+ fprintf(stderr, "achieved tok/s: %f\n", (pos-1) / (double)(end-start)*1000);
780
+ }
781
+
782
+ free(prompt_tokens);
783
+ }
784
+
785
+ void read_stdin(const char* guide, char* buffer, size_t bufsize) {
786
+ // read a line from stdin, up to but not including \n
787
+ printf("%s", guide);
788
+ if (fgets(buffer, bufsize, stdin) != NULL) {
789
+ size_t len = strlen(buffer);
790
+ if (len > 0 && buffer[len - 1] == '\n') {
791
+ buffer[len - 1] = '\0'; // strip newline
792
+ }
793
+ }
794
+ }
795
+
796
+ // ----------------------------------------------------------------------------
797
+ // chat loop
798
+ // I manually inspected the tokens for a few chat conversations compared to
799
+ // python reference and that seemed ok, but this was not thoroughly tested and
800
+ // is not safely implemented, it's more a proof of concept atm.
801
+
802
+ void chat(Transformer *transformer, Tokenizer *tokenizer, Sampler *sampler,
803
+ char *cli_user_prompt, char *cli_system_prompt, int steps) {
804
+
805
+ // buffers for reading the system prompt and user prompt from stdin
806
+ // you'll notice they are soomewhat haphazardly and unsafely set atm
807
+ char system_prompt[512];
808
+ char user_prompt[512];
809
+ char rendered_prompt[1152];
810
+ int num_prompt_tokens = 0;
811
+ int* prompt_tokens = (int*)malloc(1152 * sizeof(int));
812
+ int user_idx;
813
+
814
+ // start the main loop
815
+ int8_t user_turn = 1; // user starts
816
+ int next; // will store the next token in the sequence
817
+ int token; // stores the current token to feed into the transformer
818
+ int prev_token;
819
+ int pos = 0; // position in the sequence
820
+ while (pos < steps) {
821
+
822
+ // when it is the user's turn to contribute tokens to the dialog...
823
+ if (user_turn) {
824
+ // get the (optional) system prompt at position 0
825
+ if (pos == 0) {
826
+ // at position 0, the user can also contribute a system prompt
827
+ if (cli_system_prompt == NULL) {
828
+ // system prompt was not passed in, attempt to get it from stdin
829
+ read_stdin("Enter system prompt (optional): ", system_prompt, sizeof(system_prompt));
830
+ } else {
831
+ // system prompt was passed in, use it
832
+ strcpy(system_prompt, cli_system_prompt);
833
+ }
834
+ }
835
+ // get the user prompt
836
+ if (pos == 0 && cli_user_prompt != NULL) {
837
+ // user prompt for position 0 was passed in, use it
838
+ strcpy(user_prompt, cli_user_prompt);
839
+ } else {
840
+ // otherwise get user prompt from stdin
841
+ read_stdin("User: ", user_prompt, sizeof(user_prompt));
842
+ }
843
+ // render user/system prompts into the Llama 2 Chat schema
844
+ if (pos == 0 && system_prompt[0] != '\0') {
845
+ char system_template[] = "[INST] <<SYS>>\n%s\n<</SYS>>\n\n%s [/INST]";
846
+ sprintf(rendered_prompt, system_template, system_prompt, user_prompt);
847
+ } else {
848
+ char user_template[] = "[INST] %s [/INST]";
849
+ sprintf(rendered_prompt, user_template, user_prompt);
850
+ }
851
+ // encode the rendered prompt into tokens
852
+ encode(tokenizer, rendered_prompt, 1, 0, prompt_tokens, &num_prompt_tokens);
853
+ user_idx = 0; // reset the user index
854
+ user_turn = 0;
855
+ printf("Assistant: ");
856
+ }
857
+
858
+ // determine the token to pass into the transformer next
859
+ if (user_idx < num_prompt_tokens) {
860
+ // if we are still processing the input prompt, force the next prompt token
861
+ token = prompt_tokens[user_idx++];
862
+ } else {
863
+ // otherwise use the next token sampled from previous turn
864
+ token = next;
865
+ }
866
+ // EOS (=2) token ends the Assistant turn
867
+ if (token == 2) { user_turn = 1; }
868
+
869
+ // forward the transformer to get logits for the next token
870
+ float* logits = forward(transformer, token, pos);
871
+ next = sample(sampler, logits);
872
+ pos++;
873
+
874
+ if (user_idx >= num_prompt_tokens && next != 2) {
875
+ // the Assistant is responding, so print its output
876
+ char* piece = decode(tokenizer, token, next);
877
+ safe_printf(piece); // same as printf("%s", piece), but skips "unsafe" bytes
878
+ fflush(stdout);
879
+ }
880
+ if (next == 2) { printf("\n"); }
881
+ }
882
+ printf("\n");
883
+ free(prompt_tokens);
884
+ }
885
+
886
+
887
+ // ----------------------------------------------------------------------------
888
+ // CLI, include only if not testing
889
+ #ifndef TESTING
890
+
891
+ void error_usage() {
892
+ fprintf(stderr, "Usage: run <checkpoint> [options]\n");
893
+ fprintf(stderr, "Example: run model.bin -n 256 -i \"Once upon a time\"\n");
894
+ fprintf(stderr, "Options:\n");
895
+ fprintf(stderr, " -t <float> temperature in [0,inf], default 1.0\n");
896
+ fprintf(stderr, " -p <float> p value in top-p (nucleus) sampling in [0,1] default 0.9\n");
897
+ fprintf(stderr, " -s <int> random seed, default time(NULL)\n");
898
+ fprintf(stderr, " -n <int> number of steps to run for, default 256. 0 = max_seq_len\n");
899
+ fprintf(stderr, " -i <string> input prompt\n");
900
+ fprintf(stderr, " -z <string> optional path to custom tokenizer\n");
901
+ fprintf(stderr, " -m <string> mode: generate|chat, default: generate\n");
902
+ fprintf(stderr, " -y <string> (optional) system prompt in chat mode\n");
903
+ exit(EXIT_FAILURE);
904
+ }
905
+
906
+ int main(int argc, char *argv[]) {
907
+
908
+ // default parameters
909
+ char *checkpoint_path = NULL; // e.g. out/model.bin
910
+ char *tokenizer_path = "tokenizer.bin";
911
+ float temperature = 1.0f; // 0.0 = greedy deterministic. 1.0 = original. don't set higher
912
+ float topp = 0.9f; // top-p in nucleus sampling. 1.0 = off. 0.9 works well, but slower
913
+ int steps = 256; // number of steps to run for
914
+ char *prompt = NULL; // prompt string
915
+ unsigned long long rng_seed = 0; // seed rng with time by default
916
+ char *mode = "generate"; // generate|chat
917
+ char *system_prompt = NULL; // the (optional) system prompt to use in chat mode
918
+
919
+ // poor man's C argparse so we can override the defaults above from the command line
920
+ if (argc >= 2) { checkpoint_path = argv[1]; } else { error_usage(); }
921
+ for (int i = 2; i < argc; i+=2) {
922
+ // do some basic validation
923
+ if (i + 1 >= argc) { error_usage(); } // must have arg after flag
924
+ if (argv[i][0] != '-') { error_usage(); } // must start with dash
925
+ if (strlen(argv[i]) != 2) { error_usage(); } // must be -x (one dash, one letter)
926
+ // read in the args
927
+ if (argv[i][1] == 't') { temperature = atof(argv[i + 1]); }
928
+ else if (argv[i][1] == 'p') { topp = atof(argv[i + 1]); }
929
+ else if (argv[i][1] == 's') { rng_seed = atoi(argv[i + 1]); }
930
+ else if (argv[i][1] == 'n') { steps = atoi(argv[i + 1]); }
931
+ else if (argv[i][1] == 'i') { prompt = argv[i + 1]; }
932
+ else if (argv[i][1] == 'z') { tokenizer_path = argv[i + 1]; }
933
+ else if (argv[i][1] == 'm') { mode = argv[i + 1]; }
934
+ else if (argv[i][1] == 'y') { system_prompt = argv[i + 1]; }
935
+ else { error_usage(); }
936
+ }
937
+
938
+ // parameter validation/overrides
939
+ if (rng_seed <= 0) rng_seed = (unsigned int)time(NULL);
940
+ if (temperature < 0.0) temperature = 0.0;
941
+ if (topp < 0.0 || 1.0 < topp) topp = 0.9;
942
+ if (steps < 0) steps = 0;
943
+
944
+ // build the Transformer via the model .bin file
945
+ Transformer transformer;
946
+ build_transformer(&transformer, checkpoint_path);
947
+ if (steps == 0 || steps > transformer.config.seq_len) steps = transformer.config.seq_len; // override to ~max length
948
+
949
+ // build the Tokenizer via the tokenizer .bin file
950
+ Tokenizer tokenizer;
951
+ build_tokenizer(&tokenizer, tokenizer_path, transformer.config.vocab_size);
952
+
953
+ // build the Sampler
954
+ Sampler sampler;
955
+ build_sampler(&sampler, transformer.config.vocab_size, temperature, topp, rng_seed);
956
+
957
+ // run!
958
+ if (strcmp(mode, "generate") == 0) {
959
+ generate(&transformer, &tokenizer, &sampler, prompt, steps);
960
+ } else if (strcmp(mode, "chat") == 0) {
961
+ chat(&transformer, &tokenizer, &sampler, prompt, system_prompt, steps);
962
+ } else {
963
+ fprintf(stderr, "unknown mode: %s\n", mode);
964
+ error_usage();
965
+ }
966
+
967
+ // memory and file handles cleanup
968
+ free_sampler(&sampler);
969
+ free_tokenizer(&tokenizer);
970
+ free_transformer(&transformer);
971
+ return 0;
972
+ }
973
+ #endif
llama2.c/run.ipynb ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {
6
+ "id": "HLdoj4cz-xal"
7
+ },
8
+ "source": [
9
+ "# Run.c\n",
10
+ "\n",
11
+ "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/karpathy/llama2.c/blob/master/run.ipynb)\n",
12
+ "\n",
13
+ "More details can be found in the [README.md](README.md) ."
14
+ ]
15
+ },
16
+ {
17
+ "cell_type": "code",
18
+ "execution_count": null,
19
+ "metadata": {
20
+ "id": "Une3Ozlnu1B7"
21
+ },
22
+ "outputs": [],
23
+ "source": [
24
+ "#@title Clone Project\n",
25
+ "\n",
26
+ "!git clone https://github.com/karpathy/llama2.c.git\n",
27
+ "%cd llama2.c"
28
+ ]
29
+ },
30
+ {
31
+ "cell_type": "code",
32
+ "execution_count": null,
33
+ "metadata": {},
34
+ "outputs": [],
35
+ "source": [
36
+ "#@title Build\n",
37
+ "\n",
38
+ "!make runfast"
39
+ ]
40
+ },
41
+ {
42
+ "cell_type": "code",
43
+ "execution_count": null,
44
+ "metadata": {
45
+ "id": "thm0ZBrtSgoC"
46
+ },
47
+ "outputs": [],
48
+ "source": [
49
+ "#@title Pick Your Model\n",
50
+ "\n",
51
+ "#@markdown Choose model\n",
52
+ "model = \"stories15M\" #@param [\"stories15M\", \"stories42M\", \"stories110M\"]\n",
53
+ "\n",
54
+ "download_url = \"\"\n",
55
+ "\n",
56
+ "if(model == \"stories15M\"):\n",
57
+ " download_url = \"https://huggingface.co/karpathy/tinyllamas/resolve/main/stories15M.bin\"\n",
58
+ "if(model == \"stories42M\"):\n",
59
+ " download_url = \"https://huggingface.co/karpathy/tinyllamas/resolve/main/stories42M.bin\"\n",
60
+ "if(model == \"stories110M\"):\n",
61
+ " download_url = \"https://huggingface.co/karpathy/tinyllamas/resolve/main/stories110M.bin\"\n",
62
+ "\n",
63
+ "print(f\"download_url: {download_url}\")\n",
64
+ "\n",
65
+ "!wget $download_url\n",
66
+ "\n",
67
+ "model_file = model + \".bin\""
68
+ ]
69
+ },
70
+ {
71
+ "cell_type": "code",
72
+ "execution_count": null,
73
+ "metadata": {
74
+ "id": "OgAc3KjuT-NM"
75
+ },
76
+ "outputs": [],
77
+ "source": [
78
+ "#@title Generate Stories\n",
79
+ "\n",
80
+ "# Generate args\n",
81
+ "max_token = 256 #@param {type:\"slider\", min:32, max:1024, step:32}\n",
82
+ "temperature = 0.8 #@param {type:\"slider\", min:0.0, max:1, step:0.05}\n",
83
+ "top_p = 0.9 #@param {type:\"slider\", min:0.0, max:1.0, step:0.05}\n",
84
+ "prompt = \"One day, Lily met a Shoggoth\" #@param {type:\"string\"}\n",
85
+ "\n",
86
+ "print(f\"model: {model_file}, max_token: {max_token}, temperature: {temperature}, top_p: {top_p}, prompt: {prompt}\")\n",
87
+ "print(f\"----------------------------\\n\")\n",
88
+ "\n",
89
+ "cmd = f'./run {model_file} -t {temperature} -p {top_p} -n {max_token} -i \"{prompt}\"'\n",
90
+ "!{cmd}"
91
+ ]
92
+ },
93
+ {
94
+ "cell_type": "code",
95
+ "execution_count": null,
96
+ "metadata": {},
97
+ "outputs": [],
98
+ "source": [
99
+ "#@title Run Meta's Llama 2 models\n",
100
+ "\n",
101
+ "#@markdown input your huggingface [access token](https://huggingface.co/settings/tokens) to download Meta's Llama 2 models.\n",
102
+ "\n",
103
+ "from huggingface_hub import snapshot_download\n",
104
+ "\n",
105
+ "token = \"replace your huggingface access token\" #@param {type:\"string\"}\n",
106
+ "path = snapshot_download(repo_id=\"meta-llama/Llama-2-7b\",cache_dir=\"Llama-2-7b\", use_auth_token=token)\n",
107
+ "\n",
108
+ "!python export.py llama2_7b.bin --meta-llama $path\n",
109
+ "\n",
110
+ "print(\"./run llama2_7b.bin\\n\")\n",
111
+ "!./run llama2_7b.bin"
112
+ ]
113
+ }
114
+ ],
115
+ "metadata": {
116
+ "colab": {
117
+ "private_outputs": true,
118
+ "provenance": []
119
+ },
120
+ "kernelspec": {
121
+ "display_name": "Python 3",
122
+ "name": "python3"
123
+ },
124
+ "language_info": {
125
+ "name": "python"
126
+ }
127
+ },
128
+ "nbformat": 4,
129
+ "nbformat_minor": 0
130
+ }
llama2.c/runq.c ADDED
@@ -0,0 +1,1092 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /* Inference for Llama-2 Transformer model in pure C, int8 quantized forward pass. */
2
+
3
+ #include <stdio.h>
4
+ #include <stdlib.h>
5
+ #include <ctype.h>
6
+ #include <stdint.h>
7
+ #include <time.h>
8
+ #include <math.h>
9
+ #include <string.h>
10
+ #include <fcntl.h>
11
+ #if defined _WIN32
12
+ #include "win.h"
13
+ #else
14
+ #include <unistd.h>
15
+ #include <sys/mman.h>
16
+ #endif
17
+ // ----------------------------------------------------------------------------
18
+ // Globals
19
+ int GS = 0; // group size global for quantization of the weights
20
+
21
+ // ----------------------------------------------------------------------------
22
+ // Transformer model
23
+
24
+ typedef struct {
25
+ int dim; // transformer dimension
26
+ int hidden_dim; // for ffn layers
27
+ int n_layers; // number of layers
28
+ int n_heads; // number of query heads
29
+ int n_kv_heads; // number of key/value heads (can be < query heads because of multiquery)
30
+ int vocab_size; // vocabulary size, usually 256 (byte-level)
31
+ int seq_len; // max sequence length
32
+ } Config;
33
+
34
+ typedef struct {
35
+ int8_t* q; // quantized values
36
+ float* s; // scaling factors
37
+ } QuantizedTensor;
38
+
39
+ typedef struct {
40
+ // token embedding table
41
+ QuantizedTensor *q_tokens; // (vocab_size, dim)
42
+ float* token_embedding_table; // same, but dequantized
43
+
44
+ // weights for rmsnorms
45
+ float* rms_att_weight; // (layer, dim) rmsnorm weights
46
+ float* rms_ffn_weight; // (layer, dim)
47
+ // weights for matmuls. note dim == n_heads * head_size
48
+ QuantizedTensor *wq; // (layer, dim, n_heads * head_size)
49
+ QuantizedTensor *wk; // (layer, dim, n_kv_heads * head_size)
50
+ QuantizedTensor *wv; // (layer, dim, n_kv_heads * head_size)
51
+ QuantizedTensor *wo; // (layer, n_heads * head_size, dim)
52
+ // weights for ffn
53
+ QuantizedTensor *w1; // (layer, hidden_dim, dim)
54
+ QuantizedTensor *w2; // (layer, dim, hidden_dim)
55
+ QuantizedTensor *w3; // (layer, hidden_dim, dim)
56
+ // final rmsnorm
57
+ float* rms_final_weight; // (dim,)
58
+ // (optional) classifier weights for the logits, on the last layer
59
+ QuantizedTensor *wcls;
60
+ } TransformerWeights;
61
+
62
+ typedef struct {
63
+ // current wave of activations
64
+ float *x; // activation at current time stamp (dim,)
65
+ float *xb; // same, but inside a residual branch (dim,)
66
+ float *xb2; // an additional buffer just for convenience (dim,)
67
+ float *hb; // buffer for hidden dimension in the ffn (hidden_dim,)
68
+ float *hb2; // buffer for hidden dimension in the ffn (hidden_dim,)
69
+ QuantizedTensor xq; // quantized x (dim,)
70
+ QuantizedTensor hq; // quantized hb (hidden_dim,)
71
+ float *q; // query (dim,)
72
+ float *k; // key (dim,)
73
+ float *v; // value (dim,)
74
+ float *att; // buffer for scores/attention values (n_heads, seq_len)
75
+ float *logits; // output logits
76
+ // kv cache
77
+ float* key_cache; // (layer, seq_len, dim)
78
+ float* value_cache; // (layer, seq_len, dim)
79
+ } RunState;
80
+
81
+ typedef struct {
82
+ Config config; // the hyperparameters of the architecture (the blueprint)
83
+ TransformerWeights weights; // the weights of the model
84
+ RunState state; // buffers for the "wave" of activations in the forward pass
85
+ // some more state needed to properly clean up the memory mapping (sigh)
86
+ int fd; // file descriptor for memory mapping
87
+ float* data; // memory mapped data pointer
88
+ ssize_t file_size; // size of the checkpoint file in bytes
89
+ } Transformer;
90
+
91
+ void malloc_run_state(RunState* s, Config* p) {
92
+ // we calloc instead of malloc to keep valgrind happy
93
+ int kv_dim = (p->dim * p->n_kv_heads) / p->n_heads;
94
+ s->x = calloc(p->dim, sizeof(float));
95
+ s->xb = calloc(p->dim, sizeof(float));
96
+ s->xb2 = calloc(p->dim, sizeof(float));
97
+ s->hb = calloc(p->hidden_dim, sizeof(float));
98
+ s->hb2 = calloc(p->hidden_dim, sizeof(float));
99
+ s->xq = (QuantizedTensor) { .q = calloc(p->dim, sizeof(int8_t)), .s = calloc(p->dim, sizeof(float)) };
100
+ s->hq = (QuantizedTensor) { .q = calloc(p->hidden_dim, sizeof(int8_t)), .s = calloc(p->hidden_dim, sizeof(float)) };
101
+ s->q = calloc(p->dim, sizeof(float));
102
+ s->k = calloc(kv_dim, sizeof(float));
103
+ s->v = calloc(kv_dim, sizeof(float));
104
+ s->att = calloc(p->n_heads * p->seq_len, sizeof(float));
105
+ s->logits = calloc(p->vocab_size, sizeof(float));
106
+ s->key_cache = calloc(p->n_layers * p->seq_len * kv_dim, sizeof(float));
107
+ s->value_cache = calloc(p->n_layers * p->seq_len * kv_dim, sizeof(float));
108
+ // ensure all mallocs went fine
109
+ if (!s->x || !s->xb || !s->xb2 || !s->hb || !s->hb2 || !s->q
110
+ || !s->k || !s->v || !s->att || !s->logits || !s->key_cache
111
+ || !s->value_cache) {
112
+ fprintf(stderr, "malloc failed!\n");
113
+ exit(EXIT_FAILURE);
114
+ }
115
+ }
116
+
117
+ void free_run_state(RunState* s) {
118
+ free(s->x);
119
+ free(s->xb);
120
+ free(s->xb2);
121
+ free(s->hb);
122
+ free(s->hb2);
123
+ free(s->xq.q);
124
+ free(s->xq.s);
125
+ free(s->hq.q);
126
+ free(s->hq.s);
127
+ free(s->q);
128
+ free(s->k);
129
+ free(s->v);
130
+ free(s->att);
131
+ free(s->logits);
132
+ free(s->key_cache);
133
+ free(s->value_cache);
134
+ }
135
+
136
+ // ----------------------------------------------------------------------------
137
+ // Quantization functions
138
+
139
+ void dequantize(QuantizedTensor *qx, float* x, int n) {
140
+ for (int i = 0; i < n; i++) {
141
+ x[i] = qx->q[i] * qx->s[i / GS];
142
+ }
143
+ }
144
+
145
+ void quantize(QuantizedTensor *qx, float* x, int n) {
146
+ int num_groups = n / GS;
147
+ float Q_MAX = 127.0f;
148
+
149
+ for (int group = 0; group < num_groups; group++) {
150
+
151
+ // find the max absolute value in the current group
152
+ float wmax = 0.0;
153
+ for (int i = 0; i < GS; i++) {
154
+ float val = fabs(x[group * GS + i]);
155
+ if (val > wmax) {
156
+ wmax = val;
157
+ }
158
+ }
159
+
160
+ // calculate and write the scaling factor
161
+ float scale = wmax / Q_MAX;
162
+ qx->s[group] = scale;
163
+
164
+ // calculate and write the quantized values
165
+ for (int i = 0; i < GS; i++) {
166
+ float quant_value = x[group * GS + i] / scale; // scale
167
+ int8_t quantized = (int8_t) round(quant_value); // round and clamp
168
+ qx->q[group * GS + i] = quantized;
169
+ }
170
+ }
171
+ }
172
+
173
+ /* initialize `n` x quantized tensor (with `size_each` elements), starting from memory pointed at *ptr */
174
+ QuantizedTensor *init_quantized_tensors(void **ptr, int n, int size_each) {
175
+ void *p = *ptr;
176
+ QuantizedTensor *res = malloc(n * sizeof(QuantizedTensor));
177
+ for(int i=0; i<n; i++) {
178
+ /* map quantized int8 values*/
179
+ res[i].q = (int8_t*)p;
180
+ p = (int8_t*)p + size_each;
181
+ /* map scale factors */
182
+ res[i].s = (float*)p;
183
+ p = (float*)p + size_each / GS;
184
+ }
185
+ *ptr = p; // advance ptr to current position
186
+ return res;
187
+ }
188
+
189
+ void memory_map_weights(TransformerWeights *w, Config* p, void* ptr, uint8_t shared_classifier) {
190
+ int head_size = p->dim / p->n_heads;
191
+ // first are the parameters that are kept in fp32 (the rmsnorm (1D) weights)
192
+ float* fptr = (float*) ptr; // cast our pointer to float*
193
+ w->rms_att_weight = fptr;
194
+ fptr += p->n_layers * p->dim;
195
+ w->rms_ffn_weight = fptr;
196
+ fptr += p->n_layers * p->dim;
197
+ w->rms_final_weight = fptr;
198
+ fptr += p->dim;
199
+
200
+ // now read all the quantized weights
201
+ ptr = (void*)fptr; // now cast the pointer back to void*
202
+ w->q_tokens = init_quantized_tensors(&ptr, 1, p->vocab_size * p->dim);
203
+ // dequantize token embedding table
204
+ w->token_embedding_table = malloc(p->vocab_size * p->dim * sizeof(float));
205
+ dequantize(w->q_tokens, w->token_embedding_table, p->vocab_size * p->dim);
206
+
207
+ w->wq = init_quantized_tensors(&ptr, p->n_layers, p->dim * (p->n_heads * head_size));
208
+ w->wk = init_quantized_tensors(&ptr, p->n_layers, p->dim * (p->n_kv_heads * head_size));
209
+ w->wv = init_quantized_tensors(&ptr, p->n_layers, p->dim * (p->n_kv_heads * head_size));
210
+ w->wo = init_quantized_tensors(&ptr, p->n_layers, (p->n_heads * head_size) * p->dim);
211
+
212
+ w->w1 = init_quantized_tensors(&ptr, p->n_layers, p->dim * p->hidden_dim);
213
+ w->w2 = init_quantized_tensors(&ptr, p->n_layers, p->hidden_dim * p->dim);
214
+ w->w3 = init_quantized_tensors(&ptr, p->n_layers, p->dim * p->hidden_dim);
215
+
216
+ w->wcls = shared_classifier ? w->q_tokens : init_quantized_tensors(&ptr, 1, p->dim * p->vocab_size);
217
+ }
218
+
219
+ void read_checkpoint(char* checkpoint, Config* config, TransformerWeights* weights,
220
+ int* fd, float** data, ssize_t* file_size) {
221
+ FILE *file = fopen(checkpoint, "rb");
222
+ if (!file) { fprintf(stderr, "Couldn't open file %s\n", checkpoint); exit(EXIT_FAILURE); }
223
+ // read in magic number (uint32), has to be 0x616b3432, i.e. "ak42" in ASCII
224
+ uint32_t magic_number;
225
+ if (fread(&magic_number, sizeof(uint32_t), 1, file) != 1) { exit(EXIT_FAILURE); }
226
+ if (magic_number != 0x616b3432) { fprintf(stderr, "Bad magic number\n"); exit(EXIT_FAILURE); }
227
+ // read in the version number (uint32), has to be 2
228
+ int version;
229
+ if (fread(&version, sizeof(int), 1, file) != 1) { exit(EXIT_FAILURE); }
230
+ if (version != 2) { fprintf(stderr, "Bad version %d, need version 2\n", version); exit(EXIT_FAILURE); }
231
+ int header_size = 256; // the header size for version 2 in bytes
232
+ // read in the Config
233
+ if (fread(config, sizeof(Config), 1, file) != 1) { exit(EXIT_FAILURE); }
234
+ // read in flags
235
+ uint8_t shared_classifier; // a byte to indicate if the classifier is shared
236
+ if (fread(&shared_classifier, sizeof(uint8_t), 1, file) != 1) { exit(EXIT_FAILURE); }
237
+ int group_size; // the group size used in quantization
238
+ if (fread(&group_size, sizeof(int), 1, file) != 1) { exit(EXIT_FAILURE); }
239
+ GS = group_size; // set as global, as it will be used in many places
240
+ // figure out the file size
241
+ fseek(file, 0, SEEK_END); // move file pointer to end of file
242
+ *file_size = ftell(file); // get the file size, in bytes
243
+ fclose(file);
244
+ // memory map the Transformer weights into the data pointer
245
+ *fd = open(checkpoint, O_RDONLY); // open in read only mode
246
+ if (*fd == -1) { fprintf(stderr, "open failed!\n"); exit(EXIT_FAILURE); }
247
+ *data = mmap(NULL, *file_size, PROT_READ, MAP_PRIVATE, *fd, 0);
248
+ if (*data == MAP_FAILED) { fprintf(stderr, "mmap failed!\n"); exit(EXIT_FAILURE); }
249
+ void* weights_ptr = ((char*)*data) + header_size; // skip header bytes. char is 1 byte
250
+ memory_map_weights(weights, config, weights_ptr, shared_classifier);
251
+ }
252
+
253
+ void build_transformer(Transformer *t, char* checkpoint_path) {
254
+ // read in the Config and the Weights from the checkpoint
255
+ read_checkpoint(checkpoint_path, &t->config, &t->weights, &t->fd, &t->data, &t->file_size);
256
+ // allocate the RunState buffers
257
+ malloc_run_state(&t->state, &t->config);
258
+ }
259
+
260
+ void free_transformer(Transformer* t) {
261
+ // free QuantizedTensors
262
+ free(t->weights.q_tokens);
263
+ free(t->weights.token_embedding_table);
264
+ free(t->weights.wq);
265
+ free(t->weights.wk);
266
+ free(t->weights.wv);
267
+ free(t->weights.wo);
268
+ free(t->weights.w1);
269
+ free(t->weights.w2);
270
+ free(t->weights.w3);
271
+ if(t->weights.wcls != t->weights.q_tokens) { free(t->weights.wcls); }
272
+ // close the memory mapping
273
+ if (t->data != MAP_FAILED) { munmap(t->data, t->file_size); }
274
+ if (t->fd != -1) { close(t->fd); }
275
+ // free the RunState buffers
276
+ free_run_state(&t->state);
277
+ }
278
+
279
+ // ----------------------------------------------------------------------------
280
+ // neural net blocks; the dynamics of the Transformer
281
+
282
+ void rmsnorm(float* o, float* x, float* weight, int size) {
283
+ // calculate sum of squares
284
+ float ss = 0.0f;
285
+ for (int j = 0; j < size; j++) {
286
+ ss += x[j] * x[j];
287
+ }
288
+ ss /= size;
289
+ ss += 1e-5f;
290
+ ss = 1.0f / sqrtf(ss);
291
+ // normalize and scale
292
+ for (int j = 0; j < size; j++) {
293
+ o[j] = weight[j] * (ss * x[j]);
294
+ }
295
+ }
296
+
297
+ void softmax(float* x, int size) {
298
+ // find max value (for numerical stability)
299
+ float max_val = x[0];
300
+ for (int i = 1; i < size; i++) {
301
+ if (x[i] > max_val) {
302
+ max_val = x[i];
303
+ }
304
+ }
305
+ // exp and sum
306
+ float sum = 0.0f;
307
+ for (int i = 0; i < size; i++) {
308
+ x[i] = expf(x[i] - max_val);
309
+ sum += x[i];
310
+ }
311
+ // normalize
312
+ for (int i = 0; i < size; i++) {
313
+ x[i] /= sum;
314
+ }
315
+ }
316
+
317
+ void matmul(float* xout, QuantizedTensor *x, QuantizedTensor *w, int n, int d) {
318
+ // W (d,n) @ x (n,) -> xout (d,)
319
+ // by far the most amount of time is spent inside this little function
320
+ // inputs to this function are both quantized
321
+
322
+ int i;
323
+ #pragma omp parallel for private(i)
324
+ for (i = 0; i < d; i++) {
325
+
326
+ float val = 0.0f;
327
+ int32_t ival = 0;
328
+ int in = i * n;
329
+
330
+ // do the matmul in groups of GS
331
+ int j;
332
+ for (j = 0; j <= n - GS; j += GS) {
333
+ for (int k = 0; k < GS; k++) {
334
+ ival += ((int32_t) x->q[j + k]) * ((int32_t) w->q[in + j + k]);
335
+ }
336
+ val += ((float) ival) * w->s[(in + j) / GS] * x->s[j / GS];
337
+ ival = 0;
338
+ }
339
+
340
+ xout[i] = val;
341
+ }
342
+ }
343
+
344
+ float* forward(Transformer* transformer, int token, int pos) {
345
+
346
+ // a few convenience variables
347
+ Config* p = &transformer->config;
348
+ TransformerWeights* w = &transformer->weights;
349
+ RunState* s = &transformer->state;
350
+ float *x = s->x;
351
+ int dim = p->dim;
352
+ int kv_dim = (p->dim * p->n_kv_heads) / p->n_heads;
353
+ int kv_mul = p->n_heads / p->n_kv_heads; // integer multiplier of the kv sharing in multiquery
354
+ int hidden_dim = p->hidden_dim;
355
+ int head_size = dim / p->n_heads;
356
+
357
+ // copy the token embedding into x
358
+ memcpy(x, w->token_embedding_table + token*dim, dim * sizeof(float));
359
+
360
+ // forward all the layers
361
+ for(int l = 0; l < p->n_layers; l++) {
362
+
363
+ // attention rmsnorm
364
+ rmsnorm(s->xb, x, w->rms_att_weight + l*dim, dim);
365
+
366
+ // qkv matmuls for this position
367
+ quantize(&s->xq, s->xb, dim);
368
+ matmul(s->q, &s->xq, w->wq + l, dim, dim);
369
+ matmul(s->k, &s->xq, w->wk + l, dim, kv_dim);
370
+ matmul(s->v, &s->xq, w->wv + l, dim, kv_dim);
371
+
372
+ // RoPE relative positional encoding: complex-valued rotate q and k in each head
373
+ for (int i = 0; i < dim; i+=2) {
374
+ int head_dim = i % head_size;
375
+ float freq = 1.0f / powf(10000.0f, head_dim / (float)head_size);
376
+ float val = pos * freq;
377
+ float fcr = cosf(val);
378
+ float fci = sinf(val);
379
+ int rotn = i < kv_dim ? 2 : 1; // how many vectors? 2 = q & k, 1 = q only
380
+ for (int v = 0; v < rotn; v++) {
381
+ float* vec = v == 0 ? s->q : s->k; // the vector to rotate (query or key)
382
+ float v0 = vec[i];
383
+ float v1 = vec[i+1];
384
+ vec[i] = v0 * fcr - v1 * fci;
385
+ vec[i+1] = v0 * fci + v1 * fcr;
386
+ }
387
+ }
388
+
389
+ // save key,value at this time step (pos) to our kv cache
390
+ int loff = l * p->seq_len * kv_dim; // kv cache layer offset for convenience
391
+ float* key_cache_row = s->key_cache + loff + pos * kv_dim;
392
+ float* value_cache_row = s->value_cache + loff + pos * kv_dim;
393
+ memcpy(key_cache_row, s->k, kv_dim * sizeof(*key_cache_row));
394
+ memcpy(value_cache_row, s->v, kv_dim * sizeof(*value_cache_row));
395
+
396
+ // multihead attention. iterate over all heads
397
+ int h;
398
+ #pragma omp parallel for private(h)
399
+ for (h = 0; h < p->n_heads; h++) {
400
+ // get the query vector for this head
401
+ float* q = s->q + h * head_size;
402
+ // attention scores for this head
403
+ float* att = s->att + h * p->seq_len;
404
+ // iterate over all timesteps, including the current one
405
+ for (int t = 0; t <= pos; t++) {
406
+ // get the key vector for this head and at this timestep
407
+ float* k = s->key_cache + loff + t * kv_dim + (h / kv_mul) * head_size;
408
+ // calculate the attention score as the dot product of q and k
409
+ float score = 0.0f;
410
+ for (int i = 0; i < head_size; i++) {
411
+ score += q[i] * k[i];
412
+ }
413
+ score /= sqrtf(head_size);
414
+ // save the score to the attention buffer
415
+ att[t] = score;
416
+ }
417
+
418
+ // softmax the scores to get attention weights, from 0..pos inclusively
419
+ softmax(att, pos + 1);
420
+
421
+ // weighted sum of the values, store back into xb
422
+ float* xb = s->xb + h * head_size;
423
+ memset(xb, 0, head_size * sizeof(float));
424
+ for (int t = 0; t <= pos; t++) {
425
+ // get the value vector for this head and at this timestep
426
+ float* v = s->value_cache + loff + t * kv_dim + (h / kv_mul) * head_size;
427
+ // get the attention weight for this timestep
428
+ float a = att[t];
429
+ // accumulate the weighted value into xb
430
+ for (int i = 0; i < head_size; i++) {
431
+ xb[i] += a * v[i];
432
+ }
433
+ }
434
+ }
435
+
436
+ // final matmul to get the output of the attention
437
+ quantize(&s->xq, s->xb, dim);
438
+ matmul(s->xb2, &s->xq, w->wo + l, dim, dim);
439
+
440
+ // residual connection back into x
441
+ for (int i = 0; i < dim; i++) {
442
+ x[i] += s->xb2[i];
443
+ }
444
+
445
+ // ffn rmsnorm
446
+ rmsnorm(s->xb, x, w->rms_ffn_weight + l*dim, dim);
447
+
448
+ // Now for FFN in PyTorch we have: self.w2(F.silu(self.w1(x)) * self.w3(x))
449
+ // first calculate self.w1(x) and self.w3(x)
450
+ quantize(&s->xq, s->xb, dim);
451
+ matmul(s->hb, &s->xq, w->w1 + l, dim, hidden_dim);
452
+ matmul(s->hb2, &s->xq, w->w3 + l, dim, hidden_dim);
453
+
454
+ // SwiGLU non-linearity
455
+ for (int i = 0; i < hidden_dim; i++) {
456
+ float val = s->hb[i];
457
+ // silu(x)=x*σ(x), where σ(x) is the logistic sigmoid
458
+ val *= (1.0f / (1.0f + expf(-val)));
459
+ // elementwise multiply with w3(x)
460
+ val *= s->hb2[i];
461
+ s->hb[i] = val;
462
+ }
463
+
464
+ // final matmul to get the output of the ffn
465
+ quantize(&s->hq, s->hb, hidden_dim);
466
+ matmul(s->xb, &s->hq, w->w2 + l, hidden_dim, dim);
467
+
468
+ // residual connection
469
+ for (int i = 0; i < dim; i++) {
470
+ x[i] += s->xb[i];
471
+ }
472
+ }
473
+
474
+ // final rmsnorm
475
+ rmsnorm(x, x, w->rms_final_weight, dim);
476
+
477
+ // classifier into logits
478
+ quantize(&s->xq, x, dim);
479
+ matmul(s->logits, &s->xq, w->wcls, dim, p->vocab_size);
480
+ return s->logits;
481
+ }
482
+
483
+ // ----------------------------------------------------------------------------
484
+ // The Byte Pair Encoding (BPE) Tokenizer that translates strings <-> tokens
485
+
486
+ typedef struct {
487
+ char *str;
488
+ int id;
489
+ } TokenIndex;
490
+
491
+ typedef struct {
492
+ char** vocab;
493
+ float* vocab_scores;
494
+ TokenIndex *sorted_vocab;
495
+ int vocab_size;
496
+ unsigned int max_token_length;
497
+ unsigned char byte_pieces[512]; // stores all single-byte strings
498
+ } Tokenizer;
499
+
500
+ int compare_tokens(const void *a, const void *b) {
501
+ return strcmp(((TokenIndex*)a)->str, ((TokenIndex*)b)->str);
502
+ }
503
+
504
+ void build_tokenizer(Tokenizer* t, char* tokenizer_path, int vocab_size) {
505
+ // i should have written the vocab_size into the tokenizer file... sigh
506
+ t->vocab_size = vocab_size;
507
+ // malloc space to hold the scores and the strings
508
+ t->vocab = (char**)malloc(vocab_size * sizeof(char*));
509
+ t->vocab_scores = (float*)malloc(vocab_size * sizeof(float));
510
+ t->sorted_vocab = NULL; // initialized lazily
511
+ for (int i = 0; i < 256; i++) {
512
+ t->byte_pieces[i * 2] = (unsigned char)i;
513
+ t->byte_pieces[i * 2 + 1] = '\0';
514
+ }
515
+ // read in the file
516
+ FILE *file = fopen(tokenizer_path, "rb");
517
+ if (!file) { fprintf(stderr, "couldn't load %s\n", tokenizer_path); exit(EXIT_FAILURE); }
518
+ if (fread(&t->max_token_length, sizeof(int), 1, file) != 1) { fprintf(stderr, "failed read\n"); exit(EXIT_FAILURE); }
519
+ int len;
520
+ for (int i = 0; i < vocab_size; i++) {
521
+ if (fread(t->vocab_scores + i, sizeof(float), 1, file) != 1) { fprintf(stderr, "failed read\n"); exit(EXIT_FAILURE);}
522
+ if (fread(&len, sizeof(int), 1, file) != 1) { fprintf(stderr, "failed read\n"); exit(EXIT_FAILURE); }
523
+ t->vocab[i] = (char *)malloc(len + 1);
524
+ if (fread(t->vocab[i], len, 1, file) != 1) { fprintf(stderr, "failed read\n"); exit(EXIT_FAILURE); }
525
+ t->vocab[i][len] = '\0'; // add the string terminating token
526
+ }
527
+ fclose(file);
528
+ }
529
+
530
+ void free_tokenizer(Tokenizer* t) {
531
+ for (int i = 0; i < t->vocab_size; i++) { free(t->vocab[i]); }
532
+ free(t->vocab);
533
+ free(t->vocab_scores);
534
+ free(t->sorted_vocab);
535
+ }
536
+
537
+ char* decode(Tokenizer* t, int prev_token, int token) {
538
+ char *piece = t->vocab[token];
539
+ // following BOS (1) token, sentencepiece decoder strips any leading whitespace (see PR #89)
540
+ if (prev_token == 1 && piece[0] == ' ') { piece++; }
541
+ // careful, some tokens designate raw bytes, and look like e.g. '<0x01>'
542
+ // parse this and convert and return the actual byte
543
+ unsigned char byte_val;
544
+ if (sscanf(piece, "<0x%02hhX>", &byte_val) == 1) {
545
+ piece = (char*)t->byte_pieces + byte_val * 2;
546
+ }
547
+ return piece;
548
+ }
549
+
550
+ void safe_printf(char *piece) {
551
+ // piece might be a raw byte token, and we only want to print printable chars or whitespace
552
+ // because some of the other bytes can be various control codes, backspace, etc.
553
+ if (piece == NULL) { return; }
554
+ if (piece[0] == '\0') { return; }
555
+ if (piece[1] == '\0') {
556
+ unsigned char byte_val = piece[0];
557
+ if (!(isprint(byte_val) || isspace(byte_val))) {
558
+ return; // bad byte, don't print it
559
+ }
560
+ }
561
+ printf("%s", piece);
562
+ }
563
+
564
+ int str_lookup(char *str, TokenIndex *sorted_vocab, int vocab_size) {
565
+ // efficiently find the perfect match for str in vocab, return its index or -1 if not found
566
+ TokenIndex tok = { .str = str }; // acts as the key to search for
567
+ TokenIndex *res = bsearch(&tok, sorted_vocab, vocab_size, sizeof(TokenIndex), compare_tokens);
568
+ return res != NULL ? res->id : -1;
569
+ }
570
+
571
+ void encode(Tokenizer* t, char *text, int8_t bos, int8_t eos, int *tokens, int *n_tokens) {
572
+ // encode the string text (input) into an upper-bound preallocated tokens[] array
573
+ // bos != 0 means prepend the BOS token (=1), eos != 0 means append the EOS token (=2)
574
+ if (text == NULL) { fprintf(stderr, "cannot encode NULL text\n"); exit(EXIT_FAILURE); }
575
+
576
+ if (t->sorted_vocab == NULL) {
577
+ // lazily malloc and sort the vocabulary
578
+ t->sorted_vocab = malloc(t->vocab_size * sizeof(TokenIndex));
579
+ for (int i = 0; i < t->vocab_size; i++) {
580
+ t->sorted_vocab[i].str = t->vocab[i];
581
+ t->sorted_vocab[i].id = i;
582
+ }
583
+ qsort(t->sorted_vocab, t->vocab_size, sizeof(TokenIndex), compare_tokens);
584
+ }
585
+
586
+ // create a temporary buffer that will store merge candidates of always two consecutive tokens
587
+ // *2 for concat, +1 for null terminator +2 for UTF8 (in case max_token_length is 1)
588
+ char* str_buffer = malloc((t->max_token_length*2 +1 +2) * sizeof(char));
589
+ size_t str_len = 0;
590
+
591
+ // start at 0 tokens
592
+ *n_tokens = 0;
593
+
594
+ // add optional BOS (=1) token, if desired
595
+ if (bos) tokens[(*n_tokens)++] = 1;
596
+
597
+ // add_dummy_prefix is true by default
598
+ // so prepend a dummy prefix token to the input string, but only if text != ""
599
+ // TODO: pretty sure this isn't correct in the general case but I don't have the
600
+ // energy to read more of the sentencepiece code to figure out what it's doing
601
+ if (text[0] != '\0') {
602
+ int dummy_prefix = str_lookup(" ", t->sorted_vocab, t->vocab_size);
603
+ tokens[(*n_tokens)++] = dummy_prefix;
604
+ }
605
+
606
+ // Okay UTF-8 time. This will get messy. Here is the reference from Wikipedia:
607
+ // Code point ↔ UTF-8 conversion
608
+ // First code point Last code point Byte 1 Byte 2 Byte 3 Byte 4
609
+ // U+0000 U+007F 0xxxxxxx
610
+ // U+0080 U+07FF 110xxxxx 10xxxxxx
611
+ // U+0800 U+FFFF 1110xxxx 10xxxxxx 10xxxxxx
612
+ // U+10000 U+10FFFF 11110xxx 10xxxxxx 10xxxxxx 10xxxxxx
613
+
614
+ // process the raw (UTF-8) byte sequence of the input string
615
+ for (char *c = text; *c != '\0'; c++) {
616
+
617
+ // reset buffer if the current byte is ASCII or a leading byte
618
+ // 0xC0 is 11000000, so (*c & 0xC0) keeps the first 2 bits and zeros the rest
619
+ // 0x80 is 10000000
620
+ // in UTF-8, all continuation bytes start with "10" in first two bits
621
+ // so in English this is: "if this byte is not a continuation byte"
622
+ if ((*c & 0xC0) != 0x80) {
623
+ // this byte must be either a leading byte (11...) or an ASCII char (0x...)
624
+ // => reset our location, as we're starting a new UTF-8 codepoint
625
+ str_len = 0;
626
+ }
627
+
628
+ // append the current byte to the buffer
629
+ str_buffer[str_len++] = *c; // ++ is post-increment, incremented after this line
630
+ str_buffer[str_len] = '\0';
631
+
632
+ // while the next character is a continuation byte, continue appending
633
+ // but if there are too many of them, just stop to avoid overruning str_buffer size.
634
+ if ((*(c+1) & 0xC0) == 0x80 && str_len < 4) {
635
+ continue;
636
+ }
637
+
638
+ // ok c+1 is not a continuation byte, so we've read in a full codepoint
639
+ int id = str_lookup(str_buffer, t->sorted_vocab, t->vocab_size);
640
+
641
+ if (id != -1) {
642
+ // we found this codepoint in vocab, add it as a token
643
+ tokens[(*n_tokens)++] = id;
644
+ } else {
645
+ // byte_fallback encoding: just encode each byte as a token
646
+ // +3 is here because the first 3 vocab elements are <unk>, <s>, </s>
647
+ // so the individual bytes only start at index 3
648
+ for (int i=0; i < str_len; i++) {
649
+ tokens[(*n_tokens)++] = (unsigned char)str_buffer[i] + 3;
650
+ }
651
+ }
652
+ str_len = 0; // protect against a sequence of stray UTF8 continuation bytes
653
+ }
654
+
655
+ // merge the best consecutive pair each iteration, according the scores in vocab_scores
656
+ while (1) {
657
+ float best_score = -1e10;
658
+ int best_id = -1;
659
+ int best_idx = -1;
660
+
661
+ for (int i=0; i < (*n_tokens-1); i++) {
662
+ // check if we can merge the pair (tokens[i], tokens[i+1])
663
+ sprintf(str_buffer, "%s%s", t->vocab[tokens[i]], t->vocab[tokens[i+1]]);
664
+ int id = str_lookup(str_buffer, t->sorted_vocab, t->vocab_size);
665
+ if (id != -1 && t->vocab_scores[id] > best_score) {
666
+ // this merge pair exists in vocab! record its score and position
667
+ best_score = t->vocab_scores[id];
668
+ best_id = id;
669
+ best_idx = i;
670
+ }
671
+ }
672
+
673
+ if (best_idx == -1) {
674
+ break; // we couldn't find any more pairs to merge, so we're done
675
+ }
676
+
677
+ // merge the consecutive pair (best_idx, best_idx+1) into new token best_id
678
+ tokens[best_idx] = best_id;
679
+ // delete token at position best_idx+1, shift the entire sequence back 1
680
+ for (int i = best_idx+1; i < (*n_tokens-1); i++) {
681
+ tokens[i] = tokens[i+1];
682
+ }
683
+ (*n_tokens)--; // token length decreased
684
+ }
685
+
686
+ // add optional EOS (=2) token, if desired
687
+ if (eos) tokens[(*n_tokens)++] = 2;
688
+
689
+ free(str_buffer);
690
+ }
691
+
692
+ // ----------------------------------------------------------------------------
693
+ // The Sampler, which takes logits and returns a sampled token
694
+ // sampling can be done in a few ways: greedy argmax, sampling, top-p sampling
695
+
696
+ typedef struct {
697
+ float prob;
698
+ int index;
699
+ } ProbIndex; // struct used when sorting probabilities during top-p sampling
700
+
701
+ typedef struct {
702
+ int vocab_size;
703
+ ProbIndex* probindex; // buffer used in top-p sampling
704
+ float temperature;
705
+ float topp;
706
+ unsigned long long rng_state;
707
+ } Sampler;
708
+
709
+ int sample_argmax(float* probabilities, int n) {
710
+ // return the index that has the highest probability
711
+ int max_i = 0;
712
+ float max_p = probabilities[0];
713
+ for (int i = 1; i < n; i++) {
714
+ if (probabilities[i] > max_p) {
715
+ max_i = i;
716
+ max_p = probabilities[i];
717
+ }
718
+ }
719
+ return max_i;
720
+ }
721
+
722
+ int sample_mult(float* probabilities, int n, float coin) {
723
+ // sample index from probabilities (they must sum to 1!)
724
+ // coin is a random number in [0, 1), usually from random_f32()
725
+ float cdf = 0.0f;
726
+ for (int i = 0; i < n; i++) {
727
+ cdf += probabilities[i];
728
+ if (coin < cdf) {
729
+ return i;
730
+ }
731
+ }
732
+ return n - 1; // in case of rounding errors
733
+ }
734
+
735
+ int compare(const void* a, const void* b) {
736
+ ProbIndex* a_ = (ProbIndex*) a;
737
+ ProbIndex* b_ = (ProbIndex*) b;
738
+ if (a_->prob > b_->prob) return -1;
739
+ if (a_->prob < b_->prob) return 1;
740
+ return 0;
741
+ }
742
+
743
+ int sample_topp(float* probabilities, int n, float topp, ProbIndex* probindex, float coin) {
744
+ // top-p sampling (or "nucleus sampling") samples from the smallest set of
745
+ // tokens that exceed probability topp. This way we never sample tokens that
746
+ // have very low probabilities and are less likely to go "off the rails".
747
+ // coin is a random number in [0, 1), usually from random_f32()
748
+
749
+ int n0 = 0;
750
+ // quicksort indices in descending order of probabilities
751
+ // values smaller than (1 - topp) / (n - 1) cannot be part of the result
752
+ // so for efficiency we crop these out as candidates before sorting
753
+ const float cutoff = (1.0f - topp) / (n - 1);
754
+ for (int i = 0; i < n; i++) {
755
+ if (probabilities[i] >= cutoff) {
756
+ probindex[n0].index = i;
757
+ probindex[n0].prob = probabilities[i];
758
+ n0++;
759
+ }
760
+ }
761
+ qsort(probindex, n0, sizeof(ProbIndex), compare);
762
+
763
+ // truncate the list where cumulative probability exceeds topp
764
+ float cumulative_prob = 0.0f;
765
+ int last_idx = n0 - 1; // in case of rounding errors consider all elements
766
+ for (int i = 0; i < n0; i++) {
767
+ cumulative_prob += probindex[i].prob;
768
+ if (cumulative_prob > topp) {
769
+ last_idx = i;
770
+ break; // we've exceeded topp by including last_idx
771
+ }
772
+ }
773
+
774
+ // sample from the truncated list
775
+ float r = coin * cumulative_prob;
776
+ float cdf = 0.0f;
777
+ for (int i = 0; i <= last_idx; i++) {
778
+ cdf += probindex[i].prob;
779
+ if (r < cdf) {
780
+ return probindex[i].index;
781
+ }
782
+ }
783
+ return probindex[last_idx].index; // in case of rounding errors
784
+ }
785
+
786
+ void build_sampler(Sampler* sampler, int vocab_size, float temperature, float topp, unsigned long long rng_seed) {
787
+ sampler->vocab_size = vocab_size;
788
+ sampler->temperature = temperature;
789
+ sampler->topp = topp;
790
+ sampler->rng_state = rng_seed;
791
+ // buffer only used with nucleus sampling; may not need but it's ~small
792
+ sampler->probindex = malloc(sampler->vocab_size * sizeof(ProbIndex));
793
+ }
794
+
795
+ void free_sampler(Sampler* sampler) {
796
+ free(sampler->probindex);
797
+ }
798
+
799
+ unsigned int random_u32(unsigned long long *state) {
800
+ // xorshift rng: https://en.wikipedia.org/wiki/Xorshift#xorshift.2A
801
+ *state ^= *state >> 12;
802
+ *state ^= *state << 25;
803
+ *state ^= *state >> 27;
804
+ return (*state * 0x2545F4914F6CDD1Dull) >> 32;
805
+ }
806
+ float random_f32(unsigned long long *state) { // random float32 in [0,1)
807
+ return (random_u32(state) >> 8) / 16777216.0f;
808
+ }
809
+
810
+ int sample(Sampler* sampler, float* logits) {
811
+ // sample the token given the logits and some hyperparameters
812
+ int next;
813
+ if (sampler->temperature == 0.0f) {
814
+ // greedy argmax sampling: take the token with the highest probability
815
+ next = sample_argmax(logits, sampler->vocab_size);
816
+ } else {
817
+ // apply the temperature to the logits
818
+ for (int q=0; q<sampler->vocab_size; q++) { logits[q] /= sampler->temperature; }
819
+ // apply softmax to the logits to get the probabilities for next token
820
+ softmax(logits, sampler->vocab_size);
821
+ // flip a (float) coin (this is our source of entropy for sampling)
822
+ float coin = random_f32(&sampler->rng_state);
823
+ // we sample from this distribution to get the next token
824
+ if (sampler->topp <= 0 || sampler->topp >= 1) {
825
+ // simply sample from the predicted probability distribution
826
+ next = sample_mult(logits, sampler->vocab_size, coin);
827
+ } else {
828
+ // top-p (nucleus) sampling, clamping the least likely tokens to zero
829
+ next = sample_topp(logits, sampler->vocab_size, sampler->topp, sampler->probindex, coin);
830
+ }
831
+ }
832
+ return next;
833
+ }
834
+
835
+ // ----------------------------------------------------------------------------
836
+ // utilities: time
837
+
838
+ long time_in_ms() {
839
+ // return time in milliseconds, for benchmarking the model speed
840
+ struct timespec time;
841
+ clock_gettime(CLOCK_REALTIME, &time);
842
+ return time.tv_sec * 1000 + time.tv_nsec / 1000000;
843
+ }
844
+
845
+ // ----------------------------------------------------------------------------
846
+ // generation loop
847
+
848
+ void generate(Transformer *transformer, Tokenizer *tokenizer, Sampler *sampler, char *prompt, int steps) {
849
+ char *empty_prompt = "";
850
+ if (prompt == NULL) { prompt = empty_prompt; }
851
+
852
+ // encode the (string) prompt into tokens sequence
853
+ int num_prompt_tokens = 0;
854
+ int* prompt_tokens = (int*)malloc((strlen(prompt)+3) * sizeof(int)); // +3 for '\0', ?BOS, ?EOS
855
+ encode(tokenizer, prompt, 1, 0, prompt_tokens, &num_prompt_tokens);
856
+ if (num_prompt_tokens < 1) {
857
+ fprintf(stderr, "something is wrong, expected at least 1 prompt token\n");
858
+ exit(EXIT_FAILURE);
859
+ }
860
+
861
+ // start the main loop
862
+ long start = 0; // used to time our code, only initialized after first iteration
863
+ int next; // will store the next token in the sequence
864
+ int token = prompt_tokens[0]; // kick off with the first token in the prompt
865
+ int pos = 0; // position in the sequence
866
+ while (pos < steps) {
867
+
868
+ // forward the transformer to get logits for the next token
869
+ float* logits = forward(transformer, token, pos);
870
+
871
+ // advance the state state machine
872
+ if (pos < num_prompt_tokens - 1) {
873
+ // if we are still processing the input prompt, force the next prompt token
874
+ next = prompt_tokens[pos + 1];
875
+ } else {
876
+ // otherwise sample the next token from the logits
877
+ next = sample(sampler, logits);
878
+ }
879
+ pos++;
880
+
881
+ // data-dependent terminating condition: the BOS (=1) token delimits sequences
882
+ if (next == 1) { break; }
883
+
884
+ // print the token as string, decode it with the Tokenizer object
885
+ char* piece = decode(tokenizer, token, next);
886
+ safe_printf(piece); // same as printf("%s", piece), but skips "unsafe" bytes
887
+ fflush(stdout);
888
+ token = next;
889
+
890
+ // init the timer here because the first iteration can be slower
891
+ if (start == 0) { start = time_in_ms(); }
892
+ }
893
+ printf("\n");
894
+
895
+ // report achieved tok/s (pos-1 because the timer starts after first iteration)
896
+ if (pos > 1) {
897
+ long end = time_in_ms();
898
+ fprintf(stderr, "achieved tok/s: %f\n", (pos-1) / (double)(end-start)*1000);
899
+ }
900
+
901
+ free(prompt_tokens);
902
+ }
903
+
904
+ void read_stdin(const char* guide, char* buffer, size_t bufsize) {
905
+ // read a line from stdin, up to but not including \n
906
+ printf("%s", guide);
907
+ if (fgets(buffer, bufsize, stdin) != NULL) {
908
+ size_t len = strlen(buffer);
909
+ if (len > 0 && buffer[len - 1] == '\n') {
910
+ buffer[len - 1] = '\0'; // strip newline
911
+ }
912
+ }
913
+ }
914
+
915
+ // ----------------------------------------------------------------------------
916
+ // chat loop
917
+ // I manually inspected the tokens for a few chat conversations compared to
918
+ // python reference and that seemed ok, but this was not thoroughly tested and
919
+ // is not safely implemented, it's more a proof of concept atm.
920
+
921
+ void chat(Transformer *transformer, Tokenizer *tokenizer, Sampler *sampler,
922
+ char *cli_user_prompt, char *cli_system_prompt, int steps) {
923
+
924
+ // buffers for reading the system prompt and user prompt from stdin
925
+ // you'll notice they are soomewhat haphazardly and unsafely set atm
926
+ char system_prompt[512];
927
+ char user_prompt[512];
928
+ char rendered_prompt[1152];
929
+ int num_prompt_tokens = 0;
930
+ int* prompt_tokens = (int*)malloc(1152 * sizeof(int));
931
+ int user_idx;
932
+
933
+ // start the main loop
934
+ int8_t user_turn = 1; // user starts
935
+ int next; // will store the next token in the sequence
936
+ int token; // stores the current token to feed into the transformer
937
+ int prev_token;
938
+ int pos = 0; // position in the sequence
939
+ while (pos < steps) {
940
+
941
+ // when it is the user's turn to contribute tokens to the dialog...
942
+ if (user_turn) {
943
+ // get the (optional) system prompt at position 0
944
+ if (pos == 0) {
945
+ // at position 0, the user can also contribute a system prompt
946
+ if (cli_system_prompt == NULL) {
947
+ // system prompt was not passed in, attempt to get it from stdin
948
+ read_stdin("Enter system prompt (optional): ", system_prompt, sizeof(system_prompt));
949
+ } else {
950
+ // system prompt was passed in, use it
951
+ strcpy(system_prompt, cli_system_prompt);
952
+ }
953
+ }
954
+ // get the user prompt
955
+ if (pos == 0 && cli_user_prompt != NULL) {
956
+ // user prompt for position 0 was passed in, use it
957
+ strcpy(user_prompt, cli_user_prompt);
958
+ } else {
959
+ // otherwise get user prompt from stdin
960
+ read_stdin("User: ", user_prompt, sizeof(user_prompt));
961
+ }
962
+ // render user/system prompts into the Llama 2 Chat schema
963
+ if (pos == 0 && system_prompt[0] != '\0') {
964
+ char system_template[] = "[INST] <<SYS>>\n%s\n<</SYS>>\n\n%s [/INST]";
965
+ sprintf(rendered_prompt, system_template, system_prompt, user_prompt);
966
+ } else {
967
+ char user_template[] = "[INST] %s [/INST]";
968
+ sprintf(rendered_prompt, user_template, user_prompt);
969
+ }
970
+ // encode the rendered prompt into tokens
971
+ encode(tokenizer, rendered_prompt, 1, 0, prompt_tokens, &num_prompt_tokens);
972
+ user_idx = 0; // reset the user index
973
+ user_turn = 0;
974
+ printf("Assistant: ");
975
+ }
976
+
977
+ // determine the token to pass into the transformer next
978
+ if (user_idx < num_prompt_tokens) {
979
+ // if we are still processing the input prompt, force the next prompt token
980
+ token = prompt_tokens[user_idx++];
981
+ } else {
982
+ // otherwise use the next token sampled from previous turn
983
+ token = next;
984
+ }
985
+ // EOS (=2) token ends the Assistant turn
986
+ if (token == 2) { user_turn = 1; }
987
+
988
+ // forward the transformer to get logits for the next token
989
+ float* logits = forward(transformer, token, pos);
990
+ next = sample(sampler, logits);
991
+ pos++;
992
+
993
+ if (user_idx >= num_prompt_tokens && next != 2) {
994
+ // the Assistant is responding, so print its output
995
+ char* piece = decode(tokenizer, token, next);
996
+ safe_printf(piece); // same as printf("%s", piece), but skips "unsafe" bytes
997
+ fflush(stdout);
998
+ }
999
+ if (next == 2) { printf("\n"); }
1000
+ }
1001
+ printf("\n");
1002
+ free(prompt_tokens);
1003
+ }
1004
+
1005
+
1006
+ // ----------------------------------------------------------------------------
1007
+ // CLI, include only if not testing
1008
+ #ifndef TESTING
1009
+
1010
+ void error_usage() {
1011
+ fprintf(stderr, "Usage: run <checkpoint> [options]\n");
1012
+ fprintf(stderr, "Example: run model.bin -n 256 -i \"Once upon a time\"\n");
1013
+ fprintf(stderr, "Options:\n");
1014
+ fprintf(stderr, " -t <float> temperature in [0,inf], default 1.0\n");
1015
+ fprintf(stderr, " -p <float> p value in top-p (nucleus) sampling in [0,1] default 0.9\n");
1016
+ fprintf(stderr, " -s <int> random seed, default time(NULL)\n");
1017
+ fprintf(stderr, " -n <int> number of steps to run for, default 256. 0 = max_seq_len\n");
1018
+ fprintf(stderr, " -i <string> input prompt\n");
1019
+ fprintf(stderr, " -z <string> optional path to custom tokenizer\n");
1020
+ fprintf(stderr, " -m <string> mode: generate|chat, default: generate\n");
1021
+ fprintf(stderr, " -y <string> (optional) system prompt in chat mode\n");
1022
+ exit(EXIT_FAILURE);
1023
+ }
1024
+
1025
+ int main(int argc, char *argv[]) {
1026
+
1027
+ // default parameters
1028
+ char *checkpoint_path = NULL; // e.g. out/model.bin
1029
+ char *tokenizer_path = "tokenizer.bin";
1030
+ float temperature = 1.0f; // 0.0 = greedy deterministic. 1.0 = original. don't set higher
1031
+ float topp = 0.9f; // top-p in nucleus sampling. 1.0 = off. 0.9 works well, but slower
1032
+ int steps = 256; // number of steps to run for
1033
+ char *prompt = NULL; // prompt string
1034
+ unsigned long long rng_seed = 0; // seed rng with time by default
1035
+ char *mode = "generate"; // generate|chat
1036
+ char *system_prompt = NULL; // the (optional) system prompt to use in chat mode
1037
+
1038
+ // poor man's C argparse so we can override the defaults above from the command line
1039
+ if (argc >= 2) { checkpoint_path = argv[1]; } else { error_usage(); }
1040
+ for (int i = 2; i < argc; i+=2) {
1041
+ // do some basic validation
1042
+ if (i + 1 >= argc) { error_usage(); } // must have arg after flag
1043
+ if (argv[i][0] != '-') { error_usage(); } // must start with dash
1044
+ if (strlen(argv[i]) != 2) { error_usage(); } // must be -x (one dash, one letter)
1045
+ // read in the args
1046
+ if (argv[i][1] == 't') { temperature = atof(argv[i + 1]); }
1047
+ else if (argv[i][1] == 'p') { topp = atof(argv[i + 1]); }
1048
+ else if (argv[i][1] == 's') { rng_seed = atoi(argv[i + 1]); }
1049
+ else if (argv[i][1] == 'n') { steps = atoi(argv[i + 1]); }
1050
+ else if (argv[i][1] == 'i') { prompt = argv[i + 1]; }
1051
+ else if (argv[i][1] == 'z') { tokenizer_path = argv[i + 1]; }
1052
+ else if (argv[i][1] == 'm') { mode = argv[i + 1]; }
1053
+ else if (argv[i][1] == 'y') { system_prompt = argv[i + 1]; }
1054
+ else { error_usage(); }
1055
+ }
1056
+
1057
+ // parameter validation/overrides
1058
+ if (rng_seed <= 0) rng_seed = (unsigned int)time(NULL);
1059
+ if (temperature < 0.0) temperature = 0.0;
1060
+ if (topp < 0.0 || 1.0 < topp) topp = 0.9;
1061
+ if (steps < 0) steps = 0;
1062
+
1063
+ // build the Transformer via the model .bin file
1064
+ Transformer transformer;
1065
+ build_transformer(&transformer, checkpoint_path);
1066
+ if (steps == 0 || steps > transformer.config.seq_len) steps = transformer.config.seq_len; // override to ~max length
1067
+
1068
+ // build the Tokenizer via the tokenizer .bin file
1069
+ Tokenizer tokenizer;
1070
+ build_tokenizer(&tokenizer, tokenizer_path, transformer.config.vocab_size);
1071
+
1072
+ // build the Sampler
1073
+ Sampler sampler;
1074
+ build_sampler(&sampler, transformer.config.vocab_size, temperature, topp, rng_seed);
1075
+
1076
+ // run!
1077
+ if (strcmp(mode, "generate") == 0) {
1078
+ generate(&transformer, &tokenizer, &sampler, prompt, steps);
1079
+ } else if (strcmp(mode, "chat") == 0) {
1080
+ chat(&transformer, &tokenizer, &sampler, prompt, system_prompt, steps);
1081
+ } else {
1082
+ fprintf(stderr, "unknown mode: %s\n", mode);
1083
+ error_usage();
1084
+ }
1085
+
1086
+ // memory and file handles cleanup
1087
+ free_sampler(&sampler);
1088
+ free_tokenizer(&tokenizer);
1089
+ free_transformer(&transformer);
1090
+ return 0;
1091
+ }
1092
+ #endif
llama2.c/sample.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Sample from the trained model with PyTorch
3
+ """
4
+ import os
5
+ import pickle
6
+ from contextlib import nullcontext
7
+ import torch
8
+ from model import ModelArgs, Transformer
9
+ from tokenizer import Tokenizer
10
+
11
+ from tinystories import get_tokenizer_model_path
12
+
13
+ # -----------------------------------------------------------------------------
14
+ checkpoint = 'out/ckpt.pt'
15
+ start = "" # or "<|endoftext|>" or etc. Can also specify a file, use as: "FILE:prompt.txt"
16
+ num_samples = 1 # number of samples to draw
17
+ max_new_tokens = 100 # number of tokens generated in each sample
18
+ temperature = 1.0 # 1.0 = no change, < 1.0 = less random, > 1.0 = more random, in predictions
19
+ top_k = 300 # retain only the top_k most likely tokens, clamp others to have 0 probability
20
+ tokenizer = "" # override the tokenizer model path
21
+ seed = 1337
22
+ device = 'cuda' if torch.cuda.is_available() else 'cpu' # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1', etc.
23
+ #dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16' # 'float32' or 'bfloat16' or 'float16'
24
+ dtype = "float32"
25
+ compile = False # use PyTorch 2.0 to compile the model to be faster
26
+ exec(open('configurator.py').read()) # overrides from command line or config file
27
+ # -----------------------------------------------------------------------------
28
+
29
+ torch.manual_seed(seed)
30
+ torch.cuda.manual_seed(seed)
31
+ torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
32
+ torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
33
+ device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.autocast
34
+ ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
35
+ ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype)
36
+
37
+ # init from a model saved in a specific directory
38
+ checkpoint_dict = torch.load(checkpoint, map_location=device)
39
+ gptconf = ModelArgs(**checkpoint_dict['model_args'])
40
+ model = Transformer(gptconf)
41
+ state_dict = checkpoint_dict['model']
42
+ unwanted_prefix = '_orig_mod.'
43
+ for k,v in list(state_dict.items()):
44
+ if k.startswith(unwanted_prefix):
45
+ state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
46
+ model.load_state_dict(state_dict, strict=False)
47
+
48
+ model.eval()
49
+ model.to(device)
50
+ if compile:
51
+ print("Compiling the model...")
52
+ model = torch.compile(model) # requires PyTorch 2.0 (optional)
53
+
54
+ # load the tokenizer
55
+ vocab_source = checkpoint_dict["config"].get("vocab_source", "llama2")
56
+ vocab_size = gptconf.vocab_size
57
+ if tokenizer:
58
+ # a specific tokenizer is provided, use it
59
+ tokenizer_model = tokenizer
60
+ else:
61
+ # let's try to find the tokenizer model automatically. bit gross here...
62
+ query_vocab_size = 0 if vocab_source == "llama2" else vocab_size
63
+ tokenizer_model = get_tokenizer_model_path(vocab_size=query_vocab_size)
64
+ enc = Tokenizer(tokenizer_model=tokenizer_model)
65
+
66
+ # encode the beginning of the prompt
67
+ if start.startswith('FILE:'):
68
+ with open(start[5:], 'r', encoding='utf-8') as f:
69
+ start = f.read()
70
+ start_ids = enc.encode(start, bos=True, eos=False)
71
+ x = (torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...])
72
+
73
+ # run generation
74
+ with torch.no_grad():
75
+ with ctx:
76
+ for k in range(num_samples):
77
+ y = model.generate(x, max_new_tokens, temperature=temperature, top_k=top_k)
78
+ print(enc.decode(y[0].tolist()))
79
+ print('---------------')
llama2.c/test.c ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #define TESTING
2
+ #include "run.c"
3
+
4
+ void assert_eq(int a, int b) {
5
+ if (a != b) {
6
+ printf("Assertion failed: %d != %d\n", a, b);
7
+ exit(EXIT_FAILURE);
8
+ }
9
+ }
10
+
11
+ void test_prompt_encoding(Tokenizer* tokenizer, char* prompt, int* expected_tokens, int num_expected_tokens) {
12
+ // encode
13
+ int* prompt_tokens = (int*)malloc((strlen(prompt)+3) * sizeof(int));
14
+ int num_prompt_tokens = 0; // the total number of prompt tokens
15
+ encode(tokenizer, prompt, 1, 0, prompt_tokens, &num_prompt_tokens);
16
+
17
+ #if VERBOSITY == 1
18
+ // print maybe
19
+ printf("expected tokens:\n");
20
+ for (int i = 0; i < num_expected_tokens; i++) printf("%d ", expected_tokens[i]);
21
+ printf("\n");
22
+ printf("actual tokens:\n");
23
+ for (int i = 0; i < num_prompt_tokens; i++) printf("%d ", prompt_tokens[i]);
24
+ printf("\n");
25
+ #endif
26
+
27
+ // verify
28
+ assert_eq(num_prompt_tokens, num_expected_tokens);
29
+ for (int i = 0; i < num_prompt_tokens; i++) {
30
+ assert_eq(prompt_tokens[i], expected_tokens[i]);
31
+ }
32
+
33
+ #if VERBOSITY == 1
34
+ printf("OK\n");
35
+ printf("---\n");
36
+ #endif
37
+ free(prompt_tokens);
38
+ }
39
+
40
+ void test_prompt_encodings() {
41
+ // let's verify that the Tokenizer works as expected
42
+
43
+ char *tokenizer_path = "tokenizer.bin";
44
+ int vocab_size = 32000;
45
+ Tokenizer tokenizer;
46
+ build_tokenizer(&tokenizer, tokenizer_path, vocab_size);
47
+
48
+ // test 0 (test the empty string) (I added this as a simple case)
49
+ char *prompt0 = "";
50
+ int expected_tokens0[] = {1};
51
+ test_prompt_encoding(&tokenizer, prompt0, expected_tokens0, sizeof(expected_tokens0) / sizeof(int));
52
+
53
+ // the tests below are taken from the Meta Llama 2 repo example code
54
+ // https://github.com/facebookresearch/llama/blob/main/example_text_completion.py
55
+ // and the expected tokens come from me breaking in the debugger in Python
56
+
57
+ // test 1
58
+ char *prompt = "I believe the meaning of life is";
59
+ int expected_tokens[] = {1, 306, 4658, 278, 6593, 310, 2834, 338};
60
+ test_prompt_encoding(&tokenizer, prompt, expected_tokens, sizeof(expected_tokens) / sizeof(int));
61
+
62
+ // test 2
63
+ char* prompt2 = "Simply put, the theory of relativity states that ";
64
+ int expected_tokens2[] = {1, 3439, 17632, 1925, 29892, 278, 6368, 310, 14215, 537, 5922, 393, 29871};
65
+ test_prompt_encoding(&tokenizer, prompt2, expected_tokens2, sizeof(expected_tokens2) / sizeof(int));
66
+
67
+ // test 3
68
+ char* prompt3 = "A brief message congratulating the team on the launch:\n\n Hi everyone,\n\n I just ";
69
+ int expected_tokens3[] = {1, 319, 11473, 2643, 378, 629, 271, 18099, 278, 3815, 373, 278, 6826, 29901, 13, 13, 4706, 6324, 14332, 29892, 13, 13, 4706, 306, 925, 29871};
70
+ test_prompt_encoding(&tokenizer, prompt3, expected_tokens3, sizeof(expected_tokens3) / sizeof(int));
71
+
72
+ // test 4
73
+ char* prompt4 = "Translate English to French:\n\n sea otter => loutre de mer\n peppermint => menthe poivrée\n plush girafe => girafe peluche\n cheese =>";
74
+ int expected_tokens4[] = {1, 4103, 9632, 4223, 304, 5176, 29901, 13, 13, 4706, 7205, 4932, 357, 1149, 301, 449, 276, 316, 2778, 13, 4706, 1236, 407, 837, 524, 1149, 6042, 354, 772, 440, 29878, 1318, 13, 4706, 715, 1878, 330, 3055, 1725, 1149, 330, 3055, 1725, 4639, 28754, 13, 4706, 923, 968, 1149};
75
+ test_prompt_encoding(&tokenizer, prompt4, expected_tokens4, sizeof(expected_tokens4) / sizeof(int));
76
+
77
+ // memory and file handles cleanup
78
+ free_tokenizer(&tokenizer);
79
+ }
80
+
81
+ int main(int argc, char *argv[]) {
82
+ test_prompt_encodings();
83
+ printf("ALL OK\n");
84
+ }
llama2.c/test_all.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Run simply with
3
+ $ pytest
4
+ """
5
+ import os
6
+ import pytest # pip install pytest
7
+ import requests
8
+ import subprocess
9
+
10
+
11
+ import torch
12
+ from model import ModelArgs, Transformer
13
+ from tokenizer import Tokenizer
14
+
15
+ # -----------------------------------------------------------------------------
16
+ # test utilities
17
+
18
+ test_ckpt_dir = "test"
19
+
20
+ def download_file(url, filename):
21
+ print(f"Downloading {url} to {filename}")
22
+ response = requests.get(url, stream=True)
23
+ response.raise_for_status() # Raise an HTTPError on bad status code
24
+ with open(filename, 'wb') as file:
25
+ for chunk in response.iter_content(chunk_size=8192):
26
+ file.write(chunk)
27
+
28
+ def attempt_download_files():
29
+ os.makedirs(test_ckpt_dir, exist_ok=True)
30
+ root_url = "https://huggingface.co/karpathy/tinyllamas/resolve/main/stories260K"
31
+ need = ["stories260K.bin", "stories260K.pt", "tok512.bin", "tok512.model"]
32
+ for file in need:
33
+ url = root_url + '/' + file #os.path.join inserts \\ on windows
34
+ filename = os.path.join(test_ckpt_dir, file)
35
+ if not os.path.exists(filename):
36
+ download_file(url, filename)
37
+
38
+ expected_stdout = b'Once upon a time, there was a little girl named Lily. She loved to play outside in the park. One day, she saw a big, red ball. She wanted to play with it, but it was too high.\nLily\'s mom said, "Lily, let\'s go to the park." Lily was sad and didn\'t know what to do. She said, "I want to play with your ball, but I can\'t find it."\nLily was sad and didn\'t know what to do. She said, "I\'m sorry, Lily. I didn\'t know what to do."\nLily didn\'t want to help her mom, so she'
39
+
40
+ # -----------------------------------------------------------------------------
41
+ # actual tests
42
+
43
+ def test_runc():
44
+ """ Forwards a model against a known-good desired outcome in run.c for 200 steps"""
45
+ attempt_download_files()
46
+
47
+ model_path = os.path.join(test_ckpt_dir, "stories260K.bin")
48
+ tokenizer_path = os.path.join(test_ckpt_dir, "tok512.bin")
49
+ command = ["./run", model_path, "-z", tokenizer_path, "-t", "0.0", "-n", "200"]
50
+ with open('err.txt', mode='wb') as fe:
51
+ with open('stdout.txt', mode='wb') as fo:
52
+ proc = subprocess.Popen(command, stdout=fo, stderr=fe) #pipe in windows terminal does funny things like replacing \n with \r\n
53
+ proc.wait()
54
+
55
+ with open('stdout.txt', mode='r') as f:
56
+ stdout = f.read()
57
+ # strip the very last \n that is added by run.c for aesthetic reasons
58
+ stdout = stdout[:-1].encode('ascii')
59
+
60
+ assert stdout == expected_stdout
61
+
62
+ def test_python():
63
+ """ Forwards a model against a known-good desired outcome in sample.py for 200 steps"""
64
+ attempt_download_files()
65
+
66
+ device = "cpu" # stories260K is small enough to just breeze through it on CPU
67
+ checkpoint = os.path.join(test_ckpt_dir, "stories260K.pt")
68
+ checkpoint_dict = torch.load(checkpoint, map_location=device)
69
+ gptconf = ModelArgs(**checkpoint_dict['model_args'])
70
+ model = Transformer(gptconf)
71
+ state_dict = checkpoint_dict['model']
72
+ unwanted_prefix = '_orig_mod.'
73
+ for k,v in list(state_dict.items()):
74
+ if k.startswith(unwanted_prefix):
75
+ state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
76
+ model.load_state_dict(state_dict, strict=False)
77
+ model.eval()
78
+ model.to(device)
79
+ x = torch.tensor([[1]], dtype=torch.long, device=device) # 1 is BOS
80
+ with torch.inference_mode():
81
+ y = model.generate(x, max_new_tokens=200, temperature=0.0)
82
+ pt_tokens = y[0].tolist()
83
+
84
+ tokenizer_model = os.path.join(test_ckpt_dir, "tok512.model")
85
+ enc = Tokenizer(tokenizer_model=tokenizer_model)
86
+ text = enc.decode(pt_tokens)
87
+ text = text.encode('ascii') # turn into bytes
88
+
89
+ assert text == expected_stdout
llama2.c/tinystories.py ADDED
@@ -0,0 +1,281 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Download, preprocess and serve the TinyStories dataset as a DataLoader.
3
+ """
4
+
5
+ import argparse
6
+ import glob
7
+ import json
8
+ import os
9
+ import random
10
+ from typing import List
11
+ from concurrent.futures import ProcessPoolExecutor
12
+ from functools import partial
13
+
14
+ import numpy as np
15
+ import requests
16
+ import sentencepiece as spm
17
+ import torch
18
+ import torch.distributed as dist
19
+ from tqdm import tqdm
20
+
21
+ from tokenizer import Tokenizer
22
+
23
+ DATA_CACHE_DIR = "data"
24
+
25
+ def download_file(url: str, fname: str, chunk_size=1024):
26
+ """Helper function to download a file from a given url"""
27
+ resp = requests.get(url, stream=True)
28
+ total = int(resp.headers.get("content-length", 0))
29
+ with open(fname, "wb") as file, tqdm(
30
+ desc=fname,
31
+ total=total,
32
+ unit="iB",
33
+ unit_scale=True,
34
+ unit_divisor=1024,
35
+ ) as bar:
36
+ for data in resp.iter_content(chunk_size=chunk_size):
37
+ size = file.write(data)
38
+ bar.update(size)
39
+
40
+
41
+ def download():
42
+ """Downloads the TinyStories dataset to DATA_CACHE_DIR"""
43
+ os.makedirs(DATA_CACHE_DIR, exist_ok=True)
44
+
45
+ # download the TinyStories dataset, unless it's already downloaded
46
+ data_url = "https://huggingface.co/datasets/flopml/toy-datasets/resolve/main/tinystories_minimal.tar.gz"
47
+ data_filename = os.path.join(DATA_CACHE_DIR, "TinyStories_all_data.tar.gz")
48
+ if not os.path.exists(data_filename):
49
+ print(f"Downloading {data_url} to {data_filename}...")
50
+ download_file(data_url, data_filename)
51
+ else:
52
+ print(f"{data_filename} already exists, skipping download...")
53
+
54
+ # unpack the tar.gz file into all the data shards (json files)
55
+ data_dir = os.path.join(DATA_CACHE_DIR, "TinyStories_all_data")
56
+ if not os.path.exists(data_dir):
57
+ os.makedirs(data_dir, exist_ok=True)
58
+ print(f"Unpacking {data_filename}...")
59
+ os.system(f"tar -xzf {data_filename} -C {data_dir}")
60
+ else:
61
+ print(f"{data_dir} already exists, skipping unpacking...")
62
+
63
+ # print a single example just for debugging and such
64
+ shard_filenames = sorted(glob.glob(os.path.join(data_dir, "*.json")))
65
+ with open(shard_filenames[0], "r") as f:
66
+ data = json.load(f)
67
+ print("Download done.")
68
+ print(f"Number of shards: {len(shard_filenames)}")
69
+ print(f"Example story:\n{data[0]}")
70
+
71
+ def train_vocab(vocab_size):
72
+ """
73
+ Trains a custom sentencepiece tokenizer on the TinyStories dataset.
74
+ The custom tokenizer files will be saved in DATA_CACHE_DIR/tok{N} directories,
75
+ where N is the vocab size. This is also where the pretok .bin files will go.
76
+ """
77
+ assert vocab_size > 0, "Vocab size must be positive"
78
+
79
+ # output file prefix path for sentencepiece
80
+ prefix = os.path.join(DATA_CACHE_DIR, f"tok{vocab_size}")
81
+
82
+ # how many shards we'll use for vocab training, kept low for efficiency
83
+ num_shards = 10
84
+
85
+ # 1) export a large chunk of text as a single text file tiny.txt
86
+ tiny_file = os.path.join(DATA_CACHE_DIR, "tiny.txt")
87
+ data_dir = os.path.join(DATA_CACHE_DIR, "TinyStories_all_data")
88
+ shard_filenames = sorted(glob.glob(os.path.join(data_dir, "*.json")))
89
+
90
+ print(f"Writing temporary file {tiny_file} with {num_shards} shards...")
91
+ with open(tiny_file, "w", encoding="utf-8") as of:
92
+ for shard in tqdm(shard_filenames[:num_shards]):
93
+ with open(shard, "r") as f:
94
+ data = json.load(f)
95
+ for example in data:
96
+ text = example["story"]
97
+ text = text.strip()
98
+ of.write(text + "\n")
99
+ print(f"Size is: {os.path.getsize(tiny_file) / 1024 / 1024:.2f} MB")
100
+
101
+ # 2) train the sentencepiece model
102
+ print("Will now train the vocab...")
103
+ spm.SentencePieceTrainer.train(input=tiny_file,
104
+ model_prefix=prefix,
105
+ model_type="bpe",
106
+ vocab_size=vocab_size,
107
+ self_test_sample_size=0,
108
+ input_format="text",
109
+ character_coverage=1.0,
110
+ num_threads=os.cpu_count(),
111
+ split_digits=True,
112
+ allow_whitespace_only_pieces=True,
113
+ byte_fallback=True,
114
+ unk_surface=r" \342\201\207 ",
115
+ normalization_rule_name="identity")
116
+
117
+ # 3) optional cleanup, ask the user if they'd like to delete tiny.txt
118
+ dec = input(f"Delete the temporary file {tiny_file}? [y/N] ")
119
+ if dec.lower() == "y":
120
+ os.remove(tiny_file)
121
+ print(f"Deleted {tiny_file}")
122
+
123
+ print(f"Trained tokenizer is in {prefix}.model")
124
+ print("Done.")
125
+
126
+
127
+ def process_shard(args, vocab_size):
128
+ shard_id, shard = args
129
+ tokenizer_model = get_tokenizer_model_path(vocab_size)
130
+ enc = Tokenizer(tokenizer_model)
131
+ with open(shard, "r") as f:
132
+ data = json.load(f)
133
+ all_tokens = []
134
+ for example in tqdm(data, position=shard_id):
135
+ text = example["story"]
136
+ text = text.strip() # get rid of leading/trailing whitespace
137
+ tokens = enc.encode(text, bos=True, eos=False) # encode the text, use BOS
138
+ all_tokens.extend(tokens)
139
+ # convert to uint16 nparray
140
+ all_tokens = np.array(all_tokens, dtype=np.uint16)
141
+ # calculate the output filename
142
+ if vocab_size == 0:
143
+ # if we're using Llama 2, just save the tokenized file in the same dir
144
+ tokenized_filename = shard.replace(".json", ".bin")
145
+ else:
146
+ # save .bin files into a new tok{N} directory
147
+ bin_dir = os.path.join(DATA_CACHE_DIR, f"tok{vocab_size}")
148
+ shard_basename = os.path.basename(shard)
149
+ bin_basename = shard_basename.replace(".json", ".bin")
150
+ tokenized_filename = os.path.join(bin_dir, bin_basename)
151
+ # write the bytes
152
+ with open(tokenized_filename, "wb") as f:
153
+ f.write(all_tokens.tobytes())
154
+ # calculate the average sequence length (they are separated by BOS=1)
155
+ avg_seq_len = all_tokens.size / ((all_tokens == 1).sum())
156
+ print(f"Saved {tokenized_filename}, average seqlen: {avg_seq_len:.2f}")
157
+
158
+
159
+ def pretokenize(vocab_size):
160
+ # iterate the shards and tokenize all of them one by one
161
+ data_dir = os.path.join(DATA_CACHE_DIR, "TinyStories_all_data")
162
+ shard_filenames = sorted(glob.glob(os.path.join(data_dir, "*.json")))
163
+ if vocab_size > 0:
164
+ # .bin files will be saved into tok{N} directory, create it once here
165
+ bin_dir = os.path.join(DATA_CACHE_DIR, f"tok{vocab_size}")
166
+ os.makedirs(bin_dir, exist_ok=True)
167
+
168
+ # process all the shards in a process pool
169
+ fun = partial(process_shard, vocab_size=vocab_size)
170
+ with ProcessPoolExecutor() as executor:
171
+ executor.map(fun, enumerate(shard_filenames))
172
+ print("Done.")
173
+
174
+
175
+ class PretokDataset(torch.utils.data.IterableDataset):
176
+ """Loads pretokenized examples from disk and yields them as PyTorch tensors."""
177
+
178
+ def __init__(self, split, max_seq_len, vocab_size, vocab_source):
179
+ super().__init__()
180
+ self.split = split
181
+ self.max_seq_len = max_seq_len
182
+ self.vocab_size = vocab_size
183
+ self.vocab_source = vocab_source
184
+
185
+ def __iter__(self):
186
+ # get worker info within a DataLoader
187
+ worker_info = torch.utils.data.get_worker_info()
188
+ worker_id = worker_info.id if worker_info else 0
189
+ # get DDP rank info
190
+ rank = dist.get_rank() if dist.is_initialized() else 0
191
+ # combine the worker_id and worker_rank to create a unique seed for rng
192
+ seed = 42 + worker_id + 1337 * rank
193
+ rng = random.Random(seed)
194
+ print(f"Created a PretokDataset with rng seed {seed}")
195
+ if self.vocab_source == "llama2":
196
+ # the .bin files are right along the .json files
197
+ bin_dir = os.path.join(DATA_CACHE_DIR, "TinyStories_all_data")
198
+ shard_filenames = sorted(glob.glob(os.path.join(bin_dir, "*.bin")))
199
+ elif self.vocab_source == "custom":
200
+ # the .bin files are in tok{N} directory
201
+ bin_dir = os.path.join(DATA_CACHE_DIR, f"tok{self.vocab_size}")
202
+ shard_filenames = sorted(glob.glob(os.path.join(bin_dir, "*.bin")))
203
+ # train/test split. let's use only shard 0 for test split, rest train
204
+ shard_filenames = shard_filenames[1:] if self.split == "train" else shard_filenames[:1]
205
+ assert len(shard_filenames)>0, f"No bin files found in {bin_dir}"
206
+ while True:
207
+ rng.shuffle(shard_filenames)
208
+ for shard in shard_filenames:
209
+ # open the dataset for reading but keep it on disk with memmap
210
+ m = np.memmap(shard, dtype=np.uint16, mode="r")
211
+ num_batches = len(m) // self.max_seq_len
212
+ num_batches -= 1 # drop the last partial batch
213
+ assert num_batches > 0, "this shard is way too small? investigate."
214
+ ixs = list(range(num_batches))
215
+ rng.shuffle(ixs)
216
+ for ix in ixs:
217
+ start = ix * self.max_seq_len
218
+ end = start + self.max_seq_len + 1
219
+ # calling .astype will copy the data into a new numpy array, now in RAM
220
+ chunk = torch.from_numpy((m[start:end]).astype(np.int64))
221
+ x = chunk[:-1]
222
+ y = chunk[1:]
223
+ yield x, y
224
+
225
+ # -----------------------------------------------------------------------------
226
+ # public interface functions
227
+
228
+ def get_tokenizer_model_path(vocab_size):
229
+ """
230
+ Returns path to the sentencepiece tokenizer model for a given vocab size
231
+ vocab_size = 0 designates the default Llama 2 tokenizer, in that case
232
+ None is returned.
233
+ """
234
+ if vocab_size == 0:
235
+ return None
236
+ else:
237
+ return os.path.join(DATA_CACHE_DIR, f"tok{vocab_size}.model")
238
+
239
+ class Task:
240
+
241
+ @staticmethod
242
+ def iter_batches(batch_size, device, num_workers=0, **dataset_kwargs):
243
+ ds = PretokDataset(**dataset_kwargs)
244
+ dl = torch.utils.data.DataLoader(
245
+ ds, batch_size=batch_size, pin_memory=True, num_workers=num_workers
246
+ )
247
+ for x, y in dl:
248
+ x = x.to(device, non_blocking=True)
249
+ y = y.to(device, non_blocking=True)
250
+ yield x, y
251
+
252
+ # -----------------------------------------------------------------------------
253
+ # CLI for constructing the dataset
254
+
255
+ if __name__ == "__main__":
256
+ """
257
+ These stages are designed to be run in order.
258
+
259
+ To tokenize data with the Llama 2 tokenizer:
260
+ python tinystories.py download
261
+ python tinystories.py pretokenize
262
+
263
+ To tokenize data with a custom tokenizer we train ourselves with sentencepiece, e.g.:
264
+ python tinystories.py download
265
+ python tinystories.py train_vocab --vocab_size=2048
266
+ python tinystories.py pretokenize --vocab_size=2048
267
+ """
268
+ parser = argparse.ArgumentParser()
269
+ parser.add_argument("stage", type=str, choices=["download", "pretokenize", "train_vocab"])
270
+ parser.add_argument("--vocab_size", type=int, default=0, help="pretokenization vocab size. 0 = use Llama 2 tokenizer.")
271
+ args = parser.parse_args()
272
+
273
+ # depending on the stage call the appropriate function
274
+ if args.stage == "download":
275
+ download()
276
+ elif args.stage == "train_vocab":
277
+ train_vocab(vocab_size=args.vocab_size)
278
+ elif args.stage == "pretokenize":
279
+ pretokenize(vocab_size=args.vocab_size)
280
+ else:
281
+ raise ValueError(f"Unknown stage {args.stage}")
llama2.c/tokenizer.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:50a52ef822ee9e83de5ce9d0be0a025a773d019437f58b5ff9dcafb063ece361
3
+ size 433869
llama2.c/tokenizer.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9e556afd44213b6bd1be2b850ebbbd98f5481437a8021afaf58ee7fb1818d347
3
+ size 499723
llama2.c/tokenizer.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Taken from llama code and lightly modified
2
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
4
+
5
+ import os
6
+ import struct
7
+ import argparse
8
+ from typing import List
9
+
10
+ from sentencepiece import SentencePieceProcessor
11
+
12
+ TOKENIZER_MODEL = "tokenizer.model" # the llama sentencepiece tokenizer model
13
+
14
+ class Tokenizer:
15
+ def __init__(self, tokenizer_model=None):
16
+ model_path = tokenizer_model if tokenizer_model else TOKENIZER_MODEL
17
+ assert os.path.isfile(model_path), model_path
18
+ self.sp_model = SentencePieceProcessor(model_file=model_path)
19
+ self.model_path = model_path
20
+
21
+ # BOS / EOS token IDs
22
+ self.n_words: int = self.sp_model.vocab_size()
23
+ self.bos_id: int = self.sp_model.bos_id()
24
+ self.eos_id: int = self.sp_model.eos_id()
25
+ self.pad_id: int = self.sp_model.pad_id()
26
+ #print(f"#words: {self.n_words} - BOS ID: {self.bos_id} - EOS ID: {self.eos_id}")
27
+ assert self.sp_model.vocab_size() == self.sp_model.get_piece_size()
28
+
29
+ def encode(self, s: str, bos: bool, eos: bool) -> List[int]:
30
+ assert type(s) is str
31
+ t = self.sp_model.encode(s)
32
+ if bos:
33
+ t = [self.bos_id] + t
34
+ if eos:
35
+ t = t + [self.eos_id]
36
+ return t
37
+
38
+ def decode(self, t: List[int]) -> str:
39
+ return self.sp_model.decode(t)
40
+
41
+ def export(self):
42
+
43
+ # get all the tokens (postprocessed) and their scores as floats
44
+ tokens, scores = [], []
45
+ for i in range(self.n_words):
46
+
47
+ # decode the token and light postprocessing
48
+ t = self.sp_model.id_to_piece(i)
49
+ s = self.sp_model.get_score(i)
50
+ if i == self.bos_id:
51
+ t = '\n<s>\n'
52
+ elif i == self.eos_id:
53
+ t = '\n</s>\n'
54
+ t = t.replace('▁', ' ') # sentencepiece uses this character as whitespace
55
+ b = t.encode('utf-8') # bytes of this token, utf-8 encoded
56
+
57
+ tokens.append(b)
58
+ scores.append(s)
59
+
60
+ # record the max token length
61
+ max_token_length = max(len(t) for t in tokens)
62
+
63
+ # write to a binary file
64
+ # the tokenizer.bin file is the same as .model file, but .bin
65
+ tokenizer_bin = self.model_path.replace('.model', '.bin')
66
+ with open(tokenizer_bin, 'wb') as f:
67
+ f.write(struct.pack("I", max_token_length))
68
+ for bytes, score in zip(tokens, scores):
69
+ f.write(struct.pack("fI", score, len(bytes)))
70
+ f.write(bytes)
71
+
72
+ if __name__ == "__main__":
73
+ parser = argparse.ArgumentParser()
74
+ parser.add_argument("-t", "--tokenizer-model", type=str, help="optional path to custom tokenizer ")
75
+ args = parser.parse_args()
76
+
77
+ t = Tokenizer(args.tokenizer_model)
78
+ t.export()
llama2.c/train.py ADDED
@@ -0,0 +1,343 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This training script can be run both on a single gpu in debug mode,
3
+ and also in a larger training run with distributed data parallel (ddp).
4
+
5
+ To run on a single GPU small debug run, example:
6
+ $ python -m train.py --compile=False --eval_iters=10 --batch_size=8
7
+
8
+ To run with DDP on 4 gpus on 1 node, example:
9
+ $ torchrun --standalone --nproc_per_node=4 train.py
10
+
11
+ To run with DDP on 4 gpus across 2 nodes, example:
12
+ - Run on the first (master) node with example IP 123.456.123.456:
13
+ $ torchrun --nproc_per_node=8 --nnodes=2 --node_rank=0 --master_addr=123.456.123.456 --master_port=1234 train.py
14
+ - Run on the worker node:
15
+ $ torchrun --nproc_per_node=8 --nnodes=2 --node_rank=1 --master_addr=123.456.123.456 --master_port=1234 train.py
16
+ (If your cluster does not have Infiniband interconnect prepend NCCL_IB_DISABLE=1)
17
+ """
18
+
19
+ import math
20
+ import os
21
+ import time
22
+ from contextlib import nullcontext
23
+ from datetime import datetime
24
+ from functools import partial
25
+
26
+ import torch
27
+ from model import Transformer, ModelArgs
28
+ from torch.distributed import destroy_process_group, init_process_group
29
+ from torch.nn.parallel import DistributedDataParallel as DDP
30
+
31
+ from tinystories import Task
32
+ from export import model_export
33
+
34
+ # -----------------------------------------------------------------------------
35
+ # I/O
36
+ out_dir = "out"
37
+ eval_interval = 2000
38
+ log_interval = 1
39
+ eval_iters = 100
40
+ eval_only = False # if True, script exits right after the first eval
41
+ always_save_checkpoint = False # if True, always save a checkpoint after each eval
42
+ init_from = "scratch" # 'scratch' or 'resume'
43
+ # wandb logging
44
+ wandb_log = False # disabled by default
45
+ wandb_project = "llamac"
46
+ wandb_run_name = "run" + datetime.now().strftime("%Y_%m_%d_%H_%M_%S")
47
+ # data
48
+ batch_size = 128 # if gradient_accumulation_steps > 1, this is the micro-batch size
49
+ max_seq_len = 256
50
+ vocab_source = "llama2" # llama2|custom; use Lllama 2 vocab from Meta, or custom trained
51
+ vocab_size = 32000 # the Llama 2 tokenizer has 32K tokens
52
+ # model
53
+ dim = 288
54
+ n_layers = 2
55
+ n_heads = 2
56
+ n_kv_heads = 2
57
+ multiple_of = 32
58
+ dropout = 0.0
59
+ # adamw optimizer
60
+ gradient_accumulation_steps = 4 # used to simulate larger batch sizes
61
+ learning_rate = 5e-4 # max learning rate
62
+ max_iters = 1000 # total number of training iterations
63
+ weight_decay = 1e-1
64
+ beta1 = 0.9
65
+ beta2 = 0.95
66
+ grad_clip = 1.0 # clip gradients at this value, or disable if == 0.0
67
+ # learning rate decay settings
68
+ decay_lr = True # whether to decay the learning rate
69
+ warmup_iters = 1000 # how many steps to warm up for
70
+ # system
71
+ device = "cuda" # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1' etc., or try 'mps' on macbooks
72
+ dtype = "bfloat16" # float32|bfloat16|float16
73
+ compile = True # use PyTorch 2.0 to compile the model to be faster
74
+ # -----------------------------------------------------------------------------
75
+ config_keys = [
76
+ k
77
+ for k, v in globals().items()
78
+ if not k.startswith("_") and isinstance(v, (int, float, bool, str))
79
+ ]
80
+ exec(open("configurator.py").read()) # overrides from command line or config file
81
+ config = {k: globals()[k] for k in config_keys} # will be useful for logging
82
+ # -----------------------------------------------------------------------------
83
+
84
+ # fixing some hyperparams to sensible defaults
85
+ lr_decay_iters = max_iters # should be ~= max_iters per Chinchilla
86
+ min_lr = 0.0 # minimum learning rate, should be ~= learning_rate/10 per Chinchilla
87
+
88
+ # validating checks
89
+ assert vocab_source in ["llama2", "custom"]
90
+ assert vocab_source == "custom" or vocab_size == 32000, "The vocab from Meta has 32K tokens"
91
+
92
+ # various inits, derived attributes, I/O setup
93
+ ddp = int(os.environ.get("RANK", -1)) != -1 # is this a ddp run?
94
+ if ddp:
95
+ init_process_group(backend="nccl")
96
+ ddp_rank = int(os.environ["RANK"])
97
+ ddp_local_rank = int(os.environ["LOCAL_RANK"])
98
+ ddp_world_size = int(os.environ["WORLD_SIZE"])
99
+ device = f"cuda:{ddp_local_rank}"
100
+ torch.cuda.set_device(device)
101
+ master_process = ddp_rank == 0 # this process will do logging, checkpointing etc.
102
+ seed_offset = ddp_rank # each process gets a different seed
103
+ # world_size number of processes will be training simultaneously, so we can scale
104
+ # down the desired gradient accumulation iterations per process proportionally
105
+ assert gradient_accumulation_steps % ddp_world_size == 0
106
+ gradient_accumulation_steps //= ddp_world_size
107
+ else:
108
+ # if not ddp, we are running on a single gpu, and one process
109
+ master_process = True
110
+ seed_offset = 0
111
+ ddp_world_size = 1
112
+ tokens_per_iter = gradient_accumulation_steps * ddp_world_size * batch_size * max_seq_len
113
+ if master_process:
114
+ print(f"tokens per iteration will be: {tokens_per_iter:,}")
115
+ print(f"breaks down as: {gradient_accumulation_steps} grad accum steps * {ddp_world_size} processes * {batch_size} batch size * {max_seq_len} max seq len")
116
+
117
+ if master_process:
118
+ os.makedirs(out_dir, exist_ok=True)
119
+ torch.manual_seed(1337 + seed_offset)
120
+ torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
121
+ torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
122
+ device_type = "cuda" if "cuda" in device else "cpu" # for later use in torch.autocast
123
+ # note: float16 data type will automatically use a GradScaler
124
+ ptdtype = {"float32": torch.float32, "bfloat16": torch.bfloat16, "float16": torch.float16}[dtype]
125
+ ctx = (
126
+ nullcontext()
127
+ if device_type == "cpu"
128
+ else torch.amp.autocast(device_type=device_type, dtype=ptdtype)
129
+ )
130
+
131
+ # task-specific setup
132
+ iter_batches = partial(
133
+ Task.iter_batches,
134
+ batch_size=batch_size,
135
+ max_seq_len=max_seq_len,
136
+ vocab_size=vocab_size,
137
+ vocab_source=vocab_source,
138
+ device=device,
139
+ num_workers=0,
140
+ )
141
+
142
+ # init these up here, can override if init_from='resume' (i.e. from a checkpoint)
143
+ iter_num = 0
144
+ best_val_loss = 1e9
145
+
146
+ # model init
147
+ model_args = dict(
148
+ dim=dim,
149
+ n_layers=n_layers,
150
+ n_heads=n_heads,
151
+ n_kv_heads=n_kv_heads,
152
+ vocab_size=vocab_size,
153
+ multiple_of=multiple_of,
154
+ max_seq_len=max_seq_len,
155
+ dropout=dropout,
156
+ ) # start with model_args from command line
157
+ if init_from == "scratch":
158
+ # init a new model from scratch
159
+ print("Initializing a new model from scratch")
160
+ gptconf = ModelArgs(**model_args)
161
+ model = Transformer(gptconf)
162
+ elif init_from == "resume":
163
+ print(f"Resuming training from {out_dir}")
164
+ # resume training from a checkpoint.
165
+ ckpt_path = os.path.join(out_dir, "ckpt.pt")
166
+ checkpoint = torch.load(ckpt_path, map_location=device)
167
+ checkpoint_model_args = checkpoint["model_args"]
168
+ # force these config attributes to be equal otherwise we can't even resume training
169
+ # the rest of the attributes (e.g. dropout) can stay as desired from command line
170
+ for k in ["dim", "n_layers", "n_heads", "n_kv_heads", "vocab_size", "multiple_of", "max_seq_len"]:
171
+ model_args[k] = checkpoint_model_args[k]
172
+ # create the model
173
+ gptconf = ModelArgs(**model_args)
174
+ model = Transformer(gptconf)
175
+ state_dict = checkpoint["model"]
176
+ # fix the keys of the state dictionary :(
177
+ # honestly no idea how checkpoints sometimes get this prefix, have to debug more
178
+ unwanted_prefix = "_orig_mod."
179
+ for k, v in list(state_dict.items()):
180
+ if k.startswith(unwanted_prefix):
181
+ state_dict[k[len(unwanted_prefix) :]] = state_dict.pop(k)
182
+ model.load_state_dict(state_dict)
183
+ iter_num = checkpoint["iter_num"]
184
+ best_val_loss = checkpoint["best_val_loss"]
185
+ model.to(device)
186
+
187
+ # initialize a GradScaler. If enabled=False scaler is a no-op
188
+ scaler = torch.cuda.amp.GradScaler(enabled=(dtype == "float16"))
189
+
190
+ # optimizer
191
+ optimizer = model.configure_optimizers(weight_decay, learning_rate, (beta1, beta2), device_type)
192
+ if init_from == "resume" and "optimizer" in checkpoint:
193
+ optimizer.load_state_dict(checkpoint["optimizer"])
194
+ checkpoint = None # free up memory
195
+
196
+ # compile the model
197
+ if compile:
198
+ print("compiling the model... (takes a ~minute)")
199
+ unoptimized_model = model
200
+ model = torch.compile(model) # requires PyTorch 2.0
201
+
202
+ # wrap model into DDP container
203
+ if ddp:
204
+ # Ignore the `freqs_cis` buffer so that DDP does not broadcast it at
205
+ # construction time since NCCL does not support `ComplexFloat`
206
+ prefix = "_orig_mod." if compile else ""
207
+ model._ddp_params_and_buffers_to_ignore = {prefix + "freqs_cis"}
208
+ model = DDP(model, device_ids=[ddp_local_rank])
209
+
210
+ # helps estimate an arbitrarily accurate loss over either split using many batches
211
+ @torch.no_grad()
212
+ def estimate_loss():
213
+ out = {}
214
+ model.eval()
215
+ for split in ["train", "val"]:
216
+ batch_iter = iter_batches(split=split)
217
+ losses = torch.zeros(eval_iters) # keep on CPU
218
+ for k in range(eval_iters):
219
+ X, Y = next(batch_iter)
220
+ with ctx:
221
+ logits = model(X, Y)
222
+ loss = raw_model.last_loss
223
+ losses[k] = loss.item()
224
+ out[split] = losses.mean()
225
+ model.train()
226
+ return out
227
+
228
+ # learning rate decay scheduler (cosine with warmup)
229
+ def get_lr(it):
230
+ # 1) linear warmup for warmup_iters steps
231
+ if it < warmup_iters:
232
+ return learning_rate * it / warmup_iters
233
+ # 2) if it > lr_decay_iters, return min learning rate
234
+ if it > lr_decay_iters:
235
+ return min_lr
236
+ # 3) in between, use cosine decay down to min learning rate
237
+ decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters)
238
+ assert 0 <= decay_ratio <= 1
239
+ coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1
240
+ return min_lr + coeff * (learning_rate - min_lr)
241
+
242
+ # logging
243
+ if wandb_log and master_process:
244
+ import wandb
245
+ wandb.init(project=wandb_project, name=wandb_run_name, config=config)
246
+
247
+ # training loop
248
+ train_batch_iter = iter_batches(split="train")
249
+ X, Y = next(train_batch_iter) # fetch the very first batch
250
+ t0 = time.time()
251
+ local_iter_num = 0 # number of iterations in the lifetime of this process
252
+ raw_model = model.module if ddp else model # unwrap DDP container if needed
253
+ running_mfu = -1.0
254
+ while True:
255
+ # determine and set the learning rate for this iteration
256
+ lr = get_lr(iter_num) if decay_lr else learning_rate
257
+ for param_group in optimizer.param_groups:
258
+ param_group["lr"] = lr
259
+
260
+ # evaluate the loss on train/val sets and write checkpoints
261
+ if iter_num % eval_interval == 0 and master_process:
262
+ losses = estimate_loss()
263
+ print(f"step {iter_num}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
264
+ if wandb_log:
265
+ try:
266
+ wandb.log(
267
+ {
268
+ "iter": iter_num,
269
+ "tokens": iter_num * tokens_per_iter,
270
+ "loss/train": losses["train"],
271
+ "loss/val": losses["val"],
272
+ "lr": lr,
273
+ "mfu": running_mfu * 100, # convert to percentage
274
+ }, step = iter_num
275
+ )
276
+ except Exception as e:
277
+ print(f"logging to wandb failed: {e}")
278
+ if losses["val"] < best_val_loss or always_save_checkpoint:
279
+ best_val_loss = losses["val"]
280
+ if iter_num > 0:
281
+ checkpoint = {
282
+ "model": raw_model.state_dict(),
283
+ "optimizer": optimizer.state_dict(),
284
+ "model_args": model_args,
285
+ "iter_num": iter_num,
286
+ "best_val_loss": best_val_loss,
287
+ "config": config,
288
+ }
289
+ print(f"saving checkpoint to {out_dir}")
290
+ torch.save(checkpoint, os.path.join(out_dir, "ckpt.pt"))
291
+ model_export(raw_model, os.path.join(out_dir, "model.bin"), version=0)
292
+ if iter_num == 0 and eval_only:
293
+ break
294
+
295
+ # forward backward update, with optional gradient accumulation to simulate larger batch size
296
+ # and using the GradScaler if data type is float16
297
+ for micro_step in range(gradient_accumulation_steps):
298
+ if ddp:
299
+ # in DDP training we only need to sync gradients at the last micro step.
300
+ # the official way to do this is with model.no_sync() context manager, but
301
+ # I really dislike that this bloats the code and forces us to repeat code
302
+ # looking at the source of that context manager, it just toggles this variable
303
+ model.require_backward_grad_sync = micro_step == gradient_accumulation_steps - 1
304
+ with ctx:
305
+ logits = model(X, Y)
306
+ loss = raw_model.last_loss
307
+ loss = loss / gradient_accumulation_steps
308
+ # immediately async prefetch next batch while model is doing the forward pass on the GPU
309
+ X, Y = next(train_batch_iter)
310
+ # backward pass, with gradient scaling if training in fp16
311
+ scaler.scale(loss).backward()
312
+ # clip the gradient
313
+ if grad_clip != 0.0:
314
+ scaler.unscale_(optimizer)
315
+ torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
316
+ # step the optimizer and scaler if training in fp16
317
+ scaler.step(optimizer)
318
+ scaler.update()
319
+ # flush the gradients as soon as we can, no need for this memory anymore
320
+ optimizer.zero_grad(set_to_none=True)
321
+
322
+ # timing and logging
323
+ t1 = time.time()
324
+ dt = t1 - t0
325
+ t0 = t1
326
+ if iter_num % log_interval == 0 and master_process:
327
+ # get loss as float, scale up due to the divide above. note: this is a CPU-GPU sync point
328
+ lossf = loss.item() * gradient_accumulation_steps
329
+ if local_iter_num >= 5: # let the training loop settle a bit
330
+ mfu = raw_model.estimate_mfu(batch_size * gradient_accumulation_steps, dt)
331
+ running_mfu = mfu if running_mfu == -1.0 else 0.9 * running_mfu + 0.1 * mfu
332
+ print(
333
+ f"{iter_num} | loss {lossf:.4f} | lr {lr:e} | {dt*1000:.2f}ms | mfu {running_mfu*100:.2f}%"
334
+ )
335
+ iter_num += 1
336
+ local_iter_num += 1
337
+
338
+ # termination conditions
339
+ if iter_num > max_iters:
340
+ break
341
+
342
+ if ddp:
343
+ destroy_process_group()
llama2.c/win.c ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "win.h"
2
+ #include <errno.h>
3
+ #include <io.h>
4
+
5
+ #ifndef FILE_MAP_EXECUTE
6
+ #define FILE_MAP_EXECUTE 0x0020
7
+ #endif /* FILE_MAP_EXECUTE */
8
+
9
+ static int __map_mman_error(const uint32_t err, const int deferr)
10
+ {
11
+ if (err == 0)
12
+ return 0;
13
+ //TODO: implement
14
+ return err;
15
+ }
16
+
17
+ static uint32_t __map_mmap_prot_page(const int prot)
18
+ {
19
+ uint32_t protect = 0;
20
+
21
+ if (prot == PROT_NONE)
22
+ return protect;
23
+
24
+ if ((prot & PROT_EXEC) != 0)
25
+ {
26
+ protect = ((prot & PROT_WRITE) != 0) ?
27
+ PAGE_EXECUTE_READWRITE : PAGE_EXECUTE_READ;
28
+ }
29
+ else
30
+ {
31
+ protect = ((prot & PROT_WRITE) != 0) ?
32
+ PAGE_READWRITE : PAGE_READONLY;
33
+ }
34
+
35
+ return protect;
36
+ }
37
+
38
+ static uint32_t __map_mmap_prot_file(const int prot)
39
+ {
40
+ uint32_t desiredAccess = 0;
41
+
42
+ if (prot == PROT_NONE)
43
+ return desiredAccess;
44
+
45
+ if ((prot & PROT_READ) != 0)
46
+ desiredAccess |= FILE_MAP_READ;
47
+ if ((prot & PROT_WRITE) != 0)
48
+ desiredAccess |= FILE_MAP_WRITE;
49
+ if ((prot & PROT_EXEC) != 0)
50
+ desiredAccess |= FILE_MAP_EXECUTE;
51
+
52
+ return desiredAccess;
53
+ }
54
+
55
+ void* mmap(void *addr, size_t len, int prot, int flags, int fildes, ssize_t off)
56
+ {
57
+ HANDLE fm, h;
58
+ void * map = MAP_FAILED;
59
+
60
+ #ifdef _MSC_VER
61
+ #pragma warning(push)
62
+ #pragma warning(disable: 4293)
63
+ #endif
64
+
65
+ const uint32_t dwFileOffsetLow = (uint32_t)(off & 0xFFFFFFFFL);
66
+ const uint32_t dwFileOffsetHigh = (uint32_t)((off >> 32) & 0xFFFFFFFFL);
67
+ const uint32_t protect = __map_mmap_prot_page(prot);
68
+ const uint32_t desiredAccess = __map_mmap_prot_file(prot);
69
+
70
+ const ssize_t maxSize = off + (ssize_t)len;
71
+
72
+ const uint32_t dwMaxSizeLow = (uint32_t)(maxSize & 0xFFFFFFFFL);
73
+ const uint32_t dwMaxSizeHigh = (uint32_t)((maxSize >> 32) & 0xFFFFFFFFL);
74
+
75
+ #ifdef _MSC_VER
76
+ #pragma warning(pop)
77
+ #endif
78
+
79
+ errno = 0;
80
+
81
+ if (len == 0
82
+ /* Unsupported flag combinations */
83
+ || (flags & MAP_FIXED) != 0
84
+ /* Unsupported protection combinations */
85
+ || prot == PROT_EXEC)
86
+ {
87
+ errno = EINVAL;
88
+ return MAP_FAILED;
89
+ }
90
+
91
+ h = ((flags & MAP_ANONYMOUS) == 0) ?
92
+ (HANDLE)_get_osfhandle(fildes) : INVALID_HANDLE_VALUE;
93
+
94
+ if ((flags & MAP_ANONYMOUS) == 0 && h == INVALID_HANDLE_VALUE)
95
+ {
96
+ errno = EBADF;
97
+ return MAP_FAILED;
98
+ }
99
+
100
+ fm = CreateFileMapping(h, NULL, protect, dwMaxSizeHigh, dwMaxSizeLow, NULL);
101
+
102
+ if (fm == NULL)
103
+ {
104
+ errno = __map_mman_error(GetLastError(), EPERM);
105
+ return MAP_FAILED;
106
+ }
107
+
108
+ map = MapViewOfFile(fm, desiredAccess, dwFileOffsetHigh, dwFileOffsetLow, len);
109
+
110
+ CloseHandle(fm);
111
+
112
+ if (map == NULL)
113
+ {
114
+ errno = __map_mman_error(GetLastError(), EPERM);
115
+ return MAP_FAILED;
116
+ }
117
+
118
+ return map;
119
+ }
120
+
121
+ int munmap(void *addr, size_t len)
122
+ {
123
+ if (UnmapViewOfFile(addr))
124
+ return 0;
125
+
126
+ errno = __map_mman_error(GetLastError(), EPERM);
127
+
128
+ return -1;
129
+ }
130
+
131
+ int mprotect(void *addr, size_t len, int prot)
132
+ {
133
+ uint32_t newProtect = __map_mmap_prot_page(prot);
134
+ uint32_t oldProtect = 0;
135
+
136
+ if (VirtualProtect(addr, len, newProtect, &oldProtect))
137
+ return 0;
138
+
139
+ errno = __map_mman_error(GetLastError(), EPERM);
140
+
141
+ return -1;
142
+ }
143
+
144
+ int msync(void *addr, size_t len, int flags)
145
+ {
146
+ if (FlushViewOfFile(addr, len))
147
+ return 0;
148
+
149
+ errno = __map_mman_error(GetLastError(), EPERM);
150
+
151
+ return -1;
152
+ }
153
+
154
+ int mlock(const void *addr, size_t len)
155
+ {
156
+ if (VirtualLock((LPVOID)addr, len))
157
+ return 0;
158
+
159
+ errno = __map_mman_error(GetLastError(), EPERM);
160
+
161
+ return -1;
162
+ }
163
+
164
+ int munlock(const void *addr, size_t len)
165
+ {
166
+ if (VirtualUnlock((LPVOID)addr, len))
167
+ return 0;
168
+
169
+ errno = __map_mman_error(GetLastError(), EPERM);
170
+
171
+ return -1;
172
+ }
173
+
174
+ // Portable clock_gettime function for Windows
175
+ int clock_gettime(int clk_id, struct timespec *tp) {
176
+ uint32_t ticks = GetTickCount();
177
+ tp->tv_sec = ticks / 1000;
178
+ tp->tv_nsec = (ticks % 1000) * 1000000;
179
+ return 0;
180
+ }
llama2.c/win.h ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #ifndef _WIN_H_
2
+ #define _WIN_H_
3
+
4
+ #define WIN32_LEAN_AND_MEAN // Exclude rarely-used stuff from Windows headers
5
+ #include <windows.h>
6
+ #include <time.h>
7
+ #include <stdint.h>
8
+
9
+ #define ssize_t int64_t
10
+ #define ftell _ftelli64
11
+
12
+ // Below code is originally from mman-win32
13
+ //
14
+ /*
15
+ * sys/mman.h
16
+ * mman-win32
17
+ */
18
+
19
+ #ifndef _WIN32_WINNT // Allow use of features specific to Windows XP or later.
20
+ #define _WIN32_WINNT 0x0501 // Change this to the appropriate value to target other versions of Windows.
21
+ #endif
22
+
23
+ /* All the headers include this file. */
24
+ #ifndef _MSC_VER
25
+ #include <_mingw.h>
26
+ #endif
27
+
28
+ #include <sys/types.h>
29
+
30
+ #ifdef __cplusplus
31
+ extern "C" {
32
+ #endif
33
+
34
+ #define PROT_NONE 0
35
+ #define PROT_READ 1
36
+ #define PROT_WRITE 2
37
+ #define PROT_EXEC 4
38
+
39
+ #define MAP_FILE 0
40
+ #define MAP_SHARED 1
41
+ #define MAP_PRIVATE 2
42
+ #define MAP_TYPE 0xf
43
+ #define MAP_FIXED 0x10
44
+ #define MAP_ANONYMOUS 0x20
45
+ #define MAP_ANON MAP_ANONYMOUS
46
+
47
+ #define MAP_FAILED ((void *)-1)
48
+
49
+ /* Flags for msync. */
50
+ #define MS_ASYNC 1
51
+ #define MS_SYNC 2
52
+ #define MS_INVALIDATE 4
53
+
54
+ /* Flags for portable clock_gettime call. */
55
+ #define CLOCK_REALTIME 0
56
+
57
+ void* mmap(void *addr, size_t len, int prot, int flags, int fildes, ssize_t off);
58
+ int munmap(void *addr, size_t len);
59
+ int mprotect(void *addr, size_t len, int prot);
60
+ int msync(void *addr, size_t len, int flags);
61
+ int mlock(const void *addr, size_t len);
62
+ int munlock(const void *addr, size_t len);
63
+ int clock_gettime(int clk_id, struct timespec *tp);
64
+
65
+ #ifdef __cplusplus
66
+ };
67
+ #endif
68
+
69
+ #endif /* _WIN_H_ */