radames HF staff commited on
Commit
5c30f1d
1 Parent(s): 264c8c8
Files changed (2) hide show
  1. llama2.mojo +273 -129
  2. t260.bin +3 -0
llama2.mojo CHANGED
@@ -133,26 +133,27 @@ struct Matrix:
133
  self.data.simd_store[nelts](z * self.layers + y * self.cols + x, val)
134
 
135
 
136
- fn read_val_int(inout buf: FileBuf) -> Int:
137
  # DTypePointer[DType.ui8](buf.data).bitcast[DType.ui8]()
138
- let data = buf.data.offset(buf.offset).bitcast[DType.uint32]()
139
- let result = data.simd_load[1](0)
140
- buf.offset += 4
141
  return result.to_int()
142
 
143
 
144
- fn read_val_float32(inout buf: FileBuf) -> Float32:
145
  # DTypePointer[DType.ui8](buf.data).bitcast[DType.ui8]()
146
- let val = buf.data.offset(buf.offset).bitcast[DType.float32]().simd_load[1](0)
147
- buf.offset += 4
148
  return val
149
 
150
 
151
- fn read_val_str(inout buf: FileBuf, slen: Int) -> PointerString:
 
152
  let str = PointerString.alloc(slen + 1)
153
  for i in range(slen):
154
- str.store(i, buf.data.simd_load[1](buf.offset))
155
- buf.offset += 1
156
  str.store(slen, 0)
157
 
158
  return str
@@ -168,7 +169,7 @@ fn str_concat(s1: PointerString, s2: PointerString) -> PointerString:
168
  while s2[l2] != 0:
169
  l2 += 1
170
 
171
- let str = PointerString.alloc(l1 + l2)
172
  memcpy[UInt8](str, s1, l1)
173
  memcpy[UInt8](str.offset(l1), s2, l2)
174
  str.store(l1 + l2, 0)
@@ -183,6 +184,63 @@ fn str_to_ptr(s: String) -> PointerString:
183
  return ret
184
 
185
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
186
  struct FileBuf:
187
  var data: BufferPtrType
188
  var offset: Int
@@ -193,36 +251,95 @@ struct FileBuf:
193
  self.offset = 0
194
  self.size = 0
195
 
196
- fn move_offset(inout self, size: Int):
197
- self.offset += size
 
 
 
 
 
198
 
199
- fn bitcast_offset_float32(inout self, size: Int) -> BufferPtrFloat32:
200
  let ret = self.data.offset(self.offset).bitcast[DType.float32]()
201
- self.offset += size * sizeof[DType.float32]()
202
  return ret
203
 
 
 
 
 
 
 
 
204
 
205
  struct Tokenizer:
206
  var vocab: PointerStrings
207
  var vocab_scores: BufferPtrFloat32
208
  var max_token_length: Int
209
  var vocab_size: Int
 
 
210
 
211
- fn __init__(inout self, vocab_size: Int):
212
  self.vocab_size = vocab_size
213
- self.vocab = PointerStrings.alloc(vocab_size)
214
- self.vocab_scores = BufferPtrFloat32.alloc(vocab_size)
215
- self.max_token_length = 0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
216
 
217
 
218
  struct Config:
219
  var dim: Int
 
220
  var hidden_dim: Int
221
  var n_layers: Int
222
  var n_heads: Int
223
  var n_kv_heads: Int
 
224
  var vocab_size: Int
225
  var seq_len: Int
 
226
 
227
  fn __init__(inout self):
228
  self.dim = 0
@@ -232,6 +349,9 @@ struct Config:
232
  self.n_kv_heads = 0
233
  self.vocab_size = 0
234
  self.seq_len = 0
 
 
 
235
 
236
 
237
  struct RunState:
@@ -241,8 +361,8 @@ struct RunState:
241
  var hb: Matrix # buffer for hidden dimension in the ffn (hidden_dim,)
242
  var hb2: Matrix # buffer for hidden dimension in the ffn (hidden_dim,)
243
  var q: Matrix # query (dim,)
244
- var k: Matrix # key (dim,)
245
- var v: Matrix # value (dim,)
246
  var att: Matrix # buffer for scores/attention values (n_heads, seq_len)
247
  var logits: Matrix # output logits
248
  var key_cache: Matrix # (layer, seq_len, dim)
@@ -262,17 +382,15 @@ struct RunState:
262
  self.hb2.alloc_zero()
263
  self.q = Matrix(config.dim)
264
  self.q.alloc_zero()
265
- self.k = Matrix(config.dim)
266
- self.k.alloc_zero()
267
- self.v = Matrix(config.dim)
268
- self.v.alloc_zero()
269
  self.att = Matrix(config.n_heads, config.seq_len)
270
  self.att.alloc_zero()
271
  self.logits = Matrix(config.vocab_size)
272
  self.logits.alloc_zero()
273
- self.key_cache = Matrix(config.n_layers, config.seq_len, config.dim)
274
  self.key_cache.alloc_zero()
275
- self.value_cache = Matrix(config.n_layers, config.seq_len, config.dim)
276
  self.value_cache.alloc_zero()
