|
|
|
|
|
|
|
import torch
|
|
import torch.optim as optim
|
|
from torch.optim.lr_scheduler import LambdaLR
|
|
import math
|
|
from typing import Dict, List
|
|
|
|
class MambaOptimizer:
|
|
"""Optimizer setup for Mamba models"""
|
|
|
|
def __init__(self, model, config):
|
|
self.config = config
|
|
self.model = model
|
|
|
|
|
|
decay_params = []
|
|
no_decay_params = []
|
|
|
|
for name, param in model.named_parameters():
|
|
if param.requires_grad:
|
|
|
|
if 'bias' in name or 'norm' in name or 'embedding' in name:
|
|
no_decay_params.append(param)
|
|
else:
|
|
decay_params.append(param)
|
|
|
|
|
|
param_groups = [
|
|
{'params': decay_params, 'weight_decay': config.weight_decay},
|
|
{'params': no_decay_params, 'weight_decay': 0.0}
|
|
]
|
|
|
|
|
|
self.optimizer = optim.AdamW(
|
|
param_groups,
|
|
lr=config.learning_rate,
|
|
betas=(0.9, 0.95),
|
|
eps=1e-8
|
|
)
|
|
|
|
|
|
self.scheduler = self._create_scheduler()
|
|
|
|
def _create_scheduler(self):
|
|
"""Create learning rate scheduler with warmup"""
|
|
def lr_lambda(step):
|
|
if step < self.config.warmup_steps:
|
|
|
|
return step / self.config.warmup_steps
|
|
else:
|
|
|
|
progress = (step - self.config.warmup_steps) / (self.config.max_steps - self.config.warmup_steps)
|
|
return 0.5 * (1 + math.cos(math.pi * progress))
|
|
|
|
return LambdaLR(self.optimizer, lr_lambda)
|
|
|
|
def step(self):
|
|
"""Optimizer step with gradient clipping"""
|
|
|
|
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
|
|
|
|
|
|
self.optimizer.step()
|
|
self.scheduler.step()
|
|
|
|
return self.scheduler.get_last_lr()[0]
|
|
|
|
def zero_grad(self):
|
|
"""Zero gradients"""
|
|
self.optimizer.zero_grad()
|
|
|