Upload folder using huggingface_hub

#9
Files changed (19) hide show
  1. LICENSE +21 -0
  2. README.md +172 -0
  3. __pycache__/model.cpython-311.pyc +0 -0
  4. assets/llama_cute.jpg +0 -0
  5. build_msvc.bat +1 -0
  6. export.py +567 -0
  7. llama3_8b_instruct_q80.bin +3 -0
  8. model.py +343 -0
  9. requirements.txt +8 -0
  10. run.c +1027 -0
  11. rundll.h +19 -0
  12. runq +0 -0
  13. runq.c +1146 -0
  14. runqdll.c +1116 -0
  15. tokenizer.bin +3 -0
  16. tokenizer.model +3 -0
  17. tokenizer.py +115 -0
  18. win.c +167 -0
  19. win.h +69 -0
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2023 Andrej
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## llama3.c - A faithful clone of Karpathy's llama2.c but fully functional with LLaMA 3 8B base and instruct models.
2
+
3
+ See [Andrej Karpathy's repo](https://github.com/karpathy/llama2.c) for the real deal built for llama2.c architecture and many other cool models he has built.
4
+
5
+ <p align="center">
6
+ <img src="assets/llama_cute.jpg" width="300" height="300" alt="Cute Llama">
7
+ </p>
8
+
9
+ Have you ever wanted to inference a baby [Llama 3](https://ai.meta.com/llama/) model in pure C? No? Well, now you can!
10
+
11
+ Run LLaMA 3 8B models with one simple 700-line C file ([run.c](run.c)).
12
+
13
+ The current code inferences models in both fp32 and int8 (see below).
14
+
15
+ Please note that this repo is a modificaion of Andrej Karpathy's llama2.c but changing the hard coding to work with the modified-tiktoken tokenization used by the suite of Meta LLaMA 3 models.
16
+
17
+ ## getting started
18
+
19
+ First, navigate to the folder where you keep your projects and clone this repository to this folder:
20
+
21
+ ```bash
22
+ git clone https://github.com/jameswdelancey/llama3.c.git
23
+ ```
24
+
25
+ Then, open the repository folder:
26
+
27
+ ```bash
28
+ cd llama3.c
29
+ ```
30
+
31
+ ## which model do i download? base or instruct
32
+ - If you do not know, go with the instruct model. It will work with both the single shot "generate" mode and the "chat" mode of llama3.c.
33
+ - The "chat" mode of llama3.c only supports the instruct model and will surely not work with base model. You can try it for fun and learning at your own risk :).
34
+
35
+ Download LLaMA 3 8B base and/or instruct. The huggingface site works. You'll need to sign up and get approved.
36
+ Specifically download the `original` directory.
37
+
38
+ ```
39
+ https://huggingface.co/meta-llama/Meta-Llama-3-8B
40
+ https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct
41
+ ```
42
+
43
+ When downloading these models I did have to rename the original-params.json to params.json for the export.py to work.
44
+
45
+ ```
46
+ mv /d/llama3-8b-instruct/original_params.json /d/llama3-8b-instruct/params.json
47
+ ```
48
+
49
+ # compile and run the C code:
50
+
51
+ ```bash
52
+ gcc -Ofast run.c -o run
53
+ ./run.exe "llama3_8b_instruct.bin" -z "tokenizer_llama3.bin" -m chat
54
+ ./run.exe "llama3_8b_instruct.bin" -z "../dev/tokenizer_llama3.bin" -i "Once upon a time"
55
+ ```
56
+
57
+ # high performance
58
+
59
+ - fopenmp If you have these libraries you can run the model much faster. I'm running an Intel i3 14th gen and get 1.9 tok/s with openmp
60
+ - march=native This is required for gcc or clang to use SIMD intrinsics and will speed up your runs.
61
+ - win.c This is optional unless you're on Windows. I'm compiling with MINGW64 and it works well.
62
+ - gcc or clang, both work well and I get very close results between the two.
63
+
64
+ ```bash
65
+ $ gcc -Ofast -fopenmp -march=native run.c win.c -o run
66
+ ```
67
+
68
+ # base model single shot
69
+ This still runs at interactive rates and samples more coherent and diverse stories:
70
+
71
+ ```bash
72
+ ./run "llama3_8b_base.bin" -z "../dev/tokenizer_llama3.bin" -n 50 -i "Once upon a time"
73
+ ```
74
+
75
+ > Once upon a time, a girl named Lily who had grown up in a world where people werenΓÇÖt allowed to live happily ever after. One day, Lily decided to take a chance and do something that was a little bit crazy, something that she had
76
+
77
+ You can also prompt the model with a prefix or a number of additional command line arguments, e.g. to sample at temperature 0.8 for 256 steps and with a prompt:
78
+
79
+ ```bash
80
+ ./run "llama3_8b_base.bin" -z "../dev/tokenizer_llama3.bin" -t 0.8 -n 256 -i "One day, Lily met a Shoggoth"
81
+ ```
82
+
83
+ > One day, Lily met a Shoggoth named Lurking. She had never seen a Shoggoth before, but she was immediately impressed by its size and ferocity. They struck up a conversation, and Lurking told her about the ShoggothsΓÇÖ origins and their place in the universe. Lily was fascinated by what Lurking told her, and she learned a lot about the Shoggoths that day. She also made a new friend in Lurking.
84
+
85
+ # instruct model, chat
86
+
87
+ You're very likely to get results that you like better from the instruct model. You must use it with the chat mode as the token setup (preamble) is totally different. Here's an example:
88
+
89
+ ```bash
90
+ ./run "llama3_8b_instruct.bin" -z "../dev/tokenizer_llama3.bin" -m chat
91
+ ```
92
+
93
+ > Enter system prompt (optional): You are a professional software programmer with 20 years experience and an IQ of 200 and you love helping people.
94
+ >
95
+ > User: Give me a terse hello world in Rust
96
+ >
97
+ > Assistant: A delightful request!
98
+ >
99
+ > Here is a terse "Hello, World!" program in Rust:
100
+ > ```rust
101
+ > fn main() {
102
+ > println!("Hello, World!");
103
+ > }
104
+ > ```
105
+ > This program defines a `main` function that calls the `println!` macro to print the string "Hello, World!" followed by a newline character.
106
+ >
107
+ > That's it! You can run this program using the `rustc` compiler and then executing the resulting executable file.
108
+ >
109
+ > User:
110
+
111
+ ## Meta's Llama 3 models
112
+
113
+ Sadly there is a bit of friction here due to licensing (I can't directly upload the checkpoints, I think). So Step 1, get the Llama 3 checkpoints by following the [Meta instructions](https://github.com/facebookresearch/llama3). Once we have those checkpoints, we have to convert them into the llama3.c format.
114
+ For this we need to install the python dependencies (`pip install -r requirements.txt`) and then use the `export.py` file, e.g. for 8B model:
115
+
116
+ ```bash
117
+ python export.py llama3_8b_base.bin --meta-llama ../llama3-8b-base/
118
+ ```
119
+
120
+ The export will take ~10 minutes or so and generate a 31GB file (the weights of the 8B model in float32) called `llama3_8b.bin` in the current directory. Once the export is done, we can run it:
121
+
122
+ ```bash
123
+ ./run "llama3_8b_base.bin" -z "../dev/tokenizer_llama3.bin"
124
+ ```
125
+
126
+ This ran at about 2 tokens/s compiled with [OpenMP](#OpenMP) on 8 threads on my Intel i3 14th gen. Example output:
127
+
128
+ ```bash
129
+ ./run "llama3_8b_base.bin" -z "../dev/tokenizer_llama3.bin"
130
+ ```
131
+ > Question:
132
+ > What is the second derivative of 2*p**3 - 4*p**2 - 12*p?
133
+ > Answer:
134
+ > 12*p - 8
135
+
136
+ base models... ¯\\_(ツ)_/¯. Since we can inference the base model, it should be possible to also inference the instruct model quite easily, and have a conversation with it. And if we can find a way to run 7B more efficiently, we can start adding LoRA to our training script, and going wild with finetunes all within the repo!
137
+
138
+ You can also chat with the Llama Chat models. Export the chat model exactly as above:
139
+
140
+ ```bash
141
+ python export.py llama3_8b_instruct.bin --meta-llama ../llama3-8b-instruct/
142
+ ```
143
+
144
+ Then chat with it by specifying the chat mode using the `-m` flag, e.g.:
145
+
146
+ ```bash
147
+ ./run "llama3_8b_instruct.bin" -z "../dev/tokenizer_llama3.bin" -m chat
148
+ ```
149
+ ## Int8 Quantization
150
+
151
+ Compile the quantized version of the runtime:
152
+ ```bash
153
+ gcc -Ofast -fopenmp -march=native runq.c win.c -o runq
154
+ ```
155
+ Export a quantized version of the model. It is about 8GB vs 31.
156
+
157
+ ```bash
158
+ python export.py llama3_8b_instruct_q80.bin --meta-llama ../llama3-8b-base/ --version 2
159
+ ```
160
+
161
+ The export will take ~10 minutes or so. Once the export is done, we can run it:
162
+
163
+ ```bash
164
+ ./runq "llama3_8b_instruct_q80.bin" -z "../dev/tokenizer_llama3.bin"
165
+ ```
166
+
167
+ This ran at about 4 tokens/s compiled with [OpenMP](#OpenMP) on 8 threads on my Intel i3 14th gen. Example output:
168
+
169
+
170
+ ## License
171
+
172
+ MIT
__pycache__/model.cpython-311.pyc ADDED
Binary file (25.5 kB). View file
 
assets/llama_cute.jpg ADDED
build_msvc.bat ADDED
@@ -0,0 +1 @@
 
 
1
+ cl.exe /fp:fast /Ox /openmp /I. run.c win.c
export.py ADDED
@@ -0,0 +1,567 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This script has functions and utilties for model export.
3
+ Basically, we have a bunch of versions of the model, and we
4
+ want to export them to .bin files to be read from and inferenced in C.
5
+
6
+ Among the "input" versions of PyTorch files/models:
7
+ - Official Llama 2 weights released by Meta
8
+ - Huggingface weights available on the hub
9
+ - llama2.c (this repo) trained models
10
+
11
+ Among the "output" versions of .bin files:
12
+ - v0: Legacy files of the original llama2.c repo (will eventually be DEPRECATED)
13
+ - v1-vN: Improved .bin files with a proper header, cache alignment, etc.
14
+
15
+ This script aspires to provide all of these conversions.
16
+ """
17
+ import os
18
+ import gzip
19
+ import shutil
20
+ import struct
21
+ import argparse
22
+ import json
23
+ from pathlib import Path
24
+
25
+ import numpy as np
26
+ import torch
27
+ from torch import nn
28
+
29
+ from model import ModelArgs, Transformer
30
+
31
+ # -----------------------------------------------------------------------------
32
+ # common utilities
33
+
34
+ def serialize_fp32(file, tensor):
35
+ """ writes one fp32 tensor to file that is open in wb mode """
36
+ d = tensor.detach().cpu().view(-1).to(torch.float32).numpy()
37
+ b = struct.pack(f'{len(d)}f', *d)
38
+ file.write(b)
39
+
40
+ def serialize_int8(file, tensor):
41
+ """ writes one int8 tensor to file that is open in wb mode """
42
+ d = tensor.detach().cpu().view(-1).numpy().astype(np.int8)
43
+ b = struct.pack(f'{len(d)}b', *d)
44
+ file.write(b)
45
+
46
+ def quantize_q80(w, group_size):
47
+ """
48
+ takes a tensor and returns the Q8_0 quantized version
49
+ i.e. symmetric quantization into int8, range [-127,127]
50
+ """
51
+ assert w.numel() % group_size == 0
52
+ ori_shape = w.shape
53
+ w = w.float() # convert to float32
54
+ w = w.reshape(-1, group_size)
55
+ # find the max in each group
56
+ wmax = torch.abs(w).max(dim=1).values
57
+ # calculate the scaling factor such that float = quant * scale
58
+ scale = wmax / 127.0
59
+ # scale into range [-127, 127]
60
+ quant = w / scale[:,None]
61
+ # round to nearest integer
62
+ int8val = torch.round(quant).to(torch.int8)
63
+ # dequantize by rescaling
64
+ fp32val = (int8val.float() * scale[:,None]).view(-1)
65
+ fp32valr = fp32val.reshape(-1, group_size)
66
+ # calculate the max error in each group
67
+ err = torch.abs(fp32valr - w).max(dim=1).values
68
+ # find the max error across all groups
69
+ maxerr = err.max().item()
70
+ return int8val, scale, maxerr
71
+
72
+ # -----------------------------------------------------------------------------
73
+ # legacy
74
+
75
+ def legacy_export(model, filepath):
76
+ """ Original export of llama2.c bin files, i.e. version v0 """
77
+ out_file = open(filepath, 'wb')
78
+
79
+ # first write out the header
80
+ hidden_dim = model.layers[0].feed_forward.w1.weight.shape[0]
81
+ p = model.params
82
+ shared_classifier = torch.equal(model.tok_embeddings.weight, model.output.weight)
83
+ # legacy format uses negative/positive vocab size as a shared classifier flag
84
+ if not shared_classifier:
85
+ p.vocab_size = -p.vocab_size
86
+ n_kv_heads = p.n_heads if p.n_kv_heads is None else p.n_kv_heads
87
+ header = struct.pack('iiiiiii', p.dim, hidden_dim, p.n_layers, p.n_heads,
88
+ n_kv_heads, p.vocab_size, p.max_seq_len)
89
+ out_file.write(header)
90
+
91
+ # next write out the embedding weights
92
+ serialize_fp32(out_file, model.tok_embeddings.weight)
93
+
94
+ # now all the layers
95
+ # attention weights
96
+ for layer in model.layers:
97
+ serialize_fp32(out_file, layer.attention_norm.weight)
98
+ for layer in model.layers:
99
+ serialize_fp32(out_file, layer.attention.wq.weight)
100
+ for layer in model.layers:
101
+ serialize_fp32(out_file, layer.attention.wk.weight)
102
+ for layer in model.layers:
103
+ serialize_fp32(out_file, layer.attention.wv.weight)
104
+ for layer in model.layers:
105
+ serialize_fp32(out_file, layer.attention.wo.weight)
106
+ # ffn weights
107
+ for layer in model.layers:
108
+ serialize_fp32(out_file, layer.ffn_norm.weight)
109
+ for layer in model.layers:
110
+ serialize_fp32(out_file, layer.feed_forward.w1.weight)
111
+ for layer in model.layers:
112
+ serialize_fp32(out_file, layer.feed_forward.w2.weight)
113
+ for layer in model.layers:
114
+ serialize_fp32(out_file, layer.feed_forward.w3.weight)
115
+ # final rmsnorm
116
+ serialize_fp32(out_file, model.norm.weight)
117
+ # freqs_cis
118
+ serialize_fp32(out_file, model.freqs_cos[:p.max_seq_len])
119
+ serialize_fp32(out_file, model.freqs_sin[:p.max_seq_len])
120
+
121
+ # final classifier weights
122
+ if not shared_classifier:
123
+ serialize_fp32(out_file, model.output.weight)
124
+
125
+ # write to binary file
126
+ out_file.close()
127
+ print(f"wrote {filepath}")
128
+
129
+ # -----------------------------------------------------------------------------
130
+ # new version
131
+
132
+ def version1_export(model, filepath):
133
+ """
134
+ Export the model weights in full float32 .bin file to be read from C.
135
+ This is same as legacy_export, but with a proper header.
136
+ """
137
+ version = 1
138
+
139
+ out_file = open(filepath, 'wb')
140
+ # first write out the header. the header will be 256 bytes
141
+ # 1) write magic, which will be uint32 of "ak42" in ASCII
142
+ out_file.write(struct.pack('I', 0x616b3432))
143
+ # 2) write version, which will be int
144
+ out_file.write(struct.pack('i', version))
145
+ # 3) write the params, which will be 7 ints
146
+ p = model.params
147
+ hidden_dim = model.layers[0].feed_forward.w1.weight.shape[0]
148
+ n_kv_heads = p.n_heads if p.n_kv_heads is None else p.n_kv_heads
149
+ header = struct.pack('iiiiiii', p.dim, hidden_dim, p.n_layers, p.n_heads,
150
+ n_kv_heads, p.vocab_size, p.max_seq_len)
151
+ out_file.write(header)
152
+ # 4) write some other flags
153
+ shared_classifier = torch.equal(model.tok_embeddings.weight, model.output.weight)
154
+ out_file.write(struct.pack('B', int(shared_classifier)))
155
+ pad = 256 - out_file.tell() # pad rest with zeros; tell returns current pos
156
+ assert pad >= 0
157
+ out_file.write(b'\0' * pad)
158
+
159
+ # now let's write out all the params
160
+ weights = [
161
+ *[layer.attention_norm.weight for layer in model.layers],
162
+ *[layer.ffn_norm.weight for layer in model.layers],
163
+ model.norm.weight,
164
+ model.tok_embeddings.weight,
165
+ *[layer.attention.wq.weight for layer in model.layers],
166
+ *[layer.attention.wk.weight for layer in model.layers],
167
+ *[layer.attention.wv.weight for layer in model.layers],
168
+ *[layer.attention.wo.weight for layer in model.layers],
169
+ *[layer.feed_forward.w1.weight for layer in model.layers],
170
+ *[layer.feed_forward.w2.weight for layer in model.layers],
171
+ *[layer.feed_forward.w3.weight for layer in model.layers],
172
+ ]
173
+ if not shared_classifier:
174
+ weights.append(model.output.weight)
175
+ for w in weights:
176
+ serialize_fp32(out_file, w)
177
+
178
+ # write to binary file
179
+ out_file.close()
180
+ print(f"wrote {filepath}")
181
+
182
+ def version2_export(model, filepath, group_size=64):
183
+ """
184
+ Export the model weights in Q8_0 into .bin file to be read from C.
185
+ That is:
186
+ - quantize all weights to symmetric int8, in range [-127, 127]
187
+ - all other tensors (the rmsnorm params) are kept and exported in fp32
188
+ - quantization is done in groups of group_size to reduce the effects of any outliers
189
+ """
190
+ version = 2
191
+
192
+ # let's first do some validation for this export type
193
+ while model.params.dim % group_size != 0:
194
+ group_size //= 2
195
+ print(f"BACKOFF: reducing group size to {group_size} to fit hidden_dim")
196
+ weights = [
197
+ model.tok_embeddings.weight,
198
+ *[layer.attention.wq.weight for layer in model.layers],
199
+ *[layer.attention.wk.weight for layer in model.layers],
200
+ *[layer.attention.wv.weight for layer in model.layers],
201
+ *[layer.attention.wo.weight for layer in model.layers],
202
+ *[layer.feed_forward.w1.weight for layer in model.layers],
203
+ *[layer.feed_forward.w2.weight for layer in model.layers],
204
+ *[layer.feed_forward.w3.weight for layer in model.layers],
205
+ ]
206
+ shared_classifier = torch.equal(model.tok_embeddings.weight, model.output.weight)
207
+ if not shared_classifier:
208
+ weights.append(model.output.weight)
209
+ for w in weights:
210
+ assert w.numel() % group_size == 0, f"weight {i} has numel {w.numel()}, not a multiple of group_size {group_size}"
211
+
212
+ # write
213
+ out_file = open(filepath, 'wb')
214
+ # first write out the header. the header will be 256 bytes
215
+ # 1) write magic, which will be uint32 of "ak42" in ASCII
216
+ out_file.write(struct.pack('I', 0x616b3432))
217
+ # 2) write version, which will be int
218
+ out_file.write(struct.pack('i', version))
219
+ # 3) write the params, which will be 7 ints
220
+ p = model.params
221
+ hidden_dim = model.layers[0].feed_forward.w1.weight.shape[0]
222
+ n_kv_heads = p.n_heads if p.n_kv_heads is None else p.n_kv_heads
223
+ header = struct.pack('iiiiiii', p.dim, hidden_dim, p.n_layers, p.n_heads,
224
+ n_kv_heads, p.vocab_size, p.max_seq_len)
225
+ out_file.write(header)
226
+ # 4) write some other flags
227
+ out_file.write(struct.pack('B', int(shared_classifier)))
228
+ out_file.write(struct.pack('i', group_size)) # group size used for quantization
229
+ pad = 256 - out_file.tell() # pad rest with zeros; tell returns current pos
230
+ assert pad >= 0
231
+ out_file.write(b'\0' * pad)
232
+ # now that the header is done, let's write out the model
233
+
234
+ # first let's write out all the params that we are keeping in fp32: the norms
235
+ for layer in model.layers: # attention norms
236
+ serialize_fp32(out_file, layer.attention_norm.weight)
237
+ for layer in model.layers: # MLP norms
238
+ serialize_fp32(out_file, layer.ffn_norm.weight)
239
+ serialize_fp32(out_file, model.norm.weight) # final pre-classifier norm
240
+
241
+ # now let's write out all the params that we are quantizing to Q8_0
242
+ # note we skip classifier weights, which are shared with the embedding
243
+ ew = []
244
+ for i, w in enumerate(weights):
245
+ # quantize this weight
246
+ q, s, err = quantize_q80(w, group_size)
247
+ # save the int8 weights to file
248
+ serialize_int8(out_file, q) # save the tensor in int8
249
+ serialize_fp32(out_file, s) # save scale factors
250
+ # logging
251
+ ew.append((err, w.shape))
252
+ print(f"{i+1}/{len(weights)} quantized {tuple(w.shape)} to Q8_0 with max error {err}")
253
+
254
+ # print the highest error across all weights, should be very small, e.g. O(~0.001)
255
+ ew.sort(reverse=True)
256
+ print(f"max quantization group error across all weights: {ew[0][0]}")
257
+
258
+ # write to binary file
259
+ out_file.close()
260
+ print(f"wrote {filepath}")
261
+
262
+ def hf_export(llama_model, filepath, group_size=64, dtype=torch.float32):
263
+ """ Generate the pytorch_model.bin state_dict and config.json for HuggingFace """
264
+
265
+ try:
266
+ from transformers.models.llama.configuration_llama import LlamaConfig
267
+ except ImportError:
268
+ print("Error: transformers package is required to load huggingface models")
269
+ print("Please run `pip install transformers` to install it")
270
+ return None
271
+
272
+ # Generate LlamaModel state_dict
273
+ hf_state_dict = {}
274
+
275
+ # Sometimes we have repeated key values for the heads
276
+ dim = llama_model.params.dim
277
+ num_key_value_heads = llama_model.params.n_kv_heads
278
+ n_rep = llama_model.params.n_heads // num_key_value_heads
279
+ key_value_dim = dim // n_rep
280
+
281
+ # HuggingFace needs the weights permuted.
282
+ # See: https://github.com/huggingface/transformers/blob/b132c1703eb1c8bd9dfa4ad6a9be2bfd6ef819e9/src/transformers/models/llama/convert_llama_weights_to_hf.py#L122
283
+ def permute_original(w, n_heads=llama_model.params.n_heads, dim1=dim, dim2=dim):
284
+ return w.view(dim1, dim2).reshape(n_heads, dim1 // n_heads // 2, 2, dim2).transpose(1, 2).reshape(dim1, dim2)
285
+
286
+ # Transfer weights from llama model to the HF state dictionary format
287
+ hf_state_dict['model.embed_tokens.weight'] = llama_model.tok_embeddings.weight.clone().to(dtype)
288
+ hf_state_dict['model.norm.weight'] = llama_model.norm.weight.clone().to(dtype)
289
+
290
+ # Add each layer's weights to the HF state dictionary
291
+ for i, layer in enumerate(llama_model.layers):
292
+ layer_id = layer.layer_id
293
+ hf_state_dict[f'model.layers.{i}.input_layernorm.weight'] = llama_model.layers[layer_id].attention_norm.weight.clone().to(dtype)
294
+ hf_state_dict[f'model.layers.{i}.self_attn.q_proj.weight'] = permute_original(llama_model.layers[layer_id].attention.wq.weight.clone()).to(dtype)
295
+ hf_state_dict[f'model.layers.{i}.self_attn.k_proj.weight'] = permute_original(llama_model.layers[layer_id].attention.wk.weight.clone(), num_key_value_heads, key_value_dim, dim).to(dtype)
296
+ hf_state_dict[f'model.layers.{i}.self_attn.v_proj.weight'] = llama_model.layers[layer_id].attention.wv.weight.clone().to(dtype)
297
+ hf_state_dict[f'model.layers.{i}.self_attn.o_proj.weight'] = llama_model.layers[layer_id].attention.wo.weight.clone().to(dtype)
298
+ hf_state_dict[f'model.layers.{i}.post_attention_layernorm.weight'] = llama_model.layers[layer_id].ffn_norm.weight.clone().to(dtype)
299
+ hf_state_dict[f'model.layers.{i}.mlp.gate_proj.weight'] = llama_model.layers[layer_id].feed_forward.w1.weight.clone().to(dtype)
300
+ hf_state_dict[f'model.layers.{i}.mlp.down_proj.weight'] = llama_model.layers[layer_id].feed_forward.w2.weight.clone().to(dtype)
301
+ hf_state_dict[f'model.layers.{i}.mlp.up_proj.weight'] = llama_model.layers[layer_id].feed_forward.w3.weight.clone().to(dtype)
302
+
303
+ # llama2.c usually uses tied weights -> reference the embed_tokens.weights instead
304
+ hf_state_dict['lm_head.weight'] = hf_state_dict['model.embed_tokens.weight']
305
+
306
+ # We check that the embeddings are tied, else use manual output weights
307
+ _embeddings_are_tied: bool = torch.equal(llama_model.tok_embeddings.weight, llama_model.output.weight)
308
+ if not _embeddings_are_tied:
309
+ hf_state_dict['lm_head.weight'] = llama_model.output.weight.clone().to(dtype)
310
+
311
+
312
+ # Generate LlamaConfig (seen in transformers.models.llama.configuration_llama)
313
+
314
+ # Extract necessary attributes from llama.c model
315
+ vocab_size = llama_model.params.vocab_size
316
+ hidden_size = llama_model.params.dim
317
+ intermediate_size = llama_model.layers[0].feed_forward.w1.weight.shape[0]
318
+ num_hidden_layers = llama_model.params.n_layers
319
+ num_attention_heads = llama_model.params.n_heads
320
+ num_key_value_heads = llama_model.params.n_kv_heads
321
+ max_position_embeddings = llama_model.params.max_seq_len
322
+ rms_norm_eps = llama_model.params.norm_eps
323
+
324
+ # TODO check values for:
325
+ # pretraining_tp, initializer_range, use_cache,
326
+ # rope_theta, and rope_scaling.
327
+
328
+ config = LlamaConfig(
329
+ vocab_size=vocab_size,
330
+ hidden_size=hidden_size,
331
+ intermediate_size=intermediate_size,
332
+ num_hidden_layers=num_hidden_layers,
333
+ num_attention_heads=num_attention_heads,
334
+ num_key_value_heads=num_key_value_heads,
335
+ max_position_embeddings=max_position_embeddings,
336
+ rms_norm_eps=rms_norm_eps,
337
+ tie_word_embeddings=_embeddings_are_tied,
338
+ # Manual
339
+ architectures=["LlamaForCausalLM"],
340
+ hidden_act="silu",
341
+ )
342
+
343
+
344
+ # Save files in directory filepath
345
+ # First make the directory if it doesn't exist
346
+ os.makedirs(filepath, exist_ok=True)
347
+
348
+ # Save the state dictionary in .bin format, and config as .json
349
+ torch.save(hf_state_dict, os.path.join(filepath, "pytorch_model.bin"))
350
+ config.save_pretrained(filepath)
351
+
352
+
353
+ # -----------------------------------------------------------------------------
354
+ # Load / import functions
355
+
356
+ def load_checkpoint(checkpoint):
357
+
358
+ # load the provided model checkpoint
359
+ checkpoint_dict = torch.load(checkpoint, map_location='cpu')
360
+ gptconf = ModelArgs(**checkpoint_dict['model_args'])
361
+ model = Transformer(gptconf)
362
+ state_dict = checkpoint_dict['model']
363
+ unwanted_prefix = '_orig_mod.'
364
+ for k,v in list(state_dict.items()):
365
+ if k.startswith(unwanted_prefix):
366
+ state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
367
+ model.load_state_dict(state_dict, strict=False)
368
+ model.eval()
369
+ return model
370
+
371
+ def load_meta_model(model_path):
372
+ params_path = os.path.join(model_path, 'params.json')
373
+ with open(params_path) as f:
374
+ params = json.load(f)
375
+ print(params)
376
+
377
+ model_paths = sorted(list(Path(model_path).glob('consolidated.*.pth')))
378
+ models = [torch.load(p, map_location='cpu') for p in model_paths]
379
+
380
+ def concat_weights(models):
381
+ state_dict = {}
382
+ for name in list(models[0]):
383
+ tensors = [model[name] for model in models]
384
+ if len(tensors) == 1 or len(tensors[0].shape) == 1:
385
+ state_dict[name] = tensors[0]
386
+ continue
387
+ is_axis_1 = (
388
+ name.startswith('tok_embeddings.')
389
+ or name.endswith('.attention.wo.weight')
390
+ or name.endswith('.feed_forward.w2.weight')
391
+ )
392
+ axis = 1 if is_axis_1 else 0
393
+ state_dict[name] = torch.cat(tensors, dim=axis)
394
+ for model in models:
395
+ del model[name]
396
+ return state_dict
397
+
398
+ state_dict = concat_weights(models)
399
+ del models
400
+
401
+ # set ModelArgs
402
+ config = ModelArgs()
403
+ config.dim = params["dim"]
404
+ config.n_layers = params["n_layers"]
405
+ config.n_heads = params["n_heads"]
406
+ config.n_kv_heads = params.get('n_kv_heads') or params['n_heads']
407
+ config.multiple_of = params["multiple_of"]
408
+ config.norm_eps = params["norm_eps"]
409
+
410
+ config.vocab_size = state_dict['tok_embeddings.weight'].shape[0]
411
+ config.max_seq_len = 2048
412
+
413
+
414
+ # create a new Transformer object and set weights
415
+ model = Transformer(config)
416
+
417
+ model.tok_embeddings.weight = nn.Parameter(state_dict['tok_embeddings.weight'])
418
+ model.norm.weight = nn.Parameter(state_dict['norm.weight'])
419
+
420
+ for layer in model.layers:
421
+ i = layer.layer_id
422
+ layer.attention_norm.weight = nn.Parameter(state_dict[f'layers.{i}.attention_norm.weight'])
423
+ layer.attention.wq.weight = nn.Parameter(state_dict[f'layers.{i}.attention.wq.weight'])
424
+ layer.attention.wk.weight = nn.Parameter(state_dict[f'layers.{i}.attention.wk.weight'])
425
+ layer.attention.wv.weight = nn.Parameter(state_dict[f'layers.{i}.attention.wv.weight'])
426
+ layer.attention.wo.weight = nn.Parameter(state_dict[f'layers.{i}.attention.wo.weight'])
427
+ layer.ffn_norm.weight = nn.Parameter(state_dict[f'layers.{i}.ffn_norm.weight'])
428
+ layer.feed_forward.w1.weight = nn.Parameter(state_dict[f'layers.{i}.feed_forward.w1.weight'])
429
+ layer.feed_forward.w2.weight = nn.Parameter(state_dict[f'layers.{i}.feed_forward.w2.weight'])
430
+ layer.feed_forward.w3.weight = nn.Parameter(state_dict[f'layers.{i}.feed_forward.w3.weight'])
431
+
432
+ # final classifier
433
+ model.output.weight = nn.Parameter(state_dict['output.weight'])
434
+ model.eval()
435
+ return model
436
+
437
+ def load_hf_model(model_path):
438
+
439
+ try:
440
+ from transformers import AutoModelForCausalLM
441
+ except ImportError:
442
+ print("Error: transformers package is required to load huggingface models")
443
+ print("Please run `pip install transformers` to install it")
444
+ return None
445
+
446
+ # load HF model
447
+ hf_model = AutoModelForCausalLM.from_pretrained(model_path)
448
+ hf_dict = hf_model.state_dict()
449
+
450
+ # convert LlamaConfig to ModelArgs
451
+ config = ModelArgs()
452
+ config.dim = hf_model.config.hidden_size
453
+ config.n_layers = hf_model.config.num_hidden_layers
454
+ config.n_heads = hf_model.config.num_attention_heads
455
+ config.n_kv_heads = hf_model.config.num_attention_heads
456
+ config.vocab_size = hf_model.config.vocab_size
457
+ config.hidden_dim = hf_model.config.intermediate_size
458
+ config.norm_eps = hf_model.config.rms_norm_eps
459
+ config.max_seq_len = hf_model.config.max_position_embeddings
460
+
461
+ # create a new Transformer object and set weights
462
+ model = Transformer(config)
463
+
464
+ model.tok_embeddings.weight = nn.Parameter(hf_dict['model.embed_tokens.weight'])
465
+ model.norm.weight = nn.Parameter(hf_dict['model.norm.weight'])
466
+
467
+ # huggingface permutes WQ and WK, this function reverses it
468
+ def permute_reverse(w, n_heads=config.n_heads, dim1=config.dim, dim2=config.dim):
469
+ return w.view(n_heads, 2, dim1 // n_heads // 2, dim2).transpose(1, 2).reshape(dim1, dim2)
470
+
471
+ for layer in model.layers:
472
+ i = layer.layer_id
473
+ layer.attention_norm.weight = nn.Parameter(hf_dict[f'model.layers.{i}.input_layernorm.weight'])
474
+ layer.attention.wq.weight = nn.Parameter(permute_reverse(hf_dict[f'model.layers.{i}.self_attn.q_proj.weight']))
475
+ layer.attention.wk.weight = nn.Parameter(permute_reverse(hf_dict[f'model.layers.{i}.self_attn.k_proj.weight']))
476
+ layer.attention.wv.weight = nn.Parameter(hf_dict[f'model.layers.{i}.self_attn.v_proj.weight'])
477
+ layer.attention.wo.weight = nn.Parameter(hf_dict[f'model.layers.{i}.self_attn.o_proj.weight'])
478
+ layer.ffn_norm.weight = nn.Parameter(hf_dict[f'model.layers.{i}.post_attention_layernorm.weight'])
479
+ layer.feed_forward.w1.weight = nn.Parameter(hf_dict[f'model.layers.{i}.mlp.gate_proj.weight'])
480
+ layer.feed_forward.w2.weight = nn.Parameter(hf_dict[f'model.layers.{i}.mlp.down_proj.weight'])
481
+ layer.feed_forward.w3.weight = nn.Parameter(hf_dict[f'model.layers.{i}.mlp.up_proj.weight'])
482
+
483
+ # final classifier
484
+ model.output.weight = nn.Parameter(hf_dict['lm_head.weight'])
485
+ model.eval()
486
+ return model
487
+
488
+
489
+ # -----------------------------------------------------------------------------
490
+ # API entrypoint
491
+
492
+ def model_export(model, filepath, version, dtype=torch.float32):
493
+ """
494
+ Versions docs:
495
+ v-1:huggingface export, i.e. intended for use outside of this repo, in HF
496
+ v0: legacy llama2.c float format, DEPRECATED
497
+ v1: float32 export
498
+ v2: int8 quantized Q8_0 export, similar to llama.cpp, in groups
499
+ # TODO: add dtype export support for other versions (?)
500
+ """
501
+ if version == 0:
502
+ legacy_export(model, filepath)
503
+ elif version == 1:
504
+ version1_export(model, filepath)
505
+ elif version == 2:
506
+ version2_export(model, filepath)
507
+ elif version == -1:
508
+ hf_export(model, filepath, dtype)
509
+ else:
510
+ raise ValueError(f"unknown version {version}")
511
+
512
+ def torchscript_export(model, filepath, zero_params=False, gzip_output=False):
513
+ """
514
+ (This was submitted via a PR earlier. Leaving it here, but "orphaned" for now)
515
+ Saves the model as a TorchScript.
516
+ The resulting file can be loaded in C++ code and then used for training or
517
+ inference with:
518
+ #include <torch/script.h>
519
+ torch::jit::Module module = torch::jit::load("model.pt")
520
+ Note that the serialized model includes the initial parameters and with the default
521
+ ModelArgs the file is 59M and gzips down to 55M. If you want to serialize/distribute
522
+ the model parameters separately you can zero out the parameters before saving it and
523
+ it will gzip down to 780K.
524
+ """
525
+
526
+ # If requested zero params before saving the model. This is useful in
527
+ # conjunction with gzip_output.
528
+ if zero_params:
529
+ for p in model.parameters():
530
+ p.detach().zero_()
531
+
532
+ torch.jit.save(torch.jit.script(model), filepath)
533
+
534
+ if gzip_output:
535
+ with open(filepath, "rb") as f_in:
536
+ with gzip.open(f"{filepath}.gz", "wb") as f_out:
537
+ shutil.copyfileobj(f_in, f_out)
538
+ os.unlink(filepath)
539
+
540
+ # -----------------------------------------------------------------------------
541
+ # CLI entrypoint
542
+
543
+ if __name__ == "__main__":
544
+
545
+ parser = argparse.ArgumentParser()
546
+ parser.add_argument("filepath", type=str, help="the output filepath")
547
+ parser.add_argument("--version", default=0, type=int, help="the version to export with")
548
+ parser.add_argument("--dtype", type=str, help="dtype of the model (fp16, fp32)", default="fp32")
549
+ group = parser.add_mutually_exclusive_group(required=True)
550
+ group.add_argument("--checkpoint", type=str, help="model checkpoint, .pt file")
551
+ group.add_argument("--meta-llama", type=str, help="meta llama model path")
552
+ group.add_argument("--hf", type=str, help="huggingface model path")
553
+ args = parser.parse_args()
554
+ dtype = {"fp16": torch.float16, "fp32": torch.float32}[args.dtype]
555
+
556
+ if args.checkpoint:
557
+ model = load_checkpoint(args.checkpoint)
558
+ elif args.meta_llama:
559
+ model = load_meta_model(args.meta_llama)
560
+ elif args.hf:
561
+ model = load_hf_model(args.hf)
562
+
563
+ if model is None:
564
+ parser.error("Can't load input model!")
565
+
566
+ # export
567
+ model_export(model, args.filepath, args.version, args.dtype)
llama3_8b_instruct_q80.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:47e71d38198cf15f82fc2f6836052965c6889b9063edfd773bf69a1d5089c1c5
3
+ size 8532934912
model.py ADDED
@@ -0,0 +1,343 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import struct
3
+ import inspect
4
+ from dataclasses import dataclass
5
+ from typing import Any, Optional, Tuple
6
+
7
+ import numpy as np
8
+ import torch
9
+ import torch.nn.functional as F
10
+ from torch import nn
11
+
12
+ @dataclass
13
+ class ModelArgs:
14
+ # default hyperparameters for the Llama 7B model
15
+ dim: int = 4096
16
+ n_layers: int = 32
17
+ n_heads: int = 32
18
+ n_kv_heads: Optional[int] = None
19
+ vocab_size: int = 32000
20
+ hidden_dim: Optional[int] = None
21
+ multiple_of: int = 256 # MLP hidden layer size will be multiple of
22
+ norm_eps: float = 1e-5
23
+ max_seq_len: int = 2048
24
+ dropout: float = 0.0
25
+
26
+
27
+ class RMSNorm(torch.nn.Module):
28
+ def __init__(self, dim: int, eps: float):
29
+ super().__init__()
30
+ self.eps = eps
31
+ self.weight = nn.Parameter(torch.ones(dim))
32
+
33
+ def _norm(self, x):
34
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
35
+
36
+ def forward(self, x):
37
+ output = self._norm(x.float()).type_as(x)
38
+ return output * self.weight
39
+
40
+
41
+ def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
42
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
43
+ t = torch.arange(end, device=freqs.device) # type: ignore
44
+ freqs = torch.outer(t, freqs).float() # type: ignore
45
+ freqs_cos = torch.cos(freqs) # real part
46
+ freqs_sin = torch.sin(freqs) # imaginary part
47
+ return freqs_cos, freqs_sin
48
+
49
+ def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
50
+ ndim = x.ndim
51
+ assert 0 <= 1 < ndim
52
+ assert freqs_cis.shape == (x.shape[1], x.shape[-1])
53
+ shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
54
+ return freqs_cis.view(shape)
55
+
56
+ def apply_rotary_emb(
57
+ xq: torch.Tensor,
58
+ xk: torch.Tensor,
59
+ freqs_cos: torch.Tensor,
60
+ freqs_sin: torch.Tensor
61
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
62
+
63
+ # reshape xq and xk to match the complex representation
64
+ xq_r, xq_i = xq.float().reshape(xq.shape[:-1] + (-1, 2)).unbind(-1)
65
+ xk_r, xk_i = xk.float().reshape(xk.shape[:-1] + (-1, 2)).unbind(-1)
66
+
67
+ # reshape freqs_cos and freqs_sin for broadcasting
68
+ freqs_cos = reshape_for_broadcast(freqs_cos, xq_r)
69
+ freqs_sin = reshape_for_broadcast(freqs_sin, xq_r)
70
+
71
+ # apply rotation using real numbers
72
+ xq_out_r = xq_r * freqs_cos - xq_i * freqs_sin
73
+ xq_out_i = xq_r * freqs_sin + xq_i * freqs_cos
74
+ xk_out_r = xk_r * freqs_cos - xk_i * freqs_sin
75
+ xk_out_i = xk_r * freqs_sin + xk_i * freqs_cos
76
+
77
+ # flatten last two dimensions
78
+ xq_out = torch.stack([xq_out_r, xq_out_i], dim=-1).flatten(3)
79
+ xk_out = torch.stack([xk_out_r, xk_out_i], dim=-1).flatten(3)
80
+
81
+ return xq_out.type_as(xq), xk_out.type_as(xk)
82
+
83
+ def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
84
+ """torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
85
+ bs, slen, n_kv_heads, head_dim = x.shape
86
+ if n_rep == 1:
87
+ return x
88
+ return (
89
+ x[:, :, :, None, :]
90
+ .expand(bs, slen, n_kv_heads, n_rep, head_dim)
91
+ .reshape(bs, slen, n_kv_heads * n_rep, head_dim)
92
+ )
93
+
94
+ class Attention(nn.Module):
95
+ def __init__(self, args: ModelArgs):
96
+ super().__init__()
97
+ self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
98
+ assert args.n_heads % self.n_kv_heads == 0
99
+ model_parallel_size = 1
100
+ self.n_local_heads = args.n_heads // model_parallel_size
101
+ self.n_local_kv_heads = self.n_kv_heads // model_parallel_size
102
+ self.n_rep = self.n_local_heads // self.n_local_kv_heads
103
+ self.head_dim = args.dim // args.n_heads
104
+ self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False)
105
+ self.wk = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
106
+ self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
107
+ self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False)
108
+ self.attn_dropout = nn.Dropout(args.dropout)
109
+ self.resid_dropout = nn.Dropout(args.dropout)
110
+ self.dropout = args.dropout
111
+
112
+ # use flash attention or a manual implementation?
113
+ self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
114
+ if not self.flash:
115
+ print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
116
+ mask = torch.full((1, 1, args.max_seq_len, args.max_seq_len), float("-inf"))
117
+ mask = torch.triu(mask, diagonal=1)
118
+ self.register_buffer("mask", mask)
119
+
120
+ def forward(
121
+ self,
122
+ x: torch.Tensor,
123
+ freqs_cos: torch.Tensor,
124
+ freqs_sin: torch.Tensor,
125
+ ):
126
+ bsz, seqlen, _ = x.shape
127
+
128
+ # QKV
129
+ xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
130
+ xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
131
+ xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
132
+ xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
133
+
134
+ # RoPE relative positional embeddings
135
+ xq, xk = apply_rotary_emb(xq, xk, freqs_cos, freqs_sin)
136
+
137
+ # grouped multiquery attention: expand out keys and values
138
+ xk = repeat_kv(xk, self.n_rep) # (bs, seqlen, n_local_heads, head_dim)
139
+ xv = repeat_kv(xv, self.n_rep) # (bs, seqlen, n_local_heads, head_dim)
140
+
141
+ # make heads into a batch dimension
142
+ xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
143
+ xk = xk.transpose(1, 2)
144
+ xv = xv.transpose(1, 2)
145
+
146
+ # flash implementation
147
+ if self.flash:
148
+ output = torch.nn.functional.scaled_dot_product_attention(xq, xk, xv, attn_mask=None, dropout_p=self.dropout if self.training else 0.0, is_causal=True)
149
+ else:
150
+ # manual implementation
151
+ scores = torch.matmul(xq, xk.transpose(2, 3)) / math.sqrt(self.head_dim)
152
+ assert hasattr(self, 'mask')
153
+ scores = scores + self.mask[:, :, :seqlen, :seqlen] # (bs, n_local_heads, seqlen, cache_len + seqlen)
154
+ scores = F.softmax(scores.float(), dim=-1).type_as(xq)
155
+ scores = self.attn_dropout(scores)
156
+ output = torch.matmul(scores, xv) # (bs, n_local_heads, seqlen, head_dim)
157
+
158
+ # restore time as batch dimension and concat heads
159
+ output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
160
+
161
+ # final projection into the residual stream
162
+ output = self.wo(output)
163
+ output = self.resid_dropout(output)
164
+ return output
165
+
166
+
167
+ class FeedForward(nn.Module):
168
+ def __init__(self, dim: int, hidden_dim: int, multiple_of: int, dropout: float):
169
+ super().__init__()
170
+ if hidden_dim is None:
171
+ hidden_dim = 4 * dim
172
+ hidden_dim = int(2 * hidden_dim / 3)
173
+ hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
174
+ self.w1 = nn.Linear(dim, hidden_dim, bias=False)
175
+ self.w2 = nn.Linear(hidden_dim, dim, bias=False)
176
+ self.w3 = nn.Linear(dim, hidden_dim, bias=False)
177
+ self.dropout = nn.Dropout(dropout)
178
+
179
+ def forward(self, x):
180
+ return self.dropout(self.w2(F.silu(self.w1(x)) * self.w3(x)))
181
+
182
+
183
+ class TransformerBlock(nn.Module):
184
+ def __init__(self, layer_id: int, args: ModelArgs):
185
+ super().__init__()
186
+ self.n_heads = args.n_heads
187
+ self.dim = args.dim
188
+ self.head_dim = args.dim // args.n_heads
189
+ self.attention = Attention(args)
190
+ self.feed_forward = FeedForward(
191
+ dim=args.dim,
192
+ hidden_dim=args.hidden_dim,
193
+ multiple_of=args.multiple_of,
194
+ dropout=args.dropout,
195
+ )
196
+ self.layer_id = layer_id
197
+ self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
198
+ self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
199
+
200
+ def forward(self, x, freqs_cos, freqs_sin):
201
+ h = x + self.attention.forward(self.attention_norm(x), freqs_cos, freqs_sin)
202
+ out = h + self.feed_forward.forward(self.ffn_norm(h))
203
+ return out
204
+
205
+
206
+ class Transformer(nn.Module):
207
+ last_loss: Optional[torch.Tensor]
208
+
209
+ def __init__(self, params: ModelArgs):
210
+ super().__init__()
211
+ self.params = params
212
+ self.vocab_size = params.vocab_size
213
+ self.n_layers = params.n_layers
214
+
215
+ self.tok_embeddings = nn.Embedding(params.vocab_size, params.dim)
216
+ self.dropout = nn.Dropout(params.dropout)
217
+ self.layers = torch.nn.ModuleList()
218
+ for layer_id in range(params.n_layers):
219
+ self.layers.append(TransformerBlock(layer_id, params))
220
+ self.norm = RMSNorm(params.dim, eps=params.norm_eps)
221
+ self.output = nn.Linear(params.dim, params.vocab_size, bias=False)
222
+
223
+ # share the unembedding parameters with the embedding parameters
224
+ self.tok_embeddings.weight = self.output.weight # https://paperswithcode.com/method/weight-tying
225
+
226
+ # some useful precompute for the RoPE relative positional embeddings
227
+ freqs_cos, freqs_sin = precompute_freqs_cis(self.params.dim // self.params.n_heads, self.params.max_seq_len)
228
+ self.register_buffer("freqs_cos", freqs_cos, persistent=False)
229
+ self.register_buffer("freqs_sin", freqs_sin, persistent=False)
230
+
231
+ # init all weights
232
+ self.apply(self._init_weights)
233
+ # apply special scaled init to the residual projections, per GPT-2 paper
234
+ for pn, p in self.named_parameters():
235
+ if pn.endswith('w3.weight') or pn.endswith('wo.weight'):
236
+ torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * params.n_layers))
237
+
238
+ # Initialize attribute for the loss of the last forward call. This will be set if the forward is called with a targets tensor.
239
+ self.last_loss = None
240
+
241
+ def _init_weights(self, module):
242
+ if isinstance(module, nn.Linear):
243
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
244
+ if module.bias is not None:
245
+ torch.nn.init.zeros_(module.bias)
246
+ elif isinstance(module, nn.Embedding):
247
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
248
+
249
+ def forward(self, tokens: torch.Tensor, targets: Optional[torch.Tensor] = None) -> torch.Tensor:
250
+ _bsz, seqlen = tokens.shape
251
+ h = self.tok_embeddings(tokens)
252
+ h = self.dropout(h)
253
+ freqs_cos = self.freqs_cos[:seqlen]
254
+ freqs_sin = self.freqs_sin[:seqlen]
255
+
256
+ for layer in self.layers:
257
+ h = layer(h, freqs_cos, freqs_sin)
258
+ h = self.norm(h)
259
+
260
+ if targets is not None:
261
+ # if we are given some desired targets also calculate the loss
262
+ logits = self.output(h)
263
+ self.last_loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
264
+ else:
265
+ # inference-time mini-optimization: only forward the output on the very last position
266
+ logits = self.output(h[:, [-1], :]) # note: using list [-1] to preserve the time dim
267
+ self.last_loss = None
268
+
269
+ return logits
270
+
271
+ def configure_optimizers(self, weight_decay, learning_rate, betas, device_type):
272
+ # start with all of the candidate parameters
273
+ param_dict = {pn: p for pn, p in self.named_parameters()}
274
+ # filter out those that do not require grad
275
+ param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad}
276
+ # create optim groups. Any parameters that is 2D will be weight decayed, otherwise no.
277
+ # i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't.
278
+ decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]
279
+ nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]
280
+ optim_groups = [
281
+ {'params': decay_params, 'weight_decay': weight_decay},
282
+ {'params': nodecay_params, 'weight_decay': 0.0}
283
+ ]
284
+ num_decay_params = sum(p.numel() for p in decay_params)
285
+ num_nodecay_params = sum(p.numel() for p in nodecay_params)
286
+ print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters")
287
+ print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters")
288
+ # Create AdamW optimizer and use the fused version if it is available
289
+ fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters
290
+ use_fused = fused_available and device_type == 'cuda'
291
+ extra_args = dict(fused=True) if use_fused else dict()
292
+ optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, **extra_args)
293
+ print(f"using fused AdamW: {use_fused}")
294
+
295
+ return optimizer
296
+
297
+ def estimate_mfu(self, fwdbwd_per_iter, dt):
298
+ """ estimate model flops utilization (MFU) in units of A100 bfloat16 peak FLOPS """
299
+ # first estimate the number of flops we do per iteration.
300
+ # see PaLM paper Appendix B as ref: https://arxiv.org/abs/2204.02311
301
+ N = sum(p.numel() for p in self.parameters())
302
+ cfg = self.params
303
+ L, H, Q, T = cfg.n_layers, cfg.n_heads, cfg.dim//cfg.n_heads, cfg.max_seq_len
304
+ flops_per_token = 6*N + 12*L*H*Q*T
305
+ flops_per_fwdbwd = flops_per_token * T
306
+ flops_per_iter = flops_per_fwdbwd * fwdbwd_per_iter
307
+ # express our flops throughput as ratio of A100 bfloat16 peak flops
308
+ flops_achieved = flops_per_iter * (1.0/dt) # per second
309
+ flops_promised = 312e12 # A100 GPU bfloat16 peak flops is 312 TFLOPS
310
+ mfu = flops_achieved / flops_promised
311
+ return mfu
312
+
313
+ @torch.inference_mode()
314
+ def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
315
+ """
316
+ Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete
317
+ the sequence max_new_tokens times, feeding the predictions back into the model each time.
318
+ Most likely you'll want to make sure to be in model.eval() mode of operation for this.
319
+ Also note this is a super inefficient version of sampling with no key/value cache.
320
+ """
321
+ for _ in range(max_new_tokens):
322
+ # if the sequence context is growing too long we must crop it at block_size
323
+ idx_cond = idx if idx.size(1) <= self.params.max_seq_len else idx[:, -self.params.max_seq_len:]
324
+ # forward the model to get the logits for the index in the sequence
325
+ logits = self(idx_cond)
326
+ logits = logits[:, -1, :] # crop to just the final time step
327
+ if temperature == 0.0:
328
+ # "sample" the single most likely index
329
+ _, idx_next = torch.topk(logits, k=1, dim=-1)
330
+ else:
331
+ # pluck the logits at the final step and scale by desired temperature
332
+ logits = logits / temperature
333
+ # optionally crop the logits to only the top k options
334
+ if top_k is not None:
335
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
336
+ logits[logits < v[:, [-1]]] = -float('Inf')
337
+ # apply softmax to convert logits to (normalized) probabilities
338
+ probs = F.softmax(logits, dim=-1)
339
+ idx_next = torch.multinomial(probs, num_samples=1)
340
+ # append sampled index to the running sequence and continue
341
+ idx = torch.cat((idx, idx_next), dim=1)
342
+
343
+ return idx
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ blobfile==2.1.1
2
+ numpy
3
+ pytest==8.2.0
4
+ Requests
5
+ tiktoken==0.6.0
6
+ torch
7
+ tqdm
8
+ wandb==0.16.6
run.c ADDED
@@ -0,0 +1,1027 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /* Inference for Llama-3 Transformer model in pure C */
2
+
3
+ #include <stdio.h>
4
+ #include <stdlib.h>
5
+ #include <ctype.h>
6
+ #include <time.h>
7
+ #include <math.h>
8
+ #include <string.h>
9
+ #include <fcntl.h>
10
+ #if defined _WIN32
11
+ #include "win.h"
12
+ #else
13
+ #include <unistd.h>
14
+ #include <sys/mman.h>
15
+ #endif
16
+ // ----------------------------------------------------------------------------
17
+ // Transformer model
18
+
19
+ typedef struct {
20
+ int dim; // transformer dimension
21
+ int hidden_dim; // for ffn layers
22
+ int n_layers; // number of layers
23
+ int n_heads; // number of query heads
24
+ int n_kv_heads; // number of key/value heads (can be < query heads because of multiquery)
25
+ int vocab_size; // vocabulary size, usually 4096 (byte-level)
26
+ int seq_len; // max sequence length
27
+ } Config;
28
+
29
+ typedef struct {
30
+ // token embedding table
31
+ float* token_embedding_table; // (vocab_size, dim)
32
+ // weights for rmsnorms
33
+ float* rms_att_weight; // (layer, dim) rmsnorm weights
34
+ float* rms_ffn_weight; // (layer, dim)
35
+ // weights for matmuls. note dim == n_heads * head_size
36
+ float* wq; // (layer, dim, n_heads * head_size)
37
+ float* wk; // (layer, dim, n_kv_heads * head_size)
38
+ float* wv; // (layer, dim, n_kv_heads * head_size)
39
+ float* wo; // (layer, n_heads * head_size, dim)
40
+ // weights for ffn
41
+ float* w1; // (layer, hidden_dim, dim)
42
+ float* w2; // (layer, dim, hidden_dim)
43
+ float* w3; // (layer, hidden_dim, dim)
44
+ // final rmsnorm
45
+ float* rms_final_weight; // (dim,)
46
+ // (optional) classifier weights for the logits, on the last layer
47
+ float* wcls;
48
+ } TransformerWeights;
49
+
50
+ typedef struct {
51
+ // current wave of activations
52
+ float *x; // activation at current time stamp (dim,)
53
+ float *xb; // same, but inside a residual branch (dim,)
54
+ float *xb2; // an additional buffer just for convenience (dim,)
55
+ float *hb; // buffer for hidden dimension in the ffn (hidden_dim,)
56
+ float *hb2; // buffer for hidden dimension in the ffn (hidden_dim,)
57
+ float *q; // query (dim,)
58
+ float *k; // key (dim,)
59
+ float *v; // value (dim,)
60
+ float *att; // buffer for scores/attention values (n_heads, seq_len)
61
+ float *logits; // output logits
62
+ // kv cache
63
+ float* key_cache; // (layer, seq_len, dim)
64
+ float* value_cache; // (layer, seq_len, dim)
65
+ } RunState;
66
+
67
+ typedef struct {
68
+ Config config; // the hyperparameters of the architecture (the blueprint)
69
+ TransformerWeights weights; // the weights of the model
70
+ RunState state; // buffers for the "wave" of activations in the forward pass
71
+ // some more state needed to properly clean up the memory mapping (sigh)
72
+ int fd; // file descriptor for memory mapping
73
+ float* data; // memory mapped data pointer
74
+ ssize_t file_size; // size of the checkpoint file in bytes
75
+ } Transformer;
76
+
77
+ void malloc_run_state(RunState* s, Config* p) {
78
+ // we calloc instead of malloc to keep valgrind happy
79
+ int kv_dim = (p->dim * p->n_kv_heads) / p->n_heads;
80
+ s->x = calloc(p->dim, sizeof(float));
81
+ s->xb = calloc(p->dim, sizeof(float));
82
+ s->xb2 = calloc(p->dim, sizeof(float));
83
+ s->hb = calloc(p->hidden_dim, sizeof(float));
84
+ s->hb2 = calloc(p->hidden_dim, sizeof(float));
85
+ s->q = calloc(p->dim, sizeof(float));
86
+ s->key_cache = calloc(p->n_layers * p->seq_len * kv_dim, sizeof(float));
87
+ s->value_cache = calloc(p->n_layers * p->seq_len * kv_dim, sizeof(float));
88
+ s->att = calloc(p->n_heads * p->seq_len, sizeof(float));
89
+ s->logits = calloc(p->vocab_size, sizeof(float));
90
+ // ensure all mallocs went fine
91
+ if (!s->x || !s->xb || !s->xb2 || !s->hb || !s->hb2 || !s->q
92
+ || !s->key_cache || !s->value_cache || !s->att || !s->logits) {
93
+ fprintf(stderr, "malloc failed!\n");
94
+ exit(EXIT_FAILURE);
95
+ }
96
+ }
97
+
98
+ void free_run_state(RunState* s) {
99
+ free(s->x);
100
+ free(s->xb);
101
+ free(s->xb2);
102
+ free(s->hb);
103
+ free(s->hb2);
104
+ free(s->q);
105
+ free(s->att);
106
+ free(s->logits);
107
+ free(s->key_cache);
108
+ free(s->value_cache);
109
+ }
110
+
111
+ void memory_map_weights(TransformerWeights *w, Config* p, float* ptr, int shared_weights) {
112
+ int head_size = p->dim / p->n_heads;
113
+ // make sure the multiplications below are done in 64bit to fit the parameter counts of 13B+ models
114
+ unsigned long long n_layers = p->n_layers;
115
+ w->token_embedding_table = ptr;
116
+ ptr += p->vocab_size * p->dim;
117
+ w->rms_att_weight = ptr;
118
+ ptr += n_layers * p->dim;
119
+ w->wq = ptr;
120
+ ptr += n_layers * p->dim * (p->n_heads * head_size);
121
+ w->wk = ptr;
122
+ ptr += n_layers * p->dim * (p->n_kv_heads * head_size);
123
+ w->wv = ptr;
124
+ ptr += n_layers * p->dim * (p->n_kv_heads * head_size);
125
+ w->wo = ptr;
126
+ ptr += n_layers * (p->n_heads * head_size) * p->dim;
127
+ w->rms_ffn_weight = ptr;
128
+ ptr += n_layers * p->dim;
129
+ w->w1 = ptr;
130
+ ptr += n_layers * p->dim * p->hidden_dim;
131
+ w->w2 = ptr;
132
+ ptr += n_layers * p->hidden_dim * p->dim;
133
+ w->w3 = ptr;
134
+ ptr += n_layers * p->dim * p->hidden_dim;
135
+ w->rms_final_weight = ptr;
136
+ ptr += p->dim;
137
+ ptr += p->seq_len * head_size / 2; // skip what used to be freq_cis_real (for RoPE)
138
+ ptr += p->seq_len * head_size / 2; // skip what used to be freq_cis_imag (for RoPE)
139
+ w->wcls = shared_weights ? w->token_embedding_table : ptr;
140
+ }
141
+
142
+ void read_checkpoint(char* checkpoint, Config* config, TransformerWeights* weights,
143
+ int* fd, float** data, ssize_t* file_size) {
144
+ FILE *file = fopen(checkpoint, "rb");
145
+ if (!file) { fprintf(stderr, "Couldn't open file %s\n", checkpoint); exit(EXIT_FAILURE); }
146
+ // read in the config header
147
+ if (fread(config, sizeof(Config), 1, file) != 1) { exit(EXIT_FAILURE); }
148
+ // negative vocab size is hacky way of signaling unshared weights. bit yikes.
149
+ int shared_weights = config->vocab_size > 0 ? 1 : 0;
150
+ config->vocab_size = abs(config->vocab_size);
151
+ // figure out the file size
152
+ #if defined _WIN32
153
+ _fseeki64(file, 0, SEEK_END); // move file pointer to end of file
154
+ *file_size = _ftelli64(file); // get the file size, in bytes
155
+ #else
156
+ fseek(file, 0, SEEK_END); // move file pointer to end of file
157
+ *file_size = ftell(file); // get the file size, in bytes
158
+ #endif
159
+ fclose(file);
160
+ // memory map the Transformer weights into the data pointer
161
+ *fd = open(checkpoint, O_RDONLY); // open in read only mode
162
+ if (*fd == -1) { fprintf(stderr, "open failed!\n"); exit(EXIT_FAILURE); }
163
+ *data = mmap(NULL, *file_size, PROT_READ, MAP_PRIVATE, *fd, 0);
164
+ if (*data == MAP_FAILED) { fprintf(stderr, "mmap failed!\n"); exit(EXIT_FAILURE); }
165
+ float* weights_ptr = *data + sizeof(Config)/sizeof(float);
166
+ memory_map_weights(weights, config, weights_ptr, shared_weights);
167
+ }
168
+
169
+ void build_transformer(Transformer *t, char* checkpoint_path) {
170
+ // read in the Config and the Weights from the checkpoint
171
+ read_checkpoint(checkpoint_path, &t->config, &t->weights, &t->fd, &t->data, &t->file_size);
172
+ // allocate the RunState buffers
173
+ malloc_run_state(&t->state, &t->config);
174
+ }
175
+
176
+ void free_transformer(Transformer* t) {
177
+ // close the memory mapping
178
+ if (t->data != MAP_FAILED) { munmap(t->data, t->file_size); }
179
+ if (t->fd != -1) { close(t->fd); }
180
+ // free the RunState buffers
181
+ free_run_state(&t->state);
182
+ }
183
+
184
+ // ----------------------------------------------------------------------------
185
+ // neural net blocks; the dynamics of the Transformer
186
+
187
+ void rmsnorm(float* o, float* x, float* weight, int size) {
188
+ // calculate sum of squares
189
+ float ss = 0.0f;
190
+ for (int j = 0; j < size; j++) {
191
+ ss += x[j] * x[j];
192
+ }
193
+ ss /= size;
194
+ ss += 1e-5f;
195
+ ss = 1.0f / sqrtf(ss);
196
+ // normalize and scale
197
+ for (int j = 0; j < size; j++) {
198
+ o[j] = weight[j] * (ss * x[j]);
199
+ }
200
+ }
201
+
202
+ void softmax(float* x, int size) {
203
+ // find max value (for numerical stability)
204
+ float max_val = x[0];
205
+ for (int i = 1; i < size; i++) {
206
+ if (x[i] > max_val) {
207
+ max_val = x[i];
208
+ }
209
+ }
210
+ // exp and sum
211
+ float sum = 0.0f;
212
+ for (int i = 0; i < size; i++) {
213
+ x[i] = expf(x[i] - max_val);
214
+ sum += x[i];
215
+ }
216
+ // normalize
217
+ for (int i = 0; i < size; i++) {
218
+ x[i] /= sum;
219
+ }
220
+ }
221
+
222
+ void matmul(float* xout, float* x, float* w, int n, int d) {
223
+ // W (d,n) @ x (n,) -> xout (d,)
224
+ // by far the most amount of time is spent inside this little function
225
+ int i;
226
+ #pragma omp parallel for private(i)
227
+ for (i = 0; i < d; i++) {
228
+ float val = 0.0f;
229
+ for (int j = 0; j < n; j++) {
230
+ val += w[i * n + j] * x[j];
231
+ }
232
+ xout[i] = val;
233
+ }
234
+ }
235
+
236
+ float* forward(Transformer* transformer, int token, int pos) {
237
+
238
+ // a few convenience variables
239
+ Config* p = &transformer->config;
240
+ TransformerWeights* w = &transformer->weights;
241
+ RunState* s = &transformer->state;
242
+ float *x = s->x;
243
+ int dim = p->dim;
244
+ int kv_dim = (p->dim * p->n_kv_heads) / p->n_heads;
245
+ int kv_mul = p->n_heads / p->n_kv_heads; // integer multiplier of the kv sharing in multiquery
246
+ int hidden_dim = p->hidden_dim;
247
+ int head_size = dim / p->n_heads;
248
+
249
+ // copy the token embedding into x
250
+ float* content_row = w->token_embedding_table + token * dim;
251
+ memcpy(x, content_row, dim*sizeof(*x));
252
+
253
+ // forward all the layers
254
+ for(unsigned long long l = 0; l < p->n_layers; l++) {
255
+
256
+ // attention rmsnorm
257
+ rmsnorm(s->xb, x, w->rms_att_weight + l*dim, dim);
258
+
259
+ // key and value point to the kv cache
260
+ int loff = l * p->seq_len * kv_dim; // kv cache layer offset for convenience
261
+ s->k = s->key_cache + loff + pos * kv_dim;
262
+ s->v = s->value_cache + loff + pos * kv_dim;
263
+
264
+ // qkv matmuls for this position
265
+ matmul(s->q, s->xb, w->wq + l*dim*dim, dim, dim);
266
+ matmul(s->k, s->xb, w->wk + l*dim*kv_dim, dim, kv_dim);
267
+ matmul(s->v, s->xb, w->wv + l*dim*kv_dim, dim, kv_dim);
268
+
269
+ // RoPE relative positional encoding: complex-valued rotate q and k in each head
270
+ for (int i = 0; i < p->n_heads; i++) {
271
+ for (int j = 0; j < head_size; j += 2) {
272
+ float freq = 1.0f / powf(500000.0f, (float)j / (float)head_size);
273
+ float val = pos * freq;
274
+ float fcr = cosf(val);
275
+ float fci = sinf(val);
276
+ float q0 = s->q[i * head_size + j];
277
+ float q1 = s->q[i * head_size + j + 1];
278
+ s->q[i * head_size + j] = q0 * fcr - q1 * fci;
279
+ s->q[i * head_size + j + 1] = q0 * fci + q1 * fcr;
280
+ if (i < p->n_kv_heads) {
281
+ float k0 = s->k[i * head_size + j];
282
+ float k1 = s->k[i * head_size + j + 1];
283
+ s->k[i * head_size + j] = k0 * fcr - k1 * fci;
284
+ s->k[i * head_size + j + 1] = k0 * fci + k1 * fcr;
285
+ }
286
+ }
287
+ }
288
+
289
+ // multihead attention. iterate over all heads
290
+ int h;
291
+ #pragma omp parallel for private(h)
292
+ for (h = 0; h < p->n_heads; h++) {
293
+ // get the query vector for this head
294
+ float* q = s->q + h * head_size;
295
+ // attention scores for this head
296
+ float* att = s->att + h * p->seq_len;
297
+ // iterate over all timesteps, including the current one
298
+ for (int t = 0; t <= pos; t++) {
299
+ // get the key vector for this head and at this timestep
300
+ float* k = s->key_cache + loff + t * kv_dim + (h / kv_mul) * head_size;
301
+ // calculate the attention score as the dot product of q and k
302
+ float score = 0.0f;
303
+ for (int i = 0; i < head_size; i++) {
304
+ score += q[i] * k[i];
305
+ }
306
+ score /= sqrtf(head_size);
307
+ // save the score to the attention buffer
308
+ att[t] = score;
309
+ }
310
+
311
+ // softmax the scores to get attention weights, from 0..pos inclusively
312
+ softmax(att, pos + 1);
313
+
314
+ // weighted sum of the values, store back into xb
315
+ float* xb = s->xb + h * head_size;
316
+ memset(xb, 0, head_size * sizeof(float));
317
+ for (int t = 0; t <= pos; t++) {
318
+ // get the value vector for this head and at this timestep
319
+ float* v = s->value_cache + loff + t * kv_dim + (h / kv_mul) * head_size;
320
+ // get the attention weight for this timestep
321
+ float a = att[t];
322
+ // accumulate the weighted value into xb
323
+ for (int i = 0; i < head_size; i++) {
324
+ xb[i] += a * v[i];
325
+ }
326
+ }
327
+ }
328
+
329
+ // final matmul to get the output of the attention
330
+ matmul(s->xb2, s->xb, w->wo + l*dim*dim, dim, dim);
331
+
332
+ // residual connection back into x
333
+ for (int i = 0; i < dim; i++) {
334
+ x[i] += s->xb2[i];
335
+ }
336
+
337
+ // ffn rmsnorm
338
+ rmsnorm(s->xb, x, w->rms_ffn_weight + l*dim, dim);
339
+
340
+ // Now for FFN in PyTorch we have: self.w2(F.silu(self.w1(x)) * self.w3(x))
341
+ // first calculate self.w1(x) and self.w3(x)
342
+ matmul(s->hb, s->xb, w->w1 + l*dim*hidden_dim, dim, hidden_dim);
343
+ matmul(s->hb2, s->xb, w->w3 + l*dim*hidden_dim, dim, hidden_dim);
344
+
345
+ // SwiGLU non-linearity
346
+ for (int i = 0; i < hidden_dim; i++) {
347
+ float val = s->hb[i];
348
+ // silu(x)=x*σ(x), where σ(x) is the logistic sigmoid
349
+ val *= (1.0f / (1.0f + expf(-val)));
350
+ // elementwise multiply with w3(x)
351
+ val *= s->hb2[i];
352
+ s->hb[i] = val;
353
+ }
354
+
355
+ // final matmul to get the output of the ffn
356
+ matmul(s->xb, s->hb, w->w2 + l*dim*hidden_dim, hidden_dim, dim);
357
+
358
+ // residual connection
359
+ for (int i = 0; i < dim; i++) {
360
+ x[i] += s->xb[i];
361
+ }
362
+ }
363
+
364
+ // final rmsnorm
365
+ rmsnorm(x, x, w->rms_final_weight, dim);
366
+
367
+ // classifier into logits
368
+ matmul(s->logits, x, w->wcls, p->dim, p->vocab_size);
369
+ return s->logits;
370
+ }
371
+
372
+ // ----------------------------------------------------------------------------
373
+ // The Byte Pair Encoding (BPE) Tokenizer that translates strings <-> tokens
374
+
375
+ typedef struct {
376
+ char *str;
377
+ int id;
378
+ } TokenIndex;
379
+
380
+ typedef struct {
381
+ char** vocab;
382
+ float* vocab_scores;
383
+ TokenIndex *sorted_vocab;
384
+ int vocab_size;
385
+ unsigned int max_token_length;
386
+ unsigned char byte_pieces[512]; // stores all single-byte strings
387
+ } Tokenizer;
388
+
389
+ int compare_tokens(const void *a, const void *b) {
390
+ return strcmp(((TokenIndex*)a)->str, ((TokenIndex*)b)->str);
391
+ }
392
+
393
+ void build_tokenizer(Tokenizer* t, char* tokenizer_path, int vocab_size) {
394
+ // i should have written the vocab_size into the tokenizer file... sigh
395
+ t->vocab_size = vocab_size;
396
+ // malloc space to hold the scores and the strings
397
+ t->vocab = (char**)malloc(vocab_size * sizeof(char*));
398
+ t->vocab_scores = (float*)malloc(vocab_size * sizeof(float));
399
+ t->sorted_vocab = NULL; // initialized lazily
400
+ for (int i = 0; i < 256; i++) {
401
+ t->byte_pieces[i * 2] = (unsigned char)i;
402
+ t->byte_pieces[i * 2 + 1] = '\0';
403
+ }
404
+ // read in the file
405
+ FILE *file = fopen(tokenizer_path, "rb");
406
+ if (!file) { fprintf(stderr, "couldn't load %s\n", tokenizer_path); exit(EXIT_FAILURE); }
407
+ if (fread(&t->max_token_length, sizeof(int), 1, file) != 1) { fprintf(stderr, "failed read\n"); exit(EXIT_FAILURE); }
408
+ int len;
409
+ for (int i = 0; i < vocab_size; i++) {
410
+ if (fread(t->vocab_scores + i, sizeof(float), 1, file) != 1) { fprintf(stderr, "failed read\n"); exit(EXIT_FAILURE);}
411
+ if (fread(&len, sizeof(int), 1, file) != 1) { fprintf(stderr, "failed read\n"); exit(EXIT_FAILURE); }
412
+ t->vocab[i] = (char *)malloc(len + 1);
413
+ if (fread(t->vocab[i], len, 1, file) != 1) { fprintf(stderr, "failed read\n"); exit(EXIT_FAILURE); }
414
+ t->vocab[i][len] = '\0'; // add the string terminating token
415
+ }
416
+ fclose(file);
417
+ }
418
+
419
+ void free_tokenizer(Tokenizer* t) {
420
+ for (int i = 0; i < t->vocab_size; i++) { free(t->vocab[i]); }
421
+ free(t->vocab);
422
+ free(t->vocab_scores);
423
+ free(t->sorted_vocab);
424
+ }
425
+
426
+ char* decode(Tokenizer* t, int prev_token, int token) {
427
+ char *piece = t->vocab[token];
428
+
429
+
430
+ // careful, some tokens designate raw bytes, and look like e.g. '<0x01>'
431
+ // parse this and convert and return the actual byte
432
+ unsigned char byte_val;
433
+ if (sscanf(piece, "<0x%02hhX>", &byte_val) == 1) {
434
+ piece = (char*)t->byte_pieces + byte_val * 2;
435
+ }
436
+ return piece;
437
+ }
438
+
439
+ void safe_printf(char *piece) {
440
+ // piece might be a raw byte token, and we only want to print printable chars or whitespace
441
+ // because some of the other bytes can be various control codes, backspace, etc.
442
+ if (piece == NULL) { return; }
443
+ if (piece[0] == '\0') { return; }
444
+ if (piece[1] == '\0') {
445
+ unsigned char byte_val = piece[0];
446
+ if (!(isprint(byte_val) || isspace(byte_val))) {
447
+ return; // bad byte, don't print it
448
+ }
449
+ }
450
+ printf("%s", piece);
451
+ }
452
+
453
+ int str_lookup(char *str, TokenIndex *sorted_vocab, int vocab_size) {
454
+ // efficiently find the perfect match for str in vocab, return its index or -1 if not found
455
+ TokenIndex tok = { .str = str }; // acts as the key to search for
456
+ TokenIndex *res = bsearch(&tok, sorted_vocab, vocab_size, sizeof(TokenIndex), compare_tokens);
457
+ return res != NULL ? res->id : -1;
458
+ }
459
+
460
+ void encode(Tokenizer* t, char *text, int8_t bos, int8_t eos, int *tokens, int *n_tokens) {
461
+ // encode the string text (input) into an upper-bound preallocated tokens[] array
462
+ // bos != 0 means prepend the BOS token (=1), eos != 0 means append the EOS token (=2)
463
+ if (text == NULL) { fprintf(stderr, "cannot encode NULL text\n"); exit(EXIT_FAILURE); }
464
+
465
+ if (t->sorted_vocab == NULL) {
466
+ // lazily malloc and sort the vocabulary
467
+ t->sorted_vocab = malloc(t->vocab_size * sizeof(TokenIndex));
468
+ for (int i = 0; i < t->vocab_size; i++) {
469
+ t->sorted_vocab[i].str = t->vocab[i];
470
+ t->sorted_vocab[i].id = i;
471
+ }
472
+ qsort(t->sorted_vocab, t->vocab_size, sizeof(TokenIndex), compare_tokens);
473
+ }
474
+
475
+ // create a temporary buffer that will store merge candidates of always two consecutive tokens
476
+ // *2 for concat, +1 for null terminator +2 for UTF8 (in case max_token_length is 1)
477
+ char* str_buffer = malloc((t->max_token_length*2 +1 +2) * sizeof(char));
478
+ size_t str_len = 0;
479
+
480
+ // start at 0 tokens
481
+ *n_tokens = 0;
482
+
483
+ // add optional BOS (=128000) token, if desired
484
+ if (bos) tokens[(*n_tokens)++] = 128000;
485
+
486
+ // add_dummy_prefix is true by default
487
+ // so prepend a dummy prefix token to the input string, but only if text != ""
488
+ // TODO: pretty sure this isn't correct in the general case but I don't have the
489
+ // energy to read more of the sentencepiece code to figure out what it's doing
490
+
491
+
492
+
493
+
494
+
495
+ // Okay UTF-8 time. This will get messy. Here is the reference from Wikipedia:
496
+ // Code point ↔ UTF-8 conversion
497
+ // First code point Last code point Byte 1 Byte 2 Byte 3 Byte 4
498
+ // U+0000 U+007F 0xxxxxxx
499
+ // U+0080 U+07FF 110xxxxx 10xxxxxx
500
+ // U+0800 U+FFFF 1110xxxx 10xxxxxx 10xxxxxx
501
+ // U+10000 U+10FFFF 11110xxx 10xxxxxx 10xxxxxx 10xxxxxx
502
+
503
+ // process the raw (UTF-8) byte sequence of the input string
504
+ for (char *c = text; *c != '\0'; c++) {
505
+
506
+ // reset buffer if the current byte is ASCII or a leading byte
507
+ // 0xC0 is 11000000, so (*c & 0xC0) keeps the first 2 bits and zeros the rest
508
+ // 0x80 is 10000000
509
+ // in UTF-8, all continuation bytes start with "10" in first two bits
510
+ // so in English this is: "if this byte is not a continuation byte"
511
+ if ((*c & 0xC0) != 0x80) {
512
+ // this byte must be either a leading byte (11...) or an ASCII char (0x...)
513
+ // => reset our location, as we're starting a new UTF-8 codepoint
514
+ str_len = 0;
515
+ }
516
+
517
+ // append the current byte to the buffer
518
+ str_buffer[str_len++] = *c; // ++ is post-increment, incremented after this line
519
+ str_buffer[str_len] = '\0';
520
+
521
+ // while the next character is a continuation byte, continue appending
522
+ // but if there are too many of them, just stop to avoid overruning str_buffer size.
523
+ if ((*(c+1) & 0xC0) == 0x80 && str_len < 4) {
524
+ continue;
525
+ }
526
+
527
+ // ok c+1 is not a continuation byte, so we've read in a full codepoint
528
+ int id = str_lookup(str_buffer, t->sorted_vocab, t->vocab_size);
529
+
530
+ if (id != -1) {
531
+ // we found this codepoint in vocab, add it as a token
532
+ tokens[(*n_tokens)++] = id;
533
+ } else {
534
+ // byte_fallback encoding: just encode each byte as a token
535
+ // +3 is here because the first 3 vocab elements are <unk>, <s>, </s>
536
+ // so the individual bytes only start at index 3
537
+ for (int i=0; i < str_len; i++) {
538
+ tokens[(*n_tokens)++] = (unsigned char)str_buffer[i] + 3;
539
+ }
540
+ }
541
+ str_len = 0; // protect against a sequence of stray UTF8 continuation bytes
542
+ }
543
+
544
+ // merge the best consecutive pair or triple each iteration, according to the scores in vocab_scores
545
+ while (1) {
546
+ float best_score = -1e10;
547
+ int best_id = -1;
548
+ int best_idx = -1;
549
+ int best_len = 2; // length of the best merge sequence (2 for pair, 3 for triple)
550
+
551
+ // first, try to find the best pair to merge
552
+ for (int i = 0; i < (*n_tokens - 1); i++) {
553
+ // check if we can merge the pair (tokens[i], tokens[i+1])
554
+ sprintf(str_buffer, "%s%s", t->vocab[tokens[i]], t->vocab[tokens[i+1]]);
555
+ int id = str_lookup(str_buffer, t->sorted_vocab, t->vocab_size);
556
+ if (id != -1 && t->vocab_scores[id] > best_score) {
557
+ // this merge pair exists in vocab! record its score and position
558
+ best_score = t->vocab_scores[id];
559
+ best_id = id;
560
+ best_idx = i;
561
+ }
562
+ }
563
+
564
+ // if no pair was found, try to find the best triple to merge
565
+ if (best_idx == -1) {
566
+ for (int i = 0; i < (*n_tokens - 2); i++) {
567
+ // check if we can merge the triple (tokens[i], tokens[i+1], tokens[i+2])
568
+ sprintf(str_buffer, "%s%s%s", t->vocab[tokens[i]], t->vocab[tokens[i+1]], t->vocab[tokens[i+2]]);
569
+ int id = str_lookup(str_buffer, t->sorted_vocab, t->vocab_size);
570
+ if (id != -1 && t->vocab_scores[id] > best_score) {
571
+ // this merge triple exists in vocab! record its score and position
572
+ best_score = t->vocab_scores[id];
573
+ best_id = id;
574
+ best_idx = i;
575
+ best_len = 3;
576
+ }
577
+ }
578
+ }
579
+
580
+ if (best_idx == -1) {
581
+ break; // we couldn't find any more pairs or triples to merge, so we're done
582
+ }
583
+
584
+ // merge the consecutive pair or triple (best_idx, best_idx+1[, best_idx+2]) into new token best_id
585
+ tokens[best_idx] = best_id;
586
+ // delete token(s) at position best_idx+1 (and optionally best_idx+2), shift the entire sequence back
587
+ for (int i = best_idx + 1; i < (*n_tokens - best_len + 1); i++) {
588
+ tokens[i] = tokens[i + best_len - 1];
589
+ }
590
+ (*n_tokens) -= (best_len - 1); // token length decreased by the number of merged tokens minus one
591
+ }
592
+
593
+ // add optional EOS (=128001) token, if desired
594
+ if (eos) tokens[(*n_tokens)++] = 128001;
595
+
596
+ free(str_buffer);
597
+ }
598
+
599
+ // ----------------------------------------------------------------------------
600
+ // The Sampler, which takes logits and returns a sampled token
601
+ // sampling can be done in a few ways: greedy argmax, sampling, top-p sampling
602
+
603
+ typedef struct {
604
+ float prob;
605
+ int index;
606
+ } ProbIndex; // struct used when sorting probabilities during top-p sampling
607
+
608
+ typedef struct {
609
+ int vocab_size;
610
+ ProbIndex* probindex; // buffer used in top-p sampling
611
+ float temperature;
612
+ float topp;
613
+ unsigned long long rng_state;
614
+ } Sampler;
615
+
616
+ int sample_argmax(float* probabilities, int n) {
617
+ // return the index that has the highest probability
618
+ int max_i = 0;
619
+ float max_p = probabilities[0];
620
+ for (int i = 1; i < n; i++) {
621
+ if (probabilities[i] > max_p) {
622
+ max_i = i;
623
+ max_p = probabilities[i];
624
+ }
625
+ }
626
+ return max_i;
627
+ }
628
+
629
+ int sample_mult(float* probabilities, int n, float coin) {
630
+ // sample index from probabilities (they must sum to 1!)
631
+ // coin is a random number in [0, 1), usually from random_f32()
632
+ float cdf = 0.0f;
633
+ for (int i = 0; i < n; i++) {
634
+ cdf += probabilities[i];
635
+ if (coin < cdf) {
636
+ return i;
637
+ }
638
+ }
639
+ return n - 1; // in case of rounding errors
640
+ }
641
+
642
+ int compare(const void* a, const void* b) {
643
+ ProbIndex* a_ = (ProbIndex*) a;
644
+ ProbIndex* b_ = (ProbIndex*) b;
645
+ if (a_->prob > b_->prob) return -1;
646
+ if (a_->prob < b_->prob) return 1;
647
+ return 0;
648
+ }
649
+
650
+ int sample_topp(float* probabilities, int n, float topp, ProbIndex* probindex, float coin) {
651
+ // top-p sampling (or "nucleus sampling") samples from the smallest set of
652
+ // tokens that exceed probability topp. This way we never sample tokens that
653
+ // have very low probabilities and are less likely to go "off the rails".
654
+ // coin is a random number in [0, 1), usually from random_f32()
655
+
656
+ int n0 = 0;
657
+ // quicksort indices in descending order of probabilities
658
+ // values smaller than (1 - topp) / (n - 1) cannot be part of the result
659
+ // so for efficiency we crop these out as candidates before sorting
660
+ const float cutoff = (1.0f - topp) / (n - 1);
661
+ for (int i = 0; i < n; i++) {
662
+ if (probabilities[i] >= cutoff) {
663
+ probindex[n0].index = i;
664
+ probindex[n0].prob = probabilities[i];
665
+ n0++;
666
+ }
667
+ }
668
+ qsort(probindex, n0, sizeof(ProbIndex), compare);
669
+
670
+ // truncate the list where cumulative probability exceeds topp
671
+ float cumulative_prob = 0.0f;
672
+ int last_idx = n0 - 1; // in case of rounding errors consider all elements
673
+ for (int i = 0; i < n0; i++) {
674
+ cumulative_prob += probindex[i].prob;
675
+ if (cumulative_prob > topp) {
676
+ last_idx = i;
677
+ break; // we've exceeded topp by including last_idx
678
+ }
679
+ }
680
+
681
+ // sample from the truncated list
682
+ float r = coin * cumulative_prob;
683
+ float cdf = 0.0f;
684
+ for (int i = 0; i <= last_idx; i++) {
685
+ cdf += probindex[i].prob;
686
+ if (r < cdf) {
687
+ return probindex[i].index;
688
+ }
689
+ }
690
+ return probindex[last_idx].index; // in case of rounding errors
691
+ }
692
+
693
+ void build_sampler(Sampler* sampler, int vocab_size, float temperature, float topp, unsigned long long rng_seed) {
694
+ sampler->vocab_size = vocab_size;
695
+ sampler->temperature = temperature;
696
+ sampler->topp = topp;
697
+ sampler->rng_state = rng_seed;
698
+ // buffer only used with nucleus sampling; may not need but it's ~small
699
+ sampler->probindex = malloc(sampler->vocab_size * sizeof(ProbIndex));
700
+ }
701
+
702
+ void free_sampler(Sampler* sampler) {
703
+ free(sampler->probindex);
704
+ }
705
+
706
+ unsigned int random_u32(unsigned long long *state) {
707
+ // xorshift rng: https://en.wikipedia.org/wiki/Xorshift#xorshift.2A
708
+ *state ^= *state >> 12;
709
+ *state ^= *state << 25;
710
+ *state ^= *state >> 27;
711
+ return (*state * 0x2545F4914F6CDD1Dull) >> 32;
712
+ }
713
+ float random_f32(unsigned long long *state) { // random float32 in [0,1)
714
+ return (random_u32(state) >> 8) / 16777216.0f;
715
+ }
716
+
717
+ int sample(Sampler* sampler, float* logits) {
718
+ // sample the token given the logits and some hyperparameters
719
+ int next;
720
+ if (sampler->temperature == 0.0f) {
721
+ // greedy argmax sampling: take the token with the highest probability
722
+ next = sample_argmax(logits, sampler->vocab_size);
723
+ } else {
724
+ // apply the temperature to the logits
725
+ for (int q=0; q<sampler->vocab_size; q++) { logits[q] /= sampler->temperature; }
726
+ // apply softmax to the logits to get the probabilities for next token
727
+ softmax(logits, sampler->vocab_size);
728
+ // flip a (float) coin (this is our source of entropy for sampling)
729
+ float coin = random_f32(&sampler->rng_state);
730
+ // we sample from this distribution to get the next token
731
+ if (sampler->topp <= 0 || sampler->topp >= 1) {
732
+ // simply sample from the predicted probability distribution
733
+ next = sample_mult(logits, sampler->vocab_size, coin);
734
+ } else {
735
+ // top-p (nucleus) sampling, clamping the least likely tokens to zero
736
+ next = sample_topp(logits, sampler->vocab_size, sampler->topp, sampler->probindex, coin);
737
+ }
738
+ }
739
+ return next;
740
+ }
741
+
742
+ // ----------------------------------------------------------------------------
743
+ // utilities: time
744
+
745
+ long time_in_ms() {
746
+ // return time in milliseconds, for benchmarking the model speed
747
+ struct timespec time;
748
+ clock_gettime(CLOCK_REALTIME, &time);
749
+ return time.tv_sec * 1000 + time.tv_nsec / 1000000;
750
+ }
751
+
752
+ // ----------------------------------------------------------------------------
753
+ // generation loop
754
+
755
+ void generate(Transformer *transformer, Tokenizer *tokenizer, Sampler *sampler, char *prompt, int steps) {
756
+ char *empty_prompt = "";
757
+ if (prompt == NULL) { prompt = empty_prompt; }
758
+
759
+ // encode the (string) prompt into tokens sequence
760
+ int num_prompt_tokens = 0;
761
+ int* prompt_tokens = (int*)malloc((strlen(prompt)+3) * sizeof(int)); // +3 for '\0', ?BOS, ?EOS
762
+ encode(tokenizer, prompt, 1, 0, prompt_tokens, &num_prompt_tokens);
763
+ if (num_prompt_tokens < 1) {
764
+ fprintf(stderr, "something is wrong, expected at least 1 prompt token\n");
765
+ exit(EXIT_FAILURE);
766
+ }
767
+
768
+ // start the main loop
769
+ long start = 0; // used to time our code, only initialized after first iteration
770
+ int next; // will store the next token in the sequence
771
+ int token = prompt_tokens[0]; // kick off with the first token in the prompt
772
+ int pos = 0; // position in the sequence
773
+
774
+ while (pos < steps) {
775
+
776
+ // forward the transformer to get logits for the next token
777
+ float* logits = forward(transformer, token, pos);
778
+
779
+ // advance the state machine
780
+ if (pos < num_prompt_tokens - 1) {
781
+ // if we are still processing the input prompt, force the next prompt token
782
+ next = prompt_tokens[pos + 1];
783
+ } else {
784
+ // otherwise sample the next token from the logits
785
+ next = sample(sampler, logits);
786
+ }
787
+ pos++;
788
+
789
+ // data-dependent terminating condition: the BOS (=1) token delimits sequences
790
+ if ((next == 128001 || next == 128009) && pos > num_prompt_tokens) break;
791
+ // print the token as string, decode it with the Tokenizer object
792
+ char* piece = decode(tokenizer, token, next);
793
+ safe_printf(piece); // same as printf("%s", piece), but skips "unsafe" bytes
794
+ fflush(stdout);
795
+ token = next;
796
+
797
+ // init the timer here because the first iteration can be slower
798
+ if (start == 0) { start = time_in_ms(); }
799
+ }
800
+ printf("\n");
801
+
802
+ // report achieved tok/s (pos-1 because the timer starts after first iteration)
803
+ if (pos > 1) {
804
+ long end = time_in_ms();
805
+ fprintf(stderr, "achieved tok/s: %f\n", (pos-1) / (double)(end-start)*1000);
806
+ }
807
+
808
+ free(prompt_tokens);
809
+ }
810
+
811
+ void read_stdin(const char* guide, char* buffer, size_t bufsize) {
812
+ // read a line from stdin, up to but not including \n
813
+ printf("%s", guide);
814
+ if (fgets(buffer, bufsize, stdin) != NULL) {
815
+ size_t len = strlen(buffer);
816
+ if (len > 0 && buffer[len - 1] == '\n') {
817
+ buffer[len - 1] = '\0'; // strip newline
818
+ }
819
+ }
820
+ }
821
+
822
+ // ----------------------------------------------------------------------------
823
+ // chat loop
824
+ // I manually inspected the tokens for a few chat conversations compared to
825
+ // python reference and that seemed ok, but this was not thoroughly tested and
826
+ // is not safely implemented, it's more a proof of concept atm.
827
+
828
+ void chat(Transformer *transformer, Tokenizer *tokenizer, Sampler *sampler,
829
+ char *cli_user_prompt, char *cli_system_prompt, int steps) {
830
+
831
+ // buffers for reading the system prompt and user prompt from stdin
832
+ // you'll notice they are somewhat haphazardly and unsafely set atm
833
+ char* system_prompt = (char*)malloc(32768 * sizeof(char));
834
+ char* user_prompt = (char*)malloc(32768 * sizeof(char));
835
+ int num_prompt_tokens = 0;
836
+ int* prompt_tokens = (int*)malloc(32768 * sizeof(int));
837
+ int* system_prompt_tokens = (int*)malloc(32768 * sizeof(int));
838
+ int* user_prompt_tokens = (int*)malloc(32768 * sizeof(int));
839
+ int user_idx=0;
840
+
841
+ // start the main loop
842
+ int8_t user_turn = 1; // user starts
843
+ int next; // will store the next token in the sequence
844
+ int token; // stores the current token to feed into the transformer
845
+
846
+ int pos = 0; // position in the sequence
847
+ while (pos < steps) {
848
+
849
+ // when it is the user's turn to contribute tokens to the dialog...
850
+ if (user_turn) {
851
+ // get the (optional) system prompt at position 0
852
+ if (pos == 0) {
853
+ // at position 0, the user can also contribute a system prompt
854
+ prompt_tokens[num_prompt_tokens++] = 128000; // "<|begin_of_text|>"
855
+ prompt_tokens[num_prompt_tokens++] = 128006; // "<|start_header_id|>"
856
+ prompt_tokens[num_prompt_tokens++] = 9125; // "system"
857
+ prompt_tokens[num_prompt_tokens++] = 128007; // "<|end_header_id|>"
858
+ prompt_tokens[num_prompt_tokens++] = 271; // "\n\n"
859
+ if (cli_system_prompt == NULL) {
860
+ // system prompt was not passed in, attempt to get it from stdin
861
+ read_stdin("Enter system prompt (optional): ", system_prompt, 32768);
862
+ } else {
863
+ // system prompt was passed in, use it
864
+ strcpy(system_prompt, cli_system_prompt);
865
+ }
866
+ if (system_prompt != NULL) {
867
+ int num_system_prompt_tokens = 0;
868
+ encode(tokenizer, system_prompt, 0, 0, system_prompt_tokens, &num_system_prompt_tokens);
869
+ for (int i=0; i<num_system_prompt_tokens; i++) {
870
+ prompt_tokens[num_prompt_tokens++] = system_prompt_tokens[i];
871
+ }
872
+ }
873
+ prompt_tokens[num_prompt_tokens++] = 128009; // "<|eot_id|>"
874
+ } else {
875
+ num_prompt_tokens = 0;
876
+ }
877
+ prompt_tokens[num_prompt_tokens++] = 128006; // "<|start_header_id|>"
878
+ prompt_tokens[num_prompt_tokens++] = 882; // "user"
879
+ prompt_tokens[num_prompt_tokens++] = 128007; // "<|end_header_id|>"
880
+ prompt_tokens[num_prompt_tokens++] = 271; // "\n\n"
881
+ // get the user prompt
882
+ if (pos == 0 && cli_user_prompt != NULL) {
883
+ // user prompt for position 0 was passed in, use it
884
+ strcpy(user_prompt, cli_user_prompt);
885
+ } else {
886
+ // otherwise get user prompt from stdin
887
+ read_stdin("User (or exit): ", user_prompt, 32768);
888
+ if(strcmp(user_prompt, "exit")==0) break;
889
+ }
890
+ int num_user_prompt_tokens = 0;
891
+ // encode the user prompt into tokens
892
+ encode(tokenizer, user_prompt, 0, 0, user_prompt_tokens, &num_user_prompt_tokens);
893
+ for (int i=0; i<num_user_prompt_tokens; i++) {
894
+ prompt_tokens[num_prompt_tokens++] = user_prompt_tokens[i];
895
+ }
896
+ prompt_tokens[num_prompt_tokens++] = 128009; // "<|eot_id|>"
897
+ prompt_tokens[num_prompt_tokens++] = 128006; // "<|start_header_id|>"
898
+ prompt_tokens[num_prompt_tokens++] = 78191; // "assistant"
899
+ prompt_tokens[num_prompt_tokens++] = 128007; // "<|end_header_id|>"
900
+ prompt_tokens[num_prompt_tokens++] = 271; // "\n\n"
901
+
902
+
903
+ user_idx = 0; // reset the user index
904
+ user_turn = 0;
905
+ printf("Assistant: ");
906
+ }
907
+
908
+ // determine the token to pass into the transformer next
909
+ if (user_idx < num_prompt_tokens) {
910
+ // if we are still processing the input prompt, force the next prompt token
911
+ token = prompt_tokens[user_idx++];
912
+ } else {
913
+ // otherwise use the next token sampled from previous turn
914
+ token = next;
915
+ }
916
+ // EOS (=128009) token ends the Assistant turn
917
+ if (user_idx >= num_prompt_tokens && (token == 128009 || token == 128001)) { user_turn = 1; }
918
+
919
+ // forward the transformer to get logits for the next token
920
+ float* logits = forward(transformer, token, pos);
921
+ next = sample(sampler, logits);
922
+ pos++;
923
+
924
+ if (user_idx >= num_prompt_tokens && next != 128009 && next != 128001 && next != 128006) {
925
+ // the Assistant is responding, so print its output
926
+ char* piece = decode(tokenizer, token, next);
927
+ safe_printf(piece); // same as printf("%s", piece), but skips "unsafe" bytes
928
+ fflush(stdout);
929
+ }
930
+ if (user_idx >= num_prompt_tokens && next == 128009 || next == 128001) { printf("\n"); }
931
+ }
932
+ printf("\n");
933
+ free(prompt_tokens);
934
+ free(system_prompt_tokens);
935
+ free(user_prompt_tokens);
936
+ free(system_prompt);
937
+ free(user_prompt);
938
+ }
939
+
940
+
941
+ // ----------------------------------------------------------------------------
942
+ // CLI, include only if not testing
943
+ #ifndef TESTING
944
+
945
+ void error_usage() {
946
+ fprintf(stderr, "Usage: run <checkpoint> [options]\n");
947
+ fprintf(stderr, "Example: run model.bin -n 4096 -i \"Once upon a time\"\n");
948
+ fprintf(stderr, "Options:\n");
949
+ fprintf(stderr, " -t <float> temperature in [0,inf], default 1.0\n");
950
+ fprintf(stderr, " -p <float> p value in top-p (nucleus) sampling in [0,1] default 0.9\n");
951
+ fprintf(stderr, " -s <int> random seed, default time(NULL)\n");
952
+ fprintf(stderr, " -n <int> number of steps to run for, default 4096. 0 = max_seq_len\n");
953
+ fprintf(stderr, " -i <string> input prompt\n");
954
+ fprintf(stderr, " -z <string> optional path to custom tokenizer\n");
955
+ fprintf(stderr, " -m <string> mode: generate|chat, default: generate\n");
956
+ fprintf(stderr, " -y <string> (optional) system prompt in chat mode\n");
957
+ exit(EXIT_FAILURE);
958
+ }
959
+
960
+ int main(int argc, char *argv[]) {
961
+
962
+ // default parameters
963
+ char *checkpoint_path = NULL; // e.g. out/model.bin
964
+ char *tokenizer_path = "tokenizer.bin";
965
+ float temperature = 1.0f; // 0.0 = greedy deterministic. 1.0 = original. don't set higher
966
+ float topp = 0.9f; // top-p in nucleus sampling. 1.0 = off. 0.9 works well, but slower
967
+ int steps = 4096; // number of steps to run for
968
+ char *prompt = NULL; // prompt string
969
+ unsigned long long rng_seed = 0; // seed rng with time by default
970
+ char *mode = "generate"; // generate|chat
971
+ char *system_prompt = NULL; // the (optional) system prompt to use in chat mode
972
+
973
+ // poor man's C argparse so we can override the defaults above from the command line
974
+ if (argc >= 2) { checkpoint_path = argv[1]; } else { error_usage(); }
975
+ for (int i = 2; i < argc; i+=2) {
976
+ // do some basic validation
977
+ if (i + 1 >= argc) { error_usage(); } // must have arg after flag
978
+ if (argv[i][0] != '-') { error_usage(); } // must start with dash
979
+ if (strlen(argv[i]) != 2) { error_usage(); } // must be -x (one dash, one letter)
980
+ // read in the args
981
+ if (argv[i][1] == 't') { temperature = atof(argv[i + 1]); }
982
+ else if (argv[i][1] == 'p') { topp = atof(argv[i + 1]); }
983
+ else if (argv[i][1] == 's') { rng_seed = atoi(argv[i + 1]); }
984
+ else if (argv[i][1] == 'n') { steps = atoi(argv[i + 1]); }
985
+ else if (argv[i][1] == 'i') { prompt = argv[i + 1]; }
986
+ else if (argv[i][1] == 'z') { tokenizer_path = argv[i + 1]; }
987
+ else if (argv[i][1] == 'm') { mode = argv[i + 1]; }
988
+ else if (argv[i][1] == 'y') { system_prompt = argv[i + 1]; }
989
+ else { error_usage(); }
990
+ }
991
+
992
+ // parameter validation/overrides
993
+ if (rng_seed <= 0) rng_seed = (unsigned int)time(NULL);
994
+ if (temperature < 0.0) temperature = 0.0;
995
+ if (topp < 0.0 || 1.0 < topp) topp = 0.9;
996
+ if (steps < 0) steps = 0;
997
+
998
+ // build the Transformer via the model .bin file
999
+ Transformer transformer;
1000
+ build_transformer(&transformer, checkpoint_path);
1001
+ if (steps == 0 || steps > transformer.config.seq_len) steps = transformer.config.seq_len; // override to ~max length
1002
+
1003
+ // build the Tokenizer via the tokenizer .bin file
1004
+ Tokenizer tokenizer;
1005
+ build_tokenizer(&tokenizer, tokenizer_path, transformer.config.vocab_size);
1006
+
1007
+ // build the Sampler
1008
+ Sampler sampler;
1009
+ build_sampler(&sampler, transformer.config.vocab_size, temperature, topp, rng_seed);
1010
+
1011
+ // run!
1012
+ if (strcmp(mode, "generate") == 0) {
1013
+ generate(&transformer, &tokenizer, &sampler, prompt, steps);
1014
+ } else if (strcmp(mode, "chat") == 0) {
1015
+ chat(&transformer, &tokenizer, &sampler, prompt, system_prompt, steps);
1016
+ } else {
1017
+ fprintf(stderr, "unknown mode: %s\n", mode);
1018
+ error_usage();
1019
+ }
1020
+
1021
+ // memory and file handles cleanup
1022
+ free_sampler(&sampler);
1023
+ free_tokenizer(&tokenizer);
1024
+ free_transformer(&transformer);
1025
+ return 0;
1026
+ }
1027
+ #endif
rundll.h ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #ifndef MAIN_H
2
+ #define MAIN_H
3
+
4
+ #ifdef __cplusplus
5
+ extern "C" {
6
+ #endif
7
+
8
+ typedef struct Main Main;
9
+
10
+ Main *build_main(char* checkpoint_path, char* tokenizer_path, float temperature, float topp, int steps,
11
+ char* prompt, unsigned long long rng_seed, char* mode, char* system_prompt);
12
+ void free_main(Main *m);
13
+ char *run_main(Main *m);
14
+
15
+ #ifdef __cplusplus
16
+ }
17
+ #endif
18
+
19
+ #endif // MAIN_H
runq ADDED
Binary file (53.8 kB). View file
 
