radames commited on
Commit
264c8c8
1 Parent(s): c7eecb3
Files changed (3) hide show
  1. Dockerfile +3 -1
  2. gradio_app.py +66 -16
  3. llama2.mojo +230 -155
Dockerfile CHANGED
@@ -64,7 +64,9 @@ USER user
64
  WORKDIR $HOME/app
65
 
66
  COPY --chown=user . $HOME/app
67
- RUN wget -c https://huggingface.co/karpathy/tinyllamas/resolve/main/stories15M.bin
 
 
68
 
69
  # CMD ["mojo", "llama2.mojo"]
70
  CMD ["python3", "gradio_app.py"]
 
64
  WORKDIR $HOME/app
65
 
66
  COPY --chown=user . $HOME/app
67
+ RUN wget -c https://huggingface.co/karpathy/tinyllamas/resolve/main/stories15M.bin
68
+ RUN wget -c https://huggingface.co/karpathy/tinyllamas/resolve/main/stories42M.bin
69
+ RUN wget -c https://huggingface.co/karpathy/tinyllamas/resolve/main/stories110M.bin
70
 
71
  # CMD ["mojo", "llama2.mojo"]
72
  CMD ["python3", "gradio_app.py"]
gradio_app.py CHANGED
@@ -1,36 +1,86 @@
1
  import gradio as gr
2
  import subprocess
3
  import sys
4
- import os
5
 
6
 
7
- async def generate(prompt):
8
- # os.environ["PROMPT"] = prompt
9
  # stream stout
10
  process = subprocess.Popen(
11
- ["mojo", "llama2.mojo"], stdout=subprocess.PIPE, stderr=subprocess.PIPE
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  )
13
  text = ""
14
  for char in iter(lambda: process.stdout.read(1), b""):
15
- char_decoded = char.decode()
16
- sys.stdout.write(char_decoded)
17
  text += char_decoded
18
  yield text
19
 
20
 
