Kernels
github-actions[bot] commited on
Commit
6ec5093
·
1 Parent(s): 811726c

Add built binary [skip-build]

Browse files
Files changed (41) hide show
  1. build/torch28-cxx11-cu126-x86_64-linux/optimizer/_ops.py +3 -3
  2. build/torch28-cxx11-cu126-x86_64-linux/optimizer/{_optimizer_b0230e7_dirty.abi3.so → _optimizer_811726c_dirty.abi3.so} +2 -2
  3. build/torch28-cxx11-cu126-x86_64-linux/optimizer/muon.py +5 -0
  4. build/torch28-cxx11-cu128-x86_64-linux/optimizer/_ops.py +3 -3
  5. build/torch28-cxx11-cu128-x86_64-linux/optimizer/{_optimizer_b0230e7_dirty.abi3.so → _optimizer_811726c_dirty.abi3.so} +2 -2
  6. build/torch28-cxx11-cu128-x86_64-linux/optimizer/muon.py +5 -0
  7. build/torch28-cxx11-cu129-x86_64-linux/optimizer/_ops.py +3 -3
  8. build/torch28-cxx11-cu129-x86_64-linux/optimizer/{_optimizer_b0230e7_dirty.abi3.so → _optimizer_811726c_dirty.abi3.so} +2 -2
  9. build/torch28-cxx11-cu129-x86_64-linux/optimizer/muon.py +5 -0
  10. build/torch28-cxx11-rocm63-x86_64-linux/optimizer/_ops.py +3 -3
  11. build/torch28-cxx11-rocm63-x86_64-linux/optimizer/{_optimizer_b0230e7_dirty.abi3.so → _optimizer_811726c_dirty.abi3.so} +2 -2
  12. build/torch28-cxx11-rocm63-x86_64-linux/optimizer/muon.py +5 -0
  13. build/torch28-cxx11-rocm64-x86_64-linux/optimizer/_ops.py +3 -3
  14. build/torch28-cxx11-rocm64-x86_64-linux/optimizer/_optimizer_811726c_dirty.abi3.so +3 -0
  15. build/torch28-cxx11-rocm64-x86_64-linux/optimizer/_optimizer_b0230e7_dirty.abi3.so +0 -3
  16. build/torch28-cxx11-rocm64-x86_64-linux/optimizer/muon.py +5 -0
  17. build/torch29-cxx11-cu126-x86_64-linux/optimizer/__init__.py +5 -0
  18. build/torch29-cxx11-cu126-x86_64-linux/optimizer/_ops.py +9 -0
  19. build/torch29-cxx11-cu126-x86_64-linux/optimizer/_optimizer_811726c_dirty.abi3.so +3 -0
  20. build/torch29-cxx11-cu126-x86_64-linux/optimizer/matmul_transpose_triton.py +128 -0
  21. build/torch29-cxx11-cu126-x86_64-linux/optimizer/muon.py +1069 -0
  22. build/torch29-cxx11-cu128-x86_64-linux/optimizer/__init__.py +5 -0
  23. build/torch29-cxx11-cu128-x86_64-linux/optimizer/_ops.py +9 -0
  24. build/torch29-cxx11-cu128-x86_64-linux/optimizer/_optimizer_811726c_dirty.abi3.so +3 -0
  25. build/torch29-cxx11-cu128-x86_64-linux/optimizer/matmul_transpose_triton.py +128 -0
  26. build/torch29-cxx11-cu128-x86_64-linux/optimizer/muon.py +1069 -0
  27. build/torch29-cxx11-cu130-x86_64-linux/optimizer/__init__.py +5 -0
  28. build/torch29-cxx11-cu130-x86_64-linux/optimizer/_ops.py +9 -0
  29. build/torch29-cxx11-cu130-x86_64-linux/optimizer/_optimizer_811726c_dirty.abi3.so +3 -0
  30. build/torch29-cxx11-cu130-x86_64-linux/optimizer/matmul_transpose_triton.py +128 -0
  31. build/torch29-cxx11-cu130-x86_64-linux/optimizer/muon.py +1069 -0
  32. build/torch29-cxx11-rocm63-x86_64-linux/optimizer/__init__.py +5 -0
  33. build/torch29-cxx11-rocm63-x86_64-linux/optimizer/_ops.py +9 -0
  34. build/torch29-cxx11-rocm63-x86_64-linux/optimizer/_optimizer_811726c_dirty.abi3.so +3 -0
  35. build/torch29-cxx11-rocm63-x86_64-linux/optimizer/matmul_transpose_triton.py +128 -0
  36. build/torch29-cxx11-rocm63-x86_64-linux/optimizer/muon.py +1069 -0
  37. build/torch29-cxx11-rocm64-x86_64-linux/optimizer/__init__.py +5 -0
  38. build/torch29-cxx11-rocm64-x86_64-linux/optimizer/_ops.py +9 -0
  39. build/torch29-cxx11-rocm64-x86_64-linux/optimizer/_optimizer_811726c_dirty.abi3.so +3 -0
  40. build/torch29-cxx11-rocm64-x86_64-linux/optimizer/matmul_transpose_triton.py +128 -0
  41. build/torch29-cxx11-rocm64-x86_64-linux/optimizer/muon.py +1069 -0
build/torch28-cxx11-cu126-x86_64-linux/optimizer/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _optimizer_b0230e7_dirty
3
- ops = torch.ops._optimizer_b0230e7_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_optimizer_b0230e7_dirty::{op_name}"
 
1
  import torch
