| | """
|
| | Advanced Optimizers for Large Scale Training
|
| | Including Lion, AdamW variants, and SAM
|
| | """
|
| |
|
| | import math
|
| | import torch
|
| | import torch.nn as nn
|
| | from torch.optim import Optimizer
|
| | from typing import Any, Dict, Optional, Tuple, Union
|
| |
|
| |
|
| | class Lion(Optimizer):
|
| | """
|
| | Lion optimizer from "Symbolic Discovery of Optimization Algorithms"
|
| | More memory efficient than AdamW for large models
|
| | """
|
| |
|
| | def __init__(
|
| | self,
|
| | params,
|
| | lr: float = 1e-4,
|
| | betas: Tuple[float, float] = (0.9, 0.99),
|
| | weight_decay: float = 0.0,
|
| | maximize: bool = False,
|
| | foreach: Optional[bool] = None,
|
| | ):
|
| | if not 0.0 <= lr:
|
| | raise ValueError(f"Invalid learning rate: {lr}")
|
| | if not 0.0 <= betas[0] < 1.0:
|
| | raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
|
| | if not 0.0 <= betas[1] < 1.0:
|
| | raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
|
| | if not 0.0 <= weight_decay:
|
| | raise ValueError(f"Invalid weight_decay value: {weight_decay}")
|
| |
|
| | defaults = dict(
|
| | lr=lr,
|
| | betas=betas,
|
| | weight_decay=weight_decay,
|
| | maximize=maximize,
|
| | foreach=foreach,
|
| | )
|
| | super().__init__(params, defaults)
|
| |
|
| | def __setstate__(self, state):
|
| | super().__setstate__(state)
|
| | for group in self.param_groups:
|
| | group.setdefault("maximize", False)
|
| | group.setdefault("foreach", None)
|
| |
|
| | @torch.no_grad()
|
| | def step(self, closure=None):
|
| | """Performs a single optimization step."""
|
| | loss = None
|
| | if closure is not None:
|
| | with torch.enable_grad():
|
| | loss = closure()
|
| |
|
| | for group in self.param_groups:
|
| | params_with_grad = []
|
| | grads = []
|
| | exp_avgs = []
|
| |
|
| | beta1, beta2 = group["betas"]
|
| |
|
| | for p in group["params"]:
|
| | if p.grad is None:
|
| | continue
|
| | params_with_grad.append(p)
|
| | if p.grad.dtype in {torch.float16, torch.bfloat16}:
|
| | grads.append(p.grad.float())
|
| | else:
|
| | grads.append(p.grad)
|
| |
|
| | state = self.state[p]
|
| |
|
| | if len(state) == 0:
|
| | state["exp_avg"] = torch.zeros_like(p, memory_format=torch.preserve_format)
|
| |
|
| | exp_avgs.append(state["exp_avg"])
|
| |
|
| | lion(
|
| | params_with_grad,
|
| | grads,
|
| | exp_avgs,
|
| | beta1=beta1,
|
| | beta2=beta2,
|
| | lr=group["lr"],
|
| | weight_decay=group["weight_decay"],
|
| | maximize=group["maximize"],
|
| | )
|
| |
|
| | return loss
|
| |
|
| |
|
| | def lion(
|
| | params,
|
| | grads,
|
| | exp_avgs,
|
| | *,
|
| | beta1: float,
|
| | beta2: float,
|
| | lr: float,
|
| | weight_decay: float,
|
| | maximize: bool,
|
| | ):
|
| | """Functional API that performs Lion algorithm computation."""
|
| |
|
| | for i, param in enumerate(params):
|
| | grad = grads[i] if not maximize else -grads[i]
|
| | exp_avg = exp_avgs[i]
|
| |
|
| |
|
| | param.mul_(1 - lr * weight_decay)
|
| |
|
| |
|
| | update = exp_avg * beta1 + grad * (1 - beta1)
|
| | param.add_(torch.sign(update), alpha=-lr)
|
| |
|
| |
|
| | exp_avg.mul_(beta2).add_(grad, alpha=1 - beta2)
|
| |
|
| |
|
| | class AdamWScale(torch.optim.AdamW):
|
| | """
|
| | AdamW with learning rate scaling based on parameter norm
|
| | Useful for very large models
|
| | """
|
| |
|
| | def __init__(self, *args, scale_lr: bool = True, **kwargs):
|
| | super().__init__(*args, **kwargs)
|
| | self.scale_lr = scale_lr
|
| |
|
| | def step(self, closure=None):
|
| | if not self.scale_lr:
|
| | return super().step(closure)
|
| |
|
| |
|
| | for group in self.param_groups:
|
| | total_norm = 0.0
|
| | for p in group['params']:
|
| | if p.grad is not None:
|
| | param_norm = p.data.norm()
|
| | total_norm += param_norm.item() ** 2
|
| | total_norm = total_norm ** (1. / 2)
|
| |
|
| |
|
| | if total_norm > 0:
|
| | scale = min(1.0, 1.0 / total_norm)
|
| | group['lr'] = group['lr'] * scale
|
| |
|
| | return super().step(closure)
|
| |
|
| |
|
| | class SAM(Optimizer):
|
| | """
|
| | Sharpness-Aware Minimization (SAM) optimizer
|
| | Improves generalization by finding flatter minima
|
| | """
|
| |
|
| | def __init__(self, params, base_optimizer, rho=0.05, adaptive=False, **kwargs):
|
| | assert rho >= 0.0, f"Invalid rho, should be non-negative: {rho}"
|
| |
|
| | defaults = dict(rho=rho, adaptive=adaptive, **kwargs)
|
| | super(SAM, self).__init__(params, defaults)
|
| |
|
| | self.base_optimizer = base_optimizer(self.param_groups, **kwargs)
|
| | self.param_groups = self.base_optimizer.param_groups
|
| | self.defaults.update(self.base_optimizer.defaults)
|
| |
|
| | @torch.no_grad()
|
| | def first_step(self, zero_grad=False):
|
| | grad_norm = self._grad_norm()
|
| | for group in self.param_groups:
|
| | scale = group["rho"] / (grad_norm + 1e-12)
|
| |
|
| | for p in group["params"]:
|
| | if p.grad is None:
|
| | continue
|
| | self.state[p]["old_p"] = p.data.clone()
|
| | e_w = (torch.pow(p, 2) if group["adaptive"] else 1.0) * p.grad * scale.to(p)
|
| | p.add_(e_w)
|
| |
|
| | if zero_grad:
|
| | self.zero_grad()
|
| |
|
| | @torch.no_grad()
|
| | def second_step(self, zero_grad=False):
|
| | for group in self.param_groups:
|
| | for p in group["params"]:
|
| | if p.grad is None:
|
| | continue
|
| | p.data = self.state[p]["old_p"]
|
| |
|
| | self.base_optimizer.step()
|
| |
|
| | if zero_grad:
|
| | self.zero_grad()
|
| |
|
| | @torch.no_grad()
|
| | def step(self, closure=None):
|
| | assert closure is not None, "Sharpness Aware Minimization requires closure, but it was not provided"
|
| | closure = torch.enable_grad()(closure)
|
| |
|
| | self.first_step(zero_grad=True)
|
| | closure()
|
| | self.second_step()
|
| |
|
| | def _grad_norm(self):
|
| | shared_device = self.param_groups[0]["params"][0].device
|
| | norm = torch.norm(
|
| | torch.stack([
|
| | ((torch.abs(p) if group["adaptive"] else 1.0) * p.grad).norm(dtype=torch.float32).to(shared_device)
|
| | for group in self.param_groups for p in group["params"]
|
| | if p.grad is not None
|
| | ]),
|
| | dtype=torch.float32
|
| | )
|
| | return norm
|
| |
|
| | def load_state_dict(self, state_dict):
|
| | super().load_state_dict(state_dict)
|
| | self.base_optimizer.param_groups = self.param_groups
|
| |
|
| | def state_dict(self):
|
| | return super().state_dict()
|
| |
|
| |
|
| | class Sophia(Optimizer):
|
| | """
|
| | Sophia optimizer - Second-order clipped stochastic optimization
|
| | More efficient than Adam for large language models
|
| | """
|
| |
|
| | def __init__(
|
| | self,
|
| | params,
|
| | lr=1e-4,
|
| | betas=(0.965, 0.99),
|
| | rho=0.04,
|
| | weight_decay=1e-1,
|
| | *,
|
| | maximize: bool = False,
|
| | capturable: bool = False,
|
| | ):
|
| | if not 0.0 <= lr:
|
| | raise ValueError(f"Invalid learning rate: {lr}")
|
| | if not 0.0 <= betas[0] < 1.0:
|
| | raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
|
| | if not 0.0 <= betas[1] < 1.0:
|
| | raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
|
| | if not 0.0 <= rho:
|
| | raise ValueError(f"Invalid rho parameter: {rho}")
|
| | if not 0.0 <= weight_decay:
|
| | raise ValueError(f"Invalid weight_decay value: {weight_decay}")
|
| |
|
| | defaults = dict(
|
| | lr=lr, betas=betas, rho=rho, weight_decay=weight_decay,
|
| | maximize=maximize, capturable=capturable
|
| | )
|
| | super().__init__(params, defaults)
|
| |
|
| | def __setstate__(self, state):
|
| | super().__setstate__(state)
|
| | for group in self.param_groups:
|
| | group.setdefault('maximize', False)
|
| | group.setdefault('capturable', False)
|
| | state_values = list(self.state.values())
|
| | step_is_tensor = (len(state_values) != 0) and torch.is_tensor(state_values[0]['step'])
|
| | if not step_is_tensor:
|
| | for s in state_values:
|
| | s['step'] = torch.tensor(float(s['step']))
|
| |
|
| | @torch.no_grad()
|
| | def update_hessian(self):
|
| | """Update Hessian diagonal approximation"""
|
| | for group in self.param_groups:
|
| | for p in group['params']:
|
| | if p.grad is not None:
|
| | state = self.state[p]
|
| | if len(state) == 0:
|
| | state['step'] = torch.zeros((1,), dtype=torch.float, device=p.device) \
|
| | if group['capturable'] else torch.tensor(0.)
|
| | state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
|
| | state['hessian'] = torch.zeros_like(p, memory_format=torch.preserve_format)
|
| |
|
| | hessian = state['hessian']
|
| | beta2 = group['betas'][1]
|
| |
|
| |
|
| | hessian.mul_(beta2).addcmul_(p.grad, p.grad, value=1 - beta2)
|
| |
|
| | @torch.no_grad()
|
| | def step(self, closure=None, bs=5120):
|
| | """Performs a single optimization step."""
|
| | loss = None
|
| | if closure is not None:
|
| | with torch.enable_grad():
|
| | loss = closure()
|
| |
|
| | for group in self.param_groups:
|
| | params_with_grad = []
|
| | grads = []
|
| | exp_avgs = []
|
| | hessians = []
|
| | state_steps = []
|
| | beta1, beta2 = group['betas']
|
| |
|
| | for p in group['params']:
|
| | if p.grad is None:
|
| | continue
|
| | params_with_grad.append(p)
|
| | if p.grad.dtype in {torch.float16, torch.bfloat16}:
|
| | grads.append(p.grad.float())
|
| | else:
|
| | grads.append(p.grad)
|
| |
|
| | state = self.state[p]
|
| |
|
| | if len(state) == 0:
|
| | state['step'] = torch.zeros((1,), dtype=torch.float, device=p.device) \
|
| | if group['capturable'] else torch.tensor(0.)
|
| | state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
|
| | state['hessian'] = torch.zeros_like(p, memory_format=torch.preserve_format)
|
| |
|
| | exp_avgs.append(state['exp_avg'])
|
| | hessians.append(state['hessian'])
|
| | state_steps.append(state['step'])
|
| |
|
| | sophia(
|
| | params_with_grad,
|
| | grads,
|
| | exp_avgs,
|
| | hessians,
|
| | state_steps,
|
| | beta1=beta1,
|
| | beta2=beta2,
|
| | rho=group['rho'],
|
| | lr=group['lr'],
|
| | weight_decay=group['weight_decay'],
|
| | maximize=group['maximize'],
|
| | capturable=group['capturable'],
|
| | )
|
| |
|
| | return loss
|
| |
|
| |
|
| | def sophia(
|
| | params,
|
| | grads,
|
| | exp_avgs,
|
| | hessians,
|
| | state_steps,
|
| | capturable: bool = False,
|
| | *,
|
| | beta1: float,
|
| | beta2: float,
|
| | rho: float,
|
| | lr: float,
|
| | weight_decay: float,
|
| | maximize: bool,
|
| | ):
|
| | """Functional API that performs Sophia algorithm computation."""
|
| |
|
| | if not all(isinstance(t, torch.Tensor) for t in state_steps):
|
| | raise RuntimeError("API has changed, `state_steps` argument must contain a list of singleton tensors")
|
| |
|
| | for i, param in enumerate(params):
|
| | grad = grads[i] if not maximize else -grads[i]
|
| | exp_avg = exp_avgs[i]
|
| | hessian = hessians[i]
|
| | step_t = state_steps[i]
|
| |
|
| | if capturable:
|
| | bs = torch.ones_like(step_t) * 5120
|
| | assert param.dtype == torch.float32
|
| |
|
| |
|
| | step_t += 1
|
| |
|
| |
|
| | param.mul_(1 - lr * weight_decay)
|
| |
|
| |
|
| | exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
|
| |
|
| |
|
| | bias_correction1 = 1 - beta1 ** step_t.item()
|
| |
|
| |
|
| | k = hessian.abs().clamp_(min=1e-8)
|
| | u = (exp_avg / bias_correction1) / k.sqrt()
|
| | u.clamp_(min=-rho, max=rho)
|
| |
|
| | param.add_(u, alpha=-lr)
|
| |
|
| |
|
| | def get_optimizer(model, config):
|
| | """Get optimizer based on configuration"""
|
| |
|
| | decay_params = []
|
| | no_decay_params = []
|
| |
|
| | for name, param in model.named_parameters():
|
| | if not param.requires_grad:
|
| | continue
|
| | if param.ndim < 2 or 'bias' in name or 'norm' in name.lower():
|
| | no_decay_params.append(param)
|
| | else:
|
| | decay_params.append(param)
|
| |
|
| | optim_groups = [
|
| | {'params': decay_params, 'weight_decay': config.weight_decay},
|
| | {'params': no_decay_params, 'weight_decay': 0.0}
|
| | ]
|
| |
|
| | optimizer_name = config.optimizer.lower()
|
| |
|
| | if optimizer_name == "adamw":
|
| | optimizer = torch.optim.AdamW(
|
| | optim_groups,
|
| | lr=float(config.learning_rate),
|
| | betas=(float(config.beta1), float(config.beta2)),
|
| | eps=float(config.eps),
|
| | weight_decay=float(config.weight_decay),
|
| | )
|
| | elif optimizer_name == "adamw_scale":
|
| | optimizer = AdamWScale(
|
| | optim_groups,
|
| | lr=float(config.learning_rate),
|
| | betas=(float(config.beta1), float(config.beta2)),
|
| | eps=float(config.eps),
|
| | weight_decay=float(config.weight_decay),
|
| | scale_lr=True,
|
| | )
|
| | elif optimizer_name == "lion":
|
| | optimizer = Lion(
|
| | optim_groups,
|
| | lr=float(config.learning_rate) * 0.3,
|
| | betas=(float(config.beta1), float(config.beta2)),
|
| | weight_decay=float(config.weight_decay),
|
| | )
|
| | elif optimizer_name == "sophia":
|
| | optimizer = Sophia(
|
| | optim_groups,
|
| | lr=float(config.learning_rate),
|
| | betas=(float(config.beta1), float(config.beta2)),
|
| | rho=0.04,
|
| | weight_decay=float(config.weight_decay),
|
| | )
|
| | elif optimizer_name == "sam_adamw":
|
| | base_optimizer = torch.optim.AdamW
|
| | optimizer = SAM(
|
| | optim_groups,
|
| | base_optimizer,
|
| | rho=0.05,
|
| | adaptive=False,
|
| | lr=float(config.learning_rate),
|
| | betas=(float(config.beta1), float(config.beta2)),
|
| | eps=float(config.eps),
|
| | weight_decay=float(config.weight_decay),
|
| | )
|
| | else:
|
| | raise ValueError(f"Unknown optimizer: {optimizer_name}")
|
| |
|
| | return optimizer
|
| |
|
| |
|
| | def get_scheduler(optimizer, config):
|
| | """Get learning rate scheduler"""
|
| | scheduler_name = config.lr_scheduler.lower()
|
| |
|
| | if scheduler_name == "cosine":
|
| | from torch.optim.lr_scheduler import CosineAnnealingLR
|
| | scheduler = CosineAnnealingLR(
|
| | optimizer,
|
| | T_max=config.total_steps - config.warmup_steps,
|
| | eta_min=getattr(config, 'lr_scheduler_kwargs', {}).get('eta_min', 0),
|
| | )
|
| | elif scheduler_name == "linear":
|
| | from torch.optim.lr_scheduler import LinearLR
|
| | scheduler = LinearLR(
|
| | optimizer,
|
| | start_factor=1.0,
|
| | end_factor=0.1,
|
| | total_iters=config.total_steps - config.warmup_steps,
|
| | )
|
| | elif scheduler_name == "polynomial":
|
| | from torch.optim.lr_scheduler import PolynomialLR
|
| | scheduler = PolynomialLR(
|
| | optimizer,
|
| | total_iters=config.total_steps - config.warmup_steps,
|
| | power=getattr(config, 'lr_scheduler_kwargs', {}).get('power', 1.0),
|
| | )
|
| | else:
|
| | return None
|
| |
|
| |
|
| | if config.warmup_steps > 0:
|
| | from torch.optim.lr_scheduler import LinearLR, SequentialLR
|
| | warmup_scheduler = LinearLR(
|
| | optimizer,
|
| | start_factor=1e-8,
|
| | end_factor=1.0,
|
| | total_iters=config.warmup_steps
|
| | )
|
| |
|
| | scheduler = SequentialLR(
|
| | optimizer,
|
| | schedulers=[warmup_scheduler, scheduler],
|
| | milestones=[config.warmup_steps]
|
| | )
|
| |
|
| | return scheduler
|
| |
|