ASR optimizer-
Frequency-Adaptive Momentum (FAM) (wip)
https://github.com/sine2pi/Maxfactor
https://github.com/sine2pi/Echo
https://github.com/sine2pi/Focused-Attention
- Long-Range Dependencies and Specificity:
Scenario: Imagine a task involving long documents where you need to identify very specific pieces of information scattered throughout the text. For instance, answering questions about a legal document or summarizing a complex scientific paper.
Reasoning: When the attention span is long, you're allowing the model to consider a wide range of context. In this case, you might actually want the attention to be sharper. You don't want the model to be wishy-washy and distribute its attention equally across a large number of tokens. You want it to pinpoint the few most relevant pieces of information within that broad context. A softer attention (higher temperature) over a long span would likely lead to a diluted, less informative representation.
Example: If the question is "What is the defendant's age in Case 3.14159?", and Case 3.14159 spans several paragraphs, you'd want the model to sharply focus on the specific sentence mentioning the age, even within that large span.
- Avoiding "Attention Collapse" with Long Spans:
Scenario: With very long spans, standard (or softly scaled) attention can sometimes suffer from a phenomenon where the attention weights become too uniform. The model essentially "gives up" on trying to discriminate between tokens and attends to everything equally.
Reasoning: A sharper softmax (lower temperature) can act as a regularizer, preventing this "attention collapse." It forces the model to make more decisive choices, even when the context is large.
Analogy: Think of it like searching a large library. If you have no idea where to look (soft attention), you might just wander aimlessly. A sharper focus (even if you don't know exactly where to go) forces you to pick specific shelves and sections to examine, increasing your chances of finding what you need.
- Tasks Requiring Precise Identification within Broad Context:
Scenario: Tasks like named entity recognition (NER) or relation extraction, when applied to long documents.
Reasoning: You might need a broad context (long span) to understand the relationships between entities, but you still need to precisely identify the entities themselves (which might be short phrases). Softer attention over a long span might blur the boundaries of the entities, making it harder to extract them accurately.
- Hierarchical Reasoning:
Scenario: Imagine a multi-step reasoning task, where the model needs to first identify relevant sections of a document (long span, sharper attention) and then analyze those sections in more detail (shorter spans, possibly softer attention).
Reasoning: you might want a different temperature scaling approach that is learnable.
- Sparsity Inducement
Scenario: If the model were to be deployed on low power devices.
Reasoning: You want to create as sparse of a weight distribution as possible, and this is done by a lower temperature.
Focus blocks:
class MultiheadC(nn.Module):
use_sdpa: bool = True
def __init__(self, dims: int, heads: int, max_dist: int):
super().__init__()
if dims % heads != 0:
raise ValueError(f"dims ({dims}) must be divisible by heads ({heads})")
if dims % 2 != 0:
raise ValueError(f"dims ({dims}) must be even for rotary embeddings")
self.heads = heads
self.head_dim = dims // heads
self.dims = dims
self.max_dist = max_dist
scale = 1 / math.sqrt(self.head_dim)
self.query = nn.Linear(dims, dims)
self.key = nn.Linear(dims, dims, bias=False)
self.value = nn.Linear(dims, dims)
self.out = nn.Linear(dims, dims)
nn.init.normal_(self.query.weight, std=scale)
nn.init.normal_(self.key.weight, std=scale)
nn.init.normal_(self.value.weight, std=scale)
nn.init.zeros_(self.out.bias)
def forward(self, x: Tensor, xa: Optional[Tensor] = None,
mask: Optional[Tensor] = None, kv_cache: Optional[Dict] = None) -> Tuple[Tensor, Optional[Tensor]]:
q = self.query(x)
if kv_cache is None or xa is None or self.key not in kv_cache:
k = self.key(x if xa is None else xa)
v = self.value(x if xa is None else xa)
else:
k = kv_cache[self.key]
v = kv_cache[self.value]
wv, qk = self.qkv_attention(q=q, k=k, v=v, mask=mask)
return self.out(wv), qk
def qkv_attention(self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None) -> Tuple[Tensor, Optional[Tensor]]:
batch, ctx, dims = q.shape
scale = (dims // self.heads) ** -0.25
q = q.view(batch, ctx, self.heads, self.head_dim).permute(0, 2, 1, 3)
k = k.view(batch, ctx, self.heads, self.head_dim).permute(0, 2, 1, 3)
v = v.view(batch, ctx, self.heads, self.head_dim).permute(0, 2, 1, 3)
if self.use_sdpa and torch.cuda.is_available():
with torch.autocast('cuda'):
a = scaled_dot_product_attention(
query=q,
key=k,
value=v,
is_causal=mask is not None and ctx > 1
)
out = a.permute(0, 2, 1, 3).flatten(start_dim=2)
qk = None
else:
qk = (q * scale) @ (k * scale).transpose(-1, -2)
if mask is not None:
qk = qk + mask[:ctx, :ctx]
qk = qk.float()
w = F.softmax(qk, dim=-1).to(q.dtype)
out = (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2)
qk = qk.detach()
return out, qk
class Refiner:
def __init__(self, states, actions, alpha=0.1, gamma=0.9, epsilon=0.1):
self.states = states
self.actions = actions
self.R = {}
self.alpha = alpha
self.gamma = gamma
self.epsilon = epsilon
self.default_value = 0.0
def get_value(self, state, action):
return self.R.get((state, action), self.default_value)
def set_value(self, state, action, value):
self.R[(state, action)] = value
def choose_action(self, state):
if np.random.random() < self.epsilon:
return np.random.randint(self.actions)
else:
action_values = [self.get_value(state, a) for a in range(self.actions)]
return np.argmax(action_values)
def update(self, state, action, reward, next_state):
next_values = [self.get_value(next_state, a) for a in range(self.actions)]
best_next_value = max(next_values)
old_value = self.get_value(state, action)
td_target = reward + self.gamma * best_next_value
td_error = td_target - old_value
new_value = old_value + self.alpha * td_error
self.set_value(state, action, new_value)
class Predictor(nn.Module):
def __init__(self, dims):
super().__init__()
self.linear = nn.Linear(in_features=dims, out_features=1)
nn.init.xavier_normal_(self.linear.weight)
nn.init.zeros_(self.linear.bias)
def forward(self, global_out):
if global_out.dim() > 2:
global_out = global_out.mean(dim=1)
scale = torch.sigmoid(self.linear(global_out))
return scale
class AdaptiveSpan(nn.Module):
def __init__(self, dims, heads, max_dist, sharpen=True, temp_scale=0.01):
super().__init__()
self.heads = heads
self.max_dist = max_dist
self.dims = dims
self.temp_scale = temp_scale
self.sharpen = sharpen
self.span_scale = nn.Parameter(torch.tensor(1.0))
self.head_dim = dims // heads
self.register_buffer("scale", torch.tensor(self.head_dim**-0.25))
def forward(self, query, key, value, max_dist=None, max_span=None, span_scale=None):
if max_dist is None:
max_dist = self.max_dist
if max_span is None:
max_span = query.shape[1] # Default to sequence length
if span_scale is None:
span_scale = self.span_scale
span_mean = span_scale.mean().item()
span_len = min(int(max_span * span_mean), query.shape[1], key.shape[1], value.shape[1])
eff_span = min(span_len, max_dist)
if eff_span == 0:
batch_size = query.shape[0]
return (torch.zeros(batch_size, eff_span, self.dims, device=query.device), None)
q_span = query[:, :eff_span, :]
k_span = key[:, :eff_span, :]
v_span = value[:, :eff_span, :]
batch_size = q_span.shape[0]
reshape_dims = (batch_size, -1, self.heads, self.head_dim)
q = q_span.view(*reshape_dims).permute(0, 2, 1, 3)
k = k_span.view(*reshape_dims).permute(0, 2, 1, 3)
v = v_span.view(*reshape_dims).permute(0, 2, 1, 3)
with torch.autocast(device_type="cuda", enabled=torch.cuda.is_available()):
temperature = (
1.0 + self.temp_scale * (1.0 - span_mean)
if self.sharpen
else 0.5 + self.temp_scale * span_mean
)
scores = torch.matmul(q, k.transpose(-2, -1))
weights = torch.softmax((scores / temperature) * self.scale, dim=-1)
out = torch.matmul(weights, v)
out = out.permute(0, 2, 1, 3).reshape(batch_size, eff_span, self.dims)
return out, weights
class FocusA(nn.Module):
def __init__(self, dims, heads, max_dist, sharpen=True, win_size=256, max_span=512):
super().__init__()
self.heads = heads
self.max_dist = max_dist
self.dims = dims
self.max_span = max_span
self.sliding_window = win_size
self.temp_scale = 0.01
self.sharpen = sharpen
self.head_dim = dims // heads
self.batch_size = None # Will be set during forward pass
self.refiner = Refiner(
states=10000, actions=10, alpha=0.1, gamma=0.9, epsilon=0.1
)
self.span_pred = Predictor(dims=dims)
self.attn_local = AdaptiveSpan(
dims=dims, heads=heads, max_dist=max_dist, sharpen=True, temp_scale=0.01
)
self.attn_global = MultiheadC(dims=dims, heads=heads, max_dist=max_dist)
self.projection = nn.Linear(in_features=2 * dims, out_features=dims)
self.ln_a = nn.LayerNorm(normalized_shape=dims)
self.ln_b = nn.LayerNorm(normalized_shape=dims)
mask = torch.empty(max_span, max_span).fill_(float("-inf")).triu_(diagonal=1)
self.register_buffer("mask", mask, persistent=False)
self.register_buffer("window_mask", None, persistent=False)
self.register_buffer("threshold", torch.tensor(1e-4), persistent=False)
self.register_buffer("s_factor", torch.tensor(0.1), persistent=False)
def forward(self, x, xa=None, mask=None, kv_cache=None):
if mask is None:
mask = self.mask
local = self.ln_a(x)
globe = self.ln_b(x)
globe_out, _ = self.attn_global(globe, globe, globe)
base_scale = self.span_pred(globe_out)
state = self.extract(local)
action = self.refiner.choose_action(state=state)
refine = self.action_scale(action=action)
span_scale = torch.clamp(base_scale * refine, min=0.0, max=1.0)
span_mean = span_scale.mean().item()
with torch.no_grad():
current_win_size = max(1, int(self.sliding_window * span_mean))
current_span_len = max(1, int(self.max_span * span_mean))
effective_max = min(self.max_dist, local.size(1))
local_max = min(self.max_dist, current_span_len, current_win_size)
globe_max = effective_max
self.attn_local.max_dist = local_max
self.attn_global.max_dist = globe_max
local_out = self.slide_win(
x=local,
win_size=current_win_size,
span_len=current_span_len,
span_scale=span_scale,
mask=mask,
)
with torch.no_grad():
quality = self.quality(output=local_out)
next_state = self.extract(local_out)
self.refiner.update(
state=state, action=action, reward=quality, next_state=next_state)
combined = torch.cat([local_out, globe_out], dim=-1)
x = self.projection(combined)
return x
def quality(self, output):
with torch.no_grad():
safe_output = output.clamp(min=1e-10)
entropy = -(safe_output * torch.log(safe_output)).sum(-1).mean()
coverage = (output > 0.01).float().mean()
return float(coverage - 0.1 * entropy)
def extract(self, x):
with torch.no_grad():
mean_state = x.mean(dim=(0, 1))
var_state = x.var(dim=(0, 1), unbiased=False)
state = torch.cat([mean_state, var_state])
state_id = self.discretize(state.cpu().numpy())
return state_id
def discretize(self, state):
bins = np.linspace(-1, 1, num=10)
state_discrete = np.digitize(state, bins)
state_hash = hash(tuple(state_discrete))
state_id = state_hash % (self.refiner.states - 1)
return state_id
def action_scale(self, action):
span_value = action / (self.refiner.actions - 1)
device = next(self.parameters()).device
dtype = next(self.parameters()).dtype
span_scale = torch.tensor([span_value], device=device, dtype=dtype)
return span_scale
def _focus(self, query, key, value, span_scale, mask):
max_iterations = 10
iteration = 0
prev_attn = torch.zeros_like(input=query)
attn_out = torch.zeros_like(input=query)
attn_weights = None
threshold = self.threshold.item()
s_factor = self.s_factor.item()
while iteration < max_iterations:
span_len = int(self.max_span * span_scale.mean().item())
span_len = min(span_len, query.size(1), key.size(1), value.size(1))
eff_span = min(span_len, self.max_dist)
if eff_span == 0:
break
q_span = query[:, :eff_span, :]
k_span = key[:, :eff_span, :]
v_span = value[:, :eff_span, :]
batch_size, seq_len, dims = q_span.size()
d_k = dims // self.heads
scale_factor = 1 / math.sqrt(d_k)
q = q_span.view(batch_size, seq_len, self.heads, -1).transpose(1, 2)
k = k_span.view(batch_size, seq_len, self.heads, -1).transpose(1, 2)
v = v_span.view(batch_size, seq_len, self.heads, -1).transpose(1, 2)
if self.sharpen:
temperature = 1.0 + self.temp_scale * (1.0 - span_scale.mean().item())
else:
temperature = 0.5 + self.temp_scale * span_scale.mean().item()
attn_scores = (
torch.matmul(q, k.transpose(-2, -1)) * scale_factor / temperature
)
if mask.size(-2) != attn_scores.size(-2) or mask.size(
-1
) != attn_scores.size(-1):
mask_q_len = min(mask.size(-2), attn_scores.size(-2))
mask_k_len = min(mask.size(-1), attn_scores.size(-1))
resized_mask = torch.ones(
(
batch_size,
self.heads,
attn_scores.size(-2),
attn_scores.size(-1),
),
device=mask.device,
dtype=mask.dtype,
)
resized_mask[:, :, :mask_q_len, :mask_k_len] = mask[
:, :, :mask_q_len, :mask_k_len
]
mask = resized_mask
attn_scores = attn_scores.masked_fill(mask == 0, float("-inf"))
attn_weights = torch.softmax(attn_scores, dim=-1)
attn_out = torch.matmul(attn_weights, v)
attn_out = (
attn_out.transpose(1, 2).contiguous().view(batch_size, seq_len, -1)
)
diff = torch.abs(attn_out - prev_attn).mean()
dynamic_threshold = threshold + s_factor * diff
if diff < dynamic_threshold:
break
prev_attn = attn_out
query = query + attn_out
iteration += 1
return attn_out, attn_weights
def slide_win(self, x, win_size, span_len, span_scale, mask):
batch_size, seq_len, dims = x.size()
self.batch_size = batch_size
num_windows = (seq_len + win_size - 1) // win_size
output = torch.zeros_like(x)
device = x.device
default_mask = None
for i in range(num_windows):
start_idx = i * win_size
end_idx = min((i + 1) * win_size, seq_len)
window_size = end_idx - start_idx
key_start = max(0, start_idx - span_len + win_size)
key_end = min(start_idx + span_len, seq_len)
span_size = key_end - key_start
query = x[:, start_idx:end_idx, :]
key = x[:, key_start:key_end, :]
value = key
if mask is not None:
if mask.dim() == 4:
window_mask = mask[:, :, start_idx:end_idx, key_start:key_end]
if window_mask.size(1) == 1:
window_mask = window_mask.expand(-1, self.heads, -1, -1)
else:
if (
default_mask is None
or default_mask.size(-2) != window_size
or default_mask.size(-1) != span_size
):
default_mask = torch.ones(
(batch_size, self.heads, window_size, span_size),
device=device,
dtype=torch.bool,
)
window_mask = default_mask
else:
if (
default_mask is None
or default_mask.size(-2) != window_size
or default_mask.size(-1) != span_size
):
default_mask = torch.ones(
(batch_size, self.heads, window_size, span_size),
device=device,
dtype=torch.bool,
)
window_mask = default_mask
attn_out, _ = self._focus(
query=query,
key=key,
value=value,
span_scale=span_scale,
mask=window_mask,
)
output[:, start_idx:end_idx, :] = attn_out
return output
### optimizer
class MaxFactor(Optimizer):
def __init__(self, params, lr=0.01, beta2_decay=-0.8, eps=(1e-10, 1e-3), d=1.0,
weight_decay=0.01, gamma=0.99, eps_rms=1e-8, maximize=False):
defaults = dict(lr=lr, beta2_decay=beta2_decay, eps=eps, d=d, weight_decay=weight_decay,
gamma=gamma, eps_rms=eps_rms, maximize=maximize)
super().__init__(params=params, defaults=defaults)
def _get_lr(self, param_group, param_state):
step = param_state["step"]
step_float = step.item()
decay_factor = min(1.0, 1.0 / (step_float ** 0.5 + 1e-8))
param_scale = max(param_group["eps"][1], param_state["RMS"])
return min(param_group["lr"], param_scale * decay_factor)
@staticmethod
def _rms(tensor):
return tensor.norm() / (tensor.numel() ** 0.5)
@torch.no_grad()
def step(self, closure=None):
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
params_with_grad, grads, row_vars, col_vars, v, state_steps = [], [], [], [], [], []
eps1, eps2 = group["eps"]
for p in group["params"]:
if p.grad is None:
continue
grad = p.grad
if grad.dtype in {torch.float16, torch.bfloat16}:
grad = grad.float()
state = self.state[p]
if len(state) == 0:
state["step"] = torch.tensor(0.0, dtype=torch.float32)
if p.grad.dim() > 1:
row_shape, col_shape = list(p.grad.shape), list(p.grad.shape)
row_shape[-1], col_shape[-2] = 1, 1
state["row_var"], state["col_var"] = p.grad.new_zeros(row_shape), p.grad.new_zeros(col_shape)
state["v"] = torch.zeros_like(p, memory_format=torch.preserve_format)
row_vars.append(state.get("row_var", None))
col_vars.append(state.get("col_var", None))
v.append(state["v"])
state_steps.append(state["step"])
params_with_grad.append(p)
grads.append(grad)
for i, param in enumerate(params_with_grad):
grad = grads[i]
if group["maximize"]:
grad = -grad
step_t, row_var, col_var, vi = state_steps[i], row_vars[i], col_vars[i], v[i]
if eps1 is None:
eps1 = torch.finfo(param.dtype).eps
step_t += 1
step_float = step_t.item()
one_minus_beta2_t = step_float ** group["beta2_decay"]
state["RMS"] = self._rms(param).item()
adaptive_lr = self._get_lr(group, state)
rho_t = min(group["lr"], 1 / (step_float ** 0.5))
alpha = max(eps2, param.norm(2).item() / (param.numel() ** 0.5)) * rho_t
if group["weight_decay"] != 0:
param.mul_(1 - group["lr"] * group["weight_decay"])
if grad.dim() > 1:
row_mean = torch.norm(grad, dim=-1, keepdim=True).square_().div_(grad.size(-1) + 1e-8)
row_var.lerp_(row_mean, one_minus_beta2_t)
col_mean = torch.norm(grad, dim=-2, keepdim=True).square_().div_(grad.size(-2) + 1e-8)
col_var.lerp_(col_mean, one_minus_beta2_t)
var_estimate = row_var @ col_var
max_row_var = row_var.max(dim=-2, keepdim=True)[0]
var_estimate.div_(max_row_var.clamp_(min=eps1))
else:
vi.mul_(group["gamma"]).add_(grad ** 2, alpha=1 - group["gamma"])
var_estimate = vi
update = var_estimate.clamp_(min=eps1 * eps1).rsqrt_().mul_(grad)
update = update.div_(torch.norm(update, float('inf')).clamp_(min=eps1))
denom = max(1.0, update.norm(2).item() / ((update.numel() ** 0.5) * group["d"]))
param.add_(-adaptive_lr / denom * update.sign() * update.abs().max(dim=-1, keepdim=True)[0])
return loss
### experimental part of optimizer
#### experimental
def frequency_adaptive_momentum(grad, state, alpha=0.9, beta=0.999):
"""
Apply frequency-adaptive momentum to gradients.
Args:
grad: Current gradient
state: Optimizer state containing spectral history
alpha: Short-term frequency decay factor
beta: Long-term frequency decay factor
theta: Because we like thetas
Returns:
Updated gradient with frequency-adaptive momentum
"""
# Initialize state if needed
if "freq_history" not in state:
state["freq_history"] = {}
state["step_freq"] = 0
state["step_freq"] += 1
# For matrices (likely attention-related parameters)
if grad.dim() > 1 and min(grad.shape) > 4: # Only for substantial matrices
# Compute spectral signature using FFT on flattened gradient
with torch.no_grad():
# Sample spectral signature for efficiency
if grad.numel() > 10000:
# Sample along both dimensions for large matrices
row_indices = torch.randperm(grad.size(0))[:min(grad.size(0), 100)]
col_indices = torch.randperm(grad.size(1))[:min(grad.size(1), 100)]
grad_sample = grad[row_indices][:, col_indices].flatten()
else:
grad_sample = grad.flatten()
# Get frequency representation
freq_repr = torch.fft.rfft(grad_sample.float())
freq_power = torch.abs(freq_repr)
# Normalize power spectrum
if freq_power.sum() > 0:
freq_power = freq_power / freq_power.sum()
# Track frequency bands (divide spectrum into 10 bands)
n_bands = 10
band_size = freq_power.shape[0] // n_bands
band_powers = [freq_power[i*band_size:(i+1)*band_size].sum().item()
for i in range(n_bands)]
# Update frequency history with exponential averaging
for i, power in enumerate(band_powers):
if f"band_{i}" not in state["freq_history"]:
state["freq_history"][f"band_{i}"] = power
else:
state["freq_history"][f"band_{i}"] = (
beta * state["freq_history"][f"band_{i}"] +
(1-beta) * power
)
# Compute adaptive dampening factors based on frequency history
# High-frequency components get more dampening
dampening_factors = []
for i in range(n_bands):
# Higher bands get more dampening, but modulated by recent activity
base_dampening = i / n_bands # 0 to 0.9
recent_activity = state["freq_history"][f"band_{i}"]
# Bands with more recent activity get less dampening (more momentum)
adaptive_dampening = base_dampening * (1 - recent_activity * 5)
dampening_factors.append(max(0, min(0.9, adaptive_dampening)))
# Apply frequency-selective momentum to the gradient
if "momentum_buffer" not in state:
state["momentum_buffer"] = torch.zeros_like(grad)
# Apply band-specific momentum with inverse FFT
momentum_buffer = state["momentum_buffer"].flatten()
freq_momentum = torch.fft.rfft(momentum_buffer[:grad_sample.shape[0]].float())
# Apply different momentum factors to different frequency bands
for i in range(n_bands):
start_idx = i * band_size
end_idx = (i+1) * band_size
dampening = dampening_factors[i]
# Higher momentum for bands with higher recent activity
momentum_factor = alpha * (1 - dampening)
grad_factor = 1.0 + dampening # Boost gradient for damped frequencies
# Apply selective momentum in frequency domain
if start_idx < freq_momentum.shape[0]:
actual_end = min(end_idx, freq_momentum.shape[0])
freq_momentum[start_idx:actual_end] = (
momentum_factor * freq_momentum[start_idx:actual_end] +
grad_factor * freq_repr[start_idx:actual_end]
)
# Convert back to time domain and reshape
new_grad_sample = torch.fft.irfft(freq_momentum, n=grad_sample.shape[0])
# Update momentum buffer (in time domain)
state["momentum_buffer"] = alpha * state["momentum_buffer"] + (1-alpha) * grad
# Calculate adaptation factor to blend with original gradient
# Early steps: more gradient, later steps: more frequency adaptation
blend_factor = min(0.8, state["step_freq"] / 1000)
# Create a scaling mask based on frequency characteristics
scaling_mask = torch.ones_like(grad)
# For demonstration - actual implementation would map frequency insights
# back to the full gradient in a more sophisticated way
if state["step_freq"] > 100: # Only apply after initial training
# Example: Speech models often have issues with high-frequency noise
# Identify components likely responding to different frequencies
# Compute row and column variances as proxies for frequency response
row_var = grad.var(dim=1, keepdim=True)
col_var = grad.var(dim=0, keepdim=True)
# Normalize
row_var = row_var / (row_var.mean() + 1e-8)
col_var = col_var / (col_var.mean() + 1e-8)
# Create mask emphasizing stable gradient components
scaling_mask = 1.0 + 0.5 * (
torch.sigmoid(3 * (row_var - 1.5)) @
torch.sigmoid(3 * (col_var - 1.5)).T
)
# Apply adaptive mask to gradient
grad = grad * scaling_mask
return grad
else:
# For vectors and small matrices, use standard momentum
if "momentum_buffer" not in state:
state["momentum_buffer"] = torch.zeros_like(grad)
state["momentum_buffer"] = alpha * state["momentum_buffer"] + (1-alpha) * grad
return state["momentum_buffer"]
@torch.no_grad()
def step(self, closure=None):
for i, param in enumerate(params_with_grad):
grad = grads[i]
state = self.state[param]
# Apply frequency-adaptive momentum if enabled
if self.use_fam and param.dim() > 1:
grad = frequency_adaptive_momentum(
grad,
state,
alpha=self.fam_alpha,
beta=self.fam_beta
)
optimizer = MaxFactor(
model.parameters(),
lr=0.01,
beta2_decay=-0.8,
eps=(1e-10, 1e-4),
d=1.0,
weight_decay=0.01,
gamma=0.99,
eps_rms=1e-8,
maximize=False,
)