|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
|
|
|
class LARS(torch.optim.Optimizer): |
|
""" |
|
LARS optimizer, no rate scaling or weight decay for parameters <= 1D. |
|
""" |
|
|
|
def __init__( |
|
self, params, lr=0, weight_decay=0, momentum=0.9, trust_coefficient=0.001 |
|
): |
|
defaults = dict( |
|
lr=lr, |
|
weight_decay=weight_decay, |
|
momentum=momentum, |
|
trust_coefficient=trust_coefficient, |
|
) |
|
super().__init__(params, defaults) |
|
|
|
@torch.no_grad() |
|
def step(self): |
|
for g in self.param_groups: |
|
for p in g["params"]: |
|
dp = p.grad |
|
|
|
if dp is None: |
|
continue |
|
|
|
if p.ndim > 1: |
|
dp = dp.add(p, alpha=g["weight_decay"]) |
|
param_norm = torch.norm(p) |
|
update_norm = torch.norm(dp) |
|
one = torch.ones_like(param_norm) |
|
q = torch.where( |
|
param_norm > 0.0, |
|
torch.where( |
|
update_norm > 0, |
|
(g["trust_coefficient"] * param_norm / update_norm), |
|
one, |
|
), |
|
one, |
|
) |
|
dp = dp.mul(q) |
|
|
|
param_state = self.state[p] |
|
if "mu" not in param_state: |
|
param_state["mu"] = torch.zeros_like(p) |
|
mu = param_state["mu"] |
|
mu.mul_(g["momentum"]).add_(dp) |
|
p.add_(mu, alpha=-g["lr"]) |
|
|