2
+ from . import _optimizer_811726c_dirty
3
+ ops = torch.ops._optimizer_811726c_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_optimizer_811726c_dirty::{op_name}"
build/torch28-cxx11-cu126-x86_64-linux/optimizer/{_optimizer_b0230e7_dirty.abi3.so → _optimizer_811726c_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:69525fcbfbe640264f4d52c9843b395b17f1828d38e1eceb97cec6bf46b0d8d0
3
- size 1824256
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:511199ac2ae46febc8aeeb96e843a748da7d6fdea4922572ccf27ee5eabe312d
3
+ size 1816064
build/torch28-cxx11-cu126-x86_64-linux/optimizer/muon.py CHANGED
@@ -606,6 +606,11 @@ class Muon(torch.optim.Optimizer):
606
 
607
  if p.placements == (Shard(dim=0), ):
608
  # Case for FSDP
 
 
 
 
 
609
  return p.device_mesh.mesh, p.device_mesh.get_group(mesh_dim=0)
610
  elif p.placements == (Replicate(), Shard(dim=0)):
611
  # Case for HSDP
 
606
 
607
  if p.placements == (Shard(dim=0), ):
608
  # Case for FSDP
609
+ process_group = p.device_mesh.get_group(mesh_dim=0)
610
+ if self.rank is None:
611
+ self.rank = dist.get_rank(group=process_group)
612
+ else:
613
+ assert self.rank == dist.get_rank(group=process_group)
614
  return p.device_mesh.mesh, p.device_mesh.get_group(mesh_dim=0)
615
  elif p.placements == (Replicate(), Shard(dim=0)):
616
  # Case for HSDP
build/torch28-cxx11-cu128-x86_64-linux/optimizer/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _optimizer_b0230e7_dirty
3
- ops = torch.ops._optimizer_b0230e7_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_optimizer_b0230e7_dirty::{op_name}"
 
1
  import torch
2
+ from . import _optimizer_811726c_dirty
3
+ ops = torch.ops._optimizer_811726c_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_optimizer_811726c_dirty::{op_name}"
build/torch28-cxx11-cu128-x86_64-linux/optimizer/{_optimizer_b0230e7_dirty.abi3.so → _optimizer_811726c_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:331cc0bc5ee469afdfe0fc590bf52910c118cd0cec62ccbf85778c12ae367a95
3
- size 1883344
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b3cdb515b6c56204224cc307b66d34fcee1cd5e27b4117197a71b784d34fadc5
3
+ size 1871056
build/torch28-cxx11-cu128-x86_64-linux/optimizer/muon.py CHANGED
@@ -606,6 +606,11 @@ class Muon(torch.optim.Optimizer):
606
 
607
  if p.placements == (Shard(dim=0), ):
608
  # Case for FSDP
 
 
 
 
 
609
  return p.device_mesh.mesh, p.device_mesh.get_group(mesh_dim=0)
610
  elif p.placements == (Replicate(), Shard(dim=0)):
611
  # Case for HSDP
 
606
 
607
  if p.placements == (Shard(dim=0), ):
608
  # Case for FSDP
609
+ process_group = p.device_mesh.get_group(mesh_dim=0)
610
+ if self.rank is None:
611
+ self.rank = dist.get_rank(group=process_group)
612
+ else:
613
+ assert self.rank == dist.get_rank(group=process_group)
614
  return p.device_mesh.mesh, p.device_mesh.get_group(mesh_dim=0)
615
  elif p.placements == (Replicate(), Shard(dim=0)):
616
  # Case for HSDP
build/torch28-cxx11-cu129-x86_64-linux/optimizer/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _optimizer_b0230e7_dirty
3
- ops = torch.ops._optimizer_b0230e7_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_optimizer_b0230e7_dirty::{op_name}"
 
1
  import torch
2
+ from . import _optimizer_811726c_dirty
3
+ ops = torch.ops._optimizer_811726c_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_optimizer_811726c_dirty::{op_name}"
build/torch28-cxx11-cu129-x86_64-linux/optimizer/{_optimizer_b0230e7_dirty.abi3.so → _optimizer_811726c_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:f6ba7ad9228edcce4bf49173562b0796f1657eb734ddd6e23ca773c153eefce2
3
- size 1883344
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b957f60eab442d3ff5a5525d16a1b4b71e8c6be32edb874d9a5681953c61f0c2
3
+ size 1871056
build/torch28-cxx11-cu129-x86_64-linux/optimizer/muon.py CHANGED
@@ -606,6 +606,11 @@ class Muon(torch.optim.Optimizer):
606
 
607
  if p.placements == (Shard(dim=0), ):
608
  # Case for FSDP
 
 
 
 
 
609
  return p.device_mesh.mesh, p.device_mesh.get_group(mesh_dim=0)
610
  elif p.placements == (Replicate(), Shard(dim=0)):
611
  # Case for HSDP
 
606
 
607
  if p.placements == (Shard(dim=0), ):
608
  # Case for FSDP
609
+ process_group = p.device_mesh.get_group(mesh_dim=0)
610
+ if self.rank is None:
611
+ self.rank = dist.get_rank(group=process_group)
612
+ else:
613
+ assert self.rank == dist.get_rank(group=process_group)
614
  return p.device_mesh.mesh, p.device_mesh.get_group(mesh_dim=0)
615
  elif p.placements == (Replicate(), Shard(dim=0)):
616
  # Case for HSDP
build/torch28-cxx11-rocm63-x86_64-linux/optimizer/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _optimizer_b0230e7_dirty
3
- ops = torch.ops._optimizer_b0230e7_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_optimizer_b0230e7_dirty::{op_name}"
 
1
  import torch
2
+ from . import _optimizer_811726c_dirty
3
+ ops = torch.ops._optimizer_811726c_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_optimizer_811726c_dirty::{op_name}"
build/torch28-cxx11-rocm63-x86_64-linux/optimizer/{_optimizer_b0230e7_dirty.abi3.so → _optimizer_811726c_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:649c9c1ca7360650167cc191e373b271a4138161ec40b1e881a87515f82a613f
3
- size 1750000
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:898ff08457f77c2f6ef504c73570cc87c5c5fd9a144528dbf8af4c03ffc21049
3
+ size 1749232
build/torch28-cxx11-rocm63-x86_64-linux/optimizer/muon.py CHANGED
@@ -606,6 +606,11 @@ class Muon(torch.optim.Optimizer):
606
 
607
  if p.placements == (Shard(dim=0), ):
608
  # Case for FSDP
 
 
 
 
 
609
  return p.device_mesh.mesh, p.device_mesh.get_group(mesh_dim=0)
610
  elif p.placements == (Replicate(), Shard(dim=0)):
611
  # Case for HSDP
 
606
 
607
  if p.placements == (Shard(dim=0), ):
608
  # Case for FSDP
609
+ process_group = p.device_mesh.get_group(mesh_dim=0)
610
+ if self.rank is None:
611
+ self.rank = dist.get_rank(group=process_group)
612
+ else:
613
+ assert self.rank == dist.get_rank(group=process_group)
614
  return p.device_mesh.mesh, p.device_mesh.get_group(mesh_dim=0)
615
  elif p.placements == (Replicate(), Shard(dim=0)):
616
  # Case for HSDP
build/torch28-cxx11-rocm64-x86_64-linux/optimizer/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _optimizer_b0230e7_dirty
3
- ops = torch.ops._optimizer_b0230e7_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_optimizer_b0230e7_dirty::{op_name}"
 
1
  import torch
2
+ from . import _optimizer_811726c_dirty
3
+ ops = torch.ops._optimizer_811726c_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_optimizer_811726c_dirty::{op_name}"
build/torch28-cxx11-rocm64-x86_64-linux/optimizer/_optimizer_811726c_dirty.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:72d100180fd73094f7b1c6e765eb4a77f103ad392fdee571687cb0c66d304177
3
+ size 1749320
build/torch28-cxx11-rocm64-x86_64-linux/optimizer/_optimizer_b0230e7_dirty.abi3.so DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:42b60753dab0948f4009893fcf3a8b080ad00e0436cbdaf0995dc29ae066c0c7
3
- size 1750088
 
 
 
 
build/torch28-cxx11-rocm64-x86_64-linux/optimizer/muon.py CHANGED
@@ -606,6 +606,11 @@ class Muon(torch.optim.Optimizer):
606
 
607
  if p.placements == (Shard(dim=0), ):
608
  # Case for FSDP
 
 
 
 
 
609
  return p.device_mesh.mesh, p.device_mesh.get_group(mesh_dim=0)
610
  elif p.placements == (Replicate(), Shard(dim=0)):
611
  # Case for HSDP
 
606
 
607
  if p.placements == (Shard(dim=0), ):
608
  # Case for FSDP
609
+ process_group = p.device_mesh.get_group(mesh_dim=0)
610
+ if self.rank is None:
611
+ self.rank = dist.get_rank(group=process_group)
612
+ else:
613
+ assert self.rank == dist.get_rank(group=process_group)
614
  return p.device_mesh.mesh, p.device_mesh.get_group(mesh_dim=0)
615
  elif p.placements == (Replicate(), Shard(dim=0)):
616
  # Case for HSDP
build/torch29-cxx11-cu126-x86_64-linux/optimizer/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .muon import Muon
2
+
3
+ __all__ = [
4
+ "Muon",
5
+ ]
build/torch29-cxx11-cu126-x86_64-linux/optimizer/_ops.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from . import _optimizer_811726c_dirty
3
+ ops = torch.ops._optimizer_811726c_dirty
4
+
5
+ def add_op_namespace_prefix(op_name: str):
6
+ """
7
+ Prefix op by namespace.
8
+ """
9
+ return f"_optimizer_811726c_dirty::{op_name}"
build/torch29-cxx11-cu126-x86_64-linux/optimizer/_optimizer_811726c_dirty.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:87c8e75ead1c831dabfce1abbd7c100aa72c9b2988dfc0e1554216ca8005267c
3
+ size 1816064
build/torch29-cxx11-cu126-x86_64-linux/optimizer/matmul_transpose_triton.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MIT License
2
+ #
3
+ # Copyright (c) 2025 Tianyang Lin
4
+ #
5
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ # of this software and associated documentation files (the "Software"), to deal
7
+ # in the Software without restriction, including without limitation the rights
8
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ # copies of the Software, and to permit persons to whom the Software is
10
+ # furnished to do so, subject to the following conditions:
11
+ #
12
+ # The above copyright notice and this permission notice shall be included in all
13
+ # copies or substantial portions of the Software.
14
+ #
15
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ # SOFTWARE.
22
+
23
+ import torch
24
+ import triton
25
+ import triton.language as tl
26
+
27
+
28
+ def get_autotune_config():
29
+ return [
30
+ triton.Config(
31
+ {
32
+ 'BLOCK_SIZE_M': blk_m,
33
+ 'BLOCK_SIZE_K': blk_k,
34
+ 'GROUP_SIZE_M': grp_sz
35
+ },
36
+ num_stages=n_stages,
37
+ num_warps=n_warps) for blk_m in [32, 64, 128]
38
+ for blk_k in [32, 64] for grp_sz in [8] for n_stages in [3, 4, 5]
39
+ for n_warps in [4, 8]
40
+ ]
41
+
42
+
43
+ @triton.autotune(
44
+ configs=get_autotune_config(),
45
+ key=['M', 'K'],
46
+ )
47
+ @triton.jit
48
+ def mmt_kernel(x, y, M, K, stride_xm, stride_xk, stride_ym, stride_yn,
49
+ BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
50
+ GROUP_SIZE_M: tl.constexpr):
51
+ """
52
+ Core kernel jit function of matmul_transpose that computes y = x @ x.T
53
+ The code is a simple adaptation from the triton `matmul` tutorial:
54
+ https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html
55
+ """
56
+ pid = tl.program_id(axis=0)
57
+ num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
58
+ num_pid_n = tl.cdiv(M, BLOCK_SIZE_M)
59
+ num_pid_in_group = GROUP_SIZE_M * num_pid_n
60
+ group_id = pid // num_pid_in_group
61
+ first_pid_m = group_id * GROUP_SIZE_M
62
+ group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
63
+ pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
64
+ pid_n = (pid % num_pid_in_group) // group_size_m
65
+ if pid_m > pid_n:
66
+ return
67
+
68
+ offs_xm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
69
+ offs_xn = (pid_n * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
70
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
71
+ # we use a & b ptrs to denote different rows of x.
72
+ a_ptrs = x + (offs_xm[:, None] * stride_xm + offs_k[None, :] * stride_xk)
73
+ b_ptrs = x + (offs_xn[:, None] * stride_xm + offs_k[None, :] * stride_xk)
74
+
75
+ accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_M), dtype=tl.float32)
76
+
77
+ for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
78
+ a = tl.load(a_ptrs,
79
+ mask=offs_k[None, :] < K - k * BLOCK_SIZE_K,
80
+ other=0.0)
81
+ b = tl.load(b_ptrs,
82
+ mask=offs_k[None, :] < K - k * BLOCK_SIZE_K,
83
+ other=0.0)
84
+ accumulator = tl.dot(a, tl.permute(b, (1, 0)), accumulator)
85
+ a_ptrs += BLOCK_SIZE_K * stride_xk
86
+ b_ptrs += BLOCK_SIZE_K * stride_xk
87
+ # use dtype.element_ty to accommodate different input datatypes as in cpp templates
88
+ # https://github.com/triton-lang/triton/issues/2252
89
+ c = accumulator.to(x.dtype.element_ty)
90
+
91
+ offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
92
+ offs_cn = pid_n * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
93
+ c_ptrs = y + stride_ym * offs_cm[:, None] + stride_yn * offs_cn[None, :]
94
+ c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M)
95
+ tl.store(c_ptrs, c, mask=c_mask)
96
+
97
+ # transpose and copy
98
+ if pid_m < pid_n:
99
+ ct_ptrs = y + stride_ym * offs_cn[:,
100
+ None] + stride_yn * offs_cm[None, :]
101
+ ct_mask = (offs_cn[:, None] < M) & (offs_cm[None, :] < M)
102
+ tl.store(ct_ptrs, tl.permute(c, (1, 0)), mask=ct_mask)
103
+
104
+
105
+ def matmul_transpose_assign(d_in, d_out):
106
+ assert d_in.is_cuda, "Input `d_in` must be a CUDA tensor"
107
+ assert d_out.is_cuda, "Input `d_out` must be a CUDA tensor"
108
+ assert d_in.device == d_out.device, "Inputs `d_in` and `d_out` must be on the same CUDA device"
109
+ assert d_in.dtype == d_out.dtype, "Inputs must have the same data type"
110
+ assert d_in.ndim == 2, "Input `d_in` must be a 2D tensor"
111
+ assert d_out.ndim == 2, "Input `d_out` must be a 2D tensor"
112
+ assert d_in.size(0) == d_out.size(0) == d_out.size(0), \
113
+ "First dimension of `d_in` must match first and second dimension of `d_out`"
114
+
115
+ d_in = d_in.contiguous()
116
+ M, K = d_in.shape
117
+ grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(
118
+ M, META['BLOCK_SIZE_M']), )
119
+ with torch.cuda.device(d_in.device.index):
120
+ mmt_kernel[grid](d_in, d_out, M, K, d_in.stride(0), d_in.stride(1),
121
+ d_out.stride(0), d_out.stride(1))
122
+
123
+
124
+ def matmul_transpose(d_in):
125
+ M, _ = d_in.shape
126
+ d_out = torch.empty((M, M), device=d_in.device, dtype=d_in.dtype)
127
+ matmul_transpose_assign(d_in, d_out)
128
+ return d_out
build/torch29-cxx11-cu126-x86_64-linux/optimizer/muon.py ADDED
@@ -0,0 +1,1069 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import math
3
+ import types
4
+ from dataclasses import dataclass
5
+ from typing import List, Optional, Union, cast
6
+
7
+ import torch
8
+ import torch.distributed as dist
9
+ from torch.distributed._tensor import DTensor, Replicate, Shard
10
+
11
+ from .matmul_transpose_triton import matmul_transpose_assign
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+ COMM_DTYPE = torch.bfloat16
16
+
17
+
18
+ # This code snippet is a modified version adapted from the following GitHub repositories:
19
+ # https://github.com/KellerJordan/Muon/blob/master/muon.py
20
+ # Muon's Newton–Schulz iteration causes high variance in singular values
21
+ # Idea: give each iteration its own 3 coefficients and optimize them via gradient descent.
22
+ @torch.no_grad()
23
+ # matmul_transpose_assign from : https://github.com/nil0x9/flash-muon
24
+ def _zeropower_via_newtonschulz5(G, steps):
25
+ """
26
+ Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
27
+ quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
28
+ of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
29
+ zero even beyond the point where the iteration no longer converges all the way to one everywhere
30
+ on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
31
+ where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
32
+ performance at all relative to UV^T, where USV^T = G is the SVD.
33
+ """
34
+ assert len(G.shape) == 2
35
+ assert G.dtype == COMM_DTYPE
36
+ X = G # no manual typecast
37
+
38
+ if G.size(0) > G.size(1):
39
+ X = X.T
40
+ # Ensure spectral norm is at most 1
41
+ X = X / (X.norm() + 1e-7)
42
+ buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
43
+ buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
44
+ # Perform the NS iterations
45
+ for a, b, c in [
46
+ (4.0848, -6.8946, 2.9270),
47
+ (3.9505, -6.3029, 2.6377),
48
+ (3.7418, -5.5913, 2.3037),
49
+ (2.8769, -3.1427, 1.2046),
50
+ (2.8366, -3.0525, 1.2012),
51
+ ]:
52
+ matmul_transpose_assign(X, buf1)
53
+ matmul_transpose_assign(buf1, buf2)
54
+ buf1.mul_(b).add_(buf2, alpha=c)
55
+ X = torch.addmm(X, buf1, X, alpha=1.0, beta=a)
56
+
57
+ if G.size(0) > G.size(1):
58
+ X = X.T
59
+ return X
60
+
61
+
62
+ @dataclass
63
+ class _muon_state:
64
+ # TODO: use Optional
65
+ worker_rank: int | None = None
66
+ gathered_grad: torch.Tensor | None = None
67
+ scattered_u: DTensor | None = None
68
+ computed_u: torch.Tensor | None = None
69
+ gather_event: torch.cuda.Event | None = None
70
+ compute_event: torch.cuda.Event | None = None
71
+ scatter_event: torch.cuda.Event | None = None
72
+ process_group = None
73
+ qk_clip_state = None
74
+
75
+
76
+ def split_elems_for_src(param, src_rank, num_ranks) -> int:
77
+ rows = param.shape[0]
78
+ cols = int(param.numel() // rows)
79
+ base, rem = divmod(rows, num_ranks)
80
+ my_rows = base + (1 if src_rank < rem else 0)
81
+ return my_rows * cols
82
+
83
+
84
+ @torch.no_grad()
85
+ def _alloc_gathered_grad(params, param_to_state, rank, compute_stream):
86
+ """
87
+ Pre-allocate gathered_grad buffer on compute_stream
88
+ before launching all2all gather
89
+ """
90
+ with torch.cuda.stream(compute_stream):
91
+ for p in params:
92
+ state = param_to_state[id(p)]
93
+ if rank == state.worker_rank:
94
+ num_ranks = dist.get_world_size(group=state.process_group)
95
+ state.gathered_grad = torch.empty(p.grad.numel(),
96
+ dtype=COMM_DTYPE,
97
+ device="cuda")
98
+ else:
99
+ state.gathered_grad = None
100
+
101
+ alloc_event = torch.cuda.Event()
102
+ alloc_event.record(compute_stream)
103
+ return alloc_event
104
+
105
+
106
+ @torch.no_grad()
107
+ def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad,
108
+ alloc_event):
109
+ """
110
+ All2all gathers shards so each owner rank reconstructs its full gradient
111
+ """
112
+ with torch.cuda.stream(comm_stream):
113
+ process_group = param_to_state[id(params[0])].process_group
114
+ num_ranks = dist.get_world_size(group=process_group)
115
+
116
+ # Construct sending buffers
117
+ per_dst = [[] for _ in range(num_ranks)]
118
+ send_counts = [0] * num_ranks
119
+
120
+ for p in params:
121
+ state = param_to_state[id(p)]
122
+ dst = state.worker_rank
123
+ assert dst < num_ranks
124
+ shard_elems = split_elems_for_src(p, rank, num_ranks)
125
+ g = p.grad
126
+ g = g.to_local().to(COMM_DTYPE).contiguous().view(-1)
127
+ assert g.numel() == shard_elems
128
+ per_dst[dst].append(g)
129
+ send_counts[dst] += shard_elems
130
+
131
+ assert any(
132
+ len(v) > 0 for v in per_dst
133
+ ), "At least one destination rank must receive a sharded tensor"
134
+ # list[list[Tensor]] -> list[Tensor]
135
+ per_dst = [t for dst in per_dst for t in dst]
136
+
137
+ send_buf = torch.cat(per_dst, dim=0)
138
+
139
+ owned_params = [
140
+ p for p in params if param_to_state[id(p)].worker_rank == rank
141
+ ]
142
+
143
+ # Compute receive sizes and allocate receiving buffers
144
+ recv_counts = [0] * num_ranks
145
+
146
+ for src in range(num_ranks):
147
+ total = 0
148
+ for p in owned_params:
149
+ state = param_to_state[id(p)]
150
+ assert state.worker_rank == rank
151
+ total += split_elems_for_src(p, src, num_ranks)
152
+ recv_counts[src] = total
153
+
154
+ recv_total = sum(recv_counts)
155
+ recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda")
156
+
157
+ #All2All
158
+ dist.all_to_all_single(
159
+ recv_buf,
160
+ send_buf,
161
+ output_split_sizes=recv_counts,
162
+ input_split_sizes=send_counts,
163
+ group=process_group,
164
+ )
165
+
166
+ # Reconstructs gathered grad from the received buffer
167
+ #
168
+ # recv_buf (num ranks = 3)
169
+ #
170
+ # From rank 0 From rank 1 From rank 2
171
+ # | p1_0, p2_0, p3_0 | p1_1, p2_1, p3_1 | p1_2, p2_2, p3_2 |
172
+ #
173
+ # Outer loop:
174
+ # rank 0 -> rank 1 -> rank2
175
+ #
176
+ # Inner loop:
177
+ # p1_n -> p2_n -> p3_n
178
+
179
+ comm_stream.wait_event(alloc_event)
180
+
181
+ off = 0
182
+ write_offsets = {id(p): 0 for p in owned_params}
183
+ for src in range(num_ranks):
184
+ if recv_counts[src] == 0:
185
+ continue
186
+
187
+ block = recv_counts[src]
188
+ inner_off = 0
189
+ for p in owned_params:
190
+ state = param_to_state[id(p)]
191
+ assert state.worker_rank == rank
192
+ n = split_elems_for_src(p, src, num_ranks)
193
+ assert n > 0
194
+
195
+ sg = recv_buf.narrow(0, off + inner_off, n)
196
+ woff = write_offsets[id(p)]
197
+ dst = state.gathered_grad.narrow(0, woff, n)
198
+ dst.copy_(sg)
199
+
200
+ write_offsets[id(p)] += n
201
+ inner_off += n
202
+ off += block
203
+
204
+ for p in params:
205
+ state = param_to_state[id(p)]
206
+ if state.worker_rank == rank:
207
+ state.gathered_grad = state.gathered_grad.view_as(p)
208
+ state.gather_event = torch.cuda.Event()
209
+ state.gather_event.record(comm_stream)
210
+ else:
211
+ state.gathered_grad = None
212
+ state.gather_event = None
213
+ if none_grad:
214
+ p.grad = None
215
+
216
+
217
+ @torch.no_grad()
218
+ def _compute_u(p, state, steps, rank, compute_stream):
219
+ """
220
+ On worker_rank, compute the orthogonalized update using Newton-Schulz iteration.
221
+ """
222
+ with torch.cuda.stream(compute_stream):
223
+ if rank == state.worker_rank:
224
+ if state.gather_event is None:
225
+ raise RuntimeError("Gather event must be set before compute.")
226
+ compute_stream.wait_event(state.gather_event)
227
+ u = _zeropower_via_newtonschulz5(state.gathered_grad, steps)
228
+ state.gathered_grad = None
229
+ state.computed_u = u
230
+ state.compute_event = torch.cuda.Event()
231
+ state.compute_event.record()
232
+ else:
233
+ state.computed_u = None
234
+ state.compute_event = None
235
+
236
+
237
+ @torch.no_grad()
238
+ def _alloc_scattered_u(params, param_to_state, rank, compute_stream):
239
+ """
240
+ Pre-allocate scattered_u buffer on compute_stream
241
+ before launching all2all gather
242
+ """
243
+ with torch.cuda.stream(compute_stream):
244
+ for p in params:
245
+ state = param_to_state[id(p)]
246
+ state.scattered_u = torch.empty_like(p.to_local(),
247
+ dtype=COMM_DTYPE)
248
+
249
+ alloc_event = torch.cuda.Event()
250
+ alloc_event.record(compute_stream)
251
+ return alloc_event
252
+
253
+
254
+ def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event):
255
+ """
256
+ All2all scatters full gradients to all ranks
257
+ """
258
+ with torch.cuda.stream(comm_stream):
259
+ process_group = param_to_state[id(params[0])].process_group
260
+ num_ranks = dist.get_world_size(group=process_group)
261
+ owned_params = [
262
+ p for p in params if param_to_state[id(p)].worker_rank == rank
263
+ ]
264
+
265
+ # Construct sending buffer
266
+ per_dst = [[] for _ in range(num_ranks)]
267
+ send_counts = [0] * num_ranks
268
+
269
+ if owned_params:
270
+ for p in owned_params:
271
+ state = param_to_state[id(p)]
272
+ if state.compute_event is None:
273
+ raise RuntimeError(
274
+ "Compute event must be set before scatter.")
275
+ comm_stream.wait_event(state.compute_event)
276
+ state.gathered_grad = None
277
+
278
+ assert state.computed_u is not None
279
+
280
+ u_full = state.computed_u.to(COMM_DTYPE).contiguous().view(-1)
281
+
282
+ offset = 0
283
+ for dst in range(num_ranks):
284
+ n = split_elems_for_src(p, dst, num_ranks)
285
+ assert n > 0
286
+
287
+ su = u_full.narrow(0, offset, n)
288
+ per_dst[dst].append(su)
289
+ send_counts[dst] += n
290
+ offset += n
291
+
292
+ assert offset == u_full.numel()
293
+
294
+ lengths = [len(v) for v in per_dst]
295
+ if all(l > 0 for l in lengths):
296
+ assert all(
297
+ l == lengths[0] for l in lengths
298
+ ), "All destination ranks must have the same number of sharded tensor"
299
+ # list[list[Tensor]] -> list[Tensor]
300
+ per_dst = [t for dst in per_dst for t in dst]
301
+ send_buf = torch.cat(per_dst, dim=0)
302
+ else:
303
+ # all_to_all requires participation from all ranks
304
+ # Even non-owner ranks must join the collective call
305
+ send_buf = torch.empty(0, dtype=COMM_DTYPE, device="cuda")
306
+
307
+ # Compute receive sizes and allocate receiving buffers
308
+ recv_counts = [0] * num_ranks
309
+
310
+ for src in range(num_ranks):
311
+ total = 0
312
+ for p in params:
313
+ state = param_to_state[id(p)]
314
+ if state.worker_rank != src:
315
+ continue
316
+ total += split_elems_for_src(p, rank, num_ranks)
317
+ recv_counts[src] = total
318
+
319
+ recv_total = sum(recv_counts)
320
+ assert recv_total > 0
321
+ recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda")
322
+
323
+ #All2All
324
+ dist.all_to_all_single(
325
+ recv_buf,
326
+ send_buf,
327
+ output_split_sizes=recv_counts,
328
+ input_split_sizes=send_counts,
329
+ group=process_group,
330
+ )
331
+
332
+ # Copy to pre-allocated scattered_u buffer from the received buffer
333
+ #
334
+ # recv_buf (num ranks = 3, local_rank = 0)
335
+ #
336
+ # From rank 0 From rank 1 From rank 2
337
+ # | p1_0, p2_0, p3_0 | p4_0 | p5_0, p6_0 |
338
+ #
339
+ # Outer loop:
340
+ # rank 0 -> rank 1 -> rank2
341
+ #
342
+ # Inner loop:
343
+ # src(0) : p1_0 -> p2_0 -> p3_0
344
+ # src(1) : p4_0
345
+ # src(2) : p5_0 -> p6_0
346
+
347
+ comm_stream.wait_event(alloc_event)
348
+
349
+ off = 0
350
+ for src in range(num_ranks):
351
+ block = recv_counts[src]
352
+ if block == 0:
353
+ continue
354
+
355
+ inner_off = 0
356
+ for p in params:
357
+ state = param_to_state[id(p)]
358
+ if state.worker_rank != src:
359
+ continue
360
+ n = split_elems_for_src(p, rank, num_ranks)
361
+ assert n > 0
362
+
363
+ flat_local = recv_buf.narrow(0, off + inner_off,
364
+ n).view_as(p.to_local())
365
+ state.scattered_u.copy_(flat_local)
366
+
367
+ state.scatter_event = torch.cuda.Event()
368
+ state.scatter_event.record(comm_stream)
369
+ inner_off += n
370
+
371
+ assert inner_off == block
372
+ off += block
373
+
374
+
375
+ def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
376
+ compute_stream):
377
+ """
378
+ Update sharded parameter p with the scattered_u.
379
+ Only worker_rank frees computed_u.
380
+ """
381
+ with torch.cuda.stream(compute_stream):
382
+ if state.scatter_event is None:
383
+ raise RuntimeError("Scatter event must be set before update")
384
+ compute_stream.wait_event(state.scatter_event)
385
+ u_dtensor = DTensor.from_local(
386
+ state.scattered_u,
387
+ placements=p.placements,
388
+ device_mesh=p.device_mesh,
389
+ )
390
+
391
+ state.scattered_u = u_dtensor
392
+
393
+ if rank == state.worker_rank:
394
+ # Free computed_u
395
+ state.computed_u = None
396
+
397
+ Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay)
398
+ state.scattered_u = None
399
+ u_dtensor = None
400
+
401
+ scales_full = Muon._compute_scales(p, state.qk_clip_state)
402
+ if scales_full is not None:
403
+ num_ranks = dist.get_world_size(group=state.process_group)
404
+ local_rank = dist.get_rank(group=state.process_group)
405
+ scales_local = scales_full.chunk(num_ranks, dim=0)[local_rank]
406
+ scales_local = DTensor.from_local(
407
+ scales_local,
408
+ placements=p.placements,
409
+ device_mesh=p.device_mesh,
410
+ )
411
+ Muon._qk_clip(p, scales_local, state.qk_clip_state.head_dim)
412
+
413
+
414
+ def default_is_muon(name, x):
415
+ skip_keys = ["embed_tokens", "lm_head", "tok_embeddings", "output"]
416
+ return x.ndim >= 2 and not any(key in name for key in skip_keys)
417
+
418
+
419
+ def get_default_muon_param_groups(model, is_muon_func=default_is_muon):
420
+ muon_params, muon_names = [], []
421
+ non_muon_params = []
422
+
423
+ for n, p in model.named_parameters():
424
+ if not p.requires_grad:
425
+ continue
426
+ if is_muon_func(n, p):
427
+ muon_params.append(p)
428
+ muon_names.append(n)
429
+ else:
430
+ non_muon_params.append(p)
431
+
432
+ return [
433
+ {
434
+ "params": muon_params,
435
+ "names": muon_names,
436
+ "use_muon": True,
437
+ },
438
+ {
439
+ "params": non_muon_params,
440
+ "use_muon": False,
441
+ },
442
+ ]
443
+
444
+
445
+ def parse_qk_layer(name: str) -> tuple[str | None, int]:
446
+ """
447
+ Parse a parameter name to check if it is a query/key projection layer
448
+ ('wq', 'wk', 'q_proj', 'k_proj') and return (kind, layer_index).
449
+
450
+ Returns:
451
+ (kind, layer_idx) or (None, -1) if not matched.
452
+
453
+ Example:
454
+ 'model.3.attn.wq.weight' -> ('wq', 3)
455
+ 'model.5.attn.wk.weight' -> ('wk', 5)
456
+ 'model.2.attn.q_proj.weight' -> ('q_proj', 2)
457
+ 'model.7.attn.k_proj.weight' -> ('k_proj', 7)
458
+ 'model.4.attn.v_proj.weight' -> (None, -1)
459
+ """
460
+ parts = name.split('.')
461
+ if len(parts) < 3:
462
+ return None, -1
463
+
464
+ kind = parts[-2]
465
+
466
+ layer_idx = -1
467
+ for part in reversed(parts):
468
+ if part.isdigit():
469
+ layer_idx = int(part)
470
+ break
471
+
472
+ if kind in ('wq', 'wk', 'q_proj', 'k_proj'):
473
+ return kind, layer_idx
474
+
475
+ return None, -1
476
+
477
+
478
+ @dataclass
479
+ class QKClipInfo:
480
+ """Per-parameter dynamic info computed from config + runtime logits."""
481
+ kind: Optional[str] # 'wq'/'q_proj' or 'wk'/'k_proj' or None
482
+ indices: List[int] # which heads to consider for clipping
483
+ head_dim: int # from config
484
+ threshold: float # from config
485
+ logit: Optional[torch.Tensor]
486
+
487
+
488
+ class Muon(torch.optim.Optimizer):
489
+ """
490
+ Muon - MomentUm Orthogonalized by Newton-schulz
491
+
492
+ Muon internally runs standard SGD-momentum, and then performs an orthogonalization post-
493
+ processing step, in which each 2D parameter's update is replaced with the nearest orthogonal
494
+ matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has
495
+ the advantage that it can be stably run in bfloat16 on the GPU.
496
+
497
+ Some warnings:
498
+ - We believe this optimizer is unlikely to work well for training with small batch size.
499
+ - We believe it may not work well for finetuning pretrained models, but we haven't tested this.
500
+
501
+ Arguments:
502
+ model: The model to be optimized by Muon.
503
+ is_muon_func: A function that takes a parameter and its name, and returns whether the parameter should be optimized by Muon.
504
+ lr: The learning rate. The updates will have spectral norm of `lr`. (0.02 is a good default)
505
+ momentum: The momentum used by the internal SGD. (0.95 is a good default)
506
+ nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended)
507
+ ns_steps: The number of Newton-Schulz iterations to run. (6 is probably always enough)
508
+ weight_decay: The weight decay for Muon and AdamW.
509
+ {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW as well.
510
+ adamw_lr: The learning rate for the internal AdamW.
511
+ adamw_betas: The betas for the internal AdamW.
512
+ adamw_eps: The epsilon for the internal AdamW.
513
+ none_grad: Whether to set p.grad to None after gathering the gradients. This can save memory.
514
+ debug: Whether to print debug information.
515
+ clip_info : Configuration for QK clipping. Expected keys:
516
+ - "q_indices" (list[int]): Indices of query heads to consider.
517
+ - "k_indices" (list[int]): Indices of key heads to consider.
518
+ - "head_dim" (int): Dimensionality of each attention head.
519
+ - "threshold" (float): Threshold value; heads whose QK logits exceed
520
+ this value will be scaled down.
521
+ Default is:
522
+ {
523
+ "q_indices": [],
524
+ "k_indices": [],
525
+ "head_dim": 128,
526
+ "threshold": 100
527
+ }
528
+ overlap_step : How many all2all gather, compute operations are launched in advance
529
+ before the corresponding all2all scatter steps begin.
530
+ A higher overlap_step increases memory usage but can improve
531
+ performance by overlapping communication.
532
+ Parallel muon only.
533
+ """
534
+
535
+ def __init__(self,
536
+ params,
537
+ lr=1e-3,
538
+ momentum=0.95,
539
+ nesterov=True,
540
+ ns_steps=5,
541
+ weight_decay=0.1,
542
+ adamw_betas=(0.9, 0.95),
543
+ adamw_eps=1e-8,
544
+ none_grad=True,
545
+ debug=False,
546
+ clip_config={
547
+ "q_indices": [],
548
+ "k_indices": [],
549
+ "head_dim": 128,
550
+ "threshold": 100
551
+ },
552
+ overlap_step=5):
553
+ defaults = dict(
554
+ lr=lr,
555
+ weight_decay=weight_decay,
556
+ momentum=momentum,
557
+ nesterov=nesterov,
558
+ ns_steps=ns_steps,
559
+ adamw_betas=adamw_betas,
560
+ adamw_eps=adamw_eps,
561
+ none_grad=none_grad,
562
+ use_muon=True,
563
+ )
564
+ error_message = "The key 'use_muon' is not set in parameter group {idx}. Assuming all parameters in the group will use muon optimization, which may lead to unexpected behavior."
565
+ instruction_code = "\n\n please follow this code snippet \n```optimizer = get_kernel('motif-technologies/optimizer')\n\n\nparams = optimizer.muon.get_default_muon_param_groups(model)\n\noptim = optimizer.Muon(params, ...)```"
566
+
567
+ if isinstance(params, types.GeneratorType):
568
+ raise ValueError(error_message.format(idx=0) + instruction_code)
569
+ for _idx, param_group in enumerate(params):
570
+ if param_group.get("use_muon", None) is None:
571
+ raise ValueError(
572
+ error_message.format(idx=_idx) + instruction_code)
573
+
574
+ super().__init__(params, defaults)
575
+
576
+ self.rank = None
577
+
578
+ self.comm_stream = torch.cuda.Stream()
579
+ self.compute_stream = torch.cuda.Stream()
580
+ self.debug = debug
581
+ self.clip_config = clip_config
582
+ self.overlap_step = overlap_step
583
+
584
+ def _calc_flops(self, G, steps):
585
+ assert len(G.shape) == 2
586
+ M, N = G.shape
587
+ if M > N:
588
+ M, N = N, M
589
+
590
+ return steps * ((M**3) * 2 + (M**2 * N) * 4 + M * N * 2 + M**2 * 3)
591
+
592
+ def adjust_lr_for_muon(self, lr, param_shape):
593
+ A, B = param_shape[:2]
594
+ # We adjust the learning rate and weight decay based on the size of the parameter matrix
595
+ # as describted in the paper
596
+ adjusted_ratio = 0.2 * math.sqrt(max(A, B))
597
+ adjusted_lr = lr * adjusted_ratio
598
+ return adjusted_lr
599
+
600
+ def get_shard_mesh(self, p):
601
+ """
602
+ Get the shard mesh for a parameter p on the given rank.
603
+ """
604
+ assert isinstance(
605
+ p, DTensor), "Parallel Muon only supports DTensor parameters."
606
+
607
+ if p.placements == (Shard(dim=0), ):
608
+ # Case for FSDP
609
+ process_group = p.device_mesh.get_group(mesh_dim=0)
610
+ if self.rank is None:
611
+ self.rank = dist.get_rank(group=process_group)
612
+ else:
613
+ assert self.rank == dist.get_rank(group=process_group)
614
+ return p.device_mesh.mesh, p.device_mesh.get_group(mesh_dim=0)
615
+ elif p.placements == (Replicate(), Shard(dim=0)):
616
+ # Case for HSDP
617
+ process_group = p.device_mesh.get_group(mesh_dim=1)
618
+ if self.rank is None:
619
+ self.rank = dist.get_rank(group=process_group)
620
+ else:
621
+ assert self.rank == dist.get_rank(group=process_group)
622
+ for i, shard_mesh in enumerate(p.device_mesh.mesh):
623
+ if self.rank in shard_mesh:
624
+ return shard_mesh, p.device_mesh.get_group(mesh_dim=1)
625
+ else:
626
+ raise ValueError(f"Unsupported placements ({p.placements}).")
627
+
628
+ def init_state_and_assign_params(self, names, params, group, qk_logits):
629
+ param_to_state = {}
630
+ param_to_flops = {}
631
+
632
+ total_flops = 0
633
+ for p in params:
634
+ g = p.grad
635
+ if g is None:
636
+ continue
637
+ assert g.ndim == 2, "Muon only supports 2D parameters."
638
+
639
+ flops = self._calc_flops(g, group["ns_steps"])
640
+ param_to_flops[id(p)] = flops
641
+ total_flops += flops
642
+
643
+ if self.debug:
644
+ print(f"Total TFLOPs for Muon: {total_flops / 1e12:.2f} TFLOPs",
645
+ flush=True)
646
+
647
+ paired = list(zip(names, params))
648
+
649
+ paired_sorted = sorted(paired,
650
+ key=lambda x: param_to_flops[id(x[1])],
651
+ reverse=True)
652
+
653
+ names_sorted, params_sorted = zip(*paired_sorted)
654
+ ordered_names = list(names_sorted)
655
+ ordered_params = list(params_sorted)
656
+
657
+ round_robin = 0
658
+ mesh = None
659
+ shard_mesh = None
660
+ process_group = None
661
+ for n, p in zip(ordered_names, ordered_params):
662
+ if mesh is None:
663
+ mesh = p.device_mesh
664
+ shard_mesh, process_group = self.get_shard_mesh(p)
665
+ elif mesh != p.device_mesh:
666
+ raise ValueError("All parameters must be on the same mesh.")
667
+ num_ranks = dist.get_world_size(group=process_group)
668
+ param_to_state[id(p)] = _muon_state()
669
+ param_to_state[id(
670
+ p)].worker_rank = shard_mesh[round_robin].item() % num_ranks
671
+ param_to_state[id(p)].process_group = process_group
672
+ qk_clip_state = self.get_qk_clip_info(n, qk_logits)
673
+ param_to_state[id(p)].qk_clip_state = qk_clip_state
674
+ round_robin = (round_robin + 1) % len(shard_mesh)
675
+
676
+ return param_to_state, ordered_params
677
+
678
+ def base(self, names, params, group, lr, weight_decay, momentum,
679
+ qk_logits):
680
+ # generate weight updates in distributed fashion
681
+ for n, p in zip(names, params):
682
+ g = p.grad
683
+ if g is None:
684
+ continue
685
+ if g.ndim > 2:
686
+ g = g.view(g.size(0), -1)
687
+ assert g is not None
688
+
689
+ # calc update
690
+ state = self.state[p]
691
+ if "momentum_buffer" not in state:
692
+ state["momentum_buffer"] = torch.zeros_like(g)
693
+ buf = state["momentum_buffer"]
694
+ buf.mul_(momentum).add_(g)
695
+ if group["nesterov"]:
696
+ g = g.add(buf, alpha=momentum)
697
+ else:
698
+ g = buf
699
+
700
+ u = _zeropower_via_newtonschulz5(g.to(COMM_DTYPE),
701
+ steps=group["ns_steps"])
702
+
703
+ adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
704
+ Muon._update_p(p, u, lr, adjusted_lr, weight_decay)
705
+
706
+ qk_clip_state = self.get_qk_clip_info(n, qk_logits)
707
+
708
+ scales_full = self._compute_scales(p, qk_clip_state)
709
+ if scales_full is not None:
710
+ Muon._qk_clip(p, scales_full, qk_clip_state.head_dim)
711
+
712
+ def _update_g(self, p, g, group, momentum):
713
+ # calc update
714
+ state = self.state[p]
715
+ buf = state.setdefault("momentum_buffer", torch.zeros_like(g))
716
+ torch.add(g, buf, alpha=momentum, out=buf)
717
+ if group["nesterov"]:
718
+ g.add_(buf, alpha=momentum)
719
+ return g
720
+ return buf
721
+
722
+ @staticmethod
723
+ def _update_p(p, u, lr, adjusted_lr, weight_decay):
724
+ # apply weight decay
725
+ p.data.mul_(1 - lr * weight_decay)
726
+ # apply update
727
+ p.data.add_(u, alpha=-adjusted_lr)
728
+
729
+ def get_qk_clip_info(self, n, qk_logits):
730
+ head_dim = self.clip_config.get('head_dim')
731
+ threshold = self.clip_config.get('threshold')
732
+ kind, layer_idx = parse_qk_layer(n)
733
+
734
+ logit, indices = None, []
735
+ if qk_logits is not None and kind is not None:
736
+ logit = qk_logits[layer_idx]
737
+ indices_key = 'q_indices' if 'q' in kind else 'k_indices'
738
+ indices = self.clip_config.get(indices_key, []) or []
739
+
740
+ return QKClipInfo(
741
+ kind=kind,
742
+ indices=indices,
743
+ head_dim=head_dim,
744
+ threshold=threshold,
745
+ logit=logit,
746
+ )
747
+
748
+ @staticmethod
749
+ def _compute_scales(p, qk_clip_state):
750
+ kind = qk_clip_state.kind
751
+ indices = qk_clip_state.indices
752
+ head_dim = qk_clip_state.head_dim
753
+ threshold = qk_clip_state.threshold
754
+ logit = qk_clip_state.logit
755
+
756
+ H_global = p.shape[0] // head_dim
757
+ scales_full = torch.ones(H_global, device=p.data.device)
758
+ scaling = 0
759
+
760
+ for logit_idx, head_idx in enumerate(indices):
761
+ v_ele = float(logit[logit_idx])
762
+ if v_ele > threshold:
763
+ new_scale = math.sqrt(threshold / v_ele)
764
+ if new_scale < scales_full[head_idx]:
765
+ scales_full[head_idx] = new_scale
766
+ logger.info(
767
+ f"[{kind}] Head {head_idx} exceeded threshold "
768
+ f"(value={v_ele:.4f}, threshold={threshold:.4f}) -> applying scale={new_scale:.4f}"
769
+ )
770
+ scaling += 1
771
+
772
+ return scales_full if scaling > 0 else None
773
+
774
+ @staticmethod
775
+ def _qk_clip(p, scales, head_dim):
776
+ W = p.data.view(-1, head_dim, p.data.shape[1])
777
+ W.mul_(scales.view(-1, 1, 1))
778
+
779
+ def parallel(self, names, params, group, lr, weight_decay, momentum,
780
+ qk_logits):
781
+ """
782
+ Perform a parallel optimization step using Muon.
783
+ """
784
+
785
+ for p in params:
786
+ g = p.grad
787
+ if g is None:
788
+ continue
789
+ if g.ndim > 2:
790
+ g = g.view(g.size(0), -1)
791
+
792
+ # Update g in the local rank
793
+ g = self._update_g(
794
+ p,
795
+ g,
796
+ group,
797
+ momentum=momentum,
798
+ )
799
+ p.grad = g
800
+
801
+ param_to_state, ordered_params = self.init_state_and_assign_params(
802
+ names, params, group, qk_logits)
803
+
804
+ assert self.rank is not None
805
+
806
+ def enqueue_all2all_gather(start_idx, chunk_size):
807
+ target_params = ordered_params[start_idx:start_idx + chunk_size]
808
+ if target_params:
809
+ alloc_event = _alloc_gathered_grad(target_params,
810
+ param_to_state, self.rank,
811
+ self.compute_stream)
812
+ _all2all_gather(target_params, param_to_state, self.rank,
813
+ self.comm_stream, group["none_grad"],
814
+ alloc_event)
815
+
816
+ def enqueue_computes(start_idx, chunk_size):
817
+ for p in ordered_params[start_idx:start_idx + chunk_size]:
818
+ state = param_to_state[id(p)]
819
+ _compute_u(p, state, group["ns_steps"], self.rank,
820
+ self.compute_stream)
821
+
822
+ def enqueue_all2all_scatter(start_idx, chunk_size):
823
+ target_params = ordered_params[start_idx:start_idx + chunk_size]
824
+ if target_params:
825
+ alloc_event = _alloc_scattered_u(target_params, param_to_state,
826
+ self.rank,
827
+ self.compute_stream)
828
+ _all2all_scatter(target_params, param_to_state, self.rank,
829
+ self.comm_stream, alloc_event)
830
+
831
+ def enqueue_update_param(start_idx, chunk_size):
832
+ for p in ordered_params[start_idx:start_idx + chunk_size]:
833
+ state = param_to_state[id(p)]
834
+ adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
835
+ _update_param(p, state, lr, adjusted_lr, weight_decay,
836
+ self.rank, self.compute_stream)
837
+
838
+ chunk_size = dist.get_world_size(param_to_state[id(
839
+ params[0])].process_group)
840
+
841
+ # Wait grad update
842
+ self.comm_stream.wait_stream(torch.cuda.current_stream())
843
+
844
+ overlap_step = self.overlap_step
845
+ for i in range(0, overlap_step):
846
+ enqueue_all2all_gather(i * chunk_size, chunk_size)
847
+ enqueue_computes(i * chunk_size, chunk_size)
848
+
849
+ for i in range(0, len(params) + chunk_size - 1, chunk_size):
850
+ enqueue_all2all_scatter(i, chunk_size)
851
+ enqueue_all2all_gather(i + overlap_step * chunk_size, chunk_size)
852
+ enqueue_update_param(i, chunk_size)
853
+ enqueue_computes(i + overlap_step * chunk_size, chunk_size)
854
+
855
+ # Wait the last update_param to finish
856
+ torch.cuda.current_stream().wait_stream(self.compute_stream)
857
+
858
+ @staticmethod
859
+ def _fused_adamw(
860
+ params: list[torch.Tensor],
861
+ grads: list[torch.Tensor],
862
+ exp_avgs: list[torch.Tensor],
863
+ exp_avg_sqs: list[torch.Tensor],
864
+ max_exp_avg_sqs: list[torch.Tensor],
865
+ state_steps: list[torch.Tensor],
866
+ amsgrad: bool,
867
+ beta1: float,
868
+ beta2: float,
869
+ lr: Union[float, torch.Tensor],
870
+ weight_decay: float,
871
+ eps: float,
872
+ maximize: bool,
873
+ ) -> None:
874
+ if not params:
875
+ return
876
+
877
+ # We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer
878
+ # treating it as a scalar.
879
+ lr_dict: Optional[DeviceDict] = ({
880
+ lr.device: lr
881
+ } if isinstance(lr, torch.Tensor) and str(lr.device) != "cpu" else
882
+ None)
883
+ grouped_tensors = torch.optim.Optimizer._group_tensors_by_device_and_dtype(
884
+ [
885
+ params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs,
886
+ state_steps
887
+ ] # type: ignore[list-item]
888
+ )
889
+ for (device, _), (
890
+ (
891
+ device_params_,
892
+ device_grads_,
893
+ device_exp_avgs_,
894
+ device_exp_avg_sqs_,
895
+ device_max_exp_avg_sqs,
896
+ device_state_steps_,
897
+ ),
898
+ _,
899
+ ) in grouped_tensors.items():
900
+ device_params = cast(list[torch.Tensor], device_params_)
901
+ device_grads = cast(list[torch.Tensor], device_grads_)
902
+ device_exp_avgs = cast(list[torch.Tensor], device_exp_avgs_)
903
+ device_exp_avg_sqs = cast(list[torch.Tensor], device_exp_avg_sqs_)
904
+ device_state_steps = cast(list[torch.Tensor], device_state_steps_)
905
+
906
+ if lr_dict is not None and device not in lr_dict:
907
+ lr_dict[device] = lr.to(
908
+ device=device,
909
+ non_blocking=True) # type: ignore[union-attr]
910
+ lr = lr_dict[device]
911
+ torch._foreach_add_(device_state_steps, 1)
912
+ func = torch._fused_adamw_
913
+ func(
914
+ device_params,
915
+ device_grads,
916
+ device_exp_avgs,
917
+ device_exp_avg_sqs,
918
+ device_max_exp_avg_sqs, # type: ignore[arg-type]
919
+ device_state_steps,
920
+ amsgrad=amsgrad,
921
+ lr=lr, # type: ignore[arg-type]
922
+ beta1=beta1,
923
+ beta2=beta2,
924
+ weight_decay=weight_decay,
925
+ eps=eps,
926
+ maximize=maximize,
927
+ )
928
+
929
+ def step(self, closure=None, qk_logits=None):
930
+ """Perform a single optimization step.
931
+
932
+ Args:
933
+ closure (Callable, optional): A closure that reevaluates the model
934
+ and returns the loss.
935
+ qk_logits (dict[int, Tensor], optional): A dictionary mapping layer indices
936
+ to 1D tensors of shape (num_heads,), representing the maximum
937
+ QK logits across all tokens, computed as
938
+ (1 / sqrt(head_dim)) * (Q @ K^T).
939
+ """
940
+ loss = None
941
+ if closure is not None:
942
+ with torch.enable_grad():
943
+ loss = closure()
944
+
945
+ for group in self.param_groups:
946
+ params = group["params"]
947
+
948
+ if group["use_muon"]:
949
+ ############################
950
+ # Muon #
951
+ ############################
952
+ lr = group["lr"]
953
+ weight_decay = group["weight_decay"]
954
+ momentum = group["momentum"]
955
+ names = group["names"]
956
+
957
+ param_dtensors = []
958
+ param_tensors = []
959
+ name_dtensors = []
960
+ name_tensors = []
961
+
962
+ for n, p in zip(names, params):
963
+ if p is None or p.grad is None:
964
+ continue
965
+ if isinstance(p.data, DTensor):
966
+ if all(
967
+ isinstance(placement, Replicate)
968
+ for placement in p.placements):
969
+ param_tensors.append(p)
970
+ name_tensors.append(n)
971
+ else:
972
+ param_dtensors.append(p)
973
+ name_dtensors.append(n)
974
+ elif isinstance(p.data, torch.Tensor):
975
+ param_tensors.append(p)
976
+ name_tensors.append(n)
977
+ else:
978
+ raise TypeError(
979
+ f"Unsupported parameter type: {type(p.data)}")
980
+
981
+ if self.debug:
982
+ print(
983
+ f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors",
984
+ flush=True,
985
+ )
986
+
987
+ if len(param_dtensors) > 0:
988
+ if not dist.is_initialized():
989
+ raise RuntimeError(
990
+ "Parallel Muon requires torch.distributed to be initialized."
991
+ )
992
+
993
+ self.parallel(
994
+ name_dtensors,
995
+ param_dtensors,
996
+ group,
997
+ lr=lr,
998
+ weight_decay=weight_decay,
999
+ momentum=momentum,
1000
+ qk_logits=qk_logits,
1001
+ )
1002
+
1003
+ if len(param_tensors) > 0:
1004
+ self.base(
1005
+ name_tensors,
1006
+ param_tensors,
1007
+ group,
1008
+ lr=lr,
1009
+ weight_decay=weight_decay,
1010
+ momentum=momentum,
1011
+ qk_logits=qk_logits,
1012
+ )
1013
+
1014
+ else:
1015
+ ############################
1016
+ # AdamW backup #
1017
+ ############################
1018
+
1019
+ params_with_grads = []
1020
+ grads = []
1021
+ moment1 = []
1022
+ moment2 = []
1023
+ max_exp_avg_sqs = []
1024
+ state_steps = []
1025
+ lr = group["lr"]
1026
+ beta1, beta2 = group["adamw_betas"]
1027
+ eps = group["adamw_eps"]
1028
+ weight_decay = group["weight_decay"]
1029
+
1030
+ for p in params:
1031
+ g = p.grad
1032
+ if g is None:
1033
+ continue
1034
+ state = self.state[p]
1035
+ params_with_grads.append(p)
1036
+ grads.append(g)
1037
+ if "step" not in state:
1038
+ state["step"] = (torch.zeros((),
1039
+ dtype=torch.float32,
1040
+ device=p.device))
1041
+ state["moment1"] = torch.zeros_like(g)
1042
+ state["moment2"] = torch.zeros_like(g)
1043
+ moment1.append(state["moment1"])
1044
+ moment2.append(state["moment2"])
1045
+ if not isinstance(state["step"], torch.Tensor):
1046
+ step_tensor = torch.tensor(state["step"],
1047
+ dtype=torch.float32,
1048
+ device=p.device)
1049
+ else:
1050
+ step_tensor = state["step"]
1051
+ state_steps.append(step_tensor)
1052
+
1053
+ self._fused_adamw(
1054
+ params_with_grads,
1055
+ grads,
1056
+ moment1,
1057
+ moment2,
1058
+ max_exp_avg_sqs,
1059
+ state_steps,
1060
+ amsgrad=False,
1061
+ beta1=beta1,
1062
+ beta2=beta2,
1063
+ lr=lr,
1064
+ weight_decay=weight_decay,
1065
+ eps=eps,
1066
+ maximize=False,
1067
+ )
1068
+
1069
+ return loss
build/torch29-cxx11-cu128-x86_64-linux/optimizer/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .muon import Muon
2
+
3
+ __all__ = [
4
+ "Muon",
5
+ ]
build/torch29-cxx11-cu128-x86_64-linux/optimizer/_ops.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from . import _optimizer_811726c_dirty
3
+ ops = torch.ops._optimizer_811726c_dirty
4
+
5
+ def add_op_namespace_prefix(op_name: str):
6
+ """
7
+ Prefix op by namespace.
8
+ """
9
+ return f"_optimizer_811726c_dirty::{op_name}"
build/torch29-cxx11-cu128-x86_64-linux/optimizer/_optimizer_811726c_dirty.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ab1875be65811d88c407f36077aced58056a4feeb9946d7cd40ec55c7e1025c8
3
+ size 1871056
build/torch29-cxx11-cu128-x86_64-linux/optimizer/matmul_transpose_triton.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MIT License
2
+ #
3
+ # Copyright (c) 2025 Tianyang Lin
4
+ #
5
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ # of this software and associated documentation files (the "Software"), to deal
7
+ # in the Software without restriction, including without limitation the rights
8
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ # copies of the Software, and to permit persons to whom the Software is
10
+ # furnished to do so, subject to the following conditions:
11
+ #
12
+ # The above copyright notice and this permission notice shall be included in all
13
+ # copies or substantial portions of the Software.
14
+ #
15
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ # SOFTWARE.
22
+
23
+ import torch
24
+ import triton
25
+ import triton.language as tl
26
+
27
+
28
+ def get_autotune_config():
29
+ return [
30
+ triton.Config(
31
+ {
32
+ 'BLOCK_SIZE_M': blk_m,
33
+ 'BLOCK_SIZE_K': blk_k,
34
+ 'GROUP_SIZE_M': grp_sz
35
+ },
36
+ num_stages=n_stages,
37
+ num_warps=n_warps) for blk_m in [32, 64, 128]
38
+ for blk_k in [32, 64] for grp_sz in [8] for n_stages in [3, 4, 5]
39
+ for n_warps in [4, 8]
40
+ ]
41
+
42
+
43
+ @triton.autotune(
44
+ configs=get_autotune_config(),
45
+ key=['M', 'K'],
46
+ )
47
+ @triton.jit
48
+ def mmt_kernel(x, y, M, K, stride_xm, stride_xk, stride_ym, stride_yn,
49
+ BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
50
+ GROUP_SIZE_M: tl.constexpr):
51
+ """
52
+ Core kernel jit function of matmul_transpose that computes y = x @ x.T
53
+ The code is a simple adaptation from the triton `matmul` tutorial:
54
+ https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html
55
+ """
56
+ pid = tl.program_id(axis=0)
57
+ num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
58
+ num_pid_n = tl.cdiv(M, BLOCK_SIZE_M)
59
+ num_pid_in_group = GROUP_SIZE_M * num_pid_n
60
+ group_id = pid // num_pid_in_group
61
+ first_pid_m = group_id * GROUP_SIZE_M
62
+ group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
63
+ pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
64
+ pid_n = (pid % num_pid_in_group) // group_size_m
65
+ if pid_m > pid_n:
66
+ return
67
+
68
+ offs_xm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
69
+ offs_xn = (pid_n * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
70
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
71
+ # we use a & b ptrs to denote different rows of x.
72
+ a_ptrs = x + (offs_xm[:, None] * stride_xm + offs_k[None, :] * stride_xk)
73
+ b_ptrs = x + (offs_xn[:, None] * stride_xm + offs_k[None, :] * stride_xk)
74
+
75
+ accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_M), dtype=tl.float32)
76
+
77
+ for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
78
+ a = tl.load(a_ptrs,
79
+ mask=offs_k[None, :] < K - k * BLOCK_SIZE_K,
80
+ other=0.0)
81
+ b = tl.load(b_ptrs,
82
+ mask=offs_k[None, :] < K - k * BLOCK_SIZE_K,
83
+ other=0.0)
84
+ accumulator = tl.dot(a, tl.permute(b, (1, 0)), accumulator)
85
+ a_ptrs += BLOCK_SIZE_K * stride_xk
86
+ b_ptrs += BLOCK_SIZE_K * stride_xk
87
+ # use dtype.element_ty to accommodate different input datatypes as in cpp templates
88
+ # https://github.com/triton-lang/triton/issues/2252
89
+ c = accumulator.to(x.dtype.element_ty)
90
+
91
+ offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
92
+ offs_cn = pid_n * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
93
+ c_ptrs = y + stride_ym * offs_cm[:, None] + stride_yn * offs_cn[None, :]
94
+ c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M)
95
+ tl.store(c_ptrs, c, mask=c_mask)
96
+
97
+ # transpose and copy
98
+ if pid_m < pid_n:
99
+ ct_ptrs = y + stride_ym * offs_cn[:,
100
+ None] + stride_yn * offs_cm[None, :]
101
+ ct_mask = (offs_cn[:, None] < M) & (offs_cm[None, :] < M)
102
+ tl.store(ct_ptrs, tl.permute(c, (1, 0)), mask=ct_mask)
103
+
104
+
105
+ def matmul_transpose_assign(d_in, d_out):
106
+ assert d_in.is_cuda, "Input `d_in` must be a CUDA tensor"
107
+ assert d_out.is_cuda, "Input `d_out` must be a CUDA tensor"
108
+ assert d_in.device == d_out.device, "Inputs `d_in` and `d_out` must be on the same CUDA device"
109
+ assert d_in.dtype == d_out.dtype, "Inputs must have the same data type"
110
+ assert d_in.ndim == 2, "Input `d_in` must be a 2D tensor"
111
+ assert d_out.ndim == 2, "Input `d_out` must be a 2D tensor"
112
+ assert d_in.size(0) == d_out.size(0) == d_out.size(0), \
113
+ "First dimension of `d_in` must match first and second dimension of `d_out`"
114
+
115
+ d_in = d_in.contiguous()
116
+ M, K = d_in.shape
117
+ grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(
118
+ M, META['BLOCK_SIZE_M']), )
119
+ with torch.cuda.device(d_in.device.index):
120
+ mmt_kernel[grid](d_in, d_out, M, K, d_in.stride(0), d_in.stride(1),
121
+ d_out.stride(0), d_out.stride(1))
122
+
123
+
124
+ def matmul_transpose(d_in):
125
+ M, _ = d_in.shape
126
+ d_out = torch.empty((M, M), device=d_in.device, dtype=d_in.dtype)
127
+ matmul_transpose_assign(d_in, d_out)
128
+ return d_out
build/torch29-cxx11-cu128-x86_64-linux/optimizer/muon.py ADDED
@@ -0,0 +1,1069 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import math
3
+ import types
4
+ from dataclasses import dataclass
5
+ from typing import List, Optional, Union, cast
6
+
7
+ import torch
8
+ import torch.distributed as dist
9
+ from torch.distributed._tensor import DTensor, Replicate, Shard
10
+
11
+ from .matmul_transpose_triton import matmul_transpose_assign
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+ COMM_DTYPE = torch.bfloat16
16
+
17
+
18
+ # This code snippet is a modified version adapted from the following GitHub repositories:
19
+ # https://github.com/KellerJordan/Muon/blob/master/muon.py
20
+ # Muon's Newton–Schulz iteration causes high variance in singular values
21
+ # Idea: give each iteration its own 3 coefficients and optimize them via gradient descent.
22
+ @torch.no_grad()
23
+ # matmul_transpose_assign from : https://github.com/nil0x9/flash-muon
24
+ def _zeropower_via_newtonschulz5(G, steps):
25
+ """
26
+ Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
27
+ quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
28
+ of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
29
+ zero even beyond the point where the iteration no longer converges all the way to one everywhere
30
+ on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
31
+ where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
32
+ performance at all relative to UV^T, where USV^T = G is the SVD.
33
+ """
34
+ assert len(G.shape) == 2
35
+ assert G.dtype == COMM_DTYPE
36
+ X = G # no manual typecast
37
+
38
+ if G.size(0) > G.size(1):
39
+ X = X.T
40
+ # Ensure spectral norm is at most 1
41
+ X = X / (X.norm() + 1e-7)
42
+ buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
43
+ buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
44
+ # Perform the NS iterations
45
+ for a, b, c in [
46
+ (4.0848, -6.8946, 2.9270),
47
+ (3.9505, -6.3029, 2.6377),
48
+ (3.7418, -5.5913, 2.3037),
49
+ (2.8769, -3.1427, 1.2046),
50
+ (2.8366, -3.0525, 1.2012),
51
+ ]:
52
+ matmul_transpose_assign(X, buf1)
53
+ matmul_transpose_assign(buf1, buf2)
54
+ buf1.mul_(b).add_(buf2, alpha=c)
55
+ X = torch.addmm(X, buf1, X, alpha=1.0, beta=a)
56
+
57
+ if G.size(0) > G.size(1):
58
+ X = X.T
59
+ return X
60
+
61
+
62
+ @dataclass
63
+ class _muon_state:
64
+ # TODO: use Optional
65
+ worker_rank: int | None = None
66
+ gathered_grad: torch.Tensor | None = None
67
+ scattered_u: DTensor | None = None
68
+ computed_u: torch.Tensor | None = None
69
+ gather_event: torch.cuda.Event | None = None
70
+ compute_event: torch.cuda.Event | None = None
71
+ scatter_event: torch.cuda.Event | None = None
72
+ process_group = None
73
+ qk_clip_state = None
74
+
75
+
76
+ def split_elems_for_src(param, src_rank, num_ranks) -> int:
77
+ rows = param.shape[0]
78
+ cols = int(param.numel() // rows)
79
+ base, rem = divmod(rows, num_ranks)
80
+ my_rows = base + (1 if src_rank < rem else 0)
81
+ return my_rows * cols
82
+
83
+
84
+ @torch.no_grad()
85
+ def _alloc_gathered_grad(params, param_to_state, rank, compute_stream):
86
+ """
87
+ Pre-allocate gathered_grad buffer on compute_stream
88
+ before launching all2all gather
89
+ """
90
+ with torch.cuda.stream(compute_stream):
91
+ for p in params:
92
+ state = param_to_state[id(p)]
93
+ if rank == state.worker_rank:
94
+ num_ranks = dist.get_world_size(group=state.process_group)
95
+ state.gathered_grad = torch.empty(p.grad.numel(),
96
+ dtype=COMM_DTYPE,
97
+ device="cuda")
98
+ else:
99
+ state.gathered_grad = None
100
+
101
+ alloc_event = torch.cuda.Event()
102
+ alloc_event.record(compute_stream)
103
+ return alloc_event
104
+
105
+
106
+ @torch.no_grad()
107
+ def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad,
108
+ alloc_event):
109
+ """
110
+ All2all gathers shards so each owner rank reconstructs its full gradient
111
+ """
112
+ with torch.cuda.stream(comm_stream):
113
+ process_group = param_to_state[id(params[0])].process_group
114
+ num_ranks = dist.get_world_size(group=process_group)
115
+
116
+ # Construct sending buffers
117
+ per_dst = [[] for _ in range(num_ranks)]
118
+ send_counts = [0] * num_ranks
119
+
120
+ for p in params:
121
+ state = param_to_state[id(p)]
122
+ dst = state.worker_rank
123
+ assert dst < num_ranks
124
+ shard_elems = split_elems_for_src(p, rank, num_ranks)
125
+ g = p.grad
126
+ g = g.to_local().to(COMM_DTYPE).contiguous().view(-1)
127
+ assert g.numel() == shard_elems
128
+ per_dst[dst].append(g)
129
+ send_counts[dst] += shard_elems
130
+
131
+ assert any(
132
+ len(v) > 0 for v in per_dst
133
+ ), "At least one destination rank must receive a sharded tensor"
134
+ # list[list[Tensor]] -> list[Tensor]
135
+ per_dst = [t for dst in per_dst for t in dst]
136
+
137
+ send_buf = torch.cat(per_dst, dim=0)
138
+
139
+ owned_params = [
140
+ p for p in params if param_to_state[id(p)].worker_rank == rank
141
+ ]
142
+
143
+ # Compute receive sizes and allocate receiving buffers
144
+ recv_counts = [0] * num_ranks
145
+
146
+ for src in range(num_ranks):
147
+ total = 0
148
+ for p in owned_params:
149
+ state = param_to_state[id(p)]
150
+ assert state.worker_rank == rank
151
+ total += split_elems_for_src(p, src, num_ranks)
152
+ recv_counts[src] = total
153
+
154
+ recv_total = sum(recv_counts)
155
+ recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda")
156
+
157
+ #All2All
158
+ dist.all_to_all_single(
159
+ recv_buf,
160
+ send_buf,
161
+ output_split_sizes=recv_counts,
162
+ input_split_sizes=send_counts,
163
+ group=process_group,
164
+ )
165
+
166
+ # Reconstructs gathered grad from the received buffer
167
+ #
168
+ # recv_buf (num ranks = 3)
169
+ #
170
+ # From rank 0 From rank 1 From rank 2
171
+ # | p1_0, p2_0, p3_0 | p1_1, p2_1, p3_1 | p1_2, p2_2, p3_2 |
172
+ #
173
+ # Outer loop:
174
+ # rank 0 -> rank 1 -> rank2
175
+ #
176
+ # Inner loop:
177
+ # p1_n -> p2_n -> p3_n
178
+
179
+ comm_stream.wait_event(alloc_event)
180
+
181
+ off = 0
182
+ write_offsets = {id(p): 0 for p in owned_params}
183
+ for src in range(num_ranks):
184
+ if recv_counts[src] == 0:
185
+ continue
186
+
187
+ block = recv_counts[src]
188
+ inner_off = 0
189
+ for p in owned_params:
190
+ state = param_to_state[id(p)]
191
+ assert state.worker_rank == rank
192
+ n = split_elems_for_src(p, src, num_ranks)
193
+ assert n > 0
194
+
195
+ sg = recv_buf.narrow(0, off + inner_off, n)
196
+ woff = write_offsets[id(p)]
197
+ dst = state.gathered_grad.narrow(0, woff, n)
198
+ dst.copy_(sg)
199
+
200
+ write_offsets[id(p)] += n
201
+ inner_off += n
202
+ off += block
203
+
204
+ for p in params:
205
+ state = param_to_state[id(p)]
206
+ if state.worker_rank == rank:
207
+ state.gathered_grad = state.gathered_grad.view_as(p)
208
+ state.gather_event = torch.cuda.Event()
209
+ state.gather_event.record(comm_stream)
210
+ else:
211
+ state.gathered_grad = None
212
+ state.gather_event = None
213
+ if none_grad:
214
+ p.grad = None
215
+
216
+
217
+ @torch.no_grad()
218
+ def _compute_u(p, state, steps, rank, compute_stream):
219
+ """
220
+ On worker_rank, compute the orthogonalized update using Newton-Schulz iteration.
221
+ """
222
+ with torch.cuda.stream(compute_stream):
223
+ if rank == state.worker_rank:
224
+ if state.gather_event is None:
225
+ raise RuntimeError("Gather event must be set before compute.")
226
+ compute_stream.wait_event(state.gather_event)
227
+ u = _zeropower_via_newtonschulz5(state.gathered_grad, steps)
228
+ state.gathered_grad = None
229
+ state.computed_u = u
230
+ state.compute_event = torch.cuda.Event()
231
+ state.compute_event.record()
232
+ else:
233
+ state.computed_u = None
234
+ state.compute_event = None
235
+
236
+
237
+ @torch.no_grad()
238
+ def _alloc_scattered_u(params, param_to_state, rank, compute_stream):
239
+ """
240
+ Pre-allocate scattered_u buffer on compute_stream
241
+ before launching all2all gather
242
+ """
243
+ with torch.cuda.stream(compute_stream):
244
+ for p in params:
245
+ state = param_to_state[id(p)]
246
+ state.scattered_u = torch.empty_like(p.to_local(),
247
+ dtype=COMM_DTYPE)
248
+
249
+ alloc_event = torch.cuda.Event()
250
+ alloc_event.record(compute_stream)
251
+ return alloc_event
252
+
253
+
254
+ def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event):
255
+ """
256
+ All2all scatters full gradients to all ranks
257
+ """
258
+ with torch.cuda.stream(comm_stream):
259
+ process_group = param_to_state[id(params[0])].process_group
260
+ num_ranks = dist.get_world_size(group=process_group)
261
+ owned_params = [
262
+ p for p in params if param_to_state[id(p)].worker_rank == rank
263
+ ]
264
+
265
+ # Construct sending buffer
266
+ per_dst = [[] for _ in range(num_ranks)]
267
+ send_counts = [0] * num_ranks
268
+
269
+ if owned_params:
270
+ for p in owned_params:
271
+ state = param_to_state[id(p)]
272
+ if state.compute_event is None:
273
+ raise RuntimeError(
274
+ "Compute event must be set before scatter.")
275
+ comm_stream.wait_event(state.compute_event)
276
+ state.gathered_grad = None
277
+
278
+ assert state.computed_u is not None
279
+
280
+ u_full = state.computed_u.to(COMM_DTYPE).contiguous().view(-1)
281
+
282
+ offset = 0
283
+ for dst in range(num_ranks):
284
+ n = split_elems_for_src(p, dst, num_ranks)
285
+ assert n > 0
286
+
287
+ su = u_full.narrow(0, offset, n)
288
+ per_dst[dst].append(su)
289
+ send_counts[dst] += n
290
+ offset += n
291
+
292
+ assert offset == u_full.numel()
293
+
294
+ lengths = [len(v) for v in per_dst]
295
+ if all(l > 0 for l in lengths):
296
+ assert all(
297
+ l == lengths[0] for l in lengths
298
+ ), "All destination ranks must have the same number of sharded tensor"
299
+ # list[list[Tensor]] -> list[Tensor]
300
+ per_dst = [t for dst in per_dst for t in dst]
301
+ send_buf = torch.cat(per_dst, dim=0)
302
+ else:
303
+ # all_to_all requires participation from all ranks
304
+ # Even non-owner ranks must join the collective call
305
+ send_buf = torch.empty(0, dtype=COMM_DTYPE, device="cuda")
306
+
307
+ # Compute receive sizes and allocate receiving buffers
308
+ recv_counts = [0] * num_ranks
309
+
310
+ for src in range(num_ranks):
311
+ total = 0
312
+ for p in params:
313
+ state = param_to_state[id(p)]
314
+ if state.worker_rank != src:
315
+ continue
316
+ total += split_elems_for_src(p, rank, num_ranks)
317
+ recv_counts[src] = total
318
+
319
+ recv_total = sum(recv_counts)
320
+ assert recv_total > 0
321
+ recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda")
322
+
323
+ #All2All
324
+ dist.all_to_all_single(
325
+ recv_buf,
326
+ send_buf,
327
+ output_split_sizes=recv_counts,
328
+ input_split_sizes=send_counts,
329
+ group=process_group,
330
+ )
331
+
332
+ # Copy to pre-allocated scattered_u buffer from the received buffer
333
+ #
334
+ # recv_buf (num ranks = 3, local_rank = 0)
335
+ #
336
+ # From rank 0 From rank 1 From rank 2
337
+ # | p1_0, p2_0, p3_0 | p4_0 | p5_0, p6_0 |
338
+ #
339
+ # Outer loop:
340
+ # rank 0 -> rank 1 -> rank2
341
+ #
342
+ # Inner loop:
343
+ # src(0) : p1_0 -> p2_0 -> p3_0
344
+ # src(1) : p4_0
345
+ # src(2) : p5_0 -> p6_0
346
+
347
+ comm_stream.wait_event(alloc_event)
348
+
349
+ off = 0
350
+ for src in range(num_ranks):
351
+ block = recv_counts[src]
352
+ if block == 0:
353
+ continue
354
+
355
+ inner_off = 0
356
+ for p in params:
357
+ state = param_to_state[id(p)]
358
+ if state.worker_rank != src:
359
+ continue
360
+ n = split_elems_for_src(p, rank, num_ranks)
361
+ assert n > 0
362
+
363
+ flat_local = recv_buf.narrow(0, off + inner_off,
364
+ n).view_as(p.to_local())
365
+ state.scattered_u.copy_(flat_local)
366
+
367
+ state.scatter_event = torch.cuda.Event()
368
+ state.scatter_event.record(comm_stream)
369
+ inner_off += n
370
+
371
+ assert inner_off == block
372
+ off += block
373
+
374
+
375
+ def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
376
+ compute_stream):
377
+ """
378
+ Update sharded parameter p with the scattered_u.
379
+ Only worker_rank frees computed_u.
380
+ """
381
+ with torch.cuda.stream(compute_stream):
382
+ if state.scatter_event is None:
383
+ raise RuntimeError("Scatter event must be set before update")
384
+ compute_stream.wait_event(state.scatter_event)
385
+ u_dtensor = DTensor.from_local(
386
+ state.scattered_u,
387
+ placements=p.placements,
388
+ device_mesh=p.device_mesh,
389
+ )
390
+
391
+ state.scattered_u = u_dtensor
392
+
393
+ if rank == state.worker_rank:
394
+ # Free computed_u
395
+ state.computed_u = None
396
+
397
+ Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay)
398
+ state.scattered_u = None
399
+ u_dtensor = None
400
+
401
+ scales_full = Muon._compute_scales(p, state.qk_clip_state)
402
+ if scales_full is not None:
403
+ num_ranks = dist.get_world_size(group=state.process_group)
404
+ local_rank = dist.get_rank(group=state.process_group)
405
+ scales_local = scales_full.chunk(num_ranks, dim=0)[local_rank]
406
+ scales_local = DTensor.from_local(
407
+ scales_local,
408
+ placements=p.placements,
409
+ device_mesh=p.device_mesh,
410
+ )
411
+ Muon._qk_clip(p, scales_local, state.qk_clip_state.head_dim)
412
+
413
+
414
+ def default_is_muon(name, x):
415
+ skip_keys = ["embed_tokens", "lm_head", "tok_embeddings", "output"]
416
+ return x.ndim >= 2 and not any(key in name for key in skip_keys)
417
+
418
+
419
+ def get_default_muon_param_groups(model, is_muon_func=default_is_muon):
420
+ muon_params, muon_names = [], []
421
+ non_muon_params = []
422
+
423
+ for n, p in model.named_parameters():
424
+ if not p.requires_grad:
425
+ continue
426
+ if is_muon_func(n, p):
427
+ muon_params.append(p)
428
+ muon_names.append(n)
429
+ else:
430
+ non_muon_params.append(p)
431
+
432
+ return [
433
+ {
434
+ "params": muon_params,
435
+ "names": muon_names,
436
+ "use_muon": True,
437
+ },
438
+ {
439
+ "params": non_muon_params,
440
+ "use_muon": False,
441
+ },
442
+ ]
443
+
444
+
445
+ def parse_qk_layer(name: str) -> tuple[str | None, int]:
446
+ """
447
+ Parse a parameter name to check if it is a query/key projection layer
448
+ ('wq', 'wk', 'q_proj', 'k_proj') and return (kind, layer_index).
449
+
450
+ Returns:
451
+ (kind, layer_idx) or (None, -1) if not matched.
452
+
453
+ Example:
454
+ 'model.3.attn.wq.weight' -> ('wq', 3)
455
+ 'model.5.attn.wk.weight' -> ('wk', 5)
456
+ 'model.2.attn.q_proj.weight' -> ('q_proj', 2)
457
+ 'model.7.attn.k_proj.weight' -> ('k_proj', 7)
458
+ 'model.4.attn.v_proj.weight' -> (None, -1)
459
+ """
460
+ parts = name.split('.')
461
+ if len(parts) < 3:
462
+ return None, -1
463
+
464
+ kind = parts[-2]
465
+
466
+ layer_idx = -1
467
+ for part in reversed(parts):
468
+ if part.isdigit():
469
+ layer_idx = int(part)
470
+ break
471
+
472
+ if kind in ('wq', 'wk', 'q_proj', 'k_proj'):
473
+ return kind, layer_idx
474
+
475
+ return None, -1
476
+
477
+
478
+ @dataclass
479
+ class QKClipInfo:
480
+ """Per-parameter dynamic info computed from config + runtime logits."""
481
+ kind: Optional[str] # 'wq'/'q_proj' or 'wk'/'k_proj' or None
482
+ indices: List[int] # which heads to consider for clipping
483
+ head_dim: int # from config
484
+ threshold: float # from config
485
+ logit: Optional[torch.Tensor]
486
+
487
+
488
+ class Muon(torch.optim.Optimizer):
489
+ """
490
+ Muon - MomentUm Orthogonalized by Newton-schulz
491
+
492
+ Muon internally runs standard SGD-momentum, and then performs an orthogonalization post-
493
+ processing step, in which each 2D parameter's update is replaced with the nearest orthogonal
494
+ matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has
495
+ the advantage that it can be stably run in bfloat16 on the GPU.
496
+
497
+ Some warnings:
498
+ - We believe this optimizer is unlikely to work well for training with small batch size.
499
+ - We believe it may not work well for finetuning pretrained models, but we haven't tested this.
500
+
501
+ Arguments:
502
+ model: The model to be optimized by Muon.
503
+ is_muon_func: A function that takes a parameter and its name, and returns whether the parameter should be optimized by Muon.
504
+ lr: The learning rate. The updates will have spectral norm of `lr`. (0.02 is a good default)
505
+ momentum: The momentum used by the internal SGD. (0.95 is a good default)
506
+ nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended)
507
+ ns_steps: The number of Newton-Schulz iterations to run. (6 is probably always enough)
508
+ weight_decay: The weight decay for Muon and AdamW.
509
+ {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW as well.
510
+ adamw_lr: The learning rate for the internal AdamW.
511
+ adamw_betas: The betas for the internal AdamW.
512
+ adamw_eps: The epsilon for the internal AdamW.
513
+ none_grad: Whether to set p.grad to None after gathering the gradients. This can save memory.
514
+ debug: Whether to print debug information.
515
+ clip_info : Configuration for QK clipping. Expected keys:
516
+ - "q_indices" (list[int]): Indices of query heads to consider.
517
+ - "k_indices" (list[int]): Indices of key heads to consider.
518
+ - "head_dim" (int): Dimensionality of each attention head.
519
+ - "threshold" (float): Threshold value; heads whose QK logits exceed
520
+ this value will be scaled down.
521
+ Default is:
522
+ {
523
+ "q_indices": [],
524
+ "k_indices": [],
525
+ "head_dim": 128,
526
+ "threshold": 100
527
+ }
528
+ overlap_step : How many all2all gather, compute operations are launched in advance
529
+ before the corresponding all2all scatter steps begin.
530
+ A higher overlap_step increases memory usage but can improve
531
+ performance by overlapping communication.
532
+ Parallel muon only.
533
+ """
534
+
535
+ def __init__(self,
536
+ params,
537
+ lr=1e-3,
538
+ momentum=0.95,
539
+ nesterov=True,
540
+ ns_steps=5,
541
+ weight_decay=0.1,
542
+ adamw_betas=(0.9, 0.95),
543
+ adamw_eps=1e-8,
544
+ none_grad=True,
545
+ debug=False,
546
+ clip_config={
547
+ "q_indices": [],
548
+ "k_indices": [],
549
+ "head_dim": 128,
550
+ "threshold": 100
551
+ },
552
+ overlap_step=5):
553
+ defaults = dict(
554
+ lr=lr,
555
+ weight_decay=weight_decay,
556
+ momentum=momentum,
557
+ nesterov=nesterov,
558
+ ns_steps=ns_steps,
559
+ adamw_betas=adamw_betas,
560
+ adamw_eps=adamw_eps,
561
+ none_grad=none_grad,
562
+ use_muon=True,
563
+ )
564
+ error_message = "The key 'use_muon' is not set in parameter group {idx}. Assuming all parameters in the group will use muon optimization, which may lead to unexpected behavior."
565
+ instruction_code = "\n\n please follow this code snippet \n```optimizer = get_kernel('motif-technologies/optimizer')\n\n\nparams = optimizer.muon.get_default_muon_param_groups(model)\n\noptim = optimizer.Muon(params, ...)```"
566
+
567
+ if isinstance(params, types.GeneratorType):
568
+ raise ValueError(error_message.format(idx=0) + instruction_code)
569
+ for _idx, param_group in enumerate(params):
570
+ if param_group.get("use_muon", None) is None:
571
+ raise ValueError(
572
+ error_message.format(idx=_idx) + instruction_code)
573
+
574
+ super().__init__(params, defaults)
575
+
576
+ self.rank = None
577
+
578
+ self.comm_stream = torch.cuda.Stream()
579
+ self.compute_stream = torch.cuda.Stream()
580
+ self.debug = debug
581
+ self.clip_config = clip_config
582
+ self.overlap_step = overlap_step
583
+
584
+ def _calc_flops(self, G, steps):
585
+ assert len(G.shape) == 2
586
+ M, N = G.shape
587
+ if M > N:
588
+ M, N = N, M
589
+
590
+ return steps * ((M**3) * 2 + (M**2 * N) * 4 + M * N * 2 + M**2 * 3)
591
+
592
+ def adjust_lr_for_muon(self, lr, param_shape):
593
+ A, B = param_shape[:2]
594
+ # We adjust the learning rate and weight decay based on the size of the parameter matrix
595
+ # as describted in the paper
596
+ adjusted_ratio = 0.2 * math.sqrt(max(A, B))
597
+ adjusted_lr = lr * adjusted_ratio
598
+ return adjusted_lr
599
+
600
+ def get_shard_mesh(self, p):
601
+ """
602
+ Get the shard mesh for a parameter p on the given rank.
603
+ """
604
+ assert isinstance(
605
+ p, DTensor), "Parallel Muon only supports DTensor parameters."
606
+
607
+ if p.placements == (Shard(dim=0), ):
608
+ # Case for FSDP
609
+ process_group = p.device_mesh.get_group(mesh_dim=0)
610
+ if self.rank is None:
611
+ self.rank = dist.get_rank(group=process_group)
612
+ else:
613
+ assert self.rank == dist.get_rank(group=process_group)
614
+ return p.device_mesh.mesh, p.device_mesh.get_group(mesh_dim=0)
615
+ elif p.placements == (Replicate(), Shard(dim=0)):
616
+ # Case for HSDP
617
+ process_group = p.device_mesh.get_group(mesh_dim=1)
618
+ if self.rank is None:
619
+ self.rank = dist.get_rank(group=process_group)
620
+ else:
621
+ assert self.rank == dist.get_rank(group=process_group)
622
+ for i, shard_mesh in enumerate(p.device_mesh.mesh):
623
+ if self.rank in shard_mesh:
624
+ return shard_mesh, p.device_mesh.get_group(mesh_dim=1)
625
+ else:
626
+ raise ValueError(f"Unsupported placements ({p.placements}).")
627
+
628
+ def init_state_and_assign_params(self, names, params, group, qk_logits):
629
+ param_to_state = {}
630
+ param_to_flops = {}
631
+
632
+ total_flops = 0
633
+ for p in params:
634
+ g = p.grad
635
+ if g is None:
636
+ continue
637
+ assert g.ndim == 2, "Muon only supports 2D parameters."
638
+
639
+ flops = self._calc_flops(g, group["ns_steps"])
640
+ param_to_flops[id(p)] = flops
641
+ total_flops += flops
642
+
643
+ if self.debug:
644
+ print(f"Total TFLOPs for Muon: {total_flops / 1e12:.2f} TFLOPs",
645
+ flush=True)
646
+
647
+ paired = list(zip(names, params))
648
+
649
+ paired_sorted = sorted(paired,
650
+ key=lambda x: param_to_flops[id(x[1])],
651
+ reverse=True)
652
+
653
+ names_sorted, params_sorted = zip(*paired_sorted)
654
+ ordered_names = list(names_sorted)
655
+ ordered_params = list(params_sorted)
656
+
657
+ round_robin = 0
658
+ mesh = None
659
+ shard_mesh = None
660
+ process_group = None
661
+ for n, p in zip(ordered_names, ordered_params):
662
+ if mesh is None:
663
+ mesh = p.device_mesh
664
+ shard_mesh, process_group = self.get_shard_mesh(p)
665
+ elif mesh != p.device_mesh:
666
+ raise ValueError("All parameters must be on the same mesh.")
667
+ num_ranks = dist.get_world_size(group=process_group)
668
+ param_to_state[id(p)] = _muon_state()
669
+ param_to_state[id(
670
+ p)].worker_rank = shard_mesh[round_robin].item() % num_ranks
671
+ param_to_state[id(p)].process_group = process_group
672
+ qk_clip_state = self.get_qk_clip_info(n, qk_logits)
673
+ param_to_state[id(p)].qk_clip_state = qk_clip_state
674
+ round_robin = (round_robin + 1) % len(shard_mesh)
675
+
676
+ return param_to_state, ordered_params
677
+
678
+ def base(self, names, params, group, lr, weight_decay, momentum,
679
+ qk_logits):
680
+ # generate weight updates in distributed fashion
681
+ for n, p in zip(names, params):
682
+ g = p.grad
683
+ if g is None:
684
+ continue
685
+ if g.ndim > 2:
686
+ g = g.view(g.size(0), -1)
687
+ assert g is not None
688
+
689
+ # calc update
690
+ state = self.state[p]
691
+ if "momentum_buffer" not in state:
692
+ state["momentum_buffer"] = torch.zeros_like(g)
693
+ buf = state["momentum_buffer"]
694
+ buf.mul_(momentum).add_(g)
695
+ if group["nesterov"]:
696
+ g = g.add(buf, alpha=momentum)
697
+ else:
698
+ g = buf
699
+
700
+ u = _zeropower_via_newtonschulz5(g.to(COMM_DTYPE),
701
+ steps=group["ns_steps"])
702
+
703
+ adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
704
+ Muon._update_p(p, u, lr, adjusted_lr, weight_decay)
705
+
706
+ qk_clip_state = self.get_qk_clip_info(n, qk_logits)
707
+
708
+ scales_full = self._compute_scales(p, qk_clip_state)
709
+ if scales_full is not None:
710
+ Muon._qk_clip(p, scales_full, qk_clip_state.head_dim)
711
+
712
+ def _update_g(self, p, g, group, momentum):
713
+ # calc update
714
+ state = self.state[p]
715
+ buf = state.setdefault("momentum_buffer", torch.zeros_like(g))
716
+ torch.add(g, buf, alpha=momentum, out=buf)
717
+ if group["nesterov"]:
718
+ g.add_(buf, alpha=momentum)
719
+ return g
720
+ return buf
721
+
722
+ @staticmethod
723
+ def _update_p(p, u, lr, adjusted_lr, weight_decay):
724
+ # apply weight decay
725
+ p.data.mul_(1 - lr * weight_decay)
726
+ # apply update
727
+ p.data.add_(u, alpha=-adjusted_lr)
728
+
729
+ def get_qk_clip_info(self, n, qk_logits):
730
+ head_dim = self.clip_config.get('head_dim')
731
+ threshold = self.clip_config.get('threshold')
732
+ kind, layer_idx = parse_qk_layer(n)
733
+
734
+ logit, indices = None, []
735
+ if qk_logits is not None and kind is not None:
736
+ logit = qk_logits[layer_idx]
737
+ indices_key = 'q_indices' if 'q' in kind else 'k_indices'
738
+ indices = self.clip_config.get(indices_key, []) or []
739
+
740
+ return QKClipInfo(
741
+ kind=kind,
742
+ indices=indices,
743
+ head_dim=head_dim,
744
+ threshold=threshold,
745
+ logit=logit,
746
+ )
747
+
748
+ @staticmethod
749
+ def _compute_scales(p, qk_clip_state):
750
+ kind = qk_clip_state.kind
751
+ indices = qk_clip_state.indices
752
+ head_dim = qk_clip_state.head_dim
753
+ threshold = qk_clip_state.threshold
754
+ logit = qk_clip_state.logit
755
+
756
+ H_global = p.shape[0] // head_dim
757
+ scales_full = torch.ones(H_global, device=p.data.device)
758
+ scaling = 0
759
+
760
+ for logit_idx, head_idx in enumerate(indices):
761
+ v_ele = float(logit[logit_idx])
762
+ if v_ele > threshold:
763
+ new_scale = math.sqrt(threshold / v_ele)
764
+ if new_scale < scales_full[head_idx]:
765
+ scales_full[head_idx] = new_scale
766
+ logger.info(
767
+ f"[{kind}] Head {head_idx} exceeded threshold "
768
+ f"(value={v_ele:.4f}, threshold={threshold:.4f}) -> applying scale={new_scale:.4f}"
769
+ )
770
+ scaling += 1
771
+
772
+ return scales_full if scaling > 0 else None
773
+
774
+ @staticmethod
775
+ def _qk_clip(p, scales, head_dim):
776
+ W = p.data.view(-1, head_dim, p.data.shape[1])
777
+ W.mul_(scales.view(-1, 1, 1))
778
+
779
+ def parallel(self, names, params, group, lr, weight_decay, momentum,
780
+ qk_logits):
781
+ """
782
+ Perform a parallel optimization step using Muon.
783
+ """
784
+
785
+ for p in params:
786
+ g = p.grad
787
+ if g is None:
788
+ continue
789
+ if g.ndim > 2:
790
+ g = g.view(g.size(0), -1)
791
+
792
+ # Update g in the local rank
793
+ g = self._update_g(
794
+ p,
795
+ g,
796
+ group,
797
+ momentum=momentum,
798
+ )
799
+ p.grad = g
800
+
801
+ param_to_state, ordered_params = self.init_state_and_assign_params(
802
+ names, params, group, qk_logits)
803
+
804
+ assert self.rank is not None
805
+
806
+ def enqueue_all2all_gather(start_idx, chunk_size):
807
+ target_params = ordered_params[start_idx:start_idx + chunk_size]
808
+ if target_params:
809
+ alloc_event = _alloc_gathered_grad(target_params,
810
+ param_to_state, self.rank,
811
+ self.compute_stream)
812
+ _all2all_gather(target_params, param_to_state, self.rank,
813
+ self.comm_stream, group["none_grad"],
814
+ alloc_event)
815
+
816
+ def enqueue_computes(start_idx, chunk_size):
817
+ for p in ordered_params[start_idx:start_idx + chunk_size]:
818
+ state = param_to_state[id(p)]
819
+ _compute_u(p, state, group["ns_steps"], self.rank,
820
+ self.compute_stream)
821
+
822
+ def enqueue_all2all_scatter(start_idx, chunk_size):
823
+ target_params = ordered_params[start_idx:start_idx + chunk_size]
824
+ if target_params:
825
+ alloc_event = _alloc_scattered_u(target_params, param_to_state,
826
+ self.rank,
827
+ self.compute_stream)
828
+ _all2all_scatter(target_params, param_to_state, self.rank,
829
+ self.comm_stream, alloc_event)
830
+
831
+ def enqueue_update_param(start_idx, chunk_size):
832
+ for p in ordered_params[start_idx:start_idx + chunk_size]:
833
+ state = param_to_state[id(p)]
834
+ adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
835
+ _update_param(p, state, lr, adjusted_lr, weight_decay,
836
+ self.rank, self.compute_stream)
837
+
838
+ chunk_size = dist.get_world_size(param_to_state[id(
839
+ params[0])].process_group)
840
+
841
+ # Wait grad update
842
+ self.comm_stream.wait_stream(torch.cuda.current_stream())
843
+
844
+ overlap_step = self.overlap_step
845
+ for i in range(0, overlap_step):
846
+ enqueue_all2all_gather(i * chunk_size, chunk_size)
847
+ enqueue_computes(i * chunk_size, chunk_size)
848
+
849
+ for i in range(0, len(params) + chunk_size - 1, chunk_size):
850
+ enqueue_all2all_scatter(i, chunk_size)
851
+ enqueue_all2all_gather(i + overlap_step * chunk_size, chunk_size)
852
+ enqueue_update_param(i, chunk_size)
853
+ enqueue_computes(i + overlap_step * chunk_size, chunk_size)
854
+
855
+ # Wait the last update_param to finish
856
+ torch.cuda.current_stream().wait_stream(self.compute_stream)
857
+
858
+ @staticmethod
859
+ def _fused_adamw(
860
+ params: list[torch.Tensor],
861
+ grads: list[torch.Tensor],
862
+ exp_avgs: list[torch.Tensor],
863
+ exp_avg_sqs: list[torch.Tensor],
864
+ max_exp_avg_sqs: list[torch.Tensor],
865
+ state_steps: list[torch.Tensor],
866
+ amsgrad: bool,
867
+ beta1: float,
868
+ beta2: float,
869
+ lr: Union[float, torch.Tensor],
870
+ weight_decay: float,
871
+ eps: float,
872
+ maximize: bool,
873
+ ) -> None:
874
+ if not params:
875
+ return
876
+
877
+ # We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer
878
+ # treating it as a scalar.
879
+ lr_dict: Optional[DeviceDict] = ({
880
+ lr.device: lr
881
+ } if isinstance(lr, torch.Tensor) and str(lr.device) != "cpu" else
882
+ None)
883
+ grouped_tensors = torch.optim.Optimizer._group_tensors_by_device_and_dtype(
884
+ [
885
+ params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs,
886
+ state_steps
887
+ ] # type: ignore[list-item]
888
+ )
889
+ for (device, _), (
890
+ (
891
+ device_params_,
892
+ device_grads_,
893
+ device_exp_avgs_,
894
+ device_exp_avg_sqs_,
895
+ device_max_exp_avg_sqs,
896
+ device_state_steps_,
897
+ ),
898
+ _,
899
+ ) in grouped_tensors.items():
900
+ device_params = cast(list[torch.Tensor], device_params_)
901
+ device_grads = cast(list[torch.Tensor], device_grads_)
902
+ device_exp_avgs = cast(list[torch.Tensor], device_exp_avgs_)
903
+ device_exp_avg_sqs = cast(list[torch.Tensor], device_exp_avg_sqs_)
904
+ device_state_steps = cast(list[torch.Tensor], device_state_steps_)
905
+
906
+ if lr_dict is not None and device not in lr_dict:
907
+ lr_dict[device] = lr.to(
908
+ device=device,
909
+ non_blocking=True) # type: ignore[union-attr]
910
+ lr = lr_dict[device]
911
+ torch._foreach_add_(device_state_steps, 1)
912
+ func = torch._fused_adamw_
913
+ func(
914
+ device_params,
915
+ device_grads,
916
+ device_exp_avgs,
917
+ device_exp_avg_sqs,
918
+ device_max_exp_avg_sqs, # type: ignore[arg-type]
919
+ device_state_steps,
920
+ amsgrad=amsgrad,
921
+ lr=lr, # type: ignore[arg-type]
922
+ beta1=beta1,
923
+ beta2=beta2,
924
+ weight_decay=weight_decay,
925
+ eps=eps,
926
+ maximize=maximize,
927
+ )
928
+
929
+ def step(self, closure=None, qk_logits=None):
930
+ """Perform a single optimization step.
931
+
932
+ Args:
933
+ closure (Callable, optional): A closure that reevaluates the model
934
+ and returns the loss.
935
+ qk_logits (dict[int, Tensor], optional): A dictionary mapping layer indices
936
+ to 1D tensors of shape (num_heads,), representing the maximum
937
+ QK logits across all tokens, computed as
938
+ (1 / sqrt(head_dim)) * (Q @ K^T).
939
+ """
940
+ loss = None
941
+ if closure is not None:
942
+ with torch.enable_grad():
943
+ loss = closure()
944
+
945
+ for group in self.param_groups:
946
+ params = group["params"]
947
+
948
+ if group["use_muon"]:
949
+ ############################
950
+ # Muon #
951
+ ############################
952
+ lr = group["lr"]
953
+ weight_decay = group["weight_decay"]
954
+ momentum = group["momentum"]
955
+ names = group["names"]
956
+
957
+ param_dtensors = []
958
+ param_tensors = []
959
+ name_dtensors = []
960
+ name_tensors = []
961
+
962
+ for n, p in zip(names, params):
963
+ if p is None or p.grad is None:
964
+ continue
965
+ if isinstance(p.data, DTensor):
966
+ if all(
967
+ isinstance(placement, Replicate)
968
+ for placement in p.placements):
969
+ param_tensors.append(p)
970
+ name_tensors.append(n)
971
+ else:
972
+ param_dtensors.append(p)
973
+ name_dtensors.append(n)
974
+ elif isinstance(p.data, torch.Tensor):
975
+ param_tensors.append(p)
976
+ name_tensors.append(n)
977
+ else:
978
+ raise TypeError(
979
+ f"Unsupported parameter type: {type(p.data)}")
980
+
981
+ if self.debug:
982
+ print(
983
+ f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors",
984
+ flush=True,
985
+ )
986
+
987
+ if len(param_dtensors) > 0:
988
+ if not dist.is_initialized():
989
+ raise RuntimeError(
990
+ "Parallel Muon requires torch.distributed to be initialized."
991
+ )
992
+
993
+ self.parallel(
994
+ name_dtensors,
995
+ param_dtensors,
996
+ group,
997
+ lr=lr,
998
+ weight_decay=weight_decay,
999
+ momentum=momentum,
1000
+ qk_logits=qk_logits,
1001
+ )
1002
+
1003
+ if len(param_tensors) > 0:
1004
+ self.base(
1005
+ name_tensors,
1006
+ param_tensors,
1007
+ group,
1008
+ lr=lr,
1009
+ weight_decay=weight_decay,
1010
+ momentum=momentum,
1011
+ qk_logits=qk_logits,
1012
+ )
1013
+
1014
+ else:
1015
+ ############################
1016
+ # AdamW backup #
1017
+ ############################
1018
+
1019
+ params_with_grads = []
1020
+ grads = []
1021
+ moment1 = []
1022
+ moment2 = []
1023
+ max_exp_avg_sqs = []
1024
+ state_steps = []
1025
+ lr = group["lr"]
1026
+ beta1, beta2 = group["adamw_betas"]
1027
+ eps = group["adamw_eps"]
1028
+ weight_decay = group["weight_decay"]
1029
+
1030
+ for p in params:
1031
+ g = p.grad
1032
+ if g is None:
1033
+ continue
1034
+ state = self.state[p]
1035
+ params_with_grads.append(p)
1036
+ grads.append(g)
1037
+ if "step" not in state:
1038
+ state["step"] = (torch.zeros((),
1039
+ dtype=torch.float32,
1040
+ device=p.device))
1041
+ state["moment1"] = torch.zeros_like(g)
1042
+ state["moment2"] = torch.zeros_like(g)
1043
+ moment1.append(state["moment1"])
1044
+ moment2.append(state["moment2"])
1045
+ if not isinstance(state["step"], torch.Tensor):
1046
+ step_tensor = torch.tensor(state["step"],
1047
+ dtype=torch.float32,
1048
+ device=p.device)
1049
+ else:
1050
+ step_tensor = state["step"]
1051
+ state_steps.append(step_tensor)
1052
+
1053
+ self._fused_adamw(
1054
+ params_with_grads,
1055
+ grads,
1056
+ moment1,
1057
+ moment2,
1058
+ max_exp_avg_sqs,
1059
+ state_steps,
1060
+ amsgrad=False,
1061
+ beta1=beta1,
1062
+ beta2=beta2,
1063
+ lr=lr,
1064
+ weight_decay=weight_decay,
1065
+ eps=eps,
1066
+ maximize=False,
1067
+ )
1068
+
1069
+ return loss
build/torch29-cxx11-cu130-x86_64-linux/optimizer/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .muon import Muon
2
+
3
+ __all__ = [
4
+ "Muon",
5
+ ]
build/torch29-cxx11-cu130-x86_64-linux/optimizer/_ops.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from . import _optimizer_811726c_dirty
3
+ ops = torch.ops._optimizer_811726c_dirty
4
+
5
+ def add_op_namespace_prefix(op_name: str):
6
+ """
7
+ Prefix op by namespace.
8
+ """
9
+ return f"_optimizer_811726c_dirty::{op_name}"
build/torch29-cxx11-cu130-x86_64-linux/optimizer/_optimizer_811726c_dirty.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:52a744cf30c60fe1e8fc35ebb0d3421d679bb2047fbb4602846bd6902cfa9e52
3
+ size 1872152
build/torch29-cxx11-cu130-x86_64-linux/optimizer/matmul_transpose_triton.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MIT License
2
+ #
3
+ # Copyright (c) 2025 Tianyang Lin
4
+ #
5
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ # of this software and associated documentation files (the "Software"), to deal
7
+ # in the Software without restriction, including without limitation the rights
8
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ # copies of the Software, and to permit persons to whom the Software is
10
+ # furnished to do so, subject to the following conditions:
11
+ #
12
+ # The above copyright notice and this permission notice shall be included in all
13
+ # copies or substantial portions of the Software.
14
+ #
15
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ # SOFTWARE.
22
+
23
+ import torch
24
+ import triton
25
+ import triton.language as tl
26
+
27
+
28
+ def get_autotune_config():
29
+ return [
30
+ triton.Config(
31
+ {
32
+ 'BLOCK_SIZE_M': blk_m,
33
+ 'BLOCK_SIZE_K': blk_k,
34
+ 'GROUP_SIZE_M': grp_sz
35
+ },
36
+ num_stages=n_stages,
37
+ num_warps=n_warps) for blk_m in [32, 64, 128]
38
+ for blk_k in [32, 64] for grp_sz in [8] for n_stages in [3, 4, 5]
39
+ for n_warps in [4, 8]
40
+ ]
41
+
42
+
43
+ @triton.autotune(
44
+ configs=get_autotune_config(),
45
+ key=['M', 'K'],
46
+ )
47
+ @triton.jit
48
+ def mmt_kernel(x, y, M, K, stride_xm, stride_xk, stride_ym, stride_yn,
49
+ BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
50
+ GROUP_SIZE_M: tl.constexpr):
51
+ """
52
+ Core kernel jit function of matmul_transpose that computes y = x @ x.T
53
+ The code is a simple adaptation from the triton `matmul` tutorial:
54
+ https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html
55
+ """
56
+ pid = tl.program_id(axis=0)
57
+ num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
58
+ num_pid_n = tl.cdiv(M, BLOCK_SIZE_M)
59
+ num_pid_in_group = GROUP_SIZE_M * num_pid_n
60
+ group_id = pid // num_pid_in_group
61
+ first_pid_m = group_id * GROUP_SIZE_M
62
+ group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
63
+ pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
64
+ pid_n = (pid % num_pid_in_group) // group_size_m
65
+ if pid_m > pid_n:
66
+ return
67
+
68
+ offs_xm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
69
+ offs_xn = (pid_n * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
70
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
71
+ # we use a & b ptrs to denote different rows of x.
72
+ a_ptrs = x + (offs_xm[:, None] * stride_xm + offs_k[None, :] * stride_xk)
73
+ b_ptrs = x + (offs_xn[:, None] * stride_xm + offs_k[None, :] * stride_xk)
74
+
75
+ accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_M), dtype=tl.float32)
76
+
77
+ for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
78
+ a = tl.load(a_ptrs,
79
+ mask=offs_k[None, :] < K - k * BLOCK_SIZE_K,
80
+ other=0.0)
81
+ b = tl.load(b_ptrs,
82
+ mask=offs_k[None, :] < K - k * BLOCK_SIZE_K,
83
+ other=0.0)
84
+ accumulator = tl.dot(a, tl.permute(b, (1, 0)), accumulator)
85
+ a_ptrs += BLOCK_SIZE_K * stride_xk
86
+ b_ptrs += BLOCK_SIZE_K * stride_xk
87
+ # use dtype.element_ty to accommodate different input datatypes as in cpp templates
88
+ # https://github.com/triton-lang/triton/issues/2252
89
+ c = accumulator.to(x.dtype.element_ty)
90
+
91
+ offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
92
+ offs_cn = pid_n * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
93
+ c_ptrs = y + stride_ym * offs_cm[:, None] + stride_yn * offs_cn[None, :]
94
+ c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M)
95
+ tl.store(c_ptrs, c, mask=c_mask)
96
+
97
+ # transpose and copy
98
+ if pid_m < pid_n:
99
+ ct_ptrs = y + stride_ym * offs_cn[:,
100
+ None] + stride_yn * offs_cm[None, :]
101
+ ct_mask = (offs_cn[:, None] < M) & (offs_cm[None, :] < M)
102
+ tl.store(ct_ptrs, tl.permute(c, (1, 0)), mask=ct_mask)
103
+
104
+
105
+ def matmul_transpose_assign(d_in, d_out):
106
+ assert d_in.is_cuda, "Input `d_in` must be a CUDA tensor"
107
+ assert d_out.is_cuda, "Input `d_out` must be a CUDA tensor"
108
+ assert d_in.device == d_out.device, "Inputs `d_in` and `d_out` must be on the same CUDA device"
109
+ assert d_in.dtype == d_out.dtype, "Inputs must have the same data type"
110
+ assert d_in.ndim == 2, "Input `d_in` must be a 2D tensor"
111
+ assert d_out.ndim == 2, "Input `d_out` must be a 2D tensor"
112
+ assert d_in.size(0) == d_out.size(0) == d_out.size(0), \
113
+ "First dimension of `d_in` must match first and second dimension of `d_out`"
114
+
115
+ d_in = d_in.contiguous()
116
+ M, K = d_in.shape
117
+ grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(
118
+ M, META['BLOCK_SIZE_M']), )
119
+ with torch.cuda.device(d_in.device.index):
120
+ mmt_kernel[grid](d_in, d_out, M, K, d_in.stride(0), d_in.stride(1),
121
+ d_out.stride(0), d_out.stride(1))
122
+
123
+
124
+ def matmul_transpose(d_in):
125
+ M, _ = d_in.shape
126
+ d_out = torch.empty((M, M), device=d_in.device, dtype=d_in.dtype)
127
+ matmul_transpose_assign(d_in, d_out)
128
+ return d_out
build/torch29-cxx11-cu130-x86_64-linux/optimizer/muon.py ADDED
@@ -0,0 +1,1069 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import math
3
+ import types
4
+ from dataclasses import dataclass
5
+ from typing import List, Optional, Union, cast
6
+
7
+ import torch
8
+ import torch.distributed as dist
9
+ from torch.distributed._tensor import DTensor, Replicate, Shard
10
+
11
+ from .matmul_transpose_triton import matmul_transpose_assign
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+ COMM_DTYPE = torch.bfloat16
16
+
17
+
18
+ # This code snippet is a modified version adapted from the following GitHub repositories:
19
+ # https://github.com/KellerJordan/Muon/blob/master/muon.py
20
+ # Muon's Newton–Schulz iteration causes high variance in singular values
21
+ # Idea: give each iteration its own 3 coefficients and optimize them via gradient descent.
22
+ @torch.no_grad()
23
+ # matmul_transpose_assign from : https://github.com/nil0x9/flash-muon
24
+ def _zeropower_via_newtonschulz5(G, steps):
25
+ """
26
+ Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
27
+ quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
28
+ of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
29
+ zero even beyond the point where the iteration no longer converges all the way to one everywhere
30
+ on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
31
+ where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
32
+ performance at all relative to UV^T, where USV^T = G is the SVD.
33
+ """
34
+ assert len(G.shape) == 2
35
+ assert G.dtype == COMM_DTYPE
36
+ X = G # no manual typecast
37
+
38
+ if G.size(0) > G.size(1):
39
+ X = X.T
40
+ # Ensure spectral norm is at most 1
41
+ X = X / (X.norm() + 1e-7)
42
+ buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
43
+ buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
44
+ # Perform the NS iterations
45
+ for a, b, c in [
46
+ (4.0848, -6.8946, 2.9270),
47
+ (3.9505, -6.3029, 2.6377),
48
+ (3.7418, -5.5913, 2.3037),
49
+ (2.8769, -3.1427, 1.2046),
50
+ (2.8366, -3.0525, 1.2012),
51
+ ]:
52
+ matmul_transpose_assign(X, buf1)
53
+ matmul_transpose_assign(buf1, buf2)
54
+ buf1.mul_(b).add_(buf2, alpha=c)
55
+ X = torch.addmm(X, buf1, X, alpha=1.0, beta=a)
56
+
57
+ if G.size(0) > G.size(1):
58
+ X = X.T
59
+ return X
60
+
61
+
62
+ @dataclass
63
+ class _muon_state:
64
+ # TODO: use Optional
65
+ worker_rank: int | None = None
66
+ gathered_grad: torch.Tensor | None = None
67
+ scattered_u: DTensor | None = None
68
+ computed_u: torch.Tensor | None = None
69
+ gather_event: torch.cuda.Event | None = None
70
+ compute_event: torch.cuda.Event | None = None
71
+ scatter_event: torch.cuda.Event | None = None
72
+ process_group = None
73
+ qk_clip_state = None
74
+
75
+
76
+ def split_elems_for_src(param, src_rank, num_ranks) -> int:
77
+ rows = param.shape[0]
78
+ cols = int(param.numel() // rows)
79
+ base, rem = divmod(rows, num_ranks)
80
+ my_rows = base + (1 if src_rank < rem else 0)
81
+ return my_rows * cols
82
+
83
+
84
+ @torch.no_grad()
85
+ def _alloc_gathered_grad(params, param_to_state, rank, compute_stream):
86
+ """
87
+ Pre-allocate gathered_grad buffer on compute_stream
88
+ before launching all2all gather
89
+ """
90
+ with torch.cuda.stream(compute_stream):
91
+ for p in params:
92
+ state = param_to_state[id(p)]
93
+ if rank == state.worker_rank:
94
+ num_ranks = dist.get_world_size(group=state.process_group)
95
+ state.gathered_grad = torch.empty(p.grad.numel(),
96
+ dtype=COMM_DTYPE,
97
+ device="cuda")
98
+ else:
99
+ state.gathered_grad = None
100
+
101
+ alloc_event = torch.cuda.Event()
102
+ alloc_event.record(compute_stream)
103
+ return alloc_event
104
+
105
+
106
+ @torch.no_grad()
107
+ def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad,
108
+ alloc_event):
109
+ """
110
+ All2all gathers shards so each owner rank reconstructs its full gradient
111
+ """
112
+ with torch.cuda.stream(comm_stream):
113
+ process_group = param_to_state[id(params[0])].process_group
114
+ num_ranks = dist.get_world_size(group=process_group)
115
+
116
+ # Construct sending buffers
117
+ per_dst = [[] for _ in range(num_ranks)]
118
+ send_counts = [0] * num_ranks
119
+
120
+ for p in params:
121
+ state = param_to_state[id(p)]
122
+ dst = state.worker_rank
123
+ assert dst < num_ranks
124
+ shard_elems = split_elems_for_src(p, rank, num_ranks)
125
+ g = p.grad
126
+ g = g.to_local().to(COMM_DTYPE).contiguous().view(-1)
127
+ assert g.numel() == shard_elems
128
+ per_dst[dst].append(g)
129
+ send_counts[dst] += shard_elems
130
+
131
+ assert any(
132
+ len(v) > 0 for v in per_dst
133
+ ), "At least one destination rank must receive a sharded tensor"
134
+ # list[list[Tensor]] -> list[Tensor]
135
+ per_dst = [t for dst in per_dst for t in dst]
136
+
137
+ send_buf = torch.cat(per_dst, dim=0)
138
+
139
+ owned_params = [
140
+ p for p in params if param_to_state[id(p)].worker_rank == rank
141
+ ]
142
+
143
+ # Compute receive sizes and allocate receiving buffers
144
+ recv_counts = [0] * num_ranks
145
+
146
+ for src in range(num_ranks):
147
+ total = 0
148
+ for p in owned_params:
149
+ state = param_to_state[id(p)]
150
+ assert state.worker_rank == rank
151
+ total += split_elems_for_src(p, src, num_ranks)
152
+ recv_counts[src] = total
153
+
154
+ recv_total = sum(recv_counts)
155
+ recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda")
156
+
157
+ #All2All
158
+ dist.all_to_all_single(
159
+ recv_buf,
160
+ send_buf,
161
+ output_split_sizes=recv_counts,
162
+ input_split_sizes=send_counts,
163
+ group=process_group,
164
+ )
165
+
166
+ # Reconstructs gathered grad from the received buffer
167
+ #
168
+ # recv_buf (num ranks = 3)
169
+ #
170
+ # From rank 0 From rank 1 From rank 2
171
+ # | p1_0, p2_0, p3_0 | p1_1, p2_1, p3_1 | p1_2, p2_2, p3_2 |
172
+ #
173
+ # Outer loop:
174
+ # rank 0 -> rank 1 -> rank2
175
+ #
176
+ # Inner loop:
177
+ # p1_n -> p2_n -> p3_n
178
+
179
+ comm_stream.wait_event(alloc_event)
180
+
181
+ off = 0
182
+ write_offsets = {id(p): 0 for p in owned_params}
183
+ for src in range(num_ranks):
184
+ if recv_counts[src] == 0:
185
+ continue
186
+
187
+ block = recv_counts[src]
188
+ inner_off = 0
189
+ for p in owned_params:
190
+ state = param_to_state[id(p)]
191
+ assert state.worker_rank == rank
192
+ n = split_elems_for_src(p, src, num_ranks)
193
+ assert n > 0
194
+
195
+ sg = recv_buf.narrow(0, off + inner_off, n)
196
+ woff = write_offsets[id(p)]
197
+ dst = state.gathered_grad.narrow(0, woff, n)
198
+ dst.copy_(sg)
199
+
200
+ write_offsets[id(p)] += n
201
+ inner_off += n
202
+ off += block
203
+
204
+ for p in params:
205
+ state = param_to_state[id(p)]
206
+ if state.worker_rank == rank:
207
+ state.gathered_grad = state.gathered_grad.view_as(p)
208
+ state.gather_event = torch.cuda.Event()
209
+ state.gather_event.record(comm_stream)
210
+ else:
211
+ state.gathered_grad = None
212
+ state.gather_event = None
213
+ if none_grad:
214
+ p.grad = None
215
+
216
+
217
+ @torch.no_grad()
218
+ def _compute_u(p, state, steps, rank, compute_stream):
219
+ """
220
+ On worker_rank, compute the orthogonalized update using Newton-Schulz iteration.
221
+ """
222
+ with torch.cuda.stream(compute_stream):
223
+ if rank == state.worker_rank:
224
+ if state.gather_event is None:
225
+ raise RuntimeError("Gather event must be set before compute.")
226
+ compute_stream.wait_event(state.gather_event)
227
+ u = _zeropower_via_newtonschulz5(state.gathered_grad, steps)
228
+ state.gathered_grad = None
229
+ state.computed_u = u
230
+ state.compute_event = torch.cuda.Event()
231
+ state.compute_event.record()
232
+ else:
233
+ state.computed_u = None
234
+ state.compute_event = None
235
+
236
+
237
+ @torch.no_grad()
238
+ def _alloc_scattered_u(params, param_to_state, rank, compute_stream):
239
+ """
240
+ Pre-allocate scattered_u buffer on compute_stream
241
+ before launching all2all gather
242
+ """
243
+ with torch.cuda.stream(compute_stream):
244
+ for p in params:
245
+ state = param_to_state[id(p)]
246
+ state.scattered_u = torch.empty_like(p.to_local(),
247
+ dtype=COMM_DTYPE)
248
+
249
+ alloc_event = torch.cuda.Event()
250
+ alloc_event.record(compute_stream)
251
+ return alloc_event
252
+
253
+
254
+ def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event):
255
+ """
256
+ All2all scatters full gradients to all ranks
257
+ """
258
+ with torch.cuda.stream(comm_stream):
259
+ process_group = param_to_state[id(params[0])].process_group
260
+ num_ranks = dist.get_world_size(group=process_group)
261
+ owned_params = [
262
+ p for p in params if param_to_state[id(p)].worker_rank == rank
263
+ ]
264
+
265
+ # Construct sending buffer
266
+ per_dst = [[] for _ in range(num_ranks)]
267
+ send_counts = [0] * num_ranks
268
+
269
+ if owned_params:
270
+ for p in owned_params:
271
+ state = param_to_state[id(p)]
272
+ if state.compute_event is None:
273
+ raise RuntimeError(
274
+ "Compute event must be set before scatter.")
275
+ comm_stream.wait_event(state.compute_event)
276
+ state.gathered_grad = None
277
+
278
+ assert state.computed_u is not None
279
+
280
+ u_full = state.computed_u.to(COMM_DTYPE).contiguous().view(-1)
281
+
282
+ offset = 0
283
+ for dst in range(num_ranks):
284
+ n = split_elems_for_src(p, dst, num_ranks)
285
+ assert n > 0
286
+
287
+ su = u_full.narrow(0, offset, n)
288
+ per_dst[dst].append(su)
289
+ send_counts[dst] += n
290
+ offset += n
291
+
292
+ assert offset == u_full.numel()
293
+
294
+ lengths = [len(v) for v in per_dst]
295
+ if all(l > 0 for l in lengths):
296
+ assert all(
297
+ l == lengths[0] for l in lengths
298
+ ), "All destination ranks must have the same number of sharded tensor"
299
+ # list[list[Tensor]] -> list[Tensor]
300
+ per_dst = [t for dst in per_dst for t in dst]
301
+ send_buf = torch.cat(per_dst, dim=0)
302
+ else:
303
+ # all_to_all requires participation from all ranks
304
+ # Even non-owner ranks must join the collective call
305
+ send_buf = torch.empty(0, dtype=COMM_DTYPE, device="cuda")
306
+
307
+ # Compute receive sizes and allocate receiving buffers
308
+ recv_counts = [0] * num_ranks
309
+
310
+ for src in range(num_ranks):
311
+ total = 0
312
+ for p in params:
313
+ state = param_to_state[id(p)]
314
+ if state.worker_rank != src:
315
+ continue
316
+ total += split_elems_for_src(p, rank, num_ranks)
317
+ recv_counts[src] = total
318
+
319
+ recv_total = sum(recv_counts)
320
+ assert recv_total > 0
321
+ recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda")
322
+
323
+ #All2All
324
+ dist.all_to_all_single(
325
+ recv_buf,
326
+ send_buf,
327
+ output_split_sizes=recv_counts,
328
+ input_split_sizes=send_counts,
329
+ group=process_group,
330
+ )
331
+
332
+ # Copy to pre-allocated scattered_u buffer from the received buffer
333
+ #
334
+ # recv_buf (num ranks = 3, local_rank = 0)
335
+ #
336
+ # From rank 0 From rank 1 From rank 2
337
+ # | p1_0, p2_0, p3_0 | p4_0 | p5_0, p6_0 |
338
+ #
339
+ # Outer loop:
340
+ # rank 0 -> rank 1 -> rank2
341
+ #
342
+ # Inner loop:
343
+ # src(0) : p1_0 -> p2_0 -> p3_0
344
+ # src(1) : p4_0
345
+ # src(2) : p5_0 -> p6_0
346
+
347
+ comm_stream.wait_event(alloc_event)
348
+
349
+ off = 0
350
+ for src in range(num_ranks):
351
+ block = recv_counts[src]
352
+ if block == 0:
353
+ continue
354
+
355
+ inner_off = 0
356
+ for p in params:
357
+ state = param_to_state[id(p)]
358
+ if state.worker_rank != src:
359
+ continue
360
+ n = split_elems_for_src(p, rank, num_ranks)
361
+ assert n > 0
362
+
363
+ flat_local = recv_buf.narrow(0, off + inner_off,
364
+ n).view_as(p.to_local())
365
+ state.scattered_u.copy_(flat_local)
366
+
367
+ state.scatter_event = torch.cuda.Event()
368
+ state.scatter_event.record(comm_stream)
369
+ inner_off += n
370
+
371
+ assert inner_off == block
372
+ off += block
373
+
374
+
375
+ def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
376
+ compute_stream):
377
+ """
378
+ Update sharded parameter p with the scattered_u.
379
+ Only worker_rank frees computed_u.
380
+ """
381
+ with torch.cuda.stream(compute_stream):
382
+ if state.scatter_event is None:
383
+ raise RuntimeError("Scatter event must be set before update")
384
+ compute_stream.wait_event(state.scatter_event)
385
+ u_dtensor = DTensor.from_local(
386
+ state.scattered_u,
387
+ placements=p.placements,
388
+ device_mesh=p.device_mesh,
389
+ )
390
+
391
+ state.scattered_u = u_dtensor
392
+
393
+ if rank == state.worker_rank:
394
+ # Free computed_u
395
+ state.computed_u = None
396
+
397
+ Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay)
398
+ state.scattered_u = None
399
+ u_dtensor = None
400
+
401
+ scales_full = Muon._compute_scales(p, state.qk_clip_state)
402
+ if scales_full is not None:
403
+ num_ranks = dist.get_world_size(group=state.process_group)
404
+ local_rank = dist.get_rank(group=state.process_group)
405
+ scales_local = scales_full.chunk(num_ranks, dim=0)[local_rank]
406
+ scales_local = DTensor.from_local(
407
+ scales_local,
408
+ placements=p.placements,
409
+ device_mesh=p.device_mesh,
410
+ )
411
+ Muon._qk_clip(p, scales_local, state.qk_clip_state.head_dim)
412
+
413
+
414
+ def default_is_muon(name, x):
415
+ skip_keys = ["embed_tokens", "lm_head", "tok_embeddings", "output"]
416
+ return x.ndim >= 2 and not any(key in name for key in skip_keys)
417
+
418
+
419
+ def get_default_muon_param_groups(model, is_muon_func=default_is_muon):
420
+ muon_params, muon_names = [], []
421
+ non_muon_params = []
422
+
423
+ for n, p in model.named_parameters():
424
+ if not p.requires_grad:
425
+ continue
426
+ if is_muon_func(n, p):
427
+ muon_params.append(p)
428
+ muon_names.append(n)
429
+ else:
430
+ non_muon_params.append(p)
431
+
432
+ return [
433
+ {
434
+ "params": muon_params,
435
+ "names": muon_names,
436
+ "use_muon": True,
437
+ },
438
+ {
439
+ "params": non_muon_params,
440
+ "use_muon": False,
441
+ },
442
+ ]
443
+
444
+
445
+ def parse_qk_layer(name: str) -> tuple[str | None, int]:
446
+ """
447
+ Parse a parameter name to check if it is a query/key projection layer
448
+ ('wq', 'wk', 'q_proj', 'k_proj') and return (kind, layer_index).
449
+
450
+ Returns:
451
+ (kind, layer_idx) or (None, -1) if not matched.
452
+
453
+ Example:
454
+ 'model.3.attn.wq.weight' -> ('wq', 3)
455
+ 'model.5.attn.wk.weight' -> ('wk', 5)
456
+ 'model.2.attn.q_proj.weight' -> ('q_proj', 2)
457
+ 'model.7.attn.k_proj.weight' -> ('k_proj', 7)
458
+ 'model.4.attn.v_proj.weight' -> (None, -1)
459
+ """
460
+ parts = name.split('.')
461
+ if len(parts) < 3:
462
+ return None, -1
463
+
464
+ kind = parts[-2]
465
+
466
+ layer_idx = -1
467
+ for part in reversed(parts):
468
+ if part.isdigit():
469
+ layer_idx = int(part)
470
+ break
471
+
472
+ if kind in ('wq', 'wk', 'q_proj', 'k_proj'):
473
+ return kind, layer_idx
474
+
475
+ return None, -1
476
+
477
+
478
+ @dataclass
479
+ class QKClipInfo:
480
+ """Per-parameter dynamic info computed from config + runtime logits."""
481
+ kind: Optional[str] # 'wq'/'q_proj' or 'wk'/'k_proj' or None
482
+ indices: List[int] # which heads to consider for clipping
483
+ head_dim: int # from config
484
+ threshold: float # from config
485
+ logit: Optional[torch.Tensor]
486
+
487
+
488
+ class Muon(torch.optim.Optimizer):
489
+ """
490
+ Muon - MomentUm Orthogonalized by Newton-schulz
491
+
492
+ Muon internally runs standard SGD-momentum, and then performs an orthogonalization post-
493
+ processing step, in which each 2D parameter's update is replaced with the nearest orthogonal
494
+ matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has
495
+ the advantage that it can be stably run in bfloat16 on the GPU.
496
+
497
+ Some warnings:
498
+ - We believe this optimizer is unlikely to work well for training with small batch size.
499
+ - We believe it may not work well for finetuning pretrained models, but we haven't tested this.
500
+
501
+ Arguments:
502
+ model: The model to be optimized by Muon.
503
+ is_muon_func: A function that takes a parameter and its name, and returns whether the parameter should be optimized by Muon.
504
+ lr: The learning rate. The updates will have spectral norm of `lr`. (0.02 is a good default)
505
+ momentum: The momentum used by the internal SGD. (0.95 is a good default)
506
+ nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended)
507
+ ns_steps: The number of Newton-Schulz iterations to run. (6 is probably always enough)
508
+ weight_decay: The weight decay for Muon and AdamW.
509
+ {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW as well.
510
+ adamw_lr: The learning rate for the internal AdamW.
511
+ adamw_betas: The betas for the internal AdamW.
512
+ adamw_eps: The epsilon for the internal AdamW.
513
+ none_grad: Whether to set p.grad to None after gathering the gradients. This can save memory.
514
+ debug: Whether to print debug information.
515
+ clip_info : Configuration for QK clipping. Expected keys:
516
+ - "q_indices" (list[int]): Indices of query heads to consider.
517
+ - "k_indices" (list[int]): Indices of key heads to consider.
518
+ - "head_dim" (int): Dimensionality of each attention head.
519
+ - "threshold" (float): Threshold value; heads whose QK logits exceed
520
+ this value will be scaled down.
521
+ Default is:
522
+ {
523
+ "q_indices": [],
524
+ "k_indices": [],
525
+ "head_dim": 128,
526
+ "threshold": 100
527
+ }
528
+ overlap_step : How many all2all gather, compute operations are launched in advance
529
+ before the corresponding all2all scatter steps begin.
530
+ A higher overlap_step increases memory usage but can improve
531
+ performance by overlapping communication.
532
+ Parallel muon only.
533
+ """
534
+
535
+ def __init__(self,
536
+ params,
537
+ lr=1e-3,
538
+ momentum=0.95,
539
+ nesterov=True,
540
+ ns_steps=5,
541
+ weight_decay=0.1,
542
+ adamw_betas=(0.9, 0.95),
543
+ adamw_eps=1e-8,
544
+ none_grad=True,
545
+ debug=False,
546
+ clip_config={
547
+ "q_indices": [],
548
+ "k_indices": [],
549
+ "head_dim": 128,
550
+ "threshold": 100
551
+ },
552
+ overlap_step=5):
553
+ defaults = dict(
554
+ lr=lr,
555
+ weight_decay=weight_decay,
556
+ momentum=momentum,
557
+ nesterov=nesterov,
558
+ ns_steps=ns_steps,
559
+ adamw_betas=adamw_betas,
560
+ adamw_eps=adamw_eps,
561
+ none_grad=none_grad,
562
+ use_muon=True,
563
+ )
564
+ error_message = "The key 'use_muon' is not set in parameter group {idx}. Assuming all parameters in the group will use muon optimization, which may lead to unexpected behavior."
565
+ instruction_code = "\n\n please follow this code snippet \n```optimizer = get_kernel('motif-technologies/optimizer')\n\n\nparams = optimizer.muon.get_default_muon_param_groups(model)\n\noptim = optimizer.Muon(params, ...)```"
566
+
567
+ if isinstance(params, types.GeneratorType):
568
+ raise ValueError(error_message.format(idx=0) + instruction_code)
569
+ for _idx, param_group in enumerate(params):
570
+ if param_group.get("use_muon", None) is None:
571
+ raise ValueError(
572
+ error_message.format(idx=_idx) + instruction_code)
573
+
574
+ super().__init__(params, defaults)
575
+
576
+ self.rank = None
577
+
578
+ self.comm_stream = torch.cuda.Stream()
579
+ self.compute_stream = torch.cuda.Stream()
580
+ self.debug = debug
581
+ self.clip_config = clip_config
582
+ self.overlap_step = overlap_step
583
+
584
+ def _calc_flops(self, G, steps):
585
+ assert len(G.shape) == 2
586
+ M, N = G.shape
587
+ if M > N:
588
+ M, N = N, M
589
+
590
+ return steps * ((M**3) * 2 + (M**2 * N) * 4 + M * N * 2 + M**2 * 3)
591
+
592
+ def adjust_lr_for_muon(self, lr, param_shape):
593
+ A, B = param_shape[:2]
594
+ # We adjust the learning rate and weight decay based on the size of the parameter matrix
595
+ # as describted in the paper
596
+ adjusted_ratio = 0.2 * math.sqrt(max(A, B))
597
+ adjusted_lr = lr * adjusted_ratio
598
+ return adjusted_lr
599
+
600
+ def get_shard_mesh(self, p):
601
+ """
602
+ Get the shard mesh for a parameter p on the given rank.
603
+ """
604
+ assert isinstance(
605
+ p, DTensor), "Parallel Muon only supports DTensor parameters."
606
+
607
+ if p.placements == (Shard(dim=0), ):
608
+ # Case for FSDP
609
+ process_group = p.device_mesh.get_group(mesh_dim=0)
610
+ if self.rank is None:
611
+ self.rank = dist.get_rank(group=process_group)
612
+ else:
613
+ assert self.rank == dist.get_rank(group=process_group)
614
+ return p.device_mesh.mesh, p.device_mesh.get_group(mesh_dim=0)
615
+ elif p.placements == (Replicate(), Shard(dim=0)):
616
+ # Case for HSDP
617
+ process_group = p.device_mesh.get_group(mesh_dim=1)
618
+ if self.rank is None:
619
+ self.rank = dist.get_rank(group=process_group)
620
+ else:
621
+ assert self.rank == dist.get_rank(group=process_group)
622
+ for i, shard_mesh in enumerate(p.device_mesh.mesh):
623
+ if self.rank in shard_mesh:
624
+ return shard_mesh, p.device_mesh.get_group(mesh_dim=1)
625
+ else:
626
+ raise ValueError(f"Unsupported placements ({p.placements}).")
627
+
628
+ def init_state_and_assign_params(self, names, params, group, qk_logits):
629
+ param_to_state = {}
630
+ param_to_flops = {}
631
+
632
+ total_flops = 0
633
+ for p in params:
634
+ g = p.grad
635
+ if g is None:
636
+ continue
637
+ assert g.ndim == 2, "Muon only supports 2D parameters."
638
+
639
+ flops = self._calc_flops(g, group["ns_steps"])
640
+ param_to_flops[id(p)] = flops
641
+ total_flops += flops
642
+
643
+ if self.debug:
644
+ print(f"Total TFLOPs for Muon: {total_flops / 1e12:.2f} TFLOPs",
645
+ flush=True)
646
+
647
+ paired = list(zip(names, params))
648
+
649
+ paired_sorted = sorted(paired,
650
+ key=lambda x: param_to_flops[id(x[1])],
651
+ reverse=True)
652
+
653
+ names_sorted, params_sorted = zip(*paired_sorted)
654
+ ordered_names = list(names_sorted)
655
+ ordered_params = list(params_sorted)
656
+
657
+ round_robin = 0
658
+ mesh = None
659
+ shard_mesh = None
660
+ process_group = None
661
+ for n, p in zip(ordered_names, ordered_params):
662
+ if mesh is None:
663
+ mesh = p.device_mesh
664
+ shard_mesh, process_group = self.get_shard_mesh(p)
665
+ elif mesh != p.device_mesh:
666
+ raise ValueError("All parameters must be on the same mesh.")
667
+ num_ranks = dist.get_world_size(group=process_group)
668
+ param_to_state[id(p)] = _muon_state()
669
+ param_to_state[id(
670
+ p)].worker_rank = shard_mesh[round_robin].item() % num_ranks
671
+ param_to_state[id(p)].process_group = process_group
672
+ qk_clip_state = self.get_qk_clip_info(n, qk_logits)
673
+ param_to_state[id(p)].qk_clip_state = qk_clip_state
674
+ round_robin = (round_robin + 1) % len(shard_mesh)
675
+
676
+ return param_to_state, ordered_params
677
+
678
+ def base(self, names, params, group, lr, weight_decay, momentum,
679
+ qk_logits):
680
+ # generate weight updates in distributed fashion
681
+ for n, p in zip(names, params):
682
+ g = p.grad
683
+ if g is None:
684
+ continue
685
+ if g.ndim > 2:
686
+ g = g.view(g.size(0), -1)
687
+ assert g is not None
688
+
689
+ # calc update
690
+ state = self.state[p]
691
+ if "momentum_buffer" not in state:
692
+ state["momentum_buffer"] = torch.zeros_like(g)
693
+ buf = state["momentum_buffer"]
694
+ buf.mul_(momentum).add_(g)
695
+ if group["nesterov"]:
696
+ g = g.add(buf, alpha=momentum)
697
+ else:
698
+ g = buf
699
+
700
+ u = _zeropower_via_newtonschulz5(g.to(COMM_DTYPE),
701
+ steps=group["ns_steps"])
702
+
703
+ adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
704
+ Muon._update_p(p, u, lr, adjusted_lr, weight_decay)
705
+
706
+ qk_clip_state = self.get_qk_clip_info(n, qk_logits)
707
+
708
+ scales_full = self._compute_scales(p, qk_clip_state)
709
+ if scales_full is not None:
710
+ Muon._qk_clip(p, scales_full, qk_clip_state.head_dim)
711
+
712
+ def _update_g(self, p, g, group, momentum):
713
+ # calc update
714
+ state = self.state[p]
715
+ buf = state.setdefault("momentum_buffer", torch.zeros_like(g))
716
+ torch.add(g, buf, alpha=momentum, out=buf)
717
+ if group["nesterov"]:
718
+ g.add_(buf, alpha=momentum)
719
+ return g
720
+ return buf
721
+
722
+ @staticmethod
723
+ def _update_p(p, u, lr, adjusted_lr, weight_decay):
724
+ # apply weight decay
725
+ p.data.mul_(1 - lr * weight_decay)
726
+ # apply update
727
+ p.data.add_(u, alpha=-adjusted_lr)
728
+
729
+ def get_qk_clip_info(self, n, qk_logits):
730
+ head_dim = self.clip_config.get('head_dim')
731
+ threshold = self.clip_config.get('threshold')
732
+ kind, layer_idx = parse_qk_layer(n)
733
+
734
+ logit, indices = None, []
735
+ if qk_logits is not None and kind is not None:
736
+ logit = qk_logits[layer_idx]
737
+ indices_key = 'q_indices' if 'q' in kind else 'k_indices'
738
+ indices = self.clip_config.get(indices_key, []) or []
739
+
740
+ return QKClipInfo(
741
+ kind=kind,
742
+ indices=indices,
743
+ head_dim=head_dim,
744
+ threshold=threshold,
745
+ logit=logit,
746
+ )
747
+
748
+ @staticmethod
749
+ def _compute_scales(p, qk_clip_state):
750
+ kind = qk_clip_state.kind
751
+ indices = qk_clip_state.indices
752
+ head_dim = qk_clip_state.head_dim
753
+ threshold = qk_clip_state.threshold
754
+ logit = qk_clip_state.logit
755
+
756
+ H_global = p.shape[0] // head_dim
757
+ scales_full = torch.ones(H_global, device=p.data.device)
758
+ scaling = 0
759
+
760
+ for logit_idx, head_idx in enumerate(indices):
761
+ v_ele = float(logit[logit_idx])
762
+ if v_ele > threshold:
763
+ new_scale = math.sqrt(threshold / v_ele)
764
+ if new_scale < scales_full[head_idx]:
765
+ scales_full[head_idx] = new_scale
766
+ logger.info(
767
+ f"[{kind}] Head {head_idx} exceeded threshold "
768
+ f"(value={v_ele:.4f}, threshold={threshold:.4f}) -> applying scale={new_scale:.4f}"
769
+ )
770
+ scaling += 1
771
+
772
+ return scales_full if scaling > 0 else None
773
+
774
+ @staticmethod
775
+ def _qk_clip(p, scales, head_dim):
776
+ W = p.data.view(-1, head_dim, p.data.shape[1])
777
+ W.mul_(scales.view(-1, 1, 1))
778
+
779
+ def parallel(self, names, params, group, lr, weight_decay, momentum,
780
+ qk_logits):
781
+ """
782
+ Perform a parallel optimization step using Muon.
783
+ """
784
+
785
+ for p in params:
786
+ g = p.grad
787
+ if g is None:
788
+ continue
789
+ if g.ndim > 2:
790
+ g = g.view(g.size(0), -1)
791
+
792
+ # Update g in the local rank
793
+ g = self._update_g(
794
+ p,
795
+ g,
796
+ group,
797
+ momentum=momentum,
798
+ )
799
+ p.grad = g
800
+
801
+ param_to_state, ordered_params = self.init_state_and_assign_params(
802
+ names, params, group, qk_logits)
803
+
804
+ assert self.rank is not None
805
+
806
+ def enqueue_all2all_gather(start_idx, chunk_size):
807
+ target_params = ordered_params[start_idx:start_idx + chunk_size]
808
+ if target_params:
809
+ alloc_event = _alloc_gathered_grad(target_params,
810
+ param_to_state, self.rank,
811
+ self.compute_stream)
812
+ _all2all_gather(target_params, param_to_state, self.rank,
813
+ self.comm_stream, group["none_grad"],
814
+ alloc_event)
815
+
816
+ def enqueue_computes(start_idx, chunk_size):
817
+ for p in ordered_params[start_idx:start_idx + chunk_size]:
818
+ state = param_to_state[id(p)]
819
+ _compute_u(p, state, group["ns_steps"], self.rank,
820
+ self.compute_stream)
821
+
822
+ def enqueue_all2all_scatter(start_idx, chunk_size):
823
+ target_params = ordered_params[start_idx:start_idx + chunk_size]
824
+ if target_params:
825
+ alloc_event = _alloc_scattered_u(target_params, param_to_state,
826
+ self.rank,
827
+ self.compute_stream)
828
+ _all2all_scatter(target_params, param_to_state, self.rank,
829
+ self.comm_stream, alloc_event)
830
+
831
+ def enqueue_update_param(start_idx, chunk_size):
832
+ for p in ordered_params[start_idx:start_idx + chunk_size]:
833
+ state = param_to_state[id(p)]
834
+ adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
835
+ _update_param(p, state, lr, adjusted_lr, weight_decay,
836
+ self.rank, self.compute_stream)
837
+
838
+ chunk_size = dist.get_world_size(param_to_state[id(
839
+ params[0])].process_group)
840
+
841
+ # Wait grad update
842
+ self.comm_stream.wait_stream(torch.cuda.current_stream())
843
+
844
+ overlap_step = self.overlap_step
845
+ for i in range(0, overlap_step):
846
+ enqueue_all2all_gather(i * chunk_size, chunk_size)
847
+ enqueue_computes(i * chunk_size, chunk_size)
848
+
849
+ for i in range(0, len(params) + chunk_size - 1, chunk_size):
850
+ enqueue_all2all_scatter(i, chunk_size)
851
+ enqueue_all2all_gather(i + overlap_step * chunk_size, chunk_size)
852
+ enqueue_update_param(i, chunk_size)
853
+ enqueue_computes(i + overlap_step * chunk_size, chunk_size)
854
+
855
+ # Wait the last update_param to finish
856
+ torch.cuda.current_stream().wait_stream(self.compute_stream)
857
+
858
+ @staticmethod
859
+ def _fused_adamw(
860
+ params: list[torch.Tensor],
861
+ grads: list[torch.Tensor],
862
+ exp_avgs: list[torch.Tensor],
863
+ exp_avg_sqs: list[torch.Tensor],
864
+ max_exp_avg_sqs: list[torch.Tensor],
865
+ state_steps: list[torch.Tensor],
866
+ amsgrad: bool,
867
+ beta1: float,
868
+ beta2: float,
869
+ lr: Union[float, torch.Tensor],
870
+ weight_decay: float,
871
+ eps: float,
872
+ maximize: bool,
873
+ ) -> None:
874
+ if not params:
875
+ return
876
+
877
+ # We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer
878
+ # treating it as a scalar.
879
+ lr_dict: Optional[DeviceDict] = ({
880
+ lr.device: lr
881
+ } if isinstance(lr, torch.Tensor) and str(lr.device) != "cpu" else
882
+ None)
883
+ grouped_tensors = torch.optim.Optimizer._group_tensors_by_device_and_dtype(
884
+ [
885
+ params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs,
886
+ state_steps
887
+ ] # type: ignore[list-item]
888
+ )
889
+ for (device, _), (
890
+ (
891
+ device_params_,
892
+ device_grads_,
893
+ device_exp_avgs_,
894
+ device_exp_avg_sqs_,
895
+ device_max_exp_avg_sqs,
896
+ device_state_steps_,
897
+ ),
898
+ _,
899
+ ) in grouped_tensors.items():
900
+ device_params = cast(list[torch.Tensor], device_params_)
901
+ device_grads = cast(list[torch.Tensor], device_grads_)
902
+ device_exp_avgs = cast(list[torch.Tensor], device_exp_avgs_)
903
+ device_exp_avg_sqs = cast(list[torch.Tensor], device_exp_avg_sqs_)
904
+ device_state_steps = cast(list[torch.Tensor], device_state_steps_)
905
+
906
+ if lr_dict is not None and device not in lr_dict:
907
+ lr_dict[device] = lr.to(
908
+ device=device,
909
+ non_blocking=True) # type: ignore[union-attr]
910
+ lr = lr_dict[device]
911
+ torch._foreach_add_(device_state_steps, 1)
912
+ func = torch._fused_adamw_
913
+ func(
914
+ device_params,
915
+ device_grads,
916
+ device_exp_avgs,
917
+ device_exp_avg_sqs,
918
+ device_max_exp_avg_sqs, # type: ignore[arg-type]
919
+ device_state_steps,
920
+ amsgrad=amsgrad,
921
+ lr=lr, # type: ignore[arg-type]
922
+ beta1=beta1,
923
+ beta2=beta2,
924
+ weight_decay=weight_decay,
925
+ eps=eps,
926
+ maximize=maximize,
927
+ )
928
+
929
+ def step(self, closure=None, qk_logits=None):
930
+ """Perform a single optimization step.
931
+
932
+ Args:
933
+ closure (Callable, optional): A closure that reevaluates the model
934
+ and returns the loss.
935
+ qk_logits (dict[int, Tensor], optional): A dictionary mapping layer indices
936
+ to 1D tensors of shape (num_heads,), representing the maximum
937
+ QK logits across all tokens, computed as
938
+ (1 / sqrt(head_dim)) * (Q @ K^T).
939
+ """
940
+ loss = None
941
+ if closure is not None:
942
+ with torch.enable_grad():
943
+ loss = closure()
944
+
945
+ for group in self.param_groups:
946
+ params = group["params"]
947
+
948
+ if group["use_muon"]:
949
+ ############################
950
+ # Muon #
951
+ ############################
952
+ lr = group["lr"]
953
+ weight_decay = group["weight_decay"]
954
+ momentum = group["momentum"]
955
+ names = group["names"]
956
+
957
+ param_dtensors = []
958
+ param_tensors = []
959
+ name_dtensors = []
960
+ name_tensors = []
961
+
962
+ for n, p in zip(names, params):
963
+ if p is None or p.grad is None:
964
+ continue
965
+ if isinstance(p.data, DTensor):
966
+ if all(
967
+ isinstance(placement, Replicate)
968
+ for placement in p.placements):
969
+ param_tensors.append(p)
970
+ name_tensors.append(n)
971
+ else:
972
+ param_dtensors.append(p)
973
+ name_dtensors.append(n)
974
+ elif isinstance(p.data, torch.Tensor):
975
+ param_tensors.append(p)
976
+ name_tensors.append(n)
977
+ else:
978
+ raise TypeError(
979
+ f"Unsupported parameter type: {type(p.data)}")
980
+
981
+ if self.debug:
982
+ print(
983
+ f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors",
984
+ flush=True,
985
+ )
986
+
987
+ if len(param_dtensors) > 0:
988
+ if not dist.is_initialized():
989
+ raise RuntimeError(
990
+ "Parallel Muon requires torch.distributed to be initialized."
991
+ )
992
+
993
+ self.parallel(
994
+ name_dtensors,
995
+ param_dtensors,
996
+ group,
997
+ lr=lr,
998
+ weight_decay=weight_decay,
999
+ momentum=momentum,
1000
+ qk_logits=qk_logits,
1001
+ )
1002
+
1003
+ if len(param_tensors) > 0:
1004
+ self.base(
1005
+ name_tensors,
1006
+ param_tensors,
1007
+ group,
1008
+ lr=lr,
1009
+ weight_decay=weight_decay,
1010
+ momentum=momentum,
1011
+ qk_logits=qk_logits,
1012
+ )
1013
+
1014
+ else:
1015
+ ############################
1016
+ # AdamW backup #
1017
+ ############################
1018
+
1019
+ params_with_grads = []
1020
+ grads = []
1021
+ moment1 = []
1022
+ moment2 = []
1023
+ max_exp_avg_sqs = []
1024
+ state_steps = []
1025
+ lr = group["lr"]
1026
+ beta1, beta2 = group["adamw_betas"]
1027
+ eps = group["adamw_eps"]
1028
+ weight_decay = group["weight_decay"]
1029
+
1030
+ for p in params:
1031
+ g = p.grad
1032
+ if g is None:
1033
+ continue
1034
+ state = self.state[p]
1035
+ params_with_grads.append(p)
1036
+ grads.append(g)
1037
+ if "step" not in state:
1038
+ state["step"] = (torch.zeros((),
1039
+ dtype=torch.float32,
1040
+ device=p.device))
1041
+ state["moment1"] = torch.zeros_like(g)
1042
+ state["moment2"] = torch.zeros_like(g)
1043
+ moment1.append(state["moment1"])
1044
+ moment2.append(state["moment2"])
1045
+ if not isinstance(state["step"], torch.Tensor):
1046
+ step_tensor = torch.tensor(state["step"],
1047
+ dtype=torch.float32,
1048
+ device=p.device)
1049
+ else:
1050
+ step_tensor = state["step"]
1051
+ state_steps.append(step_tensor)
1052
+
1053
+ self._fused_adamw(
1054
+ params_with_grads,
1055
+ grads,
1056
+ moment1,
1057
+ moment2,
1058
+ max_exp_avg_sqs,
1059
+ state_steps,
1060
+ amsgrad=False,
1061
+ beta1=beta1,
1062
+ beta2=beta2,
1063
+ lr=lr,
1064
+ weight_decay=weight_decay,
1065
+ eps=eps,
1066
+ maximize=False,
1067
+ )
1068
+
1069
+ return loss
build/torch29-cxx11-rocm63-x86_64-linux/optimizer/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .muon import Muon
2
+
3
+ __all__ = [
4
+ "Muon",
5
+ ]
build/torch29-cxx11-rocm63-x86_64-linux/optimizer/_ops.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from . import _optimizer_811726c_dirty
3
+ ops = torch.ops._optimizer_811726c_dirty
4
+
5
+ def add_op_namespace_prefix(op_name: str):
6
+ """
7
+ Prefix op by namespace.
8
+ """
9
+ return f"_optimizer_811726c_dirty::{op_name}"
build/torch29-cxx11-rocm63-x86_64-linux/optimizer/_optimizer_811726c_dirty.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0661740cd0f97ca56ef83979c5a5fa059bcba411148f89d836e9305065578e73
3
+ size 1749264
build/torch29-cxx11-rocm63-x86_64-linux/optimizer/matmul_transpose_triton.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MIT License
2
+ #
3
+ # Copyright (c) 2025 Tianyang Lin
4
+ #
5
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ # of this software and associated documentation files (the "Software"), to deal
7
+ # in the Software without restriction, including without limitation the rights
8
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ # copies of the Software, and to permit persons to whom the Software is
10
+ # furnished to do so, subject to the following conditions:
11
+ #
12
+ # The above copyright notice and this permission notice shall be included in all
13
+ # copies or substantial portions of the Software.
14
+ #
15
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ # SOFTWARE.
22
+
23
+ import torch
24
+ import triton
25
+ import triton.language as tl
26
+
27
+
28
+ def get_autotune_config():
29
+ return [
30
+ triton.Config(
31
+ {
32
+ 'BLOCK_SIZE_M': blk_m,
33
+ 'BLOCK_SIZE_K': blk_k,
34
+ 'GROUP_SIZE_M': grp_sz
35
+ },
36
+ num_stages=n_stages,
37
+ num_warps=n_warps) for blk_m in [32, 64, 128]
38
+ for blk_k in [32, 64] for grp_sz in [8] for n_stages in [3, 4, 5]
39
+ for n_warps in [4, 8]
40
+ ]
41
+
42
+
43
+ @triton.autotune(
44
+ configs=get_autotune_config(),
45
+ key=['M', 'K'],
46
+ )
47
+ @triton.jit
48
+ def mmt_kernel(x, y, M, K, stride_xm, stride_xk, stride_ym, stride_yn,
49
+ BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
50
+ GROUP_SIZE_M: tl.constexpr):
51
+ """
52
+ Core kernel jit function of matmul_transpose that computes y = x @ x.T
53
+ The code is a simple adaptation from the triton `matmul` tutorial:
54
+ https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html
55
+ """
56
+ pid = tl.program_id(axis=0)
57
+ num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
58
+ num_pid_n = tl.cdiv(M, BLOCK_SIZE_M)
59
+ num_pid_in_group = GROUP_SIZE_M * num_pid_n
60
+ group_id = pid // num_pid_in_group
61
+ first_pid_m = group_id * GROUP_SIZE_M
62
+ group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
63
+ pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
64
+ pid_n = (pid % num_pid_in_group) // group_size_m
65
+ if pid_m > pid_n:
66
+ return
67
+
68
+ offs_xm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
69
+ offs_xn = (pid_n * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
70
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
71
+ # we use a & b ptrs to denote different rows of x.
72
+ a_ptrs = x + (offs_xm[:, None] * stride_xm + offs_k[None, :] * stride_xk)
73
+ b_ptrs = x + (offs_xn[:, None] * stride_xm + offs_k[None, :] * stride_xk)
74
+
75
+ accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_M), dtype=tl.float32)
76
+
77
+ for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
78
+ a = tl.load(a_ptrs,
79
+ mask=offs_k[None, :] < K - k * BLOCK_SIZE_K,
80
+ other=0.0)
81
+ b = tl.load(b_ptrs,
82
+ mask=offs_k[None, :] < K - k * BLOCK_SIZE_K,
83
+ other=0.0)
84
+ accumulator = tl.dot(a, tl.permute(b, (1, 0)), accumulator)
85
+ a_ptrs += BLOCK_SIZE_K * stride_xk
86
+ b_ptrs += BLOCK_SIZE_K * stride_xk
87
+ # use dtype.element_ty to accommodate different input datatypes as in cpp templates
88
+ # https://github.com/triton-lang/triton/issues/2252
89
+ c = accumulator.to(x.dtype.element_ty)
90
+
91
+ offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
92
+ offs_cn = pid_n * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
93
+ c_ptrs = y + stride_ym * offs_cm[:, None] + stride_yn * offs_cn[None, :]
94
+ c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M)
95
+ tl.store(c_ptrs, c, mask=c_mask)
96
+
97
+ # transpose and copy
98
+ if pid_m < pid_n:
99
+ ct_ptrs = y + stride_ym * offs_cn[:,
100
+ None] + stride_yn * offs_cm[None, :]
101
+ ct_mask = (offs_cn[:, None] < M) & (offs_cm[None, :] < M)
102
+ tl.store(ct_ptrs, tl.permute(c, (1, 0)), mask=ct_mask)
103
+
104
+
105
+ def matmul_transpose_assign(d_in, d_out):
106
+ assert d_in.is_cuda, "Input `d_in` must be a CUDA tensor"
107
+ assert d_out.is_cuda, "Input `d_out` must be a CUDA tensor"
108
+ assert d_in.device == d_out.device, "Inputs `d_in` and `d_out` must be on the same CUDA device"
109
+ assert d_in.dtype == d_out.dtype, "Inputs must have the same data type"
110
+ assert d_in.ndim == 2, "Input `d_in` must be a 2D tensor"
111
+ assert d_out.ndim == 2, "Input `d_out` must be a 2D tensor"
112
+ assert d_in.size(0) == d_out.size(0) == d_out.size(0), \
113
+ "First dimension of `d_in` must match first and second dimension of `d_out`"
114
+
115
+ d_in = d_in.contiguous()
116
+ M, K = d_in.shape
117
+ grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(
118
+ M, META['BLOCK_SIZE_M']), )
119
+ with torch.cuda.device(d_in.device.index):
120
+ mmt_kernel[grid](d_in, d_out, M, K, d_in.stride(0), d_in.stride(1),
121
+ d_out.stride(0), d_out.stride(1))
122
+
123
+
124
+ def matmul_transpose(d_in):
125
+ M, _ = d_in.shape
126
+ d_out = torch.empty((M, M), device=d_in.device, dtype=d_in.dtype)
127
+ matmul_transpose_assign(d_in, d_out)
128
+ return d_out
build/torch29-cxx11-rocm63-x86_64-linux/optimizer/muon.py ADDED
@@ -0,0 +1,1069 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import math
3
+ import types
4
+ from dataclasses import dataclass
5
+ from typing import List, Optional, Union, cast
6
+
7
+ import torch
8
+ import torch.distributed as dist
9
+ from torch.distributed._tensor import DTensor, Replicate, Shard
10
+
11
+ from .matmul_transpose_triton import matmul_transpose_assign
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+ COMM_DTYPE = torch.bfloat16
16
+
17
+
18
+ # This code snippet is a modified version adapted from the following GitHub repositories:
19
+ # https://github.com/KellerJordan/Muon/blob/master/muon.py
20
+ # Muon's Newton–Schulz iteration causes high variance in singular values
21
+ # Idea: give each iteration its own 3 coefficients and optimize them via gradient descent.
22
+ @torch.no_grad()
23
+ # matmul_transpose_assign from : https://github.com/nil0x9/flash-muon
24
+ def _zeropower_via_newtonschulz5(G, steps):
25
+ """
26
+ Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
27
+ quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
28
+ of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
29
+ zero even beyond the point where the iteration no longer converges all the way to one everywhere
30
+ on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
31
+ where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
32
+ performance at all relative to UV^T, where USV^T = G is the SVD.
33
+ """
34
+ assert len(G.shape) == 2
35
+ assert G.dtype == COMM_DTYPE
36
+ X = G # no manual typecast
37
+
38
+ if G.size(0) > G.size(1):
39
+ X = X.T
40
+ # Ensure spectral norm is at most 1
41
+ X = X / (X.norm() + 1e-7)
42
+ buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
43
+ buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
44
+ # Perform the NS iterations
45
+ for a, b, c in [
46
+ (4.0848, -6.8946, 2.9270),
47
+ (3.9505, -6.3029, 2.6377),
48
+ (3.7418, -5.5913, 2.3037),
49
+ (2.8769, -3.1427, 1.2046),
50
+ (2.8366, -3.0525, 1.2012),
51
+ ]:
52
+ matmul_transpose_assign(X, buf1)
53
+ matmul_transpose_assign(buf1, buf2)
54
+ buf1.mul_(b).add_(buf2, alpha=c)
55
+ X = torch.addmm(X, buf1, X, alpha=1.0, beta=a)
56
+
57
+ if G.size(0) > G.size(1):
58
+ X = X.T
59
+ return X
60
+
61
+
62
+ @dataclass
63
+ class _muon_state:
64
+ # TODO: use Optional
65
+ worker_rank: int | None = None
66
+ gathered_grad: torch.Tensor | None = None
67
+ scattered_u: DTensor | None = None
68
+ computed_u: torch.Tensor | None = None
69
+ gather_event: torch.cuda.Event | None = None
70
+ compute_event: torch.cuda.Event | None = None
71
+ scatter_event: torch.cuda.Event | None = None
72
+ process_group = None
73
+ qk_clip_state = None
74
+
75
+
76
+ def split_elems_for_src(param, src_rank, num_ranks) -> int:
77
+ rows = param.shape[0]
78
+ cols = int(param.numel() // rows)
79
+ base, rem = divmod(rows, num_ranks)
80
+ my_rows = base + (1 if src_rank < rem else 0)
81
+ return my_rows * cols
82
+
83
+
84
+ @torch.no_grad()
85
+ def _alloc_gathered_grad(params, param_to_state, rank, compute_stream):
86
+ """
87
+ Pre-allocate gathered_grad buffer on compute_stream
88
+ before launching all2all gather
89
+ """
90
+ with torch.cuda.stream(compute_stream):
91
+ for p in params:
92
+ state = param_to_state[id(p)]
93
+ if rank == state.worker_rank:
94
+ num_ranks = dist.get_world_size(group=state.process_group)
95
+ state.gathered_grad = torch.empty(p.grad.numel(),
96
+ dtype=COMM_DTYPE,
97
+ device="cuda")
98
+ else:
99
+ state.gathered_grad = None
100
+
101
+ alloc_event = torch.cuda.Event()
102
+ alloc_event.record(compute_stream)
103
+ return alloc_event
104
+
105
+
106
+ @torch.no_grad()
107
+ def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad,
108
+ alloc_event):
109
+ """
110
+ All2all gathers shards so each owner rank reconstructs its full gradient
111
+ """
112
+ with torch.cuda.stream(comm_stream):
113
+ process_group = param_to_state[id(params[0])].process_group
114
+ num_ranks = dist.get_world_size(group=process_group)
115
+
116
+ # Construct sending buffers
117
+ per_dst = [[] for _ in range(num_ranks)]
118
+ send_counts = [0] * num_ranks
119
+
120
+ for p in params:
121
+ state = param_to_state[id(p)]
122
+ dst = state.worker_rank
123
+ assert dst < num_ranks
124
+ shard_elems = split_elems_for_src(p, rank, num_ranks)
125
+ g = p.grad
126
+ g = g.to_local().to(COMM_DTYPE).contiguous().view(-1)
127
+ assert g.numel() == shard_elems
128
+ per_dst[dst].append(g)
129
+ send_counts[dst] += shard_elems
130
+
131
+ assert any(
132
+ len(v) > 0 for v in per_dst
133
+ ), "At least one destination rank must receive a sharded tensor"
134
+ # list[list[Tensor]] -> list[Tensor]
135
+ per_dst = [t for dst in per_dst for t in dst]
136
+
137
+ send_buf = torch.cat(per_dst, dim=0)
138
+
139
+ owned_params = [
140
+ p for p in params if param_to_state[id(p)].worker_rank == rank
141
+ ]
142
+
143
+ # Compute receive sizes and allocate receiving buffers
144
+ recv_counts = [0] * num_ranks
145
+
146
+ for src in range(num_ranks):
147
+ total = 0
148
+ for p in owned_params:
149
+ state = param_to_state[id(p)]
150
+ assert state.worker_rank == rank
151
+ total += split_elems_for_src(p, src, num_ranks)
152
+ recv_counts[src] = total
153
+
154
+ recv_total = sum(recv_counts)
155
+ recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda")
156
+
157
+ #All2All
158
+ dist.all_to_all_single(
159
+ recv_buf,
160
+ send_buf,
161
+ output_split_sizes=recv_counts,
162
+ input_split_sizes=send_counts,
163
+ group=process_group,
164
+ )
165
+
166
+ # Reconstructs gathered grad from the received buffer
167
+ #
168
+ # recv_buf (num ranks = 3)
169
+ #
170
+ # From rank 0 From rank 1 From rank 2
171
+ # | p1_0, p2_0, p3_0 | p1_1, p2_1, p3_1 | p1_2, p2_2, p3_2 |
172
+ #
173
+ # Outer loop:
174
+ # rank 0 -> rank 1 -> rank2
175
+ #
176
+ # Inner loop:
177
+ # p1_n -> p2_n -> p3_n
178
+
179
+ comm_stream.wait_event(alloc_event)
180
+
181
+ off = 0
182
+ write_offsets = {id(p): 0 for p in owned_params}
183
+ for src in range(num_ranks):
184
+ if recv_counts[src] == 0:
185
+ continue
186
+
187
+ block = recv_counts[src]
188
+ inner_off = 0
189
+ for p in owned_params:
190
+ state = param_to_state[id(p)]
191
+ assert state.worker_rank == rank
192
+ n = split_elems_for_src(p, src, num_ranks)
193
+ assert n > 0
194
+
195
+ sg = recv_buf.narrow(0, off + inner_off, n)
196
+ woff = write_offsets[id(p)]
197
+ dst = state.gathered_grad.narrow(0, woff, n)
198
+ dst.copy_(sg)
199
+
200
+ write_offsets[id(p)] += n
201
+ inner_off += n
202
+ off += block
203
+
204
+ for p in params:
205
+ state = param_to_state[id(p)]
206
+ if state.worker_rank == rank:
207
+ state.gathered_grad = state.gathered_grad.view_as(p)
208
+ state.gather_event = torch.cuda.Event()
209
+ state.gather_event.record(comm_stream)
210
+ else:
211
+ state.gathered_grad = None
212
+ state.gather_event = None
213
+ if none_grad:
214
+ p.grad = None
215
+
216
+
217
+ @torch.no_grad()
218
+ def _compute_u(p, state, steps, rank, compute_stream):
219
+ """
220
+ On worker_rank, compute the orthogonalized update using Newton-Schulz iteration.
221
+ """
222
+ with torch.cuda.stream(compute_stream):
223
+ if rank == state.worker_rank:
224
+ if state.gather_event is None:
225
+ raise RuntimeError("Gather event must be set before compute.")
226
+ compute_stream.wait_event(state.gather_event)
227
+ u = _zeropower_via_newtonschulz5(state.gathered_grad, steps)
228
+ state.gathered_grad = None
229
+ state.computed_u = u
230
+ state.compute_event = torch.cuda.Event()
231
+ state.compute_event.record()
232
+ else:
233
+ state.computed_u = None
234
+ state.compute_event = None
235
+
236
+
237
+ @torch.no_grad()
238
+ def _alloc_scattered_u(params, param_to_state, rank, compute_stream):
239
+ """
240
+ Pre-allocate scattered_u buffer on compute_stream
241
+ before launching all2all gather
242
+ """
243
+ with torch.cuda.stream(compute_stream):
244
+ for p in params:
245
+ state = param_to_state[id(p)]
246
+ state.scattered_u = torch.empty_like(p.to_local(),
247
+ dtype=COMM_DTYPE)
248
+
249
+ alloc_event = torch.cuda.Event()
250
+ alloc_event.record(compute_stream)
251
+ return alloc_event
252
+
253
+
254
+ def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event):
255
+ """
256
+ All2all scatters full gradients to all ranks
257
+ """
258
+ with torch.cuda.stream(comm_stream):
259
+ process_group = param_to_state[id(params[0])].process_group
260
+ num_ranks = dist.get_world_size(group=process_group)
261
+ owned_params = [
262
+ p for p in params if param_to_state[id(p)].worker_rank == rank
263
+ ]
264
+
265
+ # Construct sending buffer
266
+ per_dst = [[] for _ in range(num_ranks)]
267
+ send_counts = [0] * num_ranks
268
+
269
+ if owned_params:
270
+ for p in owned_params:
271
+ state = param_to_state[id(p)]
272
+ if state.compute_event is None:
273
+ raise RuntimeError(
274
+ "Compute event must be set before scatter.")
275
+ comm_stream.wait_event(state.compute_event)
276
+ state.gathered_grad = None
277
+
278
+ assert state.computed_u is not None
279
+
280
+ u_full = state.computed_u.to(COMM_DTYPE).contiguous().view(-1)
281
+
282
+ offset = 0
283
+ for dst in range(num_ranks):
284
+ n = split_elems_for_src(p, dst, num_ranks)
285
+ assert n > 0
286
+
287
+ su = u_full.narrow(0, offset, n)
288
+ per_dst[dst].append(su)
289
+ send_counts[dst] += n
290
+ offset += n
291
+
292
+ assert offset == u_full.numel()
293
+
294
+ lengths = [len(v) for v in per_dst]
295
+ if all(l > 0 for l in lengths):
296
+ assert all(
297
+ l == lengths[0] for l in lengths
298
+ ), "All destination ranks must have the same number of sharded tensor"
299
+ # list[list[Tensor]] -> list[Tensor]
300
+ per_dst = [t for dst in per_dst for t in dst]
301
+ send_buf = torch.cat(per_dst, dim=0)
302
+ else:
303
+ # all_to_all requires participation from all ranks
304
+ # Even non-owner ranks must join the collective call
305
+ send_buf = torch.empty(0, dtype=COMM_DTYPE, device="cuda")
306
+
307
+ # Compute receive sizes and allocate receiving buffers
308
+ recv_counts = [0] * num_ranks
309
+
310
+ for src in range(num_ranks):
311
+ total = 0
312
+ for p in params:
313
+ state = param_to_state[id(p)]
314
+ if state.worker_rank != src:
315
+ continue
316
+ total += split_elems_for_src(p, rank, num_ranks)
317
+ recv_counts[src] = total
318
+
319
+ recv_total = sum(recv_counts)
320
+ assert recv_total > 0
321
+ recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda")
322
+
323
+ #All2All
324
+ dist.all_to_all_single(
325
+ recv_buf,
326
+ send_buf,
327
+ output_split_sizes=recv_counts,
328
+ input_split_sizes=send_counts,
329
+ group=process_group,
330
+ )
331
+
332
+ # Copy to pre-allocated scattered_u buffer from the received buffer
333
+ #
334
+ # recv_buf (num ranks = 3, local_rank = 0)
335
+ #
336
+ # From rank 0 From rank 1 From rank 2
337
+ # | p1_0, p2_0, p3_0 | p4_0 | p5_0, p6_0 |
338
+ #
339
+ # Outer loop:
340
+ # rank 0 -> rank 1 -> rank2
341
+ #
342
+ # Inner loop:
343
+ # src(0) : p1_0 -> p2_0 -> p3_0
344
+ # src(1) : p4_0
345
+ # src(2) : p5_0 -> p6_0
346
+
347
+ comm_stream.wait_event(alloc_event)
348
+
349
+ off = 0
350
+ for src in range(num_ranks):
351
+ block = recv_counts[src]
352
+ if block == 0:
353
+ continue
354
+
355
+ inner_off = 0
356
+ for p in params:
357
+ state = param_to_state[id(p)]
358
+ if state.worker_rank != src:
359
+ continue
360
+ n = split_elems_for_src(p, rank, num_ranks)
361
+ assert n > 0
362
+
363
+ flat_local = recv_buf.narrow(0, off + inner_off,
364
+ n).view_as(p.to_local())
365
+ state.scattered_u.copy_(flat_local)
366
+
367
+ state.scatter_event = torch.cuda.Event()
368
+ state.scatter_event.record(comm_stream)
369
+ inner_off += n
370
+
371
+ assert inner_off == block
372
+ off += block
373
+
374
+
375
+ def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
376
+ compute_stream):
377
+ """
378
+ Update sharded parameter p with the scattered_u.
379
+ Only worker_rank frees computed_u.
380
+ """
381
+ with torch.cuda.stream(compute_stream):
382
+ if state.scatter_event is None:
383
+ raise RuntimeError("Scatter event must be set before update")
384
+ compute_stream.wait_event(state.scatter_event)
385
+ u_dtensor = DTensor.from_local(
386
+ state.scattered_u,
387
+ placements=p.placements,
388
+ device_mesh=p.device_mesh,
389
+ )
390
+
391
+ state.scattered_u = u_dtensor
392
+
393
+ if rank == state.worker_rank:
394
+ # Free computed_u
395
+ state.computed_u = None
396
+
397
+ Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay)
398
+ state.scattered_u = None
399
+ u_dtensor = None
400
+
401
+ scales_full = Muon._compute_scales(p, state.qk_clip_state)
402
+ if scales_full is not None:
403
+ num_ranks = dist.get_world_size(group=state.process_group)
404
+ local_rank = dist.get_rank(group=state.process_group)
405
+ scales_local = scales_full.chunk(num_ranks, dim=0)[local_rank]
406
+ scales_local = DTensor.from_local(
407
+ scales_local,
408
+ placements=p.placements,
409
+ device_mesh=p.device_mesh,
410
+ )
411
+ Muon._qk_clip(p, scales_local, state.qk_clip_state.head_dim)
412
+
413
+
414
+ def default_is_muon(name, x):
415
+ skip_keys = ["embed_tokens", "lm_head", "tok_embeddings", "output"]
416
+ return x.ndim >= 2 and not any(key in name for key in skip_keys)
417
+
418
+
419
+ def get_default_muon_param_groups(model, is_muon_func=default_is_muon):
420
+ muon_params, muon_names = [], []
421
+ non_muon_params = []
422
+
423
+ for n, p in model.named_parameters():
424
+ if not p.requires_grad:
425
+ continue
426
+ if is_muon_func(n, p):
427
+ muon_params.append(p)
428
+ muon_names.append(n)
429
+ else:
430
+ non_muon_params.append(p)
431
+
432
+ return [
433
+ {
434
+ "params": muon_params,
435
+ "names": muon_names,
436
+ "use_muon": True,
437
+ },
438
+ {
439
+ "params": non_muon_params,
440
+ "use_muon": False,
441
+ },
442
+ ]
443
+
444
+
445
+ def parse_qk_layer(name: str) -> tuple[str | None, int]:
446
+ """
447
+ Parse a parameter name to check if it is a query/key projection layer
448
+ ('wq', 'wk', 'q_proj', 'k_proj') and return (kind, layer_index).
449
+
450
+ Returns:
451
+ (kind, layer_idx) or (None, -1) if not matched.
452
+
453
+ Example:
454
+ 'model.3.attn.wq.weight' -> ('wq', 3)
455
+ 'model.5.attn.wk.weight' -> ('wk', 5)
456
+ 'model.2.attn.q_proj.weight' -> ('q_proj', 2)
457
+ 'model.7.attn.k_proj.weight' -> ('k_proj', 7)
458
+ 'model.4.attn.v_proj.weight' -> (None, -1)
459
+ """
460
+ parts = name.split('.')
461
+ if len(parts) < 3:
462
+ return None, -1
463
+
464
+ kind = parts[-2]
465
+
466
+ layer_idx = -1
467
+ for part in reversed(parts):
468
+ if part.isdigit():
469
+ layer_idx = int(part)
470
+ break
471
+
472
+ if kind in ('wq', 'wk', 'q_proj', 'k_proj'):
473
+ return kind, layer_idx
474
+
475
+ return None, -1
476
+
477
+
478
+ @dataclass
479
+ class QKClipInfo:
480
+ """Per-parameter dynamic info computed from config + runtime logits."""
481
+ kind: Optional[str] # 'wq'/'q_proj' or 'wk'/'k_proj' or None
482
+ indices: List[int] # which heads to consider for clipping
483
+ head_dim: int # from config
484
+ threshold: float # from config
485
+ logit: Optional[torch.Tensor]
486
+
487
+
488
+ class Muon(torch.optim.Optimizer):
489
+ """
490
+ Muon - MomentUm Orthogonalized by Newton-schulz
491
+
492
+ Muon internally runs standard SGD-momentum, and then performs an orthogonalization post-
493
+ processing step, in which each 2D parameter's update is replaced with the nearest orthogonal
494
+ matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has
495
+ the advantage that it can be stably run in bfloat16 on the GPU.
496
+
497
+ Some warnings:
498
+ - We believe this optimizer is unlikely to work well for training with small batch size.
499
+ - We believe it may not work well for finetuning pretrained models, but we haven't tested this.
500
+
501
+ Arguments:
502
+ model: The model to be optimized by Muon.
503
+ is_muon_func: A function that takes a parameter and its name, and returns whether the parameter should be optimized by Muon.
504
+ lr: The learning rate. The updates will have spectral norm of `lr`. (0.02 is a good default)
505
+ momentum: The momentum used by the internal SGD. (0.95 is a good default)
506
+ nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended)
507
+ ns_steps: The number of Newton-Schulz iterations to run. (6 is probably always enough)
508
+ weight_decay: The weight decay for Muon and AdamW.
509
+ {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW as well.
510
+ adamw_lr: The learning rate for the internal AdamW.
511
+ adamw_betas: The betas for the internal AdamW.
512
+ adamw_eps: The epsilon for the internal AdamW.
513
+ none_grad: Whether to set p.grad to None after gathering the gradients. This can save memory.
514
+ debug: Whether to print debug information.
515
+ clip_info : Configuration for QK clipping. Expected keys:
516
+ - "q_indices" (list[int]): Indices of query heads to consider.
517
+ - "k_indices" (list[int]): Indices of key heads to consider.
518
+ - "head_dim" (int): Dimensionality of each attention head.
519
+ - "threshold" (float): Threshold value; heads whose QK logits exceed
520
+ this value will be scaled down.
521
+ Default is:
522
+ {
523
+ "q_indices": [],
524
+ "k_indices": [],
525
+ "head_dim": 128,
526
+ "threshold": 100
527
+ }
528
+ overlap_step : How many all2all gather, compute operations are launched in advance
529
+ before the corresponding all2all scatter steps begin.
530
+ A higher overlap_step increases memory usage but can improve
531
+ performance by overlapping communication.
532
+ Parallel muon only.
533
+ """
534
+
535
+ def __init__(self,
536
+ params,
537
+ lr=1e-3,
538
+ momentum=0.95,
539
+ nesterov=True,
540
+ ns_steps=5,
541
+ weight_decay=0.1,
542
+ adamw_betas=(0.9, 0.95),
543
+ adamw_eps=1e-8,
544
+ none_grad=True,
545
+ debug=False,
546
+ clip_config={
547
+ "q_indices": [],
548
+ "k_indices": [],
549
+ "head_dim": 128,
550
+ "threshold": 100
551
+ },
552
+ overlap_step=5):
553
+ defaults = dict(
554
+ lr=lr,
555
+ weight_decay=weight_decay,
556
+ momentum=momentum,
557
+ nesterov=nesterov,
558
+ ns_steps=ns_steps,
559
+ adamw_betas=adamw_betas,
560
+ adamw_eps=adamw_eps,
561
+ none_grad=none_grad,
562
+ use_muon=True,
563
+ )
564
+ error_message = "The key 'use_muon' is not set in parameter group {idx}. Assuming all parameters in the group will use muon optimization, which may lead to unexpected behavior."
565
+ instruction_code = "\n\n please follow this code snippet \n```optimizer = get_kernel('motif-technologies/optimizer')\n\n\nparams = optimizer.muon.get_default_muon_param_groups(model)\n\noptim = optimizer.Muon(params, ...)```"
566
+
567
+ if isinstance(params, types.GeneratorType):
568
+ raise ValueError(error_message.format(idx=0) + instruction_code)
569
+ for _idx, param_group in enumerate(params):
570
+ if param_group.get("use_muon", None) is None:
571
+ raise ValueError(
572
+ error_message.format(idx=_idx) + instruction_code)
573
+
574
+ super().__init__(params, defaults)
575
+
576
+ self.rank = None
577
+
578
+ self.comm_stream = torch.cuda.Stream()
579
+ self.compute_stream = torch.cuda.Stream()
580
+ self.debug = debug
581
+ self.clip_config = clip_config
582
+ self.overlap_step = overlap_step
583
+
584
+ def _calc_flops(self, G, steps):
585
+ assert len(G.shape) == 2
586
+ M, N = G.shape
587
+ if M > N:
588
+ M, N = N, M
589
+
590
+ return steps * ((M**3) * 2 + (M**2 * N) * 4 + M * N * 2 + M**2 * 3)
591
+
592
+ def adjust_lr_for_muon(self, lr, param_shape):
593
+ A, B = param_shape[:2]
594
+ # We adjust the learning rate and weight decay based on the size of the parameter matrix
595
+ # as describted in the paper
596
+ adjusted_ratio = 0.2 * math.sqrt(max(A, B))
597
+ adjusted_lr = lr * adjusted_ratio
598
+ return adjusted_lr
599
+
600
+ def get_shard_mesh(self, p):
601
+ """
602
+ Get the shard mesh for a parameter p on the given rank.
603
+ """
604
+ assert isinstance(
605
+ p, DTensor), "Parallel Muon only supports DTensor parameters."
606
+
607
+ if p.placements == (Shard(dim=0), ):
608
+ # Case for FSDP
609
+ process_group = p.device_mesh.get_group(mesh_dim=0)
610
+ if self.rank is None:
611
+ self.rank = dist.get_rank(group=process_group)
612
+ else:
613
+ assert self.rank == dist.get_rank(group=process_group)
614
+ return p.device_mesh.mesh, p.device_mesh.get_group(mesh_dim=0)
615
+ elif p.placements == (Replicate(), Shard(dim=0)):
616
+ # Case for HSDP
617
+ process_group = p.device_mesh.get_group(mesh_dim=1)
618
+ if self.rank is None:
619
+ self.rank = dist.get_rank(group=process_group)
620
+ else:
621
+ assert self.rank == dist.get_rank(group=process_group)
622
+ for i, shard_mesh in enumerate(p.device_mesh.mesh):
623
+ if self.rank in shard_mesh:
624
+ return shard_mesh, p.device_mesh.get_group(mesh_dim=1)
625
+ else:
626
+ raise ValueError(f"Unsupported placements ({p.placements}).")
627
+
628
+ def init_state_and_assign_params(self, names, params, group, qk_logits):
629
+ param_to_state = {}
630
+ param_to_flops = {}
631
+
632
+ total_flops = 0
633
+ for p in params:
634
+ g = p.grad
635
+ if g is None:
636
+ continue
637
+ assert g.ndim == 2, "Muon only supports 2D parameters."
638
+
639
+ flops = self._calc_flops(g, group["ns_steps"])
640
+ param_to_flops[id(p)] = flops
641
+ total_flops += flops
642
+
643
+ if self.debug:
644
+ print(f"Total TFLOPs for Muon: {total_flops / 1e12:.2f} TFLOPs",
645
+ flush=True)
646
+
647
+ paired = list(zip(names, params))
648
+
649
+ paired_sorted = sorted(paired,
650
+ key=lambda x: param_to_flops[id(x[1])],
651
+ reverse=True)
652
+
653
+ names_sorted, params_sorted = zip(*paired_sorted)
654
+ ordered_names = list(names_sorted)
655
+ ordered_params = list(params_sorted)
656
+
657
+ round_robin = 0
658
+ mesh = None
659
+ shard_mesh = None
660
+ process_group = None
661
+ for n, p in zip(ordered_names, ordered_params):
662
+ if mesh is None:
663
+ mesh = p.device_mesh
664
+ shard_mesh, process_group = self.get_shard_mesh(p)
665
+ elif mesh != p.device_mesh:
666
+ raise ValueError("All parameters must be on the same mesh.")
667
+ num_ranks = dist.get_world_size(group=process_group)
668
+ param_to_state[id(p)] = _muon_state()
669
+ param_to_state[id(
670
+ p)].worker_rank = shard_mesh[round_robin].item() % num_ranks
671
+ param_to_state[id(p)].process_group = process_group
672
+ qk_clip_state = self.get_qk_clip_info(n, qk_logits)
673
+ param_to_state[id(p)].qk_clip_state = qk_clip_state
674
+ round_robin = (round_robin + 1) % len(shard_mesh)
675
+
676
+ return param_to_state, ordered_params
677
+
678
+ def base(self, names, params, group, lr, weight_decay, momentum,
679
+ qk_logits):
680
+ # generate weight updates in distributed fashion
681
+ for n, p in zip(names, params):
682
+ g = p.grad
683
+ if g is None:
684
+ continue
685
+ if g.ndim > 2:
686
+ g = g.view(g.size(0), -1)
687
+ assert g is not None
688
+
689
+ # calc update
690
+ state = self.state[p]
691
+ if "momentum_buffer" not in state:
692
+ state["momentum_buffer"] = torch.zeros_like(g)
693
+ buf = state["momentum_buffer"]
694
+ buf.mul_(momentum).add_(g)
695
+ if group["nesterov"]:
696
+ g = g.add(buf, alpha=momentum)
697
+ else:
698
+ g = buf
699
+
700
+ u = _zeropower_via_newtonschulz5(g.to(COMM_DTYPE),
701
+ steps=group["ns_steps"])
702
+
703
+ adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
704
+ Muon._update_p(p, u, lr, adjusted_lr, weight_decay)
705
+
706
+ qk_clip_state = self.get_qk_clip_info(n, qk_logits)
707
+
708
+ scales_full = self._compute_scales(p, qk_clip_state)
709
+ if scales_full is not None:
710
+ Muon._qk_clip(p, scales_full, qk_clip_state.head_dim)
711
+
712
+ def _update_g(self, p, g, group, momentum):
713
+ # calc update
714
+ state = self.state[p]
715
+ buf = state.setdefault("momentum_buffer", torch.zeros_like(g))
716
+ torch.add(g, buf, alpha=momentum, out=buf)
717
+ if group["nesterov"]:
718
+ g.add_(buf, alpha=momentum)
719
+ return g
720
+ return buf
721
+
722
+ @staticmethod
723
+ def _update_p(p, u, lr, adjusted_lr, weight_decay):
724
+ # apply weight decay
725
+ p.data.mul_(1 - lr * weight_decay)
726
+ # apply update
727
+ p.data.add_(u, alpha=-adjusted_lr)
728
+
729
+ def get_qk_clip_info(self, n, qk_logits):
730
+ head_dim = self.clip_config.get('head_dim')
731
+ threshold = self.clip_config.get('threshold')
732
+ kind, layer_idx = parse_qk_layer(n)
733
+
734
+ logit, indices = None, []
735
+ if qk_logits is not None and kind is not None:
736
+ logit = qk_logits[layer_idx]
737
+ indices_key = 'q_indices' if 'q' in kind else 'k_indices'
738
+ indices = self.clip_config.get(indices_key, []) or []
739
+
740
+ return QKClipInfo(
741
+ kind=kind,
742
+ indices=indices,
743
+ head_dim=head_dim,
744
+ threshold=threshold,
745
+ logit=logit,
746
+ )
747
+
748
+ @staticmethod
749
+ def _compute_scales(p, qk_clip_state):
750
+ kind = qk_clip_state.kind
751
+ indices = qk_clip_state.indices
752
+ head_dim = qk_clip_state.head_dim
753
+ threshold = qk_clip_state.threshold
754
+ logit = qk_clip_state.logit
755
+
756
+ H_global = p.shape[0] // head_dim
757
+ scales_full = torch.ones(H_global, device=p.data.device)
758
+ scaling = 0
759
+
760
+ for logit_idx, head_idx in enumerate(indices):
761
+ v_ele = float(logit[logit_idx])
762
+ if v_ele > threshold:
763
+ new_scale = math.sqrt(threshold / v_ele)
764
+ if new_scale < scales_full[head_idx]:
765
+ scales_full[head_idx] = new_scale
766
+ logger.info(
767
+ f"[{kind}] Head {head_idx} exceeded threshold "
768
+ f"(value={v_ele:.4f}, threshold={threshold:.4f}) -> applying scale={new_scale:.4f}"
769
+ )
770
+ scaling += 1
771
+
772
+ return scales_full if scaling > 0 else None
773
+
774
+ @staticmethod
775
+ def _qk_clip(p, scales, head_dim):
776
+ W = p.data.view(-1, head_dim, p.data.shape[1])
777
+ W.mul_(scales.view(-1, 1, 1))
778
+
779
+ def parallel(self, names, params, group, lr, weight_decay, momentum,
780
+ qk_logits):
781
+ """
782
+ Perform a parallel optimization step using Muon.
783
+ """
784
+
785
+ for p in params:
786
+ g = p.grad
787
+ if g is None:
788
+ continue
789
+ if g.ndim > 2:
790
+ g = g.view(g.size(0), -1)
791
+
792
+ # Update g in the local rank
793
+ g = self._update_g(
794
+ p,
795
+ g,
796
+ group,
797
+ momentum=momentum,
798
+ )
799
+ p.grad = g
800
+
801
+ param_to_state, ordered_params = self.init_state_and_assign_params(
802
+ names, params, group, qk_logits)
803
+
804
+ assert self.rank is not None
805
+
806
+ def enqueue_all2all_gather(start_idx, chunk_size):
807
+ target_params = ordered_params[start_idx:start_idx + chunk_size]
808
+ if target_params:
809
+ alloc_event = _alloc_gathered_grad(target_params,
810
+ param_to_state, self.rank,
811
+ self.compute_stream)
812
+ _all2all_gather(target_params, param_to_state, self.rank,
813
+ self.comm_stream, group["none_grad"],
814
+ alloc_event)
815
+
816
+ def enqueue_computes(start_idx, chunk_size):
817
+ for p in ordered_params[start_idx:start_idx + chunk_size]:
818
+ state = param_to_state[id(p)]
819
+ _compute_u(p, state, group["ns_steps"], self.rank,
820
+ self.compute_stream)
821
+
822
+ def enqueue_all2all_scatter(start_idx, chunk_size):
823
+ target_params = ordered_params[start_idx:start_idx + chunk_size]
824
+ if target_params:
825
+ alloc_event = _alloc_scattered_u(target_params, param_to_state,
826
+ self.rank,
827
+ self.compute_stream)
828
+ _all2all_scatter(target_params, param_to_state, self.rank,
829
+ self.comm_stream, alloc_event)
830
+
831
+ def enqueue_update_param(start_idx, chunk_size):
832
+ for p in ordered_params[start_idx:start_idx + chunk_size]:
833
+ state = param_to_state[id(p)]
834
+ adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
835
+ _update_param(p, state, lr, adjusted_lr, weight_decay,
836
+ self.rank, self.compute_stream)
837
+
838
+ chunk_size = dist.get_world_size(param_to_state[id(
839
+ params[0])].process_group)
840
+
841
+ # Wait grad update
842
+ self.comm_stream.wait_stream(torch.cuda.current_stream())
843
+
844
+ overlap_step = self.overlap_step
845
+ for i in range(0, overlap_step):
846
+ enqueue_all2all_gather(i * chunk_size, chunk_size)
847
+ enqueue_computes(i * chunk_size, chunk_size)
848
+
849
+ for i in range(0, len(params) + chunk_size - 1, chunk_size):
850
+ enqueue_all2all_scatter(i, chunk_size)
851
+ enqueue_all2all_gather(i + overlap_step * chunk_size, chunk_size)
852
+ enqueue_update_param(i, chunk_size)
853
+ enqueue_computes(i + overlap_step * chunk_size, chunk_size)
854
+
855
+ # Wait the last update_param to finish
856
+ torch.cuda.current_stream().wait_stream(self.compute_stream)
857
+
858
+ @staticmethod
859
+ def _fused_adamw(
860
+ params: list[torch.Tensor],
861
+ grads: list[torch.Tensor],
862
+ exp_avgs: list[torch.Tensor],
863
+ exp_avg_sqs: list[torch.Tensor],
864
+ max_exp_avg_sqs: list[torch.Tensor],
865
+ state_steps: list[torch.Tensor],
866
+ amsgrad: bool,
867
+ beta1: float,
868
+ beta2: float,
869
+ lr: Union[float, torch.Tensor],
870
+ weight_decay: float,
871
+ eps: float,
872
+ maximize: bool,
873
+ ) -> None:
874
+ if not params:
875
+ return
876
+
877
+ # We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer
878
+ # treating it as a scalar.
879
+ lr_dict: Optional[DeviceDict] = ({
880
+ lr.device: lr
881
+ } if isinstance(lr, torch.Tensor) and str(lr.device) != "cpu" else
882
+ None)
883
+ grouped_tensors = torch.optim.Optimizer._group_tensors_by_device_and_dtype(
884
+ [
885
+ params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs,
886
+ state_steps
887
+ ] # type: ignore[list-item]
888
+ )
889
+ for (device, _), (
890
+ (
891
+ device_params_,
892
+ device_grads_,
893
+ device_exp_avgs_,
894
+ device_exp_avg_sqs_,
895
+ device_max_exp_avg_sqs,
896
+ device_state_steps_,
897
+ ),
898
+ _,
899
+ ) in grouped_tensors.items():
900
+ device_params = cast(list[torch.Tensor], device_params_)
901
+ device_grads = cast(list[torch.Tensor], device_grads_)
902
+ device_exp_avgs = cast(list[torch.Tensor], device_exp_avgs_)
903
+ device_exp_avg_sqs = cast(list[torch.Tensor], device_exp_avg_sqs_)
904
+ device_state_steps = cast(list[torch.Tensor], device_state_steps_)
905
+
906
+ if lr_dict is not None and device not in lr_dict:
907
+ lr_dict[device] = lr.to(
908
+ device=device,
909
+ non_blocking=True) # type: ignore[union-attr]
910
+ lr = lr_dict[device]
911
+ torch._foreach_add_(device_state_steps, 1)
912
+ func = torch._fused_adamw_
913
+ func(
914
+ device_params,
915
+ device_grads,
916
+ device_exp_avgs,
917
+ device_exp_avg_sqs,
918
+ device_max_exp_avg_sqs, # type: ignore[arg-type]
919
+ device_state_steps,
920
+ amsgrad=amsgrad,
921
+ lr=lr, # type: ignore[arg-type]
922
+ beta1=beta1,
923
+ beta2=beta2,
924
+ weight_decay=weight_decay,
925
+ eps=eps,
926
+ maximize=maximize,
927
+ )
928
+
929
+ def step(self, closure=None, qk_logits=None):
930
+ """Perform a single optimization step.
931
+
932
+ Args:
933
+ closure (Callable, optional): A closure that reevaluates the model
934
+ and returns the loss.
935
+ qk_logits (dict[int, Tensor], optional): A dictionary mapping layer indices
936
+ to 1D tensors of shape (num_heads,), representing the maximum
937
+ QK logits across all tokens, computed as
938
+ (1 / sqrt(head_dim)) * (Q @ K^T).
939
+ """
940
+ loss = None
941
+ if closure is not None:
942
+ with torch.enable_grad():
943
+ loss = closure()
944
+
945
+ for group in self.param_groups:
946
+ params = group["params"]
947
+
948
+ if group["use_muon"]:
949
+ ############################
950
+ # Muon #
951
+ ############################
952
+ lr = group["lr"]
953
+ weight_decay = group["weight_decay"]
954
+ momentum = group["momentum"]
955
+ names = group["names"]
956
+
957
+ param_dtensors = []
958
+ param_tensors = []
959
+ name_dtensors = []
960
+ name_tensors = []
961
+
962
+ for n, p in zip(names, params):
963
+ if p is None or p.grad is None:
964
+ continue
965
+ if isinstance(p.data, DTensor):
966
+ if all(
967
+ isinstance(placement, Replicate)
968
+ for placement in p.placements):
969
+ param_tensors.append(p)
970
+ name_tensors.append(n)
971
+ else:
972
+ param_dtensors.append(p)
973
+ name_dtensors.append(n)
974
+ elif isinstance(p.data, torch.Tensor):
975
+ param_tensors.append(p)
976
+ name_tensors.append(n)
977
+ else:
978
+ raise TypeError(
979
+ f"Unsupported parameter type: {type(p.data)}")
980
+
981
+ if self.debug:
982
+ print(
983
+ f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors",
984
+ flush=True,
985
+ )
986
+
987
+ if len(param_dtensors) > 0:
988
+ if not dist.is_initialized():
989
+ raise RuntimeError(
990
+ "Parallel Muon requires torch.distributed to be initialized."
991
+ )
992
+
993
+ self.parallel(
994
+ name_dtensors,
995
+ param_dtensors,
996
+ group,
997
+ lr=lr,
998
+ weight_decay=weight_decay,
999
+ momentum=momentum,
1000
+ qk_logits=qk_logits,
1001
+ )
1002
+
1003
+ if len(param_tensors) > 0:
1004
+ self.base(
1005
+ name_tensors,
1006
+ param_tensors,
1007
+ group,
1008
+ lr=lr,
1009
+ weight_decay=weight_decay,
1010
+ momentum=momentum,
1011
+ qk_logits=qk_logits,
1012
+ )
1013
+
1014
+ else:
1015
+ ############################
1016
+ # AdamW backup #
1017
+ ############################
1018
+
1019
+ params_with_grads = []
1020
+ grads = []
1021
+ moment1 = []
1022
+ moment2 = []
1023
+ max_exp_avg_sqs = []
1024
+ state_steps = []
1025
+ lr = group["lr"]
1026
+ beta1, beta2 = group["adamw_betas"]
1027
+ eps = group["adamw_eps"]
1028
+ weight_decay = group["weight_decay"]
1029
+
1030
+ for p in params:
1031
+ g = p.grad
1032
+ if g is None:
1033
+ continue
1034
+ state = self.state[p]
1035
+ params_with_grads.append(p)
1036
+ grads.append(g)
1037
+ if "step" not in state:
1038
+ state["step"] = (torch.zeros((),
1039
+ dtype=torch.float32,
1040
+ device=p.device))
1041
+ state["moment1"] = torch.zeros_like(g)
1042
+ state["moment2"] = torch.zeros_like(g)
1043
+ moment1.append(state["moment1"])
1044
+ moment2.append(state["moment2"])
1045
+ if not isinstance(state["step"], torch.Tensor):
1046
+ step_tensor = torch.tensor(state["step"],
1047
+ dtype=torch.float32,
1048
+ device=p.device)
1049
+ else:
1050
+ step_tensor = state["step"]
1051
+ state_steps.append(step_tensor)
1052
+
1053
+ self._fused_adamw(
1054
+ params_with_grads,
1055
+ grads,
1056
+ moment1,
1057
+ moment2,
1058
+ max_exp_avg_sqs,
1059
+ state_steps,
1060
+ amsgrad=False,
1061
+ beta1=beta1,
1062
+ beta2=beta2,
1063
+ lr=lr,
1064
+ weight_decay=weight_decay,
1065
+ eps=eps,
1066
+ maximize=False,
1067
+ )
1068
+
1069
+ return loss
build/torch29-cxx11-rocm64-x86_64-linux/optimizer/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .muon import Muon
2
+
3
+ __all__ = [
4
+ "Muon",
5
+ ]
build/torch29-cxx11-rocm64-x86_64-linux/optimizer/_ops.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from . import _optimizer_811726c_dirty
3
+ ops = torch.ops._optimizer_811726c_dirty
4
+
5
+ def add_op_namespace_prefix(op_name: str):
6
+ """
7
+ Prefix op by namespace.
8
+ """
9
+ return f"_optimizer_811726c_dirty::{op_name}"
build/torch29-cxx11-rocm64-x86_64-linux/optimizer/_optimizer_811726c_dirty.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:08b55491319446b12d0d890926506639640414edcba945e0f71afef0fac369d5
3
+ size 1749352
build/torch29-cxx11-rocm64-x86_64-linux/optimizer/matmul_transpose_triton.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MIT License
2
+ #
3
+ # Copyright (c) 2025 Tianyang Lin
4
+ #
5
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ # of this software and associated documentation files (the "Software"), to deal
7
+ # in the Software without restriction, including without limitation the rights
8
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ # copies of the Software, and to permit persons to whom the Software is
10
+ # furnished to do so, subject to the following conditions:
11
+ #
12
+ # The above copyright notice and this permission notice shall be included in all
13
+ # copies or substantial portions of the Software.
14
+ #
15
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ # SOFTWARE.
22
+
23
+ import torch
24
+ import triton
25
+ import triton.language as tl
26
+
27
+
28
+ def get_autotune_config():
29
+ return [
30
+ triton.Config(
31
+ {
32
+ 'BLOCK_SIZE_M': blk_m,
33
+ 'BLOCK_SIZE_K': blk_k,
34
+ 'GROUP_SIZE_M': grp_sz
35
+ },
36
+ num_stages=n_stages,
37
+ num_warps=n_warps) for blk_m in [32, 64, 128]
38
+ for blk_k in [32, 64] for grp_sz in [8] for n_stages in [3, 4, 5]
39
+ for n_warps in [4, 8]
40
+ ]
41
+
42
+
43
+ @triton.autotune(
44
+ configs=get_autotune_config(),
45
+ key=['M', 'K'],
46
+ )
47
+ @triton.jit
48
+ def mmt_kernel(x, y, M, K, stride_xm, stride_xk, stride_ym, stride_yn,
49
+ BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
50
+ GROUP_SIZE_M: tl.constexpr):
51
+ """
52
+ Core kernel jit function of matmul_transpose that computes y = x @ x.T
53
+ The code is a simple adaptation from the triton `matmul` tutorial:
54
+ https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html
55
+ """
56
+ pid = tl.program_id(axis=0)
57
+ num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
58
+ num_pid_n = tl.cdiv(M, BLOCK_SIZE_M)
59
+ num_pid_in_group = GROUP_SIZE_M * num_pid_n
60
+ group_id = pid // num_pid_in_group
61
+ first_pid_m = group_id * GROUP_SIZE_M
62
+ group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
63
+ pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
64
+ pid_n = (pid % num_pid_in_group) // group_size_m
65
+ if pid_m > pid_n:
66
+ return
67
+
68
+ offs_xm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
69
+ offs_xn = (pid_n * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
70
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
71
+ # we use a & b ptrs to denote different rows of x.
72
+ a_ptrs = x + (offs_xm[:, None] * stride_xm + offs_k[None, :] * stride_xk)
73
+ b_ptrs = x + (offs_xn[:, None] * stride_xm + offs_k[None, :] * stride_xk)
74
+
75
+ accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_M), dtype=tl.float32)
76
+
77
+ for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
78
+ a = tl.load(a_ptrs,
79
+ mask=offs_k[None, :] < K - k * BLOCK_SIZE_K,
80
+ other=0.0)
81
+ b = tl.load(b_ptrs,
82
+ mask=offs_k[None, :] < K - k * BLOCK_SIZE_K,
83
+ other=0.0)
84
+ accumulator = tl.dot(a, tl.permute(b, (1, 0)), accumulator)
85
+ a_ptrs += BLOCK_SIZE_K * stride_xk
86
+ b_ptrs += BLOCK_SIZE_K * stride_xk
87
+ # use dtype.element_ty to accommodate different input datatypes as in cpp templates
88
+ # https://github.com/triton-lang/triton/issues/2252
89
+ c = accumulator.to(x.dtype.element_ty)
90
+
91
+ offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
92
+ offs_cn = pid_n * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
93
+ c_ptrs = y + stride_ym * offs_cm[:, None] + stride_yn * offs_cn[None, :]
94
+ c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M)
95
+ tl.store(c_ptrs, c, mask=c_mask)
96
+
97
+ # transpose and copy
98
+ if pid_m < pid_n:
99
+ ct_ptrs = y + stride_ym * offs_cn[:,
100
+ None] + stride_yn * offs_cm[None, :]
101
+ ct_mask = (offs_cn[:, None] < M) & (offs_cm[None, :] < M)
102
+ tl.store(ct_ptrs, tl.permute(c, (1, 0)), mask=ct_mask)
103
+
104
+
105
+ def matmul_transpose_assign(d_in, d_out):
106
+ assert d_in.is_cuda, "Input `d_in` must be a CUDA tensor"
107
+ assert d_out.is_cuda, "Input `d_out` must be a CUDA tensor"
108
+ assert d_in.device == d_out.device, "Inputs `d_in` and `d_out` must be on the same CUDA device"
109
+ assert d_in.dtype == d_out.dtype, "Inputs must have the same data type"
110
+ assert d_in.ndim == 2, "Input `d_in` must be a 2D tensor"
111
+ assert d_out.ndim == 2, "Input `d_out` must be a 2D tensor"
112
+ assert d_in.size(0) == d_out.size(0) == d_out.size(0), \
113
+ "First dimension of `d_in` must match first and second dimension of `d_out`"
114
+
115
+ d_in = d_in.contiguous()
116
+ M, K = d_in.shape
117
+ grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(
118
+ M, META['BLOCK_SIZE_M']), )
119
+ with torch.cuda.device(d_in.device.index):
120
+ mmt_kernel[grid](d_in, d_out, M, K, d_in.stride(0), d_in.stride(1),
121
+ d_out.stride(0), d_out.stride(1))
122
+
123
+
124
+ def matmul_transpose(d_in):
125
+ M, _ = d_in.shape
126
+ d_out = torch.empty((M, M), device=d_in.device, dtype=d_in.dtype)
127
+ matmul_transpose_assign(d_in, d_out)
128
+ return d_out
build/torch29-cxx11-rocm64-x86_64-linux/optimizer/muon.py ADDED
@@ -0,0 +1,1069 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import math
3
+ import types
4
+ from dataclasses import dataclass
5
+ from typing import List, Optional, Union, cast
6
+
7
+ import torch
8
+ import torch.distributed as dist
9
+ from torch.distributed._tensor import DTensor, Replicate, Shard
10
+
11
+ from .matmul_transpose_triton import matmul_transpose_assign
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+ COMM_DTYPE = torch.bfloat16
16
+
17
+
18
+ # This code snippet is a modified version adapted from the following GitHub repositories:
19
+ # https://github.com/KellerJordan/Muon/blob/master/muon.py
20
+ # Muon's Newton–Schulz iteration causes high variance in singular values
21
+ # Idea: give each iteration its own 3 coefficients and optimize them via gradient descent.
22
+ @torch.no_grad()
23
+ # matmul_transpose_assign from : https://github.com/nil0x9/flash-muon
24
+ def _zeropower_via_newtonschulz5(G, steps):
25
+ """
26
+ Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
27
+ quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
28
+ of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
29
+ zero even beyond the point where the iteration no longer converges all the way to one everywhere
30
+ on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
31
+ where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
32
+ performance at all relative to UV^T, where USV^T = G is the SVD.
33
+ """
34
+ assert len(G.shape) == 2
35
+ assert G.dtype == COMM_DTYPE
36
+ X = G # no manual typecast
37
+
38
+ if G.size(0) > G.size(1):
39
+ X = X.T
40
+ # Ensure spectral norm is at most 1
41
+ X = X / (X.norm() + 1e-7)
42
+ buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
43
+ buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
44
+ # Perform the NS iterations
45
+ for a, b, c in [
46
+ (4.0848, -6.8946, 2.9270),
47
+ (3.9505, -6.3029, 2.6377),
48
+ (3.7418, -5.5913, 2.3037),
49
+ (2.8769, -3.1427, 1.2046),
50
+ (2.8366, -3.0525, 1.2012),
51
+ ]:
52
+ matmul_transpose_assign(X, buf1)
53
+ matmul_transpose_assign(buf1, buf2)
54
+ buf1.mul_(b).add_(buf2, alpha=c)
55
+ X = torch.addmm(X, buf1, X, alpha=1.0, beta=a)
56
+
57
+ if G.size(0) > G.size(1):
58
+ X = X.T
59
+ return X
60
+
61
+
62
+ @dataclass
63
+ class _muon_state:
64
+ # TODO: use Optional
65
+ worker_rank: int | None = None
66
+ gathered_grad: torch.Tensor | None = None
67
+ scattered_u: DTensor | None = None
68
+ computed_u: torch.Tensor | None = None
69
+ gather_event: torch.cuda.Event | None = None
70
+ compute_event: torch.cuda.Event | None = None
71
+ scatter_event: torch.cuda.Event | None = None
72
+ process_group = None
73
+ qk_clip_state = None
74
+
75
+
76
+ def split_elems_for_src(param, src_rank, num_ranks) -> int:
77
+ rows = param.shape[0]
78
+ cols = int(param.numel() // rows)
79
+ base, rem = divmod(rows, num_ranks)
80
+ my_rows = base + (1 if src_rank < rem else 0)
81
+ return my_rows * cols
82
+
83
+
84
+ @torch.no_grad()
85
+ def _alloc_gathered_grad(params, param_to_state, rank, compute_stream):
86
+ """
87
+ Pre-allocate gathered_grad buffer on compute_stream
88
+ before launching all2all gather
89
+ """
90
+ with torch.cuda.stream(compute_stream):
91
+ for p in params:
92
+ state = param_to_state[id(p)]
93
+ if rank == state.worker_rank:
94
+ num_ranks = dist.get_world_size(group=state.process_group)
95
+ state.gathered_grad = torch.empty(p.grad.numel(),
96
+ dtype=COMM_DTYPE,
97
+ device="cuda")
98
+ else:
99
+ state.gathered_grad = None
100
+
101
+ alloc_event = torch.cuda.Event()
102
+ alloc_event.record(compute_stream)
103
+ return alloc_event
104
+
105
+
106
+ @torch.no_grad()
107
+ def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad,
108
+ alloc_event):
109
+ """
110
+ All2all gathers shards so each owner rank reconstructs its full gradient
111
+ """
112
+ with torch.cuda.stream(comm_stream):
113
+ process_group = param_to_state[id(params[0])].process_group
114
+ num_ranks = dist.get_world_size(group=process_group)
115
+
116
+ # Construct sending buffers
117
+ per_dst = [[] for _ in range(num_ranks)]
118
+ send_counts = [0] * num_ranks
119
+
120
+ for p in params:
121
+ state = param_to_state[id(p)]
122
+ dst = state.worker_rank
123
+ assert dst < num_ranks
124
+ shard_elems = split_elems_for_src(p, rank, num_ranks)
125
+ g = p.grad
126
+ g = g.to_local().to(COMM_DTYPE).contiguous().view(-1)
127
+ assert g.numel() == shard_elems
128
+ per_dst[dst].append(g)
129
+ send_counts[dst] += shard_elems
130
+
131
+ assert any(
132
+ len(v) > 0 for v in per_dst
133
+ ), "At least one destination rank must receive a sharded tensor"
134
+ # list[list[Tensor]] -> list[Tensor]
135
+ per_dst = [t for dst in per_dst for t in dst]
136
+
137
+ send_buf = torch.cat(per_dst, dim=0)
138
+
139
+ owned_params = [
140
+ p for p in params if param_to_state[id(p)].worker_rank == rank
141
+ ]
142
+
143
+ # Compute receive sizes and allocate receiving buffers
144
+ recv_counts = [0] * num_ranks
145
+
146
+ for src in range(num_ranks):
147
+ total = 0
148
+ for p in owned_params:
149
+ state = param_to_state[id(p)]
150
+ assert state.worker_rank == rank
151
+ total += split_elems_for_src(p, src, num_ranks)
152
+ recv_counts[src] = total
153
+
154
+ recv_total = sum(recv_counts)
155
+ recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda")
156
+
157
+ #All2All
158
+ dist.all_to_all_single(
159
+ recv_buf,
160
+ send_buf,
161
+ output_split_sizes=recv_counts,
162
+ input_split_sizes=send_counts,
163
+ group=process_group,
164
+ )
165
+
166
+ # Reconstructs gathered grad from the received buffer
167
+ #
168
+ # recv_buf (num ranks = 3)
169
+ #
170
+ # From rank 0 From rank 1 From rank 2
171
+ # | p1_0, p2_0, p3_0 | p1_1, p2_1, p3_1 | p1_2, p2_2, p3_2 |
172
+ #
173
+ # Outer loop:
174
+ # rank 0 -> rank 1 -> rank2
175
+ #
176
+ # Inner loop:
177
+ # p1_n -> p2_n -> p3_n
178
+
179
+ comm_stream.wait_event(alloc_event)
180
+
181
+ off = 0
182
+ write_offsets = {id(p): 0 for p in owned_params}
183
+ for src in range(num_ranks):
184
+ if recv_counts[src] == 0:
185
+ continue
186
+
187
+ block = recv_counts[src]
188
+ inner_off = 0
189
+ for p in owned_params:
190
+ state = param_to_state[id(p)]
191
+ assert state.worker_rank == rank
192
+ n = split_elems_for_src(p, src, num_ranks)
193
+ assert n > 0
194
+
195
+ sg = recv_buf.narrow(0, off + inner_off, n)
196
+ woff = write_offsets[id(p)]
197
+ dst = state.gathered_grad.narrow(0, woff, n)
198
+ dst.copy_(sg)
199
+
200
+ write_offsets[id(p)] += n
201
+ inner_off += n
202
+ off += block
203
+
204
+ for p in params:
205
+ state = param_to_state[id(p)]
206
+ if state.worker_rank == rank:
207
+ state.gathered_grad = state.gathered_grad.view_as(p)
208
+ state.gather_event = torch.cuda.Event()
209
+ state.gather_event.record(comm_stream)
210
+ else:
211
+ state.gathered_grad = None
212
+ state.gather_event = None
213
+ if none_grad:
214
+ p.grad = None
215
+
216
+
217
+ @torch.no_grad()
218
+ def _compute_u(p, state, steps, rank, compute_stream):
219
+ """
220
+ On worker_rank, compute the orthogonalized update using Newton-Schulz iteration.
221
+ """
222
+ with torch.cuda.stream(compute_stream):
223
+ if rank == state.worker_rank:
224
+ if state.gather_event is None:
225
+ raise RuntimeError("Gather event must be set before compute.")
226
+ compute_stream.wait_event(state.gather_event)
227
+ u = _zeropower_via_newtonschulz5(state.gathered_grad, steps)
228
+ state.gathered_grad = None
229
+ state.computed_u = u
230
+ state.compute_event = torch.cuda.Event()
231
+ state.compute_event.record()
232
+ else:
233
+ state.computed_u = None
234
+ state.compute_event = None
235
+
236
+
237
+ @torch.no_grad()
238
+ def _alloc_scattered_u(params, param_to_state, rank, compute_stream):
239
+ """
240
+ Pre-allocate scattered_u buffer on compute_stream
241
+ before launching all2all gather
242
+ """
243
+ with torch.cuda.stream(compute_stream):
244
+ for p in params:
245
+ state = param_to_state[id(p)]
246
+ state.scattered_u = torch.empty_like(p.to_local(),
247
+ dtype=COMM_DTYPE)
248
+
249
+ alloc_event = torch.cuda.Event()
250
+ alloc_event.record(compute_stream)
251
+ return alloc_event
252
+
253
+
254
+ def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event):
255
+ """
256
+ All2all scatters full gradients to all ranks
257
+ """
258
+ with torch.cuda.stream(comm_stream):
259
+ process_group = param_to_state[id(params[0])].process_group
260
+ num_ranks = dist.get_world_size(group=process_group)
261
+ owned_params = [
262
+ p for p in params if param_to_state[id(p)].worker_rank == rank
263
+ ]
264
+
265
+ # Construct sending buffer
266
+ per_dst = [[] for _ in range(num_ranks)]
267
+ send_counts = [0] * num_ranks
268
+
269
+ if owned_params:
270
+ for p in owned_params:
271
+ state = param_to_state[id(p)]
272
+ if state.compute_event is None:
273
+ raise RuntimeError(
274
+ "Compute event must be set before scatter.")
275
+ comm_stream.wait_event(state.compute_event)
276
+ state.gathered_grad = None
277
+
278
+ assert state.computed_u is not None
279
+
280
+ u_full = state.computed_u.to(COMM_DTYPE).contiguous().view(-1)
281
+
282
+ offset = 0
283
+ for dst in range(num_ranks):
284
+ n = split_elems_for_src(p, dst, num_ranks)
285
+ assert n > 0
286
+
287
+ su = u_full.narrow(0, offset, n)
288
+ per_dst[dst].append(su)
289
+ send_counts[dst] += n
290
+ offset += n
291
+
292
+ assert offset == u_full.numel()
293
+
294
+ lengths = [len(v) for v in per_dst]
295
+ if all(l > 0 for l in lengths):
296
+ assert all(
297
+ l == lengths[0] for l in lengths
298
+ ), "All destination ranks must have the same number of sharded tensor"
299
+ # list[list[Tensor]] -> list[Tensor]
300
+ per_dst = [t for dst in per_dst for t in dst]
301
+ send_buf = torch.cat(per_dst, dim=0)
302
+ else:
303
+ # all_to_all requires participation from all ranks
304
+ # Even non-owner ranks must join the collective call
305
+ send_buf = torch.empty(0, dtype=COMM_DTYPE, device="cuda")
306
+
307
+ # Compute receive sizes and allocate receiving buffers
308
+ recv_counts = [0] * num_ranks
309
+
310
+ for src in range(num_ranks):
311
+ total = 0
312
+ for p in params:
313
+ state = param_to_state[id(p)]
314
+ if state.worker_rank != src:
315
+ continue
316
+ total += split_elems_for_src(p, rank, num_ranks)
317
+ recv_counts[src] = total
318
+
319
+ recv_total = sum(recv_counts)
320
+ assert recv_total > 0
321
+ recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda")
322
+
323
+ #All2All
324
+ dist.all_to_all_single(
325
+ recv_buf,
326
+ send_buf,
327
+ output_split_sizes=recv_counts,
328
+ input_split_sizes=send_counts,
329
+ group=process_group,
330
+ )
331
+
332
+ # Copy to pre-allocated scattered_u buffer from the received buffer
333
+ #
334
+ # recv_buf (num ranks = 3, local_rank = 0)
335
+ #
336
+ # From rank 0 From rank 1 From rank 2
337
+ # | p1_0, p2_0, p3_0 | p4_0 | p5_0, p6_0 |
338
+ #
339
+ # Outer loop:
340
+ # rank 0 -> rank 1 -> rank2
341
+ #
342
+ # Inner loop:
343
+ # src(0) : p1_0 -> p2_0 -> p3_0
344
+ # src(1) : p4_0
345
+ # src(2) : p5_0 -> p6_0
346
+
347
+ comm_stream.wait_event(alloc_event)
348
+
349
+ off = 0
350
+ for src in range(num_ranks):
351
+ block = recv_counts[src]
352
+ if block == 0:
353
+ continue
354
+
355
+ inner_off = 0
356
+ for p in params:
357
+ state = param_to_state[id(p)]
358
+ if state.worker_rank != src:
359
+ continue
360
+ n = split_elems_for_src(p, rank, num_ranks)
361
+ assert n > 0
362
+
363
+ flat_local = recv_buf.narrow(0, off + inner_off,
364
+ n).view_as(p.to_local())
365
+ state.scattered_u.copy_(flat_local)
366
+
367
+ state.scatter_event = torch.cuda.Event()
368
+ state.scatter_event.record(comm_stream)
369
+ inner_off += n
370
+
371
+ assert inner_off == block
372
+ off += block
373
+
374
+
375
+ def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
376
+ compute_stream):
377
+ """
378
+ Update sharded parameter p with the scattered_u.
379
+ Only worker_rank frees computed_u.
380
+ """
381
+ with torch.cuda.stream(compute_stream):
382
+ if state.scatter_event is None:
383
+ raise RuntimeError("Scatter event must be set before update")
384
+ compute_stream.wait_event(state.scatter_event)
385
+ u_dtensor = DTensor.from_local(
386
+ state.scattered_u,
387
+ placements=p.placements,
388
+ device_mesh=p.device_mesh,
389
+ )
390
+
391
+ state.scattered_u = u_dtensor
392
+
393
+ if rank == state.worker_rank:
394
+ # Free computed_u
395
+ state.computed_u = None
396
+
397
+ Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay)
398
+ state.scattered_u = None
399
+ u_dtensor = None
400
+
401
+ scales_full = Muon._compute_scales(p, state.qk_clip_state)
402
+ if scales_full is not None:
403
+ num_ranks = dist.get_world_size(group=state.process_group)
404
+ local_rank = dist.get_rank(group=state.process_group)
405
+ scales_local = scales_full.chunk(num_ranks, dim=0)[local_rank]
406
+ scales_local = DTensor.from_local(
407
+ scales_local,
408
+ placements=p.placements,
409
+ device_mesh=p.device_mesh,
410
+ )
411
+ Muon._qk_clip(p, scales_local, state.qk_clip_state.head_dim)
412
+
413
+
414
+ def default_is_muon(name, x):
415
+ skip_keys = ["embed_tokens", "lm_head", "tok_embeddings", "output"]
416
+ return x.ndim >= 2 and not any(key in name for key in skip_keys)
417
+
418
+
419
+ def get_default_muon_param_groups(model, is_muon_func=default_is_muon):
420
+ muon_params, muon_names = [], []
421
+ non_muon_params = []
422
+
423
+ for n, p in model.named_parameters():
424
+ if not p.requires_grad:
425
+ continue
426
+ if is_muon_func(n, p):
427
+ muon_params.append(p)
428
+ muon_names.append(n)
429
+ else:
430
+ non_muon_params.append(p)
431
+
432
+ return [
433
+ {
434
+ "params": muon_params,
435
+ "names": muon_names,
436
+ "use_muon": True,
437
+ },
438
+ {
439
+ "params": non_muon_params,
440
+ "use_muon": False,
441
+ },
442
+ ]
443
+
444
+
445
+ def parse_qk_layer(name: str) -> tuple[str | None, int]:
446
+ """
447
+ Parse a parameter name to check if it is a query/key projection layer
448
+ ('wq', 'wk', 'q_proj', 'k_proj') and return (kind, layer_index).
449
+
450
+ Returns:
451
+ (kind, layer_idx) or (None, -1) if not matched.
452
+
453
+ Example:
454
+ 'model.3.attn.wq.weight' -> ('wq', 3)
455
+ 'model.5.attn.wk.weight' -> ('wk', 5)
456
+ 'model.2.attn.q_proj.weight' -> ('q_proj', 2)
457
+ 'model.7.attn.k_proj.weight' -> ('k_proj', 7)
458
+ 'model.4.attn.v_proj.weight' -> (None, -1)
459
+ """
460
+ parts = name.split('.')
461
+ if len(parts) < 3:
462
+ return None, -1
463
+
464
+ kind = parts[-2]
465
+
466
+ layer_idx = -1
467
+ for part in reversed(parts):
468
+ if part.isdigit():
469
+ layer_idx = int(part)
470
+ break
471
+
472
+ if kind in ('wq', 'wk', 'q_proj', 'k_proj'):
473
+ return kind, layer_idx
474
+
475
+ return None, -1
476
+
477
+
478
+ @dataclass
479
+ class QKClipInfo:
480
+ """Per-parameter dynamic info computed from config + runtime logits."""
481
+ kind: Optional[str] # 'wq'/'q_proj' or 'wk'/'k_proj' or None
482
+ indices: List[int] # which heads to consider for clipping
483
+ head_dim: int # from config
484
+ threshold: float # from config
485
+ logit: Optional[torch.Tensor]
486
+
487
+
488
+ class Muon(torch.optim.Optimizer):
489
+ """
490
+ Muon - MomentUm Orthogonalized by Newton-schulz
491
+
492
+ Muon internally runs standard SGD-momentum, and then performs an orthogonalization post-
493
+ processing step, in which each 2D parameter's update is replaced with the nearest orthogonal
494
+ matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has
495
+ the advantage that it can be stably run in bfloat16 on the GPU.
496
+
497
+ Some warnings:
498
+ - We believe this optimizer is unlikely to work well for training with small batch size.
499
+ - We believe it may not work well for finetuning pretrained models, but we haven't tested this.
500
+
501
+ Arguments:
502
+ model: The model to be optimized by Muon.
503
+ is_muon_func: A function that takes a parameter and its name, and returns whether the parameter should be optimized by Muon.
504
+ lr: The learning rate. The updates will have spectral norm of `lr`. (0.02 is a good default)
505
+ momentum: The momentum used by the internal SGD. (0.95 is a good default)
506
+ nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended)
507
+ ns_steps: The number of Newton-Schulz iterations to run. (6 is probably always enough)
508
+ weight_decay: The weight decay for Muon and AdamW.
509
+ {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW as well.
510
+ adamw_lr: The learning rate for the internal AdamW.
511
+ adamw_betas: The betas for the internal AdamW.
512
+ adamw_eps: The epsilon for the internal AdamW.
513
+ none_grad: Whether to set p.grad to None after gathering the gradients. This can save memory.
514
+ debug: Whether to print debug information.
515
+ clip_info : Configuration for QK clipping. Expected keys:
516
+ - "q_indices" (list[int]): Indices of query heads to consider.
517
+ - "k_indices" (list[int]): Indices of key heads to consider.
518
+ - "head_dim" (int): Dimensionality of each attention head.
519
+ - "threshold" (float): Threshold value; heads whose QK logits exceed
520
+ this value will be scaled down.
521
+ Default is:
522
+ {
523
+ "q_indices": [],
524
+ "k_indices": [],
525
+ "head_dim": 128,
526
+ "threshold": 100
527
+ }
528
+ overlap_step : How many all2all gather, compute operations are launched in advance
529
+ before the corresponding all2all scatter steps begin.
530
+ A higher overlap_step increases memory usage but can improve
531
+ performance by overlapping communication.
532
+ Parallel muon only.
533
+ """
534
+
535
+ def __init__(self,
536
+ params,
537
+ lr=1e-3,
538
+ momentum=0.95,
539
+ nesterov=True,
540
+ ns_steps=5,
541
+ weight_decay=0.1,
542
+ adamw_betas=(0.9, 0.95),
543
+ adamw_eps=1e-8,
544
+ none_grad=True,
545
+ debug=False,
546
+ clip_config={
547
+ "q_indices": [],
548
+ "k_indices": [],
549
+ "head_dim": 128,
550
+ "threshold": 100
551
+ },
552
+ overlap_step=5):
553
+ defaults = dict(
554
+ lr=lr,
555
+ weight_decay=weight_decay,
556
+ momentum=momentum,
557
+ nesterov=nesterov,
558
+ ns_steps=ns_steps,
559
+ adamw_betas=adamw_betas,
560
+ adamw_eps=adamw_eps,
561
+ none_grad=none_grad,
562
+ use_muon=True,
563
+ )
564
+ error_message = "The key 'use_muon' is not set in parameter group {idx}. Assuming all parameters in the group will use muon optimization, which may lead to unexpected behavior."
565
+ instruction_code = "\n\n please follow this code snippet \n```optimizer = get_kernel('motif-technologies/optimizer')\n\n\nparams = optimizer.muon.get_default_muon_param_groups(model)\n\noptim = optimizer.Muon(params, ...)```"
566
+
567
+ if isinstance(params, types.GeneratorType):
568
+ raise ValueError(error_message.format(idx=0) + instruction_code)
569
+ for _idx, param_group in enumerate(params):
570
+ if param_group.get("use_muon", None) is None:
571
+ raise ValueError(
572
+ error_message.format(idx=_idx) + instruction_code)
573
+
574
+ super().__init__(params, defaults)
575
+
576
+ self.rank = None
577
+
578
+ self.comm_stream = torch.cuda.Stream()
579
+ self.compute_stream = torch.cuda.Stream()
580
+ self.debug = debug
581
+ self.clip_config = clip_config
582
+ self.overlap_step = overlap_step
583
+
584
+ def _calc_flops(self, G, steps):
585
+ assert len(G.shape) == 2
586
+ M, N = G.shape
587
+ if M > N:
588
+ M, N = N, M
589
+
590
+ return steps * ((M**3) * 2 + (M**2 * N) * 4 + M * N * 2 + M**2 * 3)
591
+
592
+ def adjust_lr_for_muon(self, lr, param_shape):
593
+ A, B = param_shape[:2]
594
+ # We adjust the learning rate and weight decay based on the size of the parameter matrix
595
+ # as describted in the paper
596
+ adjusted_ratio = 0.2 * math.sqrt(max(A, B))
597
+ adjusted_lr = lr * adjusted_ratio
598
+ return adjusted_lr
599
+
600
+ def get_shard_mesh(self, p):
601
+ """
602
+ Get the shard mesh for a parameter p on the given rank.
603
+ """
604
+ assert isinstance(
605
+ p, DTensor), "Parallel Muon only supports DTensor parameters."
606
+
607
+ if p.placements == (Shard(dim=0), ):
608
+ # Case for FSDP
609
+ process_group = p.device_mesh.get_group(mesh_dim=0)
610
+ if self.rank is None:
611
+ self.rank = dist.get_rank(group=process_group)
612
+ else:
613
+ assert self.rank == dist.get_rank(group=process_group)
614
+ return p.device_mesh.mesh, p.device_mesh.get_group(mesh_dim=0)
615
+ elif p.placements == (Replicate(), Shard(dim=0)):
616
+ # Case for HSDP
617
+ process_group = p.device_mesh.get_group(mesh_dim=1)
618
+ if self.rank is None:
619
+ self.rank = dist.get_rank(group=process_group)
620
+ else:
621
+ assert self.rank == dist.get_rank(group=process_group)
622
+ for i, shard_mesh in enumerate(p.device_mesh.mesh):
623
+ if self.rank in shard_mesh:
624
+ return shard_mesh, p.device_mesh.get_group(mesh_dim=1)
625
+ else:
626
+ raise ValueError(f"Unsupported placements ({p.placements}).")
627
+
628
+ def init_state_and_assign_params(self, names, params, group, qk_logits):
629
+ param_to_state = {}
630
+ param_to_flops = {}
631
+
632
+ total_flops = 0
633
+ for p in params:
634
+ g = p.grad
635
+ if g is None:
636
+ continue
637
+ assert g.ndim == 2, "Muon only supports 2D parameters."
638
+
639
+ flops = self._calc_flops(g, group["ns_steps"])
640
+ param_to_flops[id(p)] = flops
641
+ total_flops += flops
642
+
643
+ if self.debug:
644
+ print(f"Total TFLOPs for Muon: {total_flops / 1e12:.2f} TFLOPs",
645
+ flush=True)
646
+
647
+ paired = list(zip(names, params))
648
+
649
+ paired_sorted = sorted(paired,
650
+ key=lambda x: param_to_flops[id(x[1])],
651
+ reverse=True)
652
+
653
+ names_sorted, params_sorted = zip(*paired_sorted)
654
+ ordered_names = list(names_sorted)
655
+ ordered_params = list(params_sorted)
656
+
657
+ round_robin = 0
658
+ mesh = None
659
+ shard_mesh = None
660
+ process_group = None
661
+ for n, p in zip(ordered_names, ordered_params):
662
+ if mesh is None:
663
+ mesh = p.device_mesh
664
+ shard_mesh, process_group = self.get_shard_mesh(p)
665
+ elif mesh != p.device_mesh:
666
+ raise ValueError("All parameters must be on the same mesh.")
667
+ num_ranks = dist.get_world_size(group=process_group)
668
+ param_to_state[id(p)] = _muon_state()
669
+ param_to_state[id(
670
+ p)].worker_rank = shard_mesh[round_robin].item() % num_ranks
671
+ param_to_state[id(p)].process_group = process_group
672
+ qk_clip_state = self.get_qk_clip_info(n, qk_logits)
673
+ param_to_state[id(p)].qk_clip_state = qk_clip_state
674
+ round_robin = (round_robin + 1) % len(shard_mesh)
675
+
676
+ return param_to_state, ordered_params
677
+
678
+ def base(self, names, params, group, lr, weight_decay, momentum,
679
+ qk_logits):
680
+ # generate weight updates in distributed fashion
681
+ for n, p in zip(names, params):
682
+ g = p.grad
683
+ if g is None:
684
+ continue
685
+ if g.ndim > 2:
686
+ g = g.view(g.size(0), -1)
687
+ assert g is not None
688
+
689
+ # calc update
690
+ state = self.state[p]
691
+ if "momentum_buffer" not in state:
692
+ state["momentum_buffer"] = torch.zeros_like(g)
693
+ buf = state["momentum_buffer"]
694
+ buf.mul_(momentum).add_(g)
695
+ if group["nesterov"]:
696
+ g = g.add(buf, alpha=momentum)
697
+ else:
698
+ g = buf
699
+
700
+ u = _zeropower_via_newtonschulz5(g.to(COMM_DTYPE),
701
+ steps=group["ns_steps"])
702
+
703
+ adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
704
+ Muon._update_p(p, u, lr, adjusted_lr, weight_decay)
705
+
706
+ qk_clip_state = self.get_qk_clip_info(n, qk_logits)
707
+
708
+ scales_full = self._compute_scales(p, qk_clip_state)
709
+ if scales_full is not None:
710
+ Muon._qk_clip(p, scales_full, qk_clip_state.head_dim)
711
+
712
+ def _update_g(self, p, g, group, momentum):
713
+ # calc update
714
+ state = self.state[p]
715
+ buf = state.setdefault("momentum_buffer", torch.zeros_like(g))
716
+ torch.add(g, buf, alpha=momentum, out=buf)
717
+ if group["nesterov"]:
718
+ g.add_(buf, alpha=momentum)
719
+ return g
720
+ return buf
721
+
722
+ @staticmethod
723
+ def _update_p(p, u, lr, adjusted_lr, weight_decay):
724
+ # apply weight decay
725
+ p.data.mul_(1 - lr * weight_decay)
726
+ # apply update
727
+ p.data.add_(u, alpha=-adjusted_lr)
728
+
729
+ def get_qk_clip_info(self, n, qk_logits):
730
+ head_dim = self.clip_config.get('head_dim')
731
+ threshold = self.clip_config.get('threshold')
732
+ kind, layer_idx = parse_qk_layer(n)
733
+
734
+ logit, indices = None, []
735
+ if qk_logits is not None and kind is not None:
736
+ logit = qk_logits[layer_idx]
737
+ indices_key = 'q_indices' if 'q' in kind else 'k_indices'
738
+ indices = self.clip_config.get(indices_key, []) or []
739
+
740
+ return QKClipInfo(
741
+ kind=kind,
742
+ indices=indices,
743
+ head_dim=head_dim,
744
+ threshold=threshold,
745
+ logit=logit,
746
+ )
747
+
748
+ @staticmethod
749
+ def _compute_scales(p, qk_clip_state):
750
+ kind = qk_clip_state.kind
751
+ indices = qk_clip_state.indices
752
+ head_dim = qk_clip_state.head_dim
753
+ threshold = qk_clip_state.threshold
754
+ logit = qk_clip_state.logit
755
+
756
+ H_global = p.shape[0] // head_dim
757
+ scales_full = torch.ones(H_global, device=p.data.device)
758
+ scaling = 0
759
+
760
+ for logit_idx, head_idx in enumerate(indices):
761
+ v_ele = float(logit[logit_idx])
762
+ if v_ele > threshold:
763
+ new_scale = math.sqrt(threshold / v_ele)
764
+ if new_scale < scales_full[head_idx]:
765
+ scales_full[head_idx] = new_scale
766
+ logger.info(
767
+ f"[{kind}] Head {head_idx} exceeded threshold "
768
+ f"(value={v_ele:.4f}, threshold={threshold:.4f}) -> applying scale={new_scale:.4f}"
769
+ )
770
+ scaling += 1
771
+
772
+ return scales_full if scaling > 0 else None
773
+
774
+ @staticmethod
775
+ def _qk_clip(p, scales, head_dim):
776
+ W = p.data.view(-1, head_dim, p.data.shape[1])
777
+ W.mul_(scales.view(-1, 1, 1))
778
+
779
+ def parallel(self, names, params, group, lr, weight_decay, momentum,
780
+ qk_logits):
781
+ """
782
+ Perform a parallel optimization step using Muon.
783
+ """
784
+
785
+ for p in params:
786
+ g = p.grad
787
+ if g is None:
788
+ continue
789
+ if g.ndim > 2:
790
+ g = g.view(g.size(0), -1)
791
+
792
+ # Update g in the local rank
793
+ g = self._update_g(
794
+ p,
795
+ g,
796
+ group,
797
+ momentum=momentum,
798
+ )
799
+ p.grad = g
800
+
801
+ param_to_state, ordered_params = self.init_state_and_assign_params(
802
+ names, params, group, qk_logits)
803
+
804
+ assert self.rank is not None
805
+
806
+ def enqueue_all2all_gather(start_idx, chunk_size):
807
+ target_params = ordered_params[start_idx:start_idx + chunk_size]
808
+ if target_params:
809
+ alloc_event = _alloc_gathered_grad(target_params,
810
+ param_to_state, self.rank,
811
+ self.compute_stream)
812
+ _all2all_gather(target_params, param_to_state, self.rank,
813
+ self.comm_stream, group["none_grad"],
814
+ alloc_event)
815
+
816
+ def enqueue_computes(start_idx, chunk_size):
817
+ for p in ordered_params[start_idx:start_idx + chunk_size]:
818
+ state = param_to_state[id(p)]
819
+ _compute_u(p, state, group["ns_steps"], self.rank,
820
+ self.compute_stream)
821
+
822
+ def enqueue_all2all_scatter(start_idx, chunk_size):
823
+ target_params = ordered_params[start_idx:start_idx + chunk_size]
824
+ if target_params:
825
+ alloc_event = _alloc_scattered_u(target_params, param_to_state,
826
+ self.rank,
827
+ self.compute_stream)
828
+ _all2all_scatter(target_params, param_to_state, self.rank,
829
+ self.comm_stream, alloc_event)
830
+
831
+ def enqueue_update_param(start_idx, chunk_size):
832
+ for p in ordered_params[start_idx:start_idx + chunk_size]:
833
+ state = param_to_state[id(p)]
834
+ adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
835
+ _update_param(p, state, lr, adjusted_lr, weight_decay,
836
+ self.rank, self.compute_stream)
837
+
838
+ chunk_size = dist.get_world_size(param_to_state[id(
839
+ params[0])].process_group)
840
+
841
+ # Wait grad update
842
+ self.comm_stream.wait_stream(torch.cuda.current_stream())
843
+
844
+ overlap_step = self.overlap_step
845
+ for i in range(0, overlap_step):
846
+ enqueue_all2all_gather(i * chunk_size, chunk_size)
847
+ enqueue_computes(i * chunk_size, chunk_size)
848
+
849
+ for i in range(0, len(params) + chunk_size - 1, chunk_size):
850
+ enqueue_all2all_scatter(i, chunk_size)
851
+ enqueue_all2all_gather(i + overlap_step * chunk_size, chunk_size)
852
+ enqueue_update_param(i, chunk_size)
853
+ enqueue_computes(i + overlap_step * chunk_size, chunk_size)
854
+
855
+ # Wait the last update_param to finish
856
+ torch.cuda.current_stream().wait_stream(self.compute_stream)
857
+
858
+ @staticmethod
859
+ def _fused_adamw(
860
+ params: list[torch.Tensor],
861
+ grads: list[torch.Tensor],
862
+ exp_avgs: list[torch.Tensor],
863
+ exp_avg_sqs: list[torch.Tensor],
864
+ max_exp_avg_sqs: list[torch.Tensor],
865
+ state_steps: list[torch.Tensor],
866
+ amsgrad: bool,
867
+ beta1: float,
868
+ beta2: float,
869
+ lr: Union[float, torch.Tensor],
870
+ weight_decay: float,
871
+ eps: float,
872
+ maximize: bool,
873
+ ) -> None:
874
+ if not params:
875
+ return
876
+
877
+ # We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer
878
+ # treating it as a scalar.
879
+ lr_dict: Optional[DeviceDict] = ({
880
+ lr.device: lr
881
+ } if isinstance(lr, torch.Tensor) and str(lr.device) != "cpu" else
882
+ None)
883
+ grouped_tensors = torch.optim.Optimizer._group_tensors_by_device_and_dtype(
884
+ [
885
+ params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs,
886
+ state_steps
887
+ ] # type: ignore[list-item]
888
+ )
889
+ for (device, _), (
890
+ (
891
+ device_params_,
892
+ device_grads_,
893
+ device_exp_avgs_,
894
+ device_exp_avg_sqs_,
895
+ device_max_exp_avg_sqs,
896
+ device_state_steps_,
897
+ ),
898
+ _,
899
+ ) in grouped_tensors.items():
900
+ device_params = cast(list[torch.Tensor], device_params_)
901
+ device_grads = cast(list[torch.Tensor], device_grads_)
902
+ device_exp_avgs = cast(list[torch.Tensor], device_exp_avgs_)
903
+ device_exp_avg_sqs = cast(list[torch.Tensor], device_exp_avg_sqs_)
904
+ device_state_steps = cast(list[torch.Tensor], device_state_steps_)
905
+
906
+ if lr_dict is not None and device not in lr_dict:
907
+ lr_dict[device] = lr.to(
908
+ device=device,
909
+ non_blocking=True) # type: ignore[union-attr]
910
+ lr = lr_dict[device]
911
+ torch._foreach_add_(device_state_steps, 1)
912
+ func = torch._fused_adamw_
913
+ func(
914
+ device_params,
915
+ device_grads,
916
+ device_exp_avgs,
917
+ device_exp_avg_sqs,
918
+ device_max_exp_avg_sqs, # type: ignore[arg-type]
919
+ device_state_steps,
920
+ amsgrad=amsgrad,
921
+ lr=lr, # type: ignore[arg-type]
922
+ beta1=beta1,
923
+ beta2=beta2,
924
+ weight_decay=weight_decay,
925
+ eps=eps,
926
+ maximize=maximize,
927
+ )
928
+
929
+ def step(self, closure=None, qk_logits=None):
930
+ """Perform a single optimization step.
931
+
932
+ Args:
933
+ closure (Callable, optional): A closure that reevaluates the model
934
+ and returns the loss.
935
+ qk_logits (dict[int, Tensor], optional): A dictionary mapping layer indices
936
+ to 1D tensors of shape (num_heads,), representing the maximum
937
+ QK logits across all tokens, computed as
938
+ (1 / sqrt(head_dim)) * (Q @ K^T).
939
+ """
940
+ loss = None
941
+ if closure is not None:
942
+ with torch.enable_grad():
943
+ loss = closure()
944
+
945
+ for group in self.param_groups:
946
+ params = group["params"]
947
+
948
+ if group["use_muon"]:
949
+ ############################
950
+ # Muon #
951
+ ############################
952
+ lr = group["lr"]
953
+ weight_decay = group["weight_decay"]
954
+ momentum = group["momentum"]
955
+ names = group["names"]
956
+
957
+ param_dtensors = []
958
+ param_tensors = []
959
+ name_dtensors = []
960
+ name_tensors = []
961
+
962
+ for n, p in zip(names, params):
963
+ if p is None or p.grad is None:
964
+ continue
965
+ if isinstance(p.data, DTensor):
966
+ if all(
967
+ isinstance(placement, Replicate)
968
+ for placement in p.placements):
969
+ param_tensors.append(p)
970
+ name_tensors.append(n)
971
+ else:
972
+ param_dtensors.append(p)
973
+ name_dtensors.append(n)
974
+ elif isinstance(p.data, torch.Tensor):
975
+ param_tensors.append(p)
976
+ name_tensors.append(n)
977
+ else:
978
+ raise TypeError(
979
+ f"Unsupported parameter type: {type(p.data)}")
980
+
981
+ if self.debug:
982
+ print(
983
+ f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors",
984
+ flush=True,
985
+ )
986
+
987
+ if len(param_dtensors) > 0:
988
+ if not dist.is_initialized():
989
+ raise RuntimeError(
990
+ "Parallel Muon requires torch.distributed to be initialized."
991
+ )
992
+
993
+ self.parallel(
994
+ name_dtensors,
995
+ param_dtensors,
996
+ group,
997
+ lr=lr,
998
+ weight_decay=weight_decay,
999
+ momentum=momentum,
1000
+ qk_logits=qk_logits,
1001
+ )
1002
+
1003
+ if len(param_tensors) > 0:
1004
+ self.base(
1005
+ name_tensors,
1006
+ param_tensors,
1007
+ group,
1008
+ lr=lr,
1009
+ weight_decay=weight_decay,
1010
+ momentum=momentum,
1011
+ qk_logits=qk_logits,
1012
+ )
1013
+
1014
+ else:
1015
+ ############################
1016
+ # AdamW backup #
1017
+ ############################
1018
+
1019
+ params_with_grads = []
1020
+ grads = []
1021
+ moment1 = []
1022
+ moment2 = []
1023
+ max_exp_avg_sqs = []
1024
+ state_steps = []
1025
+ lr = group["lr"]
1026
+ beta1, beta2 = group["adamw_betas"]
1027
+ eps = group["adamw_eps"]
1028
+ weight_decay = group["weight_decay"]
1029
+
1030
+ for p in params:
1031
+ g = p.grad
1032
+ if g is None:
1033
+ continue
1034
+ state = self.state[p]
1035
+ params_with_grads.append(p)
1036
+ grads.append(g)
1037
+ if "step" not in state:
1038
+ state["step"] = (torch.zeros((),
1039
+ dtype=torch.float32,
1040
+ device=p.device))
1041
+ state["moment1"] = torch.zeros_like(g)
1042
+ state["moment2"] = torch.zeros_like(g)
1043
+ moment1.append(state["moment1"])
1044
+ moment2.append(state["moment2"])
1045
+ if not isinstance(state["step"], torch.Tensor):
1046
+ step_tensor = torch.tensor(state["step"],
1047
+ dtype=torch.float32,
1048
+ device=p.device)
1049
+ else:
1050
+ step_tensor = state["step"]
1051
+ state_steps.append(step_tensor)
1052
+
1053
+ self._fused_adamw(
1054
+ params_with_grads,
1055
+ grads,
1056
+ moment1,
1057
+ moment2,
1058
+ max_exp_avg_sqs,
1059
+ state_steps,
1060
+ amsgrad=False,
1061
+ beta1=beta1,
1062
+ beta2=beta2,
1063
+ lr=lr,
1064
+ weight_decay=weight_decay,
1065
+ eps=eps,
1066
+ maximize=False,
1067
+ )
1068
+
1069
+ return loss