277
  self.rt = Runtime(num_cores() // 2)
278
 
@@ -293,7 +411,7 @@ struct TransformerWeights:
293
  var rms_final_weight: Matrix
294
  var wcls: Matrix
295
 
296
- fn __init__(inout self, config: Config, shared_weights: Int, inout buf: FileBuf):
297
  self.token_embedding_table = Matrix(config.vocab_size, config.dim)
298
  # set buf ptr to buf data from file
299
  self.token_embedding_table.set_buf_ptr(
@@ -305,9 +423,9 @@ struct TransformerWeights:
305
  )
306
  self.wq = Matrix(config.n_layers, config.dim, config.dim)
307
  self.wq.set_buf_ptr(buf.bitcast_offset_float32(self.wq.size()))
308
- self.wk = Matrix(config.n_layers, config.dim, config.dim)
309
  self.wk.set_buf_ptr(buf.bitcast_offset_float32(self.wk.size()))
310
- self.wv = Matrix(config.n_layers, config.dim, config.dim)
311
  self.wv.set_buf_ptr(buf.bitcast_offset_float32(self.wv.size()))
312
  self.wo = Matrix(config.n_layers, config.dim, config.dim)
313
  self.wo.set_buf_ptr(buf.bitcast_offset_float32(self.wo.size()))
@@ -369,64 +487,77 @@ fn config_init(inout config: Config, inout buf: FileBuf) raises:
369
  config.n_kv_heads = read_val_int(buf)
370
  config.vocab_size = read_val_int(buf)
371
  config.seq_len = read_val_int(buf)
372
- return None
373
-
374
-
375
- fn tokenizer_init(inout tok: Tokenizer, inout buf: FileBuf) -> None:
376
- tok.max_token_length = read_val_int(buf)
377
- tok.vocab_scores = BufferPtrFloat32.alloc(tok.vocab_size)
378
- tok.vocab = PointerStrings.alloc(tok.vocab_size)
379
-
380
- # read vocab_scores & vocab values (tokens)
381
- for i in range(0, tok.vocab_size):
382
- tok.vocab_scores.simd_store[1](i, read_val_float32(buf))
383
- let slen = read_val_int(buf)
384
- tok.vocab.store(i, read_val_str(buf, slen))
385
-
386
- tok.vocab_scores = buf.data.offset(buf.offset).bitcast[DType.float32]()
387
- buf.offset += tok.vocab_size * 4
388
  return None
389
 
390
 
391
  fn accum(inout a: BufferPtrFloat32, b: BufferPtrFloat32, size: Int) -> None:
392
- for i in range(size):
393
- let val = a.offset(i).simd_load[1](0) + b.offset(i).simd_load[1](0)
394
- a.offset(i).simd_store[1](0, val)
 
 
 
 
395
 
396
 
397
  fn rmsnorm(
398
  inout o: BufferPtrFloat32, x: BufferPtrFloat32, weight: BufferPtrFloat32, size: Int
399
  ) -> None:
400
  # Calculate sum of squares
401
- var ss: Float32 = 0.0
402
- for i in range(size):
403
- let xx = x.offset(i).simd_load[1](0) ** 2
404
- ss += xx
 
 
 
 
 
 
 
 
405
  ss = ss / size + 1e-5
406
  ss = 1.0 / math.sqrt(ss)
 
407
  # Normalize and scale
408
- for j in range(size):
409
- let val = weight.offset(j).simd_load[1](0) * (ss * x.offset(j).simd_load[1](0))
410
- o.offset(j).simd_store[1](0, val)
 
 
 
411
 
412
 
413
  fn softmax(inout x: BufferPtrFloat32, size: Int) -> None:
414
  # Find max value (for numerical stability)
415
- var max_val: Float32 = x.offset(0).simd_load[1](0)
416
- for i in range(size):
417
- let xi = x.offset(i).simd_load[1](0)
418
- if xi > max_val:
419
- max_val = xi
 
 
 
 
 
420
  # Exp and sum
421
  var ssum: Float32 = 0.0
422
- for i in range(size):
423
- let xi = x.offset(i).simd_load[1](0)
424
- x.offset(i).simd_store[1](0, math.exp(xi - max_val))
425
- ssum += x.offset(i).simd_load[1](0)
426
- # Normalize
427
- for i in range(size):
428
- let xi = x.offset(i).simd_load[1](0)
429
- x.offset(i).simd_store[1](0, xi / ssum)
 
 
 
 
 
430
 
431
 
432
  fn matmul_parallelized(C: Matrix, A: Matrix, B: Matrix, rt: Runtime):
@@ -463,7 +594,9 @@ fn transformer(
463
  var x = state.x.data
464
  let dim = config.dim
465
  let hidden_dim = config.hidden_dim
466
- let head_size = dim // config.n_heads
 
 
467
 
468
  # tmp matrix for matmul operations
469
  var tmpw = Matrix(0, 0)
@@ -485,37 +618,41 @@ fn transformer(
485
  tmpw.set_buf_ptr(weights.wq.data.offset(l * dim * dim), dim, dim)
486
  matmul(state.q, state.xb, tmpw, state.rt)
487
 
488
- tmpw.set_buf_ptr(weights.wk.data.offset(l * dim * dim), dim, dim)
 
 
489
  matmul(state.k, state.xb, tmpw, state.rt)
490
 
491
- tmpw.set_buf_ptr(weights.wv.data.offset(l * dim * dim), dim, dim)
 
 
 
492
  matmul(state.v, state.xb, tmpw, state.rt)
493
 
494
  # Apply RoPE rotation to the q and k vectors for each head
495
- for h in range(config.n_heads):
496
- # Get the q and k vectors for this head
497
- let q = state.q.data.offset(h * head_size)
498
- let k = state.k.data.offset(h * head_size)
499
-
500
- # Rotate q and k by the freq_cis_real and freq_cis_imag
501
- for i in range(0, head_size, 2):
502
- let q0 = q.offset(i).simd_load[1](0)
503
- let q1 = q.offset(i + 1).simd_load[1](0)
504
- let k0 = k.offset(i).simd_load[1](0)
505
- let k1 = k.offset(i + 1).simd_load[1](0)
506
- let fcr = freq_cis_real_row.offset(i // 2).simd_load[1](0)
507
- let fci = freq_cis_imag_row.offset(i // 2).simd_load[1](0)
508
- q.offset(i).simd_store[1](0, q0 * fcr - q1 * fci)
509
- q.offset(i + 1).simd_store[1](0, q0 * fci + q1 * fcr)
510
- k.offset(i).simd_store[1](0, k0 * fcr - k1 * fci)
511
- k.offset(i + 1).simd_store[1](0, k0 * fci + k1 * fcr)
512
-
513
- # Save key,value at this time step (pos) to our kv cache
514
- let loff = l * config.seq_len * dim # kv cache layer offset for convenience
515
- let key_cache_row = state.key_cache.data.offset(loff + pos * dim)
516
- let value_cache_row = state.value_cache.data.offset(loff + pos * dim)
517
- memcpy[DType.float32](key_cache_row, state.k.data, config.dim)
518
- memcpy[DType.float32](value_cache_row, state.v.data, config.dim)
519
 
520
  # Multihead attention. Iterate over all heads
521
  for h in range(config.n_heads):
@@ -528,15 +665,17 @@ fn transformer(
528
  # Iterate over all timesteps, including the current one
529
  for t in range(pos + 1):
530
  # Get the key vector for this head and at this timestep
531
- let k = state.key_cache.data.offset(loff + t * dim + h * head_size)
 
 
532
  # Calculate the attention score as the dot product of q and k
533
  var score: Float32 = 0.0
534
  for i in range(head_size):
535
- score += q.offset(i).simd_load[1](0) * k.offset(i).simd_load[1](0)
536
  score /= math.sqrt[DType.float32, 1](head_size)
537
 
538
  # Save the score to the attention buffer
539
- att.offset(t).simd_store[1](0, score)
540
 
541
  # Softmax the scores to get attention weights, from 0..pos inclusively
542
  softmax(att, pos + 1)
@@ -546,15 +685,15 @@ fn transformer(
546
  memset_zero(xb, head_size)
547
  for t in range(pos + 1):
548
  # Get the value vector for this head and at this timestep
549
- let v = state.value_cache.data.offset(loff + t * dim + h * head_size)
 
 
550
  # Get the attention weight for this timestep
551
- let a = att.offset(t).simd_load[1](0)
552
  # Accumulate the weighted value into xb
553
  for i in range(head_size):
554
- let xbi = xb.offset(i).simd_load[1](0) + a * v.offset(i).simd_load[
555
- 1
556
- ](0)
557
- xb.offset(i).simd_store[1](0, xbi)
558
  # Final matrix multiplication to get the output of the attention
559
  tmpw.set_buf_ptr(weights.wo.data.offset(l * dim * dim), dim, dim)
560
  matmul(state.xb2, state.xb, tmpw, state.rt)
@@ -616,29 +755,15 @@ fn sample(probabilities: Matrix) -> Int:
616
  var cdf: Float32 = 0.0
617
  for i in range(n):
618
  cdf += probabilities[i]
619
- if r.simd_load[1](0) < cdf:
620
  return i
621
  return n - 1 # In case of rounding errors
622
 
623
 
624
- fn str_lookup(str: PointerString, tok: Tokenizer) -> Int:
625
- for pos in range(tok.vocab_size):
626
- let s1 = tok.vocab[pos]
627
- var p1 = 0
628
- while s1[p1] != 0 and str[p1] != 0:
629
- if s1[p1] != str[p1]:
630
- break
631
- p1 += 1
632
- if s1[p1] != 0 or str[p1] != 0:
633
- continue
634
- return pos
635
- return -1
636
-
637
-
638
- fn bpe_encode(inout tokens: DynamicVector[Int], text: String, tok: Tokenizer):
639
  for pos in range(len(text)):
640
  let char = str_to_ptr(text[pos])
641
- let id = str_lookup(char, tok)
642
 
643
  if id == -1:
644
  print("Not a good prompt token at pos ", pos)
@@ -653,7 +778,7 @@ fn bpe_encode(inout tokens: DynamicVector[Int], text: String, tok: Tokenizer):
653
  for i in range(len(tokens) - 1):
654
  # Check if we can merge the pair (tokens[i], tokens[i+1])
655
  let str = str_concat(tok.vocab[tokens[i]], tok.vocab[tokens[i + 1]])
656
- let id = str_lookup(str, tok)
657
  if id != -1 and tok.vocab_scores.load(id) > best_score:
658
  best_score = tok.vocab_scores.load(id)
659
  best_id = id
@@ -674,7 +799,20 @@ fn bpe_encode(inout tokens: DynamicVector[Int], text: String, tok: Tokenizer):
674
  tokens = _tokens
675
 
676
 
 
 
 
 
 
 
 
677
  fn print_str(s: PointerString):
 
 
 
 
 
 
678
  # print all chars till null character
679
  var p: Int = 0
680
  while s[p].to_int() != 0:
@@ -689,7 +827,10 @@ fn time_in_ms() -> Int:
689
 
690
  fn print_usage():
691
  print("Usage: mojo llama2.mojo <checkpoint> [options]")
692
- print("Example: mojo llama2.mojo stories15M.bin -s 99 -n 256 -t 0.5 -i \"Llama is an animal\"")
 
 
 
693
  print("Options:")
694
  print(" -s <int> random seed, default time.now()")
695
  print(" -t <float> temperature in [0,1.0], default 1.0")
@@ -718,6 +859,8 @@ fn main() raises:
718
  print("Option not supported: ", args[i])
719
  if args[i] == "-n":
720
  steps = atol(args[i + 1])
 
 
721
  if args[i] == "-s":
722
  rng_seed = atol(args[i + 1])
723
  if args[i] == "-i":
@@ -748,7 +891,7 @@ fn main() raises:
748
  var config: Config = Config()
749
 
750
  read_file(checkpoint, fbuf)
751
- print("checkpoint size: ", fbuf.size)
752
  config_init(config, fbuf)
753
 
754
  # negative vocab size is hacky way of signaling unshared weights. bit yikes.
@@ -759,14 +902,12 @@ fn main() raises:
759
 
760
  let weights: TransformerWeights = TransformerWeights(config, shared_weights, fbuf)
761
 
762
- var tok: Tokenizer = Tokenizer(config.vocab_size)
763
-
764
  if steps <= 0 or steps > config.seq_len:
765
  steps = config.seq_len
766
 
767
  # Read in the tokenizer.bin file
768
  read_file(tokenizer, tbuf)
769
- tokenizer_init(tok, tbuf)
770
 
771
  # Create and initialize the application RunState
772
  var state = RunState(config)
@@ -805,6 +946,9 @@ fn main() raises:
805
  # Sample from this distribution to get the next token
806
  next_token = sample(state.logits)
807
 
 
 
 
808
  var token_str: PointerString = tok.vocab[next_token]
809
  if token == 1 and token_str[0] == ord(" "):
810
  token_str = token_str.offset(1)
@@ -819,4 +963,4 @@ fn main() raises:
819
  start = time_in_ms()
820
 
821
  let end = time_in_ms()
822
- print("\nachieved tok/s: ", (steps - 1) / (end - start) * 1000)
 
133
  self.data.simd_store[nelts](z * self.layers + y * self.cols + x, val)
134
 
135
 
136
+ fn read_val_int(inout buf: FileBuf) raises -> Int:
137
  # DTypePointer[DType.ui8](buf.data).bitcast[DType.ui8]()
138
+ let data = buf.data.offset(buf.get_offset()).bitcast[DType.uint32]()
139
+ let result = data.load(0)
140
+ buf.move_offset(4)
141
  return result.to_int()
142
 
143
 
144
+ fn read_val_float32(inout buf: FileBuf) raises -> Float32:
145
  # DTypePointer[DType.ui8](buf.data).bitcast[DType.ui8]()
146
+ let val = buf.data.offset(buf.get_offset()).bitcast[DType.float32]().load(0)
147
+ buf.move_offset(4)
148
  return val
149
 
150
 
151
+ fn read_val_str(inout buf: FileBuf, slen: Int) raises -> PointerString:
152
+
153
  let str = PointerString.alloc(slen + 1)
154
  for i in range(slen):
155
+ str.store(i, buf.data.load(buf.get_offset()))
156
+ buf.move_offset(1)
157
  str.store(slen, 0)
158
 
159
  return str
 
169
  while s2[l2] != 0:
170
  l2 += 1
171
 
172
+ let str = PointerString.alloc(l1 + l2 + 1)
173
  memcpy[UInt8](str, s1, l1)
174
  memcpy[UInt8](str.offset(l1), s2, l2)
175
  str.store(l1 + l2, 0)
 
184
  return ret
185
 
186
 
187
+ fn string_compare(a: PointerString, b: PointerString) -> Int:
188
+ var index = 0
189
+ while a[index] != 0 and b[index] != 0:
190
+ if a[index] < b[index]:
191
+ return -1
192
+ if a[index] > b[index]:
193
+ return 1
194
+
195
+ index += 1
196
+
197
+ if a[index] != 0 and b[index] == 0:
198
+ return 1
199
+
200
+ if a[index] == 0 and b[index] != 0:
201
+ return -1
202
+
203
+ return 0
204
+
205
+
206
+ # Quicksort helper function to find the partition position
207
+ fn partition(
208
+ inout array: PointerStrings, inout indices: DynamicVector[Int], low: Int, high: Int
209
+ ) -> Int:
210
+ let pivot = array[high]
211
+ var ii = low - 1
212
+ for jj in range(low, high):
213
+ if string_compare(pivot, array[jj]) == 1:
214
+ # If element smaller than pivot, swap
215
+ ii = ii + 1
216
+
217
+ let tmp = array[ii]
218
+ let tmp_idx = indices[ii]
219
+ array.store(ii, array[jj])
220
+ indices[ii] = indices[jj]
221
+ array.store(jj, tmp)
222
+ indices[jj] = tmp_idx
223
+
224
+ # Swap the pivot element
225
+ let tmp = array[ii + 1]
226
+ let tmp_idx = indices[ii + 1]
227
+ array.store(ii + 1, array[high])
228
+ indices[ii + 1] = indices[high]
229
+ array.store(high, tmp)
230
+ indices[high] = tmp_idx
231
+
232
+ return ii + 1
233
+
234
+
235
+ fn quicksort(
236
+ inout array: PointerStrings, inout indices: DynamicVector[Int], low: Int, high: Int
237
+ ):
238
+ if low < high:
239
+ let pi = partition(array, indices, low, high)
240
+ quicksort(array, indices, low, pi - 1)
241
+ quicksort(array, indices, pi + 1, high)
242
+
243
+
244
  struct FileBuf:
245
  var data: BufferPtrType
246
  var offset: Int
 
251
  self.offset = 0
252
  self.size = 0
253
 
254
+ fn move_offset(inout self, size: Int) raises:
255
+ let new_offset = self.offset + size
256
+ if new_offset > self.size:
257
+ raise Error("Resulting offset will be past the end of the FileBuf")
258
+ if new_offset < 0:
259
+ raise Error("Resulting offset will be before the beginning of the FileBuf")
260
+ self.offset = new_offset
261
 
262
+ fn bitcast_offset_float32(inout self, size: Int) raises -> BufferPtrFloat32:
263
  let ret = self.data.offset(self.offset).bitcast[DType.float32]()
264
+ self.move_offset(size * sizeof[DType.float32]())
265
  return ret
266
 
267
+ fn get_offset(self) raises -> Int:
268
+ if self.offset > self.size:
269
+ raise Error("Offset is past the end of the FileBuf")
270
+ if self.offset < 0:
271
+ raise Error("Offset is before the beginning of the FileBuf")
272
+ return self.offset
273
+
274
 
275
  struct Tokenizer:
276
  var vocab: PointerStrings
277
  var vocab_scores: BufferPtrFloat32
278
  var max_token_length: Int
279
  var vocab_size: Int
280
+ var sorted_vocab: PointerStrings
281
+ var sorted_indices: DynamicVector[Int]
282
 
283
+ fn __init__(inout self, vocab_size: Int, inout buf: FileBuf) raises -> None:
284
  self.vocab_size = vocab_size
285
+ self.max_token_length = read_val_int(buf)
286
+ self.vocab_scores = BufferPtrFloat32.alloc(self.vocab_size)
287
+ self.vocab = PointerStrings.alloc(self.vocab_size)
288
+ # lazy load sorted vocab
289
+ self.sorted_vocab = PointerStrings.alloc(0)
290
+ self.sorted_indices = DynamicVector[Int](0)
291
+
292
+ # read vocab_scores & vocab values (tokens)
293
+ for i in range(0, self.vocab_size):
294
+ self.vocab_scores.store(i, read_val_float32(buf))
295
+ let slen = read_val_int(buf)
296
+ self.vocab.store(i, read_val_str(buf, slen))
297
+
298
+ return None
299
+
300
+ # sort vocab by string_compare
301
+ fn sort(inout self) -> None:
302
+ if len(self.sorted_indices) < self.vocab_size:
303
+ self.sorted_indices = DynamicVector[Int](self.vocab_size)
304
+ self.sorted_vocab = PointerStrings.alloc(self.vocab_size)
305
+ for ii in range(self.vocab_size):
306
+ self.sorted_vocab.store(ii, self.vocab[ii])
307
+ self.sorted_indices.push_back(ii)
308
+
309
+ let n = self.vocab_size
310
+ quicksort(self.sorted_vocab, self.sorted_indices, 0, n - 1)
311
+ return None
312
+
313
+ # Binary search that returns -1 if string is not found
314
+ fn find(inout self, token: PointerString) -> Int:
315
+ let n = self.vocab_size
316
+ if len(self.sorted_indices) < n:
317
+ self.sort()
318
+ var left = 0
319
+ var right = n - 1
320
+ while left <= right:
321
+ let mid = left + (right - left) // 2
322
+ let comparison = string_compare(self.sorted_vocab[mid], token)
323
+ if comparison == 0:
324
+ return self.sorted_indices[mid]
325
+ if comparison < 0:
326
+ left = mid + 1
327
+ else:
328
+ right = mid - 1
329
+ return -1
330
 
331
 
332
  struct Config:
333
  var dim: Int
334
+ var kv_dim: Int
335
  var hidden_dim: Int
336
  var n_layers: Int
337
  var n_heads: Int
338
  var n_kv_heads: Int
339
+ var kv_mul: Int
340
  var vocab_size: Int
341
  var seq_len: Int
342
+ var head_size: Int
343
 
344
  fn __init__(inout self):
345
  self.dim = 0
 
349
  self.n_kv_heads = 0
350
  self.vocab_size = 0
351
  self.seq_len = 0
352
+ self.kv_dim = 0
353
+ self.kv_mul = 0
354
+ self.head_size = 0
355
 
356
 
357
  struct RunState:
 
361
  var hb: Matrix # buffer for hidden dimension in the ffn (hidden_dim,)
362
  var hb2: Matrix # buffer for hidden dimension in the ffn (hidden_dim,)
363
  var q: Matrix # query (dim,)
364
+ var k: Matrix # key (kv_dim,)
365
+ var v: Matrix # value (kv_dim,)
366
  var att: Matrix # buffer for scores/attention values (n_heads, seq_len)
367
  var logits: Matrix # output logits
368
  var key_cache: Matrix # (layer, seq_len, dim)
 
382
  self.hb2.alloc_zero()
383
  self.q = Matrix(config.dim)
384
  self.q.alloc_zero()
385
+ self.k = Matrix(0, 0)
386
+ self.v = Matrix(0, 0)
 
 
387
  self.att = Matrix(config.n_heads, config.seq_len)
388
  self.att.alloc_zero()
389
  self.logits = Matrix(config.vocab_size)
390
  self.logits.alloc_zero()
391
+ self.key_cache = Matrix(config.n_layers, config.seq_len, config.kv_dim)
392
  self.key_cache.alloc_zero()
393
+ self.value_cache = Matrix(config.n_layers, config.seq_len, config.kv_dim)
394
  self.value_cache.alloc_zero()
395
  self.rt = Runtime(num_cores() // 2)
396
 
 
411
  var rms_final_weight: Matrix
412
  var wcls: Matrix
413
 
414
+ fn __init__(inout self, config: Config, shared_weights: Int, inout buf: FileBuf) raises:
415
  self.token_embedding_table = Matrix(config.vocab_size, config.dim)
416
  # set buf ptr to buf data from file
417
  self.token_embedding_table.set_buf_ptr(
 
423
  )
424
  self.wq = Matrix(config.n_layers, config.dim, config.dim)
425
  self.wq.set_buf_ptr(buf.bitcast_offset_float32(self.wq.size()))
426
+ self.wk = Matrix(config.n_layers, config.dim, config.kv_dim)
427
  self.wk.set_buf_ptr(buf.bitcast_offset_float32(self.wk.size()))
428
+ self.wv = Matrix(config.n_layers, config.dim, config.kv_dim)
429
  self.wv.set_buf_ptr(buf.bitcast_offset_float32(self.wv.size()))
430
  self.wo = Matrix(config.n_layers, config.dim, config.dim)
431
  self.wo.set_buf_ptr(buf.bitcast_offset_float32(self.wo.size()))
 
487
  config.n_kv_heads = read_val_int(buf)
488
  config.vocab_size = read_val_int(buf)
489
  config.seq_len = read_val_int(buf)
490
+ config.head_size = config.dim // config.n_heads
491
+ config.kv_dim = (config.n_kv_heads * config.dim) // config.n_heads
492
+ config.kv_mul = config.n_heads // config.n_kv_heads
 
 
 
 
 
 
 
 
 
 
 
 
 
493
  return None
494
 
495
 
496
  fn accum(inout a: BufferPtrFloat32, b: BufferPtrFloat32, size: Int) -> None:
497
+ @parameter
498
+ fn _acc[_nelts: Int](j: Int):
499
+ a.offset(j).simd_store[_nelts](
500
+ 0, a.offset(j).simd_load[_nelts](0) + b.offset(j).simd_load[_nelts](0)
501
+ )
502
+
503
+ vectorize[nelts, _acc](size)
504
 
505
 
506
  fn rmsnorm(
507
  inout o: BufferPtrFloat32, x: BufferPtrFloat32, weight: BufferPtrFloat32, size: Int
508
  ) -> None:
509
  # Calculate sum of squares
510
+ var tmp = SIMD[DType.float32, nelts](0)
511
+
512
+ @parameter
513
+ fn _sum2[_nelts: Int](j: Int):
514
+ if _nelts < nelts:
515
+ tmp[0] += (x.offset(j).simd_load[_nelts](0) ** 2).reduce_add()
516
+ else:
517
+ tmp += x.offset(j).simd_load[nelts](0) ** 2
518
+
519
+ vectorize[nelts, _sum2](size)
520
+
521
+ var ss: Float32 = tmp.reduce_add()
522
  ss = ss / size + 1e-5
523
  ss = 1.0 / math.sqrt(ss)
524
+
525
  # Normalize and scale
526
+ @parameter
527
+ fn _norm[_nelts: Int](j: Int):
528
+ let val = weight.simd_load[_nelts](j) * ss * x.simd_load[_nelts](j)
529
+ o.offset(j).simd_store[_nelts](0, val)
530
+
531
+ vectorize[nelts, _norm](size)
532
 
533
 
534
  fn softmax(inout x: BufferPtrFloat32, size: Int) -> None:
535
  # Find max value (for numerical stability)
536
+ var max_val: Float32 = -1e9
537
+
538
+ @parameter
539
+ fn _max[_nelts: Int](j: Int):
540
+ let val = x.simd_load[_nelts](j).reduce_max()
541
+ if val > max_val:
542
+ max_val = val
543
+
544
+ vectorize[nelts, _max](size)
545
+
546
  # Exp and sum
547
  var ssum: Float32 = 0.0
548
+
549
+ @parameter
550
+ fn _sum_exp[_nelts: Int](j: Int):
551
+ x.simd_store[_nelts](j, math.exp(x.simd_load[_nelts](j) - max_val))
552
+ ssum += x.simd_load[_nelts](j).reduce_add()
553
+
554
+ vectorize[nelts, _sum_exp](size)
555
+
556
+ @parameter
557
+ fn _norm[_nelts: Int](j: Int):
558
+ x.simd_store[_nelts](j, x.simd_load[_nelts](j) / ssum)
559
+
560
+ vectorize[nelts, _norm](size)
561
 
562
 
563
  fn matmul_parallelized(C: Matrix, A: Matrix, B: Matrix, rt: Runtime):
 
594
  var x = state.x.data
595
  let dim = config.dim
596
  let hidden_dim = config.hidden_dim
597
+ let head_size = config.head_size
598
+ let kv_dim = config.kv_dim
599
+ let kv_mul = config.kv_mul
600
 
601
  # tmp matrix for matmul operations
602
  var tmpw = Matrix(0, 0)
 
618
  tmpw.set_buf_ptr(weights.wq.data.offset(l * dim * dim), dim, dim)
619
  matmul(state.q, state.xb, tmpw, state.rt)
620
 
621
+ let loff = l * config.seq_len * kv_dim
622
+ state.k.set_buf_ptr(state.key_cache.data.offset(loff + pos * kv_dim), 1, kv_dim)
623
+ tmpw.set_buf_ptr(weights.wk.data.offset(l * dim * kv_dim), kv_dim, dim)
624
  matmul(state.k, state.xb, tmpw, state.rt)
625
 
626
+ state.v.set_buf_ptr(
627
+ state.value_cache.data.offset(loff + pos * kv_dim), 1, kv_dim
628
+ )
629
+ tmpw.set_buf_ptr(weights.wv.data.offset(l * dim * kv_dim), kv_dim, dim)
630
  matmul(state.v, state.xb, tmpw, state.rt)
631
 
632
  # Apply RoPE rotation to the q and k vectors for each head
633
+ let q = state.q.data
634
+ let k = state.k.data
635
+ for i in range(0, head_size * config.n_kv_heads, 2):
636
+ let head_dim_half = i % head_size // 2
637
+ let fcr = freq_cis_real_row.offset(head_dim_half).load(0)
638
+ let fci = freq_cis_imag_row.offset(head_dim_half).load(0)
639
+ let q0 = q.offset(i).load(0)
640
+ let q1 = q.offset(i + 1).load(0)
641
+ let k0 = k.offset(i).load(0)
642
+ let k1 = k.offset(i + 1).load(0)
643
+ q.offset(i).store(0, q0 * fcr - q1 * fci)
644
+ q.offset(i + 1).store(0, q0 * fci + q1 * fcr)
645
+ k.offset(i).store(0, k0 * fcr - k1 * fci)
646
+ k.offset(i + 1).store(0, k0 * fci + k1 * fcr)
647
+
648
+ for i in range(head_size * config.n_kv_heads, dim, 2):
649
+ let head_dim_half = i % head_size // 2
650
+ let fcr = freq_cis_real_row.offset(head_dim_half).load(0)
651
+ let fci = freq_cis_imag_row.offset(head_dim_half).load(0)
652
+ let q0 = q.offset(i).load(0)
653
+ let q1 = q.offset(i + 1).load(0)
654
+ q.offset(i).store(0, q0 * fcr - q1 * fci)
655
+ q.offset(i + 1).store(0, q0 * fci + q1 * fcr)
 
656
 
657
  # Multihead attention. Iterate over all heads
658
  for h in range(config.n_heads):
 
665
  # Iterate over all timesteps, including the current one
666
  for t in range(pos + 1):
667
  # Get the key vector for this head and at this timestep
668
+ let k = state.key_cache.data.offset(
669
+ loff + t * kv_dim + (h // kv_mul) * head_size
670
+ )
671
  # Calculate the attention score as the dot product of q and k
672
  var score: Float32 = 0.0
673
  for i in range(head_size):
674
+ score += q.offset(i).load(0) * k.offset(i).load(0)
675
  score /= math.sqrt[DType.float32, 1](head_size)
676
 
677
  # Save the score to the attention buffer
678
+ att.offset(t).store(0, score)
679
 
680
  # Softmax the scores to get attention weights, from 0..pos inclusively
681
  softmax(att, pos + 1)
 
685
  memset_zero(xb, head_size)
686
  for t in range(pos + 1):
687
  # Get the value vector for this head and at this timestep
688
+ let v = state.value_cache.data.offset(
689
+ loff + t * kv_dim + (h // kv_mul) * head_size
690
+ )
691
  # Get the attention weight for this timestep
692
+ let a = att.offset(t).load(0)
693
  # Accumulate the weighted value into xb
694
  for i in range(head_size):
695
+ let xbi = xb.offset(i).load(0) + a * v.offset(i).load(0)
696
+ xb.offset(i).store(0, xbi)
 
 
697
  # Final matrix multiplication to get the output of the attention
698
  tmpw.set_buf_ptr(weights.wo.data.offset(l * dim * dim), dim, dim)
699
  matmul(state.xb2, state.xb, tmpw, state.rt)
 
755
  var cdf: Float32 = 0.0
756
  for i in range(n):
757
  cdf += probabilities[i]
758
+ if r.load(0) < cdf:
759
  return i
760
  return n - 1 # In case of rounding errors
761
 
762
 
763
+ fn bpe_encode(inout tokens: DynamicVector[Int], text: String, inout tok: Tokenizer):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
764
  for pos in range(len(text)):
765
  let char = str_to_ptr(text[pos])
766
+ let id = tok.find(char)
767
 
768
  if id == -1:
769
  print("Not a good prompt token at pos ", pos)
 
778
  for i in range(len(tokens) - 1):
779
  # Check if we can merge the pair (tokens[i], tokens[i+1])
780
  let str = str_concat(tok.vocab[tokens[i]], tok.vocab[tokens[i + 1]])
781
+ let id = tok.find(str)
782
  if id != -1 and tok.vocab_scores.load(id) > best_score:
783
  best_score = tok.vocab_scores.load(id)
784
  best_id = id
 
799
  tokens = _tokens
800
 
801
 
802
+ fn str2num(d: Int) -> Int:
803
+ # covert Hex to decimal
804
+ if d >= ord("A"):
805
+ return d - ord("A") + 10
806
+ return d - ord("0")
807
+
808
+
809
  fn print_str(s: PointerString):
810
+ # print raw byte like <0x0A>
811
+ if (s[1].to_int() == ord("0")) and (s[2].to_int() == ord("x")):
812
+ let d1: Int = s[3].to_int()
813
+ let d2: Int = s[4].to_int()
814
+ print_no_newline(chr(str2num(d1) * 16 + str2num(d2)))
815
+ return
816
  # print all chars till null character
817
  var p: Int = 0
818
  while s[p].to_int() != 0:
 
827
 
828
  fn print_usage():
829
  print("Usage: mojo llama2.mojo <checkpoint> [options]")
830
+ print(
831
+ 'Example: mojo llama2.mojo stories15M.bin -s 99 -n 256 -t 0.5 -i "Llama is an'
832
+ ' animal"'
833
+ )
834
  print("Options:")
835
  print(" -s <int> random seed, default time.now()")
836
  print(" -t <float> temperature in [0,1.0], default 1.0")
 
859
  print("Option not supported: ", args[i])
860
  if args[i] == "-n":
861
  steps = atol(args[i + 1])
862
+ if args[i] == "-tk":
863
+ tokenizer = args[i + 1]
864
  if args[i] == "-s":
865
  rng_seed = atol(args[i + 1])
866
  if args[i] == "-i":
 
891
  var config: Config = Config()
892
 
893
  read_file(checkpoint, fbuf)
894
+ print("checkpoint size: ", fbuf.size, "[", fbuf.size // 1024 // 1024, "MB ]")
895
  config_init(config, fbuf)
896
 
897
  # negative vocab size is hacky way of signaling unshared weights. bit yikes.
 
902
 
903
  let weights: TransformerWeights = TransformerWeights(config, shared_weights, fbuf)
904
 
 
 
905
  if steps <= 0 or steps > config.seq_len:
906
  steps = config.seq_len
907
 
908
  # Read in the tokenizer.bin file
909
  read_file(tokenizer, tbuf)
910
+ var tok = Tokenizer(config.vocab_size, tbuf)
911
 
912
  # Create and initialize the application RunState
913
  var state = RunState(config)
 
946
  # Sample from this distribution to get the next token
947
  next_token = sample(state.logits)
948
 
949
+ # Finish generating when EOS, BOS appear
950
+ if next_token == 1 or next_token == 2:
951
+ break
952
  var token_str: PointerString = tok.vocab[next_token]
953
  if token == 1 and token_str[0] == ord(" "):
954
  token_str = token_str.offset(1)
 
963
  start = time_in_ms()
964
 
965
  let end = time_in_ms()
966
+ print("\nachieved tok/s: ", (pos - 1) / (end - start) * 1000)
t260.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:037cb335abb25d1fa9e8ecae30ed2a3a8ace9302862ebcdc05d51a6bbb10c312
3
+ size 6227