runq.c ADDED
@@ -0,0 +1,1146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /* Inference for Llama-3 Transformer model in pure C, int8 quantized forward pass. */
2
+
3
+ #include <stdio.h>
4
+ #include <stdlib.h>
5
+ #include <ctype.h>
6
+ #include <stdint.h>
7
+ #include <time.h>
8
+ #include <math.h>
9
+ #include <string.h>
10
+ #include <fcntl.h>
11
+ #if defined _WIN32
12
+ #include "win.h"
13
+ #else
14
+ #include <unistd.h>
15
+ #include <sys/mman.h>
16
+ #endif
17
+ // ----------------------------------------------------------------------------
18
+ // Globals
19
+ int GS = 0; // group size global for quantization of the weights
20
+
21
+ // ----------------------------------------------------------------------------
22
+ // Transformer model
23
+
24
+ typedef struct {
25
+ int dim; // transformer dimension
26
+ int hidden_dim; // for ffn layers
27
+ int n_layers; // number of layers
28
+ int n_heads; // number of query heads
29
+ int n_kv_heads; // number of key/value heads (can be < query heads because of multiquery)
30
+ int vocab_size; // vocabulary size, usually 4096 (byte-level)
31
+ int seq_len; // max sequence length
32
+ } Config;
33
+
34
+ typedef struct {
35
+ int8_t* q; // quantized values
36
+ float* s; // scaling factors
37
+ } QuantizedTensor;
38
+
39
+ typedef struct {
40
+ // token embedding table
41
+ QuantizedTensor *q_tokens; // (vocab_size, dim)
42
+ float* token_embedding_table; // same, but dequantized
43
+
44
+ // weights for rmsnorms
45
+ float* rms_att_weight; // (layer, dim) rmsnorm weights
46
+ float* rms_ffn_weight; // (layer, dim)
47
+ // weights for matmuls. note dim == n_heads * head_size
48
+ QuantizedTensor *wq; // (layer, dim, n_heads * head_size)
49
+ QuantizedTensor *wk; // (layer, dim, n_kv_heads * head_size)
50
+ QuantizedTensor *wv; // (layer, dim, n_kv_heads * head_size)
51
+ QuantizedTensor *wo; // (layer, n_heads * head_size, dim)
52
+ // weights for ffn
53
+ QuantizedTensor *w1; // (layer, hidden_dim, dim)
54
+ QuantizedTensor *w2; // (layer, dim, hidden_dim)
55
+ QuantizedTensor *w3; // (layer, hidden_dim, dim)
56
+ // final rmsnorm
57
+ float* rms_final_weight; // (dim,)
58
+ // (optional) classifier weights for the logits, on the last layer
59
+ QuantizedTensor *wcls;
60
+ } TransformerWeights;
61
+
62
+ typedef struct {
63
+ // current wave of activations
64
+ float *x; // activation at current time stamp (dim,)
65
+ float *xb; // same, but inside a residual branch (dim,)
66
+ float *xb2; // an additional buffer just for convenience (dim,)
67
+ float *hb; // buffer for hidden dimension in the ffn (hidden_dim,)
68
+ float *hb2; // buffer for hidden dimension in the ffn (hidden_dim,)
69
+ QuantizedTensor xq; // quantized x (dim,)
70
+ QuantizedTensor hq; // quantized hb (hidden_dim,)
71
+ float *q; // query (dim,)
72
+ float *k; // key (dim,)
73
+ float *v; // value (dim,)
74
+ float *att; // buffer for scores/attention values (n_heads, seq_len)
75
+ float *logits; // output logits
76
+ // kv cache
77
+ float* key_cache; // (layer, seq_len, dim)
78
+ float* value_cache; // (layer, seq_len, dim)
79
+ } RunState;
80
+
81
+ typedef struct {
82
+ Config config; // the hyperparameters of the architecture (the blueprint)
83
+ TransformerWeights weights; // the weights of the model
84
+ RunState state; // buffers for the "wave" of activations in the forward pass
85
+ // some more state needed to properly clean up the memory mapping (sigh)
86
+ int fd; // file descriptor for memory mapping
87
+ float* data; // memory mapped data pointer
88
+ ssize_t file_size; // size of the checkpoint file in bytes
89
+ } Transformer;
90
+
91
+ void malloc_run_state(RunState* s, Config* p) {
92
+ // we calloc instead of malloc to keep valgrind happy
93
+ int kv_dim = (p->dim * p->n_kv_heads) / p->n_heads;
94
+ s->x = calloc(p->dim, sizeof(float));
95
+ s->xb = calloc(p->dim, sizeof(float));
96
+ s->xb2 = calloc(p->dim, sizeof(float));
97
+ s->hb = calloc(p->hidden_dim, sizeof(float));
98
+ s->hb2 = calloc(p->hidden_dim, sizeof(float));
99
+ s->xq = (QuantizedTensor) { .q = calloc(p->dim, sizeof(int8_t)), .s = calloc(p->dim, sizeof(float)) };
100
+ s->hq = (QuantizedTensor) { .q = calloc(p->hidden_dim, sizeof(int8_t)), .s = calloc(p->hidden_dim, sizeof(float)) };
101
+ s->q = calloc(p->dim, sizeof(float));
102
+ s->k = calloc(kv_dim, sizeof(float));
103
+ s->v = calloc(kv_dim, sizeof(float));
104
+ s->att = calloc(p->n_heads * p->seq_len, sizeof(float));
105
+ s->logits = calloc(p->vocab_size, sizeof(float));
106
+ s->key_cache = calloc(p->n_layers * p->seq_len * kv_dim, sizeof(float));
107
+ s->value_cache = calloc(p->n_layers * p->seq_len * kv_dim, sizeof(float));
108
+ // ensure all mallocs went fine
109
+ if (!s->x || !s->xb || !s->xb2 || !s->hb || !s->hb2 || !s->q
110
+ || !s->k || !s->v || !s->att || !s->logits || !s->key_cache
111
+ || !s->value_cache) {
112
+ fprintf(stderr, "malloc failed!\n");
113
+ exit(EXIT_FAILURE);
114
+ }
115
+ }
116
+
117
+ void free_run_state(RunState* s) {
118
+ free(s->x);
119
+ free(s->xb);
120
+ free(s->xb2);
121
+ free(s->hb);
122
+ free(s->hb2);
123
+ free(s->xq.q);
124
+ free(s->xq.s);
125
+ free(s->hq.q);
126
+ free(s->hq.s);
127
+ free(s->q);
128
+ free(s->k);
129
+ free(s->v);
130
+ free(s->att);
131
+ free(s->logits);
132
+ free(s->key_cache);
133
+ free(s->value_cache);
134
+ }
135
+
136
+ // ----------------------------------------------------------------------------
137
+ // Quantization functions
138
+
139
+ void dequantize(QuantizedTensor *qx, float* x, int n) {
140
+ for (int i = 0; i < n; i++) {
141
+ x[i] = qx->q[i] * qx->s[i / GS];
142
+ }
143
+ }
144
+
145
+ void quantize(QuantizedTensor *qx, float* x, int n) {
146
+ int num_groups = n / GS;
147
+ float Q_MAX = 127.0f;
148
+
149
+ for (int group = 0; group < num_groups; group++) {
150
+
151
+ // find the max absolute value in the current group
152
+ float wmax = 0.0;
153
+ for (int i = 0; i < GS; i++) {
154
+ float val = fabs(x[group * GS + i]);
155
+ if (val > wmax) {
156
+ wmax = val;
157
+ }
158
+ }
159
+
160
+ // calculate and write the scaling factor
161
+ float scale = wmax / Q_MAX;
162
+ qx->s[group] = scale;
163
+
164
+ // calculate and write the quantized values
165
+ for (int i = 0; i < GS; i++) {
166
+ float quant_value = x[group * GS + i] / scale; // scale
167
+ int8_t quantized = (int8_t) round(quant_value); // round and clamp
168
+ qx->q[group * GS + i] = quantized;
169
+ }
170
+ }
171
+ }
172
+
173
+ /* initialize `n` x quantized tensor (with `size_each` elements), starting from memory pointed at *ptr */
174
+ QuantizedTensor *init_quantized_tensors(void **ptr, int n, int size_each) {
175
+ void *p = *ptr;
176
+ QuantizedTensor *res = malloc(n * sizeof(QuantizedTensor));
177
+ for(int i=0; i<n; i++) {
178
+ /* map quantized int8 values*/
179
+ res[i].q = (int8_t*)p;
180
+ p = (int8_t*)p + size_each;
181
+ /* map scale factors */
182
+ res[i].s = (float*)p;
183
+ p = (float*)p + size_each / GS;
184
+ }
185
+ *ptr = p; // advance ptr to current position
186
+ return res;
187
+ }
188
+
189
+ void memory_map_weights(TransformerWeights *w, Config* p, void* ptr, uint8_t shared_classifier) {
190
+ int head_size = p->dim / p->n_heads;
191
+ // first are the parameters that are kept in fp32 (the rmsnorm (1D) weights)
192
+ float* fptr = (float*) ptr; // cast our pointer to float*
193
+ w->rms_att_weight = fptr;
194
+ fptr += p->n_layers * p->dim;
195
+ w->rms_ffn_weight = fptr;
196
+ fptr += p->n_layers * p->dim;
197
+ w->rms_final_weight = fptr;
198
+ fptr += p->dim;
199
+
200
+ // now read all the quantized weights
201
+ ptr = (void*)fptr; // now cast the pointer back to void*
202
+ w->q_tokens = init_quantized_tensors(&ptr, 1, p->vocab_size * p->dim);
203
+ // dequantize token embedding table
204
+ w->token_embedding_table = malloc(p->vocab_size * p->dim * sizeof(float));
205
+ dequantize(w->q_tokens, w->token_embedding_table, p->vocab_size * p->dim);
206
+
207
+ w->wq = init_quantized_tensors(&ptr, p->n_layers, p->dim * (p->n_heads * head_size));
208
+ w->wk = init_quantized_tensors(&ptr, p->n_layers, p->dim * (p->n_kv_heads * head_size));
209
+ w->wv = init_quantized_tensors(&ptr, p->n_layers, p->dim * (p->n_kv_heads * head_size));
210
+ w->wo = init_quantized_tensors(&ptr, p->n_layers, (p->n_heads * head_size) * p->dim);
211
+
212
+ w->w1 = init_quantized_tensors(&ptr, p->n_layers, p->dim * p->hidden_dim);
213
+ w->w2 = init_quantized_tensors(&ptr, p->n_layers, p->hidden_dim * p->dim);
214
+ w->w3 = init_quantized_tensors(&ptr, p->n_layers, p->dim * p->hidden_dim);
215
+
216
+ w->wcls = shared_classifier ? w->q_tokens : init_quantized_tensors(&ptr, 1, p->dim * p->vocab_size);
217
+ }
218
+
219
+ void read_checkpoint(char* checkpoint, Config* config, TransformerWeights* weights,
220
+ int* fd, float** data, ssize_t* file_size) {
221
+ FILE *file = fopen(checkpoint, "rb");
222
+ if (!file) { fprintf(stderr, "Couldn't open file %s\n", checkpoint); exit(EXIT_FAILURE); }
223
+ // read in magic number (uint32), has to be 0x616b3432, i.e. "ak42" in ASCII
224
+ uint32_t magic_number;
225
+ if (fread(&magic_number, sizeof(uint32_t), 1, file) != 1) { exit(EXIT_FAILURE); }
226
+ if (magic_number != 0x616b3432) { fprintf(stderr, "Bad magic number\n"); exit(EXIT_FAILURE); }
227
+ // read in the version number (uint32), has to be 2
228
+ int version;
229
+ if (fread(&version, sizeof(int), 1, file) != 1) { exit(EXIT_FAILURE); }
230
+ if (version != 2) { fprintf(stderr, "Bad version %d, need version 2\n", version); exit(EXIT_FAILURE); }
231
+ int header_size = 256; // the header size for version 2 in bytes
232
+ // read in the Config
233
+ if (fread(config, sizeof(Config), 1, file) != 1) { exit(EXIT_FAILURE); }
234
+ // read in flags
235
+ uint8_t shared_classifier; // a byte to indicate if the classifier is shared
236
+ if (fread(&shared_classifier, sizeof(uint8_t), 1, file) != 1) { exit(EXIT_FAILURE); }
237
+ int group_size; // the group size used in quantization
238
+ if (fread(&group_size, sizeof(int), 1, file) != 1) { exit(EXIT_FAILURE); }
239
+ GS = group_size; // set as global, as it will be used in many places
240
+ // figure out the file size
241
+ #if defined _WIN32
242
+ _fseeki64(file, 0, SEEK_END); // move file pointer to end of file
243
+ *file_size = _ftelli64(file); // get the file size, in bytes
244
+ #else
245
+ fseek(file, 0, SEEK_END); // move file pointer to end of file
246
+ *file_size = ftell(file); // get the file size, in bytes
247
+ #endif
248
+ fclose(file);
249
+ // memory map the Transformer weights into the data pointer
250
+ *fd = open(checkpoint, O_RDONLY); // open in read only mode
251
+ if (*fd == -1) { fprintf(stderr, "open failed!\n"); exit(EXIT_FAILURE); }
252
+ *data = mmap(NULL, *file_size, PROT_READ, MAP_PRIVATE, *fd, 0);
253
+ if (*data == MAP_FAILED) { fprintf(stderr, "mmap failed!\n"); exit(EXIT_FAILURE); }
254
+ void* weights_ptr = ((char*)*data) + header_size; // skip header bytes. char is 1 byte
255
+ memory_map_weights(weights, config, weights_ptr, shared_classifier);
256
+ }
257
+
258
+ void build_transformer(Transformer *t, char* checkpoint_path) {
259
+ // read in the Config and the Weights from the checkpoint
260
+ read_checkpoint(checkpoint_path, &t->config, &t->weights, &t->fd, &t->data, &t->file_size);
261
+ // allocate the RunState buffers
262
+ malloc_run_state(&t->state, &t->config);
263
+ }
264
+
265
+ void free_transformer(Transformer* t) {
266
+ // free QuantizedTensors
267
+ free(t->weights.q_tokens);
268
+ free(t->weights.token_embedding_table);
269
+ free(t->weights.wq);
270
+ free(t->weights.wk);
271
+ free(t->weights.wv);
272
+ free(t->weights.wo);
273
+ free(t->weights.w1);
274
+ free(t->weights.w2);
275
+ free(t->weights.w3);
276
+ if(t->weights.wcls != t->weights.q_tokens) { free(t->weights.wcls); }
277
+ // close the memory mapping
278
+ if (t->data != MAP_FAILED) { munmap(t->data, t->file_size); }
279
+ if (t->fd != -1) { close(t->fd); }
280
+ // free the RunState buffers
281
+ free_run_state(&t->state);
282
+ }
283
+
284
+ // ----------------------------------------------------------------------------
285
+ // neural net blocks; the dynamics of the Transformer
286
+
287
+ void rmsnorm(float* o, float* x, float* weight, int size) {
288
+ // calculate sum of squares
289
+ float ss = 0.0f;
290
+ for (int j = 0; j < size; j++) {
291
+ ss += x[j] * x[j];
292
+ }
293
+ ss /= size;
294
+ ss += 1e-5f;
295
+ ss = 1.0f / sqrtf(ss);
296
+ // normalize and scale
297
+ for (int j = 0; j < size; j++) {
298
+ o[j] = weight[j] * (ss * x[j]);
299
+ }
300
+ }
301
+
302
+ void softmax(float* x, int size) {
303
+ // find max value (for numerical stability)
304
+ float max_val = x[0];
305
+ for (int i = 1; i < size; i++) {
306
+ if (x[i] > max_val) {
307
+ max_val = x[i];
308
+ }
309
+ }
310
+ // exp and sum
311
+ float sum = 0.0f;
312
+ for (int i = 0; i < size; i++) {
313
+ x[i] = expf(x[i] - max_val);
314
+ sum += x[i];
315
+ }
316
+ // normalize
317
+ for (int i = 0; i < size; i++) {
318
+ x[i] /= sum;
319
+ }
320
+ }
321
+
322
+ void matmul(float* xout, QuantizedTensor *x, QuantizedTensor *w, int n, int d) {
323
+ // W (d,n) @ x (n,) -> xout (d,)
324
+ // by far the most amount of time is spent inside this little function
325
+ // inputs to this function are both quantized
326
+
327
+ int i;
328
+ #pragma omp parallel for private(i)
329
+ for (i = 0; i < d; i++) {
330
+
331
+ float val = 0.0f;
332
+ int32_t ival = 0;
333
+ int in = i * n;
334
+
335
+ // do the matmul in groups of GS
336
+ int j;
337
+ for (j = 0; j <= n - GS; j += GS) {
338
+ for (int k = 0; k < GS; k++) {
339
+ ival += ((int32_t) x->q[j + k]) * ((int32_t) w->q[in + j + k]);
340
+ }
341
+ val += ((float) ival) * w->s[(in + j) / GS] * x->s[j / GS];
342
+ ival = 0;
343
+ }
344
+
345
+ xout[i] = val;
346
+ }
347
+ }
348
+
349
+ float* forward(Transformer* transformer, int token, int pos) {
350
+
351
+ // a few convenience variables
352
+ Config* p = &transformer->config;
353
+ TransformerWeights* w = &transformer->weights;
354
+ RunState* s = &transformer->state;
355
+ float *x = s->x;
356
+ int dim = p->dim;
357
+ int kv_dim = (p->dim * p->n_kv_heads) / p->n_heads;
358
+ int kv_mul = p->n_heads / p->n_kv_heads; // integer multiplier of the kv sharing in multiquery
359
+ int hidden_dim = p->hidden_dim;
360
+ int head_size = dim / p->n_heads;
361
+
362
+ // copy the token embedding into x
363
+ memcpy(x, w->token_embedding_table + token*dim, dim * sizeof(float));
364
+
365
+ // forward all the layers
366
+ for(unsigned long long l = 0; l < p->n_layers; l++) {
367
+
368
+ // attention rmsnorm
369
+ rmsnorm(s->xb, x, w->rms_att_weight + l*dim, dim);
370
+
371
+ // qkv matmuls for this position
372
+ quantize(&s->xq, s->xb, dim);
373
+ matmul(s->q, &s->xq, w->wq + l, dim, dim);
374
+ matmul(s->k, &s->xq, w->wk + l, dim, kv_dim);
375
+ matmul(s->v, &s->xq, w->wv + l, dim, kv_dim);
376
+
377
+ // RoPE relative positional encoding: complex-valued rotate q and k in each head
378
+ for (int i = 0; i < p->n_heads; i++) {
379
+ for (int j = 0; j < head_size; j += 2) {
380
+ float freq = 1.0f / powf(500000.0f, (float)j / (float)head_size);
381
+ float val = pos * freq;
382
+ float fcr = cosf(val);
383
+ float fci = sinf(val);
384
+ float q0 = s->q[i * head_size + j];
385
+ float q1 = s->q[i * head_size + j + 1];
386
+ s->q[i * head_size + j] = q0 * fcr - q1 * fci;
387
+ s->q[i * head_size + j + 1] = q0 * fci + q1 * fcr;
388
+ if (i < p->n_kv_heads) {
389
+ float k0 = s->k[i * head_size + j];
390
+ float k1 = s->k[i * head_size + j + 1];
391
+ s->k[i * head_size + j] = k0 * fcr - k1 * fci;
392
+ s->k[i * head_size + j + 1] = k0 * fci + k1 * fcr;
393
+ }
394
+ }
395
+ }
396
+
397
+ // save key,value at this time step (pos) to our kv cache
398
+ int loff = l * p->seq_len * kv_dim; // kv cache layer offset for convenience
399
+ float* key_cache_row = s->key_cache + loff + pos * kv_dim;
400
+ float* value_cache_row = s->value_cache + loff + pos * kv_dim;
401
+ memcpy(key_cache_row, s->k, kv_dim * sizeof(*key_cache_row));
402
+ memcpy(value_cache_row, s->v, kv_dim * sizeof(*value_cache_row));
403
+
404
+ // multihead attention. iterate over all heads
405
+ int h;
406
+ #pragma omp parallel for private(h)
407
+ for (h = 0; h < p->n_heads; h++) {
408
+ // get the query vector for this head
409
+ float* q = s->q + h * head_size;
410
+ // attention scores for this head
411
+ float* att = s->att + h * p->seq_len;
412
+ // iterate over all timesteps, including the current one
413
+ for (int t = 0; t <= pos; t++) {
414
+ // get the key vector for this head and at this timestep
415
+ float* k = s->key_cache + loff + t * kv_dim + (h / kv_mul) * head_size;
416
+ // calculate the attention score as the dot product of q and k
417
+ float score = 0.0f;
418
+ for (int i = 0; i < head_size; i++) {
419
+ score += q[i] * k[i];
420
+ }
421
+ score /= sqrtf(head_size);
422
+ // save the score to the attention buffer
423
+ att[t] = score;
424
+ }
425
+
426
+ // softmax the scores to get attention weights, from 0..pos inclusively
427
+ softmax(att, pos + 1);
428
+
429
+ // weighted sum of the values, store back into xb
430
+ float* xb = s->xb + h * head_size;
431
+ memset(xb, 0, head_size * sizeof(float));
432
+ for (int t = 0; t <= pos; t++) {
433
+ // get the value vector for this head and at this timestep
434
+ float* v = s->value_cache + loff + t * kv_dim + (h / kv_mul) * head_size;
435
+ // get the attention weight for this timestep
436
+ float a = att[t];
437
+ // accumulate the weighted value into xb
438
+ for (int i = 0; i < head_size; i++) {
439
+ xb[i] += a * v[i];
440
+ }
441
+ }
442
+ }
443
+
444
+ // final matmul to get the output of the attention
445
+ quantize(&s->xq, s->xb, dim);
446
+ matmul(s->xb2, &s->xq, w->wo + l, dim, dim);
447
+
448
+ // residual connection back into x
449
+ for (int i = 0; i < dim; i++) {
450
+ x[i] += s->xb2[i];
451
+ }
452
+
453
+ // ffn rmsnorm
454
+ rmsnorm(s->xb, x, w->rms_ffn_weight + l*dim, dim);
455
+
456
+ // Now for FFN in PyTorch we have: self.w2(F.silu(self.w1(x)) * self.w3(x))
457
+ // first calculate self.w1(x) and self.w3(x)
458
+ quantize(&s->xq, s->xb, dim);
459
+ matmul(s->hb, &s->xq, w->w1 + l, dim, hidden_dim);
460
+ matmul(s->hb2, &s->xq, w->w3 + l, dim, hidden_dim);
461
+
462
+ // SwiGLU non-linearity
463
+ for (int i = 0; i < hidden_dim; i++) {
464
+ float val = s->hb[i];
465
+ // silu(x)=x*s(x), where s(x) is the logistic sigmoid
466
+ val *= (1.0f / (1.0f + expf(-val)));
467
+ // elementwise multiply with w3(x)
468
+ val *= s->hb2[i];
469
+ s->hb[i] = val;
470
+ }
471
+
472
+ // final matmul to get the output of the ffn
473
+ quantize(&s->hq, s->hb, hidden_dim);
474
+ matmul(s->xb, &s->hq, w->w2 + l, hidden_dim, dim);
475
+
476
+ // residual connection
477
+ for (int i = 0; i < dim; i++) {
478
+ x[i] += s->xb[i];
479
+ }
480
+ }
481
+
482
+ // final rmsnorm
483
+ rmsnorm(x, x, w->rms_final_weight, dim);
484
+
485
+ // classifier into logits
486
+ quantize(&s->xq, x, dim);
487
+ matmul(s->logits, &s->xq, w->wcls, dim, p->vocab_size);
488
+ return s->logits;
489
+ }
490
+
491
+ // ----------------------------------------------------------------------------
492
+ // The Byte Pair Encoding (BPE) Tokenizer that translates strings <-> tokens
493
+
494
+ typedef struct {
495
+ char *str;
496
+ int id;
497
+ } TokenIndex;
498
+
499
+ typedef struct {
500
+ char** vocab;
501
+ float* vocab_scores;
502
+ TokenIndex *sorted_vocab;
503
+ int vocab_size;
504
+ unsigned int max_token_length;
505
+ unsigned char byte_pieces[512]; // stores all single-byte strings
506
+ } Tokenizer;
507
+
508
+ int compare_tokens(const void *a, const void *b) {
509
+ return strcmp(((TokenIndex*)a)->str, ((TokenIndex*)b)->str);
510
+ }
511
+
512
+ void build_tokenizer(Tokenizer* t, char* tokenizer_path, int vocab_size) {
513
+ // i should have written the vocab_size into the tokenizer file... sigh
514
+ t->vocab_size = vocab_size;
515
+ // malloc space to hold the scores and the strings
516
+ t->vocab = (char**)malloc(vocab_size * sizeof(char*));
517
+ t->vocab_scores = (float*)malloc(vocab_size * sizeof(float));
518
+ t->sorted_vocab = NULL; // initialized lazily
519
+ for (int i = 0; i < 256; i++) {
520
+ t->byte_pieces[i * 2] = (unsigned char)i;
521
+ t->byte_pieces[i * 2 + 1] = '\0';
522
+ }
523
+ // read in the file
524
+ FILE *file = fopen(tokenizer_path, "rb");
525
+ if (!file) { fprintf(stderr, "couldn't load %s\n", tokenizer_path); exit(EXIT_FAILURE); }
526
+ if (fread(&t->max_token_length, sizeof(int), 1, file) != 1) { fprintf(stderr, "failed read\n"); exit(EXIT_FAILURE); }
527
+ int len;
528
+ for (int i = 0; i < vocab_size; i++) {
529
+ if (fread(t->vocab_scores + i, sizeof(float), 1, file) != 1) { fprintf(stderr, "failed read\n"); exit(EXIT_FAILURE);}
530
+ if (fread(&len, sizeof(int), 1, file) != 1) { fprintf(stderr, "failed read\n"); exit(EXIT_FAILURE); }
531
+ t->vocab[i] = (char *)malloc(len + 1);
532
+ if (fread(t->vocab[i], len, 1, file) != 1) { fprintf(stderr, "failed read\n"); exit(EXIT_FAILURE); }
533
+ t->vocab[i][len] = '\0'; // add the string terminating token
534
+ }
535
+ fclose(file);
536
+ }
537
+
538
+ void free_tokenizer(Tokenizer* t) {
539
+ for (int i = 0; i < t->vocab_size; i++) { free(t->vocab[i]); }
540
+ free(t->vocab);
541
+ free(t->vocab_scores);
542
+ free(t->sorted_vocab);
543
+ }
544
+
545
+ char* decode(Tokenizer* t, int prev_token, int token) {
546
+ char *piece = t->vocab[token];
547
+
548
+
549
+ // careful, some tokens designate raw bytes, and look like e.g. '<0x01>'
550
+ // parse this and convert and return the actual byte
551
+ unsigned char byte_val;
552
+ if (sscanf(piece, "<0x%02hhX>", &byte_val) == 1) {
553
+ piece = (char*)t->byte_pieces + byte_val * 2;
554
+ }
555
+ return piece;
556
+ }
557
+
558
+ void safe_printf(char *piece) {
559
+ // piece might be a raw byte token, and we only want to print printable chars or whitespace
560
+ // because some of the other bytes can be various control codes, backspace, etc.
561
+ if (piece == NULL) { return; }
562
+ if (piece[0] == '\0') { return; }
563
+ if (piece[1] == '\0') {
564
+ unsigned char byte_val = piece[0];
565
+ if (!(isprint(byte_val) || isspace(byte_val))) {
566
+ return; // bad byte, don't print it
567
+ }
568
+ }
569
+ printf("%s", piece);
570
+ }
571
+
572
+ int str_lookup(char *str, TokenIndex *sorted_vocab, int vocab_size) {
573
+ // efficiently find the perfect match for str in vocab, return its index or -1 if not found
574
+ TokenIndex tok = { .str = str }; // acts as the key to search for
575
+ TokenIndex *res = bsearch(&tok, sorted_vocab, vocab_size, sizeof(TokenIndex), compare_tokens);
576
+ return res != NULL ? res->id : -1;
577
+ }
578
+
579
+ void encode(Tokenizer* t, char *text, int8_t bos, int8_t eos, int *tokens, int *n_tokens) {
580
+ // encode the string text (input) into an upper-bound preallocated tokens[] array
581
+ // bos != 0 means prepend the BOS token (=1), eos != 0 means append the EOS token (=2)
582
+ if (text == NULL) { fprintf(stderr, "cannot encode NULL text\n"); exit(EXIT_FAILURE); }
583
+
584
+ if (t->sorted_vocab == NULL) {
585
+ // lazily malloc and sort the vocabulary
586
+ t->sorted_vocab = malloc(t->vocab_size * sizeof(TokenIndex));
587
+ for (int i = 0; i < t->vocab_size; i++) {
588
+ t->sorted_vocab[i].str = t->vocab[i];
589
+ t->sorted_vocab[i].id = i;
590
+ }
591
+ qsort(t->sorted_vocab, t->vocab_size, sizeof(TokenIndex), compare_tokens);
592
+ }
593
+
594
+ // create a temporary buffer that will store merge candidates of always two consecutive tokens
595
+ // *2 for concat, +1 for null terminator +2 for UTF8 (in case max_token_length is 1)
596
+ char* str_buffer = malloc((t->max_token_length*2 +1 +2) * sizeof(char));
597
+ size_t str_len = 0;
598
+
599
+ // start at 0 tokens
600
+ *n_tokens = 0;
601
+
602
+ // add optional BOS (=128000) token, if desired
603
+ if (bos) tokens[(*n_tokens)++] = 128000;
604
+
605
+ // add_dummy_prefix is true by default
606
+ // so prepend a dummy prefix token to the input string, but only if text != ""
607
+ // TODO: pretty sure this isn't correct in the general case but I don't have the
608
+ // energy to read more of the sentencepiece code to figure out what it's doing
609
+
610
+
611
+
612
+
613
+
614
+ // Okay UTF-8 time. This will get messy. Here is the reference from Wikipedia:
615
+ // Code point ? UTF-8 conversion
616
+ // First code point Last code point Byte 1 Byte 2 Byte 3 Byte 4
617
+ // U+0000 U+007F 0xxxxxxx
618
+ // U+0080 U+07FF 110xxxxx 10xxxxxx
619
+ // U+0800 U+FFFF 1110xxxx 10xxxxxx 10xxxxxx
620
+ // U+10000 U+10FFFF 11110xxx 10xxxxxx 10xxxxxx 10xxxxxx
621
+
622
+ // process the raw (UTF-8) byte sequence of the input string
623
+ for (char *c = text; *c != '\0'; c++) {
624
+
625
+ // reset buffer if the current byte is ASCII or a leading byte
626
+ // 0xC0 is 11000000, so (*c & 0xC0) keeps the first 2 bits and zeros the rest
627
+ // 0x80 is 10000000
628
+ // in UTF-8, all continuation bytes start with "10" in first two bits
629
+ // so in English this is: "if this byte is not a continuation byte"
630
+ if ((*c & 0xC0) != 0x80) {
631
+ // this byte must be either a leading byte (11...) or an ASCII char (0x...)
632
+ // => reset our location, as we're starting a new UTF-8 codepoint
633
+ str_len = 0;
634
+ }
635
+
636
+ // append the current byte to the buffer
637
+ str_buffer[str_len++] = *c; // ++ is post-increment, incremented after this line
638
+ str_buffer[str_len] = '\0';
639
+
640
+ // while the next character is a continuation byte, continue appending
641
+ // but if there are too many of them, just stop to avoid overruning str_buffer size.
642
+ if ((*(c+1) & 0xC0) == 0x80 && str_len < 4) {
643
+ continue;
644
+ }
645
+
646
+ // ok c+1 is not a continuation byte, so we've read in a full codepoint
647
+ int id = str_lookup(str_buffer, t->sorted_vocab, t->vocab_size);
648
+
649
+ if (id != -1) {
650
+ // we found this codepoint in vocab, add it as a token
651
+ tokens[(*n_tokens)++] = id;
652
+ } else {
653
+ // byte_fallback encoding: just encode each byte as a token
654
+ // +3 is here because the first 3 vocab elements are <unk>, <s>, </s>
655
+ // so the individual bytes only start at index 3
656
+ for (int i=0; i < str_len; i++) {
657
+ tokens[(*n_tokens)++] = (unsigned char)str_buffer[i] + 3;
658
+ }
659
+ }
660
+ str_len = 0; // protect against a sequence of stray UTF8 continuation bytes
661
+ }
662
+
663
+ // merge the best consecutive pair or triple each iteration, according to the scores in vocab_scores
664
+ while (1) {
665
+ float best_score = -1e10;
666
+ int best_id = -1;
667
+ int best_idx = -1;
668
+ int best_len = 2; // length of the best merge sequence (2 for pair, 3 for triple)
669
+
670
+ // first, try to find the best pair to merge
671
+ for (int i = 0; i < (*n_tokens - 1); i++) {
672
+ // check if we can merge the pair (tokens[i], tokens[i+1])
673
+ sprintf(str_buffer, "%s%s", t->vocab[tokens[i]], t->vocab[tokens[i+1]]);
674
+ int id = str_lookup(str_buffer, t->sorted_vocab, t->vocab_size);
675
+ if (id != -1 && t->vocab_scores[id] > best_score) {
676
+ // this merge pair exists in vocab! record its score and position
677
+ best_score = t->vocab_scores[id];
678
+ best_id = id;
679
+ best_idx = i;
680
+ }
681
+ }
682
+
683
+ // if no pair was found, try to find the best triple to merge
684
+ if (best_idx == -1) {
685
+ for (int i = 0; i < (*n_tokens - 2); i++) {
686
+ // check if we can merge the triple (tokens[i], tokens[i+1], tokens[i+2])
687
+ sprintf(str_buffer, "%s%s%s", t->vocab[tokens[i]], t->vocab[tokens[i+1]], t->vocab[tokens[i+2]]);
688
+ int id = str_lookup(str_buffer, t->sorted_vocab, t->vocab_size);
689
+ if (id != -1 && t->vocab_scores[id] > best_score) {
690
+ // this merge triple exists in vocab! record its score and position
691
+ best_score = t->vocab_scores[id];
692
+ best_id = id;
693
+ best_idx = i;
694
+ best_len = 3;
695
+ }
696
+ }
697
+ }
698
+
699
+ if (best_idx == -1) {
700
+ break; // we couldn't find any more pairs or triples to merge, so we're done
701
+ }
702
+
703
+ // merge the consecutive pair or triple (best_idx, best_idx+1[, best_idx+2]) into new token best_id
704
+ tokens[best_idx] = best_id;
705
+ // delete token(s) at position best_idx+1 (and optionally best_idx+2), shift the entire sequence back
706
+ for (int i = best_idx + 1; i < (*n_tokens - best_len + 1); i++) {
707
+ tokens[i] = tokens[i + best_len - 1];
708
+ }
709
+ (*n_tokens) -= (best_len - 1); // token length decreased by the number of merged tokens minus one
710
+ }
711
+
712
+ // add optional EOS (=128001) token, if desired
713
+ if (eos) tokens[(*n_tokens)++] = 128001;
714
+
715
+ free(str_buffer);
716
+ }
717
+
718
+ // ----------------------------------------------------------------------------
719
+ // The Sampler, which takes logits and returns a sampled token
720
+ // sampling can be done in a few ways: greedy argmax, sampling, top-p sampling
721
+
722
+ typedef struct {
723
+ float prob;
724
+ int index;
725
+ } ProbIndex; // struct used when sorting probabilities during top-p sampling
726
+
727
+ typedef struct {
728
+ int vocab_size;
729
+ ProbIndex* probindex; // buffer used in top-p sampling
730
+ float temperature;
731
+ float topp;
732
+ unsigned long long rng_state;
733
+ } Sampler;
734
+
735
+ int sample_argmax(float* probabilities, int n) {
736
+ // return the index that has the highest probability
737
+ int max_i = 0;
738
+ float max_p = probabilities[0];
739
+ for (int i = 1; i < n; i++) {
740
+ if (probabilities[i] > max_p) {
741
+ max_i = i;
742
+ max_p = probabilities[i];
743
+ }
744
+ }
745
+ return max_i;
746
+ }
747
+
748
+ int sample_mult(float* probabilities, int n, float coin) {
749
+ // sample index from probabilities (they must sum to 1!)
750
+ // coin is a random number in [0, 1), usually from random_f32()
751
+ float cdf = 0.0f;
752
+ for (int i = 0; i < n; i++) {
753
+ cdf += probabilities[i];
754
+ if (coin < cdf) {
755
+ return i;
756
+ }
757
+ }
758
+ return n - 1; // in case of rounding errors
759
+ }
760
+
761
+ int compare(const void* a, const void* b) {
762
+ ProbIndex* a_ = (ProbIndex*) a;
763
+ ProbIndex* b_ = (ProbIndex*) b;
764
+ if (a_->prob > b_->prob) return -1;
765
+ if (a_->prob < b_->prob) return 1;
766
+ return 0;
767
+ }
768
+
769
+ int sample_topp(float* probabilities, int n, float topp, ProbIndex* probindex, float coin) {
770
+ // top-p sampling (or "nucleus sampling") samples from the smallest set of
771
+ // tokens that exceed probability topp. This way we never sample tokens that
772
+ // have very low probabilities and are less likely to go "off the rails".
773
+ // coin is a random number in [0, 1), usually from random_f32()
774
+
775
+ int n0 = 0;
776
+ // quicksort indices in descending order of probabilities
777
+ // values smaller than (1 - topp) / (n - 1) cannot be part of the result
778
+ // so for efficiency we crop these out as candidates before sorting
779
+ const float cutoff = (1.0f - topp) / (n - 1);
780
+ for (int i = 0; i < n; i++) {
781
+ if (probabilities[i] >= cutoff) {
782
+ probindex[n0].index = i;
783
+ probindex[n0].prob = probabilities[i];
784
+ n0++;
785
+ }
786
+ }
787
+ qsort(probindex, n0, sizeof(ProbIndex), compare);
788
+
789
+ // truncate the list where cumulative probability exceeds topp
790
+ float cumulative_prob = 0.0f;
791
+ int last_idx = n0 - 1; // in case of rounding errors consider all elements
792
+ for (int i = 0; i < n0; i++) {
793
+ cumulative_prob += probindex[i].prob;
794
+ if (cumulative_prob > topp) {
795
+ last_idx = i;
796
+ break; // we've exceeded topp by including last_idx
797
+ }
798
+ }
799
+
800
+ // sample from the truncated list
801
+ float r = coin * cumulative_prob;
802
+ float cdf = 0.0f;
803
+ for (int i = 0; i <= last_idx; i++) {
804
+ cdf += probindex[i].prob;
805
+ if (r < cdf) {
806
+ return probindex[i].index;
807
+ }
808
+ }
809
+ return probindex[last_idx].index; // in case of rounding errors
810
+ }
811
+
812
+ void build_sampler(Sampler* sampler, int vocab_size, float temperature, float topp, unsigned long long rng_seed) {
813
+ sampler->vocab_size = vocab_size;
814
+ sampler->temperature = temperature;
815
+ sampler->topp = topp;
816
+ sampler->rng_state = rng_seed;
817
+ // buffer only used with nucleus sampling; may not need but it's ~small
818
+ sampler->probindex = malloc(sampler->vocab_size * sizeof(ProbIndex));
819
+ }
820
+
821
+ void free_sampler(Sampler* sampler) {
822
+ free(sampler->probindex);
823
+ }
824
+
825
+ unsigned int random_u32(unsigned long long *state) {
826
+ // xorshift rng: https://en.wikipedia.org/wiki/Xorshift#xorshift.2A
827
+ *state ^= *state >> 12;
828
+ *state ^= *state << 25;
829
+ *state ^= *state >> 27;
830
+ return (*state * 0x2545F4914F6CDD1Dull) >> 32;
831
+ }
832
+ float random_f32(unsigned long long *state) { // random float32 in [0,1)
833
+ return (random_u32(state) >> 8) / 16777216.0f;
834
+ }
835
+
836
+ int sample(Sampler* sampler, float* logits) {
837
+ // sample the token given the logits and some hyperparameters
838
+ int next;
839
+ if (sampler->temperature == 0.0f) {
840
+ // greedy argmax sampling: take the token with the highest probability
841
+ next = sample_argmax(logits, sampler->vocab_size);
842
+ } else {
843
+ // apply the temperature to the logits
844
+ for (int q=0; q<sampler->vocab_size; q++) { logits[q] /= sampler->temperature; }
845
+ // apply softmax to the logits to get the probabilities for next token
846
+ softmax(logits, sampler->vocab_size);
847
+ // flip a (float) coin (this is our source of entropy for sampling)
848
+ float coin = random_f32(&sampler->rng_state);
849
+ // we sample from this distribution to get the next token
850
+ if (sampler->topp <= 0 || sampler->topp >= 1) {
851
+ // simply sample from the predicted probability distribution
852
+ next = sample_mult(logits, sampler->vocab_size, coin);
853
+ } else {
854
+ // top-p (nucleus) sampling, clamping the least likely tokens to zero
855
+ next = sample_topp(logits, sampler->vocab_size, sampler->topp, sampler->probindex, coin);
856
+ }
857
+ }
858
+ return next;
859
+ }
860
+
861
+ // ----------------------------------------------------------------------------
862
+ // utilities: time
863
+
864
+ long time_in_ms() {
865
+ // return time in milliseconds, for benchmarking the model speed
866
+ struct timespec time;
867
+ clock_gettime(CLOCK_REALTIME, &time);
868
+ return time.tv_sec * 1000 + time.tv_nsec / 1000000;
869
+ }
870
+
871
+ // ----------------------------------------------------------------------------
872
+ // generation loop
873
+
874
+ void generate(Transformer *transformer, Tokenizer *tokenizer, Sampler *sampler, char *prompt, int steps) {
875
+ char *empty_prompt = "";
876
+ if (prompt == NULL) { prompt = empty_prompt; }
877
+
878
+ // encode the (string) prompt into tokens sequence
879
+ int num_prompt_tokens = 0;
880
+ int* prompt_tokens = (int*)malloc((strlen(prompt)+3) * sizeof(int)); // +3 for '\0', ?BOS, ?EOS
881
+ encode(tokenizer, prompt, 1, 0, prompt_tokens, &num_prompt_tokens);
882
+ if (num_prompt_tokens < 1) {
883
+ fprintf(stderr, "something is wrong, expected at least 1 prompt token\n");
884
+ exit(EXIT_FAILURE);
885
+ }
886
+
887
+ // start the main loop
888
+ long start = 0; // used to time our code, only initialized after first iteration
889
+ int next; // will store the next token in the sequence
890
+ int token = prompt_tokens[0]; // kick off with the first token in the prompt
891
+ int pos = 0; // position in the sequence
892
+
893
+ while (pos < steps) {
894
+
895
+ // forward the transformer to get logits for the next token
896
+ float* logits = forward(transformer, token, pos);
897
+
898
+ // advance the state machine
899
+ if (pos < num_prompt_tokens - 1) {
900
+ // if we are still processing the input prompt, force the next prompt token
901
+ next = prompt_tokens[pos + 1];
902
+ } else {
903
+ // otherwise sample the next token from the logits
904
+ next = sample(sampler, logits);
905
+ }
906
+ pos++;
907
+
908
+ // data-dependent terminating condition: the BOS (=1) token delimits sequences
909
+ if ((next == 128001 || next == 128009) && pos > num_prompt_tokens) break;
910
+ // print the token as string, decode it with the Tokenizer object
911
+ char* piece = decode(tokenizer, token, next);
912
+ safe_printf(piece); // same as printf("%s", piece), but skips "unsafe" bytes
913
+ fflush(stdout);
914
+ token = next;
915
+
916
+ // init the timer here because the first iteration can be slower
917
+ if (start == 0) { start = time_in_ms(); }
918
+ }
919
+ printf("\n");
920
+
921
+ // report achieved tok/s (pos-1 because the timer starts after first iteration)
922
+ if (pos > 1) {
923
+ long end = time_in_ms();
924
+ fprintf(stderr, "achieved tok/s: %f\n", (pos-1) / (double)(end-start)*1000);
925
+ }
926
+
927
+ free(prompt_tokens);
928
+ }
929
+
930
+ void read_stdin(const char* guide, char* buffer, size_t bufsize) {
931
+ // read a line from stdin, up to but not including \n
932
+ printf("%s", guide);
933
+ if (fgets(buffer, bufsize, stdin) != NULL) {
934
+ size_t len = strlen(buffer);
935
+ if (len > 0 && buffer[len - 1] == '\n') {
936
+ buffer[len - 1] = '\0'; // strip newline
937
+ }
938
+ }
939
+ }
940
+
941
+ // ----------------------------------------------------------------------------
942
+ // chat loop
943
+ // I manually inspected the tokens for a few chat conversations compared to
944
+ // python reference and that seemed ok, but this was not thoroughly tested and
945
+ // is not safely implemented, it's more a proof of concept atm.
946
+
947
+ void chat(Transformer *transformer, Tokenizer *tokenizer, Sampler *sampler,
948
+ char *cli_user_prompt, char *cli_system_prompt, int steps) {
949
+
950
+ // buffers for reading the system prompt and user prompt from stdin
951
+ // you'll notice they are somewhat haphazardly and unsafely set atm
952
+ char* system_prompt = (char*)malloc(32768 * sizeof(char));
953
+ char* user_prompt = (char*)malloc(32768 * sizeof(char));
954
+ int num_prompt_tokens = 0;
955
+ int* prompt_tokens = (int*)malloc(32768 * sizeof(int));
956
+ int* system_prompt_tokens = (int*)malloc(32768 * sizeof(int));
957
+ int* user_prompt_tokens = (int*)malloc(32768 * sizeof(int));
958
+ int user_idx=0;
959
+
960
+ // start the main loop
961
+ int8_t user_turn = 1; // user starts
962
+ int next; // will store the next token in the sequence
963
+ int token; // stores the current token to feed into the transformer
964
+
965
+ int pos = 0; // position in the sequence
966
+ while (pos < steps) {
967
+
968
+ // when it is the user's turn to contribute tokens to the dialog...
969
+ if (user_turn) {
970
+ // get the (optional) system prompt at position 0
971
+ if (pos == 0) {
972
+ // at position 0, the user can also contribute a system prompt
973
+ prompt_tokens[num_prompt_tokens++] = 128000; // "<|begin_of_text|>"
974
+ prompt_tokens[num_prompt_tokens++] = 128006; // "<|start_header_id|>"
975
+ prompt_tokens[num_prompt_tokens++] = 9125; // "system"
976
+ prompt_tokens[num_prompt_tokens++] = 128007; // "<|end_header_id|>"
977
+ prompt_tokens[num_prompt_tokens++] = 271; // "\n\n"
978
+ if (cli_system_prompt == NULL) {
979
+ // system prompt was not passed in, attempt to get it from stdin
980
+ read_stdin("Enter system prompt (optional): ", system_prompt, 32768);
981
+ } else {
982
+ // system prompt was passed in, use it
983
+ strcpy(system_prompt, cli_system_prompt);
984
+ }
985
+ if (system_prompt != NULL) {
986
+ int num_system_prompt_tokens = 0;
987
+ encode(tokenizer, system_prompt, 0, 0, system_prompt_tokens, &num_system_prompt_tokens);
988
+ for (int i=0; i<num_system_prompt_tokens; i++) {
989
+ prompt_tokens[num_prompt_tokens++] = system_prompt_tokens[i];
990
+ }
991
+ }
992
+ prompt_tokens[num_prompt_tokens++] = 128009; // "<|eot_id|>"
993
+ } else {
994
+ num_prompt_tokens = 0;
995
+ }
996
+ prompt_tokens[num_prompt_tokens++] = 128006; // "<|start_header_id|>"
997
+ prompt_tokens[num_prompt_tokens++] = 882; // "user"
998
+ prompt_tokens[num_prompt_tokens++] = 128007; // "<|end_header_id|>"
999
+ prompt_tokens[num_prompt_tokens++] = 271; // "\n\n"
1000
+ // get the user prompt
1001
+ if (pos == 0 && cli_user_prompt != NULL) {
1002
+ // user prompt for position 0 was passed in, use it
1003
+ strcpy(user_prompt, cli_user_prompt);
1004
+ } else {
1005
+ // otherwise get user prompt from stdin
1006
+ read_stdin("User (or exit): ", user_prompt, 32768);
1007
+ if(strcmp(user_prompt, "exit")==0) break;
1008
+ }
1009
+ int num_user_prompt_tokens = 0;
1010
+ // encode the user prompt into tokens
1011
+ encode(tokenizer, user_prompt, 0, 0, user_prompt_tokens, &num_user_prompt_tokens);
1012
+ for (int i=0; i<num_user_prompt_tokens; i++) {
1013
+ prompt_tokens[num_prompt_tokens++] = user_prompt_tokens[i];
1014
+ }
1015
+ prompt_tokens[num_prompt_tokens++] = 128009; // "<|eot_id|>"
1016
+ prompt_tokens[num_prompt_tokens++] = 128006; // "<|start_header_id|>"
1017
+ prompt_tokens[num_prompt_tokens++] = 78191; // "assistant"
1018
+ prompt_tokens[num_prompt_tokens++] = 128007; // "<|end_header_id|>"
1019
+ prompt_tokens[num_prompt_tokens++] = 271; // "\n\n"
1020
+
1021
+
1022
+ user_idx = 0; // reset the user index
1023
+ user_turn = 0;
1024
+ printf("Assistant: ");
1025
+ }
1026
+
1027
+ // determine the token to pass into the transformer next
1028
+ if (user_idx < num_prompt_tokens) {
1029
+ // if we are still processing the input prompt, force the next prompt token
1030
+ token = prompt_tokens[user_idx++];
1031
+ } else {
1032
+ // otherwise use the next token sampled from previous turn
1033
+ token = next;
1034
+ }
1035
+ // EOS (=128009) token ends the Assistant turn
1036
+ if (user_idx >= num_prompt_tokens && (token == 128009 || token == 128001)) { user_turn = 1; }
1037
+
1038
+ // forward the transformer to get logits for the next token
1039
+ float* logits = forward(transformer, token, pos);
1040
+ next = sample(sampler, logits);
1041
+ pos++;
1042
+
1043
+ if (user_idx >= num_prompt_tokens && next != 128009 && next != 128001 && next != 128006) {
1044
+ // the Assistant is responding, so print its output
1045
+ char* piece = decode(tokenizer, token, next);
1046
+ safe_printf(piece); // same as printf("%s", piece), but skips "unsafe" bytes
1047
+ fflush(stdout);
1048
+ }
1049
+ if (user_idx >= num_prompt_tokens && next == 128009 || next == 128001) { printf("\n"); }
1050
+ }
1051
+ printf("\n");
1052
+ free(prompt_tokens);
1053
+ free(system_prompt_tokens);
1054
+ free(user_prompt_tokens);
1055
+ free(system_prompt);
1056
+ free(user_prompt);
1057
+ }
1058
+
1059
+
1060
+ // ----------------------------------------------------------------------------
1061
+ // CLI, include only if not testing
1062
+ #ifndef TESTING
1063
+
1064
+ void error_usage() {
1065
+ fprintf(stderr, "Usage: run <checkpoint> [options]\n");
1066
+ fprintf(stderr, "Example: run model.bin -n 4096 -i \"Once upon a time\"\n");
1067
+ fprintf(stderr, "Options:\n");
1068
+ fprintf(stderr, " -t <float> temperature in [0,inf], default 1.0\n");
1069
+ fprintf(stderr, " -p <float> p value in top-p (nucleus) sampling in [0,1] default 0.9\n");
1070
+ fprintf(stderr, " -s <int> random seed, default time(NULL)\n");
1071
+ fprintf(stderr, " -n <int> number of steps to run for, default 4096. 0 = max_seq_len\n");
1072
+ fprintf(stderr, " -i <string> input prompt\n");
1073
+ fprintf(stderr, " -z <string> optional path to custom tokenizer\n");
1074
+ fprintf(stderr, " -m <string> mode: generate|chat, default: generate\n");
1075
+ fprintf(stderr, " -y <string> (optional) system prompt in chat mode\n");
1076
+ exit(EXIT_FAILURE);
1077
+ }
1078
+
1079
+ int main(int argc, char *argv[]) {
1080
+
1081
+ // default parameters
1082
+ char *checkpoint_path = NULL; // e.g. out/model.bin
1083
+ char *tokenizer_path = "tokenizer.bin";
1084
+ float temperature = 1.0f; // 0.0 = greedy deterministic. 1.0 = original. don't set higher
1085
+ float topp = 0.9f; // top-p in nucleus sampling. 1.0 = off. 0.9 works well, but slower
1086
+ int steps = 4096; // number of steps to run for
1087
+ char *prompt = NULL; // prompt string
1088
+ unsigned long long rng_seed = 0; // seed rng with time by default
1089
+ char *mode = "generate"; // generate|chat
1090
+ char *system_prompt = NULL; // the (optional) system prompt to use in chat mode
1091
+
1092
+ // poor man's C argparse so we can override the defaults above from the command line
1093
+ if (argc >= 2) { checkpoint_path = argv[1]; } else { error_usage(); }
1094
+ for (int i = 2; i < argc; i+=2) {
1095
+ // do some basic validation
1096
+ if (i + 1 >= argc) { error_usage(); } // must have arg after flag
1097
+ if (argv[i][0] != '-') { error_usage(); } // must start with dash
1098
+ if (strlen(argv[i]) != 2) { error_usage(); } // must be -x (one dash, one letter)
1099
+ // read in the args
1100
+ if (argv[i][1] == 't') { temperature = atof(argv[i + 1]); }
1101
+ else if (argv[i][1] == 'p') { topp = atof(argv[i + 1]); }
1102
+ else if (argv[i][1] == 's') { rng_seed = atoi(argv[i + 1]); }
1103
+ else if (argv[i][1] == 'n') { steps = atoi(argv[i + 1]); }
1104
+ else if (argv[i][1] == 'i') { prompt = argv[i + 1]; }
1105
+ else if (argv[i][1] == 'z') { tokenizer_path = argv[i + 1]; }
1106
+ else if (argv[i][1] == 'm') { mode = argv[i + 1]; }
1107
+ else if (argv[i][1] == 'y') { system_prompt = argv[i + 1]; }
1108
+ else { error_usage(); }
1109
+ }
1110
+
1111
+ // parameter validation/overrides
1112
+ if (rng_seed <= 0) rng_seed = (unsigned int)time(NULL);
1113
+ if (temperature < 0.0) temperature = 0.0;
1114
+ if (topp < 0.0 || 1.0 < topp) topp = 0.9;
1115
+ if (steps < 0) steps = 0;
1116
+
1117
+ // build the Transformer via the model .bin file
1118
+ Transformer transformer;
1119
+ build_transformer(&transformer, checkpoint_path);
1120
+ if (steps == 0 || steps > transformer.config.seq_len) steps = transformer.config.seq_len; // override to ~max length
1121
+
1122
+ // build the Tokenizer via the tokenizer .bin file
1123
+ Tokenizer tokenizer;
1124
+ build_tokenizer(&tokenizer, tokenizer_path, transformer.config.vocab_size);
1125
+
1126
+ // build the Sampler
1127
+ Sampler sampler;
1128
+ build_sampler(&sampler, transformer.config.vocab_size, temperature, topp, rng_seed);
1129
+
1130
+ // run!
1131
+ if (strcmp(mode, "generate") == 0) {
1132
+ generate(&transformer, &tokenizer, &sampler, prompt, steps);
1133
+ } else if (strcmp(mode, "chat") == 0) {
1134
+ chat(&transformer, &tokenizer, &sampler, prompt, system_prompt, steps);
1135
+ } else {
1136
+ fprintf(stderr, "unknown mode: %s\n", mode);
1137
+ error_usage();
1138
+ }
1139
+
1140
+ // memory and file handles cleanup
1141
+ free_sampler(&sampler);
1142
+ free_tokenizer(&tokenizer);
1143
+ free_transformer(&transformer);
1144
+ return 0;
1145
+ }
1146
+ #endif
runqdll.c ADDED
@@ -0,0 +1,1116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /* Inference for Llama-3 Transformer model in pure C, int8 quantized forward pass. */
2
+
3
+ #include <stdio.h>
4
+ #include <stdlib.h>
5
+ #include <ctype.h>
6
+ #include <stdint.h>
7
+ #include <time.h>
8
+ #include <math.h>
9
+ #include <string.h>
10
+ #include <fcntl.h>
11
+ #if defined _WIN32
12
+ #include "win.h"
13
+ #else
14
+ #include <unistd.h>
15
+ #include <sys/mman.h>
16
+ #endif
17
+ // ----------------------------------------------------------------------------
18
+ // Globals
19
+ int GS = 0; // group size global for quantization of the weights
20
+
21
+ // ----------------------------------------------------------------------------
22
+ // Transformer model
23
+
24
+ typedef struct {
25
+ int dim; // transformer dimension
26
+ int hidden_dim; // for ffn layers
27
+ int n_layers; // number of layers
28
+ int n_heads; // number of query heads
29
+ int n_kv_heads; // number of key/value heads (can be < query heads because of multiquery)
30
+ int vocab_size; // vocabulary size, usually 4096 (byte-level)
31
+ int seq_len; // max sequence length
32
+ } Config;
33
+
34
+ typedef struct {
35
+ int8_t* q; // quantized values
36
+ float* s; // scaling factors
37
+ } QuantizedTensor;
38
+
39
+ typedef struct {
40
+ // token embedding table
41
+ QuantizedTensor *q_tokens; // (vocab_size, dim)
42
+ float* token_embedding_table; // same, but dequantized
43
+
44
+ // weights for rmsnorms
45
+ float* rms_att_weight; // (layer, dim) rmsnorm weights
46
+ float* rms_ffn_weight; // (layer, dim)
47
+ // weights for matmuls. note dim == n_heads * head_size
48
+ QuantizedTensor *wq; // (layer, dim, n_heads * head_size)
49
+ QuantizedTensor *wk; // (layer, dim, n_kv_heads * head_size)
50
+ QuantizedTensor *wv; // (layer, dim, n_kv_heads * head_size)
51
+ QuantizedTensor *wo; // (layer, n_heads * head_size, dim)
52
+ // weights for ffn
53
+ QuantizedTensor *w1; // (layer, hidden_dim, dim)
54
+ QuantizedTensor *w2; // (layer, dim, hidden_dim)
55
+ QuantizedTensor *w3; // (layer, hidden_dim, dim)
56
+ // final rmsnorm
57
+ float* rms_final_weight; // (dim,)
58
+ // (optional) classifier weights for the logits, on the last layer
59
+ QuantizedTensor *wcls;
60
+ } TransformerWeights;
61
+
62
+ typedef struct {
63
+ // current wave of activations
64
+ float *x; // activation at current time stamp (dim,)
65
+ float *xb; // same, but inside a residual branch (dim,)
66
+ float *xb2; // an additional buffer just for convenience (dim,)
67
+ float *hb; // buffer for hidden dimension in the ffn (hidden_dim,)
68
+ float *hb2; // buffer for hidden dimension in the ffn (hidden_dim,)
69
+ QuantizedTensor xq; // quantized x (dim,)
70
+ QuantizedTensor hq; // quantized hb (hidden_dim,)
71
+ float *q; // query (dim,)
72
+ float *k; // key (dim,)
73
+ float *v; // value (dim,)
74
+ float *att; // buffer for scores/attention values (n_heads, seq_len)
75
+ float *logits; // output logits
76
+ // kv cache
77
+ float* key_cache; // (layer, seq_len, dim)
78
+ float* value_cache; // (layer, seq_len, dim)
79
+ } RunState;
80
+
81
+ typedef struct {
82
+ Config config; // the hyperparameters of the architecture (the blueprint)
83
+ TransformerWeights weights; // the weights of the model
84
+ RunState state; // buffers for the "wave" of activations in the forward pass
85
+ // some more state needed to properly clean up the memory mapping (sigh)
86
+ int fd; // file descriptor for memory mapping
87
+ float* data; // memory mapped data pointer
88
+ ssize_t file_size; // size of the checkpoint file in bytes
89
+ } Transformer;
90
+
91
+ void malloc_run_state(RunState* s, Config* p) {
92
+ // we calloc instead of malloc to keep valgrind happy
93
+ int kv_dim = (p->dim * p->n_kv_heads) / p->n_heads;
94
+ s->x = calloc(p->dim, sizeof(float));
95
+ s->xb = calloc(p->dim, sizeof(float));
96
+ s->xb2 = calloc(p->dim, sizeof(float));
97
+ s->hb = calloc(p->hidden_dim, sizeof(float));
98
+ s->hb2 = calloc(p->hidden_dim, sizeof(float));
99
+ s->xq = (QuantizedTensor) { .q = calloc(p->dim, sizeof(int8_t)), .s = calloc(p->dim, sizeof(float)) };
100
+ s->hq = (QuantizedTensor) { .q = calloc(p->hidden_dim, sizeof(int8_t)), .s = calloc(p->hidden_dim, sizeof(float)) };
101
+ s->q = calloc(p->dim, sizeof(float));
102
+ s->k = calloc(kv_dim, sizeof(float));
103
+ s->v = calloc(kv_dim, sizeof(float));
104
+ s->att = calloc(p->n_heads * p->seq_len, sizeof(float));
105
+ s->logits = calloc(p->vocab_size, sizeof(float));
106
+ s->key_cache = calloc(p->n_layers * p->seq_len * kv_dim, sizeof(float));
107
+ s->value_cache = calloc(p->n_layers * p->seq_len * kv_dim, sizeof(float));
108
+ // ensure all mallocs went fine
109
+ if (!s->x || !s->xb || !s->xb2 || !s->hb || !s->hb2 || !s->q
110
+ || !s->k || !s->v || !s->att || !s->logits || !s->key_cache
111
+ || !s->value_cache) {
112
+ fprintf(stderr, "malloc failed!\n");
113
+ exit(EXIT_FAILURE);
114
+ }
115
+ }
116
+
117
+ void free_run_state(RunState* s) {
118
+ free(s->x);
119
+ free(s->xb);
120
+ free(s->xb2);
121
+ free(s->hb);
122
+ free(s->hb2);
123
+ free(s->xq.q);
124
+ free(s->xq.s);
125
+ free(s->hq.q);
126
+ free(s->hq.s);
127
+ free(s->q);
128
+ free(s->k);
129
+ free(s->v);
130
+ free(s->att);
131
+ free(s->logits);
132
+ free(s->key_cache);
133
+ free(s->value_cache);
134
+ }
135
+
136
+ // ----------------------------------------------------------------------------
137
+ // Quantization functions
138
+
139
+ void dequantize(QuantizedTensor *qx, float* x, int n) {
140
+ for (int i = 0; i < n; i++) {
141
+ x[i] = qx->q[i] * qx->s[i / GS];
142
+ }
143
+ }
144
+
145
+ void quantize(QuantizedTensor *qx, float* x, int n) {
146
+ int num_groups = n / GS;
147
+ float Q_MAX = 127.0f;
148
+
149
+ for (int group = 0; group < num_groups; group++) {
150
+
151
+ // find the max absolute value in the current group
152
+ float wmax = 0.0;
153
+ for (int i = 0; i < GS; i++) {
154
+ float val = fabs(x[group * GS + i]);
155
+ if (val > wmax) {
156
+ wmax = val;
157
+ }
158
+ }
159
+
160
+ // calculate and write the scaling factor
161
+ float scale = wmax / Q_MAX;
162
+ qx->s[group] = scale;
163
+
164
+ // calculate and write the quantized values
165
+ for (int i = 0; i < GS; i++) {
166
+ float quant_value = x[group * GS + i] / scale; // scale
167
+ int8_t quantized = (int8_t) round(quant_value); // round and clamp
168
+ qx->q[group * GS + i] = quantized;
169
+ }
170
+ }
171
+ }
172
+
173
+ /* initialize `n` x quantized tensor (with `size_each` elements), starting from memory pointed at *ptr */
174
+ QuantizedTensor *init_quantized_tensors(void **ptr, int n, int size_each) {
175
+ void *p = *ptr;
176
+ QuantizedTensor *res = malloc(n * sizeof(QuantizedTensor));
177
+ for(int i=0; i<n; i++) {
178
+ /* map quantized int8 values*/
179
+ res[i].q = (int8_t*)p;
180
+ p = (int8_t*)p + size_each;
181
+ /* map scale factors */
182
+ res[i].s = (float*)p;
183
+ p = (float*)p + size_each / GS;
184
+ }
185
+ *ptr = p; // advance ptr to current position
186
+ return res;
187
+ }
188
+
189
+ void memory_map_weights(TransformerWeights *w, Config* p, void* ptr, uint8_t shared_classifier) {
190
+ int head_size = p->dim / p->n_heads;
191
+ // first are the parameters that are kept in fp32 (the rmsnorm (1D) weights)
192
+ float* fptr = (float*) ptr; // cast our pointer to float*
193
+ w->rms_att_weight = fptr;
194
+ fptr += p->n_layers * p->dim;
195
+ w->rms_ffn_weight = fptr;
196
+ fptr += p->n_layers * p->dim;
197
+ w->rms_final_weight = fptr;
198
+ fptr += p->dim;
199
+
200
+ // now read all the quantized weights
201
+ ptr = (void*)fptr; // now cast the pointer back to void*
202
+ w->q_tokens = init_quantized_tensors(&ptr, 1, p->vocab_size * p->dim);
203
+ // dequantize token embedding table
204
+ w->token_embedding_table = malloc(p->vocab_size * p->dim * sizeof(float));
205
+ dequantize(w->q_tokens, w->token_embedding_table, p->vocab_size * p->dim);
206
+
207
+ w->wq = init_quantized_tensors(&ptr, p->n_layers, p->dim * (p->n_heads * head_size));
208
+ w->wk = init_quantized_tensors(&ptr, p->n_layers, p->dim * (p->n_kv_heads * head_size));
209
+ w->wv = init_quantized_tensors(&ptr, p->n_layers, p->dim * (p->n_kv_heads * head_size));
210
+ w->wo = init_quantized_tensors(&ptr, p->n_layers, (p->n_heads * head_size) * p->dim);
211
+
212
+ w->w1 = init_quantized_tensors(&ptr, p->n_layers, p->dim * p->hidden_dim);
213
+ w->w2 = init_quantized_tensors(&ptr, p->n_layers, p->hidden_dim * p->dim);
214
+ w->w3 = init_quantized_tensors(&ptr, p->n_layers, p->dim * p->hidden_dim);
215
+
216
+ w->wcls = shared_classifier ? w->q_tokens : init_quantized_tensors(&ptr, 1, p->dim * p->vocab_size);
217
+ }
218
+
219
+ void read_checkpoint(char* checkpoint, Config* config, TransformerWeights* weights,
220
+ int* fd, float** data, ssize_t* file_size) {
221
+ FILE *file = fopen(checkpoint, "rb");
222
+ if (!file) { fprintf(stderr, "Couldn't open file %s\n", checkpoint); exit(EXIT_FAILURE); }
223
+ // read in magic number (uint32), has to be 0x616b3432, i.e. "ak42" in ASCII
224
+ uint32_t magic_number;
225
+ if (fread(&magic_number, sizeof(uint32_t), 1, file) != 1) { exit(EXIT_FAILURE); }
226
+ if (magic_number != 0x616b3432) { fprintf(stderr, "Bad magic number\n"); exit(EXIT_FAILURE); }
227
+ // read in the version number (uint32), has to be 2
228
+ int version;
229
+ if (fread(&version, sizeof(int), 1, file) != 1) { exit(EXIT_FAILURE); }
230
+ if (version != 2) { fprintf(stderr, "Bad version %d, need version 2\n", version); exit(EXIT_FAILURE); }
231
+ int header_size = 256; // the header size for version 2 in bytes
232
+ // read in the Config
233
+ if (fread(config, sizeof(Config), 1, file) != 1) { exit(EXIT_FAILURE); }
234
+ // read in flags
235
+ uint8_t shared_classifier; // a byte to indicate if the classifier is shared
236
+ if (fread(&shared_classifier, sizeof(uint8_t), 1, file) != 1) { exit(EXIT_FAILURE); }
237
+ int group_size; // the group size used in quantization
238
+ if (fread(&group_size, sizeof(int), 1, file) != 1) { exit(EXIT_FAILURE); }
239
+ GS = group_size; // set as global, as it will be used in many places
240
+ // figure out the file size
241
+ fseek(file, 0, SEEK_END); // move file pointer to end of file
242
+ *file_size = ftell(file); // get the file size, in bytes
243
+ fclose(file);
244
+ // memory map the Transformer weights into the data pointer
245
+ *fd = open(checkpoint, O_RDONLY); // open in read only mode
246
+ if (*fd == -1) { fprintf(stderr, "open failed!\n"); exit(EXIT_FAILURE); }
247
+ *data = mmap(NULL, *file_size, PROT_READ, MAP_PRIVATE, *fd, 0);
248
+ if (*data == MAP_FAILED) { fprintf(stderr, "mmap failed!\n"); exit(EXIT_FAILURE); }
249
+ void* weights_ptr = ((char*)*data) + header_size; // skip header bytes. char is 1 byte
250
+ memory_map_weights(weights, config, weights_ptr, shared_classifier);
251
+ }
252
+
253
+ void build_transformer(Transformer *t, char* checkpoint_path) {
254
+ // read in the Config and the Weights from the checkpoint
255
+ read_checkpoint(checkpoint_path, &t->config, &t->weights, &t->fd, &t->data, &t->file_size);
256
+ // allocate the RunState buffers
257
+ malloc_run_state(&t->state, &t->config);
258
+ }
259
+
260
+ void free_transformer(Transformer* t) {
261
+ // free QuantizedTensors
262
+ free(t->weights.q_tokens);
263
+ free(t->weights.token_embedding_table);
264
+ free(t->weights.wq);
265
+ free(t->weights.wk);
266
+ free(t->weights.wv);
267
+ free(t->weights.wo);
268
+ free(t->weights.w1);
269
+ free(t->weights.w2);
270
+ free(t->weights.w3);
271
+ if(t->weights.wcls != t->weights.q_tokens) { free(t->weights.wcls); }
272
+ // close the memory mapping
273
+ if (t->data != MAP_FAILED) { munmap(t->data, t->file_size); }
274
+ if (t->fd != -1) { close(t->fd); }
275
+ // free the RunState buffers
276
+ free_run_state(&t->state);
277
+ }
278
+
279
+ // ----------------------------------------------------------------------------
280
+ // neural net blocks; the dynamics of the Transformer
281
+
282
+ void rmsnorm(float* o, float* x, float* weight, int size) {
283
+ // calculate sum of squares
284
+ float ss = 0.0f;
285
+ for (int j = 0; j < size; j++) {
286
+ ss += x[j] * x[j];
287
+ }
288
+ ss /= size;
289
+ ss += 1e-5f;
290
+ ss = 1.0f / sqrtf(ss);
291
+ // normalize and scale
292
+ for (int j = 0; j < size; j++) {
293
+ o[j] = weight[j] * (ss * x[j]);
294
+ }
295
+ }
296
+
297
+ void softmax(float* x, int size) {
298
+ // find max value (for numerical stability)
299
+ float max_val = x[0];
300
+ for (int i = 1; i < size; i++) {
301
+ if (x[i] > max_val) {
302
+ max_val = x[i];
303
+ }
304
+ }
305
+ // exp and sum
306
+ float sum = 0.0f;
307
+ for (int i = 0; i < size; i++) {
308
+ x[i] = expf(x[i] - max_val);
309
+ sum += x[i];
310
+ }
311
+ // normalize
312
+ for (int i = 0; i < size; i++) {
313
+ x[i] /= sum;
314
+ }
315
+ }
316
+
317
+ void matmul(float* xout, QuantizedTensor *x, QuantizedTensor *w, int n, int d) {
318
+ // W (d,n) @ x (n,) -> xout (d,)
319
+ // by far the most amount of time is spent inside this little function
320
+ // inputs to this function are both quantized
321
+
322
+ int i;
323
+ #pragma omp parallel for private(i)
324
+ for (i = 0; i < d; i++) {
325
+
326
+ float val = 0.0f;
327
+ int32_t ival = 0;
328
+ int in = i * n;
329
+
330
+ // do the matmul in groups of GS
331
+ int j;
332
+ for (j = 0; j <= n - GS; j += GS) {
333
+ for (int k = 0; k < GS; k++) {
334
+ ival += ((int32_t) x->q[j + k]) * ((int32_t) w->q[in + j + k]);
335
+ }
336
+ val += ((float) ival) * w->s[(in + j) / GS] * x->s[j / GS];
337
+ ival = 0;
338
+ }
339
+
340
+ xout[i] = val;
341
+ }
342
+ }
343
+
344
+ float* forward(Transformer* transformer, int token, int pos) {
345
+
346
+ // a few convenience variables
347
+ Config* p = &transformer->config;
348
+ TransformerWeights* w = &transformer->weights;
349
+ RunState* s = &transformer->state;
350
+ float *x = s->x;
351
+ int dim = p->dim;
352
+ int kv_dim = (p->dim * p->n_kv_heads) / p->n_heads;
353
+ int kv_mul = p->n_heads / p->n_kv_heads; // integer multiplier of the kv sharing in multiquery
354
+ int hidden_dim = p->hidden_dim;
355
+ int head_size = dim / p->n_heads;
356
+
357
+ // copy the token embedding into x
358
+ memcpy(x, w->token_embedding_table + token*dim, dim * sizeof(float));
359
+
360
+ // forward all the layers
361
+ for(unsigned long long l = 0; l < p->n_layers; l++) {
362
+
363
+ // attention rmsnorm
364
+ rmsnorm(s->xb, x, w->rms_att_weight + l*dim, dim);
365
+
366
+ // qkv matmuls for this position
367
+ quantize(&s->xq, s->xb, dim);
368
+ matmul(s->q, &s->xq, w->wq + l, dim, dim);
369
+ matmul(s->k, &s->xq, w->wk + l, dim, kv_dim);
370
+ matmul(s->v, &s->xq, w->wv + l, dim, kv_dim);
371
+
372
+ // RoPE relative positional encoding: complex-valued rotate q and k in each head
373
+ for (int i = 0; i < p->n_heads; i++) {
374
+ for (int j = 0; j < head_size; j += 2) {
375
+ float freq = 1.0f / powf(500000.0f, (float)j / (float)head_size);
376
+ float val = pos * freq;
377
+ float fcr = cosf(val);
378
+ float fci = sinf(val);
379
+ float q0 = s->q[i * head_size + j];
380
+ float q1 = s->q[i * head_size + j + 1];
381
+ s->q[i * head_size + j] = q0 * fcr - q1 * fci;
382
+ s->q[i * head_size + j + 1] = q0 * fci + q1 * fcr;
383
+ if (i < p->n_kv_heads) {
384
+ float k0 = s->k[i * head_size + j];
385
+ float k1 = s->k[i * head_size + j + 1];
386
+ s->k[i * head_size + j] = k0 * fcr - k1 * fci;
387
+ s->k[i * head_size + j + 1] = k0 * fci + k1 * fcr;
388
+ }
389
+ }
390
+ }
391
+
392
+ // save key,value at this time step (pos) to our kv cache
393
+ int loff = l * p->seq_len * kv_dim; // kv cache layer offset for convenience
394
+ float* key_cache_row = s->key_cache + loff + pos * kv_dim;
395
+ float* value_cache_row = s->value_cache + loff + pos * kv_dim;
396
+ memcpy(key_cache_row, s->k, kv_dim * sizeof(*key_cache_row));
397
+ memcpy(value_cache_row, s->v, kv_dim * sizeof(*value_cache_row));
398
+
399
+ // multihead attention. iterate over all heads
400
+ int h;
401
+ #pragma omp parallel for private(h)
402
+ for (h = 0; h < p->n_heads; h++) {
403
+ // get the query vector for this head
404
+ float* q = s->q + h * head_size;
405
+ // attention scores for this head
406
+ float* att = s->att + h * p->seq_len;
407
+ // iterate over all timesteps, including the current one
408
+ for (int t = 0; t <= pos; t++) {
409
+ // get the key vector for this head and at this timestep
410
+ float* k = s->key_cache + loff + t * kv_dim + (h / kv_mul) * head_size;
411
+ // calculate the attention score as the dot product of q and k
412
+ float score = 0.0f;
413
+ for (int i = 0; i < head_size; i++) {
414
+ score += q[i] * k[i];
415
+ }
416
+ score /= sqrtf(head_size);
417
+ // save the score to the attention buffer
418
+ att[t] = score;
419
+ }
420
+
421
+ // softmax the scores to get attention weights, from 0..pos inclusively
422
+ softmax(att, pos + 1);
423
+
424
+ // weighted sum of the values, store back into xb
425
+ float* xb = s->xb + h * head_size;
426
+ memset(xb, 0, head_size * sizeof(float));
427
+ for (int t = 0; t <= pos; t++) {
428
+ // get the value vector for this head and at this timestep
429
+ float* v = s->value_cache + loff + t * kv_dim + (h / kv_mul) * head_size;
430
+ // get the attention weight for this timestep
431
+ float a = att[t];
432
+ // accumulate the weighted value into xb
433
+ for (int i = 0; i < head_size; i++) {
434
+ xb[i] += a * v[i];
435
+ }
436
+ }
437
+ }
438
+
439
+ // final matmul to get the output of the attention
440
+ quantize(&s->xq, s->xb, dim);
441
+ matmul(s->xb2, &s->xq, w->wo + l, dim, dim);
442
+
443
+ // residual connection back into x
444
+ for (int i = 0; i < dim; i++) {
445
+ x[i] += s->xb2[i];
446
+ }
447
+
448
+ // ffn rmsnorm
449
+ rmsnorm(s->xb, x, w->rms_ffn_weight + l*dim, dim);
450
+
451
+ // Now for FFN in PyTorch we have: self.w2(F.silu(self.w1(x)) * self.w3(x))
452
+ // first calculate self.w1(x) and self.w3(x)
453
+ quantize(&s->xq, s->xb, dim);
454
+ matmul(s->hb, &s->xq, w->w1 + l, dim, hidden_dim);
455
+ matmul(s->hb2, &s->xq, w->w3 + l, dim, hidden_dim);
456
+
457
+ // SwiGLU non-linearity
458
+ for (int i = 0; i < hidden_dim; i++) {
459
+ float val = s->hb[i];
460
+ // silu(x)=x*s(x), where s(x) is the logistic sigmoid
461
+ val *= (1.0f / (1.0f + expf(-val)));
462
+ // elementwise multiply with w3(x)
463
+ val *= s->hb2[i];
464
+ s->hb[i] = val;
465
+ }
466
+
467
+ // final matmul to get the output of the ffn
468
+ quantize(&s->hq, s->hb, hidden_dim);
469
+ matmul(s->xb, &s->hq, w->w2 + l, hidden_dim, dim);
470
+
471
+ // residual connection
472
+ for (int i = 0; i < dim; i++) {
473
+ x[i] += s->xb[i];
474
+ }
475
+ }
476
+
477
+ // final rmsnorm
478
+ rmsnorm(x, x, w->rms_final_weight, dim);
479
+
480
+ // classifier into logits
481
+ quantize(&s->xq, x, dim);
482
+ matmul(s->logits, &s->xq, w->wcls, dim, p->vocab_size);
483
+ return s->logits;
484
+ }
485
+
486
+ // ----------------------------------------------------------------------------
487
+ // The Byte Pair Encoding (BPE) Tokenizer that translates strings <-> tokens
488
+
489
+ typedef struct {
490
+ char *str;
491
+ int id;
492
+ } TokenIndex;
493
+
494
+ typedef struct {
495
+ char** vocab;
496
+ float* vocab_scores;
497
+ TokenIndex *sorted_vocab;
498
+ int vocab_size;
499
+ unsigned int max_token_length;
500
+ unsigned char byte_pieces[512]; // stores all single-byte strings
501
+ } Tokenizer;
502
+
503
+ int compare_tokens(const void *a, const void *b) {
504
+ return strcmp(((TokenIndex*)a)->str, ((TokenIndex*)b)->str);
505
+ }
506
+
507
+ void build_tokenizer(Tokenizer* t, char* tokenizer_path, int vocab_size) {
508
+ // i should have written the vocab_size into the tokenizer file... sigh
509
+ t->vocab_size = vocab_size;
510
+ // malloc space to hold the scores and the strings
511
+ t->vocab = (char**)malloc(vocab_size * sizeof(char*));
512
+ t->vocab_scores = (float*)malloc(vocab_size * sizeof(float));
513
+ t->sorted_vocab = NULL; // initialized lazily
514
+ for (int i = 0; i < 256; i++) {
515
+ t->byte_pieces[i * 2] = (unsigned char)i;
516
+ t->byte_pieces[i * 2 + 1] = '\0';
517
+ }
518
+ // read in the file
519
+ FILE *file = fopen(tokenizer_path, "rb");
520
+ if (!file) { fprintf(stderr, "couldn't load %s\n", tokenizer_path); exit(EXIT_FAILURE); }
521
+ if (fread(&t->max_token_length, sizeof(int), 1, file) != 1) { fprintf(stderr, "failed read\n"); exit(EXIT_FAILURE); }
522
+ int len;
523
+ for (int i = 0; i < vocab_size; i++) {
524
+ if (fread(t->vocab_scores + i, sizeof(float), 1, file) != 1) { fprintf(stderr, "failed read\n"); exit(EXIT_FAILURE);}
525
+ if (fread(&len, sizeof(int), 1, file) != 1) { fprintf(stderr, "failed read\n"); exit(EXIT_FAILURE); }
526
+ t->vocab[i] = (char *)malloc(len + 1);
527
+ if (fread(t->vocab[i], len, 1, file) != 1) { fprintf(stderr, "failed read\n"); exit(EXIT_FAILURE); }
528
+ t->vocab[i][len] = '\0'; // add the string terminating token
529
+ }
530
+ fclose(file);
531
+ }
532
+
533
+ void free_tokenizer(Tokenizer* t) {
534
+ for (int i = 0; i < t->vocab_size; i++) { free(t->vocab[i]); }
535
+ free(t->vocab);
536
+ free(t->vocab_scores);
537
+ free(t->sorted_vocab);
538
+ }
539
+
540
+ char* decode(Tokenizer* t, int prev_token, int token) {
541
+ char *piece = t->vocab[token];
542
+
543
+
544
+ // careful, some tokens designate raw bytes, and look like e.g. '<0x01>'
545
+ // parse this and convert and return the actual byte
546
+ unsigned char byte_val;
547
+ if (sscanf(piece, "<0x%02hhX>", &byte_val) == 1) {
548
+ piece = (char*)t->byte_pieces + byte_val * 2;
549
+ }
550
+ return piece;
551
+ }
552
+
553
+ void safe_printf(char *piece, char *out_buffer) {
554
+ // piece might be a raw byte token, and we only want to print printable chars or whitespace
555
+ // because some of the other bytes can be various control codes, backspace, etc.
556
+ if (piece == NULL) { return; }
557
+ if (piece[0] == '\0') { return; }
558
+ if (piece[1] == '\0') {
559
+ unsigned char byte_val = piece[0];
560
+ if (!(isprint(byte_val) || isspace(byte_val))) {
561
+ return; // bad byte, don't print it
562
+ }
563
+ }
564
+ strcat(out_buffer, piece);
565
+ }
566
+
567
+ int str_lookup(char *str, TokenIndex *sorted_vocab, int vocab_size) {
568
+ // efficiently find the perfect match for str in vocab, return its index or -1 if not found
569
+ TokenIndex tok = { .str = str }; // acts as the key to search for
570
+ TokenIndex *res = bsearch(&tok, sorted_vocab, vocab_size, sizeof(TokenIndex), compare_tokens);
571
+ return res != NULL ? res->id : -1;
572
+ }
573
+
574
+ void encode(Tokenizer* t, char *text, int8_t bos, int8_t eos, int *tokens, int *n_tokens) {
575
+ // encode the string text (input) into an upper-bound preallocated tokens[] array
576
+ // bos != 0 means prepend the BOS token (=1), eos != 0 means append the EOS token (=2)
577
+ if (text == NULL) { fprintf(stderr, "cannot encode NULL text\n"); exit(EXIT_FAILURE); }
578
+
579
+ if (t->sorted_vocab == NULL) {
580
+ // lazily malloc and sort the vocabulary
581
+ t->sorted_vocab = malloc(t->vocab_size * sizeof(TokenIndex));
582
+ for (int i = 0; i < t->vocab_size; i++) {
583
+ t->sorted_vocab[i].str = t->vocab[i];
584
+ t->sorted_vocab[i].id = i;
585
+ }
586
+ qsort(t->sorted_vocab, t->vocab_size, sizeof(TokenIndex), compare_tokens);
587
+ }
588
+
589
+ // create a temporary buffer that will store merge candidates of always two consecutive tokens
590
+ // *2 for concat, +1 for null terminator +2 for UTF8 (in case max_token_length is 1)
591
+ char* str_buffer = malloc((t->max_token_length*2 +1 +2) * sizeof(char));
592
+ size_t str_len = 0;
593
+
594
+ // start at 0 tokens
595
+ *n_tokens = 0;
596
+
597
+ // add optional BOS (=128000) token, if desired
598
+ if (bos) tokens[(*n_tokens)++] = 128000;
599
+
600
+ // add_dummy_prefix is true by default
601
+ // so prepend a dummy prefix token to the input string, but only if text != ""
602
+ // TODO: pretty sure this isn't correct in the general case but I don't have the
603
+ // energy to read more of the sentencepiece code to figure out what it's doing
604
+
605
+
606
+
607
+
608
+
609
+ // Okay UTF-8 time. This will get messy. Here is the reference from Wikipedia:
610
+ // Code point ? UTF-8 conversion
611
+ // First code point Last code point Byte 1 Byte 2 Byte 3 Byte 4
612
+ // U+0000 U+007F 0xxxxxxx
613
+ // U+0080 U+07FF 110xxxxx 10xxxxxx
614
+ // U+0800 U+FFFF 1110xxxx 10xxxxxx 10xxxxxx
615
+ // U+10000 U+10FFFF 11110xxx 10xxxxxx 10xxxxxx 10xxxxxx
616
+
617
+ // process the raw (UTF-8) byte sequence of the input string
618
+ for (char *c = text; *c != '\0'; c++) {
619
+
620
+ // reset buffer if the current byte is ASCII or a leading byte
621
+ // 0xC0 is 11000000, so (*c & 0xC0) keeps the first 2 bits and zeros the rest
622
+ // 0x80 is 10000000
623
+ // in UTF-8, all continuation bytes start with "10" in first two bits
624
+ // so in English this is: "if this byte is not a continuation byte"
625
+ if ((*c & 0xC0) != 0x80) {
626
+ // this byte must be either a leading byte (11...) or an ASCII char (0x...)
627
+ // => reset our location, as we're starting a new UTF-8 codepoint
628
+ str_len = 0;
629
+ }
630
+
631
+ // append the current byte to the buffer
632
+ str_buffer[str_len++] = *c; // ++ is post-increment, incremented after this line
633
+ str_buffer[str_len] = '\0';
634
+
635
+ // while the next character is a continuation byte, continue appending
636
+ // but if there are too many of them, just stop to avoid overruning str_buffer size.
637
+ if ((*(c+1) & 0xC0) == 0x80 && str_len < 4) {
638
+ continue;
639
+ }
640
+
641
+ // ok c+1 is not a continuation byte, so we've read in a full codepoint
642
+ int id = str_lookup(str_buffer, t->sorted_vocab, t->vocab_size);
643
+
644
+ if (id != -1) {
645
+ // we found this codepoint in vocab, add it as a token
646
+ tokens[(*n_tokens)++] = id;
647
+ } else {
648
+ // byte_fallback encoding: just encode each byte as a token
649
+ // +3 is here because the first 3 vocab elements are <unk>, <s>, </s>
650
+ // so the individual bytes only start at index 3
651
+ for (int i=0; i < str_len; i++) {
652
+ tokens[(*n_tokens)++] = (unsigned char)str_buffer[i] + 3;
653
+ }
654
+ }
655
+ str_len = 0; // protect against a sequence of stray UTF8 continuation bytes
656
+ }
657
+
658
+ // merge the best consecutive pair or triple each iteration, according to the scores in vocab_scores
659
+ while (1) {
660
+ float best_score = -1e10;
661
+ int best_id = -1;
662
+ int best_idx = -1;
663
+ int best_len = 2; // length of the best merge sequence (2 for pair, 3 for triple)
664
+
665
+ // first, try to find the best pair to merge
666
+ for (int i = 0; i < (*n_tokens - 1); i++) {
667
+ // check if we can merge the pair (tokens[i], tokens[i+1])
668
+ sprintf(str_buffer, "%s%s", t->vocab[tokens[i]], t->vocab[tokens[i+1]]);
669
+ int id = str_lookup(str_buffer, t->sorted_vocab, t->vocab_size);
670
+ if (id != -1 && t->vocab_scores[id] > best_score) {
671
+ // this merge pair exists in vocab! record its score and position
672
+ best_score = t->vocab_scores[id];
673
+ best_id = id;
674
+ best_idx = i;
675
+ }
676
+ }
677
+
678
+ // if no pair was found, try to find the best triple to merge
679
+ if (best_idx == -1) {
680
+ for (int i = 0; i < (*n_tokens - 2); i++) {
681
+ // check if we can merge the triple (tokens[i], tokens[i+1], tokens[i+2])
682
+ sprintf(str_buffer, "%s%s%s", t->vocab[tokens[i]], t->vocab[tokens[i+1]], t->vocab[tokens[i+2]]);
683
+ int id = str_lookup(str_buffer, t->sorted_vocab, t->vocab_size);
684
+ if (id != -1 && t->vocab_scores[id] > best_score) {
685
+ // this merge triple exists in vocab! record its score and position
686
+ best_score = t->vocab_scores[id];
687
+ best_id = id;
688
+ best_idx = i;
689
+ best_len = 3;
690
+ }
691
+ }
692
+ }
693
+
694
+ if (best_idx == -1) {
695
+ break; // we couldn't find any more pairs or triples to merge, so we're done
696
+ }
697
+
698
+ // merge the consecutive pair or triple (best_idx, best_idx+1[, best_idx+2]) into new token best_id
699
+ tokens[best_idx] = best_id;
700
+ // delete token(s) at position best_idx+1 (and optionally best_idx+2), shift the entire sequence back
701
+ for (int i = best_idx + 1; i < (*n_tokens - best_len + 1); i++) {
702
+ tokens[i] = tokens[i + best_len - 1];
703
+ }
704
+ (*n_tokens) -= (best_len - 1); // token length decreased by the number of merged tokens minus one
705
+ }
706
+
707
+ // add optional EOS (=128001) token, if desired
708
+ if (eos) tokens[(*n_tokens)++] = 128001;
709
+
710
+ free(str_buffer);
711
+ }
712
+
713
+ // ----------------------------------------------------------------------------
714
+ // The Sampler, which takes logits and returns a sampled token
715
+ // sampling can be done in a few ways: greedy argmax, sampling, top-p sampling
716
+
717
+ typedef struct {
718
+ float prob;
719
+ int index;
720
+ } ProbIndex; // struct used when sorting probabilities during top-p sampling
721
+
722
+ typedef struct {
723
+ int vocab_size;
724
+ ProbIndex* probindex; // buffer used in top-p sampling
725
+ float temperature;
726
+ float topp;
727
+ unsigned long long rng_state;
728
+ } Sampler;
729
+
730
+ int sample_argmax(float* probabilities, int n) {
731
+ // return the index that has the highest probability
732
+ int max_i = 0;
733
+ float max_p = probabilities[0];
734
+ for (int i = 1; i < n; i++) {
735
+ if (probabilities[i] > max_p) {
736
+ max_i = i;
737
+ max_p = probabilities[i];
738
+ }
739
+ }
740
+ return max_i;
741
+ }
742
+
743
+ int sample_mult(float* probabilities, int n, float coin) {
744
+ // sample index from probabilities (they must sum to 1!)
745
+ // coin is a random number in [0, 1), usually from random_f32()
746
+ float cdf = 0.0f;
747
+ for (int i = 0; i < n; i++) {
748
+ cdf += probabilities[i];
749
+ if (coin < cdf) {
750
+ return i;
751
+ }
752
+ }
753
+ return n - 1; // in case of rounding errors
754
+ }
755
+
756
+ int compare(const void* a, const void* b) {
757
+ ProbIndex* a_ = (ProbIndex*) a;
758
+ ProbIndex* b_ = (ProbIndex*) b;
759
+ if (a_->prob > b_->prob) return -1;
760
+ if (a_->prob < b_->prob) return 1;
761
+ return 0;
762
+ }
763
+
764
+ int sample_topp(float* probabilities, int n, float topp, ProbIndex* probindex, float coin) {
765
+ // top-p sampling (or "nucleus sampling") samples from the smallest set of
766
+ // tokens that exceed probability topp. This way we never sample tokens that
767
+ // have very low probabilities and are less likely to go "off the rails".
768
+ // coin is a random number in [0, 1), usually from random_f32()
769
+
770
+ int n0 = 0;
771
+ // quicksort indices in descending order of probabilities
772
+ // values smaller than (1 - topp) / (n - 1) cannot be part of the result
773
+ // so for efficiency we crop these out as candidates before sorting
774
+ const float cutoff = (1.0f - topp) / (n - 1);
775
+ for (int i = 0; i < n; i++) {
776
+ if (probabilities[i] >= cutoff) {
777
+ probindex[n0].index = i;
778
+ probindex[n0].prob = probabilities[i];
779
+ n0++;
780
+ }
781
+ }
782
+ qsort(probindex, n0, sizeof(ProbIndex), compare);
783
+
784
+ // truncate the list where cumulative probability exceeds topp
785
+ float cumulative_prob = 0.0f;
786
+ int last_idx = n0 - 1; // in case of rounding errors consider all elements
787
+ for (int i = 0; i < n0; i++) {
788
+ cumulative_prob += probindex[i].prob;
789
+ if (cumulative_prob > topp) {
790
+ last_idx = i;
791
+ break; // we've exceeded topp by including last_idx
792
+ }
793
+ }
794
+
795
+ // sample from the truncated list
796
+ float r = coin * cumulative_prob;
797
+ float cdf = 0.0f;
798
+ for (int i = 0; i <= last_idx; i++) {
799
+ cdf += probindex[i].prob;
800
+ if (r < cdf) {
801
+ return probindex[i].index;
802
+ }
803
+ }
804
+ return probindex[last_idx].index; // in case of rounding errors
805
+ }
806
+
807
+ void build_sampler(Sampler* sampler, int vocab_size, float temperature, float topp, unsigned long long rng_seed) {
808
+ sampler->vocab_size = vocab_size;
809
+ sampler->temperature = temperature;
810
+ sampler->topp = topp;
811
+ sampler->rng_state = rng_seed;
812
+ // buffer only used with nucleus sampling; may not need but it's ~small
813
+ sampler->probindex = malloc(sampler->vocab_size * sizeof(ProbIndex));
814
+ }
815
+
816
+ void free_sampler(Sampler* sampler) {
817
+ free(sampler->probindex);
818
+ }
819
+
820
+ unsigned int random_u32(unsigned long long *state) {
821
+ // xorshift rng: https://en.wikipedia.org/wiki/Xorshift#xorshift.2A
822
+ *state ^= *state >> 12;
823
+ *state ^= *state << 25;
824
+ *state ^= *state >> 27;
825
+ return (*state * 0x2545F4914F6CDD1Dull) >> 32;
826
+ }
827
+ float random_f32(unsigned long long *state) { // random float32 in [0,1)
828
+ return (random_u32(state) >> 8) / 16777216.0f;
829
+ }
830
+
831
+ int sample(Sampler* sampler, float* logits) {
832
+ // sample the token given the logits and some hyperparameters
833
+ int next;
834
+ if (sampler->temperature == 0.0f) {
835
+ // greedy argmax sampling: take the token with the highest probability
836
+ next = sample_argmax(logits, sampler->vocab_size);
837
+ } else {
838
+ // apply the temperature to the logits
839
+ for (int q=0; q<sampler->vocab_size; q++) { logits[q] /= sampler->temperature; }
840
+ // apply softmax to the logits to get the probabilities for next token
841
+ softmax(logits, sampler->vocab_size);
842
+ // flip a (float) coin (this is our source of entropy for sampling)
843
+ float coin = random_f32(&sampler->rng_state);
844
+ // we sample from this distribution to get the next token
845
+ if (sampler->topp <= 0 || sampler->topp >= 1) {
846
+ // simply sample from the predicted probability distribution
847
+ next = sample_mult(logits, sampler->vocab_size, coin);
848
+ } else {
849
+ // top-p (nucleus) sampling, clamping the least likely tokens to zero
850
+ next = sample_topp(logits, sampler->vocab_size, sampler->topp, sampler->probindex, coin);
851
+ }
852
+ }
853
+ return next;
854
+ }
855
+
856
+ // ----------------------------------------------------------------------------
857
+ // utilities: time
858
+
859
+ long time_in_ms() {
860
+ // return time in milliseconds, for benchmarking the model speed
861
+ struct timespec time;
862
+ clock_gettime(CLOCK_REALTIME, &time);
863
+ return time.tv_sec * 1000 + time.tv_nsec / 1000000;
864
+ }
865
+
866
+ // ----------------------------------------------------------------------------
867
+ // generation loop
868
+
869
+ void generate(Transformer *transformer, Tokenizer *tokenizer, Sampler *sampler, char *prompt, int steps, char *out_buffer) {
870
+ char *empty_prompt = "";
871
+ if (prompt == NULL) { prompt = empty_prompt; }
872
+
873
+ // encode the (string) prompt into tokens sequence
874
+ int num_prompt_tokens = 0;
875
+ int* prompt_tokens = (int*)malloc((strlen(prompt)+3) * sizeof(int)); // +3 for '\0', ?BOS, ?EOS
876
+ encode(tokenizer, prompt, 1, 0, prompt_tokens, &num_prompt_tokens);
877
+ if (num_prompt_tokens < 1) {
878
+ fprintf(stderr, "something is wrong, expected at least 1 prompt token\n");
879
+ exit(EXIT_FAILURE);
880
+ }
881
+
882
+ // start the main loop
883
+ long start = 0; // used to time our code, only initialized after first iteration
884
+ int next; // will store the next token in the sequence
885
+ int token = prompt_tokens[0]; // kick off with the first token in the prompt
886
+ int pos = 0; // position in the sequence
887
+
888
+ while (pos < steps) {
889
+
890
+ // forward the transformer to get logits for the next token
891
+ float* logits = forward(transformer, token, pos);
892
+
893
+ // advance the state machine
894
+ if (pos < num_prompt_tokens - 1) {
895
+ // if we are still processing the input prompt, force the next prompt token
896
+ next = prompt_tokens[pos + 1];
897
+ } else {
898
+ // otherwise sample the next token from the logits
899
+ next = sample(sampler, logits);
900
+ }
901
+ pos++;
902
+
903
+ // data-dependent terminating condition: the BOS (=1) token delimits sequences
904
+ if ((next == 128001 || next == 128009) && pos > num_prompt_tokens) break;
905
+ // print the token as string, decode it with the Tokenizer object
906
+ char* piece = decode(tokenizer, token, next);
907
+ safe_printf(piece, out_buffer); // same as printf("%s", piece), but skips "unsafe" bytes
908
+ fflush(stdout);
909
+ token = next;
910
+
911
+ // init the timer here because the first iteration can be slower
912
+ if (start == 0) { start = time_in_ms(); }
913
+ }
914
+ strcat(out_buffer, "\n");
915
+
916
+ // report achieved tok/s (pos-1 because the timer starts after first iteration)
917
+ if (pos > 1) {
918
+ long end = time_in_ms();
919
+ fprintf(stderr, "achieved tok/s: %f\n", (pos-1) / (double)(end-start)*1000);
920
+ }
921
+
922
+ free(prompt_tokens);
923
+ }
924
+
925
+ void read_stdin(const char* guide, char* buffer, size_t bufsize, char *out_buffer) {
926
+ // read a line from stdin, up to but not including \n
927
+ strcat(out_buffer, guide);
928
+ if (fgets(buffer, bufsize, stdin) != NULL) {
929
+ size_t len = strlen(buffer);
930
+ if (len > 0 && buffer[len - 1] == '\n') {
931
+ buffer[len - 1] = '\0'; // strip newline
932
+ }
933
+ }
934
+ }
935
+
936
+ // ----------------------------------------------------------------------------
937
+ // chat loop
938
+ // I manually inspected the tokens for a few chat conversations compared to
939
+ // python reference and that seemed ok, but this was not thoroughly tested and
940
+ // is not safely implemented, it's more a proof of concept atm.
941
+
942
+ void chat(Transformer *transformer, Tokenizer *tokenizer, Sampler *sampler,
943
+ char *cli_user_prompt, char *cli_system_prompt, int steps, char *out_buffer) {
944
+
945
+ // buffers for reading the system prompt and user prompt from stdin
946
+ // you'll notice they are somewhat haphazardly and unsafely set atm
947
+ char* system_prompt = (char*)malloc(32768 * sizeof(char));
948
+ char* user_prompt = (char*)malloc(32768 * sizeof(char));
949
+ int num_prompt_tokens = 0;
950
+ int* prompt_tokens = (int*)malloc(32768 * sizeof(int));
951
+ int* system_prompt_tokens = (int*)malloc(32768 * sizeof(int));
952
+ int* user_prompt_tokens = (int*)malloc(32768 * sizeof(int));
953
+ int user_idx=0;
954
+
955
+ // start the main loop
956
+ int8_t user_turn = 1; // user starts
957
+ int next; // will store the next token in the sequence
958
+ int token; // stores the current token to feed into the transformer
959
+
960
+ int pos = 0; // position in the sequence
961
+ while (pos < steps) {
962
+
963
+ // when it is the user's turn to contribute tokens to the dialog...
964
+ if (user_turn) {
965
+ // get the (optional) system prompt at position 0
966
+ if (pos == 0) {
967
+ // at position 0, the user can also contribute a system prompt
968
+ prompt_tokens[num_prompt_tokens++] = 128000; // "<|begin_of_text|>"
969
+ prompt_tokens[num_prompt_tokens++] = 128006; // "<|start_header_id|>"
970
+ prompt_tokens[num_prompt_tokens++] = 9125; // "system"
971
+ prompt_tokens[num_prompt_tokens++] = 128007; // "<|end_header_id|>"
972
+ prompt_tokens[num_prompt_tokens++] = 271; // "\n\n"
973
+ if (cli_system_prompt == NULL) {
974
+ // system prompt was not passed in, attempt to get it from stdin
975
+ read_stdin("Enter system prompt (optional): ", system_prompt, 32768, out_buffer);
976
+ } else {
977
+ // system prompt was passed in, use it
978
+ strcpy(system_prompt, cli_system_prompt);
979
+ }
980
+ if (system_prompt != NULL) {
981
+ int num_system_prompt_tokens = 0;
982
+ encode(tokenizer, system_prompt, 0, 0, system_prompt_tokens, &num_system_prompt_tokens);
983
+ for (int i=0; i<num_system_prompt_tokens; i++) {
984
+ prompt_tokens[num_prompt_tokens++] = system_prompt_tokens[i];
985
+ }
986
+ }
987
+ prompt_tokens[num_prompt_tokens++] = 128009; // "<|eot_id|>"
988
+ } else {
989
+ num_prompt_tokens = 0;
990
+ }
991
+ prompt_tokens[num_prompt_tokens++] = 128006; // "<|start_header_id|>"
992
+ prompt_tokens[num_prompt_tokens++] = 882; // "user"
993
+ prompt_tokens[num_prompt_tokens++] = 128007; // "<|end_header_id|>"
994
+ prompt_tokens[num_prompt_tokens++] = 271; // "\n\n"
995
+ // get the user prompt
996
+ if (pos == 0 && cli_user_prompt != NULL) {
997
+ // user prompt for position 0 was passed in, use it
998
+ strcpy(user_prompt, cli_user_prompt);
999
+ } else {
1000
+ // otherwise get user prompt from stdin
1001
+ read_stdin("User (or exit): ", user_prompt, 32768, out_buffer);
1002
+ if(strcmp(user_prompt, "exit")==0) break;
1003
+ }
1004
+ int num_user_prompt_tokens = 0;
1005
+ // encode the user prompt into tokens
1006
+ encode(tokenizer, user_prompt, 0, 0, user_prompt_tokens, &num_user_prompt_tokens);
1007
+ for (int i=0; i<num_user_prompt_tokens; i++) {
1008
+ prompt_tokens[num_prompt_tokens++] = user_prompt_tokens[i];
1009
+ }
1010
+ prompt_tokens[num_prompt_tokens++] = 128009; // "<|eot_id|>"
1011
+ prompt_tokens[num_prompt_tokens++] = 128006; // "<|start_header_id|>"
1012
+ prompt_tokens[num_prompt_tokens++] = 78191; // "assistant"
1013
+ prompt_tokens[num_prompt_tokens++] = 128007; // "<|end_header_id|>"
1014
+ prompt_tokens[num_prompt_tokens++] = 271; // "\n\n"
1015
+
1016
+
1017
+ user_idx = 0; // reset the user index
1018
+ user_turn = 0;
1019
+ strcat(out_buffer, "Assistant: ");
1020
+ }
1021
+
1022
+ // determine the token to pass into the transformer next
1023
+ if (user_idx < num_prompt_tokens) {
1024
+ // if we are still processing the input prompt, force the next prompt token
1025
+ token = prompt_tokens[user_idx++];
1026
+ } else {
1027
+ // otherwise use the next token sampled from previous turn
1028
+ token = next;
1029
+ }
1030
+ // EOS (=128009) token ends the Assistant turn
1031
+ if (user_idx >= num_prompt_tokens && (token == 128009 || token == 128001)) { user_turn = 1; }
1032
+
1033
+ // forward the transformer to get logits for the next token
1034
+ float* logits = forward(transformer, token, pos);
1035
+ next = sample(sampler, logits);
1036
+ pos++;
1037
+
1038
+ if (user_idx >= num_prompt_tokens && next != 128009 && next != 128001 && next != 128006) {
1039
+ // the Assistant is responding, so print its output
1040
+ char* piece = decode(tokenizer, token, next);
1041
+ safe_printf(piece, out_buffer); // same as printf("%s", piece), but skips "unsafe" bytes
1042
+ fflush(stdout);
1043
+ }
1044
+ if (user_idx >= num_prompt_tokens && next == 128009 || next == 128001) { printf("\n"); }
1045
+ }
1046
+ strcat(out_buffer, "\n");
1047
+ free(prompt_tokens);
1048
+ free(system_prompt_tokens);
1049
+ free(user_prompt_tokens);
1050
+ free(system_prompt);
1051
+ free(user_prompt);
1052
+ }
1053
+
1054
+ typedef struct {
1055
+ char *checkpoint_path;
1056
+ char *tokenizer_path;
1057
+ float temperature;
1058
+ float topp;
1059
+ int steps;
1060
+ char *prompt;
1061
+ unsigned long long rng_seed;
1062
+ char *mode;
1063
+ char *system_prompt;
1064
+ char out_buffer[32768];
1065
+ Transformer transformer;
1066
+ Tokenizer tokenizer;
1067
+ Sampler sampler;
1068
+ } Main;
1069
+
1070
+ #define DEFAULT_CHECKPOINT_PATH "model.bin"
1071
+ #define DEFAULT_TOKENIZER_PATH "tokenizer.bin"
1072
+ #define DEFAULT_MAIN_MODE "generate"
1073
+
1074
+ __declspec(dllexport) Main *build_main(char* checkpoint_path, char* tokenizer_path, float temperature, float topp, int steps,
1075
+ char* prompt, unsigned long long rng_seed, char* mode, char* system_prompt) {
1076
+ // parameter validation/overrides
1077
+ Main *ret = (Main *)calloc(1, sizeof(Main));
1078
+ if (!ret) return ret;
1079
+ ret->checkpoint_path = checkpoint_path ? checkpoint_path : DEFAULT_CHECKPOINT_PATH;
1080
+ ret->tokenizer_path = tokenizer_path ? tokenizer_path : DEFAULT_TOKENIZER_PATH;
1081
+ ret->temperature = (temperature < 0.0) ? 0.0f : (temperature ? temperature : 1.0f);
1082
+ ret->topp = topp ? topp : 0.9f;
1083
+ ret->steps = (steps < 0) ? 0 : steps;
1084
+ ret->prompt = prompt ? system_prompt : NULL;
1085
+ ret->rng_seed = (rng_seed <= 0) ? (unsigned int)time(NULL) : rng_seed;
1086
+ ret->mode = mode ? mode : DEFAULT_MAIN_MODE;
1087
+ ret->system_prompt = system_prompt ? system_prompt : NULL;
1088
+ // build the Transformer via the model .bin file
1089
+ build_transformer(&ret->transformer, ret->checkpoint_path);
1090
+ ret->steps = (steps == 0 || steps > ret->transformer.config.seq_len) ? ret->transformer.config.seq_len : steps; // override to ~max length
1091
+ // build the Tokenizer via the tokenizer .bin file
1092
+ build_tokenizer(&ret->tokenizer, ret->tokenizer_path, ret->transformer.config.vocab_size);
1093
+ // build the Sampler
1094
+ build_sampler(&ret->sampler, ret->transformer.config.vocab_size, ret->temperature, ret->topp, ret->rng_seed);
1095
+ return ret;
1096
+ }
1097
+
1098
+ __declspec(dllexport) void free_main(Main *m) {
1099
+ // memory and file handles cleanup
1100
+ free_sampler(&m->sampler);
1101
+ free_tokenizer(&m->tokenizer);
1102
+ free_transformer(&m->transformer);
1103
+ free(m);
1104
+ }
1105
+
1106
+ __declspec(dllexport) char *run_main(Main *m) {
1107
+ // run!
1108
+ if (strcmp(m->mode, "generate") == 0) {
1109
+ generate(&m->transformer, &m->tokenizer, &m->sampler, m->prompt, m->steps, m->out_buffer);
1110
+ } else if (strcmp(m->mode, "chat") == 0) {
1111
+ chat(&m->transformer, &m->tokenizer, &m->sampler, m->prompt, m->system_prompt, m->steps, m->out_buffer);
1112
+ } else {
1113
+ fprintf(stderr, "unknown mode: %s\n", m->mode);
1114
+ }
1115
+ return m->out_buffer;
1116
+ }
tokenizer.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5718704735f72fe91c60a346d527822e2fc551d2424635a757109a30f25e325b
3
+ size 1864861
tokenizer.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:82e9d31979e92ab929cd544440f129d9ecd797b69e327f80f17e1c50d5551b55
3
+ size 2183982
tokenizer.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Taken from llama code and lightly modified
2
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ # This software may be used and distributed according to the terms of the Llama 3 Community License Agreement.
4
+
5
+ import array
6
+ import os
7
+ import struct
8
+ import argparse
9
+ from pathlib import Path
10
+ from typing import List
11
+
12
+ import tiktoken
13
+ from tiktoken.load import load_tiktoken_bpe
14
+
15
+ TOKENIZER_MODEL = "tokenizer.model" # the llama tiktoken tokenizer model
16
+
17
+
18
+ class Tokenizer:
19
+ pat_str = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"
20
+
21
+ def __init__(self, tokenizer_model=None):
22
+ model_path = tokenizer_model if tokenizer_model else TOKENIZER_MODEL
23
+ assert os.path.isfile(model_path), model_path
24
+ mergeable_ranks = load_tiktoken_bpe(model_path)
25
+ self.model_path = model_path
26
+
27
+ # BOS / EOS token IDs
28
+ num_base_tokens = len(mergeable_ranks)
29
+ num_reserved_special_tokens = 256
30
+
31
+ special_tokens = [
32
+ "<|begin_of_text|>",
33
+ "<|end_of_text|>",
34
+ "<|reserved_special_token_0|>",
35
+ "<|reserved_special_token_1|>",
36
+ "<|reserved_special_token_2|>",
37
+ "<|reserved_special_token_3|>",
38
+ "<|start_header_id|>",
39
+ "<|end_header_id|>",
40
+ "<|reserved_special_token_4|>",
41
+ "<|eot_id|>", # end of turn
42
+ ] + [
43
+ f"<|reserved_special_token_{i}|>"
44
+ for i in range(5, num_reserved_special_tokens - 5)
45
+ ]
46
+ self.special_tokens = {
47
+ token: num_base_tokens + i for i, token in enumerate(special_tokens)
48
+ }
49
+ self.model = tiktoken.Encoding(
50
+ name=Path(model_path).name,
51
+ pat_str=self.pat_str,
52
+ mergeable_ranks=mergeable_ranks,
53
+ special_tokens=self.special_tokens,
54
+ )
55
+ self.n_words = self.model.n_vocab
56
+ self.bos_id = self.special_tokens["<|begin_of_text|>"]
57
+ self.eos_id = self.special_tokens["<|end_of_text|>"]
58
+ self.pad_id = -1
59
+ self.stop_tokens = {
60
+ self.special_tokens["<|end_of_text|>"],
61
+ self.special_tokens["<|eot_id|>"],
62
+ }
63
+
64
+ def encode(
65
+ self, s: str, bos: bool, eos: bool, allowed_special, disallowed_special
66
+ ) -> List[int]:
67
+ assert type(s) is str
68
+ self.model.encode(
69
+ substr,
70
+ allowed_special=allowed_special,
71
+ disallowed_special=disallowed_special,
72
+ )
73
+
74
+ if bos:
75
+ t.insert(0, self.bos_id)
76
+ if eos:
77
+ t.append(self.eos_id)
78
+ return t
79
+
80
+ def decode(self, t: List[int]) -> str:
81
+ return self.model.decode(t)
82
+
83
+ def export(self):
84
+
85
+ # get all the tokens (postprocessed) and their scores as floats
86
+ tokens, scores = [], []
87
+ for i in range(self.n_words):
88
+
89
+ # decode the token and light postprocessing
90
+ t = self.model.decode_single_token_bytes(i)
91
+ s = i
92
+ tokens.append(t)
93
+ scores.append(s)
94
+
95
+ # record the max token length
96
+ max_token_length = max(len(t) for t in tokens)
97
+
98
+ # write to a binary file
99
+ # the tokenizer.bin file is the same as .model file, but .bin
100
+ tokenizer_bin = self.model_path.replace(".model", ".bin")
101
+ with open(tokenizer_bin, "wb") as f:
102
+ f.write(struct.pack("I", max_token_length))
103
+ for bytes, score in zip(tokens, scores):
104
+ f.write(struct.pack("fI", score, len(bytes)))
105
+ f.write(bytes)
106
+
107
+
108
+ if __name__ == "__main__":
109
+ parser = argparse.ArgumentParser()
110
+ parser.add_argument("-t", "--tokenizer-model", type=str, help="optional path to custom tokenizer ")
111
+
112
+ args = parser.parse_args()
113
+
114
+ t = Tokenizer(args.tokenizer_model)
115
+ t.export()
win.c ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "win.h"
2
+ #include <errno.h>
3
+ #include <io.h>
4
+
5
+ #ifndef FILE_MAP_EXECUTE
6
+ #define FILE_MAP_EXECUTE 0x0020
7
+ #endif /* FILE_MAP_EXECUTE */
8
+
9
+ static int __map_mman_error(const uint32_t err, const int deferr)
10
+ {
11
+ if (err == 0)
12
+ return 0;
13
+ //TODO: implement
14
+ return err;
15
+ }
16
+
17
+ static uint32_t __map_mmap_prot_page(const int prot)
18
+ {
19
+ uint32_t protect = 0;
20
+
21
+ if (prot == PROT_NONE)
22
+ return protect;
23
+
24
+ if ((prot & PROT_EXEC) != 0)
25
+ {
26
+ protect = ((prot & PROT_WRITE) != 0) ?
27
+ PAGE_EXECUTE_READWRITE : PAGE_EXECUTE_READ;
28
+ }
29
+ else
30
+ {
31
+ protect = ((prot & PROT_WRITE) != 0) ?
32
+ PAGE_READWRITE : PAGE_READONLY;
33
+ }
34
+
35
+ return protect;
36
+ }
37
+
38
+ static uint32_t __map_mmap_prot_file(const int prot)
39
+ {
40
+ uint32_t desiredAccess = 0;
41
+
42
+ if (prot == PROT_NONE)
43
+ return desiredAccess;
44
+
45
+ if ((prot & PROT_READ) != 0)
46
+ desiredAccess |= FILE_MAP_READ;
47
+ if ((prot & PROT_WRITE) != 0)
48
+ desiredAccess |= FILE_MAP_WRITE;
49
+ if ((prot & PROT_EXEC) != 0)
50
+ desiredAccess |= FILE_MAP_EXECUTE;
51
+
52
+ return desiredAccess;
53
+ }
54
+
55
+ void* mmap(void *addr, size_t len, int prot, int flags, int fildes, ssize_t off)
56
+ {
57
+ HANDLE fm, h;
58
+ void * map = MAP_FAILED;
59
+
60
+ #ifdef _MSC_VER
61
+ #pragma warning(push)
62
+ #pragma warning(disable: 4293)
63
+ #endif
64
+
65
+ const uint32_t dwFileOffsetLow = (uint32_t)(off & 0xFFFFFFFFL);
66
+ const uint32_t dwFileOffsetHigh = (uint32_t)((off >> 32) & 0xFFFFFFFFL);
67
+ const uint32_t protect = __map_mmap_prot_page(prot);
68
+ const uint32_t desiredAccess = __map_mmap_prot_file(prot);
69
+
70
+ const ssize_t maxSize = off + (ssize_t)len;
71
+
72
+ const uint32_t dwMaxSizeLow = (uint32_t)(maxSize & 0xFFFFFFFFL);
73
+ const uint32_t dwMaxSizeHigh = (uint32_t)((maxSize >> 32) & 0xFFFFFFFFL);
74
+
75
+ #ifdef _MSC_VER
76
+ #pragma warning(pop)
77
+ #endif
78
+
79
+ errno = 0;
80
+
81
+ if (len == 0
82
+ /* Unsupported flag combinations */
83
+ || (flags & MAP_FIXED) != 0
84
+ /* Unsupported protection combinations */
85
+ || prot == PROT_EXEC)
86
+ {
87
+ errno = EINVAL;
88
+ return MAP_FAILED;
89
+ }
90
+
91
+ h = ((flags & MAP_ANONYMOUS) == 0) ?
92
+ (HANDLE)_get_osfhandle(fildes) : INVALID_HANDLE_VALUE;
93
+
94
+ if ((flags & MAP_ANONYMOUS) == 0 && h == INVALID_HANDLE_VALUE)
95
+ {
96
+ errno = EBADF;
97
+ return MAP_FAILED;
98
+ }
99
+
100
+ fm = CreateFileMapping(h, NULL, protect, dwMaxSizeHigh, dwMaxSizeLow, NULL);
101
+
102
+ if (fm == NULL)
103
+ {
104
+ errno = __map_mman_error(GetLastError(), EPERM);
105
+ return MAP_FAILED;
106
+ }
107
+
108
+ map = MapViewOfFile(fm, desiredAccess, dwFileOffsetHigh, dwFileOffsetLow, len);
109
+
110
+ CloseHandle(fm);
111
+
112
+ if (map == NULL)
113
+ {
114
+ errno = __map_mman_error(GetLastError(), EPERM);
115
+ return MAP_FAILED;
116
+ }
117
+
118
+ return map;
119
+ }
120
+
121
+ int munmap(void *addr, size_t len)
122
+ {
123
+ if (UnmapViewOfFile(addr))
124
+ return 0;
125
+
126
+ errno = __map_mman_error(GetLastError(), EPERM);
127
+
128
+ return -1;
129
+ }
130
+
131
+ int msync(void *addr, size_t len, int flags)
132
+ {
133
+ if (FlushViewOfFile(addr, len))
134
+ return 0;
135
+
136
+ errno = __map_mman_error(GetLastError(), EPERM);
137
+
138
+ return -1;
139
+ }
140
+
141
+ int mlock(const void *addr, size_t len)
142
+ {
143
+ if (VirtualLock((LPVOID)addr, len))
144
+ return 0;
145
+
146
+ errno = __map_mman_error(GetLastError(), EPERM);
147
+
148
+ return -1;
149
+ }
150
+
151
+ int munlock(const void *addr, size_t len)
152
+ {
153
+ if (VirtualUnlock((LPVOID)addr, len))
154
+ return 0;
155
+
156
+ errno = __map_mman_error(GetLastError(), EPERM);
157
+
158
+ return -1;
159
+ }
160
+
161
+ // Portable clock_gettime function for Windows
162
+ int clock_gettime(int clk_id, struct timespec *tp) {
163
+ uint32_t ticks = GetTickCount();
164
+ tp->tv_sec = ticks / 1000;
165
+ tp->tv_nsec = (ticks % 1000) * 1000000;
166
+ return 0;
167
+ }
win.h ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #ifndef _WIN_H_
2
+ #define _WIN_H_
3
+
4
+ #define WIN32_LEAN_AND_MEAN // Exclude rarely-used stuff from Windows headers
5
+ #include <windows.h>
6
+ #include <time.h>
7
+ #include <stdint.h>
8
+
9
+ #define ssize_t int64_t
10
+ #define ftell _ftelli64
11
+
12
+ // Below code is originally from mman-win32
13
+ //
14
+ /*
15
+ * sys/mman.h
16
+ * mman-win32
17
+ */
18
+
19
+ #ifndef _WIN32_WINNT // Allow use of features specific to Windows XP or later.
20
+ #define _WIN32_WINNT 0x0501 // Change this to the appropriate value to target other versions of Windows.
21
+ #endif
22
+
23
+ /* All the headers include this file. */
24
+ #ifndef _MSC_VER
25
+ #include <_mingw.h>
26
+ #endif
27
+
28
+ #include <sys/types.h>
29
+
30
+ #ifdef __cplusplus
31
+ extern "C" {
32
+ #endif
33
+
34
+ #define PROT_NONE 0
35
+ #define PROT_READ 1
36
+ #define PROT_WRITE 2
37
+ #define PROT_EXEC 4
38
+
39
+ #define MAP_FILE 0
40
+ #define MAP_SHARED 1
41
+ #define MAP_PRIVATE 2
42
+ #define MAP_TYPE 0xf
43
+ #define MAP_FIXED 0x10
44
+ #define MAP_ANONYMOUS 0x20
45
+ #define MAP_ANON MAP_ANONYMOUS
46
+
47
+ #define MAP_FAILED ((void *)-1)
48
+
49
+ /* Flags for msync. */
50
+ #define MS_ASYNC 1
51
+ #define MS_SYNC 2
52
+ #define MS_INVALIDATE 4
53
+
54
+ /* Flags for portable clock_gettime call. */
55
+ #define CLOCK_REALTIME 0
56
+
57
+ void* mmap(void *addr, size_t len, int prot, int flags, int fildes, ssize_t off);
58
+ int munmap(void *addr, size_t len);
59
+ int mprotect(void *addr, size_t len, int prot);
60
+ int msync(void *addr, size_t len, int flags);
61
+ int mlock(const void *addr, size_t len);
62
+ int munlock(const void *addr, size_t len);
63
+ int clock_gettime(int clk_id, struct timespec *tp);
64
+
65
+ #ifdef __cplusplus
66
+ };
67
+ #endif
68
+
69
+ #endif /* _WIN_H_ */