import numpy as np import torch import enum import itertools from dataclasses import dataclass import torch.optim as optim @torch.no_grad() def PowerIter(mat_g, error_tolerance=1e-6, num_iters=100): """Power iteration. Compute the maximum eigenvalue of mat, for scaling. v is a random vector with values in (-1, 1) Args: mat_g: the symmetric PSD matrix. error_tolerance: Iterative exit condition. num_iters: Number of iterations. Returns: eigen vector, eigen value, num_iters """ v = torch.rand(list(mat_g.shape)[0], device=mat_g.get_device()) * 2 - 1 error = 1 iters = 0 singular_val = 0 while error > error_tolerance and iters < num_iters: v = v / torch.norm(v) mat_v = torch.mv(mat_g, v) s_v = torch.dot(v, mat_v) error = torch.abs(s_v - singular_val) v = mat_v singular_val = s_v iters += 1 return singular_val, v / torch.norm(v), iters @torch.no_grad() def MatPower(mat_m, p): """Computes mat_m^p, for p a positive integer. Args: mat_m: a square matrix p: a positive integer Returns: mat_m^p """ if p in [1, 2, 4, 8, 16, 32]: p_done = 1 res = mat_m while p_done < p: res = torch.matmul(res, res) p_done *= 2 return res power = None while p > 0: if p % 2 == 1: power = torch.matmul(mat_m, power) if power is not None else mat_m p //= 2 mat_m = torch.matmul(mat_m, mat_m) return power @torch.no_grad() def ComputePower(mat_g, p, iter_count=100, error_tolerance=1e-6, ridge_epsilon=1e-6): """A method to compute G^{-1/p} using a coupled Newton iteration. See for example equation 3.2 on page 9 of: A Schur-Newton Method for the Matrix p-th Root and its Inverse by Chun-Hua Guo and Nicholas J. Higham SIAM Journal on Matrix Analysis and Applications, 2006, Vol. 28, No. 3 : pp. 788-804 https://pdfs.semanticscholar.org/0abe/7f77433cf5908bfe2b79aa91af881da83858.pdf Args: mat_g: A square positive semidefinite matrix p: a positive integer iter_count: Stop iterating after this many rounds. error_tolerance: Threshold for stopping iteration ridge_epsilon: We add this times I to G, to make is positive definite. For scaling, we multiply it by the largest eigenvalue of G. Returns: (mat_g + rI)^{-1/p} (r = ridge_epsilon * max_eigenvalue of mat_g). """ shape = list(mat_g.shape) if len(shape) == 1: return torch.pow(mat_g + ridge_epsilon, -1/p) identity = torch.eye(shape[0], device=mat_g.get_device()) if shape[0] == 1: return identity alpha = -1.0/p max_ev, _, _ = PowerIter(mat_g) ridge_epsilon *= max_ev mat_g += ridge_epsilon * identity z = (1 + p) / (2 * torch.norm(mat_g)) # The best value for z is # (1 + p) * (c_max^{1/p} - c_min^{1/p}) / # (c_max^{1+1/p} - c_min^{1+1/p}) # where c_max and c_min are the largest and smallest singular values of # mat_g. # The above estimate assumes that c_max > c_min * 2^p # Can replace above line by the one below, but it is less accurate, # hence needs more iterations to converge. # z = (1 + p) / tf.trace(mat_g) # If we want the method to always converge, use z = 1 / norm(mat_g) # or z = 1 / tf.trace(mat_g), but these can result in many # extra iterations. mat_root = identity * torch.pow(z, 1.0/p) mat_m = mat_g * z error = torch.max(torch.abs(mat_m - identity)) count = 0 while error > error_tolerance and count < iter_count: tmp_mat_m = (1 - alpha) * identity + alpha * mat_m new_mat_root = torch.matmul(mat_root, tmp_mat_m) mat_m = torch.matmul(MatPower(tmp_mat_m, p), mat_m) new_error = torch.max(torch.abs(mat_m - identity)) if new_error > error * 1.2: break mat_root = new_mat_root error = new_error count += 1 return mat_root # Grafting is a technique to fix the layerwise scale of Shampoo optimizer. # https://arxiv.org/pdf/2002.11803.pdf studies this in detail. This # allows us to plugin the Shampoo optimizer into settings where SGD/AdaGrad # is already well tuned. Grafting onto Shampoo means take the Shampoo direction, # but use the step magnitude from the grafted optimizer such as Adagrad or SGD. class LayerwiseGrafting(enum.IntEnum): NONE = 0 SGD = 1 ADAGRAD = 2 @dataclass class ShampooHyperParams: """Shampoo hyper parameters.""" beta2: float = 0.9 diagonal_eps: float = 1e-6 matrix_eps: float = 1e-12 weight_decay: float = 0.0 inverse_exponent_override: int = 2 # fixed exponent for preconditioner, if >0 start_preconditioning_step: int = 1 # Performance tuning params for controlling memory and compute requirements. # How often to compute preconditioner. preconditioning_compute_steps: int = 1 # How often to compute statistics. statistics_compute_steps: int = 1 # Block size for large layers (if > 0). # Block size = 1 ==> Adagrad (Don't do this, extremely inefficient!) # Block size should be as large as feasible under memory/time constraints. block_size: int = 128 # Automatic shape interpretation (for eg: [4, 3, 1024, 512] would result in # 12 x [1024, 512] L and R statistics. Disabled by default which results in # Shampoo constructing statistics [4, 4], [3, 3], [1024, 1024], [512, 512]. best_effort_shape_interpretation: bool = True # Type of grafting (SGD or AdaGrad). # https://arxiv.org/pdf/2002.11803.pdf graft_type: int = LayerwiseGrafting.ADAGRAD # Nesterov momentum nesterov: bool = True class Graft: """Base class to perform grafting onto Shampoo. This class does no grafting. """ def __init__(self, hps, unused_var): self.hps = hps def add_statistics(self, grad): pass def precondition_gradient(self, grad): return grad def update_momentum(self, update, unused_beta1): return update class SGDGraft(Graft): """Graft using SGD+momentum. momentum maintains an exponentially weighted moving average of gradients. """ def __init__(self, hps, var): super(SGDGraft, self).__init__(hps, var) self.momentum = torch.zeros_like(var.data, device=var.get_device()) def update_momentum(self, update, beta1): self.momentum.mul_(beta1).add_(update) return self.momentum class AdagradGraft(SGDGraft): """Graft using Adagrad. Essentially an implementation of Adagrad with momentum. """ def __init__(self, hps, var): super(AdagradGraft, self).__init__(hps, var) self.statistics = torch.zeros_like(var.data, device=var.get_device()) def add_statistics(self, grad): self.statistics.add_(grad * grad) def precondition_gradient(self, grad): return grad / (torch.sqrt(self.statistics) + self.hps.diagonal_eps) class BlockPartitioner: """Partitions a tensor into smaller tensors for preconditioning. For example, if a variable has shape (4096, 512), we might split the 4096 into 4 blocks, so we effectively have 4 variables of size (1024, 512) each. """ def __init__(self, var, hps): self._shape = var.shape self._splits = [] self._split_sizes = [] split_sizes = [] # We split var into smaller blocks. Here we store the metadata to make # that split. for i, d in enumerate(var.shape): if hps.block_size > 0 and d > hps.block_size: # d-1, otherwise split appends a 0-size array. nsplit = (d-1) // hps.block_size indices = (np.arange(nsplit, dtype=np.int32) + 1) * hps.block_size sizes = np.ones(nsplit + 1, dtype=np.int32) * hps.block_size sizes[-1] = d - indices[-1] self._splits.append((i, indices)) self._split_sizes.append((i, sizes)) split_sizes.append(sizes) else: split_sizes.append(np.array([d], dtype=np.int32)) self._num_splits = len(split_sizes) self._preconditioner_shapes = [] for t in itertools.product(*split_sizes): self._preconditioner_shapes.extend([[d, d] for d in t]) def shapes_for_preconditioners(self): return self._preconditioner_shapes def num_splits(self): return self._num_splits def partition(self, tensor): """Partition tensor into blocks.""" assert tensor.shape == self._shape tensors = [tensor] for (i, sizes) in self._split_sizes: tensors_local = [] for t in tensors: tensors_local.extend( torch.split(t, tuple(sizes), dim=i)) tensors = tensors_local return tensors def merge_partitions(self, partitions): """Merge partitions back to original shape.""" for (i, indices) in reversed(self._splits): n = len(indices) + 1 partial_merged_tensors = [] ind = 0 while ind < len(partitions): partial_merged_tensors.append( torch.cat(partitions[ind:ind + n], axis=i)) ind += n partitions = partial_merged_tensors assert len(partitions) == 1 return partitions[0] def _merge_small_dims(shape_to_merge, max_dim): """Merge small dimensions. If there are some small dimensions, we collapse them: e.g. [1, 2, 512, 1, 2048, 1, 3, 4] --> [1024, 2048, 12] if max_dim = 1024 [1, 2, 768, 1, 2048] --> [2, 768, 2048] Args: shape_to_merge: Shape to merge small dimensions. max_dim: Maximal dimension of output shape used in merging. Returns: Merged shape. """ resulting_shape = [] product = 1 for d in shape_to_merge: if product * d <= max_dim: product *= d else: if product > 1: resulting_shape.append(product) product = d if product > 1: resulting_shape.append(product) return resulting_shape class Preconditioner: """Compute statistics/shape from gradients for preconditioning.""" def __init__(self, var, hps): self._hps = hps self._original_shape = var.shape self._transformed_shape = var.shape if hps.best_effort_shape_interpretation: self._transformed_shape = _merge_small_dims( self._original_shape, hps.block_size) reshaped_var = torch.reshape(var, self._transformed_shape) self._partitioner = BlockPartitioner(reshaped_var, hps) shapes = self._partitioner.shapes_for_preconditioners() rank = len(self._transformed_shape) device = var.get_device() if rank <= 1: self.statistics = [] self.preconditioners = [] else: eps = self._hps.matrix_eps self.statistics = [eps * torch.eye(s[0], device=device) for s in shapes] self.preconditioners = [torch.eye(s[0], device=device) for s in shapes] def add_statistics(self, grad): """Compute statistics from gradients and add to the correct state entries. Args: grad: Gradient to compute statistics from. """ if not self.statistics: return reshaped_grad = torch.reshape(grad, self._transformed_shape) partitioned_grads = self._partitioner.partition(reshaped_grad) w1 = self._hps.beta2 w2 = 1.0 if w1 == 1.0 else (1.0 - w1) rank = len(self._transformed_shape) for j, grad in enumerate(partitioned_grads): for i in range(rank): axes = list(range(i)) + list(range(i + 1, rank)) stat = torch.tensordot(grad, grad, [axes, axes]) self.statistics[j*rank + i].mul_(w1).add_(stat, alpha=w2) def exponent_for_preconditioner(self): """Returns exponent to use for inverse-pth root M^{-1/p}.""" if self._hps.inverse_exponent_override > 0: return self._hps.inverse_exponent_override return 2 * len(self._transformed_shape) def compute_preconditioners(self): """Compute L^{-1/exp} for each stats matrix L.""" exp = self.exponent_for_preconditioner() eps = self._hps.matrix_eps for i, stat in enumerate(self.statistics): self.preconditioners[i] = ComputePower( stat, exp, ridge_epsilon=eps) def preconditioned_grad(self, grad): """Precondition the gradient. Args: grad: A gradient tensor to precondition. Returns: A preconditioned gradient. """ if not self.preconditioners: return grad reshaped_grad = torch.reshape(grad, self._transformed_shape) partitioned_grads = self._partitioner.partition(reshaped_grad) preconditioned_partitioned_grads = [] num_splits = self._partitioner.num_splits() for i, grad in enumerate(partitioned_grads): preconditioners_for_grad = self.preconditioners[i * num_splits:(i + 1) * num_splits] rank = len(grad.shape) precond_grad = grad for j in range(rank): preconditioner = preconditioners_for_grad[j] precond_grad = torch.tensordot( precond_grad, preconditioner, [[0], [0]]) preconditioned_partitioned_grads.append(precond_grad) merged_grad = self._partitioner.merge_partitions( preconditioned_partitioned_grads) return torch.reshape(merged_grad, self._original_shape) STEP = 'step' MOMENTUM = 'momentum' PRECONDITIONER = 'preconditioner' GRAFT = 'graft' class Shampoo(optim.Optimizer): """The Shampoo optimizer.""" def __init__(self, params, lr=1.0, momentum=0.9, hyperparams=ShampooHyperParams()): defaults = dict(lr=lr, momentum=momentum) self.hps = hyperparams super(Shampoo, self).__init__(params, defaults) def init_var_state(self, var, state): """Initialize the PyTorch state of for a single variable.""" state[STEP] = 0 state[MOMENTUM] = torch.zeros_like(var.data, device=var.get_device()) state[PRECONDITIONER] = Preconditioner(var, self.hps) if self.hps.graft_type == LayerwiseGrafting.ADAGRAD: state[GRAFT] = AdagradGraft(self.hps, var) elif self.hps.graft_type == LayerwiseGrafting.SGD: state[GRAFT] = SGDGraft(self.hps, var) else: state[GRAFT] = Graft(self.hps, var) def step(self, closure=None): hps = self.hps for group in self.param_groups: lr = group['lr'] for p in group['params']: if p.grad is None: continue grad = p.grad.data if grad.is_sparse: raise RuntimeError('Shampoo does not support sparse yet') state = self.state[p] if not state: self.init_var_state(p, state) state[STEP] += 1 preconditioner = state[PRECONDITIONER] graft = state[GRAFT] # Gather statistics, compute preconditioners graft.add_statistics(grad) if state[STEP] % hps.statistics_compute_steps == 0: preconditioner.add_statistics(grad) if state[STEP] % hps.preconditioning_compute_steps == 0: preconditioner.compute_preconditioners() # Precondition gradients graft_grad = graft.precondition_gradient(grad) shampoo_grad = grad if state[STEP] >= self.hps.start_preconditioning_step: shampoo_grad = preconditioner.preconditioned_grad(grad) # Grafting graft_norm = torch.norm(graft_grad) shampoo_norm = torch.norm(shampoo_grad) shampoo_grad.mul_(graft_norm / (shampoo_norm + 1e-16)) # Weight decay if self.hps.weight_decay != 0.0: shampoo_grad.add_(p.data, alpha=self.hps.weight_decay) graft_grad.add_(p.data, alpha=self.hps.weight_decay) # Momentum and Nesterov momentum, if needed state[MOMENTUM].mul_(group['momentum']).add_(shampoo_grad) graft_momentum = graft.update_momentum(grad, group['momentum']) if state[STEP] >= self.hps.start_preconditioning_step: momentum_update = state[MOMENTUM] wd_update = shampoo_grad else: momentum_update = graft_momentum wd_update = graft_grad if hps.nesterov: momentum_update.mul_(group['momentum']).add_(wd_update) # Final update p.data.add_(momentum_update, alpha=-lr)