Kernels
dongseokmotif commited on
Commit
b220459
·
unverified ·
2 Parent(s): 33929c0 bf30b9b

Merge pull request #17 from MotifTechnologies/optimal-ns-coefficients

Browse files

Replace hardcoded NS coefficients with analytically optimal ones [ski…

Files changed (1) hide show
  1. torch-ext/optimizer/newton_schulz.py +134 -20
torch-ext/optimizer/newton_schulz.py CHANGED
@@ -1,3 +1,7 @@
 
 
 
 
1
  import torch
2
 
3
  from .matmul_transpose_triton import matmul_transpose_assign
@@ -6,21 +10,134 @@ COMM_DTYPE = torch.bfloat16
6
  DEFAULT_CHUNK_SIZE_RATIO = 4
7
 
8
 
9
- # This code snippet is a modified version adapted from the following GitHub repositories:
10
- # https://github.com/KellerJordan/Muon/blob/master/muon.py
11
- # Muon's Newton–Schulz iteration causes high variance in singular values
12
- # Idea: give each iteration its own 3 coefficients and optimize them via gradient descent.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  @torch.no_grad()
14
- # matmul_transpose_assign from : https://github.com/nil0x9/flash-muon
15
  def _zeropower_via_newtonschulz5(G, steps):
16
  """
17
- Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
18
- quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
19
- of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
20
- zero even beyond the point where the iteration no longer converges all the way to one everywhere
21
- on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
22
- where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
23
- performance at all relative to UV^T, where USV^T = G is the SVD.
 
 
 
 
 
 
 
24
  """
25
  assert len(G.shape) == 2
26
  assert G.dtype == COMM_DTYPE
@@ -28,18 +145,14 @@ def _zeropower_via_newtonschulz5(G, steps):
28
 
29
  if G.size(0) > G.size(1):
30
  X = X.T
31
- # Ensure spectral norm is at most 1
32
  X = X / (X.norm() + 1e-7)
 
 
33
  buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
34
  buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
35
  # Perform the NS iterations
36
- for a, b, c in [
37
- (4.0848, -6.8946, 2.9270),
38
- (3.9505, -6.3029, 2.6377),
39
- (3.7418, -5.5913, 2.3037),
40
- (2.8769, -3.1427, 1.2046),
41
- (2.8366, -3.0525, 1.2012),
42
- ]:
43
  matmul_transpose_assign(X, buf1)
44
  matmul_transpose_assign(buf1, buf2)
45
  buf1.mul_(b).add_(buf2, alpha=c)
@@ -47,4 +160,5 @@ def _zeropower_via_newtonschulz5(G, steps):
47
 
48
  if G.size(0) > G.size(1):
49
  X = X.T
 
50
  return X
 
1
+ from itertools import repeat
2
+ from math import inf, sqrt
3
+
4
+ import numpy as np
5
  import torch
6
 
7
  from .matmul_transpose_triton import matmul_transpose_assign
 
10
  DEFAULT_CHUNK_SIZE_RATIO = 4
11
 
12
 
13
+ def _optimal_quintic(l, u, max_iter=1000):
14
+ """
15
+ Use the simplified Remez algorithm to find the optimal odd quintic approximant
16
+ to the constant function x -> 1 over the interval [l, u].
17
+
18
+ Returns (a, b, c) for p(x) = ax + bx^3 + cx^5 that minimizes the maximum
19
+ approximation error max_{x in [l,u]} |p(x) - 1|. Iterates by updating the
20
+ two interior equioscillation nodes q, r until convergence. Returns the
21
+ closed-form equioscillating solution when l ≈ u.
22
+
23
+ Raises ValueError if any intermediate value (a, b, c, E, q, r) is non-finite
24
+ (NaN or inf). Raises RuntimeError if convergence is not reached within
25
+ max_iter iterations.
26
+ """
27
+ assert 0 <= l <= u
28
+ if 1 - 5e-6 <= l / u:
29
+ return (15 / 8) / u, (-10 / 8) / (u**3), (3 / 8) / (u**5)
30
+ q = (3 * l + u) / 4
31
+ r = (l + 3 * u) / 4
32
+ E = inf
33
+ for _ in range(max_iter):
34
+ old_E = E
35
+ LHS = np.array([
36
+ [l, l**3, l**5, 1],
37
+ [q, q**3, q**5, -1],
38
+ [r, r**3, r**5, 1],
39
+ [u, u**3, u**5, -1],
40
+ ])
41
+ a, b, c, E = np.linalg.solve(LHS, np.ones(4))
42
+ if not np.all(np.isfinite([a, b, c, E])):
43
+ raise ValueError(f"_optimal_quintic: non-finite solve result "
44
+ f"a={a}, b={b}, c={c}, E={E}")
45
+ q, r = np.sqrt(
46
+ (-3 * b + np.array([-1, 1]) * sqrt(9 * b**2 - 20 * a * c)) /
47
+ (10 * c))
48
+ if not np.all(np.isfinite([q, r])):
49
+ raise ValueError(
50
+ f"_optimal_quintic: non-finite node update q={q}, r={r}")
51
+ if abs(old_E - E) <= 1e-15:
52
+ break
53
+ else:
54
+ raise RuntimeError(
55
+ f"_optimal_quintic: did not converge after {max_iter} iterations")
56
+ return float(a), float(b), float(c)
57
+
58
+
59
+ def _optimal_composition(l, num_iters, safety_factor_eps=0, cushion=0):
60
+ """
61
+ Compute the Polar Express coefficient series for `num_iters` quintic iterations.
62
+
63
+ Builds a sequence of per-step optimal odd quintic coefficients (a, b, c) that
64
+ compose to map singular values from [l, 1] toward 1. At each step:
65
+ 1. Solves `_optimal_quintic` on [max(l, cushion*u), u]. The `cushion`
66
+ prevents near-zero singular values from stalling by raising the effective
67
+ lower bound; if it is active (cushion*u > l), the coefficients are
68
+ rescaled so that p(l) and p(u) are centered around 1 w.r.t. the true [l, u].
69
+ 2. Deflates the coefficients by (1 + safety_factor_eps)^degree for all but the
70
+ last iteration, providing numerical headroom at the cost of a slightly slower
71
+ final convergence step.
72
+ 3. Advances the interval: l <- p(l), u <- 2 - p(l) (by symmetry of p around 1).
73
+
74
+ Returns a list of (a, b, c) tuples, one per iteration.
75
+
76
+ Reference: Amsel et al., "The Polar Express: Optimal Matrix Sign Methods and
77
+ Their Application to the Muon Algorithm", https://arxiv.org/abs/2505.16932
78
+ """
79
+ u = 1
80
+ assert 0 <= l <= u
81
+ safety_factor = 1 + safety_factor_eps
82
+ coefficients = []
83
+ for iter in range(num_iters):
84
+ a, b, c = _optimal_quintic(max(l, cushion * u), u)
85
+ if cushion * u > l:
86
+ pl = a * l + b * l**3 + c * l**5
87
+ pu = a * u + b * u**3 + c * u**5
88
+ rescaler = 2 / (pl + pu)
89
+ a *= rescaler
90
+ b *= rescaler
91
+ c *= rescaler
92
+ if iter < num_iters - 1:
93
+ a /= safety_factor
94
+ b /= safety_factor**3
95
+ c /= safety_factor**5
96
+ coefficients.append((a, b, c))
97
+ l = a * l + b * l**3 + c * l**5
98
+ u = 2 - l
99
+ return coefficients
100
+
101
+
102
+ # Precomputed Polar Express coefficients (a, b, c) for 10 quintic Newton-Schulz
103
+ # iterations. Each tuple is the minimax-optimal (Remez/equioscillation) odd quintic
104
+ # approximant to x->1 over the current singular-value interval, computed once at
105
+ # import time and reused across all optimizer steps.
106
+ #
107
+ # Contrast with the former hardcoded NS coefficients (5 fixed tuples):
108
+ # - Former: empirically tuned to maximize slope at zero; did not converge
109
+ # singular values to 1, yielding US'V^T with S' ~ Uniform(0.5, 1.5) instead
110
+ # of the true polar factor UV^T.
111
+ # - Polar Express: analytically optimal per step, adapting to the shrinking
112
+ # singular-value interval [l, u] as iterations progress; converges all
113
+ # singular values to 1, producing the exact polar factor UV^T.
114
+ _coeffs_list = _optimal_composition(l=1e-3,
115
+ num_iters=10,
116
+ safety_factor_eps=1e-2,
117
+ cushion=0.02)
118
+
119
+
120
+ # This code is adapted from:
121
+ # KellerJordan/Muon (https://github.com/KellerJordan/Muon/blob/master/muon.py)
122
+ # NoahAmsel/PolarExpress (https://github.com/NoahAmsel/PolarExpress)
123
+ # matmul_transpose_assign kernel from nil0x9/flash-muon (https://github.com/nil0x9/flash-muon)
124
  @torch.no_grad()
 
125
  def _zeropower_via_newtonschulz5(G, steps):
126
  """
127
+ Compute the polar factor of G via the Polar Express method.
128
+
129
+ Applies `steps` quintic iterations X <- aX + bX^3 + cX^5, where (a, b, c)
130
+ are the Polar Express coefficients from `_coeffs_list`. Each step is the
131
+ optimal odd quintic approximant to x -> 1 over the current singular-value
132
+ interval, minimizing the maximum approximation error (Remez / minimax criterion).
133
+ The composition maps singular values from [l, 1] to near 1, producing the
134
+ polar factor (orthogonal factor in the polar decomposition G = UP).
135
+
136
+ `_coeffs_list` is precomputed for 10 iterations (l=1e-3, safety_factor_eps=1e-2,
137
+ cushion=0.02). If `steps` exceeds 10, the final coefficient set is repeated.
138
+
139
+ Reference: Amsel et al., "The Polar Express: Optimal Matrix Sign Methods and
140
+ Their Application to the Muon Algorithm", https://arxiv.org/abs/2505.16932
141
  """
142
  assert len(G.shape) == 2
143
  assert G.dtype == COMM_DTYPE
 
145
 
146
  if G.size(0) > G.size(1):
147
  X = X.T
148
+
149
  X = X / (X.norm() + 1e-7)
150
+ hs = _coeffs_list[:steps] + list(
151
+ repeat(_coeffs_list[-1], steps - len(_coeffs_list)))
152
  buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
153
  buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
154
  # Perform the NS iterations
155
+ for a, b, c in hs:
 
 
 
 
 
 
156
  matmul_transpose_assign(X, buf1)
157
  matmul_transpose_assign(buf1, buf2)
158
  buf1.mul_(b).add_(buf2, alpha=c)
 
160
 
161
  if G.size(0) > G.size(1):
162
  X = X.T
163
+
164
  return X