21
- output_text = gr.Textbox(label="Generated Text")
22
-
23
- demo = gr.Interface(
24
- fn=generate,
25
- inputs=None,
26
- outputs=output_text,
27
- description="""
28
  # llama2.🔥
29
  ## [Mojo](https://docs.modular.com/mojo/) implementation of [llama2.c](https://github.com/karpathy/llama2.c) by [@tairov](https://github.com/tairov)
30
  Source: https://github.com/tairov/llama2.mojo
31
- """,
32
- allow_flagging="never",
33
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
  demo.queue()
36
  demo.launch(server_name="0.0.0.0")
 
1
  import gradio as gr
2
  import subprocess
3
  import sys
4
+ from pathlib import Path
5
 
6
 
7
+ async def generate(prompt, model_name, seed=0, temperature=0.5, num_tokens=256):
 
8
  # stream stout
9
  process = subprocess.Popen(
10
+ [
11
+ "mojo",
12
+ "llama2.mojo",
13
+ Path(model_name),
14
+ "-s",
15
+ str(seed),
16
+ "-n",
17
+ str(num_tokens),
18
+ "-t",
19
+ str(temperature),
20
+ "-i",
21
+ prompt,
22
+ ],
23
+ stdout=subprocess.PIPE,
24
+ stderr=subprocess.PIPE,
25
  )
26
  text = ""
27
  for char in iter(lambda: process.stdout.read(1), b""):
28
+ char_decoded = char.decode("utf-8", errors="ignore")
 
29
  text += char_decoded
30
  yield text
31
 
32
 
33
+ with gr.Blocks() as demo:
34
+ gr.Markdown(
35
+ """
 
 
 
 
36
  # llama2.🔥
37
  ## [Mojo](https://docs.modular.com/mojo/) implementation of [llama2.c](https://github.com/karpathy/llama2.c) by [@tairov](https://github.com/tairov)
38
  Source: https://github.com/tairov/llama2.mojo
39
+ """
40
+ )
41
+ with gr.Row():
42
+ with gr.Column():
43
+ prompt = gr.Textbox(label="Prompt", placeholder="Add your prompt here...")
44
+ seed = gr.Slider(
45
+ minimum=0,
46
+ maximum=2**53,
47
+ value=0,
48
+ step=1,
49
+ label="Seed",
50
+ randomize=True,
51
+ )
52
+ temperature = gr.Slider(
53
+ minimum=0.0, maximum=2.0, step=0.01, value=0.5, label="Temperature"
54
+ )
55
+ num_tokens = gr.Slider(
56
+ minimum=1, maximum=256, value=256, label="Number of tokens"
57
+ )
58
+ model_name = gr.Dropdown(
59
+ ["stories15M.bin", "stories42M.bin", "stories110M.bin"],
60
+ value="stories15M.bin",
61
+ label="Model Size",
62
+ )
63
+ with gr.Row():
64
+ stop = gr.Button("Stop")
65
+ run = gr.Button("Run")
66
+ with gr.Column(scale=2):
67
+ output_text = gr.Textbox(label="Generated Text")
68
+
69
+ # update maximum number of tokens based on model size
70
+ model_name.change(
71
+ lambda x: gr.update(maximum=1024)
72
+ if x == "stories110M.bin" or x == "stories42M.bin"
73
+ else gr.update(maximum=256),
74
+ model_name,
75
+ num_tokens,
76
+ queue=False,
77
+ )
78
+ click_event = run.click(
79
+ fn=generate,
80
+ inputs=[prompt, model_name, seed, temperature, num_tokens],
81
+ outputs=output_text,
82
+ )
83
+ stop.click(fn=None, inputs=None, outputs=None, cancels=[click_event])
84
 
85
  demo.queue()
86
  demo.launch(server_name="0.0.0.0")
llama2.mojo CHANGED
@@ -1,25 +1,22 @@
 
 
 
1
  from math import round
2
- import math
3
-
4
  from memory import memset_zero, memcpy
 
5
  from memory.unsafe import DTypePointer
 
6
  from random import rand
7
- from sys.info import simdwidthof
8
- from builtin import string
9
- import time
10
- import random
11
- import os
12
-
13
- from runtime.llcl import num_cores
14
-
15
  from read import BufReader, File
16
- from memory.buffer import Buffer
17
-
18
- from python import Python
19
 
20
  # The SIMD vector width.
21
- from algorithm import vectorize, parallelize
22
- from algorithm import sum
 
 
 
23
 
24
  alias nelts = (2 * simdwidthof[DType.float32]())
25
 
@@ -29,98 +26,51 @@ alias BufferPtrFloat32 = DTypePointer[DType.float32]
29
  alias PointerStrings = Pointer[PointerString]
30
 
31
 
32
- struct Matrix3:
33
- var data: BufferPtrFloat32
34
- var rows: Int
35
- var cols: Int
36
- var layers: Int
37
- var allocated: Int
38
-
39
- fn __init__(inout self, layers: Int, rows: Int, cols: Int):
40
- self.data = BufferPtrFloat32.alloc(0)
41
- self.rows = rows
42
- self.cols = cols
43
- self.layers = layers
44
- self.allocated = 0
45
-
46
- @always_inline
47
- fn alloc(inout self, fill: Int = 0):
48
- self.data = BufferPtrFloat32.alloc(self.size())
49
- self.allocated = 1
50
- if fill == 1:
51
- self.zero()
52
-
53
- @always_inline
54
- fn alloc_zero(inout self):
55
- self.alloc(1)
56
-
57
- @always_inline
58
- fn set_buf_ptr(inout self, ptr: BufferPtrFloat32):
59
- self.data = ptr
60
-
61
- fn __del__(owned self):
62
- if self.allocated == 1:
63
- self.data.free()
64
-
65
- @always_inline
66
- fn zero(inout self):
67
- memset_zero(self.data, self.layers * self.rows * self.cols)
68
-
69
- @always_inline
70
- fn size(inout self) -> Int:
71
- return self.layers * self.cols * self.rows
72
-
73
- @always_inline
74
- fn __getitem__(self, z: Int, y: Int, x: Int) -> Float32:
75
- return self.load[1](z, y, x)
76
-
77
- @always_inline
78
- fn load[nelts: Int](self, z: Int, y: Int, x: Int) -> SIMD[DType.float32, nelts]:
79
- return self.data.simd_load[nelts](z * self.layers + y * self.cols + x)
80
-
81
- @always_inline
82
- fn __setitem__(self, z: Int, y: Int, x: Int, val: Float32):
83
- return self.store[1](z, y, x, val)
84
-
85
- @always_inline
86
- fn store[nelts: Int](self, z: Int, y: Int, x: Int, val: SIMD[DType.float32, nelts]):
87
- self.data.simd_store[nelts](z * self.layers + y * self.cols + x, val)
88
-
89
-
90
  struct Matrix:
91
  var data: BufferPtrFloat32
92
  var rows: Int
93
  var cols: Int
 
94
  var allocated: Int
95
 
96
  fn __init__(inout self, rows: Int, cols: Int):
97
  self.data = BufferPtrFloat32.alloc(0)
98
  self.rows = rows
99
  self.cols = cols
 
100
  self.allocated = 0
101
 
102
  fn __init__(inout self, cols: Int):
103
  self.data = BufferPtrFloat32.alloc(0)
104
  self.rows = 1
 
105
  self.cols = cols
106
  self.allocated = 0
107
 
 
 
 
 
108
  fn __del__(owned self):
109
  if self.allocated == 1:
110
  self.data.free()
111
 
 
112
  fn alloc(inout self, fill: Int = 0):
113
  self.data = BufferPtrFloat32.alloc(self.size())
114
  self.allocated = 1
115
  if fill == 1:
116
  self.zero()
117
 
 
118
  fn alloc_zero(inout self):
119
  self.alloc(1)
120
 
 
121
  fn zero(inout self):
122
- memset_zero(self.data, self.rows * self.cols)
123
 
 
124
  fn set_buf_ptr(inout self, ptr: BufferPtrFloat32):
125
  self.data = ptr
126
 
@@ -130,8 +80,9 @@ struct Matrix:
130
  self.rows = rows
131
  self.cols = cols
132
 
 
133
  fn size(inout self) -> Int:
134
- return self.cols * self.rows
135
 
136
  @always_inline
137
  fn __getitem__(self, y: Int, x: Int) -> Float32:
@@ -165,6 +116,22 @@ struct Matrix:
165
  fn store[nelts: Int](self, x: Int, val: SIMD[DType.float32, nelts]):
166
  self.data.simd_store[nelts](x, val)
167
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
 
169
  fn read_val_int(inout buf: FileBuf) -> Int:
170
  # DTypePointer[DType.ui8](buf.data).bitcast[DType.ui8]()
@@ -191,6 +158,31 @@ fn read_val_str(inout buf: FileBuf, slen: Int) -> PointerString:
191
  return str
192
 
193
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
194
  struct FileBuf:
195
  var data: BufferPtrType
196
  var offset: Int
@@ -253,8 +245,9 @@ struct RunState:
253
  var v: Matrix # value (dim,)
254
  var att: Matrix # buffer for scores/attention values (n_heads, seq_len)
255
  var logits: Matrix # output logits
256
- var key_cache: Matrix3 # (layer, seq_len, dim)
257
- var value_cache: Matrix3 # (layer, seq_len, dim)
 
258
 
259
  fn __init__(inout self, config: Config):
260
  self.x = Matrix(config.dim)
@@ -277,10 +270,11 @@ struct RunState:
277
  self.att.alloc_zero()
278
  self.logits = Matrix(config.vocab_size)
279
  self.logits.alloc_zero()
280
- self.key_cache = Matrix3(config.n_layers, config.seq_len, config.dim)
281
  self.key_cache.alloc_zero()
282
- self.value_cache = Matrix3(config.n_layers, config.seq_len, config.dim)
283
  self.value_cache.alloc_zero()
 
284
 
285
 
286
  struct TransformerWeights:
@@ -288,14 +282,14 @@ struct TransformerWeights:
288
  var freq_cis_real: Matrix
289
  var freq_cis_imag: Matrix
290
  var rms_att_weight: Matrix
291
- var wq: Matrix3
292
- var wk: Matrix3
293
- var wv: Matrix3
294
- var wo: Matrix3
295
  var rms_ffn_weight: Matrix
296
- var w1: Matrix3
297
- var w3: Matrix3
298
- var w2: Matrix3
299
  var rms_final_weight: Matrix
300
  var wcls: Matrix
301
 
@@ -309,23 +303,23 @@ struct TransformerWeights:
309
  self.rms_att_weight.set_buf_ptr(
310
  buf.bitcast_offset_float32(self.rms_att_weight.size())
311
  )
312
- self.wq = Matrix3(config.n_layers, config.dim, config.dim)
313
  self.wq.set_buf_ptr(buf.bitcast_offset_float32(self.wq.size()))
314
- self.wk = Matrix3(config.n_layers, config.dim, config.dim)
315
  self.wk.set_buf_ptr(buf.bitcast_offset_float32(self.wk.size()))
316
- self.wv = Matrix3(config.n_layers, config.dim, config.dim)
317
  self.wv.set_buf_ptr(buf.bitcast_offset_float32(self.wv.size()))
318
- self.wo = Matrix3(config.n_layers, config.dim, config.dim)
319
  self.wo.set_buf_ptr(buf.bitcast_offset_float32(self.wo.size()))
320
  self.rms_ffn_weight = Matrix(config.n_layers, config.dim)
321
  self.rms_ffn_weight.set_buf_ptr(
322
  buf.bitcast_offset_float32(self.rms_ffn_weight.size())
323
  )
324
- self.w1 = Matrix3(config.n_layers, config.dim, config.hidden_dim)
325
  self.w1.set_buf_ptr(buf.bitcast_offset_float32(self.w1.size()))
326
- self.w2 = Matrix3(config.n_layers, config.dim, config.hidden_dim)
327
  self.w2.set_buf_ptr(buf.bitcast_offset_float32(self.w2.size()))
328
- self.w3 = Matrix3(config.n_layers, config.dim, config.hidden_dim)
329
  self.w3.set_buf_ptr(buf.bitcast_offset_float32(self.w3.size()))
330
  self.rms_final_weight = Matrix(config.dim)
331
  self.rms_final_weight.set_buf_ptr(
@@ -435,22 +429,14 @@ fn softmax(inout x: BufferPtrFloat32, size: Int) -> None:
435
  x.offset(i).simd_store[1](0, xi / ssum)
436
 
437
 
438
- fn matmul_naive(C: Matrix, x: Matrix, w: Matrix) -> None:
439
- # W(d,n) @ X(n,) -> C (d,)
440
- # By far the most amount of time is spent inside this little function
441
- for i in range(w.rows):
442
- C[i] = 0.0
443
- for j in range(w.cols):
444
- C[i] += x[j] * w[i, j]
445
-
446
-
447
- fn matmul_vectorized(C: Matrix, A: Matrix, B: Matrix):
448
- for i in range(0, B.rows):
449
  var tmp = SIMD[DType.float32, nelts](0)
450
 
451
  @parameter
452
  fn dot[_nelts: Int](j: Int):
453
- if _nelts < nelts: # take care of tail array elements with length < nelts
454
  tmp[0] += (A.load[_nelts](j) * B.load[_nelts](i, j)).reduce_add()
455
  else:
456
  tmp += A.load[nelts](j) * B.load[nelts](i, j)
@@ -458,28 +444,12 @@ fn matmul_vectorized(C: Matrix, A: Matrix, B: Matrix):
458
  vectorize[nelts, dot](B.cols)
459
  C[i] = tmp.reduce_add()
460
 
461
- fn matmul_parallelized(C: Matrix, A: Matrix, B: Matrix):
462
- @parameter
463
- fn calc_row(i: Int):
464
- var T = BufferPtrFloat32.alloc(nelts)
465
- var Tbuf = Buffer[nelts, DType.float32](T)
466
- memset_zero(T, nelts)
467
- @parameter
468
- fn dot[nelts: Int](j: Int):
469
- T.simd_store[nelts](
470
- 0, T.simd_load[nelts](0) + A.load[nelts](j) * B.load[nelts](i, j)
471
- )
472
-
473
- vectorize[nelts, dot](B.cols)
474
- C[i] = sum[nelts, DType.float32](Tbuf)
475
 
476
- parallelize[calc_row](B.rows)
477
 
478
-
479
- fn matmul(inout C: Matrix, A: Matrix, B: Matrix) -> None:
480
  # B (d,n) @ A (n,) -> C (d,)
481
- matmul_vectorized(C, A, B)
482
- # matmul_parallelized(C, A, B)
483
 
484
 
485
  fn transformer(
@@ -513,13 +483,13 @@ fn transformer(
513
 
514
  # QKV matmuls for this position
515
  tmpw.set_buf_ptr(weights.wq.data.offset(l * dim * dim), dim, dim)
516
- matmul(state.q, state.xb, tmpw)
517
 
518
  tmpw.set_buf_ptr(weights.wk.data.offset(l * dim * dim), dim, dim)
519
- matmul(state.k, state.xb, tmpw)
520
 
521
  tmpw.set_buf_ptr(weights.wv.data.offset(l * dim * dim), dim, dim)
522
- matmul(state.v, state.xb, tmpw)
523
 
524
  # Apply RoPE rotation to the q and k vectors for each head
525
  for h in range(config.n_heads):
@@ -587,7 +557,7 @@ fn transformer(
587
  xb.offset(i).simd_store[1](0, xbi)
588
  # Final matrix multiplication to get the output of the attention
589
  tmpw.set_buf_ptr(weights.wo.data.offset(l * dim * dim), dim, dim)
590
- matmul(state.xb2, state.xb, tmpw)
591
 
592
  # Residual connection back into x
593
  accum(x, state.xb2.data, dim)
@@ -597,10 +567,10 @@ fn transformer(
597
 
598
  # Calculate self.w1(x) and self.w3(x) for FFN
599
  tmpw.set_buf_ptr(weights.w1.data.offset(l * dim * hidden_dim), hidden_dim, dim)
600
- matmul(state.hb, state.xb, tmpw)
601
 
602
  tmpw.set_buf_ptr(weights.w3.data.offset(l * dim * hidden_dim), hidden_dim, dim)
603
- matmul(state.hb2, state.xb, tmpw)
604
 
605
  # Apply SiLU activation function (silu(x) = x * sigmoid(x))
606
  for i in range(hidden_dim):
@@ -613,7 +583,7 @@ fn transformer(
613
 
614
  # Final matrix multiplication to get the output of the FFN
615
  tmpw.set_buf_ptr(weights.w2.data.offset(l * dim * hidden_dim), dim, hidden_dim)
616
- matmul(state.xb, state.hb, tmpw)
617
 
618
  # Residual connection
619
  accum(x, state.xb.data, dim)
@@ -623,7 +593,7 @@ fn transformer(
623
 
624
  # Classifier into logits
625
  tmpw.set_buf_ptr(weights.wcls.data, config.vocab_size, dim)
626
- matmul(state.logits, state.x, tmpw)
627
 
628
 
629
  fn argmax(v: Matrix) -> Int:
@@ -651,6 +621,59 @@ fn sample(probabilities: Matrix) -> Int:
651
  return n - 1 # In case of rounding errors
652
 
653
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
654
  fn print_str(s: PointerString):
655
  # print all chars till null character
656
  var p: Int = 0
@@ -664,15 +687,61 @@ fn time_in_ms() -> Int:
664
  return time.now() // 1_000_000
665
 
666
 
 
 
 
 
 
 
 
 
 
 
667
  fn main() raises:
668
- print("num hardware threads: ", num_cores(), " SIMD vector width: ", nelts)
669
- let checkpoint = "stories15M.bin"
670
- # let checkpoint = "stories110M.bin"
671
- let tokenizer = "tokenizer.bin"
672
- let temperature = 0.0
673
  var steps = 256
674
- let prompt = ""
675
- let rng_seed: Int = time.now()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
676
  random.seed(rng_seed)
677
  var fbuf: FileBuf = FileBuf()
678
  var tbuf: FileBuf = FileBuf()
@@ -702,39 +771,45 @@ fn main() raises:
702
  # Create and initialize the application RunState
703
  var state = RunState(config)
704
 
 
 
 
 
 
 
705
  # Start the main loop
706
  var start = 0 # Used to time our code, only initialized after the first iteration
707
  var next_token = 0 # Will store the next token in the sequence
708
  # Initialize with token 1 (=BOS), as done in Llama-2 sentencepiece tokenizer
709
  var token = 1
710
- var pos = 0 # Position in the sequence
711
- # Explicitly print the initial BOS token for stylistic symmetry reasons
712
-
713
- print("<s>")
714
 
 
 
715
  while pos < steps:
716
  # Forward the transformer to get logits for the next token
717
  transformer(token, pos, config, state, weights)
718
 
719
- # Sample the next token
720
- if temperature == 0.0:
721
- # Greedy argmax sampling: take the token with the highest probability
722
- next_token = argmax(state.logits)
723
  else:
724
- # Apply the temperature to the logits
725
- for q in range(config.vocab_size):
726
- state.logits[q] = state.logits[q] / temperature
727
- # Apply softmax to the logits to get the probabilities for the next token
728
- softmax(state.logits.data, config.vocab_size)
729
- # Sample from this distribution to get the next token
730
- next_token = sample(state.logits)
 
 
 
 
 
731
 
732
  var token_str: PointerString = tok.vocab[next_token]
733
  if token == 1 and token_str[0] == ord(" "):
734
  token_str = token_str.offset(1)
735
 
736
  print_str(token_str)
737
- # flush?
738
 
739
  # Advance forward
740
  token = next_token
 
1
+ from algorithm import sum
2
+ from algorithm import vectorize, parallelize
3
+ from builtin import string
4
  from math import round
 
 
5
  from memory import memset_zero, memcpy
6
+ from memory.buffer import Buffer
7
  from memory.unsafe import DTypePointer
8
+ from python import Python
9
  from random import rand
 
 
 
 
 
 
 
 
10
  from read import BufReader, File
11
+ from runtime.llcl import num_cores, Runtime
12
+ from sys import argv
 
13
 
14
  # The SIMD vector width.
15
+ from sys.info import simdwidthof
16
+ import math
17
+ import os
18
+ import random
19
+ import time
20
 
21
  alias nelts = (2 * simdwidthof[DType.float32]())
22
 
 
26
  alias PointerStrings = Pointer[PointerString]
27
 
28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  struct Matrix:
30
  var data: BufferPtrFloat32
31
  var rows: Int
32
  var cols: Int
33
+ var layers: Int
34
  var allocated: Int
35
 
36
  fn __init__(inout self, rows: Int, cols: Int):
37
  self.data = BufferPtrFloat32.alloc(0)
38
  self.rows = rows
39
  self.cols = cols
40
+ self.layers = 1
41
  self.allocated = 0
42
 
43
  fn __init__(inout self, cols: Int):
44
  self.data = BufferPtrFloat32.alloc(0)
45
  self.rows = 1
46
+ self.layers = 1
47
  self.cols = cols
48
  self.allocated = 0
49
 
50
+ fn __init__(inout self, layers: Int, rows: Int, cols: Int):
51
+ self.__init__(rows, cols)
52
+ self.layers = layers
53
+
54
  fn __del__(owned self):
55
  if self.allocated == 1:
56
  self.data.free()
57
 
58
+ @always_inline
59
  fn alloc(inout self, fill: Int = 0):
60
  self.data = BufferPtrFloat32.alloc(self.size())
61
  self.allocated = 1
62
  if fill == 1:
63
  self.zero()
64
 
65
+ @always_inline
66
  fn alloc_zero(inout self):
67
  self.alloc(1)
68
 
69
+ @always_inline
70
  fn zero(inout self):
71
+ memset_zero(self.data, self.size())
72
 
73
+ @always_inline
74
  fn set_buf_ptr(inout self, ptr: BufferPtrFloat32):
75
  self.data = ptr
76
 
 
80
  self.rows = rows
81
  self.cols = cols
82
 
83
+ @always_inline
84
  fn size(inout self) -> Int:
85
+ return self.cols * self.rows * self.layers
86
 
87
  @always_inline
88
  fn __getitem__(self, y: Int, x: Int) -> Float32:
 
116
  fn store[nelts: Int](self, x: Int, val: SIMD[DType.float32, nelts]):
117
  self.data.simd_store[nelts](x, val)
118
 
119
+ @always_inline
120
+ fn __getitem__(self, z: Int, y: Int, x: Int) -> Float32:
121
+ return self.load[1](z, y, x)
122
+
123
+ @always_inline
124
+ fn load[nelts: Int](self, z: Int, y: Int, x: Int) -> SIMD[DType.float32, nelts]:
125
+ return self.data.simd_load[nelts](z * self.layers + y * self.cols + x)
126
+
127
+ @always_inline
128
+ fn __setitem__(self, z: Int, y: Int, x: Int, val: Float32):
129
+ return self.store[1](z, y, x, val)
130
+
131
+ @always_inline
132
+ fn store[nelts: Int](self, z: Int, y: Int, x: Int, val: SIMD[DType.float32, nelts]):
133
+ self.data.simd_store[nelts](z * self.layers + y * self.cols + x, val)
134
+
135
 
136
  fn read_val_int(inout buf: FileBuf) -> Int:
137
  # DTypePointer[DType.ui8](buf.data).bitcast[DType.ui8]()
 
158
  return str
159
 
160
 
161
+ # not optimal concat
162
+ fn str_concat(s1: PointerString, s2: PointerString) -> PointerString:
163
+ var l1 = 0
164
+ var l2 = 0
165
+
166
+ while s1[l1] != 0:
167
+ l1 += 1
168
+ while s2[l2] != 0:
169
+ l2 += 1
170
+
171
+ let str = PointerString.alloc(l1 + l2)
172
+ memcpy[UInt8](str, s1, l1)
173
+ memcpy[UInt8](str.offset(l1), s2, l2)
174
+ str.store(l1 + l2, 0)
175
+ return str
176
+
177
+
178
+ fn str_to_ptr(s: String) -> PointerString:
179
+ let ret = PointerString.alloc(len(s) + 1)
180
+ for i in range(len(s)):
181
+ ret.store(i, ord(s[i]))
182
+ ret.store(len(s), 0)
183
+ return ret
184
+
185
+
186
  struct FileBuf:
187
  var data: BufferPtrType
188
  var offset: Int
 
245
  var v: Matrix # value (dim,)
246
  var att: Matrix # buffer for scores/attention values (n_heads, seq_len)
247
  var logits: Matrix # output logits
248
+ var key_cache: Matrix # (layer, seq_len, dim)
249
+ var value_cache: Matrix # (layer, seq_len, dim)
250
+ var rt: Runtime
251
 
252
  fn __init__(inout self, config: Config):
253
  self.x = Matrix(config.dim)
 
270
  self.att.alloc_zero()
271
  self.logits = Matrix(config.vocab_size)
272
  self.logits.alloc_zero()
273
+ self.key_cache = Matrix(config.n_layers, config.seq_len, config.dim)
274
  self.key_cache.alloc_zero()
275
+ self.value_cache = Matrix(config.n_layers, config.seq_len, config.dim)
276
  self.value_cache.alloc_zero()
277
+ self.rt = Runtime(num_cores() // 2)
278
 
279
 
280
  struct TransformerWeights:
 
282
  var freq_cis_real: Matrix
283
  var freq_cis_imag: Matrix
284
  var rms_att_weight: Matrix
285
+ var wq: Matrix
286
+ var wk: Matrix
287
+ var wv: Matrix
288
+ var wo: Matrix
289
  var rms_ffn_weight: Matrix
290
+ var w1: Matrix
291
+ var w3: Matrix
292
+ var w2: Matrix
293
  var rms_final_weight: Matrix
294
  var wcls: Matrix
295
 
 
303
  self.rms_att_weight.set_buf_ptr(
304
  buf.bitcast_offset_float32(self.rms_att_weight.size())
305
  )
306
+ self.wq = Matrix(config.n_layers, config.dim, config.dim)
307
  self.wq.set_buf_ptr(buf.bitcast_offset_float32(self.wq.size()))
308
+ self.wk = Matrix(config.n_layers, config.dim, config.dim)
309
  self.wk.set_buf_ptr(buf.bitcast_offset_float32(self.wk.size()))
310
+ self.wv = Matrix(config.n_layers, config.dim, config.dim)
311
  self.wv.set_buf_ptr(buf.bitcast_offset_float32(self.wv.size()))
312
+ self.wo = Matrix(config.n_layers, config.dim, config.dim)
313
  self.wo.set_buf_ptr(buf.bitcast_offset_float32(self.wo.size()))
314
  self.rms_ffn_weight = Matrix(config.n_layers, config.dim)
315
  self.rms_ffn_weight.set_buf_ptr(
316
  buf.bitcast_offset_float32(self.rms_ffn_weight.size())
317
  )
318
+ self.w1 = Matrix(config.n_layers, config.dim, config.hidden_dim)
319
  self.w1.set_buf_ptr(buf.bitcast_offset_float32(self.w1.size()))
320
+ self.w2 = Matrix(config.n_layers, config.dim, config.hidden_dim)
321
  self.w2.set_buf_ptr(buf.bitcast_offset_float32(self.w2.size()))
322
+ self.w3 = Matrix(config.n_layers, config.dim, config.hidden_dim)
323
  self.w3.set_buf_ptr(buf.bitcast_offset_float32(self.w3.size()))
324
  self.rms_final_weight = Matrix(config.dim)
325
  self.rms_final_weight.set_buf_ptr(
 
429
  x.offset(i).simd_store[1](0, xi / ssum)
430
 
431
 
432
+ fn matmul_parallelized(C: Matrix, A: Matrix, B: Matrix, rt: Runtime):
433
+ @parameter
434
+ fn compute_row(i: Int):
 
 
 
 
 
 
 
 
435
  var tmp = SIMD[DType.float32, nelts](0)
436
 
437
  @parameter
438
  fn dot[_nelts: Int](j: Int):
439
+ if _nelts < nelts: # take care of tail array elements with length < nelts
440
  tmp[0] += (A.load[_nelts](j) * B.load[_nelts](i, j)).reduce_add()
441
  else:
442
  tmp += A.load[nelts](j) * B.load[nelts](i, j)
 
444
  vectorize[nelts, dot](B.cols)
445
  C[i] = tmp.reduce_add()
446
 
447
+ parallelize[compute_row](rt, B.rows, rt.parallelism_level())
 
 
 
 
 
 
 
 
 
 
 
 
 
448
 
 
449
 
450
+ fn matmul(inout C: Matrix, A: Matrix, B: Matrix, rt: Runtime) -> None:
 
451
  # B (d,n) @ A (n,) -> C (d,)
452
+ matmul_parallelized(C, A, B, rt)
 
453
 
454
 
455
  fn transformer(
 
483
 
484
  # QKV matmuls for this position
485
  tmpw.set_buf_ptr(weights.wq.data.offset(l * dim * dim), dim, dim)
486
+ matmul(state.q, state.xb, tmpw, state.rt)
487
 
488
  tmpw.set_buf_ptr(weights.wk.data.offset(l * dim * dim), dim, dim)
489
+ matmul(state.k, state.xb, tmpw, state.rt)
490
 
491
  tmpw.set_buf_ptr(weights.wv.data.offset(l * dim * dim), dim, dim)
492
+ matmul(state.v, state.xb, tmpw, state.rt)
493
 
494
  # Apply RoPE rotation to the q and k vectors for each head
495
  for h in range(config.n_heads):
 
557
  xb.offset(i).simd_store[1](0, xbi)
558
  # Final matrix multiplication to get the output of the attention
559
  tmpw.set_buf_ptr(weights.wo.data.offset(l * dim * dim), dim, dim)
560
+ matmul(state.xb2, state.xb, tmpw, state.rt)
561
 
562
  # Residual connection back into x
563
  accum(x, state.xb2.data, dim)
 
567
 
568
  # Calculate self.w1(x) and self.w3(x) for FFN
569
  tmpw.set_buf_ptr(weights.w1.data.offset(l * dim * hidden_dim), hidden_dim, dim)
570
+ matmul(state.hb, state.xb, tmpw, state.rt)
571
 
572
  tmpw.set_buf_ptr(weights.w3.data.offset(l * dim * hidden_dim), hidden_dim, dim)
573
+ matmul(state.hb2, state.xb, tmpw, state.rt)
574
 
575
  # Apply SiLU activation function (silu(x) = x * sigmoid(x))
576
  for i in range(hidden_dim):
 
583
 
584
  # Final matrix multiplication to get the output of the FFN
585
  tmpw.set_buf_ptr(weights.w2.data.offset(l * dim * hidden_dim), dim, hidden_dim)
586
+ matmul(state.xb, state.hb, tmpw, state.rt)
587
 
588
  # Residual connection
589
  accum(x, state.xb.data, dim)
 
593
 
594
  # Classifier into logits
595
  tmpw.set_buf_ptr(weights.wcls.data, config.vocab_size, dim)
596
+ matmul(state.logits, state.x, tmpw, state.rt)
597
 
598
 
599
  fn argmax(v: Matrix) -> Int:
 
621
  return n - 1 # In case of rounding errors
622
 
623
 
624
+ fn str_lookup(str: PointerString, tok: Tokenizer) -> Int:
625
+ for pos in range(tok.vocab_size):
626
+ let s1 = tok.vocab[pos]
627
+ var p1 = 0
628
+ while s1[p1] != 0 and str[p1] != 0:
629
+ if s1[p1] != str[p1]:
630
+ break
631
+ p1 += 1
632
+ if s1[p1] != 0 or str[p1] != 0:
633
+ continue
634
+ return pos
635
+ return -1
636
+
637
+
638
+ fn bpe_encode(inout tokens: DynamicVector[Int], text: String, tok: Tokenizer):
639
+ for pos in range(len(text)):
640
+ let char = str_to_ptr(text[pos])
641
+ let id = str_lookup(char, tok)
642
+
643
+ if id == -1:
644
+ print("Not a good prompt token at pos ", pos)
645
+ return
646
+ tokens.push_back(id)
647
+
648
+ while True:
649
+ var best_score = Float32(-1e10)
650
+ var best_id = -1
651
+ var best_idx = -1
652
+
653
+ for i in range(len(tokens) - 1):
654
+ # Check if we can merge the pair (tokens[i], tokens[i+1])
655
+ let str = str_concat(tok.vocab[tokens[i]], tok.vocab[tokens[i + 1]])
656
+ let id = str_lookup(str, tok)
657
+ if id != -1 and tok.vocab_scores.load(id) > best_score:
658
+ best_score = tok.vocab_scores.load(id)
659
+ best_id = id
660
+ best_idx = i
661
+
662
+ if best_idx == -1:
663
+ # We couldn't find any more pairs to merge, so we're done
664
+ break
665
+
666
+ # Merge the consecutive pair (best_idx, best_idx+1) into new token best_id
667
+ tokens[best_idx] = best_id
668
+ # Delete token at position best_idx+1, shift the entire sequence back 1
669
+ var _tokens = DynamicVector[Int]()
670
+ for i in range(0, best_idx + 1):
671
+ _tokens.push_back(tokens[i])
672
+ for i in range(best_idx + 2, len(tokens)):
673
+ _tokens.push_back(tokens[i])
674
+ tokens = _tokens
675
+
676
+
677
  fn print_str(s: PointerString):
678
  # print all chars till null character
679
  var p: Int = 0
 
687
  return time.now() // 1_000_000
688
 
689
 
690
+ fn print_usage():
691
+ print("Usage: mojo llama2.mojo <checkpoint> [options]")
692
+ print("Example: mojo llama2.mojo stories15M.bin -s 99 -n 256 -t 0.5 -i \"Llama is an animal\"")
693
+ print("Options:")
694
+ print(" -s <int> random seed, default time.now()")
695
+ print(" -t <float> temperature in [0,1.0], default 1.0")
696
+ print(" -n <int> number of steps to run for, default 256. 0 = max_seq_len")
697
+ print(" -i <string> input prompt")
698
+
699
+
700
  fn main() raises:
701
+ print("num hardware threads: ", num_cores())
702
+ print("SIMD vector width: ", nelts)
703
+ var tokenizer = StringRef("tokenizer.bin")
704
+ var checkpoint = StringRef("stories15M.bin")
705
+ var temperature = 0.9
706
  var steps = 256
707
+ var prompt = String("")
708
+ var rng_seed: Int = time.now()
709
+
710
+ @parameter
711
+ fn argparse() raises -> Int:
712
+ let args = argv()
713
+ if len(args) < 2:
714
+ return 0
715
+ checkpoint = args[1]
716
+ for i in range(2, len(args), 2):
717
+ if args[i] == "-p":
718
+ print("Option not supported: ", args[i])
719
+ if args[i] == "-n":
720
+ steps = atol(args[i + 1])
721
+ if args[i] == "-s":
722
+ rng_seed = atol(args[i + 1])
723
+ if args[i] == "-i":
724
+ prompt = args[i + 1]
725
+ if args[i] == "-t":
726
+ let val = args[i + 1]
727
+ temperature = 0.0
728
+ # hacky parse float, keep only 1 digit
729
+ for c in range(0, len(val)):
730
+ if val[c] == ".":
731
+ temperature += atol(val[c + 1]) * 0.1
732
+ break
733
+ else:
734
+ temperature = atol(val[c])
735
+ if temperature < -1e9 or temperature > (1 + 1e9):
736
+ print("Wrong temperature value", temperature)
737
+ return 0
738
+ return 1
739
+
740
+ let res = argparse()
741
+ if res == 0:
742
+ print_usage()
743
+ return
744
+
745
  random.seed(rng_seed)
746
  var fbuf: FileBuf = FileBuf()
747
  var tbuf: FileBuf = FileBuf()
 
771
  # Create and initialize the application RunState
772
  var state = RunState(config)
773
 
774
+ # Process the prompt, if any
775
+ var prompt_tokens = DynamicVector[Int]()
776
+
777
+ if prompt:
778
+ bpe_encode(prompt_tokens, prompt, tok)
779
+
780
  # Start the main loop
781
  var start = 0 # Used to time our code, only initialized after the first iteration
782
  var next_token = 0 # Will store the next token in the sequence
783
  # Initialize with token 1 (=BOS), as done in Llama-2 sentencepiece tokenizer
784
  var token = 1
 
 
 
 
785
 
786
+ # Position in the sequence
787
+ var pos = 0
788
  while pos < steps:
789
  # Forward the transformer to get logits for the next token
790
  transformer(token, pos, config, state, weights)
791
 
792
+ if pos < len(prompt_tokens):
793
+ next_token = prompt_tokens[pos]
 
 
794
  else:
795
+ # Sample the next token
796
+ if temperature == 0.0:
797
+ # Greedy argmax sampling: take the token with the highest probability
798
+ next_token = argmax(state.logits)
799
+ else:
800
+ # Apply the temperature to the logits
801
+ for q in range(config.vocab_size):
802
+ state.logits[q] = state.logits[q] / temperature
803
+ # Apply softmax to the logits to get the probabilities for the next token
804
+ softmax(state.logits.data, config.vocab_size)
805
+ # Sample from this distribution to get the next token
806
+ next_token = sample(state.logits)
807
 
808
  var token_str: PointerString = tok.vocab[next_token]
809
  if token == 1 and token_str[0] == ord(" "):
810
  token_str = token_str.offset(1)
811
 
812
  print_str(token_str)
 
813
 
814
  # Advance forward
815
  token = next_token