radames HF staff commited on
Commit
0021428
1 Parent(s): c40c794
Files changed (3) hide show
  1. Dockerfile +2 -0
  2. gradio_app.py +72 -17
  3. llama2.mojo +704 -448
Dockerfile CHANGED
@@ -66,6 +66,8 @@ COPY --chown=user . $HOME/app
66
  RUN wget -c -nv https://huggingface.co/karpathy/tinyllamas/resolve/main/stories15M.bin
67
  RUN wget -c -nv https://huggingface.co/karpathy/tinyllamas/resolve/main/stories42M.bin
68
  RUN wget -c -nv https://huggingface.co/karpathy/tinyllamas/resolve/main/stories110M.bin
 
 
69
 
70
  # CMD ["mojo", "llama2.mojo"]
71
  CMD ["python3", "gradio_app.py"]
 
66
  RUN wget -c -nv https://huggingface.co/karpathy/tinyllamas/resolve/main/stories15M.bin
67
  RUN wget -c -nv https://huggingface.co/karpathy/tinyllamas/resolve/main/stories42M.bin
68
  RUN wget -c -nv https://huggingface.co/karpathy/tinyllamas/resolve/main/stories110M.bin
69
+ RUN wget -c -nv https://huggingface.co/kirp/TinyLlama-1.1B-Chat-v0.2-bin/resolve/main/tok_tl-chat.bin
70
+ RUN wget -c -nv https://huggingface.co/kirp/TinyLlama-1.1B-Chat-v0.2-bin/resolve/main/tl-chat.bin
71
 
72
  # CMD ["mojo", "llama2.mojo"]
73
  CMD ["python3", "gradio_app.py"]
gradio_app.py CHANGED
@@ -1,36 +1,91 @@
1
  import gradio as gr
2
  import subprocess
3
  import sys
4
- import os
5
 
6
-
7
- async def generate(prompt):
8
- # os.environ["PROMPT"] = prompt
9
  # stream stout
 
 
 
 
10
  process = subprocess.Popen(
11
- ["mojo", "llama2.mojo"], stdout=subprocess.PIPE, stderr=subprocess.PIPE
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  )
13
  text = ""
14
  for char in iter(lambda: process.stdout.read(1), b""):
15
- char_decoded = char.decode()
16
- sys.stdout.write(char_decoded)
17
  text += char_decoded
18
  yield text
19
 
20
 
