radames commited on
Commit
c40c794
1 Parent(s): c614a13
Files changed (2) hide show
  1. gradio_app.py +16 -66
  2. llama2.mojo +264 -483
gradio_app.py CHANGED
@@ -1,86 +1,36 @@
1
  import gradio as gr
2
  import subprocess
3
  import sys
4
- from pathlib import Path
5
 
6
 
7
- async def generate(prompt, model_name, seed=0, temperature=0.5, num_tokens=256):
 
8
  # stream stout
9
  process = subprocess.Popen(
10
- [
11
- "mojo",
12
- "llama2.mojo",
13
- Path(model_name),
14
- "-s",
15
- str(seed),
16
- "-n",
17
- str(num_tokens),
18
- "-t",
19
- str(temperature),
20
- "-i",
21
- prompt,
22
- ],
23
- stdout=subprocess.PIPE,
24
- stderr=subprocess.PIPE,
25
  )
26
  text = ""
27
  for char in iter(lambda: process.stdout.read(1), b""):
28
- char_decoded = char.decode("utf-8", errors="ignore")
 
29
  text += char_decoded
30
  yield text
31
 
32
 
33
- with gr.Blocks() as demo:
34
- gr.Markdown(
35
- """
 
 
 
 
36
  # llama2.🔥
37
  ## [Mojo](https://docs.modular.com/mojo/) implementation of [llama2.c](https://github.com/karpathy/llama2.c) by [@tairov](https://github.com/tairov)
38
  Source: https://github.com/tairov/llama2.mojo
39
- """
40
- )
41
- with gr.Row():
42
- with gr.Column():
43
- prompt = gr.Textbox(label="Prompt", placeholder="Add your prompt here...")
44
- seed = gr.Slider(
45
- minimum=0,
46
- maximum=2**53,
47
- value=0,
48
- step=1,
49
- label="Seed",
50
- randomize=True,
51
- )
52
- temperature = gr.Slider(
53
- minimum=0.0, maximum=2.0, step=0.01, value=0.5, label="Temperature"
54
- )
55
- num_tokens = gr.Slider(
56
- minimum=1, maximum=256, value=256, label="Number of tokens"
57
- )
58
- model_name = gr.Dropdown(
59
- ["stories15M.bin", "stories42M.bin", "stories110M.bin"],
60
- value="stories15M.bin",
61
- label="Model Size",
62
- )
63
- with gr.Row():
64
- stop = gr.Button("Stop")
65
- run = gr.Button("Run")
66
- with gr.Column(scale=2):
67
- output_text = gr.Textbox(label="Generated Text")
68
-
69
- # update maximum number of tokens based on model size
70
- model_name.change(
71
- lambda x: gr.update(maximum=1024)
72
- if x == "stories110M.bin" or x == "stories42M.bin"
73
- else gr.update(maximum=256),
74
- model_name,
75
- num_tokens,
76
- queue=False,
77
- )
78
- click_event = run.click(
79
- fn=generate,
80
- inputs=[prompt, model_name, seed, temperature, num_tokens],
81
- outputs=output_text,
82
- )
83
- stop.click(fn=None, inputs=None, outputs=None, cancels=[click_event])
84
 
85
  demo.queue()
86
  demo.launch(server_name="0.0.0.0")
 
1
  import gradio as gr
2
  import subprocess
3
  import sys
4
+ import os
5
 
6
 
7
+ async def generate(prompt):
8
+ # os.environ["PROMPT"] = prompt
9
  # stream stout
10
  process = subprocess.Popen(
11
+ ["mojo", "llama2.mojo"], stdout=subprocess.PIPE, stderr=subprocess.PIPE
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  )
13
  text = ""
14
  for char in iter(lambda: process.stdout.read(1), b""):
15
+ char_decoded = char.decode()
16
+ sys.stdout.write(char_decoded)
17
  text += char_decoded
18
  yield text
19
 
20
 
