Spaces:
Runtime error
Runtime error
import numpy as np | |
import torch | |
import enum | |
import itertools | |
from dataclasses import dataclass | |
import torch.optim as optim | |
def PowerIter(mat_g, error_tolerance=1e-6, num_iters=100): | |
"""Power iteration. | |
Compute the maximum eigenvalue of mat, for scaling. | |
v is a random vector with values in (-1, 1) | |
Args: | |
mat_g: the symmetric PSD matrix. | |
error_tolerance: Iterative exit condition. | |
num_iters: Number of iterations. | |
Returns: | |
eigen vector, eigen value, num_iters | |
""" | |
v = torch.rand(list(mat_g.shape)[0], device=mat_g.get_device()) * 2 - 1 | |
error = 1 | |
iters = 0 | |
singular_val = 0 | |
while error > error_tolerance and iters < num_iters: | |
v = v / torch.norm(v) | |
mat_v = torch.mv(mat_g, v) | |
s_v = torch.dot(v, mat_v) | |
error = torch.abs(s_v - singular_val) | |
v = mat_v | |
singular_val = s_v | |
iters += 1 | |
return singular_val, v / torch.norm(v), iters | |
def MatPower(mat_m, p): | |
"""Computes mat_m^p, for p a positive integer. | |
Args: | |
mat_m: a square matrix | |
p: a positive integer | |
Returns: | |
mat_m^p | |
""" | |
if p in [1, 2, 4, 8, 16, 32]: | |
p_done = 1 | |
res = mat_m | |
while p_done < p: | |
res = torch.matmul(res, res) | |
p_done *= 2 | |
return res | |
power = None | |
while p > 0: | |
if p % 2 == 1: | |
power = torch.matmul(mat_m, power) if power is not None else mat_m | |
p //= 2 | |
mat_m = torch.matmul(mat_m, mat_m) | |
return power | |
def ComputePower(mat_g, p, | |
iter_count=100, | |
error_tolerance=1e-6, | |
ridge_epsilon=1e-6): | |
"""A method to compute G^{-1/p} using a coupled Newton iteration. | |
See for example equation 3.2 on page 9 of: | |
A Schur-Newton Method for the Matrix p-th Root and its Inverse | |
by Chun-Hua Guo and Nicholas J. Higham | |
SIAM Journal on Matrix Analysis and Applications, | |
2006, Vol. 28, No. 3 : pp. 788-804 | |
https://pdfs.semanticscholar.org/0abe/7f77433cf5908bfe2b79aa91af881da83858.pdf | |
Args: | |
mat_g: A square positive semidefinite matrix | |
p: a positive integer | |
iter_count: Stop iterating after this many rounds. | |
error_tolerance: Threshold for stopping iteration | |
ridge_epsilon: We add this times I to G, to make is positive definite. | |
For scaling, we multiply it by the largest eigenvalue of G. | |
Returns: | |
(mat_g + rI)^{-1/p} (r = ridge_epsilon * max_eigenvalue of mat_g). | |
""" | |
shape = list(mat_g.shape) | |
if len(shape) == 1: | |
return torch.pow(mat_g + ridge_epsilon, -1/p) | |
identity = torch.eye(shape[0], device=mat_g.get_device()) | |
if shape[0] == 1: | |
return identity | |
alpha = -1.0/p | |
max_ev, _, _ = PowerIter(mat_g) | |
ridge_epsilon *= max_ev | |
mat_g += ridge_epsilon * identity | |
z = (1 + p) / (2 * torch.norm(mat_g)) | |
# The best value for z is | |
# (1 + p) * (c_max^{1/p} - c_min^{1/p}) / | |
# (c_max^{1+1/p} - c_min^{1+1/p}) | |
# where c_max and c_min are the largest and smallest singular values of | |
# mat_g. | |
# The above estimate assumes that c_max > c_min * 2^p | |
# Can replace above line by the one below, but it is less accurate, | |
# hence needs more iterations to converge. | |
# z = (1 + p) / tf.trace(mat_g) | |
# If we want the method to always converge, use z = 1 / norm(mat_g) | |
# or z = 1 / tf.trace(mat_g), but these can result in many | |
# extra iterations. | |
mat_root = identity * torch.pow(z, 1.0/p) | |
mat_m = mat_g * z | |
error = torch.max(torch.abs(mat_m - identity)) | |
count = 0 | |
while error > error_tolerance and count < iter_count: | |
tmp_mat_m = (1 - alpha) * identity + alpha * mat_m | |
new_mat_root = torch.matmul(mat_root, tmp_mat_m) | |
mat_m = torch.matmul(MatPower(tmp_mat_m, p), mat_m) | |
new_error = torch.max(torch.abs(mat_m - identity)) | |
if new_error > error * 1.2: | |
break | |
mat_root = new_mat_root | |
error = new_error | |
count += 1 | |
return mat_root | |
# Grafting is a technique to fix the layerwise scale of Shampoo optimizer. | |
# https://arxiv.org/pdf/2002.11803.pdf studies this in detail. This | |
# allows us to plugin the Shampoo optimizer into settings where SGD/AdaGrad | |
# is already well tuned. Grafting onto Shampoo means take the Shampoo direction, | |
# but use the step magnitude from the grafted optimizer such as Adagrad or SGD. | |
class LayerwiseGrafting(enum.IntEnum): | |
NONE = 0 | |
SGD = 1 | |
ADAGRAD = 2 | |
class ShampooHyperParams: | |
"""Shampoo hyper parameters.""" | |
beta2: float = 0.9 | |
diagonal_eps: float = 1e-6 | |
matrix_eps: float = 1e-12 | |
weight_decay: float = 0.0 | |
inverse_exponent_override: int = 2 # fixed exponent for preconditioner, if >0 | |
start_preconditioning_step: int = 1 | |
# Performance tuning params for controlling memory and compute requirements. | |
# How often to compute preconditioner. | |
preconditioning_compute_steps: int = 1 | |
# How often to compute statistics. | |
statistics_compute_steps: int = 1 | |
# Block size for large layers (if > 0). | |
# Block size = 1 ==> Adagrad (Don't do this, extremely inefficient!) | |
# Block size should be as large as feasible under memory/time constraints. | |
block_size: int = 128 | |
# Automatic shape interpretation (for eg: [4, 3, 1024, 512] would result in | |
# 12 x [1024, 512] L and R statistics. Disabled by default which results in | |
# Shampoo constructing statistics [4, 4], [3, 3], [1024, 1024], [512, 512]. | |
best_effort_shape_interpretation: bool = True | |
# Type of grafting (SGD or AdaGrad). | |
# https://arxiv.org/pdf/2002.11803.pdf | |
graft_type: int = LayerwiseGrafting.ADAGRAD | |
# Nesterov momentum | |
nesterov: bool = True | |
class Graft: | |
"""Base class to perform grafting onto Shampoo. This class does no grafting. | |
""" | |
def __init__(self, hps, unused_var): | |
self.hps = hps | |
def add_statistics(self, grad): | |
pass | |
def precondition_gradient(self, grad): | |
return grad | |
def update_momentum(self, update, unused_beta1): | |
return update | |
class SGDGraft(Graft): | |
"""Graft using SGD+momentum. | |
momentum maintains an exponentially weighted moving average of gradients. | |
""" | |
def __init__(self, hps, var): | |
super(SGDGraft, self).__init__(hps, var) | |
self.momentum = torch.zeros_like(var.data, device=var.get_device()) | |
def update_momentum(self, update, beta1): | |
self.momentum.mul_(beta1).add_(update) | |
return self.momentum | |
class AdagradGraft(SGDGraft): | |
"""Graft using Adagrad. | |
Essentially an implementation of Adagrad with momentum. | |
""" | |
def __init__(self, hps, var): | |
super(AdagradGraft, self).__init__(hps, var) | |
self.statistics = torch.zeros_like(var.data, device=var.get_device()) | |
def add_statistics(self, grad): | |
self.statistics.add_(grad * grad) | |
def precondition_gradient(self, grad): | |
return grad / (torch.sqrt(self.statistics) + self.hps.diagonal_eps) | |
class BlockPartitioner: | |
"""Partitions a tensor into smaller tensors for preconditioning. | |
For example, if a variable has shape (4096, 512), we might split the | |
4096 into 4 blocks, so we effectively have 4 variables of size | |
(1024, 512) each. | |
""" | |
def __init__(self, var, hps): | |
self._shape = var.shape | |
self._splits = [] | |
self._split_sizes = [] | |
split_sizes = [] | |
# We split var into smaller blocks. Here we store the metadata to make | |
# that split. | |
for i, d in enumerate(var.shape): | |
if hps.block_size > 0 and d > hps.block_size: | |
# d-1, otherwise split appends a 0-size array. | |
nsplit = (d-1) // hps.block_size | |
indices = (np.arange(nsplit, dtype=np.int32) + 1) * hps.block_size | |
sizes = np.ones(nsplit + 1, dtype=np.int32) * hps.block_size | |
sizes[-1] = d - indices[-1] | |
self._splits.append((i, indices)) | |
self._split_sizes.append((i, sizes)) | |
split_sizes.append(sizes) | |
else: | |
split_sizes.append(np.array([d], dtype=np.int32)) | |
self._num_splits = len(split_sizes) | |
self._preconditioner_shapes = [] | |
for t in itertools.product(*split_sizes): | |
self._preconditioner_shapes.extend([[d, d] for d in t]) | |
def shapes_for_preconditioners(self): | |
return self._preconditioner_shapes | |
def num_splits(self): | |
return self._num_splits | |
def partition(self, tensor): | |
"""Partition tensor into blocks.""" | |
assert tensor.shape == self._shape | |
tensors = [tensor] | |
for (i, sizes) in self._split_sizes: | |
tensors_local = [] | |
for t in tensors: | |
tensors_local.extend( | |
torch.split(t, tuple(sizes), dim=i)) | |
tensors = tensors_local | |
return tensors | |
def merge_partitions(self, partitions): | |
"""Merge partitions back to original shape.""" | |
for (i, indices) in reversed(self._splits): | |
n = len(indices) + 1 | |
partial_merged_tensors = [] | |
ind = 0 | |
while ind < len(partitions): | |
partial_merged_tensors.append( | |
torch.cat(partitions[ind:ind + n], axis=i)) | |
ind += n | |
partitions = partial_merged_tensors | |
assert len(partitions) == 1 | |
return partitions[0] | |
def _merge_small_dims(shape_to_merge, max_dim): | |
"""Merge small dimensions. | |
If there are some small dimensions, we collapse them: | |
e.g. [1, 2, 512, 1, 2048, 1, 3, 4] --> [1024, 2048, 12] if max_dim = 1024 | |
[1, 2, 768, 1, 2048] --> [2, 768, 2048] | |
Args: | |
shape_to_merge: Shape to merge small dimensions. | |
max_dim: Maximal dimension of output shape used in merging. | |
Returns: | |
Merged shape. | |
""" | |
resulting_shape = [] | |
product = 1 | |
for d in shape_to_merge: | |
if product * d <= max_dim: | |
product *= d | |
else: | |
if product > 1: | |
resulting_shape.append(product) | |
product = d | |
if product > 1: | |
resulting_shape.append(product) | |
return resulting_shape | |
class Preconditioner: | |
"""Compute statistics/shape from gradients for preconditioning.""" | |
def __init__(self, var, hps): | |
self._hps = hps | |
self._original_shape = var.shape | |
self._transformed_shape = var.shape | |
if hps.best_effort_shape_interpretation: | |
self._transformed_shape = _merge_small_dims( | |
self._original_shape, hps.block_size) | |
reshaped_var = torch.reshape(var, self._transformed_shape) | |
self._partitioner = BlockPartitioner(reshaped_var, hps) | |
shapes = self._partitioner.shapes_for_preconditioners() | |
rank = len(self._transformed_shape) | |
device = var.get_device() | |
if rank <= 1: | |
self.statistics = [] | |
self.preconditioners = [] | |
else: | |
eps = self._hps.matrix_eps | |
self.statistics = [eps * torch.eye(s[0], device=device) for s in shapes] | |
self.preconditioners = [torch.eye(s[0], device=device) for s in shapes] | |
def add_statistics(self, grad): | |
"""Compute statistics from gradients and add to the correct state entries. | |
Args: | |
grad: Gradient to compute statistics from. | |
""" | |
if not self.statistics: return | |
reshaped_grad = torch.reshape(grad, self._transformed_shape) | |
partitioned_grads = self._partitioner.partition(reshaped_grad) | |
w1 = self._hps.beta2 | |
w2 = 1.0 if w1 == 1.0 else (1.0 - w1) | |
rank = len(self._transformed_shape) | |
for j, grad in enumerate(partitioned_grads): | |
for i in range(rank): | |
axes = list(range(i)) + list(range(i + 1, rank)) | |
stat = torch.tensordot(grad, grad, [axes, axes]) | |
self.statistics[j*rank + i].mul_(w1).add_(stat, alpha=w2) | |
def exponent_for_preconditioner(self): | |
"""Returns exponent to use for inverse-pth root M^{-1/p}.""" | |
if self._hps.inverse_exponent_override > 0: | |
return self._hps.inverse_exponent_override | |
return 2 * len(self._transformed_shape) | |
def compute_preconditioners(self): | |
"""Compute L^{-1/exp} for each stats matrix L.""" | |
exp = self.exponent_for_preconditioner() | |
eps = self._hps.matrix_eps | |
for i, stat in enumerate(self.statistics): | |
self.preconditioners[i] = ComputePower( | |
stat, exp, ridge_epsilon=eps) | |
def preconditioned_grad(self, grad): | |
"""Precondition the gradient. | |
Args: | |
grad: A gradient tensor to precondition. | |
Returns: | |
A preconditioned gradient. | |
""" | |
if not self.preconditioners: return grad | |
reshaped_grad = torch.reshape(grad, self._transformed_shape) | |
partitioned_grads = self._partitioner.partition(reshaped_grad) | |
preconditioned_partitioned_grads = [] | |
num_splits = self._partitioner.num_splits() | |
for i, grad in enumerate(partitioned_grads): | |
preconditioners_for_grad = self.preconditioners[i * num_splits:(i + 1) * | |
num_splits] | |
rank = len(grad.shape) | |
precond_grad = grad | |
for j in range(rank): | |
preconditioner = preconditioners_for_grad[j] | |
precond_grad = torch.tensordot( | |
precond_grad, preconditioner, [[0], [0]]) | |
preconditioned_partitioned_grads.append(precond_grad) | |
merged_grad = self._partitioner.merge_partitions( | |
preconditioned_partitioned_grads) | |
return torch.reshape(merged_grad, self._original_shape) | |
STEP = 'step' | |
MOMENTUM = 'momentum' | |
PRECONDITIONER = 'preconditioner' | |
GRAFT = 'graft' | |
class Shampoo(optim.Optimizer): | |
"""The Shampoo optimizer.""" | |
def __init__(self, | |
params, | |
lr=1.0, | |
momentum=0.9, | |
hyperparams=ShampooHyperParams()): | |
defaults = dict(lr=lr, momentum=momentum) | |
self.hps = hyperparams | |
super(Shampoo, self).__init__(params, defaults) | |
def init_var_state(self, var, state): | |
"""Initialize the PyTorch state of for a single variable.""" | |
state[STEP] = 0 | |
state[MOMENTUM] = torch.zeros_like(var.data, device=var.get_device()) | |
state[PRECONDITIONER] = Preconditioner(var, self.hps) | |
if self.hps.graft_type == LayerwiseGrafting.ADAGRAD: | |
state[GRAFT] = AdagradGraft(self.hps, var) | |
elif self.hps.graft_type == LayerwiseGrafting.SGD: | |
state[GRAFT] = SGDGraft(self.hps, var) | |
else: | |
state[GRAFT] = Graft(self.hps, var) | |
def step(self, closure=None): | |
hps = self.hps | |
for group in self.param_groups: | |
lr = group['lr'] | |
for p in group['params']: | |
if p.grad is None: continue | |
grad = p.grad.data | |
if grad.is_sparse: | |
raise RuntimeError('Shampoo does not support sparse yet') | |
state = self.state[p] | |
if not state: | |
self.init_var_state(p, state) | |
state[STEP] += 1 | |
preconditioner = state[PRECONDITIONER] | |
graft = state[GRAFT] | |
# Gather statistics, compute preconditioners | |
graft.add_statistics(grad) | |
if state[STEP] % hps.statistics_compute_steps == 0: | |
preconditioner.add_statistics(grad) | |
if state[STEP] % hps.preconditioning_compute_steps == 0: | |
preconditioner.compute_preconditioners() | |
# Precondition gradients | |
graft_grad = graft.precondition_gradient(grad) | |
shampoo_grad = grad | |
if state[STEP] >= self.hps.start_preconditioning_step: | |
shampoo_grad = preconditioner.preconditioned_grad(grad) | |
# Grafting | |
graft_norm = torch.norm(graft_grad) | |
shampoo_norm = torch.norm(shampoo_grad) | |
shampoo_grad.mul_(graft_norm / (shampoo_norm + 1e-16)) | |
# Weight decay | |
if self.hps.weight_decay != 0.0: | |
shampoo_grad.add_(p.data, alpha=self.hps.weight_decay) | |
graft_grad.add_(p.data, alpha=self.hps.weight_decay) | |
# Momentum and Nesterov momentum, if needed | |
state[MOMENTUM].mul_(group['momentum']).add_(shampoo_grad) | |
graft_momentum = graft.update_momentum(grad, group['momentum']) | |
if state[STEP] >= self.hps.start_preconditioning_step: | |
momentum_update = state[MOMENTUM] | |
wd_update = shampoo_grad | |
else: | |
momentum_update = graft_momentum | |
wd_update = graft_grad | |
if hps.nesterov: | |
momentum_update.mul_(group['momentum']).add_(wd_update) | |
# Final update | |
p.data.add_(momentum_update, alpha=-lr) |