21
- output_text = gr.Textbox(label="Generated Text")
22
-
23
- demo = gr.Interface(
24
- fn=generate,
25
- inputs=None,
26
- outputs=output_text,
27
- description="""
28
  # llama2.🔥
29
  ## [Mojo](https://docs.modular.com/mojo/) implementation of [llama2.c](https://github.com/karpathy/llama2.c) by [@tairov](https://github.com/tairov)
30
  Source: https://github.com/tairov/llama2.mojo
31
- """,
32
- allow_flagging="never",
33
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
  demo.queue()
36
  demo.launch(server_name="0.0.0.0")
 
1
  import gradio as gr
2
  import subprocess
3
  import sys
4
+ from pathlib import Path
5
 
6
+ async def generate(prompt, model_name, seed=0, temperature=0.5, num_tokens=256):
 
 
7
  # stream stout
8
+ base = ""#"../model/"
9
+ tokenizer_name = "tokenizer.bin"
10
+ if model_name == "tl-chat.bin":
11
+ tokenizer_name = 'tok_tl-chat.bin'
12
  process = subprocess.Popen(
13
+ [
14
+ "mojo",
15
+ "llama2.mojo",
16
+ Path(base + model_name),
17
+ "-s",
18
+ str(seed),
19
+ "-n",
20
+ str(num_tokens),
21
+ "-t",
22
+ str(temperature),
23
+ "-i",
24
+ prompt,
25
+ "-z",
26
+ Path(base + tokenizer_name)
27
+ ],
28
+ stdout=subprocess.PIPE,
29
+ stderr=subprocess.PIPE,
30
  )
31
  text = ""
32
  for char in iter(lambda: process.stdout.read(1), b""):
33
+ char_decoded = char.decode("utf-8", errors="ignore")
 
34
  text += char_decoded
35
  yield text
36
 
37
 
38
+ with gr.Blocks() as demo:
39
+ gr.Markdown(
40
+ """
 
 
 
 
41
  # llama2.🔥
42
  ## [Mojo](https://docs.modular.com/mojo/) implementation of [llama2.c](https://github.com/karpathy/llama2.c) by [@tairov](https://github.com/tairov)
43
  Source: https://github.com/tairov/llama2.mojo
44
+ """
45
+ )
46
+ with gr.Row():
47
+ with gr.Column():
48
+ prompt = gr.Textbox(label="Prompt", placeholder="Add your prompt here...")
49
+ seed = gr.Slider(
50
+ minimum=0,
51
+ maximum=2**53,
52
+ value=0,
53
+ step=1,
54
+ label="Seed",
55
+ randomize=True,
56
+ )
57
+ temperature = gr.Slider(
58
+ minimum=0.0, maximum=2.0, step=0.01, value=0.0, label="Temperature"
59
+ )
60
+ num_tokens = gr.Slider(
61
+ minimum=1, maximum=256, value=256, label="Number of tokens"
62
+ )
63
+ model_name = gr.Dropdown(
64
+ ["stories15M.bin", "stories42M.bin", "stories110M.bin", "tl-chat.bin"],
65
+ value="stories15M.bin",
66
+ label="Model Size",
67
+ )
68
+ with gr.Row():
69
+ stop = gr.Button("Stop")
70
+ run = gr.Button("Run")
71
+ with gr.Column(scale=2):
72
+ output_text = gr.Textbox(label="Generated Text")
73
+
74
+ # update maximum number of tokens based on model size
75
+ model_name.change(
76
+ lambda x: gr.update(maximum=1024)
77
+ if x == "stories110M.bin" or x == "stories42M.bin" or x == "tl-chat.bin"
78
+ else gr.update(maximum=256),
79
+ model_name,
80
+ num_tokens,
81
+ queue=False,
82
+ )
83
+ click_event = run.click(
84
+ fn=generate,
85
+ inputs=[prompt, model_name, seed, temperature, num_tokens],
86
+ outputs=output_text,
87
+ )
88
+ stop.click(fn=None, inputs=None, outputs=None, cancels=[click_event])
89
 
90
  demo.queue()
91
  demo.launch(server_name="0.0.0.0")
llama2.mojo CHANGED
@@ -1,194 +1,220 @@
 
 
 
1
  from math import round
2
- import math
3
-
4
  from memory import memset_zero, memcpy
 
5
  from memory.unsafe import DTypePointer
 
6
  from random import rand
7
- from sys.info import simdwidthof
8
- from builtin import string
9
- import time
10
- import random
11
- import os
12
-
13
- from runtime.llcl import num_cores
14
-
15
  from read import BufReader, File
16
- from memory.buffer import Buffer
17
-
18
- from python import Python
19
 
20
  # The SIMD vector width.
21
- from algorithm import vectorize, parallelize
22
- from algorithm import sum
 
 
 
23
 
24
- alias nelts = (2 * simdwidthof[DType.float32]())
 
 
25
 
26
  alias PointerString = Pointer[UInt8]
27
  alias BufferPtrType = DTypePointer[DType.uint8]
28
  alias BufferPtrFloat32 = DTypePointer[DType.float32]
29
  alias PointerStrings = Pointer[PointerString]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
 
 
31
 
32
- struct Matrix3:
33
- var data: BufferPtrFloat32
34
- var rows: Int
35
- var cols: Int
36
- var layers: Int
37
- var allocated: Int
38
-
39
- fn __init__(inout self, layers: Int, rows: Int, cols: Int):
40
- self.data = BufferPtrFloat32.alloc(0)
41
- self.rows = rows
42
- self.cols = cols
43
- self.layers = layers
44
- self.allocated = 0
45
 
46
- @always_inline
47
- fn alloc(inout self, fill: Int = 0):
48
- self.data = BufferPtrFloat32.alloc(self.size())
49
- self.allocated = 1
50
- if fill == 1:
51
- self.zero()
52
 
53
- @always_inline
54
- fn alloc_zero(inout self):
55
- self.alloc(1)
56
 
57
- @always_inline
58
- fn set_buf_ptr(inout self, ptr: BufferPtrFloat32):
59
- self.data = ptr
60
 
61
- fn __del__(owned self):
62
- if self.allocated == 1:
63
- self.data.free()
64
 
65
- @always_inline
66
- fn zero(inout self):
67
- memset_zero(self.data, self.layers * self.rows * self.cols)
 
 
 
 
68
 
69
- @always_inline
70
- fn size(inout self) -> Int:
71
- return self.layers * self.cols * self.rows
 
72
 
73
- @always_inline
74
- fn __getitem__(self, z: Int, y: Int, x: Int) -> Float32:
75
- return self.load[1](z, y, x)
76
 
77
- @always_inline
78
- fn load[nelts: Int](self, z: Int, y: Int, x: Int) -> SIMD[DType.float32, nelts]:
79
- return self.data.simd_load[nelts](z * self.layers + y * self.cols + x)
80
 
81
- @always_inline
82
- fn __setitem__(self, z: Int, y: Int, x: Int, val: Float32):
83
- return self.store[1](z, y, x, val)
84
 
85
- @always_inline
86
- fn store[nelts: Int](self, z: Int, y: Int, x: Int, val: SIMD[DType.float32, nelts]):
87
- self.data.simd_store[nelts](z * self.layers + y * self.cols + x, val)
88
 
 
 
 
 
 
 
89
 
90
- struct Matrix:
91
- var data: BufferPtrFloat32
92
- var rows: Int
93
- var cols: Int
94
- var allocated: Int
95
 
96
- fn __init__(inout self, rows: Int, cols: Int):
97
- self.data = BufferPtrFloat32.alloc(0)
98
- self.rows = rows
99
- self.cols = cols
100
- self.allocated = 0
101
 
102
- fn __init__(inout self, cols: Int):
103
- self.data = BufferPtrFloat32.alloc(0)
104
- self.rows = 1
105
- self.cols = cols
106
- self.allocated = 0
107
 
108
- fn __del__(owned self):
109
- if self.allocated == 1:
110
- self.data.free()
111
-
112
- fn alloc(inout self, fill: Int = 0):
113
- self.data = BufferPtrFloat32.alloc(self.size())
114
- self.allocated = 1
115
- if fill == 1:
116
- self.zero()
117
 
118
- fn alloc_zero(inout self):
119
- self.alloc(1)
120
 
121
- fn zero(inout self):
122
- memset_zero(self.data, self.rows * self.cols)
123
 
124
- fn set_buf_ptr(inout self, ptr: BufferPtrFloat32):
125
- self.data = ptr
 
 
 
126
 
127
- # set buf ptr with redefined rows, colss
128
- fn set_buf_ptr(inout self, ptr: BufferPtrFloat32, rows: Int, cols: Int):
129
- self.data = ptr
130
- self.rows = rows
131
- self.cols = cols
132
 
133
- fn size(inout self) -> Int:
134
- return self.cols * self.rows
 
 
 
 
 
 
 
135
 
136
- @always_inline
137
- fn __getitem__(self, y: Int, x: Int) -> Float32:
138
- return self.load[1](y, x)
139
 
140
- @always_inline
141
- fn __getitem__(self, x: Int) -> Float32:
142
- return self.load[1](0, x)
 
 
 
143
 
144
- @always_inline
145
- fn load[nelts: Int](self, y: Int, x: Int) -> SIMD[DType.float32, nelts]:
146
- return self.data.simd_load[nelts](y * self.cols + x)
147
 
148
- @always_inline
149
- fn __setitem__(self, y: Int, x: Int, val: Float32):
150
- return self.store[1](y, x, val)
 
 
 
 
151
 
152
- @always_inline
153
- fn __setitem__(self, x: Int, val: Float32):
154
- return self.store[1](0, x, val)
155
 
156
- @always_inline
157
- fn store[nelts: Int](self, y: Int, x: Int, val: SIMD[DType.float32, nelts]):
158
- self.data.simd_store[nelts](y * self.cols + x, val)
159
 
160
- @always_inline
161
- fn load[nelts: Int](self, x: Int) -> SIMD[DType.float32, nelts]:
162
- return self.data.simd_load[nelts](x)
163
 
164
- @always_inline
165
- fn store[nelts: Int](self, x: Int, val: SIMD[DType.float32, nelts]):
166
- self.data.simd_store[nelts](x, val)
167
 
168
 
169
- fn read_val_int(inout buf: FileBuf) -> Int:
170
- # DTypePointer[DType.ui8](buf.data).bitcast[DType.ui8]()
171
- let data = buf.data.offset(buf.offset).bitcast[DType.uint32]()
172
- let result = data.simd_load[1](0)
173
- buf.offset += 4
174
- return result.to_int()
 
 
 
 
175
 
 
 
 
 
 
 
176
 
177
- fn read_val_float32(inout buf: FileBuf) -> Float32:
178
- # DTypePointer[DType.ui8](buf.data).bitcast[DType.ui8]()
179
- let val = buf.data.offset(buf.offset).bitcast[DType.float32]().simd_load[1](0)
180
- buf.offset += 4
181
- return val
 
 
182
 
 
183
 
184
- fn read_val_str(inout buf: FileBuf, slen: Int) -> PointerString:
185
- let str = PointerString.alloc(slen + 1)
186
- for i in range(slen):
187
- str.store(i, buf.data.simd_load[1](buf.offset))
188
- buf.offset += 1
189
- str.store(slen, 0)
190
 
191
- return str
 
 
 
 
 
 
192
 
193
 
194
  struct FileBuf:
@@ -201,36 +227,111 @@ struct FileBuf:
201
  self.offset = 0
202
  self.size = 0
203
 
204
- fn move_offset(inout self, size: Int):
205
- self.offset += size
 
 
 
 
 
 
 
 
206
 
207
- fn bitcast_offset_float32(inout self, size: Int) -> BufferPtrFloat32:
208
  let ret = self.data.offset(self.offset).bitcast[DType.float32]()
209
- self.offset += size * sizeof[DType.float32]()
210
  return ret
211
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
212
 
213
  struct Tokenizer:
214
  var vocab: PointerStrings
215
  var vocab_scores: BufferPtrFloat32
216
  var max_token_length: Int
217
  var vocab_size: Int
 
 
218
 
219
- fn __init__(inout self, vocab_size: Int):
220
  self.vocab_size = vocab_size
221
- self.vocab = PointerStrings.alloc(vocab_size)
222
- self.vocab_scores = BufferPtrFloat32.alloc(vocab_size)
223
- self.max_token_length = 0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
224
 
225
 
226
  struct Config:
227
  var dim: Int
 
228
  var hidden_dim: Int
229
  var n_layers: Int
230
  var n_heads: Int
231
  var n_kv_heads: Int
 
232
  var vocab_size: Int
233
  var seq_len: Int
 
234
 
235
  fn __init__(inout self):
236
  self.dim = 0
@@ -240,109 +341,93 @@ struct Config:
240
  self.n_kv_heads = 0
241
  self.vocab_size = 0
242
  self.seq_len = 0
 
 
 
243
 
244
 
245
  struct RunState:
246
- var x: Matrix # activation at current time stamp (dim,)
247
- var xb: Matrix # same, but inside a residual branch (dim,)
248
- var xb2: Matrix # an additional buffer just for convenience (dim,)
249
- var hb: Matrix # buffer for hidden dimension in the ffn (hidden_dim,)
250
- var hb2: Matrix # buffer for hidden dimension in the ffn (hidden_dim,)
251
- var q: Matrix # query (dim,)
252
- var k: Matrix # key (dim,)
253
- var v: Matrix # value (dim,)
254
- var att: Matrix # buffer for scores/attention values (n_heads, seq_len)
255
- var logits: Matrix # output logits
256
- var key_cache: Matrix3 # (layer, seq_len, dim)
257
- var value_cache: Matrix3 # (layer, seq_len, dim)
258
-
259
- fn __init__(inout self, config: Config):
260
- self.x = Matrix(config.dim)
261
- self.x.alloc_zero()
262
- self.xb = Matrix(config.dim)
263
- self.xb.alloc_zero()
264
- self.xb2 = Matrix(config.dim)
265
- self.xb2.alloc_zero()
266
- self.hb = Matrix(config.hidden_dim)
267
- self.hb.alloc_zero()
268
- self.hb2 = Matrix(config.hidden_dim)
269
- self.hb2.alloc_zero()
270
- self.q = Matrix(config.dim)
271
- self.q.alloc_zero()
272
- self.k = Matrix(config.dim)
273
- self.k.alloc_zero()
274
- self.v = Matrix(config.dim)
275
- self.v.alloc_zero()
276
- self.att = Matrix(config.n_heads, config.seq_len)
277
- self.att.alloc_zero()
278
- self.logits = Matrix(config.vocab_size)
279
- self.logits.alloc_zero()
280
- self.key_cache = Matrix3(config.n_layers, config.seq_len, config.dim)
281
- self.key_cache.alloc_zero()
282
- self.value_cache = Matrix3(config.n_layers, config.seq_len, config.dim)
283
- self.value_cache.alloc_zero()
284
 
285
 
286
  struct TransformerWeights:
287
- var token_embedding_table: Matrix
288
- var freq_cis_real: Matrix
289
- var freq_cis_imag: Matrix
290
- var rms_att_weight: Matrix
291
- var wq: Matrix3
292
- var wk: Matrix3
293
- var wv: Matrix3
294
- var wo: Matrix3
295
- var rms_ffn_weight: Matrix
296
- var w1: Matrix3
297
- var w3: Matrix3
298
- var w2: Matrix3
299
- var rms_final_weight: Matrix
300
- var wcls: Matrix
301
-
302
- fn __init__(inout self, config: Config, shared_weights: Int, inout buf: FileBuf):
303
- self.token_embedding_table = Matrix(config.vocab_size, config.dim)
304
- # set buf ptr to buf data from file
305
- self.token_embedding_table.set_buf_ptr(
306
- buf.bitcast_offset_float32(self.token_embedding_table.size())
307
- )
308
- self.rms_att_weight = Matrix(config.n_layers, config.dim)
309
- self.rms_att_weight.set_buf_ptr(
310
- buf.bitcast_offset_float32(self.rms_att_weight.size())
311
- )
312
- self.wq = Matrix3(config.n_layers, config.dim, config.dim)
313
- self.wq.set_buf_ptr(buf.bitcast_offset_float32(self.wq.size()))
314
- self.wk = Matrix3(config.n_layers, config.dim, config.dim)
315
- self.wk.set_buf_ptr(buf.bitcast_offset_float32(self.wk.size()))
316
- self.wv = Matrix3(config.n_layers, config.dim, config.dim)
317
- self.wv.set_buf_ptr(buf.bitcast_offset_float32(self.wv.size()))
318
- self.wo = Matrix3(config.n_layers, config.dim, config.dim)
319
- self.wo.set_buf_ptr(buf.bitcast_offset_float32(self.wo.size()))
320
- self.rms_ffn_weight = Matrix(config.n_layers, config.dim)
321
- self.rms_ffn_weight.set_buf_ptr(
322
- buf.bitcast_offset_float32(self.rms_ffn_weight.size())
323
- )
324
- self.w1 = Matrix3(config.n_layers, config.dim, config.hidden_dim)
325
- self.w1.set_buf_ptr(buf.bitcast_offset_float32(self.w1.size()))
326
- self.w2 = Matrix3(config.n_layers, config.dim, config.hidden_dim)
327
- self.w2.set_buf_ptr(buf.bitcast_offset_float32(self.w2.size()))
328
- self.w3 = Matrix3(config.n_layers, config.dim, config.hidden_dim)
329
- self.w3.set_buf_ptr(buf.bitcast_offset_float32(self.w3.size()))
330
- self.rms_final_weight = Matrix(config.dim)
331
- self.rms_final_weight.set_buf_ptr(
332
- buf.bitcast_offset_float32(self.rms_final_weight.size())
333
- )
334
- self.freq_cis_real = Matrix(config.seq_len, (config.dim // config.n_heads) // 2)
335
- self.freq_cis_real.set_buf_ptr(
336
- buf.bitcast_offset_float32(self.freq_cis_real.size())
337
- )
338
- self.freq_cis_imag = Matrix(config.seq_len, (config.dim // config.n_heads) // 2)
339
- self.freq_cis_imag.set_buf_ptr(
340
- buf.bitcast_offset_float32(self.freq_cis_imag.size())
341
- )
342
- self.wcls = Matrix(
343
- config.vocab_size, config.dim
344
- ) # if shared_weights else rest_floats
345
- self.wcls.set_buf_ptr(self.token_embedding_table.data)
346
 
347
 
348
  fn read_file(file_name: String, inout buf: FileBuf) raises:
@@ -375,270 +460,323 @@ fn config_init(inout config: Config, inout buf: FileBuf) raises:
375
  config.n_kv_heads = read_val_int(buf)
376
  config.vocab_size = read_val_int(buf)
377
  config.seq_len = read_val_int(buf)
 
 
 
378
  return None
379
 
380
 
381
- fn tokenizer_init(inout tok: Tokenizer, inout buf: FileBuf) -> None:
382
- tok.max_token_length = read_val_int(buf)
383
- tok.vocab_scores = BufferPtrFloat32.alloc(tok.vocab_size)
384
- tok.vocab = PointerStrings.alloc(tok.vocab_size)
385
-
386
- # read vocab_scores & vocab values (tokens)
387
- for i in range(0, tok.vocab_size):
388
- tok.vocab_scores.simd_store[1](i, read_val_float32(buf))
389
- let slen = read_val_int(buf)
390
- tok.vocab.store(i, read_val_str(buf, slen))
391
-
392
- tok.vocab_scores = buf.data.offset(buf.offset).bitcast[DType.float32]()
393
- buf.offset += tok.vocab_size * 4
394
- return None
395
 
 
 
 
396
 
397
- fn accum(inout a: BufferPtrFloat32, b: BufferPtrFloat32, size: Int) -> None:
398
- for i in range(size):
399
- let val = a.offset(i).simd_load[1](0) + b.offset(i).simd_load[1](0)
400
- a.offset(i).simd_store[1](0, val)
401
 
402
 
 
403
  fn rmsnorm(
404
  inout o: BufferPtrFloat32, x: BufferPtrFloat32, weight: BufferPtrFloat32, size: Int
405
  ) -> None:
406
  # Calculate sum of squares
407
- var ss: Float32 = 0.0
408
- for i in range(size):
409
- let xx = x.offset(i).simd_load[1](0) ** 2
410
- ss += xx
 
 
 
 
 
 
 
 
411
  ss = ss / size + 1e-5
412
  ss = 1.0 / math.sqrt(ss)
 
413
  # Normalize and scale
414
- for j in range(size):
415
- let val = weight.offset(j).simd_load[1](0) * (ss * x.offset(j).simd_load[1](0))
416
- o.offset(j).simd_store[1](0, val)
417
-
418
-
419
- fn softmax(inout x: BufferPtrFloat32, size: Int) -> None:
420
- # Find max value (for numerical stability)
421
- var max_val: Float32 = x.offset(0).simd_load[1](0)
422
- for i in range(size):
423
- let xi = x.offset(i).simd_load[1](0)
424
- if xi > max_val:
425
- max_val = xi
426
- # Exp and sum
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
427
  var ssum: Float32 = 0.0
428
- for i in range(size):
429
- let xi = x.offset(i).simd_load[1](0)
430
- x.offset(i).simd_store[1](0, math.exp(xi - max_val))
431
- ssum += x.offset(i).simd_load[1](0)
432
- # Normalize
433
- for i in range(size):
434
- let xi = x.offset(i).simd_load[1](0)
435
- x.offset(i).simd_store[1](0, xi / ssum)
436
-
437
-
438
- fn matmul_naive(C: Matrix, x: Matrix, w: Matrix) -> None:
439
- # W(d,n) @ X(n,) -> C (d,)
440
- # By far the most amount of time is spent inside this little function
441
- for i in range(w.rows):
442
- C[i] = 0.0
443
- for j in range(w.cols):
444
- C[i] += x[j] * w[i, j]
445
-
446
-
447
- fn matmul_vectorized(C: Matrix, A: Matrix, B: Matrix):
448
- for i in range(0, B.rows):
449
  var tmp = SIMD[DType.float32, nelts](0)
450
 
451
  @parameter
452
  fn dot[_nelts: Int](j: Int):
453
- if _nelts < nelts: # take care of tail array elements with length < nelts
454
- tmp[0] += (A.load[_nelts](j) * B.load[_nelts](i, j)).reduce_add()
 
 
455
  else:
456
- tmp += A.load[nelts](j) * B.load[nelts](i, j)
457
 
458
- vectorize[nelts, dot](B.cols)
459
- C[i] = tmp.reduce_add()
 
460
 
461
- fn matmul_parallelized(C: Matrix, A: Matrix, B: Matrix):
462
- @parameter
463
- fn calc_row(i: Int):
464
- var T = BufferPtrFloat32.alloc(nelts)
465
- var Tbuf = Buffer[nelts, DType.float32](T)
466
- memset_zero(T, nelts)
467
- @parameter
468
- fn dot[nelts: Int](j: Int):
469
- T.simd_store[nelts](
470
- 0, T.simd_load[nelts](0) + A.load[nelts](j) * B.load[nelts](i, j)
471
- )
472
 
473
- vectorize[nelts, dot](B.cols)
474
- C[i] = sum[nelts, DType.float32](Tbuf)
475
 
476
- parallelize[calc_row](B.rows)
 
 
 
 
 
 
477
 
478
 
479
- fn matmul(inout C: Matrix, A: Matrix, B: Matrix) -> None:
 
480
  # B (d,n) @ A (n,) -> C (d,)
481
- matmul_vectorized(C, A, B)
482
- # matmul_parallelized(C, A, B)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
483
 
484
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
485
  fn transformer(
486
  token: Int,
487
  pos: Int,
488
  config: Config,
489
  inout state: RunState,
490
  weights: TransformerWeights,
491
- ) -> None:
492
  # A few convenience variables
493
- var x = state.x.data
494
  let dim = config.dim
495
  let hidden_dim = config.hidden_dim
496
- let head_size = dim // config.n_heads
497
-
498
- # tmp matrix for matmul operations
499
- var tmpw = Matrix(0, 0)
500
 
501
  # Copy the token embedding into x
502
- let content_row = weights.token_embedding_table.data.offset(token * dim)
503
- memcpy[DType.float32](x, content_row, config.dim)
504
 
505
  # Pluck out the "pos" row of freq_cis_real and freq_cis_imag
506
- let freq_cis_real_row = weights.freq_cis_real.data.offset(pos * head_size // 2)
507
- let freq_cis_imag_row = weights.freq_cis_imag.data.offset(pos * head_size // 2)
508
 
509
  # Forward all the layers
510
  for l in range(config.n_layers):
511
  # Attention rmsnorm
512
- rmsnorm(state.xb.data, x, weights.rms_att_weight.data.offset(l * dim), dim)
513
-
514
  # QKV matmuls for this position
515
- tmpw.set_buf_ptr(weights.wq.data.offset(l * dim * dim), dim, dim)
516
- matmul(state.q, state.xb, tmpw)
517
 
518
- tmpw.set_buf_ptr(weights.wk.data.offset(l * dim * dim), dim, dim)
519
- matmul(state.k, state.xb, tmpw)
 
520
 
521
- tmpw.set_buf_ptr(weights.wv.data.offset(l * dim * dim), dim, dim)
522
- matmul(state.v, state.xb, tmpw)
523
 
524
  # Apply RoPE rotation to the q and k vectors for each head
525
- for h in range(config.n_heads):
526
- # Get the q and k vectors for this head
527
- let q = state.q.data.offset(h * head_size)
528
- let k = state.k.data.offset(h * head_size)
529
-
530
- # Rotate q and k by the freq_cis_real and freq_cis_imag
531
- for i in range(0, head_size, 2):
532
- let q0 = q.offset(i).simd_load[1](0)
533
- let q1 = q.offset(i + 1).simd_load[1](0)
534
- let k0 = k.offset(i).simd_load[1](0)
535
- let k1 = k.offset(i + 1).simd_load[1](0)
536
- let fcr = freq_cis_real_row.offset(i // 2).simd_load[1](0)
537
- let fci = freq_cis_imag_row.offset(i // 2).simd_load[1](0)
538
- q.offset(i).simd_store[1](0, q0 * fcr - q1 * fci)
539
- q.offset(i + 1).simd_store[1](0, q0 * fci + q1 * fcr)
540
- k.offset(i).simd_store[1](0, k0 * fcr - k1 * fci)
541
- k.offset(i + 1).simd_store[1](0, k0 * fci + k1 * fcr)
542
-
543
- # Save key,value at this time step (pos) to our kv cache
544
- let loff = l * config.seq_len * dim # kv cache layer offset for convenience
545
- let key_cache_row = state.key_cache.data.offset(loff + pos * dim)
546
- let value_cache_row = state.value_cache.data.offset(loff + pos * dim)
547
- memcpy[DType.float32](key_cache_row, state.k.data, config.dim)
548
- memcpy[DType.float32](value_cache_row, state.v.data, config.dim)
549
-
550
- # Multihead attention. Iterate over all heads
551
- for h in range(config.n_heads):
552
  # Get the query vector for this head
553
- let q = state.q.data.offset(h * head_size)
554
 
555
- # Attention scores for this head
556
- var att = state.att.data.offset(h * config.seq_len)
557
 
558
  # Iterate over all timesteps, including the current one
559
  for t in range(pos + 1):
560
- # Get the key vector for this head and at this timestep
561
- let k = state.key_cache.data.offset(loff + t * dim + h * head_size)
562
  # Calculate the attention score as the dot product of q and k
563
  var score: Float32 = 0.0
564
- for i in range(head_size):
565
- score += q.offset(i).simd_load[1](0) * k.offset(i).simd_load[1](0)
 
 
 
 
 
 
 
566
  score /= math.sqrt[DType.float32, 1](head_size)
567
 
568
  # Save the score to the attention buffer
569
- att.offset(t).simd_store[1](0, score)
570
 
571
  # Softmax the scores to get attention weights, from 0..pos inclusively
572
- softmax(att, pos + 1)
573
-
574
  # Weighted sum of the values, store back into xb
575
- let xb = state.xb.data.offset(h * head_size)
576
- memset_zero(xb, head_size)
577
  for t in range(pos + 1):
578
- # Get the value vector for this head and at this timestep
579
- let v = state.value_cache.data.offset(loff + t * dim + h * head_size)
 
580
  # Get the attention weight for this timestep
581
- let a = att.offset(t).simd_load[1](0)
582
  # Accumulate the weighted value into xb
583
- for i in range(head_size):
584
- let xbi = xb.offset(i).simd_load[1](0) + a * v.offset(i).simd_load[
585
- 1
586
- ](0)
587
- xb.offset(i).simd_store[1](0, xbi)
588
- # Final matrix multiplication to get the output of the attention
589
- tmpw.set_buf_ptr(weights.wo.data.offset(l * dim * dim), dim, dim)
590
- matmul(state.xb2, state.xb, tmpw)
591
 
592
- # Residual connection back into x
593
- accum(x, state.xb2.data, dim)
 
 
 
 
594
 
 
 
 
 
 
 
 
595
  # FFN rmsnorm
596
- rmsnorm(state.xb.data, x, weights.rms_ffn_weight.data.offset(l * dim), dim)
597
 
598
  # Calculate self.w1(x) and self.w3(x) for FFN
599
- tmpw.set_buf_ptr(weights.w1.data.offset(l * dim * hidden_dim), hidden_dim, dim)
600
- matmul(state.hb, state.xb, tmpw)
601
-
602
- tmpw.set_buf_ptr(weights.w3.data.offset(l * dim * hidden_dim), hidden_dim, dim)
603
- matmul(state.hb2, state.xb, tmpw)
604
-
605
- # Apply SiLU activation function (silu(x) = x * sigmoid(x))
606
- for i in range(hidden_dim):
607
- let hbi = state.hb[i]
608
- state.hb[i] = hbi * (1.0 / (1.0 + math.exp(-hbi)))
609
 
610
- # Elementwise multiply with w3(x)
611
- for i in range(hidden_dim):
612
- state.hb[i] = state.hb[i] * state.hb2[i]
613
 
 
 
 
 
 
 
 
 
 
614
  # Final matrix multiplication to get the output of the FFN
615
- tmpw.set_buf_ptr(weights.w2.data.offset(l * dim * hidden_dim), dim, hidden_dim)
616
- matmul(state.xb, state.hb, tmpw)
617
 
618
  # Residual connection
619
- accum(x, state.xb.data, dim)
620
 
621
  # Final rmsnorm
622
- rmsnorm(x, x, weights.rms_final_weight.data, dim)
623
 
624
  # Classifier into logits
625
- tmpw.set_buf_ptr(weights.wcls.data, config.vocab_size, dim)
626
- matmul(state.logits, state.x, tmpw)
627
 
628
 
629
- fn argmax(v: Matrix) -> Int:
630
  # return argmax of v
631
  var max_i: Int = 0
632
  var max_p: Float32 = v[0]
633
- for i in range(v.cols):
634
  if v[i] > max_p:
635
  max_i = i
636
  max_p = v[i]
637
  return max_i
638
 
639
 
640
- fn sample(probabilities: Matrix) -> Int:
641
- let n = probabilities.cols
642
  # Sample index from probabilities, they must sum to 1
643
  # get random value within (min, max) float32 range
644
  let r = DTypePointer[DType.float32].alloc(1)
@@ -646,12 +784,64 @@ fn sample(probabilities: Matrix) -> Int:
646
  var cdf: Float32 = 0.0
647
  for i in range(n):
648
  cdf += probabilities[i]
649
- if r.simd_load[1](0) < cdf:
650
  return i
651
  return n - 1 # In case of rounding errors
652
 
653
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
654
  fn print_str(s: PointerString):
 
 
 
 
 
 
655
  # print all chars till null character
656
  var p: Int = 0
657
  while s[p].to_int() != 0:
@@ -664,22 +854,76 @@ fn time_in_ms() -> Int:
664
  return time.now() // 1_000_000
665
 
666
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
667
  fn main() raises:
668
- print("num hardware threads: ", num_cores(), " SIMD vector width: ", nelts)
669
- let checkpoint = "stories15M.bin"
670
- # let checkpoint = "stories110M.bin"
671
- let tokenizer = "tokenizer.bin"
672
- let temperature = 0.0
673
  var steps = 256
674
- let prompt = ""
675
- let rng_seed: Int = time.now()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
676
  random.seed(rng_seed)
677
  var fbuf: FileBuf = FileBuf()
678
  var tbuf: FileBuf = FileBuf()
679
  var config: Config = Config()
680
 
681
  read_file(checkpoint, fbuf)
682
- print("checkpoint size: ", fbuf.size)
683
  config_init(config, fbuf)
684
 
685
  # negative vocab size is hacky way of signaling unshared weights. bit yikes.
@@ -690,51 +934,63 @@ fn main() raises:
690
 
691
  let weights: TransformerWeights = TransformerWeights(config, shared_weights, fbuf)
692
 
693
- var tok: Tokenizer = Tokenizer(config.vocab_size)
694
-
695
  if steps <= 0 or steps > config.seq_len:
696
  steps = config.seq_len
697
 
698
  # Read in the tokenizer.bin file
699
  read_file(tokenizer, tbuf)
700
- tokenizer_init(tok, tbuf)
 
 
 
 
701
 
702
  # Create and initialize the application RunState
703
  var state = RunState(config)
704
 
 
 
 
 
 
 
705
  # Start the main loop
706
  var start = 0 # Used to time our code, only initialized after the first iteration
707
  var next_token = 0 # Will store the next token in the sequence
708
  # Initialize with token 1 (=BOS), as done in Llama-2 sentencepiece tokenizer
709
  var token = 1
710
- var pos = 0 # Position in the sequence
711
- # Explicitly print the initial BOS token for stylistic symmetry reasons
712
-
713
- print("<s>")
714
 
 
 
715
  while pos < steps:
716
  # Forward the transformer to get logits for the next token
717
  transformer(token, pos, config, state, weights)
718
 
719
- # Sample the next token
720
- if temperature == 0.0:
721
- # Greedy argmax sampling: take the token with the highest probability
722
- next_token = argmax(state.logits)
723
  else:
724
- # Apply the temperature to the logits
725
- for q in range(config.vocab_size):
726
- state.logits[q] = state.logits[q] / temperature
727
- # Apply softmax to the logits to get the probabilities for the next token
728
- softmax(state.logits.data, config.vocab_size)
729
- # Sample from this distribution to get the next token
730
- next_token = sample(state.logits)
731
-
 
 
 
 
 
 
 
 
 
732
  var token_str: PointerString = tok.vocab[next_token]
733
  if token == 1 and token_str[0] == ord(" "):
734
  token_str = token_str.offset(1)
735
 
736
  print_str(token_str)
737
- # flush?
738
 
739
  # Advance forward
740
  token = next_token
@@ -744,4 +1000,4 @@ fn main() raises:
744
  start = time_in_ms()
745
 
746
  let end = time_in_ms()
747
- print("\nachieved tok/s: ", (steps - 1) / (end - start) * 1000)
 
1
+ from algorithm import sum
2
+ from algorithm import vectorize, parallelize
3
+ from builtin import string
4
  from math import round
 
 
5
  from memory import memset_zero, memcpy
6
+ from memory.buffer import Buffer
7
  from memory.unsafe import DTypePointer
8
+ from python import Python
9
  from random import rand
 
 
 
 
 
 
 
 
10
  from read import BufReader, File
11
+ from runtime.llcl import num_cores, Runtime
12
+ from sys import argv
13
+ from tensor import Tensor, TensorShape, TensorSpec
14
 
15
  # The SIMD vector width.
16
+ from sys.info import simdwidthof
17
+ import math
18
+ import os
19
+ import random
20
+ import time
21
 
22
+ var workers = 0
23
+
24
+ alias nelts = (2*simdwidthof[DType.float32]())
25
 
26
  alias PointerString = Pointer[UInt8]
27
  alias BufferPtrType = DTypePointer[DType.uint8]
28
  alias BufferPtrFloat32 = DTypePointer[DType.float32]
29
  alias PointerStrings = Pointer[PointerString]
30
+ alias TensorF32 = Tensor[DType.float32]
31
+
32
+
33
+ struct TensorSlice:
34
+ # Provides a view into a tensor representing a 1D slice on its first or first 2 dimensions.
35
+ # Same function signatures as Tensor but without owning the data.
36
+ var _data: BufferPtrFloat32
37
+ var _shape: TensorShape
38
+
39
+ fn __init__(inout self, t: TensorF32, layer: Int) raises:
40
+ let elements_per_layer = t.num_elements() // t.dim(0)
41
+ self._data = t.data().offset(layer * elements_per_layer)
42
+ if t.rank() == 2:
43
+ self._shape = TensorShape(t.dim(1))
44
+ elif t.rank() == 3:
45
+ self._shape = TensorShape(t.dim(1), t.dim(2))
46
+ else:
47
+ # Compiler complains if _shape not defined
48
+ self._shape = TensorShape(1)
49
+ raise Error("TensorSlice: rank greater than 3 not implemented.")
50
+
51
+ fn __init__(inout self, t: TensorF32, layer: Int, row: Int) raises:
52
+ let elements_per_layer = t.num_elements() // t.dim(0)
53
+ let elements_per_row = elements_per_layer // t.dim(1)
54
+ self._data = t.data().offset(
55
+ layer * elements_per_layer + row * elements_per_row
56
+ )
57
+ if t.rank() == 3:
58
+ self._shape = TensorShape(t.dim(2))
59
+ elif t.rank() == 1:
60
+ # Compiler complains if _shape not defined
61
+ self._shape = TensorShape(1)
62
+ raise Error(
63
+ "Trying to slice a 1D Tensor by layer and row. This requires a 3D"
64
+ " Tensor."
65
+ )
66
+ else:
67
+ # Compiler complains if _shape not defined
68
+ self._shape = TensorShape(1)
69
+ raise Error("TensorSlice: rank greater than 3 not implemented.")
70
 
71
+ fn data(self) -> BufferPtrFloat32:
72
+ return self._data
73
 
74
+ fn shape(self) -> TensorShape:
75
+ return self._shape
 
 
 
 
 
 
 
 
 
 
 
76
 
77
+ fn num_elements(self) -> Int:
78
+ return self._shape.num_elements()
 
 
 
 
79
 
80
+ fn dim(self, idx: Int) -> Int:
81
+ return self._shape[idx]
 
82
 
83
+ fn rank(self) -> Int:
84
+ return self._shape.rank()
 
85
 
86
+ fn simd_load[nelts: Int](self, idx: Int) -> SIMD[DType.float32, nelts]:
87
+ return self._data.simd_load[nelts](idx)
 
88
 
89
+ fn simd_load[nelts: Int](self, *indices: Int) -> SIMD[DType.float32, nelts]:
90
+ if len(VariadicList(indices)) > 2:
91
+ print(
92
+ "Warning: TensorSlice only supports 1D and 2D indexing. Results are"
93
+ " unlikely to be correct."
94
+ )
95
+ return self.simd_load[nelts](indices[0] * self._shape[1] + indices[1])
96
 
97
+ fn simd_load[
98
+ nelts: Int
99
+ ](self, indices: StaticIntTuple[2]) -> SIMD[DType.float32, nelts]:
100
+ return self._data.simd_load[nelts](indices[0] * self._shape[1] + indices[1])
101
 
102
+ fn __getitem__(self, idx: Int) -> SIMD[DType.float32, 1]:
103
+ return self._data.simd_load[1](idx)
 
104
 
105
+ fn simd_store[nelts: Int](self, idx: Int, val: SIMD[DType.float32, nelts]):
106
+ return self._data.simd_store[nelts](idx, val)
 
107
 
108
+ fn __setitem__(self, idx: Int, val: SIMD[DType.float32, 1]):
109
+ return self.simd_store[1](idx, val)
 
110
 
 
 
 
111
 
112
+ fn read_val_int(inout buf: FileBuf) raises -> Int:
113
+ # DTypePointer[DType.ui8](buf.data).bitcast[DType.ui8]()
114
+ let data = buf.data.offset(buf.get_offset()).bitcast[DType.int32]()
115
+ let result = data.load(0)
116
+ buf.move_offset(4)
117
+ return result.to_int()
118
 
 
 
 
 
 
119
 
120
+ fn read_val_float32(inout buf: FileBuf) raises -> Float32:
121
+ # DTypePointer[DType.ui8](buf.data).bitcast[DType.ui8]()
122
+ let val = buf.data.offset(buf.get_offset()).bitcast[DType.float32]().load(0)
123
+ buf.move_offset(4)
124
+ return val
125
 
 
 
 
 
 
126
 
127
+ fn read_val_str(inout buf: FileBuf, slen: Int) raises -> PointerString:
128
+ let str = PointerString.alloc(slen + 1)
129
+ for i in range(slen):
130
+ str.store(i, buf.data.load(buf.get_offset()))
131
+ buf.move_offset(1)
132
+ str.store(slen, 0)
 
 
 
133
 
134
+ return str
 
135
 
 
 
136
 
137
+ fn str_len(s: PointerString) -> Int:
138
+ var len = 0
139
+ while s[len] != 0:
140
+ len += 1
141
+ return len
142
 
 
 
 
 
 
143
 
144
+ # not optimal concat
145
+ fn str_concat(s1: PointerString, s2: PointerString) -> PointerString:
146
+ let l1 = str_len(s1)
147
+ let l2 = str_len(s2)
148
+ let str = PointerString.alloc(l1 + l2 + 1)
149
+ memcpy[UInt8](str, s1, l1)
150
+ memcpy[UInt8](str.offset(l1), s2, l2)
151
+ str.store(l1 + l2, 0)
152
+ return str
153
 
 
 
 
154
 
155
+ fn str_to_ptr(s: String) -> PointerString:
156
+ let ret = PointerString.alloc(len(s) + 1)
157
+ for i in range(len(s)):
158
+ ret.store(i, ord(s[i]))
159
+ ret.store(len(s), 0)
160
+ return ret
161
 
 
 
 
162
 
163
+ fn string_compare(a: PointerString, b: PointerString) -> Int:
164
+ var index = 0
165
+ while a[index] != 0 and b[index] != 0:
166
+ if a[index] < b[index]:
167
+ return -1
168
+ if a[index] > b[index]:
169
+ return 1
170
 
171
+ index += 1
 
 
172
 
173
+ if a[index] != 0 and b[index] == 0:
174
+ return 1
 
175
 
176
+ if a[index] == 0 and b[index] != 0:
177
+ return -1
 
178
 
179
+ return 0
 
 
180
 
181
 
182
+ # Quicksort helper function to find the partition position
183
+ fn partition(
184
+ inout array: PointerStrings, inout indices: DynamicVector[Int], low: Int, high: Int
185
+ ) -> Int:
186
+ let pivot = array[high]
187
+ var ii = low - 1
188
+ for jj in range(low, high):
189
+ if string_compare(pivot, array[jj]) == 1:
190
+ # If element smaller than pivot, swap
191
+ ii = ii + 1
192
 
193
+ let tmp = array[ii]
194
+ let tmp_idx = indices[ii]
195
+ array.store(ii, array[jj])
196
+ indices[ii] = indices[jj]
197
+ array.store(jj, tmp)
198
+ indices[jj] = tmp_idx
199
 
200
+ # Swap the pivot element
201
+ let tmp = array[ii + 1]
202
+ let tmp_idx = indices[ii + 1]
203
+ array.store(ii + 1, array[high])
204
+ indices[ii + 1] = indices[high]
205
+ array.store(high, tmp)
206
+ indices[high] = tmp_idx
207
 
208
+ return ii + 1
209
 
 
 
 
 
 
 
210
 
211
+ fn quicksort(
212
+ inout array: PointerStrings, inout indices: DynamicVector[Int], low: Int, high: Int
213
+ ):
214
+ if low < high:
215
+ let pi = partition(array, indices, low, high)
216
+ quicksort(array, indices, low, pi - 1)
217
+ quicksort(array, indices, pi + 1, high)
218
 
219
 
220
  struct FileBuf:
 
227
  self.offset = 0
228
  self.size = 0
229
 
230
+ fn __del__(owned self):
231
+ self.data.free()
232
+
233
+ fn move_offset(inout self, size: Int) raises:
234
+ let new_offset = self.offset + size
235
+ if new_offset > self.size:
236
+ raise Error("Resulting offset will be past the end of the FileBuf")
237
+ if new_offset < 0:
238
+ raise Error("Resulting offset will be before the beginning of the FileBuf")
239
+ self.offset = new_offset
240
 
241
+ fn bitcast_offset_f32(inout self, size: Int) raises -> BufferPtrFloat32:
242
  let ret = self.data.offset(self.offset).bitcast[DType.float32]()
243
+ self.move_offset(size * sizeof[DType.float32]())
244
  return ret
245
 
246
+ fn get_offset(self) raises -> Int:
247
+ if self.offset > self.size:
248
+ raise Error("Offset is past the end of the FileBuf")
249
+ if self.offset < 0:
250
+ raise Error("Offset is before the beginning of the FileBuf")
251
+ return self.offset
252
+
253
+
254
+ fn wrap(token: PointerString) -> PointerString:
255
+ if string_compare(token, str_to_ptr("\\n")) == 0:
256
+ return str_to_ptr("<0x0A>")
257
+ if string_compare(token, str_to_ptr("\\t")) == 0:
258
+ return str_to_ptr("<0x09>")
259
+ if string_compare(token, str_to_ptr("'")) == 0:
260
+ return str_to_ptr("<0x27>")
261
+ elif string_compare(token, str_to_ptr('"')) == 0:
262
+ return str_to_ptr("<0x22>")
263
+ return token
264
+
265
 
266
  struct Tokenizer:
267
  var vocab: PointerStrings
268
  var vocab_scores: BufferPtrFloat32
269
  var max_token_length: Int
270
  var vocab_size: Int
271
+ var sorted_vocab: PointerStrings
272
+ var sorted_indices: DynamicVector[Int]
273
 
274
+ fn __init__(inout self, vocab_size: Int, inout buf: FileBuf) raises -> None:
275
  self.vocab_size = vocab_size
276
+ self.max_token_length = read_val_int(buf)
277
+ self.vocab_scores = BufferPtrFloat32.alloc(self.vocab_size)
278
+ self.vocab = PointerStrings.alloc(self.vocab_size)
279
+ # lazy load sorted vocab
280
+ self.sorted_vocab = PointerStrings.alloc(0)
281
+ self.sorted_indices = DynamicVector[Int](0)
282
+
283
+ # read vocab_scores & vocab values (tokens)
284
+ for i in range(0, self.vocab_size):
285
+ self.vocab_scores.store(i, read_val_float32(buf))
286
+ let slen = read_val_int(buf)
287
+ self.vocab.store(i, read_val_str(buf, slen))
288
+
289
+ return None
290
+
291
+ # sort vocab by string_compare
292
+ fn sort(inout self) -> None:
293
+ if len(self.sorted_indices) < self.vocab_size:
294
+ self.sorted_indices = DynamicVector[Int](self.vocab_size)
295
+ self.sorted_vocab = PointerStrings.alloc(self.vocab_size)
296
+ for ii in range(self.vocab_size):
297
+ self.sorted_vocab.store(ii, self.vocab[ii])
298
+ self.sorted_indices.push_back(ii)
299
+
300
+ let n = self.vocab_size
301
+ quicksort(self.sorted_vocab, self.sorted_indices, 0, n - 1)
302
+ return None
303
+
304
+ # Binary search that returns -1 if string is not found
305
+ fn find(inout self, token_o: PointerString) -> Int:
306
+ let token = wrap(token_o)
307
+ let n = self.vocab_size
308
+ if len(self.sorted_indices) < n:
309
+ self.sort()
310
+ var left = 0
311
+ var right = n - 1
312
+ while left <= right:
313
+ let mid = left + (right - left) // 2
314
+ let comparison = string_compare(self.sorted_vocab[mid], token)
315
+ if comparison == 0:
316
+ return self.sorted_indices[mid]
317
+ if comparison < 0:
318
+ left = mid + 1
319
+ else:
320
+ right = mid - 1
321
+ return -1
322
 
323
 
324
  struct Config:
325
  var dim: Int
326
+ var kv_dim: Int
327
  var hidden_dim: Int
328
  var n_layers: Int
329
  var n_heads: Int
330
  var n_kv_heads: Int
331
+ var kv_mul: Int
332
  var vocab_size: Int
333
  var seq_len: Int
334
+ var head_size: Int
335
 
336
  fn __init__(inout self):
337
  self.dim = 0
 
341
  self.n_kv_heads = 0
342
  self.vocab_size = 0
343
  self.seq_len = 0
344
+ self.kv_dim = 0
345
+ self.kv_mul = 0
346
+ self.head_size = 0
347
 
348
 
349
  struct RunState:
350
+ var x: TensorF32 # activation at current time stamp (dim,)
351
+ var xb: TensorF32 # same, but inside a residual branch (dim,)
352
+ var xb2: TensorF32 # an additional buffer just for convenience (dim,)
353
+ var hb: TensorF32 # buffer for hidden dimension in the ffn (hidden_dim,)
354
+ var hb2: TensorF32 # buffer for hidden dimension in the ffn (hidden_dim,)
355
+ var q: TensorF32 # query (dim,)
356
+ var k: TensorSlice # key (kv_dim,)
357
+ var v: TensorSlice # value (kv_dim,)
358
+ var att: TensorF32 # buffer for scores/attention values (n_heads, seq_len)
359
+ var logits: TensorF32 # output logits
360
+ var key_cache: TensorF32 # (layer, seq_len, dim)
361
+ var value_cache: TensorF32 # (layer, seq_len, dim)
362
+
363
+
364
+ fn __init__(inout self, config: Config) raises:
365
+ self.x = TensorF32(config.dim)
366
+ self.xb = TensorF32(config.dim)
367
+ self.xb2 = TensorF32(config.dim)
368
+ self.hb = TensorF32(config.hidden_dim)
369
+ self.hb2 = TensorF32(config.hidden_dim)
370
+ self.q = TensorF32(config.dim)
371
+ self.att = TensorF32(config.n_heads, config.seq_len)
372
+ self.logits = TensorF32(config.vocab_size)
373
+ self.key_cache = TensorF32(config.n_layers, config.seq_len, config.kv_dim)
374
+ self.value_cache = TensorF32(config.n_layers, config.seq_len, config.kv_dim)
375
+ # So their updates flow to the caches, k and v are slices with shared memory.
376
+ # Initialize with placeholders. The real tensors reference layer and position during forward pass.
377
+ self.k = TensorSlice(TensorF32(TensorShape(1, config.kv_dim)), 1)
378
+ self.v = TensorSlice(TensorF32(TensorShape(1, config.kv_dim)), 1)
379
+
 
 
 
 
 
 
 
 
380
 
381
 
382
  struct TransformerWeights:
383
+ var token_embedding_table: TensorF32
384
+ var freq_cis_real: TensorF32
385
+ var freq_cis_imag: TensorF32
386
+ var rms_att_weight: TensorF32
387
+ var wq: TensorF32
388
+ var wk: TensorF32
389
+ var wv: TensorF32
390
+ var wo: TensorF32
391
+ var rms_ffn_weight: TensorF32
392
+ var w1: TensorF32
393
+ var w3: TensorF32
394
+ var w2: TensorF32
395
+ var rms_final_weight: TensorF32
396
+ var wcls: TensorF32
397
+
398
+ fn __init__(
399
+ inout self, config: Config, shared_weights: Int, inout buf: FileBuf
400
+ ) raises:
401
+ fn load_weights(inout buf: FileBuf, *dims: Int) raises -> TensorF32:
402
+ # Ensure returned Tensor doesn't share a pointer with FileBuf
403
+ let shape = TensorShape(dims)
404
+ let result_data = BufferPtrFloat32.alloc(shape.num_elements())
405
+ memcpy(
406
+ result_data,
407
+ buf.bitcast_offset_f32(shape.num_elements()),
408
+ shape.num_elements(),
409
+ )
410
+ return TensorF32(result_data, shape)
411
+
412
+ self.token_embedding_table = load_weights(buf, config.vocab_size, config.dim)
413
+ self.rms_att_weight = load_weights(buf, config.n_layers, config.dim)
414
+ self.wq = load_weights(buf, config.n_layers, config.dim, config.dim)
415
+ self.wk = load_weights(buf, config.n_layers, config.kv_dim, config.dim)
416
+ self.wv = load_weights(buf, config.n_layers, config.kv_dim, config.dim)
417
+ self.wo = load_weights(buf, config.n_layers, config.dim, config.dim)
418
+ self.rms_ffn_weight = load_weights(buf, config.n_layers, config.dim)
419
+ self.w1 = load_weights(buf, config.n_layers, config.hidden_dim, config.dim)
420
+ self.w2 = load_weights(buf, config.n_layers, config.dim, config.hidden_dim)
421
+ self.w3 = load_weights(buf, config.n_layers, config.hidden_dim, config.dim)
422
+ self.rms_final_weight = load_weights(buf, config.dim)
423
+ # maybe need modifying for different model
424
+ # config.head_size // 2 for stories and tinyllama-1.1
425
+ self.freq_cis_real = load_weights(buf, config.seq_len, config.head_size // 2)
426
+ self.freq_cis_imag = load_weights(buf, config.seq_len, config.head_size // 2)
427
+ if shared_weights:
428
+ self.wcls = self.token_embedding_table
429
+ else:
430
+ self.wcls = load_weights(buf, config.vocab_size, config.dim)
 
 
 
 
 
 
 
 
 
 
 
431
 
432
 
433
  fn read_file(file_name: String, inout buf: FileBuf) raises:
 
460
  config.n_kv_heads = read_val_int(buf)
461
  config.vocab_size = read_val_int(buf)
462
  config.seq_len = read_val_int(buf)
463
+ config.head_size = config.dim // config.n_heads
464
+ config.kv_dim = (config.n_kv_heads * config.dim) // config.n_heads
465
+ config.kv_mul = config.n_heads // config.n_kv_heads
466
  return None
467
 
468
 
469
+ @always_inline
470
+ fn accum(inout a: TensorF32, b: TensorF32) -> None:
471
+ let size = a.dim(0)
 
 
 
 
 
 
 
 
 
 
 
472
 
473
+ @parameter
474
+ fn _acc[_nelts: Int](j: Int):
475
+ a.simd_store[_nelts](j, a.simd_load[_nelts](j) + b.simd_load[_nelts](j))
476
 
477
+ vectorize[nelts, _acc](size)
 
 
 
478
 
479
 
480
+ @always_inline
481
  fn rmsnorm(
482
  inout o: BufferPtrFloat32, x: BufferPtrFloat32, weight: BufferPtrFloat32, size: Int
483
  ) -> None:
484
  # Calculate sum of squares
485
+ var tmp = SIMD[DType.float32, nelts](0)
486
+
487
+ @parameter
488
+ fn _sum2[_nelts: Int](j: Int):
489
+ if _nelts < nelts:
490
+ tmp[0] += (x.offset(j).simd_load[_nelts](0) ** 2).reduce_add()
491
+ else:
492
+ tmp += x.offset(j).simd_load[nelts](0) ** 2
493
+
494
+ vectorize[nelts, _sum2](size)
495
+
496
+ var ss: Float32 = tmp.reduce_add()
497
  ss = ss / size + 1e-5
498
  ss = 1.0 / math.sqrt(ss)
499
+
500
  # Normalize and scale
501
+ @parameter
502
+ fn _norm[_nelts: Int](j: Int):
503
+ let val = weight.simd_load[_nelts](j) * ss * x.simd_load[_nelts](j)
504
+ o.offset(j).simd_store[_nelts](0, val)
505
+
506
+ vectorize[nelts, _norm](size)
507
+
508
+
509
+ @always_inline
510
+ fn rmsnorm(inout o: TensorF32, x: TensorF32, weight: TensorF32):
511
+ rmsnorm(o._ptr, x.data(), weight.data(), weight.dim(weight.rank() - 1))
512
+
513
+
514
+ @always_inline
515
+ fn rmsnorm(inout o: TensorF32, x: TensorF32, weight: TensorSlice):
516
+ rmsnorm(o._ptr, x.data(), weight.data(), weight.dim(weight.rank() - 1))
517
+
518
+
519
+ @always_inline
520
+ fn softmax(inout x: TensorF32) -> None:
521
+ softmax(x, 0, x.dim(0))
522
+
523
+
524
+ @always_inline
525
+ fn softmax(inout x: TensorF32, start: Int, end: Int):
526
+ var max_val: Float32 = -1e9
527
+
528
+ @parameter
529
+ fn _max[_nelts: Int](ii: Int):
530
+ let val = x.simd_load[_nelts](start + ii).reduce_max()
531
+ if val > max_val:
532
+ max_val = val
533
+
534
+ vectorize[nelts, _max](end - start)
535
+
536
  var ssum: Float32 = 0.0
537
+
538
+ @parameter
539
+ fn _exp[_nelts: Int](ii: Int):
540
+ x.simd_store[_nelts](
541
+ start + ii, math.exp(x.simd_load[_nelts](start + ii) - max_val)
542
+ )
543
+ ssum += x.simd_load[_nelts](start + ii).reduce_add()
544
+
545
+ vectorize[nelts, _exp](end - start)
546
+
547
+ @parameter
548
+ fn _norm[_nelts: Int](ii: Int):
549
+ x.simd_store[_nelts](start + ii, x.simd_load[_nelts](start + ii) / ssum)
550
+
551
+ vectorize[nelts, _norm](end - start)
552
+
553
+
554
+ @always_inline
555
+ fn matmul_parallelized(C: BufferPtrFloat32,A: BufferPtrFloat32,B: BufferPtrFloat32,rows: Int,cols: Int,):
556
+ @parameter
557
+ fn compute_row(i: Int):
558
  var tmp = SIMD[DType.float32, nelts](0)
559
 
560
  @parameter
561
  fn dot[_nelts: Int](j: Int):
562
+ if _nelts < nelts: # take care of tail array elements with length < nelts
563
+ tmp[0] += (
564
+ A.simd_load[_nelts](j) * B.simd_load[_nelts](i * cols + j)
565
+ ).reduce_add()
566
  else:
567
+ tmp += A.simd_load[nelts](j) * B.simd_load[nelts](i * cols + j)
568
 
569
+ vectorize[nelts, dot](cols)
570
+ C.store(i, tmp.reduce_add())
571
+
572
 
573
+ parallelize[compute_row](rows, workers)
 
 
 
 
 
 
 
 
 
 
574
 
575
+
 
576
 
577
+
578
+
579
+ @always_inline
580
+ fn matmul(C: TensorF32, A: TensorF32, B: TensorF32) raises:
581
+ # B (d,n) @ A (n,) -> C (d,)
582
+ matmul_dimension_checks(A.shape(), B.shape())
583
+ matmul_parallelized(C.data(), A.data(), B.data(), B.dim(0), B.dim(1))
584
 
585
 
586
+ @always_inline
587
+ fn matmul(C: TensorF32, A: TensorF32, B: TensorSlice) raises:
588
  # B (d,n) @ A (n,) -> C (d,)
589
+ matmul_dimension_checks(A.shape(), B.shape())
590
+ matmul_parallelized(C.data(), A.data(), B.data(), B.dim(0), B.dim(1))
591
+
592
+
593
+ @always_inline
594
+ fn matmul(C: TensorSlice, A: TensorF32, B: TensorSlice) raises:
595
+ # B (d,n) @ A (n,) -> C (d,)
596
+ matmul_dimension_checks(A.shape(), B.shape())
597
+ matmul_parallelized(C.data(), A.data(), B.data(), B.dim(0), B.dim(1))
598
+
599
+
600
+ fn matmul_dimension_checks(a: TensorShape, b: TensorShape) raises:
601
+ if a[0] != b[1]:
602
+ raise Error(
603
+ "matmul dimension mismatch. A rows (dim 0) not equal to B columns (dim 1)"
604
+ )
605
+ if b.rank() != 2:
606
+ raise Error("matmul expects B to be a 2D matrix")
607
 
608
 
609
+ # Apply RoPE rotation to the q and k vectors for each head
610
+ # rotate odd and even dim
611
+ @always_inline
612
+ fn rope_rotation_llama(
613
+ inout state: RunState,
614
+ freq_cis_real_row: TensorSlice,
615
+ freq_cis_imag_row: TensorSlice,
616
+ config: Config,
617
+ ) -> None:
618
+ # stories model, llama2
619
+ let head_size = config.head_size
620
+ @parameter
621
+ fn head_loop(i:Int):
622
+ # Simple vectorization with (head_size // 2) steps gave junk transformer output.
623
+ # Maybe because the nelt ranges end up overlapping between the steps.
624
+ for j in range(0, config.head_size, 2):
625
+ let fcr = freq_cis_real_row[j // 2]
626
+ let fci = freq_cis_imag_row[j // 2]
627
+ let q0 = state.q[i * head_size + j]
628
+ let q1 = state.q[i * head_size + j + 1]
629
+ state.q[i * head_size + j] = q0 * fcr - q1 * fci
630
+ state.q[i * head_size + j + 1] = q0 * fci + q1 * fcr
631
+ if i < config.n_kv_heads:
632
+ let k0 = state.k[i * head_size + j]
633
+ let k1 = state.k[i * head_size + j + 1]
634
+ state.k[i * head_size + j] = k0 * fcr - k1 * fci
635
+ state.k[i * head_size + j + 1] = k0 * fci + k1 * fcr
636
+ parallelize[head_loop](config.n_heads, workers)
637
+
638
+
639
+
640
+ @always_inline
641
  fn transformer(
642
  token: Int,
643
  pos: Int,
644
  config: Config,
645
  inout state: RunState,
646
  weights: TransformerWeights,
647
+ ) raises -> None:
648
  # A few convenience variables
 
649
  let dim = config.dim
650
  let hidden_dim = config.hidden_dim
651
+ let head_size = config.head_size
652
+ let kv_dim = config.kv_dim
653
+ let kv_mul = config.kv_mul
 
654
 
655
  # Copy the token embedding into x
656
+ let content_row = weights.token_embedding_table.data().offset(token * dim)
657
+ memcpy[DType.float32](state.x.data(), content_row, dim)
658
 
659
  # Pluck out the "pos" row of freq_cis_real and freq_cis_imag
660
+ let freq_cis_real_row = TensorSlice(weights.freq_cis_real, pos)
661
+ let freq_cis_imag_row = TensorSlice(weights.freq_cis_imag, pos)
662
 
663
  # Forward all the layers
664
  for l in range(config.n_layers):
665
  # Attention rmsnorm
666
+ rmsnorm(state.xb, state.x, TensorSlice(weights.rms_att_weight, l))
 
667
  # QKV matmuls for this position
668
+ matmul(state.q, state.xb, TensorSlice(weights.wq, l))
 
669
 
670
+ let loff = l * config.seq_len * config.kv_dim
671
+ state.k = TensorSlice(state.key_cache, l, pos)
672
+ matmul(state.k, state.xb, TensorSlice(weights.wk, l))
673
 
674
+ state.v = TensorSlice(state.value_cache, l, pos)
675
+ matmul(state.v, state.xb, TensorSlice(weights.wv, l))
676
 
677
  # Apply RoPE rotation to the q and k vectors for each head
678
+ rope_rotation_llama(state, freq_cis_real_row, freq_cis_imag_row, config)
679
+
680
+ memset_zero(state.xb.data(), state.xb.num_elements())
681
+
682
+ # Multihead attention. Iterate over all heads in parallel.
683
+ @parameter
684
+ fn loop_over_heads(h:Int):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
685
  # Get the query vector for this head
686
+ let q_offset = h * head_size
687
 
688
+ # Index of attention scores for this head
689
+ let att_offset = h * config.seq_len
690
 
691
  # Iterate over all timesteps, including the current one
692
  for t in range(pos + 1):
693
+ # Starting index of the key vector for this head and at this timestep
694
+ let k_offset = loff + t * kv_dim + (h // kv_mul) * head_size
695
  # Calculate the attention score as the dot product of q and k
696
  var score: Float32 = 0.0
697
+
698
+ @parameter
699
+ fn score_fn[_nelts: Int](i: Int):
700
+ score += (
701
+ state.q.simd_load[_nelts](q_offset + i)
702
+ * state.key_cache.simd_load[_nelts](k_offset + i)
703
+ ).reduce_add()
704
+
705
+ vectorize[nelts, score_fn](head_size)
706
  score /= math.sqrt[DType.float32, 1](head_size)
707
 
708
  # Save the score to the attention buffer
709
+ state.att[att_offset + t] = score
710
 
711
  # Softmax the scores to get attention weights, from 0..pos inclusively
712
+ softmax(state.att, att_offset, att_offset + pos + 1)
 
713
  # Weighted sum of the values, store back into xb
714
+ let xb_offset = h * head_size
 
715
  for t in range(pos + 1):
716
+ # Starting index of the value vector for this head and at this timestep
717
+ let v_offset = loff + t * kv_dim + (h // kv_mul) * head_size
718
+
719
  # Get the attention weight for this timestep
720
+ let a = state.att[att_offset + t]
721
  # Accumulate the weighted value into xb
 
 
 
 
 
 
 
 
722
 
723
+ @parameter
724
+ fn xb_accumulate[_nelts: Int](i: Int):
725
+ let xbi = state.xb.simd_load[_nelts](
726
+ xb_offset + i
727
+ ) + a * state.value_cache.simd_load[_nelts](v_offset + i)
728
+ state.xb.simd_store[_nelts](xb_offset + i, xbi)
729
 
730
+ vectorize[nelts, xb_accumulate](head_size)
731
+
732
+ parallelize[loop_over_heads](config.n_heads, workers)
733
+ # Final matrix multiplication to get the output of the attention
734
+ matmul(state.xb2, state.xb, TensorSlice(weights.wo, l))
735
+ # Residual connection back into x
736
+ accum(state.x, state.xb2)
737
  # FFN rmsnorm
738
+ rmsnorm(state.xb, state.x, TensorSlice(weights.rms_ffn_weight, l))
739
 
740
  # Calculate self.w1(x) and self.w3(x) for FFN
741
+ matmul(state.hb, state.xb, TensorSlice(weights.w1, l))
 
 
 
 
 
 
 
 
 
742
 
743
+ matmul(state.hb2, state.xb, TensorSlice(weights.w3, l))
 
 
744
 
745
+ @parameter
746
+ fn silu[_nelts: Int](i: Int):
747
+ let initial_hb = state.hb.simd_load[_nelts](i)
748
+ # Apply SiLU activation function (silu(x) = x * sigmoid(x))
749
+ let hbi = initial_hb * (1.0 / (1.0 + math.exp(-initial_hb)))
750
+ # Elementwise multiply with w3(x)
751
+ state.hb.simd_store[_nelts](i, hbi * state.hb2.simd_load[_nelts](i))
752
+
753
+ vectorize[nelts, silu](hidden_dim)
754
  # Final matrix multiplication to get the output of the FFN
755
+ matmul(state.xb, state.hb, TensorSlice(weights.w2, l))
 
756
 
757
  # Residual connection
758
+ accum(state.x, state.xb)
759
 
760
  # Final rmsnorm
761
+ rmsnorm(state.x, state.x, weights.rms_final_weight)
762
 
763
  # Classifier into logits
764
+ matmul(state.logits, state.x, weights.wcls)
 
765
 
766
 
767
+ fn argmax(v: TensorF32) -> Int:
768
  # return argmax of v
769
  var max_i: Int = 0
770
  var max_p: Float32 = v[0]
771
+ for i in range(v.dim(0)):
772
  if v[i] > max_p:
773
  max_i = i
774
  max_p = v[i]
775
  return max_i
776
 
777
 
778
+ fn sample(probabilities: TensorF32) -> Int:
779
+ let n = probabilities.dim(0)
780
  # Sample index from probabilities, they must sum to 1
781
  # get random value within (min, max) float32 range
782
  let r = DTypePointer[DType.float32].alloc(1)
 
784
  var cdf: Float32 = 0.0
785
  for i in range(n):
786
  cdf += probabilities[i]
787
+ if r.load(0) < cdf:
788
  return i
789
  return n - 1 # In case of rounding errors
790
 
791
 
792
+ fn bpe_encode(inout tokens: DynamicVector[Int], text: String, inout tok: Tokenizer):
793
+ for pos in range(len(text)):
794
+ let char = str_to_ptr(text[pos])
795
+ let id = tok.find(char)
796
+
797
+ if id == -1:
798
+ print("Not a good prompt token at pos ", pos)
799
+ return
800
+ tokens.push_back(id)
801
+
802
+ while True:
803
+ var best_score = Float32(-1e10)
804
+ var best_id = -1
805
+ var best_idx = -1
806
+
807
+ for i in range(len(tokens) - 1):
808
+ # Check if we can merge the pair (tokens[i], tokens[i+1])
809
+ let str = str_concat(tok.vocab[tokens[i]], tok.vocab[tokens[i + 1]])
810
+ let id = tok.find(str)
811
+ if id != -1 and tok.vocab_scores.load(id) > best_score:
812
+ best_score = tok.vocab_scores.load(id)
813
+ best_id = id
814
+ best_idx = i
815
+
816
+ if best_idx == -1:
817
+ # We couldn't find any more pairs to merge, so we're done
818
+ break
819
+
820
+ # Merge the consecutive pair (best_idx, best_idx+1) into new token best_id
821
+ tokens[best_idx] = best_id
822
+ # Delete token at position best_idx+1, shift the entire sequence back 1
823
+ var _tokens = DynamicVector[Int]()
824
+ for i in range(0, best_idx + 1):
825
+ _tokens.push_back(tokens[i])
826
+ for i in range(best_idx + 2, len(tokens)):
827
+ _tokens.push_back(tokens[i])
828
+ tokens = _tokens
829
+
830
+
831
+ fn str2num(d: Int) -> Int:
832
+ # covert Hex to decimal
833
+ if d >= ord("A"):
834
+ return d - ord("A") + 10
835
+ return d - ord("0")
836
+
837
+
838
  fn print_str(s: PointerString):
839
+ # print raw byte like <0x0A>
840
+ if (s[1].to_int() == ord("0")) and (s[2].to_int() == ord("x")):
841
+ let d1: Int = s[3].to_int()
842
+ let d2: Int = s[4].to_int()
843
+ print_no_newline(chr(str2num(d1) * 16 + str2num(d2)))
844
+ return
845
  # print all chars till null character
846
  var p: Int = 0
847
  while s[p].to_int() != 0:
 
854
  return time.now() // 1_000_000
855
 
856
 
857
+ fn print_usage():
858
+ print("Usage: mojo llama2.mojo <checkpoint> [options]")
859
+ print(
860
+ 'Example: mojo llama2.mojo stories15M.bin -s 99 -n 256 -t 0.5 -i "Llama is an'
861
+ ' animal"'
862
+ )
863
+ print("Options:")
864
+ print(" -s <int> random seed, default time.now()")
865
+ print(" -t <float> temperature in [0,1.0], default 1.0")
866
+ print(" -n <int> number of steps to run for, default 256. 0 = max_seq_len")
867
+ print(" -i <string> input prompt")
868
+ print(" -z tokenizer path")
869
+ print(" -j number of workers to use, default num_cores()")
870
+
871
+
872
  fn main() raises:
873
+ workers = num_cores()
874
+ var tokenizer = StringRef("tokenizer.bin")
875
+ var checkpoint = StringRef("stories15M.bin")
876
+ var temperature = 0.9
 
877
  var steps = 256
878
+ var prompt = String("")
879
+ var rng_seed: Int = time.now()
880
+
881
+ @parameter
882
+ fn argparse() raises -> Int:
883
+ let args = argv()
884
+ if len(args) < 2:
885
+ return 0
886
+ checkpoint = args[1]
887
+ for i in range(2, len(args), 2):
888
+ if args[i] == "-p":
889
+ print("Option not supported: ", args[i])
890
+ if args[i] == "-n":
891
+ steps = atol(args[i + 1])
892
+ if args[i] == "-z":
893
+ tokenizer = args[i + 1]
894
+ if args[i] == "-s":
895
+ rng_seed = atol(args[i + 1])
896
+ if args[i] == "-i":
897
+ prompt = args[i + 1]
898
+ if args[i] == "-j":
899
+ workers = atol(args[i + 1])
900
+ if args[i] == "-t":
901
+ let val = args[i + 1]
902
+ temperature = 0.0
903
+ # hacky parse float, keep only 1 digit
904
+ for c in range(0, len(val)):
905
+ if val[c] == ".":
906
+ temperature += atol(val[c + 1]) * 0.1
907
+ break
908
+ else:
909
+ temperature = atol(val[c])
910
+ if temperature < -1e9 or temperature > (1 + 1e9):
911
+ print("Wrong temperature value", temperature)
912
+ return 0
913
+ return 1
914
+
915
+ let res = argparse()
916
+ if res == 0:
917
+ print_usage()
918
+ return
919
+
920
+ print("num parallel workers:", workers, " SIMD width:", nelts)
921
  random.seed(rng_seed)
922
  var fbuf: FileBuf = FileBuf()
923
  var tbuf: FileBuf = FileBuf()
924
  var config: Config = Config()
925
 
926
  read_file(checkpoint, fbuf)
 
927
  config_init(config, fbuf)
928
 
929
  # negative vocab size is hacky way of signaling unshared weights. bit yikes.
 
934
 
935
  let weights: TransformerWeights = TransformerWeights(config, shared_weights, fbuf)
936
 
 
 
937
  if steps <= 0 or steps > config.seq_len:
938
  steps = config.seq_len
939
 
940
  # Read in the tokenizer.bin file
941
  read_file(tokenizer, tbuf)
942
+ var tok = Tokenizer(config.vocab_size, tbuf)
943
+
944
+ # print the layers number and vocab size
945
+ print("checkpoint size: ", fbuf.size, "[", fbuf.size // 1024 // 1024, "MB ]",
946
+ "| n layers:", config.n_layers, "| vocab size:", tok.vocab_size)
947
 
948
  # Create and initialize the application RunState
949
  var state = RunState(config)
950
 
951
+ # Process the prompt, if any
952
+ var prompt_tokens = DynamicVector[Int]()
953
+
954
+ if prompt:
955
+ bpe_encode(prompt_tokens, prompt, tok)
956
+
957
  # Start the main loop
958
  var start = 0 # Used to time our code, only initialized after the first iteration
959
  var next_token = 0 # Will store the next token in the sequence
960
  # Initialize with token 1 (=BOS), as done in Llama-2 sentencepiece tokenizer
961
  var token = 1
 
 
 
 
962
 
963
+ # Position in the sequence
964
+ var pos = 0
965
  while pos < steps:
966
  # Forward the transformer to get logits for the next token
967
  transformer(token, pos, config, state, weights)
968
 
969
+ if pos < len(prompt_tokens):
970
+ next_token = prompt_tokens[pos]
 
 
971
  else:
972
+ # Sample the next token
973
+ if temperature == 0.0:
974
+ # Greedy argmax sampling: take the token with the highest probability
975
+ next_token = argmax(state.logits)
976
+ else:
977
+ # Apply the temperature to the logits
978
+ for q in range(config.vocab_size):
979
+ state.logits[q] = state.logits[q] / temperature
980
+
981
+ # Apply softmax to the logits to get the probabilities for the next token
982
+ softmax(state.logits)
983
+ # Sample from this distribution to get the next token
984
+ next_token = sample(state.logits)
985
+
986
+ # Finish generating when EOS, BOS appear
987
+ if next_token == 1 or next_token == 2:
988
+ break
989
  var token_str: PointerString = tok.vocab[next_token]
990
  if token == 1 and token_str[0] == ord(" "):
991
  token_str = token_str.offset(1)
992
 
993
  print_str(token_str)
 
994
 
995
  # Advance forward
996
  token = next_token
 
1000
  start = time_in_ms()
1001
 
1002
  let end = time_in_ms()
1003
+ print("\nachieved tok/s: ", (pos - 1) / (end - start) * 1000)