File size: 2,539 Bytes
fc54e43 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 |
# =============================================================================
# training/optimizer.py
# =============================================================================
import torch
import torch.optim as optim
from torch.optim.lr_scheduler import LambdaLR
import math
from typing import Dict, List
class MambaOptimizer:
"""Optimizer setup for Mamba models"""
def __init__(self, model, config):
self.config = config
self.model = model
# Separate parameters that should and shouldn't have weight decay
decay_params = []
no_decay_params = []
for name, param in model.named_parameters():
if param.requires_grad:
# Don't apply weight decay to biases and layer norms
if 'bias' in name or 'norm' in name or 'embedding' in name:
no_decay_params.append(param)
else:
decay_params.append(param)
# Create parameter groups
param_groups = [
{'params': decay_params, 'weight_decay': config.weight_decay},
{'params': no_decay_params, 'weight_decay': 0.0}
]
# Initialize optimizer
self.optimizer = optim.AdamW(
param_groups,
lr=config.learning_rate,
betas=(0.9, 0.95),
eps=1e-8
)
# Learning rate scheduler
self.scheduler = self._create_scheduler()
def _create_scheduler(self):
"""Create learning rate scheduler with warmup"""
def lr_lambda(step):
if step < self.config.warmup_steps:
# Linear warmup
return step / self.config.warmup_steps
else:
# Cosine decay
progress = (step - self.config.warmup_steps) / (self.config.max_steps - self.config.warmup_steps)
return 0.5 * (1 + math.cos(math.pi * progress))
return LambdaLR(self.optimizer, lr_lambda)
def step(self):
"""Optimizer step with gradient clipping"""
# Gradient clipping
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
# Optimizer step
self.optimizer.step()
self.scheduler.step()
return self.scheduler.get_last_lr()[0]
def zero_grad(self):
"""Zero gradients"""
self.optimizer.zero_grad()
|