Spaces:
Runtime error
Runtime error
update
Browse files- llama2.mojo +273 -129
- t260.bin +3 -0
llama2.mojo
CHANGED
@@ -133,26 +133,27 @@ struct Matrix:
|
|
133 |
self.data.simd_store[nelts](z * self.layers + y * self.cols + x, val)
|
134 |
|
135 |
|
136 |
-
fn read_val_int(inout buf: FileBuf) -> Int:
|
137 |
# DTypePointer[DType.ui8](buf.data).bitcast[DType.ui8]()
|
138 |
-
let data = buf.data.offset(buf.
|
139 |
-
let result = data.
|
140 |
-
buf.
|
141 |
return result.to_int()
|
142 |
|
143 |
|
144 |
-
fn read_val_float32(inout buf: FileBuf) -> Float32:
|
145 |
# DTypePointer[DType.ui8](buf.data).bitcast[DType.ui8]()
|
146 |
-
let val = buf.data.offset(buf.
|
147 |
-
buf.
|
148 |
return val
|
149 |
|
150 |
|
151 |
-
fn read_val_str(inout buf: FileBuf, slen: Int) -> PointerString:
|
|
|
152 |
let str = PointerString.alloc(slen + 1)
|
153 |
for i in range(slen):
|
154 |
-
str.store(i, buf.data.
|
155 |
-
buf.
|
156 |
str.store(slen, 0)
|
157 |
|
158 |
return str
|
@@ -168,7 +169,7 @@ fn str_concat(s1: PointerString, s2: PointerString) -> PointerString:
|
|
168 |
while s2[l2] != 0:
|
169 |
l2 += 1
|
170 |
|
171 |
-
let str = PointerString.alloc(l1 + l2)
|
172 |
memcpy[UInt8](str, s1, l1)
|
173 |
memcpy[UInt8](str.offset(l1), s2, l2)
|
174 |
str.store(l1 + l2, 0)
|
@@ -183,6 +184,63 @@ fn str_to_ptr(s: String) -> PointerString:
|
|
183 |
return ret
|
184 |
|
185 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
186 |
struct FileBuf:
|
187 |
var data: BufferPtrType
|
188 |
var offset: Int
|
@@ -193,36 +251,95 @@ struct FileBuf:
|
|
193 |
self.offset = 0
|
194 |
self.size = 0
|
195 |
|
196 |
-
fn move_offset(inout self, size: Int):
|
197 |
-
self.offset
|
|
|
|
|
|
|
|
|
|
|
198 |
|
199 |
-
fn bitcast_offset_float32(inout self, size: Int) -> BufferPtrFloat32:
|
200 |
let ret = self.data.offset(self.offset).bitcast[DType.float32]()
|
201 |
-
self.
|
202 |
return ret
|
203 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
204 |
|
205 |
struct Tokenizer:
|
206 |
var vocab: PointerStrings
|
207 |
var vocab_scores: BufferPtrFloat32
|
208 |
var max_token_length: Int
|
209 |
var vocab_size: Int
|
|
|
|
|
210 |
|
211 |
-
fn __init__(inout self, vocab_size: Int):
|
212 |
self.vocab_size = vocab_size
|
213 |
-
self.
|
214 |
-
self.vocab_scores = BufferPtrFloat32.alloc(vocab_size)
|
215 |
-
self.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
216 |
|
217 |
|
218 |
struct Config:
|
219 |
var dim: Int
|
|
|
220 |
var hidden_dim: Int
|
221 |
var n_layers: Int
|
222 |
var n_heads: Int
|
223 |
var n_kv_heads: Int
|
|
|
224 |
var vocab_size: Int
|
225 |
var seq_len: Int
|
|
|
226 |
|
227 |
fn __init__(inout self):
|
228 |
self.dim = 0
|
@@ -232,6 +349,9 @@ struct Config:
|
|
232 |
self.n_kv_heads = 0
|
233 |
self.vocab_size = 0
|
234 |
self.seq_len = 0
|
|
|
|
|
|
|
235 |
|
236 |
|
237 |
struct RunState:
|
@@ -241,8 +361,8 @@ struct RunState:
|
|
241 |
var hb: Matrix # buffer for hidden dimension in the ffn (hidden_dim,)
|
242 |
var hb2: Matrix # buffer for hidden dimension in the ffn (hidden_dim,)
|
243 |
var q: Matrix # query (dim,)
|
244 |
-
var k: Matrix # key (
|
245 |
-
var v: Matrix # value (
|
246 |
var att: Matrix # buffer for scores/attention values (n_heads, seq_len)
|
247 |
var logits: Matrix # output logits
|
248 |
var key_cache: Matrix # (layer, seq_len, dim)
|
@@ -262,17 +382,15 @@ struct RunState:
|
|
262 |
self.hb2.alloc_zero()
|
263 |
self.q = Matrix(config.dim)
|
264 |
self.q.alloc_zero()
|
265 |
-
self.k = Matrix(
|
266 |
-
self.
|
267 |
-
self.v = Matrix(config.dim)
|
268 |
-
self.v.alloc_zero()
|
269 |
self.att = Matrix(config.n_heads, config.seq_len)
|
270 |
self.att.alloc_zero()
|
271 |
self.logits = Matrix(config.vocab_size)
|
272 |
self.logits.alloc_zero()
|
273 |
-
self.key_cache = Matrix(config.n_layers, config.seq_len, config.
|
274 |
self.key_cache.alloc_zero()
|
275 |
-
self.value_cache = Matrix(config.n_layers, config.seq_len, config.
|
276 |
self.value_cache.alloc_zero()
|
277 |
self.rt = Runtime(num_cores() // 2)
|
278 |
|
@@ -293,7 +411,7 @@ struct TransformerWeights:
|
|
293 |
var rms_final_weight: Matrix
|
294 |
var wcls: Matrix
|
295 |
|
296 |
-
fn __init__(inout self, config: Config, shared_weights: Int, inout buf: FileBuf):
|
297 |
self.token_embedding_table = Matrix(config.vocab_size, config.dim)
|
298 |
# set buf ptr to buf data from file
|
299 |
self.token_embedding_table.set_buf_ptr(
|
@@ -305,9 +423,9 @@ struct TransformerWeights:
|
|
305 |
)
|
306 |
self.wq = Matrix(config.n_layers, config.dim, config.dim)
|
307 |
self.wq.set_buf_ptr(buf.bitcast_offset_float32(self.wq.size()))
|
308 |
-
self.wk = Matrix(config.n_layers, config.dim, config.
|
309 |
self.wk.set_buf_ptr(buf.bitcast_offset_float32(self.wk.size()))
|
310 |
-
self.wv = Matrix(config.n_layers, config.dim, config.
|
311 |
self.wv.set_buf_ptr(buf.bitcast_offset_float32(self.wv.size()))
|
312 |
self.wo = Matrix(config.n_layers, config.dim, config.dim)
|
313 |
self.wo.set_buf_ptr(buf.bitcast_offset_float32(self.wo.size()))
|
@@ -369,64 +487,77 @@ fn config_init(inout config: Config, inout buf: FileBuf) raises:
|
|
369 |
config.n_kv_heads = read_val_int(buf)
|
370 |
config.vocab_size = read_val_int(buf)
|
371 |
config.seq_len = read_val_int(buf)
|
372 |
-
|
373 |
-
|
374 |
-
|
375 |
-
fn tokenizer_init(inout tok: Tokenizer, inout buf: FileBuf) -> None:
|
376 |
-
tok.max_token_length = read_val_int(buf)
|
377 |
-
tok.vocab_scores = BufferPtrFloat32.alloc(tok.vocab_size)
|
378 |
-
tok.vocab = PointerStrings.alloc(tok.vocab_size)
|
379 |
-
|
380 |
-
# read vocab_scores & vocab values (tokens)
|
381 |
-
for i in range(0, tok.vocab_size):
|
382 |
-
tok.vocab_scores.simd_store[1](i, read_val_float32(buf))
|
383 |
-
let slen = read_val_int(buf)
|
384 |
-
tok.vocab.store(i, read_val_str(buf, slen))
|
385 |
-
|
386 |
-
tok.vocab_scores = buf.data.offset(buf.offset).bitcast[DType.float32]()
|
387 |
-
buf.offset += tok.vocab_size * 4
|
388 |
return None
|
389 |
|
390 |
|
391 |
fn accum(inout a: BufferPtrFloat32, b: BufferPtrFloat32, size: Int) -> None:
|
392 |
-
|
393 |
-
|
394 |
-
a.offset(
|
|
|
|
|
|
|
|
|
395 |
|
396 |
|
397 |
fn rmsnorm(
|
398 |
inout o: BufferPtrFloat32, x: BufferPtrFloat32, weight: BufferPtrFloat32, size: Int
|
399 |
) -> None:
|
400 |
# Calculate sum of squares
|
401 |
-
var
|
402 |
-
|
403 |
-
|
404 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
405 |
ss = ss / size + 1e-5
|
406 |
ss = 1.0 / math.sqrt(ss)
|
|
|
407 |
# Normalize and scale
|
408 |
-
|
409 |
-
|
410 |
-
|
|
|
|
|
|
|
411 |
|
412 |
|
413 |
fn softmax(inout x: BufferPtrFloat32, size: Int) -> None:
|
414 |
# Find max value (for numerical stability)
|
415 |
-
var max_val: Float32 =
|
416 |
-
|
417 |
-
|
418 |
-
|
419 |
-
|
|
|
|
|
|
|
|
|
|
|
420 |
# Exp and sum
|
421 |
var ssum: Float32 = 0.0
|
422 |
-
|
423 |
-
|
424 |
-
|
425 |
-
|
426 |
-
|
427 |
-
|
428 |
-
|
429 |
-
|
|
|
|
|
|
|
|
|
|
|
430 |
|
431 |
|
432 |
fn matmul_parallelized(C: Matrix, A: Matrix, B: Matrix, rt: Runtime):
|
@@ -463,7 +594,9 @@ fn transformer(
|
|
463 |
var x = state.x.data
|
464 |
let dim = config.dim
|
465 |
let hidden_dim = config.hidden_dim
|
466 |
-
let head_size =
|
|
|
|
|
467 |
|
468 |
# tmp matrix for matmul operations
|
469 |
var tmpw = Matrix(0, 0)
|
@@ -485,37 +618,41 @@ fn transformer(
|
|
485 |
tmpw.set_buf_ptr(weights.wq.data.offset(l * dim * dim), dim, dim)
|
486 |
matmul(state.q, state.xb, tmpw, state.rt)
|
487 |
|
488 |
-
|
|
|
|
|
489 |
matmul(state.k, state.xb, tmpw, state.rt)
|
490 |
|
491 |
-
|
|
|
|
|
|
|
492 |
matmul(state.v, state.xb, tmpw, state.rt)
|
493 |
|
494 |
# Apply RoPE rotation to the q and k vectors for each head
|
495 |
-
|
496 |
-
|
497 |
-
|
498 |
-
let
|
499 |
-
|
500 |
-
|
501 |
-
|
502 |
-
|
503 |
-
|
504 |
-
|
505 |
-
|
506 |
-
|
507 |
-
|
508 |
-
|
509 |
-
|
510 |
-
|
511 |
-
|
512 |
-
|
513 |
-
|
514 |
-
|
515 |
-
|
516 |
-
|
517 |
-
|
518 |
-
memcpy[DType.float32](value_cache_row, state.v.data, config.dim)
|
519 |
|
520 |
# Multihead attention. Iterate over all heads
|
521 |
for h in range(config.n_heads):
|
@@ -528,15 +665,17 @@ fn transformer(
|
|
528 |
# Iterate over all timesteps, including the current one
|
529 |
for t in range(pos + 1):
|
530 |
# Get the key vector for this head and at this timestep
|
531 |
-
let k = state.key_cache.data.offset(
|
|
|
|
|
532 |
# Calculate the attention score as the dot product of q and k
|
533 |
var score: Float32 = 0.0
|
534 |
for i in range(head_size):
|
535 |
-
score += q.offset(i).
|
536 |
score /= math.sqrt[DType.float32, 1](head_size)
|
537 |
|
538 |
# Save the score to the attention buffer
|
539 |
-
att.offset(t).
|
540 |
|
541 |
# Softmax the scores to get attention weights, from 0..pos inclusively
|
542 |
softmax(att, pos + 1)
|
@@ -546,15 +685,15 @@ fn transformer(
|
|
546 |
memset_zero(xb, head_size)
|
547 |
for t in range(pos + 1):
|
548 |
# Get the value vector for this head and at this timestep
|
549 |
-
let v = state.value_cache.data.offset(
|
|
|
|
|
550 |
# Get the attention weight for this timestep
|
551 |
-
let a = att.offset(t).
|
552 |
# Accumulate the weighted value into xb
|
553 |
for i in range(head_size):
|
554 |
-
let xbi = xb.offset(i).
|
555 |
-
|
556 |
-
](0)
|
557 |
-
xb.offset(i).simd_store[1](0, xbi)
|
558 |
# Final matrix multiplication to get the output of the attention
|
559 |
tmpw.set_buf_ptr(weights.wo.data.offset(l * dim * dim), dim, dim)
|
560 |
matmul(state.xb2, state.xb, tmpw, state.rt)
|
@@ -616,29 +755,15 @@ fn sample(probabilities: Matrix) -> Int:
|
|
616 |
var cdf: Float32 = 0.0
|
617 |
for i in range(n):
|
618 |
cdf += probabilities[i]
|
619 |
-
if r.
|
620 |
return i
|
621 |
return n - 1 # In case of rounding errors
|
622 |
|
623 |
|
624 |
-
fn
|
625 |
-
for pos in range(tok.vocab_size):
|
626 |
-
let s1 = tok.vocab[pos]
|
627 |
-
var p1 = 0
|
628 |
-
while s1[p1] != 0 and str[p1] != 0:
|
629 |
-
if s1[p1] != str[p1]:
|
630 |
-
break
|
631 |
-
p1 += 1
|
632 |
-
if s1[p1] != 0 or str[p1] != 0:
|
633 |
-
continue
|
634 |
-
return pos
|
635 |
-
return -1
|
636 |
-
|
637 |
-
|
638 |
-
fn bpe_encode(inout tokens: DynamicVector[Int], text: String, tok: Tokenizer):
|
639 |
for pos in range(len(text)):
|
640 |
let char = str_to_ptr(text[pos])
|
641 |
-
let id =
|
642 |
|
643 |
if id == -1:
|
644 |
print("Not a good prompt token at pos ", pos)
|
@@ -653,7 +778,7 @@ fn bpe_encode(inout tokens: DynamicVector[Int], text: String, tok: Tokenizer):
|
|
653 |
for i in range(len(tokens) - 1):
|
654 |
# Check if we can merge the pair (tokens[i], tokens[i+1])
|
655 |
let str = str_concat(tok.vocab[tokens[i]], tok.vocab[tokens[i + 1]])
|
656 |
-
let id =
|
657 |
if id != -1 and tok.vocab_scores.load(id) > best_score:
|
658 |
best_score = tok.vocab_scores.load(id)
|
659 |
best_id = id
|
@@ -674,7 +799,20 @@ fn bpe_encode(inout tokens: DynamicVector[Int], text: String, tok: Tokenizer):
|
|
674 |
tokens = _tokens
|
675 |
|
676 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
677 |
fn print_str(s: PointerString):
|
|
|
|
|
|
|
|
|
|
|
|
|
678 |
# print all chars till null character
|
679 |
var p: Int = 0
|
680 |
while s[p].to_int() != 0:
|
@@ -689,7 +827,10 @@ fn time_in_ms() -> Int:
|
|
689 |
|
690 |
fn print_usage():
|
691 |
print("Usage: mojo llama2.mojo <checkpoint> [options]")
|
692 |
-
print(
|
|
|
|
|
|
|
693 |
print("Options:")
|
694 |
print(" -s <int> random seed, default time.now()")
|
695 |
print(" -t <float> temperature in [0,1.0], default 1.0")
|
@@ -718,6 +859,8 @@ fn main() raises:
|
|
718 |
print("Option not supported: ", args[i])
|
719 |
if args[i] == "-n":
|
720 |
steps = atol(args[i + 1])
|
|
|
|
|
721 |
if args[i] == "-s":
|
722 |
rng_seed = atol(args[i + 1])
|
723 |
if args[i] == "-i":
|
@@ -748,7 +891,7 @@ fn main() raises:
|
|
748 |
var config: Config = Config()
|
749 |
|
750 |
read_file(checkpoint, fbuf)
|
751 |
-
print("checkpoint size: ", fbuf.size)
|
752 |
config_init(config, fbuf)
|
753 |
|
754 |
# negative vocab size is hacky way of signaling unshared weights. bit yikes.
|
@@ -759,14 +902,12 @@ fn main() raises:
|
|
759 |
|
760 |
let weights: TransformerWeights = TransformerWeights(config, shared_weights, fbuf)
|
761 |
|
762 |
-
var tok: Tokenizer = Tokenizer(config.vocab_size)
|
763 |
-
|
764 |
if steps <= 0 or steps > config.seq_len:
|
765 |
steps = config.seq_len
|
766 |
|
767 |
# Read in the tokenizer.bin file
|
768 |
read_file(tokenizer, tbuf)
|
769 |
-
|
770 |
|
771 |
# Create and initialize the application RunState
|
772 |
var state = RunState(config)
|
@@ -805,6 +946,9 @@ fn main() raises:
|
|
805 |
# Sample from this distribution to get the next token
|
806 |
next_token = sample(state.logits)
|
807 |
|
|
|
|
|
|
|
808 |
var token_str: PointerString = tok.vocab[next_token]
|
809 |
if token == 1 and token_str[0] == ord(" "):
|
810 |
token_str = token_str.offset(1)
|
@@ -819,4 +963,4 @@ fn main() raises:
|
|
819 |
start = time_in_ms()
|
820 |
|
821 |
let end = time_in_ms()
|
822 |
-
print("\nachieved tok/s: ", (
|
|
|
133 |
self.data.simd_store[nelts](z * self.layers + y * self.cols + x, val)
|
134 |
|
135 |
|
136 |
+
fn read_val_int(inout buf: FileBuf) raises -> Int:
|
137 |
# DTypePointer[DType.ui8](buf.data).bitcast[DType.ui8]()
|
138 |
+
let data = buf.data.offset(buf.get_offset()).bitcast[DType.uint32]()
|
139 |
+
let result = data.load(0)
|
140 |
+
buf.move_offset(4)
|
141 |
return result.to_int()
|
142 |
|
143 |
|
144 |
+
fn read_val_float32(inout buf: FileBuf) raises -> Float32:
|
145 |
# DTypePointer[DType.ui8](buf.data).bitcast[DType.ui8]()
|
146 |
+
let val = buf.data.offset(buf.get_offset()).bitcast[DType.float32]().load(0)
|
147 |
+
buf.move_offset(4)
|
148 |
return val
|
149 |
|
150 |
|
151 |
+
fn read_val_str(inout buf: FileBuf, slen: Int) raises -> PointerString:
|
152 |
+
|
153 |
let str = PointerString.alloc(slen + 1)
|
154 |
for i in range(slen):
|
155 |
+
str.store(i, buf.data.load(buf.get_offset()))
|
156 |
+
buf.move_offset(1)
|
157 |
str.store(slen, 0)
|
158 |
|
159 |
return str
|
|
|
169 |
while s2[l2] != 0:
|
170 |
l2 += 1
|
171 |
|
172 |
+
let str = PointerString.alloc(l1 + l2 + 1)
|
173 |
memcpy[UInt8](str, s1, l1)
|
174 |
memcpy[UInt8](str.offset(l1), s2, l2)
|
175 |
str.store(l1 + l2, 0)
|
|
|
184 |
return ret
|
185 |
|
186 |
|
187 |
+
fn string_compare(a: PointerString, b: PointerString) -> Int:
|
188 |
+
var index = 0
|
189 |
+
while a[index] != 0 and b[index] != 0:
|
190 |
+
if a[index] < b[index]:
|
191 |
+
return -1
|
192 |
+
if a[index] > b[index]:
|
193 |
+
return 1
|
194 |
+
|
195 |
+
index += 1
|
196 |
+
|
197 |
+
if a[index] != 0 and b[index] == 0:
|
198 |
+
return 1
|
199 |
+
|
200 |
+
if a[index] == 0 and b[index] != 0:
|
201 |
+
return -1
|
202 |
+
|
203 |
+
return 0
|
204 |
+
|
205 |
+
|
206 |
+
# Quicksort helper function to find the partition position
|
207 |
+
fn partition(
|
208 |
+
inout array: PointerStrings, inout indices: DynamicVector[Int], low: Int, high: Int
|
209 |
+
) -> Int:
|
210 |
+
let pivot = array[high]
|
211 |
+
var ii = low - 1
|
212 |
+
for jj in range(low, high):
|
213 |
+
if string_compare(pivot, array[jj]) == 1:
|
214 |
+
# If element smaller than pivot, swap
|
215 |
+
ii = ii + 1
|
216 |
+
|
217 |
+
let tmp = array[ii]
|
218 |
+
let tmp_idx = indices[ii]
|
219 |
+
array.store(ii, array[jj])
|
220 |
+
indices[ii] = indices[jj]
|
221 |
+
array.store(jj, tmp)
|
222 |
+
indices[jj] = tmp_idx
|
223 |
+
|
224 |
+
# Swap the pivot element
|
225 |
+
let tmp = array[ii + 1]
|
226 |
+
let tmp_idx = indices[ii + 1]
|
227 |
+
array.store(ii + 1, array[high])
|
228 |
+
indices[ii + 1] = indices[high]
|
229 |
+
array.store(high, tmp)
|
230 |
+
indices[high] = tmp_idx
|
231 |
+
|
232 |
+
return ii + 1
|
233 |
+
|
234 |
+
|
235 |
+
fn quicksort(
|
236 |
+
inout array: PointerStrings, inout indices: DynamicVector[Int], low: Int, high: Int
|
237 |
+
):
|
238 |
+
if low < high:
|
239 |
+
let pi = partition(array, indices, low, high)
|
240 |
+
quicksort(array, indices, low, pi - 1)
|
241 |
+
quicksort(array, indices, pi + 1, high)
|
242 |
+
|
243 |
+
|
244 |
struct FileBuf:
|
245 |
var data: BufferPtrType
|
246 |
var offset: Int
|
|
|
251 |
self.offset = 0
|
252 |
self.size = 0
|
253 |
|
254 |
+
fn move_offset(inout self, size: Int) raises:
|
255 |
+
let new_offset = self.offset + size
|
256 |
+
if new_offset > self.size:
|
257 |
+
raise Error("Resulting offset will be past the end of the FileBuf")
|
258 |
+
if new_offset < 0:
|
259 |
+
raise Error("Resulting offset will be before the beginning of the FileBuf")
|
260 |
+
self.offset = new_offset
|
261 |
|
262 |
+
fn bitcast_offset_float32(inout self, size: Int) raises -> BufferPtrFloat32:
|
263 |
let ret = self.data.offset(self.offset).bitcast[DType.float32]()
|
264 |
+
self.move_offset(size * sizeof[DType.float32]())
|
265 |
return ret
|
266 |
|
267 |
+
fn get_offset(self) raises -> Int:
|
268 |
+
if self.offset > self.size:
|
269 |
+
raise Error("Offset is past the end of the FileBuf")
|
270 |
+
if self.offset < 0:
|
271 |
+
raise Error("Offset is before the beginning of the FileBuf")
|
272 |
+
return self.offset
|
273 |
+
|
274 |
|
275 |
struct Tokenizer:
|
276 |
var vocab: PointerStrings
|
277 |
var vocab_scores: BufferPtrFloat32
|
278 |
var max_token_length: Int
|
279 |
var vocab_size: Int
|
280 |
+
var sorted_vocab: PointerStrings
|
281 |
+
var sorted_indices: DynamicVector[Int]
|
282 |
|
283 |
+
fn __init__(inout self, vocab_size: Int, inout buf: FileBuf) raises -> None:
|
284 |
self.vocab_size = vocab_size
|
285 |
+
self.max_token_length = read_val_int(buf)
|
286 |
+
self.vocab_scores = BufferPtrFloat32.alloc(self.vocab_size)
|
287 |
+
self.vocab = PointerStrings.alloc(self.vocab_size)
|
288 |
+
# lazy load sorted vocab
|
289 |
+
self.sorted_vocab = PointerStrings.alloc(0)
|
290 |
+
self.sorted_indices = DynamicVector[Int](0)
|
291 |
+
|
292 |
+
# read vocab_scores & vocab values (tokens)
|
293 |
+
for i in range(0, self.vocab_size):
|
294 |
+
self.vocab_scores.store(i, read_val_float32(buf))
|
295 |
+
let slen = read_val_int(buf)
|
296 |
+
self.vocab.store(i, read_val_str(buf, slen))
|
297 |
+
|
298 |
+
return None
|
299 |
+
|
300 |
+
# sort vocab by string_compare
|
301 |
+
fn sort(inout self) -> None:
|
302 |
+
if len(self.sorted_indices) < self.vocab_size:
|
303 |
+
self.sorted_indices = DynamicVector[Int](self.vocab_size)
|
304 |
+
self.sorted_vocab = PointerStrings.alloc(self.vocab_size)
|
305 |
+
for ii in range(self.vocab_size):
|
306 |
+
self.sorted_vocab.store(ii, self.vocab[ii])
|
307 |
+
self.sorted_indices.push_back(ii)
|
308 |
+
|
309 |
+
let n = self.vocab_size
|
310 |
+
quicksort(self.sorted_vocab, self.sorted_indices, 0, n - 1)
|
311 |
+
return None
|
312 |
+
|
313 |
+
# Binary search that returns -1 if string is not found
|
314 |
+
fn find(inout self, token: PointerString) -> Int:
|
315 |
+
let n = self.vocab_size
|
316 |
+
if len(self.sorted_indices) < n:
|
317 |
+
self.sort()
|
318 |
+
var left = 0
|
319 |
+
var right = n - 1
|
320 |
+
while left <= right:
|
321 |
+
let mid = left + (right - left) // 2
|
322 |
+
let comparison = string_compare(self.sorted_vocab[mid], token)
|
323 |
+
if comparison == 0:
|
324 |
+
return self.sorted_indices[mid]
|
325 |
+
if comparison < 0:
|
326 |
+
left = mid + 1
|
327 |
+
else:
|
328 |
+
right = mid - 1
|
329 |
+
return -1
|
330 |
|
331 |
|
332 |
struct Config:
|
333 |
var dim: Int
|
334 |
+
var kv_dim: Int
|
335 |
var hidden_dim: Int
|
336 |
var n_layers: Int
|
337 |
var n_heads: Int
|
338 |
var n_kv_heads: Int
|
339 |
+
var kv_mul: Int
|
340 |
var vocab_size: Int
|
341 |
var seq_len: Int
|
342 |
+
var head_size: Int
|
343 |
|
344 |
fn __init__(inout self):
|
345 |
self.dim = 0
|
|
|
349 |
self.n_kv_heads = 0
|
350 |
self.vocab_size = 0
|
351 |
self.seq_len = 0
|
352 |
+
self.kv_dim = 0
|
353 |
+
self.kv_mul = 0
|
354 |
+
self.head_size = 0
|
355 |
|
356 |
|
357 |
struct RunState:
|
|
|
361 |
var hb: Matrix # buffer for hidden dimension in the ffn (hidden_dim,)
|
362 |
var hb2: Matrix # buffer for hidden dimension in the ffn (hidden_dim,)
|
363 |
var q: Matrix # query (dim,)
|
364 |
+
var k: Matrix # key (kv_dim,)
|
365 |
+
var v: Matrix # value (kv_dim,)
|
366 |
var att: Matrix # buffer for scores/attention values (n_heads, seq_len)
|
367 |
var logits: Matrix # output logits
|
368 |
var key_cache: Matrix # (layer, seq_len, dim)
|
|
|
382 |
self.hb2.alloc_zero()
|
383 |
self.q = Matrix(config.dim)
|
384 |
self.q.alloc_zero()
|
385 |
+
self.k = Matrix(0, 0)
|
386 |
+
self.v = Matrix(0, 0)
|
|
|
|
|
387 |
self.att = Matrix(config.n_heads, config.seq_len)
|
388 |
self.att.alloc_zero()
|
389 |
self.logits = Matrix(config.vocab_size)
|
390 |
self.logits.alloc_zero()
|
391 |
+
self.key_cache = Matrix(config.n_layers, config.seq_len, config.kv_dim)
|
392 |
self.key_cache.alloc_zero()
|
393 |
+
self.value_cache = Matrix(config.n_layers, config.seq_len, config.kv_dim)
|
394 |
self.value_cache.alloc_zero()
|
395 |
self.rt = Runtime(num_cores() // 2)
|
396 |
|
|
|
411 |
var rms_final_weight: Matrix
|
412 |
var wcls: Matrix
|
413 |
|
414 |
+
fn __init__(inout self, config: Config, shared_weights: Int, inout buf: FileBuf) raises:
|
415 |
self.token_embedding_table = Matrix(config.vocab_size, config.dim)
|
416 |
# set buf ptr to buf data from file
|
417 |
self.token_embedding_table.set_buf_ptr(
|
|
|
423 |
)
|
424 |
self.wq = Matrix(config.n_layers, config.dim, config.dim)
|
425 |
self.wq.set_buf_ptr(buf.bitcast_offset_float32(self.wq.size()))
|
426 |
+
self.wk = Matrix(config.n_layers, config.dim, config.kv_dim)
|
427 |
self.wk.set_buf_ptr(buf.bitcast_offset_float32(self.wk.size()))
|
428 |
+
self.wv = Matrix(config.n_layers, config.dim, config.kv_dim)
|
429 |
self.wv.set_buf_ptr(buf.bitcast_offset_float32(self.wv.size()))
|
430 |
self.wo = Matrix(config.n_layers, config.dim, config.dim)
|
431 |
self.wo.set_buf_ptr(buf.bitcast_offset_float32(self.wo.size()))
|
|
|
487 |
config.n_kv_heads = read_val_int(buf)
|
488 |
config.vocab_size = read_val_int(buf)
|
489 |
config.seq_len = read_val_int(buf)
|
490 |
+
config.head_size = config.dim // config.n_heads
|
491 |
+
config.kv_dim = (config.n_kv_heads * config.dim) // config.n_heads
|
492 |
+
config.kv_mul = config.n_heads // config.n_kv_heads
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
493 |
return None
|
494 |
|
495 |
|
496 |
fn accum(inout a: BufferPtrFloat32, b: BufferPtrFloat32, size: Int) -> None:
|
497 |
+
@parameter
|
498 |
+
fn _acc[_nelts: Int](j: Int):
|
499 |
+
a.offset(j).simd_store[_nelts](
|
500 |
+
0, a.offset(j).simd_load[_nelts](0) + b.offset(j).simd_load[_nelts](0)
|
501 |
+
)
|
502 |
+
|
503 |
+
vectorize[nelts, _acc](size)
|
504 |
|
505 |
|
506 |
fn rmsnorm(
|
507 |
inout o: BufferPtrFloat32, x: BufferPtrFloat32, weight: BufferPtrFloat32, size: Int
|
508 |
) -> None:
|
509 |
# Calculate sum of squares
|
510 |
+
var tmp = SIMD[DType.float32, nelts](0)
|
511 |
+
|
512 |
+
@parameter
|
513 |
+
fn _sum2[_nelts: Int](j: Int):
|
514 |
+
if _nelts < nelts:
|
515 |
+
tmp[0] += (x.offset(j).simd_load[_nelts](0) ** 2).reduce_add()
|
516 |
+
else:
|
517 |
+
tmp += x.offset(j).simd_load[nelts](0) ** 2
|
518 |
+
|
519 |
+
vectorize[nelts, _sum2](size)
|
520 |
+
|
521 |
+
var ss: Float32 = tmp.reduce_add()
|
522 |
ss = ss / size + 1e-5
|
523 |
ss = 1.0 / math.sqrt(ss)
|
524 |
+
|
525 |
# Normalize and scale
|
526 |
+
@parameter
|
527 |
+
fn _norm[_nelts: Int](j: Int):
|
528 |
+
let val = weight.simd_load[_nelts](j) * ss * x.simd_load[_nelts](j)
|
529 |
+
o.offset(j).simd_store[_nelts](0, val)
|
530 |
+
|
531 |
+
vectorize[nelts, _norm](size)
|
532 |
|
533 |
|
534 |
fn softmax(inout x: BufferPtrFloat32, size: Int) -> None:
|
535 |
# Find max value (for numerical stability)
|
536 |
+
var max_val: Float32 = -1e9
|
537 |
+
|
538 |
+
@parameter
|
539 |
+
fn _max[_nelts: Int](j: Int):
|
540 |
+
let val = x.simd_load[_nelts](j).reduce_max()
|
541 |
+
if val > max_val:
|
542 |
+
max_val = val
|
543 |
+
|
544 |
+
vectorize[nelts, _max](size)
|
545 |
+
|
546 |
# Exp and sum
|
547 |
var ssum: Float32 = 0.0
|
548 |
+
|
549 |
+
@parameter
|
550 |
+
fn _sum_exp[_nelts: Int](j: Int):
|
551 |
+
x.simd_store[_nelts](j, math.exp(x.simd_load[_nelts](j) - max_val))
|
552 |
+
ssum += x.simd_load[_nelts](j).reduce_add()
|
553 |
+
|
554 |
+
vectorize[nelts, _sum_exp](size)
|
555 |
+
|
556 |
+
@parameter
|
557 |
+
fn _norm[_nelts: Int](j: Int):
|
558 |
+
x.simd_store[_nelts](j, x.simd_load[_nelts](j) / ssum)
|
559 |
+
|
560 |
+
vectorize[nelts, _norm](size)
|
561 |
|
562 |
|
563 |
fn matmul_parallelized(C: Matrix, A: Matrix, B: Matrix, rt: Runtime):
|
|
|
594 |
var x = state.x.data
|
595 |
let dim = config.dim
|
596 |
let hidden_dim = config.hidden_dim
|
597 |
+
let head_size = config.head_size
|
598 |
+
let kv_dim = config.kv_dim
|
599 |
+
let kv_mul = config.kv_mul
|
600 |
|
601 |
# tmp matrix for matmul operations
|
602 |
var tmpw = Matrix(0, 0)
|
|
|
618 |
tmpw.set_buf_ptr(weights.wq.data.offset(l * dim * dim), dim, dim)
|
619 |
matmul(state.q, state.xb, tmpw, state.rt)
|
620 |
|
621 |
+
let loff = l * config.seq_len * kv_dim
|
622 |
+
state.k.set_buf_ptr(state.key_cache.data.offset(loff + pos * kv_dim), 1, kv_dim)
|
623 |
+
tmpw.set_buf_ptr(weights.wk.data.offset(l * dim * kv_dim), kv_dim, dim)
|
624 |
matmul(state.k, state.xb, tmpw, state.rt)
|
625 |
|
626 |
+
state.v.set_buf_ptr(
|
627 |
+
state.value_cache.data.offset(loff + pos * kv_dim), 1, kv_dim
|
628 |
+
)
|
629 |
+
tmpw.set_buf_ptr(weights.wv.data.offset(l * dim * kv_dim), kv_dim, dim)
|
630 |
matmul(state.v, state.xb, tmpw, state.rt)
|
631 |
|
632 |
# Apply RoPE rotation to the q and k vectors for each head
|
633 |
+
let q = state.q.data
|
634 |
+
let k = state.k.data
|
635 |
+
for i in range(0, head_size * config.n_kv_heads, 2):
|
636 |
+
let head_dim_half = i % head_size // 2
|
637 |
+
let fcr = freq_cis_real_row.offset(head_dim_half).load(0)
|
638 |
+
let fci = freq_cis_imag_row.offset(head_dim_half).load(0)
|
639 |
+
let q0 = q.offset(i).load(0)
|
640 |
+
let q1 = q.offset(i + 1).load(0)
|
641 |
+
let k0 = k.offset(i).load(0)
|
642 |
+
let k1 = k.offset(i + 1).load(0)
|
643 |
+
q.offset(i).store(0, q0 * fcr - q1 * fci)
|
644 |
+
q.offset(i + 1).store(0, q0 * fci + q1 * fcr)
|
645 |
+
k.offset(i).store(0, k0 * fcr - k1 * fci)
|
646 |
+
k.offset(i + 1).store(0, k0 * fci + k1 * fcr)
|
647 |
+
|
648 |
+
for i in range(head_size * config.n_kv_heads, dim, 2):
|
649 |
+
let head_dim_half = i % head_size // 2
|
650 |
+
let fcr = freq_cis_real_row.offset(head_dim_half).load(0)
|
651 |
+
let fci = freq_cis_imag_row.offset(head_dim_half).load(0)
|
652 |
+
let q0 = q.offset(i).load(0)
|
653 |
+
let q1 = q.offset(i + 1).load(0)
|
654 |
+
q.offset(i).store(0, q0 * fcr - q1 * fci)
|
655 |
+
q.offset(i + 1).store(0, q0 * fci + q1 * fcr)
|
|
|
656 |
|
657 |
# Multihead attention. Iterate over all heads
|
658 |
for h in range(config.n_heads):
|
|
|
665 |
# Iterate over all timesteps, including the current one
|
666 |
for t in range(pos + 1):
|
667 |
# Get the key vector for this head and at this timestep
|
668 |
+
let k = state.key_cache.data.offset(
|
669 |
+
loff + t * kv_dim + (h // kv_mul) * head_size
|
670 |
+
)
|
671 |
# Calculate the attention score as the dot product of q and k
|
672 |
var score: Float32 = 0.0
|
673 |
for i in range(head_size):
|
674 |
+
score += q.offset(i).load(0) * k.offset(i).load(0)
|
675 |
score /= math.sqrt[DType.float32, 1](head_size)
|
676 |
|
677 |
# Save the score to the attention buffer
|
678 |
+
att.offset(t).store(0, score)
|
679 |
|
680 |
# Softmax the scores to get attention weights, from 0..pos inclusively
|
681 |
softmax(att, pos + 1)
|
|
|
685 |
memset_zero(xb, head_size)
|
686 |
for t in range(pos + 1):
|
687 |
# Get the value vector for this head and at this timestep
|
688 |
+
let v = state.value_cache.data.offset(
|
689 |
+
loff + t * kv_dim + (h // kv_mul) * head_size
|
690 |
+
)
|
691 |
# Get the attention weight for this timestep
|
692 |
+
let a = att.offset(t).load(0)
|
693 |
# Accumulate the weighted value into xb
|
694 |
for i in range(head_size):
|
695 |
+
let xbi = xb.offset(i).load(0) + a * v.offset(i).load(0)
|
696 |
+
xb.offset(i).store(0, xbi)
|
|
|
|
|
697 |
# Final matrix multiplication to get the output of the attention
|
698 |
tmpw.set_buf_ptr(weights.wo.data.offset(l * dim * dim), dim, dim)
|
699 |
matmul(state.xb2, state.xb, tmpw, state.rt)
|
|
|
755 |
var cdf: Float32 = 0.0
|
756 |
for i in range(n):
|
757 |
cdf += probabilities[i]
|
758 |
+
if r.load(0) < cdf:
|
759 |
return i
|
760 |
return n - 1 # In case of rounding errors
|
761 |
|
762 |
|
763 |
+
fn bpe_encode(inout tokens: DynamicVector[Int], text: String, inout tok: Tokenizer):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
764 |
for pos in range(len(text)):
|
765 |
let char = str_to_ptr(text[pos])
|
766 |
+
let id = tok.find(char)
|
767 |
|
768 |
if id == -1:
|
769 |
print("Not a good prompt token at pos ", pos)
|
|
|
778 |
for i in range(len(tokens) - 1):
|
779 |
# Check if we can merge the pair (tokens[i], tokens[i+1])
|
780 |
let str = str_concat(tok.vocab[tokens[i]], tok.vocab[tokens[i + 1]])
|
781 |
+
let id = tok.find(str)
|
782 |
if id != -1 and tok.vocab_scores.load(id) > best_score:
|
783 |
best_score = tok.vocab_scores.load(id)
|
784 |
best_id = id
|
|
|
799 |
tokens = _tokens
|
800 |
|
801 |
|
802 |
+
fn str2num(d: Int) -> Int:
|
803 |
+
# covert Hex to decimal
|
804 |
+
if d >= ord("A"):
|
805 |
+
return d - ord("A") + 10
|
806 |
+
return d - ord("0")
|
807 |
+
|
808 |
+
|
809 |
fn print_str(s: PointerString):
|
810 |
+
# print raw byte like <0x0A>
|
811 |
+
if (s[1].to_int() == ord("0")) and (s[2].to_int() == ord("x")):
|
812 |
+
let d1: Int = s[3].to_int()
|
813 |
+
let d2: Int = s[4].to_int()
|
814 |
+
print_no_newline(chr(str2num(d1) * 16 + str2num(d2)))
|
815 |
+
return
|
816 |
# print all chars till null character
|
817 |
var p: Int = 0
|
818 |
while s[p].to_int() != 0:
|
|
|
827 |
|
828 |
fn print_usage():
|
829 |
print("Usage: mojo llama2.mojo <checkpoint> [options]")
|
830 |
+
print(
|
831 |
+
'Example: mojo llama2.mojo stories15M.bin -s 99 -n 256 -t 0.5 -i "Llama is an'
|
832 |
+
' animal"'
|
833 |
+
)
|
834 |
print("Options:")
|
835 |
print(" -s <int> random seed, default time.now()")
|
836 |
print(" -t <float> temperature in [0,1.0], default 1.0")
|
|
|
859 |
print("Option not supported: ", args[i])
|
860 |
if args[i] == "-n":
|
861 |
steps = atol(args[i + 1])
|
862 |
+
if args[i] == "-tk":
|
863 |
+
tokenizer = args[i + 1]
|
864 |
if args[i] == "-s":
|
865 |
rng_seed = atol(args[i + 1])
|
866 |
if args[i] == "-i":
|
|
|
891 |
var config: Config = Config()
|
892 |
|
893 |
read_file(checkpoint, fbuf)
|
894 |
+
print("checkpoint size: ", fbuf.size, "[", fbuf.size // 1024 // 1024, "MB ]")
|
895 |
config_init(config, fbuf)
|
896 |
|
897 |
# negative vocab size is hacky way of signaling unshared weights. bit yikes.
|
|
|
902 |
|
903 |
let weights: TransformerWeights = TransformerWeights(config, shared_weights, fbuf)
|
904 |
|
|
|
|
|
905 |
if steps <= 0 or steps > config.seq_len:
|
906 |
steps = config.seq_len
|
907 |
|
908 |
# Read in the tokenizer.bin file
|
909 |
read_file(tokenizer, tbuf)
|
910 |
+
var tok = Tokenizer(config.vocab_size, tbuf)
|
911 |
|
912 |
# Create and initialize the application RunState
|
913 |
var state = RunState(config)
|
|
|
946 |
# Sample from this distribution to get the next token
|
947 |
next_token = sample(state.logits)
|
948 |
|
949 |
+
# Finish generating when EOS, BOS appear
|
950 |
+
if next_token == 1 or next_token == 2:
|
951 |
+
break
|
952 |
var token_str: PointerString = tok.vocab[next_token]
|
953 |
if token == 1 and token_str[0] == ord(" "):
|
954 |
token_str = token_str.offset(1)
|
|
|
963 |
start = time_in_ms()
|
964 |
|
965 |
let end = time_in_ms()
|
966 |
+
print("\nachieved tok/s: ", (pos - 1) / (end - start) * 1000)
|
t260.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:037cb335abb25d1fa9e8ecae30ed2a3a8ace9302862ebcdc05d51a6bbb10c312
|
3 |
+
size 6227
|