21
+ output_text = gr.Textbox(label="Generated Text")
22
+
23
+ demo = gr.Interface(
24
+ fn=generate,
25
+ inputs=None,
26
+ outputs=output_text,
27
+ description="""
28
  # llama2.🔥
29
  ## [Mojo](https://docs.modular.com/mojo/) implementation of [llama2.c](https://github.com/karpathy/llama2.c) by [@tairov](https://github.com/tairov)
30
  Source: https://github.com/tairov/llama2.mojo
31
+ """,
32
+ allow_flagging="never",
33
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
  demo.queue()
36
  demo.launch(server_name="0.0.0.0")
llama2.mojo CHANGED
@@ -1,22 +1,25 @@
1
- from algorithm import sum
2
- from algorithm import vectorize, parallelize
3
- from builtin import string
4
  from math import round
 
 
5
  from memory import memset_zero, memcpy
6
- from memory.buffer import Buffer
7
  from memory.unsafe import DTypePointer
8
- from python import Python
9
  from random import rand
 
 
 
 
 
 
 
 
10
  from read import BufReader, File
11
- from runtime.llcl import num_cores, Runtime
12
- from sys import argv
 
13
 
14
  # The SIMD vector width.
15
- from sys.info import simdwidthof
16
- import math
17
- import os
18
- import random
19
- import time
20
 
21
  alias nelts = (2 * simdwidthof[DType.float32]())
22
 
@@ -26,51 +29,98 @@ alias BufferPtrFloat32 = DTypePointer[DType.float32]
26
  alias PointerStrings = Pointer[PointerString]
27
 
28
 
29
- struct Matrix:
30
  var data: BufferPtrFloat32
31
  var rows: Int
32
  var cols: Int
33
  var layers: Int
34
  var allocated: Int
35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  fn __init__(inout self, rows: Int, cols: Int):
37
  self.data = BufferPtrFloat32.alloc(0)
38
  self.rows = rows
39
  self.cols = cols
40
- self.layers = 1
41
  self.allocated = 0
42
 
43
  fn __init__(inout self, cols: Int):
44
  self.data = BufferPtrFloat32.alloc(0)
45
  self.rows = 1
46
- self.layers = 1
47
  self.cols = cols
48
  self.allocated = 0
49
 
50
- fn __init__(inout self, layers: Int, rows: Int, cols: Int):
51
- self.__init__(rows, cols)
52
- self.layers = layers
53
-
54
  fn __del__(owned self):
55
  if self.allocated == 1:
56
  self.data.free()
57
 
58
- @always_inline
59
  fn alloc(inout self, fill: Int = 0):
60
  self.data = BufferPtrFloat32.alloc(self.size())
61
  self.allocated = 1
62
  if fill == 1:
63
  self.zero()
64
 
65
- @always_inline
66
  fn alloc_zero(inout self):
67
  self.alloc(1)
68
 
69
- @always_inline
70
  fn zero(inout self):
71
- memset_zero(self.data, self.size())
72
 
73
- @always_inline
74
  fn set_buf_ptr(inout self, ptr: BufferPtrFloat32):
75
  self.data = ptr
76
 
@@ -80,9 +130,8 @@ struct Matrix:
80
  self.rows = rows
81
  self.cols = cols
82
 
83
- @always_inline
84
  fn size(inout self) -> Int:
85
- return self.cols * self.rows * self.layers
86
 
87
  @always_inline
88
  fn __getitem__(self, y: Int, x: Int) -> Float32:
@@ -116,131 +165,32 @@ struct Matrix:
116
  fn store[nelts: Int](self, x: Int, val: SIMD[DType.float32, nelts]):
117
  self.data.simd_store[nelts](x, val)
118
 
119
- @always_inline
120
- fn __getitem__(self, z: Int, y: Int, x: Int) -> Float32:
121
- return self.load[1](z, y, x)
122
-
123
- @always_inline
124
- fn load[nelts: Int](self, z: Int, y: Int, x: Int) -> SIMD[DType.float32, nelts]:
125
- return self.data.simd_load[nelts](z * self.layers + y * self.cols + x)
126
-
127
- @always_inline
128
- fn __setitem__(self, z: Int, y: Int, x: Int, val: Float32):
129
- return self.store[1](z, y, x, val)
130
-
131
- @always_inline
132
- fn store[nelts: Int](self, z: Int, y: Int, x: Int, val: SIMD[DType.float32, nelts]):
133
- self.data.simd_store[nelts](z * self.layers + y * self.cols + x, val)
134
 
135
-
136
- fn read_val_int(inout buf: FileBuf) raises -> Int:
137
  # DTypePointer[DType.ui8](buf.data).bitcast[DType.ui8]()
138
- let data = buf.data.offset(buf.get_offset()).bitcast[DType.uint32]()
139
- let result = data.load(0)
140
- buf.move_offset(4)
141
  return result.to_int()
142
 
143
 
144
- fn read_val_float32(inout buf: FileBuf) raises -> Float32:
145
  # DTypePointer[DType.ui8](buf.data).bitcast[DType.ui8]()
146
- let val = buf.data.offset(buf.get_offset()).bitcast[DType.float32]().load(0)
147
- buf.move_offset(4)
148
  return val
149
 
150
 
151
- fn read_val_str(inout buf: FileBuf, slen: Int) raises -> PointerString:
152
-
153
  let str = PointerString.alloc(slen + 1)
154
  for i in range(slen):
155
- str.store(i, buf.data.load(buf.get_offset()))
156
- buf.move_offset(1)
157
  str.store(slen, 0)
158
 
159
  return str
160
 
161
 
162
- # not optimal concat
163
- fn str_concat(s1: PointerString, s2: PointerString) -> PointerString:
164
- var l1 = 0
165
- var l2 = 0
166
-
167
- while s1[l1] != 0:
168
- l1 += 1
169
- while s2[l2] != 0:
170
- l2 += 1
171
-
172
- let str = PointerString.alloc(l1 + l2 + 1)
173
- memcpy[UInt8](str, s1, l1)
174
- memcpy[UInt8](str.offset(l1), s2, l2)
175
- str.store(l1 + l2, 0)
176
- return str
177
-
178
-
179
- fn str_to_ptr(s: String) -> PointerString:
180
- let ret = PointerString.alloc(len(s) + 1)
181
- for i in range(len(s)):
182
- ret.store(i, ord(s[i]))
183
- ret.store(len(s), 0)
184
- return ret
185
-
186
-
187
- fn string_compare(a: PointerString, b: PointerString) -> Int:
188
- var index = 0
189
- while a[index] != 0 and b[index] != 0:
190
- if a[index] < b[index]:
191
- return -1
192
- if a[index] > b[index]:
193
- return 1
194
-
195
- index += 1
196
-
197
- if a[index] != 0 and b[index] == 0:
198
- return 1
199
-
200
- if a[index] == 0 and b[index] != 0:
201
- return -1
202
-
203
- return 0
204
-
205
-
206
- # Quicksort helper function to find the partition position
207
- fn partition(
208
- inout array: PointerStrings, inout indices: DynamicVector[Int], low: Int, high: Int
209
- ) -> Int:
210
- let pivot = array[high]
211
- var ii = low - 1
212
- for jj in range(low, high):
213
- if string_compare(pivot, array[jj]) == 1:
214
- # If element smaller than pivot, swap
215
- ii = ii + 1
216
-
217
- let tmp = array[ii]
218
- let tmp_idx = indices[ii]
219
- array.store(ii, array[jj])
220
- indices[ii] = indices[jj]
221
- array.store(jj, tmp)
222
- indices[jj] = tmp_idx
223
-
224
- # Swap the pivot element
225
- let tmp = array[ii + 1]
226
- let tmp_idx = indices[ii + 1]
227
- array.store(ii + 1, array[high])
228
- indices[ii + 1] = indices[high]
229
- array.store(high, tmp)
230
- indices[high] = tmp_idx
231
-
232
- return ii + 1
233
-
234
-
235
- fn quicksort(
236
- inout array: PointerStrings, inout indices: DynamicVector[Int], low: Int, high: Int
237
- ):
238
- if low < high:
239
- let pi = partition(array, indices, low, high)
240
- quicksort(array, indices, low, pi - 1)
241
- quicksort(array, indices, pi + 1, high)
242
-
243
-
244
  struct FileBuf:
245
  var data: BufferPtrType
246
  var offset: Int
@@ -251,95 +201,36 @@ struct FileBuf:
251
  self.offset = 0
252
  self.size = 0
253
 
254
- fn move_offset(inout self, size: Int) raises:
255
- let new_offset = self.offset + size
256
- if new_offset > self.size:
257
- raise Error("Resulting offset will be past the end of the FileBuf")
258
- if new_offset < 0:
259
- raise Error("Resulting offset will be before the beginning of the FileBuf")
260
- self.offset = new_offset
261
 
262
- fn bitcast_offset_float32(inout self, size: Int) raises -> BufferPtrFloat32:
263
  let ret = self.data.offset(self.offset).bitcast[DType.float32]()
264
- self.move_offset(size * sizeof[DType.float32]())
265
  return ret
266
 
267
- fn get_offset(self) raises -> Int:
268
- if self.offset > self.size:
269
- raise Error("Offset is past the end of the FileBuf")
270
- if self.offset < 0:
271
- raise Error("Offset is before the beginning of the FileBuf")
272
- return self.offset
273
-
274
 
275
  struct Tokenizer:
276
  var vocab: PointerStrings
277
  var vocab_scores: BufferPtrFloat32
278
  var max_token_length: Int
279
  var vocab_size: Int
280
- var sorted_vocab: PointerStrings
281
- var sorted_indices: DynamicVector[Int]
282
 
283
- fn __init__(inout self, vocab_size: Int, inout buf: FileBuf) raises -> None:
284
  self.vocab_size = vocab_size
285
- self.max_token_length = read_val_int(buf)
286
- self.vocab_scores = BufferPtrFloat32.alloc(self.vocab_size)
287
- self.vocab = PointerStrings.alloc(self.vocab_size)
288
- # lazy load sorted vocab
289
- self.sorted_vocab = PointerStrings.alloc(0)
290
- self.sorted_indices = DynamicVector[Int](0)
291
-
292
- # read vocab_scores & vocab values (tokens)
293
- for i in range(0, self.vocab_size):
294
- self.vocab_scores.store(i, read_val_float32(buf))
295
- let slen = read_val_int(buf)
296
- self.vocab.store(i, read_val_str(buf, slen))
297
-
298
- return None
299
-
300
- # sort vocab by string_compare
301
- fn sort(inout self) -> None:
302
- if len(self.sorted_indices) < self.vocab_size:
303
- self.sorted_indices = DynamicVector[Int](self.vocab_size)
304
- self.sorted_vocab = PointerStrings.alloc(self.vocab_size)
305
- for ii in range(self.vocab_size):
306
- self.sorted_vocab.store(ii, self.vocab[ii])
307
- self.sorted_indices.push_back(ii)
308
-
309
- let n = self.vocab_size
310
- quicksort(self.sorted_vocab, self.sorted_indices, 0, n - 1)
311
- return None
312
-
313
- # Binary search that returns -1 if string is not found
314
- fn find(inout self, token: PointerString) -> Int:
315
- let n = self.vocab_size
316
- if len(self.sorted_indices) < n:
317
- self.sort()
318
- var left = 0
319
- var right = n - 1
320
- while left <= right:
321
- let mid = left + (right - left) // 2
322
- let comparison = string_compare(self.sorted_vocab[mid], token)
323
- if comparison == 0:
324
- return self.sorted_indices[mid]
325
- if comparison < 0:
326
- left = mid + 1
327
- else:
328
- right = mid - 1
329
- return -1
330
 
331
 
332
  struct Config:
333
  var dim: Int
334
- var kv_dim: Int
335
  var hidden_dim: Int
336
  var n_layers: Int
337
  var n_heads: Int
338
  var n_kv_heads: Int
339
- var kv_mul: Int
340
  var vocab_size: Int
341
  var seq_len: Int
342
- var head_size: Int
343
 
344
  fn __init__(inout self):
345
  self.dim = 0
@@ -349,9 +240,6 @@ struct Config:
349
  self.n_kv_heads = 0
350
  self.vocab_size = 0
351
  self.seq_len = 0
352
- self.kv_dim = 0
353
- self.kv_mul = 0
354
- self.head_size = 0
355
 
356
 
357
  struct RunState:
@@ -361,13 +249,12 @@ struct RunState:
361
  var hb: Matrix # buffer for hidden dimension in the ffn (hidden_dim,)
362
  var hb2: Matrix # buffer for hidden dimension in the ffn (hidden_dim,)
363
  var q: Matrix # query (dim,)
364
- var k: Matrix # key (kv_dim,)
365
- var v: Matrix # value (kv_dim,)
366
  var att: Matrix # buffer for scores/attention values (n_heads, seq_len)
367
  var logits: Matrix # output logits
368
- var key_cache: Matrix # (layer, seq_len, dim)
369
- var value_cache: Matrix # (layer, seq_len, dim)
370
- var rt: Runtime
371
 
372
  fn __init__(inout self, config: Config):
373
  self.x = Matrix(config.dim)
@@ -382,17 +269,18 @@ struct RunState:
382
  self.hb2.alloc_zero()
383
  self.q = Matrix(config.dim)
384
  self.q.alloc_zero()
385
- self.k = Matrix(0, 0)
386
- self.v = Matrix(0, 0)
 
 
387
  self.att = Matrix(config.n_heads, config.seq_len)
388
  self.att.alloc_zero()
389
  self.logits = Matrix(config.vocab_size)
390
  self.logits.alloc_zero()
391
- self.key_cache = Matrix(config.n_layers, config.seq_len, config.kv_dim)
392
  self.key_cache.alloc_zero()
393
- self.value_cache = Matrix(config.n_layers, config.seq_len, config.kv_dim)
394
  self.value_cache.alloc_zero()
395
- self.rt = Runtime(num_cores() // 2)
396
 
397
 
398
  struct TransformerWeights:
@@ -400,18 +288,18 @@ struct TransformerWeights:
400
  var freq_cis_real: Matrix
401
  var freq_cis_imag: Matrix
402
  var rms_att_weight: Matrix
403
- var wq: Matrix
404
- var wk: Matrix
405
- var wv: Matrix
406
- var wo: Matrix
407
  var rms_ffn_weight: Matrix
408
- var w1: Matrix
409
- var w3: Matrix
410
- var w2: Matrix
411
  var rms_final_weight: Matrix
412
  var wcls: Matrix
413
 
414
- fn __init__(inout self, config: Config, shared_weights: Int, inout buf: FileBuf) raises:
415
  self.token_embedding_table = Matrix(config.vocab_size, config.dim)
416
  # set buf ptr to buf data from file
417
  self.token_embedding_table.set_buf_ptr(
@@ -421,23 +309,23 @@ struct TransformerWeights:
421
  self.rms_att_weight.set_buf_ptr(
422
  buf.bitcast_offset_float32(self.rms_att_weight.size())
423
  )
424
- self.wq = Matrix(config.n_layers, config.dim, config.dim)
425
  self.wq.set_buf_ptr(buf.bitcast_offset_float32(self.wq.size()))
426
- self.wk = Matrix(config.n_layers, config.dim, config.kv_dim)
427
  self.wk.set_buf_ptr(buf.bitcast_offset_float32(self.wk.size()))
428
- self.wv = Matrix(config.n_layers, config.dim, config.kv_dim)
429
  self.wv.set_buf_ptr(buf.bitcast_offset_float32(self.wv.size()))
430
- self.wo = Matrix(config.n_layers, config.dim, config.dim)
431
  self.wo.set_buf_ptr(buf.bitcast_offset_float32(self.wo.size()))
432
  self.rms_ffn_weight = Matrix(config.n_layers, config.dim)
433
  self.rms_ffn_weight.set_buf_ptr(
434
  buf.bitcast_offset_float32(self.rms_ffn_weight.size())
435
  )
436
- self.w1 = Matrix(config.n_layers, config.dim, config.hidden_dim)
437
  self.w1.set_buf_ptr(buf.bitcast_offset_float32(self.w1.size()))
438
- self.w2 = Matrix(config.n_layers, config.dim, config.hidden_dim)
439
  self.w2.set_buf_ptr(buf.bitcast_offset_float32(self.w2.size()))
440
- self.w3 = Matrix(config.n_layers, config.dim, config.hidden_dim)
441
  self.w3.set_buf_ptr(buf.bitcast_offset_float32(self.w3.size()))
442
  self.rms_final_weight = Matrix(config.dim)
443
  self.rms_final_weight.set_buf_ptr(
@@ -487,87 +375,82 @@ fn config_init(inout config: Config, inout buf: FileBuf) raises:
487
  config.n_kv_heads = read_val_int(buf)
488
  config.vocab_size = read_val_int(buf)
489
  config.seq_len = read_val_int(buf)
490
- config.head_size = config.dim // config.n_heads
491
- config.kv_dim = (config.n_kv_heads * config.dim) // config.n_heads
492
- config.kv_mul = config.n_heads // config.n_kv_heads
493
  return None
494
 
495
 
496
- fn accum(inout a: BufferPtrFloat32, b: BufferPtrFloat32, size: Int) -> None:
497
- @parameter
498
- fn _acc[_nelts: Int](j: Int):
499
- a.offset(j).simd_store[_nelts](
500
- 0, a.offset(j).simd_load[_nelts](0) + b.offset(j).simd_load[_nelts](0)
501
- )
 
 
 
 
 
 
 
 
 
502
 
503
- vectorize[nelts, _acc](size)
 
 
 
504
 
505
 
506
  fn rmsnorm(
507
  inout o: BufferPtrFloat32, x: BufferPtrFloat32, weight: BufferPtrFloat32, size: Int
508
  ) -> None:
509
  # Calculate sum of squares
510
- var tmp = SIMD[DType.float32, nelts](0)
511
-
512
- @parameter
513
- fn _sum2[_nelts: Int](j: Int):
514
- if _nelts < nelts:
515
- tmp[0] += (x.offset(j).simd_load[_nelts](0) ** 2).reduce_add()
516
- else:
517
- tmp += x.offset(j).simd_load[nelts](0) ** 2
518
-
519
- vectorize[nelts, _sum2](size)
520
-
521
- var ss: Float32 = tmp.reduce_add()
522
  ss = ss / size + 1e-5
523
  ss = 1.0 / math.sqrt(ss)
524
-
525
  # Normalize and scale
526
- @parameter
527
- fn _norm[_nelts: Int](j: Int):
528
- let val = weight.simd_load[_nelts](j) * ss * x.simd_load[_nelts](j)
529
- o.offset(j).simd_store[_nelts](0, val)
530
-
531
- vectorize[nelts, _norm](size)
532
 
533
 
534
  fn softmax(inout x: BufferPtrFloat32, size: Int) -> None:
535
  # Find max value (for numerical stability)
536
- var max_val: Float32 = -1e9
537
-
538
- @parameter
539
- fn _max[_nelts: Int](j: Int):
540
- let val = x.simd_load[_nelts](j).reduce_max()
541
- if val > max_val:
542
- max_val = val
543
-
544
- vectorize[nelts, _max](size)
545
-
546
  # Exp and sum
547
  var ssum: Float32 = 0.0
548
-
549
- @parameter
550
- fn _sum_exp[_nelts: Int](j: Int):
551
- x.simd_store[_nelts](j, math.exp(x.simd_load[_nelts](j) - max_val))
552
- ssum += x.simd_load[_nelts](j).reduce_add()
553
-
554
- vectorize[nelts, _sum_exp](size)
555
-
556
- @parameter
557
- fn _norm[_nelts: Int](j: Int):
558
- x.simd_store[_nelts](j, x.simd_load[_nelts](j) / ssum)
559
-
560
- vectorize[nelts, _norm](size)
561
-
562
-
563
- fn matmul_parallelized(C: Matrix, A: Matrix, B: Matrix, rt: Runtime):
564
- @parameter
565
- fn compute_row(i: Int):
 
 
 
566
  var tmp = SIMD[DType.float32, nelts](0)
567
 
568
  @parameter
569
  fn dot[_nelts: Int](j: Int):
570
- if _nelts < nelts: # take care of tail array elements with length < nelts
571
  tmp[0] += (A.load[_nelts](j) * B.load[_nelts](i, j)).reduce_add()
572
  else:
573
  tmp += A.load[nelts](j) * B.load[nelts](i, j)
@@ -575,12 +458,28 @@ fn matmul_parallelized(C: Matrix, A: Matrix, B: Matrix, rt: Runtime):
575
  vectorize[nelts, dot](B.cols)
576
  C[i] = tmp.reduce_add()
577
 
578
- parallelize[compute_row](rt, B.rows, rt.parallelism_level())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
579
 
580
 
581
- fn matmul(inout C: Matrix, A: Matrix, B: Matrix, rt: Runtime) -> None:
582
  # B (d,n) @ A (n,) -> C (d,)
583
- matmul_parallelized(C, A, B, rt)
 
584
 
585
 
586
  fn transformer(
@@ -594,9 +493,7 @@ fn transformer(
594
  var x = state.x.data
595
  let dim = config.dim
596
  let hidden_dim = config.hidden_dim
597
- let head_size = config.head_size
598
- let kv_dim = config.kv_dim
599
- let kv_mul = config.kv_mul
600
 
601
  # tmp matrix for matmul operations
602
  var tmpw = Matrix(0, 0)
@@ -616,43 +513,39 @@ fn transformer(
616
 
617
  # QKV matmuls for this position
618
  tmpw.set_buf_ptr(weights.wq.data.offset(l * dim * dim), dim, dim)
619
- matmul(state.q, state.xb, tmpw, state.rt)
620
 
621
- let loff = l * config.seq_len * kv_dim
622
- state.k.set_buf_ptr(state.key_cache.data.offset(loff + pos * kv_dim), 1, kv_dim)
623
- tmpw.set_buf_ptr(weights.wk.data.offset(l * dim * kv_dim), kv_dim, dim)
624
- matmul(state.k, state.xb, tmpw, state.rt)
625
 
626
- state.v.set_buf_ptr(
627
- state.value_cache.data.offset(loff + pos * kv_dim), 1, kv_dim
628
- )
629
- tmpw.set_buf_ptr(weights.wv.data.offset(l * dim * kv_dim), kv_dim, dim)
630
- matmul(state.v, state.xb, tmpw, state.rt)
631
 
632
  # Apply RoPE rotation to the q and k vectors for each head
633
- let q = state.q.data
634
- let k = state.k.data
635
- for i in range(0, head_size * config.n_kv_heads, 2):
636
- let head_dim_half = i % head_size // 2
637
- let fcr = freq_cis_real_row.offset(head_dim_half).load(0)
638
- let fci = freq_cis_imag_row.offset(head_dim_half).load(0)
639
- let q0 = q.offset(i).load(0)
640
- let q1 = q.offset(i + 1).load(0)
641
- let k0 = k.offset(i).load(0)
642
- let k1 = k.offset(i + 1).load(0)
643
- q.offset(i).store(0, q0 * fcr - q1 * fci)
644
- q.offset(i + 1).store(0, q0 * fci + q1 * fcr)
645
- k.offset(i).store(0, k0 * fcr - k1 * fci)
646
- k.offset(i + 1).store(0, k0 * fci + k1 * fcr)
647
-
648
- for i in range(head_size * config.n_kv_heads, dim, 2):
649
- let head_dim_half = i % head_size // 2
650
- let fcr = freq_cis_real_row.offset(head_dim_half).load(0)
651
- let fci = freq_cis_imag_row.offset(head_dim_half).load(0)
652
- let q0 = q.offset(i).load(0)
653
- let q1 = q.offset(i + 1).load(0)
654
- q.offset(i).store(0, q0 * fcr - q1 * fci)
655
- q.offset(i + 1).store(0, q0 * fci + q1 * fcr)
 
656
 
657
  # Multihead attention. Iterate over all heads
658
  for h in range(config.n_heads):
@@ -665,17 +558,15 @@ fn transformer(
665
  # Iterate over all timesteps, including the current one
666
  for t in range(pos + 1):
667
  # Get the key vector for this head and at this timestep
668
- let k = state.key_cache.data.offset(
669
- loff + t * kv_dim + (h // kv_mul) * head_size
670
- )
671
  # Calculate the attention score as the dot product of q and k
672
  var score: Float32 = 0.0
673
  for i in range(head_size):
674
- score += q.offset(i).load(0) * k.offset(i).load(0)
675
  score /= math.sqrt[DType.float32, 1](head_size)
676
 
677
  # Save the score to the attention buffer
678
- att.offset(t).store(0, score)
679
 
680
  # Softmax the scores to get attention weights, from 0..pos inclusively
681
  softmax(att, pos + 1)
@@ -685,18 +576,18 @@ fn transformer(
685
  memset_zero(xb, head_size)
686
  for t in range(pos + 1):
687
  # Get the value vector for this head and at this timestep
688
- let v = state.value_cache.data.offset(
689
- loff + t * kv_dim + (h // kv_mul) * head_size
690
- )
691
  # Get the attention weight for this timestep
692
- let a = att.offset(t).load(0)
693
  # Accumulate the weighted value into xb
694
  for i in range(head_size):
695
- let xbi = xb.offset(i).load(0) + a * v.offset(i).load(0)
696
- xb.offset(i).store(0, xbi)
 
 
697
  # Final matrix multiplication to get the output of the attention
698
  tmpw.set_buf_ptr(weights.wo.data.offset(l * dim * dim), dim, dim)
699
- matmul(state.xb2, state.xb, tmpw, state.rt)
700
 
701
  # Residual connection back into x
702
  accum(x, state.xb2.data, dim)
@@ -706,10 +597,10 @@ fn transformer(
706
 
707
  # Calculate self.w1(x) and self.w3(x) for FFN
708
  tmpw.set_buf_ptr(weights.w1.data.offset(l * dim * hidden_dim), hidden_dim, dim)
709
- matmul(state.hb, state.xb, tmpw, state.rt)
710
 
711
  tmpw.set_buf_ptr(weights.w3.data.offset(l * dim * hidden_dim), hidden_dim, dim)
712
- matmul(state.hb2, state.xb, tmpw, state.rt)
713
 
714
  # Apply SiLU activation function (silu(x) = x * sigmoid(x))
715
  for i in range(hidden_dim):
@@ -722,7 +613,7 @@ fn transformer(
722
 
723
  # Final matrix multiplication to get the output of the FFN
724
  tmpw.set_buf_ptr(weights.w2.data.offset(l * dim * hidden_dim), dim, hidden_dim)
725
- matmul(state.xb, state.hb, tmpw, state.rt)
726
 
727
  # Residual connection
728
  accum(x, state.xb.data, dim)
@@ -732,7 +623,7 @@ fn transformer(
732
 
733
  # Classifier into logits
734
  tmpw.set_buf_ptr(weights.wcls.data, config.vocab_size, dim)
735
- matmul(state.logits, state.x, tmpw, state.rt)
736
 
737
 
738
  fn argmax(v: Matrix) -> Int:
@@ -755,64 +646,12 @@ fn sample(probabilities: Matrix) -> Int:
755
  var cdf: Float32 = 0.0
756
  for i in range(n):
757
  cdf += probabilities[i]
758
- if r.load(0) < cdf:
759
  return i
760
  return n - 1 # In case of rounding errors
761
 
762
 
763
- fn bpe_encode(inout tokens: DynamicVector[Int], text: String, inout tok: Tokenizer):
764
- for pos in range(len(text)):
765
- let char = str_to_ptr(text[pos])
766
- let id = tok.find(char)
767
-
768
- if id == -1:
769
- print("Not a good prompt token at pos ", pos)
770
- return
771
- tokens.push_back(id)
772
-
773
- while True:
774
- var best_score = Float32(-1e10)
775
- var best_id = -1
776
- var best_idx = -1
777
-
778
- for i in range(len(tokens) - 1):
779
- # Check if we can merge the pair (tokens[i], tokens[i+1])
780
- let str = str_concat(tok.vocab[tokens[i]], tok.vocab[tokens[i + 1]])
781
- let id = tok.find(str)
782
- if id != -1 and tok.vocab_scores.load(id) > best_score:
783
- best_score = tok.vocab_scores.load(id)
784
- best_id = id
785
- best_idx = i
786
-
787
- if best_idx == -1:
788
- # We couldn't find any more pairs to merge, so we're done
789
- break
790
-
791
- # Merge the consecutive pair (best_idx, best_idx+1) into new token best_id
792
- tokens[best_idx] = best_id
793
- # Delete token at position best_idx+1, shift the entire sequence back 1
794
- var _tokens = DynamicVector[Int]()
795
- for i in range(0, best_idx + 1):
796
- _tokens.push_back(tokens[i])
797
- for i in range(best_idx + 2, len(tokens)):
798
- _tokens.push_back(tokens[i])
799
- tokens = _tokens
800
-
801
-
802
- fn str2num(d: Int) -> Int:
803
- # covert Hex to decimal
804
- if d >= ord("A"):
805
- return d - ord("A") + 10
806
- return d - ord("0")
807
-
808
-
809
  fn print_str(s: PointerString):
810
- # print raw byte like <0x0A>
811
- if (s[1].to_int() == ord("0")) and (s[2].to_int() == ord("x")):
812
- let d1: Int = s[3].to_int()
813
- let d2: Int = s[4].to_int()
814
- print_no_newline(chr(str2num(d1) * 16 + str2num(d2)))
815
- return
816
  # print all chars till null character
817
  var p: Int = 0
818
  while s[p].to_int() != 0:
@@ -825,73 +664,22 @@ fn time_in_ms() -> Int:
825
  return time.now() // 1_000_000
826
 
827
 
828
- fn print_usage():
829
- print("Usage: mojo llama2.mojo <checkpoint> [options]")
830
- print(
831
- 'Example: mojo llama2.mojo stories15M.bin -s 99 -n 256 -t 0.5 -i "Llama is an'
832
- ' animal"'
833
- )
834
- print("Options:")
835
- print(" -s <int> random seed, default time.now()")
836
- print(" -t <float> temperature in [0,1.0], default 1.0")
837
- print(" -n <int> number of steps to run for, default 256. 0 = max_seq_len")
838
- print(" -i <string> input prompt")
839
-
840
-
841
  fn main() raises:
842
- print("num hardware threads: ", num_cores())
843
- print("SIMD vector width: ", nelts)
844
- var tokenizer = StringRef("tokenizer.bin")
845
- var checkpoint = StringRef("stories15M.bin")
846
- var temperature = 0.9
847
  var steps = 256
848
- var prompt = String("")
849
- var rng_seed: Int = time.now()
850
-
851
- @parameter
852
- fn argparse() raises -> Int:
853
- let args = argv()
854
- if len(args) < 2:
855
- return 0
856
- checkpoint = args[1]
857
- for i in range(2, len(args), 2):
858
- if args[i] == "-p":
859
- print("Option not supported: ", args[i])
860
- if args[i] == "-n":
861
- steps = atol(args[i + 1])
862
- if args[i] == "-tk":
863
- tokenizer = args[i + 1]
864
- if args[i] == "-s":
865
- rng_seed = atol(args[i + 1])
866
- if args[i] == "-i":
867
- prompt = args[i + 1]
868
- if args[i] == "-t":
869
- let val = args[i + 1]
870
- temperature = 0.0
871
- # hacky parse float, keep only 1 digit
872
- for c in range(0, len(val)):
873
- if val[c] == ".":
874
- temperature += atol(val[c + 1]) * 0.1
875
- break
876
- else:
877
- temperature = atol(val[c])
878
- if temperature < -1e9 or temperature > (1 + 1e9):
879
- print("Wrong temperature value", temperature)
880
- return 0
881
- return 1
882
-
883
- let res = argparse()
884
- if res == 0:
885
- print_usage()
886
- return
887
-
888
  random.seed(rng_seed)
889
  var fbuf: FileBuf = FileBuf()
890
  var tbuf: FileBuf = FileBuf()
891
  var config: Config = Config()
892
 
893
  read_file(checkpoint, fbuf)
894
- print("checkpoint size: ", fbuf.size, "[", fbuf.size // 1024 // 1024, "MB ]")
895
  config_init(config, fbuf)
896
 
897
  # negative vocab size is hacky way of signaling unshared weights. bit yikes.
@@ -902,58 +690,51 @@ fn main() raises:
902
 
903
  let weights: TransformerWeights = TransformerWeights(config, shared_weights, fbuf)
904
 
 
 
905
  if steps <= 0 or steps > config.seq_len:
906
  steps = config.seq_len
907
 
908
  # Read in the tokenizer.bin file
909
  read_file(tokenizer, tbuf)
910
- var tok = Tokenizer(config.vocab_size, tbuf)
911
 
912
  # Create and initialize the application RunState
913
  var state = RunState(config)
914
 
915
- # Process the prompt, if any
916
- var prompt_tokens = DynamicVector[Int]()
917
-
918
- if prompt:
919
- bpe_encode(prompt_tokens, prompt, tok)
920
-
921
  # Start the main loop
922
  var start = 0 # Used to time our code, only initialized after the first iteration
923
  var next_token = 0 # Will store the next token in the sequence
924
  # Initialize with token 1 (=BOS), as done in Llama-2 sentencepiece tokenizer
925
  var token = 1
 
 
 
 
926
 
927
- # Position in the sequence
928
- var pos = 0
929
  while pos < steps:
930
  # Forward the transformer to get logits for the next token
931
  transformer(token, pos, config, state, weights)
932
 
933
- if pos < len(prompt_tokens):
934
- next_token = prompt_tokens[pos]
 
 
935
  else:
936
- # Sample the next token
937
- if temperature == 0.0:
938
- # Greedy argmax sampling: take the token with the highest probability
939
- next_token = argmax(state.logits)
940
- else:
941
- # Apply the temperature to the logits
942
- for q in range(config.vocab_size):
943
- state.logits[q] = state.logits[q] / temperature
944
- # Apply softmax to the logits to get the probabilities for the next token
945
- softmax(state.logits.data, config.vocab_size)
946
- # Sample from this distribution to get the next token
947
- next_token = sample(state.logits)
948
-
949
- # Finish generating when EOS, BOS appear
950
- if next_token == 1 or next_token == 2:
951
- break
952
  var token_str: PointerString = tok.vocab[next_token]
953
  if token == 1 and token_str[0] == ord(" "):
954
  token_str = token_str.offset(1)
955
 
956
  print_str(token_str)
 
957
 
958
  # Advance forward
959
  token = next_token
@@ -963,4 +744,4 @@ fn main() raises:
963
  start = time_in_ms()
964
 
965
  let end = time_in_ms()
966
- print("\nachieved tok/s: ", (pos - 1) / (end - start) * 1000)
 
 
 
 
1
  from math import round
2
+ import math
3
+
4
  from memory import memset_zero, memcpy
 
5
  from memory.unsafe import DTypePointer
 
6
  from random import rand
7
+ from sys.info import simdwidthof
8
+ from builtin import string
9
+ import time
10
+ import random
11
+ import os
12
+
13
+ from runtime.llcl import num_cores
14
+
15
  from read import BufReader, File
16
+ from memory.buffer import Buffer
17
+
18
+ from python import Python
19
 
20
  # The SIMD vector width.
21
+ from algorithm import vectorize, parallelize
22
+ from algorithm import sum
 
 
 
23
 
24
  alias nelts = (2 * simdwidthof[DType.float32]())
25
 
 
29
  alias PointerStrings = Pointer[PointerString]
30
 
31
 
32
+ struct Matrix3:
33
  var data: BufferPtrFloat32
34
  var rows: Int
35
  var cols: Int
36
  var layers: Int
37
  var allocated: Int
38
 
39
+ fn __init__(inout self, layers: Int, rows: Int, cols: Int):
40
+ self.data = BufferPtrFloat32.alloc(0)
41
+ self.rows = rows
42
+ self.cols = cols
43
+ self.layers = layers
44
+ self.allocated = 0
45
+
46
+ @always_inline
47
+ fn alloc(inout self, fill: Int = 0):
48
+ self.data = BufferPtrFloat32.alloc(self.size())
49
+ self.allocated = 1
50
+ if fill == 1:
51
+ self.zero()
52
+
53
+ @always_inline
54
+ fn alloc_zero(inout self):
55
+ self.alloc(1)
56
+
57
+ @always_inline
58
+ fn set_buf_ptr(inout self, ptr: BufferPtrFloat32):
59
+ self.data = ptr
60
+
61
+ fn __del__(owned self):
62
+ if self.allocated == 1:
63
+ self.data.free()
64
+
65
+ @always_inline
66
+ fn zero(inout self):
67
+ memset_zero(self.data, self.layers * self.rows * self.cols)
68
+
69
+ @always_inline
70
+ fn size(inout self) -> Int:
71
+ return self.layers * self.cols * self.rows
72
+
73
+ @always_inline
74
+ fn __getitem__(self, z: Int, y: Int, x: Int) -> Float32:
75
+ return self.load[1](z, y, x)
76
+
77
+ @always_inline
78
+ fn load[nelts: Int](self, z: Int, y: Int, x: Int) -> SIMD[DType.float32, nelts]:
79
+ return self.data.simd_load[nelts](z * self.layers + y * self.cols + x)
80
+
81
+ @always_inline
82
+ fn __setitem__(self, z: Int, y: Int, x: Int, val: Float32):
83
+ return self.store[1](z, y, x, val)
84
+
85
+ @always_inline
86
+ fn store[nelts: Int](self, z: Int, y: Int, x: Int, val: SIMD[DType.float32, nelts]):
87
+ self.data.simd_store[nelts](z * self.layers + y * self.cols + x, val)
88
+
89
+
90
+ struct Matrix:
91
+ var data: BufferPtrFloat32
92
+ var rows: Int
93
+ var cols: Int
94
+ var allocated: Int
95
+
96
  fn __init__(inout self, rows: Int, cols: Int):
97
  self.data = BufferPtrFloat32.alloc(0)
98
  self.rows = rows
99
  self.cols = cols
 
100
  self.allocated = 0
101
 
102
  fn __init__(inout self, cols: Int):
103
  self.data = BufferPtrFloat32.alloc(0)
104
  self.rows = 1
 
105
  self.cols = cols
106
  self.allocated = 0
107
 
 
 
 
 
108
  fn __del__(owned self):
109
  if self.allocated == 1:
110
  self.data.free()
111
 
 
112
  fn alloc(inout self, fill: Int = 0):
113
  self.data = BufferPtrFloat32.alloc(self.size())
114
  self.allocated = 1
115
  if fill == 1:
116
  self.zero()
117
 
 
118
  fn alloc_zero(inout self):
119
  self.alloc(1)
120
 
 
121
  fn zero(inout self):
122
+ memset_zero(self.data, self.rows * self.cols)
123
 
 
124
  fn set_buf_ptr(inout self, ptr: BufferPtrFloat32):
125
  self.data = ptr
126
 
 
130
  self.rows = rows
131
  self.cols = cols
132
 
 
133
  fn size(inout self) -> Int:
134
+ return self.cols * self.rows
135
 
136
  @always_inline
137
  fn __getitem__(self, y: Int, x: Int) -> Float32:
 
165
  fn store[nelts: Int](self, x: Int, val: SIMD[DType.float32, nelts]):
166
  self.data.simd_store[nelts](x, val)
167
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
 
169
+ fn read_val_int(inout buf: FileBuf) -> Int:
 
170
  # DTypePointer[DType.ui8](buf.data).bitcast[DType.ui8]()
171
+ let data = buf.data.offset(buf.offset).bitcast[DType.uint32]()
172
+ let result = data.simd_load[1](0)
173
+ buf.offset += 4
174
  return result.to_int()
175
 
176
 
177
+ fn read_val_float32(inout buf: FileBuf) -> Float32:
178
  # DTypePointer[DType.ui8](buf.data).bitcast[DType.ui8]()
179
+ let val = buf.data.offset(buf.offset).bitcast[DType.float32]().simd_load[1](0)
180
+ buf.offset += 4
181
  return val
182
 
183
 
184
+ fn read_val_str(inout buf: FileBuf, slen: Int) -> PointerString:
 
185
  let str = PointerString.alloc(slen + 1)
186
  for i in range(slen):
187
+ str.store(i, buf.data.simd_load[1](buf.offset))
188
+ buf.offset += 1
189
  str.store(slen, 0)
190
 
191
  return str
192
 
193
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
194
  struct FileBuf:
195
  var data: BufferPtrType
196
  var offset: Int
 
201
  self.offset = 0
202
  self.size = 0
203
 
204
+ fn move_offset(inout self, size: Int):
205
+ self.offset += size
 
 
 
 
 
206
 
207
+ fn bitcast_offset_float32(inout self, size: Int) -> BufferPtrFloat32:
208
  let ret = self.data.offset(self.offset).bitcast[DType.float32]()
209
+ self.offset += size * sizeof[DType.float32]()
210
  return ret
211
 
 
 
 
 
 
 
 
212
 
213
  struct Tokenizer:
214
  var vocab: PointerStrings
215
  var vocab_scores: BufferPtrFloat32
216
  var max_token_length: Int
217
  var vocab_size: Int
 
 
218
 
219
+ fn __init__(inout self, vocab_size: Int):
220
  self.vocab_size = vocab_size
221
+ self.vocab = PointerStrings.alloc(vocab_size)
222
+ self.vocab_scores = BufferPtrFloat32.alloc(vocab_size)
223
+ self.max_token_length = 0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
224
 
225
 
226
  struct Config:
227
  var dim: Int
 
228
  var hidden_dim: Int
229
  var n_layers: Int
230
  var n_heads: Int
231
  var n_kv_heads: Int
 
232
  var vocab_size: Int
233
  var seq_len: Int
 
234
 
235
  fn __init__(inout self):
236
  self.dim = 0
 
240
  self.n_kv_heads = 0
241
  self.vocab_size = 0
242
  self.seq_len = 0
 
 
 
243
 
244
 
245
  struct RunState:
 
249
  var hb: Matrix # buffer for hidden dimension in the ffn (hidden_dim,)
250
  var hb2: Matrix # buffer for hidden dimension in the ffn (hidden_dim,)
251
  var q: Matrix # query (dim,)
252
+ var k: Matrix # key (dim,)
253
+ var v: Matrix # value (dim,)
254
  var att: Matrix # buffer for scores/attention values (n_heads, seq_len)
255
  var logits: Matrix # output logits
256
+ var key_cache: Matrix3 # (layer, seq_len, dim)
257
+ var value_cache: Matrix3 # (layer, seq_len, dim)
 
258
 
259
  fn __init__(inout self, config: Config):
260
  self.x = Matrix(config.dim)
 
269
  self.hb2.alloc_zero()
270
  self.q = Matrix(config.dim)
271
  self.q.alloc_zero()
272
+ self.k = Matrix(config.dim)
273
+ self.k.alloc_zero()
274
+ self.v = Matrix(config.dim)
275
+ self.v.alloc_zero()
276
  self.att = Matrix(config.n_heads, config.seq_len)
277
  self.att.alloc_zero()
278
  self.logits = Matrix(config.vocab_size)
279
  self.logits.alloc_zero()
280
+ self.key_cache = Matrix3(config.n_layers, config.seq_len, config.dim)
281
  self.key_cache.alloc_zero()
282
+ self.value_cache = Matrix3(config.n_layers, config.seq_len, config.dim)
283
  self.value_cache.alloc_zero()
 
284
 
285
 
286
  struct TransformerWeights:
 
288
  var freq_cis_real: Matrix
289
  var freq_cis_imag: Matrix
290
  var rms_att_weight: Matrix
291
+ var wq: Matrix3
292
+ var wk: Matrix3
293
+ var wv: Matrix3
294
+ var wo: Matrix3
295
  var rms_ffn_weight: Matrix
296
+ var w1: Matrix3
297
+ var w3: Matrix3
298
+ var w2: Matrix3
299
  var rms_final_weight: Matrix
300
  var wcls: Matrix
301
 
302
+ fn __init__(inout self, config: Config, shared_weights: Int, inout buf: FileBuf):
303
  self.token_embedding_table = Matrix(config.vocab_size, config.dim)
304
  # set buf ptr to buf data from file
305
  self.token_embedding_table.set_buf_ptr(
 
309
  self.rms_att_weight.set_buf_ptr(
310
  buf.bitcast_offset_float32(self.rms_att_weight.size())
311
  )
312
+ self.wq = Matrix3(config.n_layers, config.dim, config.dim)
313
  self.wq.set_buf_ptr(buf.bitcast_offset_float32(self.wq.size()))
314
+ self.wk = Matrix3(config.n_layers, config.dim, config.dim)
315
  self.wk.set_buf_ptr(buf.bitcast_offset_float32(self.wk.size()))
316
+ self.wv = Matrix3(config.n_layers, config.dim, config.dim)
317
  self.wv.set_buf_ptr(buf.bitcast_offset_float32(self.wv.size()))
318
+ self.wo = Matrix3(config.n_layers, config.dim, config.dim)
319
  self.wo.set_buf_ptr(buf.bitcast_offset_float32(self.wo.size()))
320
  self.rms_ffn_weight = Matrix(config.n_layers, config.dim)
321
  self.rms_ffn_weight.set_buf_ptr(
322
  buf.bitcast_offset_float32(self.rms_ffn_weight.size())
323
  )
324
+ self.w1 = Matrix3(config.n_layers, config.dim, config.hidden_dim)
325
  self.w1.set_buf_ptr(buf.bitcast_offset_float32(self.w1.size()))
326
+ self.w2 = Matrix3(config.n_layers, config.dim, config.hidden_dim)
327
  self.w2.set_buf_ptr(buf.bitcast_offset_float32(self.w2.size()))
328
+ self.w3 = Matrix3(config.n_layers, config.dim, config.hidden_dim)
329
  self.w3.set_buf_ptr(buf.bitcast_offset_float32(self.w3.size()))
330
  self.rms_final_weight = Matrix(config.dim)
331
  self.rms_final_weight.set_buf_ptr(
 
375
  config.n_kv_heads = read_val_int(buf)
376
  config.vocab_size = read_val_int(buf)
377
  config.seq_len = read_val_int(buf)
 
 
 
378
  return None
379
 
380
 
381
+ fn tokenizer_init(inout tok: Tokenizer, inout buf: FileBuf) -> None:
382
+ tok.max_token_length = read_val_int(buf)
383
+ tok.vocab_scores = BufferPtrFloat32.alloc(tok.vocab_size)
384
+ tok.vocab = PointerStrings.alloc(tok.vocab_size)
385
+
386
+ # read vocab_scores & vocab values (tokens)
387
+ for i in range(0, tok.vocab_size):
388
+ tok.vocab_scores.simd_store[1](i, read_val_float32(buf))
389
+ let slen = read_val_int(buf)
390
+ tok.vocab.store(i, read_val_str(buf, slen))
391
+
392
+ tok.vocab_scores = buf.data.offset(buf.offset).bitcast[DType.float32]()
393
+ buf.offset += tok.vocab_size * 4
394
+ return None
395
+
396
 
397
+ fn accum(inout a: BufferPtrFloat32, b: BufferPtrFloat32, size: Int) -> None:
398
+ for i in range(size):
399
+ let val = a.offset(i).simd_load[1](0) + b.offset(i).simd_load[1](0)
400
+ a.offset(i).simd_store[1](0, val)
401
 
402
 
403
  fn rmsnorm(
404
  inout o: BufferPtrFloat32, x: BufferPtrFloat32, weight: BufferPtrFloat32, size: Int
405
  ) -> None:
406
  # Calculate sum of squares
407
+ var ss: Float32 = 0.0
408
+ for i in range(size):
409
+ let xx = x.offset(i).simd_load[1](0) ** 2
410
+ ss += xx
 
 
 
 
 
 
 
 
411
  ss = ss / size + 1e-5
412
  ss = 1.0 / math.sqrt(ss)
 
413
  # Normalize and scale
414
+ for j in range(size):
415
+ let val = weight.offset(j).simd_load[1](0) * (ss * x.offset(j).simd_load[1](0))
416
+ o.offset(j).simd_store[1](0, val)
 
 
 
417
 
418
 
419
  fn softmax(inout x: BufferPtrFloat32, size: Int) -> None:
420
  # Find max value (for numerical stability)
421
+ var max_val: Float32 = x.offset(0).simd_load[1](0)
422
+ for i in range(size):
423
+ let xi = x.offset(i).simd_load[1](0)
424
+ if xi > max_val:
425
+ max_val = xi
 
 
 
 
 
426
  # Exp and sum
427
  var ssum: Float32 = 0.0
428
+ for i in range(size):
429
+ let xi = x.offset(i).simd_load[1](0)
430
+ x.offset(i).simd_store[1](0, math.exp(xi - max_val))
431
+ ssum += x.offset(i).simd_load[1](0)
432
+ # Normalize
433
+ for i in range(size):
434
+ let xi = x.offset(i).simd_load[1](0)
435
+ x.offset(i).simd_store[1](0, xi / ssum)
436
+
437
+
438
+ fn matmul_naive(C: Matrix, x: Matrix, w: Matrix) -> None:
439
+ # W(d,n) @ X(n,) -> C (d,)
440
+ # By far the most amount of time is spent inside this little function
441
+ for i in range(w.rows):
442
+ C[i] = 0.0
443
+ for j in range(w.cols):
444
+ C[i] += x[j] * w[i, j]
445
+
446
+
447
+ fn matmul_vectorized(C: Matrix, A: Matrix, B: Matrix):
448
+ for i in range(0, B.rows):
449
  var tmp = SIMD[DType.float32, nelts](0)
450
 
451
  @parameter
452
  fn dot[_nelts: Int](j: Int):
453
+ if _nelts < nelts: # take care of tail array elements with length < nelts
454
  tmp[0] += (A.load[_nelts](j) * B.load[_nelts](i, j)).reduce_add()
455
  else:
456
  tmp += A.load[nelts](j) * B.load[nelts](i, j)
 
458
  vectorize[nelts, dot](B.cols)
459
  C[i] = tmp.reduce_add()
460
 
461
+ fn matmul_parallelized(C: Matrix, A: Matrix, B: Matrix):
462
+ @parameter
463
+ fn calc_row(i: Int):
464
+ var T = BufferPtrFloat32.alloc(nelts)
465
+ var Tbuf = Buffer[nelts, DType.float32](T)
466
+ memset_zero(T, nelts)
467
+ @parameter
468
+ fn dot[nelts: Int](j: Int):
469
+ T.simd_store[nelts](
470
+ 0, T.simd_load[nelts](0) + A.load[nelts](j) * B.load[nelts](i, j)
471
+ )
472
+
473
+ vectorize[nelts, dot](B.cols)
474
+ C[i] = sum[nelts, DType.float32](Tbuf)
475
+
476
+ parallelize[calc_row](B.rows)
477
 
478
 
479
+ fn matmul(inout C: Matrix, A: Matrix, B: Matrix) -> None:
480
  # B (d,n) @ A (n,) -> C (d,)
481
+ matmul_vectorized(C, A, B)
482
+ # matmul_parallelized(C, A, B)
483
 
484
 
485
  fn transformer(
 
493
  var x = state.x.data
494
  let dim = config.dim
495
  let hidden_dim = config.hidden_dim
496
+ let head_size = dim // config.n_heads
 
 
497
 
498
  # tmp matrix for matmul operations
499
  var tmpw = Matrix(0, 0)
 
513
 
514
  # QKV matmuls for this position
515
  tmpw.set_buf_ptr(weights.wq.data.offset(l * dim * dim), dim, dim)
516
+ matmul(state.q, state.xb, tmpw)
517
 
518
+ tmpw.set_buf_ptr(weights.wk.data.offset(l * dim * dim), dim, dim)
519
+ matmul(state.k, state.xb, tmpw)
 
 
520
 
521
+ tmpw.set_buf_ptr(weights.wv.data.offset(l * dim * dim), dim, dim)
522
+ matmul(state.v, state.xb, tmpw)
 
 
 
523
 
524
  # Apply RoPE rotation to the q and k vectors for each head
525
+ for h in range(config.n_heads):
526
+ # Get the q and k vectors for this head
527
+ let q = state.q.data.offset(h * head_size)
528
+ let k = state.k.data.offset(h * head_size)
529
+
530
+ # Rotate q and k by the freq_cis_real and freq_cis_imag
531
+ for i in range(0, head_size, 2):
532
+ let q0 = q.offset(i).simd_load[1](0)
533
+ let q1 = q.offset(i + 1).simd_load[1](0)
534
+ let k0 = k.offset(i).simd_load[1](0)
535
+ let k1 = k.offset(i + 1).simd_load[1](0)
536
+ let fcr = freq_cis_real_row.offset(i // 2).simd_load[1](0)
537
+ let fci = freq_cis_imag_row.offset(i // 2).simd_load[1](0)
538
+ q.offset(i).simd_store[1](0, q0 * fcr - q1 * fci)
539
+ q.offset(i + 1).simd_store[1](0, q0 * fci + q1 * fcr)
540
+ k.offset(i).simd_store[1](0, k0 * fcr - k1 * fci)
541
+ k.offset(i + 1).simd_store[1](0, k0 * fci + k1 * fcr)
542
+
543
+ # Save key,value at this time step (pos) to our kv cache
544
+ let loff = l * config.seq_len * dim # kv cache layer offset for convenience
545
+ let key_cache_row = state.key_cache.data.offset(loff + pos * dim)
546
+ let value_cache_row = state.value_cache.data.offset(loff + pos * dim)
547
+ memcpy[DType.float32](key_cache_row, state.k.data, config.dim)
548
+ memcpy[DType.float32](value_cache_row, state.v.data, config.dim)
549
 
550
  # Multihead attention. Iterate over all heads
551
  for h in range(config.n_heads):
 
558
  # Iterate over all timesteps, including the current one
559
  for t in range(pos + 1):
560
  # Get the key vector for this head and at this timestep
561
+ let k = state.key_cache.data.offset(loff + t * dim + h * head_size)
 
 
562
  # Calculate the attention score as the dot product of q and k
563
  var score: Float32 = 0.0
564
  for i in range(head_size):
565
+ score += q.offset(i).simd_load[1](0) * k.offset(i).simd_load[1](0)
566
  score /= math.sqrt[DType.float32, 1](head_size)
567
 
568
  # Save the score to the attention buffer
569
+ att.offset(t).simd_store[1](0, score)
570
 
571
  # Softmax the scores to get attention weights, from 0..pos inclusively
572
  softmax(att, pos + 1)
 
576
  memset_zero(xb, head_size)
577
  for t in range(pos + 1):
578
  # Get the value vector for this head and at this timestep
579
+ let v = state.value_cache.data.offset(loff + t * dim + h * head_size)
 
 
580
  # Get the attention weight for this timestep
581
+ let a = att.offset(t).simd_load[1](0)
582
  # Accumulate the weighted value into xb
583
  for i in range(head_size):
584
+ let xbi = xb.offset(i).simd_load[1](0) + a * v.offset(i).simd_load[
585
+ 1
586
+ ](0)
587
+ xb.offset(i).simd_store[1](0, xbi)
588
  # Final matrix multiplication to get the output of the attention
589
  tmpw.set_buf_ptr(weights.wo.data.offset(l * dim * dim), dim, dim)
590
+ matmul(state.xb2, state.xb, tmpw)
591
 
592
  # Residual connection back into x
593
  accum(x, state.xb2.data, dim)
 
597
 
598
  # Calculate self.w1(x) and self.w3(x) for FFN
599
  tmpw.set_buf_ptr(weights.w1.data.offset(l * dim * hidden_dim), hidden_dim, dim)
600
+ matmul(state.hb, state.xb, tmpw)
601
 
602
  tmpw.set_buf_ptr(weights.w3.data.offset(l * dim * hidden_dim), hidden_dim, dim)
603
+ matmul(state.hb2, state.xb, tmpw)
604
 
605
  # Apply SiLU activation function (silu(x) = x * sigmoid(x))
606
  for i in range(hidden_dim):
 
613
 
614
  # Final matrix multiplication to get the output of the FFN
615
  tmpw.set_buf_ptr(weights.w2.data.offset(l * dim * hidden_dim), dim, hidden_dim)
616
+ matmul(state.xb, state.hb, tmpw)
617
 
618
  # Residual connection
619
  accum(x, state.xb.data, dim)
 
623
 
624
  # Classifier into logits
625
  tmpw.set_buf_ptr(weights.wcls.data, config.vocab_size, dim)
626
+ matmul(state.logits, state.x, tmpw)
627
 
628
 
629
  fn argmax(v: Matrix) -> Int:
 
646
  var cdf: Float32 = 0.0
647
  for i in range(n):
648
  cdf += probabilities[i]
649
+ if r.simd_load[1](0) < cdf:
650
  return i
651
  return n - 1 # In case of rounding errors
652
 
653
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
654
  fn print_str(s: PointerString):
 
 
 
 
 
 
655
  # print all chars till null character
656
  var p: Int = 0
657
  while s[p].to_int() != 0:
 
664
  return time.now() // 1_000_000
665
 
666
 
 
 
 
 
 
 
 
 
 
 
 
 
 
667
  fn main() raises:
668
+ print("num hardware threads: ", num_cores(), " SIMD vector width: ", nelts)
669
+ let checkpoint = "stories15M.bin"
670
+ # let checkpoint = "stories110M.bin"
671
+ let tokenizer = "tokenizer.bin"
672
+ let temperature = 0.0
673
  var steps = 256
674
+ let prompt = ""
675
+ let rng_seed: Int = time.now()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
676
  random.seed(rng_seed)
677
  var fbuf: FileBuf = FileBuf()
678
  var tbuf: FileBuf = FileBuf()
679
  var config: Config = Config()
680
 
681
  read_file(checkpoint, fbuf)
682
+ print("checkpoint size: ", fbuf.size)
683
  config_init(config, fbuf)
684
 
685
  # negative vocab size is hacky way of signaling unshared weights. bit yikes.
 
690
 
691
  let weights: TransformerWeights = TransformerWeights(config, shared_weights, fbuf)
692
 
693
+ var tok: Tokenizer = Tokenizer(config.vocab_size)
694
+
695
  if steps <= 0 or steps > config.seq_len:
696
  steps = config.seq_len
697
 
698
  # Read in the tokenizer.bin file
699
  read_file(tokenizer, tbuf)
700
+ tokenizer_init(tok, tbuf)
701
 
702
  # Create and initialize the application RunState
703
  var state = RunState(config)
704
 
 
 
 
 
 
 
705
  # Start the main loop
706
  var start = 0 # Used to time our code, only initialized after the first iteration
707
  var next_token = 0 # Will store the next token in the sequence
708
  # Initialize with token 1 (=BOS), as done in Llama-2 sentencepiece tokenizer
709
  var token = 1
710
+ var pos = 0 # Position in the sequence
711
+ # Explicitly print the initial BOS token for stylistic symmetry reasons
712
+
713
+ print("<s>")
714
 
 
 
715
  while pos < steps:
716
  # Forward the transformer to get logits for the next token
717
  transformer(token, pos, config, state, weights)
718
 
719
+ # Sample the next token
720
+ if temperature == 0.0:
721
+ # Greedy argmax sampling: take the token with the highest probability
722
+ next_token = argmax(state.logits)
723
  else:
724
+ # Apply the temperature to the logits
725
+ for q in range(config.vocab_size):
726
+ state.logits[q] = state.logits[q] / temperature
727
+ # Apply softmax to the logits to get the probabilities for the next token
728
+ softmax(state.logits.data, config.vocab_size)
729
+ # Sample from this distribution to get the next token
730
+ next_token = sample(state.logits)
731
+
 
 
 
 
 
 
 
 
732
  var token_str: PointerString = tok.vocab[next_token]
733
  if token == 1 and token_str[0] == ord(" "):
734
  token_str = token_str.offset(1)
735
 
736
  print_str(token_str)
737
+ # flush?
738
 
739
  # Advance forward
740
  token = next_token
 
744
  start = time_in_ms()
745
 
746
  let end = time_in_ms()
747
+ print("\nachieved tok/s: ", (steps - 1) / (end - start) * 1000)