Spaces:
Runtime error
Runtime error
| import numpy as np | |
| import torch | |
| import enum | |
| import itertools | |
| from dataclasses import dataclass | |
| import torch.optim as optim | |
| def PowerIter(mat_g, error_tolerance=1e-6, num_iters=100): | |
| """Power iteration. | |
| Compute the maximum eigenvalue of mat, for scaling. | |
| v is a random vector with values in (-1, 1) | |
| Args: | |
| mat_g: the symmetric PSD matrix. | |
| error_tolerance: Iterative exit condition. | |
| num_iters: Number of iterations. | |
| Returns: | |
| eigen vector, eigen value, num_iters | |
| """ | |
| v = torch.rand(list(mat_g.shape)[0], device=mat_g.get_device()) * 2 - 1 | |
| error = 1 | |
| iters = 0 | |
| singular_val = 0 | |
| while error > error_tolerance and iters < num_iters: | |
| v = v / torch.norm(v) | |
| mat_v = torch.mv(mat_g, v) | |
| s_v = torch.dot(v, mat_v) | |
| error = torch.abs(s_v - singular_val) | |
| v = mat_v | |
| singular_val = s_v | |
| iters += 1 | |
| return singular_val, v / torch.norm(v), iters | |
| def MatPower(mat_m, p): | |
| """Computes mat_m^p, for p a positive integer. | |
| Args: | |
| mat_m: a square matrix | |
| p: a positive integer | |
| Returns: | |
| mat_m^p | |
| """ | |
| if p in [1, 2, 4, 8, 16, 32]: | |
| p_done = 1 | |
| res = mat_m | |
| while p_done < p: | |
| res = torch.matmul(res, res) | |
| p_done *= 2 | |
| return res | |
| power = None | |
| while p > 0: | |
| if p % 2 == 1: | |
| power = torch.matmul(mat_m, power) if power is not None else mat_m | |
| p //= 2 | |
| mat_m = torch.matmul(mat_m, mat_m) | |
| return power | |
| def ComputePower(mat_g, p, | |
| iter_count=100, | |
| error_tolerance=1e-6, | |
| ridge_epsilon=1e-6): | |
| """A method to compute G^{-1/p} using a coupled Newton iteration. | |
| See for example equation 3.2 on page 9 of: | |
| A Schur-Newton Method for the Matrix p-th Root and its Inverse | |
| by Chun-Hua Guo and Nicholas J. Higham | |
| SIAM Journal on Matrix Analysis and Applications, | |
| 2006, Vol. 28, No. 3 : pp. 788-804 | |
| https://pdfs.semanticscholar.org/0abe/7f77433cf5908bfe2b79aa91af881da83858.pdf | |
| Args: | |
| mat_g: A square positive semidefinite matrix | |
| p: a positive integer | |
| iter_count: Stop iterating after this many rounds. | |
| error_tolerance: Threshold for stopping iteration | |
| ridge_epsilon: We add this times I to G, to make is positive definite. | |
| For scaling, we multiply it by the largest eigenvalue of G. | |
| Returns: | |
| (mat_g + rI)^{-1/p} (r = ridge_epsilon * max_eigenvalue of mat_g). | |
| """ | |
| shape = list(mat_g.shape) | |
| if len(shape) == 1: | |
| return torch.pow(mat_g + ridge_epsilon, -1/p) | |
| identity = torch.eye(shape[0], device=mat_g.get_device()) | |
| if shape[0] == 1: | |
| return identity | |
| alpha = -1.0/p | |
| max_ev, _, _ = PowerIter(mat_g) | |
| ridge_epsilon *= max_ev | |
| mat_g += ridge_epsilon * identity | |
| z = (1 + p) / (2 * torch.norm(mat_g)) | |
| # The best value for z is | |
| # (1 + p) * (c_max^{1/p} - c_min^{1/p}) / | |
| # (c_max^{1+1/p} - c_min^{1+1/p}) | |
| # where c_max and c_min are the largest and smallest singular values of | |
| # mat_g. | |
| # The above estimate assumes that c_max > c_min * 2^p | |
| # Can replace above line by the one below, but it is less accurate, | |
| # hence needs more iterations to converge. | |
| # z = (1 + p) / tf.trace(mat_g) | |
| # If we want the method to always converge, use z = 1 / norm(mat_g) | |
| # or z = 1 / tf.trace(mat_g), but these can result in many | |
| # extra iterations. | |
| mat_root = identity * torch.pow(z, 1.0/p) | |
| mat_m = mat_g * z | |
| error = torch.max(torch.abs(mat_m - identity)) | |
| count = 0 | |
| while error > error_tolerance and count < iter_count: | |
| tmp_mat_m = (1 - alpha) * identity + alpha * mat_m | |
| new_mat_root = torch.matmul(mat_root, tmp_mat_m) | |
| mat_m = torch.matmul(MatPower(tmp_mat_m, p), mat_m) | |
| new_error = torch.max(torch.abs(mat_m - identity)) | |
| if new_error > error * 1.2: | |
| break | |
| mat_root = new_mat_root | |
| error = new_error | |
| count += 1 | |
| return mat_root | |
| # Grafting is a technique to fix the layerwise scale of Shampoo optimizer. | |
| # https://arxiv.org/pdf/2002.11803.pdf studies this in detail. This | |
| # allows us to plugin the Shampoo optimizer into settings where SGD/AdaGrad | |
| # is already well tuned. Grafting onto Shampoo means take the Shampoo direction, | |
| # but use the step magnitude from the grafted optimizer such as Adagrad or SGD. | |
| class LayerwiseGrafting(enum.IntEnum): | |
| NONE = 0 | |
| SGD = 1 | |
| ADAGRAD = 2 | |
| class ShampooHyperParams: | |
| """Shampoo hyper parameters.""" | |
| beta2: float = 0.9 | |
| diagonal_eps: float = 1e-6 | |
| matrix_eps: float = 1e-12 | |
| weight_decay: float = 0.0 | |
| inverse_exponent_override: int = 2 # fixed exponent for preconditioner, if >0 | |
| start_preconditioning_step: int = 1 | |
| # Performance tuning params for controlling memory and compute requirements. | |
| # How often to compute preconditioner. | |
| preconditioning_compute_steps: int = 1 | |
| # How often to compute statistics. | |
| statistics_compute_steps: int = 1 | |
| # Block size for large layers (if > 0). | |
| # Block size = 1 ==> Adagrad (Don't do this, extremely inefficient!) | |
| # Block size should be as large as feasible under memory/time constraints. | |
| block_size: int = 128 | |
| # Automatic shape interpretation (for eg: [4, 3, 1024, 512] would result in | |
| # 12 x [1024, 512] L and R statistics. Disabled by default which results in | |
| # Shampoo constructing statistics [4, 4], [3, 3], [1024, 1024], [512, 512]. | |
| best_effort_shape_interpretation: bool = True | |
| # Type of grafting (SGD or AdaGrad). | |
| # https://arxiv.org/pdf/2002.11803.pdf | |
| graft_type: int = LayerwiseGrafting.ADAGRAD | |
| # Nesterov momentum | |
| nesterov: bool = True | |
| class Graft: | |
| """Base class to perform grafting onto Shampoo. This class does no grafting. | |
| """ | |
| def __init__(self, hps, unused_var): | |
| self.hps = hps | |
| def add_statistics(self, grad): | |
| pass | |
| def precondition_gradient(self, grad): | |
| return grad | |
| def update_momentum(self, update, unused_beta1): | |
| return update | |
| class SGDGraft(Graft): | |
| """Graft using SGD+momentum. | |
| momentum maintains an exponentially weighted moving average of gradients. | |
| """ | |
| def __init__(self, hps, var): | |
| super(SGDGraft, self).__init__(hps, var) | |
| self.momentum = torch.zeros_like(var.data, device=var.get_device()) | |
| def update_momentum(self, update, beta1): | |
| self.momentum.mul_(beta1).add_(update) | |
| return self.momentum | |
| class AdagradGraft(SGDGraft): | |
| """Graft using Adagrad. | |
| Essentially an implementation of Adagrad with momentum. | |
| """ | |
| def __init__(self, hps, var): | |
| super(AdagradGraft, self).__init__(hps, var) | |
| self.statistics = torch.zeros_like(var.data, device=var.get_device()) | |
| def add_statistics(self, grad): | |
| self.statistics.add_(grad * grad) | |
| def precondition_gradient(self, grad): | |
| return grad / (torch.sqrt(self.statistics) + self.hps.diagonal_eps) | |
| class BlockPartitioner: | |
| """Partitions a tensor into smaller tensors for preconditioning. | |
| For example, if a variable has shape (4096, 512), we might split the | |
| 4096 into 4 blocks, so we effectively have 4 variables of size | |
| (1024, 512) each. | |
| """ | |
| def __init__(self, var, hps): | |
| self._shape = var.shape | |
| self._splits = [] | |
| self._split_sizes = [] | |
| split_sizes = [] | |
| # We split var into smaller blocks. Here we store the metadata to make | |
| # that split. | |
| for i, d in enumerate(var.shape): | |
| if hps.block_size > 0 and d > hps.block_size: | |
| # d-1, otherwise split appends a 0-size array. | |
| nsplit = (d-1) // hps.block_size | |
| indices = (np.arange(nsplit, dtype=np.int32) + 1) * hps.block_size | |
| sizes = np.ones(nsplit + 1, dtype=np.int32) * hps.block_size | |
| sizes[-1] = d - indices[-1] | |
| self._splits.append((i, indices)) | |
| self._split_sizes.append((i, sizes)) | |
| split_sizes.append(sizes) | |
| else: | |
| split_sizes.append(np.array([d], dtype=np.int32)) | |
| self._num_splits = len(split_sizes) | |
| self._preconditioner_shapes = [] | |
| for t in itertools.product(*split_sizes): | |
| self._preconditioner_shapes.extend([[d, d] for d in t]) | |
| def shapes_for_preconditioners(self): | |
| return self._preconditioner_shapes | |
| def num_splits(self): | |
| return self._num_splits | |
| def partition(self, tensor): | |
| """Partition tensor into blocks.""" | |
| assert tensor.shape == self._shape | |
| tensors = [tensor] | |
| for (i, sizes) in self._split_sizes: | |
| tensors_local = [] | |
| for t in tensors: | |
| tensors_local.extend( | |
| torch.split(t, tuple(sizes), dim=i)) | |
| tensors = tensors_local | |
| return tensors | |
| def merge_partitions(self, partitions): | |
| """Merge partitions back to original shape.""" | |
| for (i, indices) in reversed(self._splits): | |
| n = len(indices) + 1 | |
| partial_merged_tensors = [] | |
| ind = 0 | |
| while ind < len(partitions): | |
| partial_merged_tensors.append( | |
| torch.cat(partitions[ind:ind + n], axis=i)) | |
| ind += n | |
| partitions = partial_merged_tensors | |
| assert len(partitions) == 1 | |
| return partitions[0] | |
| def _merge_small_dims(shape_to_merge, max_dim): | |
| """Merge small dimensions. | |
| If there are some small dimensions, we collapse them: | |
| e.g. [1, 2, 512, 1, 2048, 1, 3, 4] --> [1024, 2048, 12] if max_dim = 1024 | |
| [1, 2, 768, 1, 2048] --> [2, 768, 2048] | |
| Args: | |
| shape_to_merge: Shape to merge small dimensions. | |
| max_dim: Maximal dimension of output shape used in merging. | |
| Returns: | |
| Merged shape. | |
| """ | |
| resulting_shape = [] | |
| product = 1 | |
| for d in shape_to_merge: | |
| if product * d <= max_dim: | |
| product *= d | |
| else: | |
| if product > 1: | |
| resulting_shape.append(product) | |
| product = d | |
| if product > 1: | |
| resulting_shape.append(product) | |
| return resulting_shape | |
| class Preconditioner: | |
| """Compute statistics/shape from gradients for preconditioning.""" | |
| def __init__(self, var, hps): | |
| self._hps = hps | |
| self._original_shape = var.shape | |
| self._transformed_shape = var.shape | |
| if hps.best_effort_shape_interpretation: | |
| self._transformed_shape = _merge_small_dims( | |
| self._original_shape, hps.block_size) | |
| reshaped_var = torch.reshape(var, self._transformed_shape) | |
| self._partitioner = BlockPartitioner(reshaped_var, hps) | |
| shapes = self._partitioner.shapes_for_preconditioners() | |
| rank = len(self._transformed_shape) | |
| device = var.get_device() | |
| if rank <= 1: | |
| self.statistics = [] | |
| self.preconditioners = [] | |
| else: | |
| eps = self._hps.matrix_eps | |
| self.statistics = [eps * torch.eye(s[0], device=device) for s in shapes] | |
| self.preconditioners = [torch.eye(s[0], device=device) for s in shapes] | |
| def add_statistics(self, grad): | |
| """Compute statistics from gradients and add to the correct state entries. | |
| Args: | |
| grad: Gradient to compute statistics from. | |
| """ | |
| if not self.statistics: return | |
| reshaped_grad = torch.reshape(grad, self._transformed_shape) | |
| partitioned_grads = self._partitioner.partition(reshaped_grad) | |
| w1 = self._hps.beta2 | |
| w2 = 1.0 if w1 == 1.0 else (1.0 - w1) | |
| rank = len(self._transformed_shape) | |
| for j, grad in enumerate(partitioned_grads): | |
| for i in range(rank): | |
| axes = list(range(i)) + list(range(i + 1, rank)) | |
| stat = torch.tensordot(grad, grad, [axes, axes]) | |
| self.statistics[j*rank + i].mul_(w1).add_(stat, alpha=w2) | |
| def exponent_for_preconditioner(self): | |
| """Returns exponent to use for inverse-pth root M^{-1/p}.""" | |
| if self._hps.inverse_exponent_override > 0: | |
| return self._hps.inverse_exponent_override | |
| return 2 * len(self._transformed_shape) | |
| def compute_preconditioners(self): | |
| """Compute L^{-1/exp} for each stats matrix L.""" | |
| exp = self.exponent_for_preconditioner() | |
| eps = self._hps.matrix_eps | |
| for i, stat in enumerate(self.statistics): | |
| self.preconditioners[i] = ComputePower( | |
| stat, exp, ridge_epsilon=eps) | |
| def preconditioned_grad(self, grad): | |
| """Precondition the gradient. | |
| Args: | |
| grad: A gradient tensor to precondition. | |
| Returns: | |
| A preconditioned gradient. | |
| """ | |
| if not self.preconditioners: return grad | |
| reshaped_grad = torch.reshape(grad, self._transformed_shape) | |
| partitioned_grads = self._partitioner.partition(reshaped_grad) | |
| preconditioned_partitioned_grads = [] | |
| num_splits = self._partitioner.num_splits() | |
| for i, grad in enumerate(partitioned_grads): | |
| preconditioners_for_grad = self.preconditioners[i * num_splits:(i + 1) * | |
| num_splits] | |
| rank = len(grad.shape) | |
| precond_grad = grad | |
| for j in range(rank): | |
| preconditioner = preconditioners_for_grad[j] | |
| precond_grad = torch.tensordot( | |
| precond_grad, preconditioner, [[0], [0]]) | |
| preconditioned_partitioned_grads.append(precond_grad) | |
| merged_grad = self._partitioner.merge_partitions( | |
| preconditioned_partitioned_grads) | |
| return torch.reshape(merged_grad, self._original_shape) | |
| STEP = 'step' | |
| MOMENTUM = 'momentum' | |
| PRECONDITIONER = 'preconditioner' | |
| GRAFT = 'graft' | |
| class Shampoo(optim.Optimizer): | |
| """The Shampoo optimizer.""" | |
| def __init__(self, | |
| params, | |
| lr=1.0, | |
| momentum=0.9, | |
| hyperparams=ShampooHyperParams()): | |
| defaults = dict(lr=lr, momentum=momentum) | |
| self.hps = hyperparams | |
| super(Shampoo, self).__init__(params, defaults) | |
| def init_var_state(self, var, state): | |
| """Initialize the PyTorch state of for a single variable.""" | |
| state[STEP] = 0 | |
| state[MOMENTUM] = torch.zeros_like(var.data, device=var.get_device()) | |
| state[PRECONDITIONER] = Preconditioner(var, self.hps) | |
| if self.hps.graft_type == LayerwiseGrafting.ADAGRAD: | |
| state[GRAFT] = AdagradGraft(self.hps, var) | |
| elif self.hps.graft_type == LayerwiseGrafting.SGD: | |
| state[GRAFT] = SGDGraft(self.hps, var) | |
| else: | |
| state[GRAFT] = Graft(self.hps, var) | |
| def step(self, closure=None): | |
| hps = self.hps | |
| for group in self.param_groups: | |
| lr = group['lr'] | |
| for p in group['params']: | |
| if p.grad is None: continue | |
| grad = p.grad.data | |
| if grad.is_sparse: | |
| raise RuntimeError('Shampoo does not support sparse yet') | |
| state = self.state[p] | |
| if not state: | |
| self.init_var_state(p, state) | |
| state[STEP] += 1 | |
| preconditioner = state[PRECONDITIONER] | |
| graft = state[GRAFT] | |
| # Gather statistics, compute preconditioners | |
| graft.add_statistics(grad) | |
| if state[STEP] % hps.statistics_compute_steps == 0: | |
| preconditioner.add_statistics(grad) | |
| if state[STEP] % hps.preconditioning_compute_steps == 0: | |
| preconditioner.compute_preconditioners() | |
| # Precondition gradients | |
| graft_grad = graft.precondition_gradient(grad) | |
| shampoo_grad = grad | |
| if state[STEP] >= self.hps.start_preconditioning_step: | |
| shampoo_grad = preconditioner.preconditioned_grad(grad) | |
| # Grafting | |
| graft_norm = torch.norm(graft_grad) | |
| shampoo_norm = torch.norm(shampoo_grad) | |
| shampoo_grad.mul_(graft_norm / (shampoo_norm + 1e-16)) | |
| # Weight decay | |
| if self.hps.weight_decay != 0.0: | |
| shampoo_grad.add_(p.data, alpha=self.hps.weight_decay) | |
| graft_grad.add_(p.data, alpha=self.hps.weight_decay) | |
| # Momentum and Nesterov momentum, if needed | |
| state[MOMENTUM].mul_(group['momentum']).add_(shampoo_grad) | |
| graft_momentum = graft.update_momentum(grad, group['momentum']) | |
| if state[STEP] >= self.hps.start_preconditioning_step: | |
| momentum_update = state[MOMENTUM] | |
| wd_update = shampoo_grad | |
| else: | |
| momentum_update = graft_momentum | |
| wd_update = graft_grad | |
| if hps.nesterov: | |
| momentum_update.mul_(group['momentum']).add_(wd_update) | |
| # Final update | |
| p.data.add_(momentum_update, alpha=-lr) |