diff --git "a/tools/train/distributed_shampoo.py" "b/tools/train/distributed_shampoo.py" --- "a/tools/train/distributed_shampoo.py" +++ "b/tools/train/distributed_shampoo.py" @@ -48,103 +48,114 @@ import optax # pylint:disable=no-value-for-parameter @struct.dataclass class QuantizedValue: - """State associated with quantized value.""" - quantized: chex.Array - diagonal: chex.Array # Diagonal (if extract_diagonal is set) - bucket_size: chex.Array - quantized_dtype: jnp.dtype = struct.field( - pytree_node=False) # Dtype for the quantized value. - extract_diagonal: bool = struct.field( - pytree_node=False) # In case its centered. - shape: Any = struct.field(pytree_node=False) # Shape of the tensor. - - @classmethod - def from_float_value(cls, fvalue, quantized_dtype, extract_diagonal=False): - if isinstance(fvalue, list) and not fvalue: - return QuantizedValue([], [], [], quantized_dtype, extract_diagonal, []) - quantized, diagonal_fvalue, bucket_size = QuantizedValue.quantize( - fvalue, quantized_dtype, extract_diagonal) - return QuantizedValue(quantized, diagonal_fvalue, bucket_size, - quantized_dtype, extract_diagonal, - list(quantized.shape)) - - # Quantization is from Lingvo JAX optimizers. - # We extend it for int16 quantization of PSD matrices. - @classmethod - def quantize(cls, fvalue, quantized_dtype, extract_diagonal=False): - """Returns quantized value and the bucket.""" - if quantized_dtype == jnp.float32: - return fvalue, [], [] - elif quantized_dtype == jnp.bfloat16: - return fvalue.astype(jnp.bfloat16), [], [] - - float_dtype = fvalue.dtype - if quantized_dtype == jnp.int8: - # value -128 is not used. - num_buckets = jnp.array(127.0, dtype=float_dtype) - elif quantized_dtype == jnp.int16: - # value -32768 is not used. - num_buckets = jnp.array(32767.0, dtype=float_dtype) - else: - raise ValueError(f'Quantized dtype {quantized_dtype} not supported.') - # max value is mapped to num_buckets - - if extract_diagonal and fvalue.ndim != 2: - raise ValueError( - f'Input array {fvalue} must be 2D to work with extract_diagonal.') - - diagonal_fvalue = [] - if extract_diagonal: - diagonal_fvalue = jnp.diag(fvalue) - # Remove the diagonal entries. - fvalue = fvalue - jnp.diag(diagonal_fvalue) - - # TODO(rohananil): Extend this by making use of information about the blocks - # SM3 style which will be useful for diagonal statistics - # We first decide the scale. - if fvalue.ndim < 1: - raise ValueError( - f'Input array {fvalue} must have a strictly positive number of ' - 'dimensions.') - - max_abs = jnp.max(jnp.abs(fvalue), axis=0) - bucket_size = max_abs / num_buckets - bs_expanded = bucket_size[jnp.newaxis, Ellipsis] - # To avoid divide by 0.0 - bs_nonzero = jnp.where(bs_expanded > 0.0, bs_expanded, - jnp.ones_like(bs_expanded)) - ratio = fvalue / bs_nonzero - # We use rounding to remove bias. - quantized = jnp.round(ratio) - return quantized.astype(quantized_dtype), diagonal_fvalue, bucket_size - - def to_float(self): - """Returns the float value.""" - if isinstance(self.quantized, list) and not self.quantized: - return self.quantized - - if self.quantized_dtype == jnp.float32: - return self.quantized - - if self.quantized_dtype == jnp.bfloat16: - return self.quantized.astype(jnp.float32) - - float_dtype = self.bucket_size.dtype - bucket_size = self.bucket_size[jnp.newaxis, Ellipsis] - val = self.quantized.astype(float_dtype) * bucket_size - if self.extract_diagonal: - val += jnp.diag(self.diagonal) - return val + """State associated with quantized value.""" + + quantized: chex.Array + diagonal: chex.Array # Diagonal (if extract_diagonal is set) + bucket_size: chex.Array + quantized_dtype: jnp.dtype = struct.field( + pytree_node=False + ) # Dtype for the quantized value. + extract_diagonal: bool = struct.field(pytree_node=False) # In case its centered. + shape: Any = struct.field(pytree_node=False) # Shape of the tensor. + + @classmethod + def from_float_value(cls, fvalue, quantized_dtype, extract_diagonal=False): + if isinstance(fvalue, list) and not fvalue: + return QuantizedValue([], [], [], quantized_dtype, extract_diagonal, []) + quantized, diagonal_fvalue, bucket_size = QuantizedValue.quantize( + fvalue, quantized_dtype, extract_diagonal + ) + return QuantizedValue( + quantized, + diagonal_fvalue, + bucket_size, + quantized_dtype, + extract_diagonal, + list(quantized.shape), + ) + + # Quantization is from Lingvo JAX optimizers. + # We extend it for int16 quantization of PSD matrices. + @classmethod + def quantize(cls, fvalue, quantized_dtype, extract_diagonal=False): + """Returns quantized value and the bucket.""" + if quantized_dtype == jnp.float32: + return fvalue, [], [] + elif quantized_dtype == jnp.bfloat16: + return fvalue.astype(jnp.bfloat16), [], [] + + float_dtype = fvalue.dtype + if quantized_dtype == jnp.int8: + # value -128 is not used. + num_buckets = jnp.array(127.0, dtype=float_dtype) + elif quantized_dtype == jnp.int16: + # value -32768 is not used. + num_buckets = jnp.array(32767.0, dtype=float_dtype) + else: + raise ValueError(f"Quantized dtype {quantized_dtype} not supported.") + # max value is mapped to num_buckets + + if extract_diagonal and fvalue.ndim != 2: + raise ValueError( + f"Input array {fvalue} must be 2D to work with extract_diagonal." + ) + + diagonal_fvalue = [] + if extract_diagonal: + diagonal_fvalue = jnp.diag(fvalue) + # Remove the diagonal entries. + fvalue = fvalue - jnp.diag(diagonal_fvalue) + + # TODO(rohananil): Extend this by making use of information about the blocks + # SM3 style which will be useful for diagonal statistics + # We first decide the scale. + if fvalue.ndim < 1: + raise ValueError( + f"Input array {fvalue} must have a strictly positive number of " + "dimensions." + ) + + max_abs = jnp.max(jnp.abs(fvalue), axis=0) + bucket_size = max_abs / num_buckets + bs_expanded = bucket_size[jnp.newaxis, Ellipsis] + # To avoid divide by 0.0 + bs_nonzero = jnp.where( + bs_expanded > 0.0, bs_expanded, jnp.ones_like(bs_expanded) + ) + ratio = fvalue / bs_nonzero + # We use rounding to remove bias. + quantized = jnp.round(ratio) + return quantized.astype(quantized_dtype), diagonal_fvalue, bucket_size + + def to_float(self): + """Returns the float value.""" + if isinstance(self.quantized, list) and not self.quantized: + return self.quantized + + if self.quantized_dtype == jnp.float32: + return self.quantized + + if self.quantized_dtype == jnp.bfloat16: + return self.quantized.astype(jnp.float32) + + float_dtype = self.bucket_size.dtype + bucket_size = self.bucket_size[jnp.newaxis, Ellipsis] + val = self.quantized.astype(float_dtype) * bucket_size + if self.extract_diagonal: + val += jnp.diag(self.diagonal) + return val # Per parameter optimizer state used in data-parallel training. class ParameterStats(NamedTuple): - """State associated to each parameter of the model being trained.""" - diagonal_statistics: QuantizedValue # Accumulator for diagonal preconditioner - statistics: List[Any] # Statistics (QuantizedValue, chex.Array) - preconditioners: List[Any] # Preconditioners (QuantizedValue, chex.Array) - diagonal_momentum: QuantizedValue # Momentum for the diagonal preconditioner - momentum: QuantizedValue # Momentum for the shampoo preconditioner + """State associated to each parameter of the model being trained.""" + + diagonal_statistics: QuantizedValue # Accumulator for diagonal preconditioner + statistics: List[Any] # Statistics (QuantizedValue, chex.Array) + preconditioners: List[Any] # Preconditioners (QuantizedValue, chex.Array) + diagonal_momentum: QuantizedValue # Momentum for the diagonal preconditioner + momentum: QuantizedValue # Momentum for the shampoo preconditioner # For training extremely large model; We keep a global state with a concatenated @@ -153,91 +164,98 @@ class ParameterStats(NamedTuple): # communication. @struct.dataclass class GlobalShardedParameterStats: - statistics: chex.Array # Statistics - preconditioners: chex.Array # Preconditioners + statistics: chex.Array # Statistics + preconditioners: chex.Array # Preconditioners # These are per-parameter local states; All statistics here mirror the parameter # Thus the sharding is copied over from the param specification. @struct.dataclass class LocalShardedParameterStats: - """State associated to each parameter of the model being trained.""" - diagonal_statistics: QuantizedValue # Accumulator for diagonal preconditioner - diagonal_momentum: QuantizedValue # Momentum for the diagonal preconditioner - momentum: QuantizedValue # Momentum for the shampoo preconditioner - index_start: np.int32 = struct.field( - pytree_node=False) # Index into global statistics array - sizes: Any = struct.field(pytree_node=False) # Sizes of the statistics. + """State associated to each parameter of the model being trained.""" + + diagonal_statistics: QuantizedValue # Accumulator for diagonal preconditioner + diagonal_momentum: QuantizedValue # Momentum for the diagonal preconditioner + momentum: QuantizedValue # Momentum for the shampoo preconditioner + index_start: np.int32 = struct.field( + pytree_node=False + ) # Index into global statistics array + sizes: Any = struct.field(pytree_node=False) # Sizes of the statistics. class ShardedShampooStats(NamedTuple): - """Shampoo state in sharded mode.""" - global_stats: Any - local_stats: Any + """Shampoo state in sharded mode.""" + + global_stats: Any + local_stats: Any class ShampooState(NamedTuple): - count: chex.Array - stats: Any + count: chex.Array + stats: Any class GraftingType(enum.IntEnum): - SGD = 1 - ADAGRAD = 2 - RMSPROP = 3 - RMSPROP_NORMALIZED = 4 + SGD = 1 + ADAGRAD = 2 + RMSPROP = 3 + RMSPROP_NORMALIZED = 4 def power_iteration( - matrix, - num_iters=100, - error_tolerance=1e-6, - precision=lax.Precision.HIGHEST): - r"""Power iteration algorithm. - - The power iteration algorithm takes a symmetric PSD matrix `A`, and produces - a scalar `\lambda` , which is the greatest (in absolute value) eigenvalue - of `A`, and a vector v, which is the corresponding eigenvector of `A`. - - References: - [Wikipedia, 2021](https://en.wikipedia.org/wiki/Power_iteration) - - Args: - matrix: the symmetric PSD matrix. - num_iters: Number of iterations. - error_tolerance: Iterative exit condition. - precision: precision XLA related flag, the available options are: - a) lax.Precision.DEFAULT (better step time, but not precise) - b) lax.Precision.HIGH (increased precision, slower) - c) lax.Precision.HIGHEST (best possible precision, slowest) - - Returns: - eigen vector, eigen value - """ - matrix_size = matrix.shape[-1] - def _iter_condition(state): - i, unused_v, unused_s, unused_s_v, run_step = state - return jnp.logical_and(i < num_iters, run_step) - - def _iter_body(state): - """One step of power iteration.""" - i, new_v, s, s_v, unused_run_step = state - new_v = new_v / jnp.linalg.norm(new_v) - - s_v = jnp.einsum('ij,j->i', matrix, new_v, precision=precision) - s_new = jnp.einsum('i,i->', new_v, s_v, precision=precision) - return (i + 1, s_v, s_new, s_v, - jnp.greater(jnp.abs(s_new - s), error_tolerance)) - - # Figure out how to use step as seed for random. - v_0 = np.random.RandomState(1729).uniform(-1.0, 1.0, - matrix_size).astype(matrix.dtype) - - init_state = tuple([0, v_0, jnp.zeros([], dtype=matrix.dtype), v_0, True]) - _, v_out, s_out, _, _ = lax.while_loop( - _iter_condition, _iter_body, init_state) - v_out = v_out / jnp.linalg.norm(v_out) - return v_out, s_out + matrix, num_iters=100, error_tolerance=1e-6, precision=lax.Precision.HIGHEST +): + r"""Power iteration algorithm. + + The power iteration algorithm takes a symmetric PSD matrix `A`, and produces + a scalar `\lambda` , which is the greatest (in absolute value) eigenvalue + of `A`, and a vector v, which is the corresponding eigenvector of `A`. + + References: + [Wikipedia, 2021](https://en.wikipedia.org/wiki/Power_iteration) + + Args: + matrix: the symmetric PSD matrix. + num_iters: Number of iterations. + error_tolerance: Iterative exit condition. + precision: precision XLA related flag, the available options are: + a) lax.Precision.DEFAULT (better step time, but not precise) + b) lax.Precision.HIGH (increased precision, slower) + c) lax.Precision.HIGHEST (best possible precision, slowest) + + Returns: + eigen vector, eigen value + """ + matrix_size = matrix.shape[-1] + + def _iter_condition(state): + i, unused_v, unused_s, unused_s_v, run_step = state + return jnp.logical_and(i < num_iters, run_step) + + def _iter_body(state): + """One step of power iteration.""" + i, new_v, s, s_v, unused_run_step = state + new_v = new_v / jnp.linalg.norm(new_v) + + s_v = jnp.einsum("ij,j->i", matrix, new_v, precision=precision) + s_new = jnp.einsum("i,i->", new_v, s_v, precision=precision) + return ( + i + 1, + s_v, + s_new, + s_v, + jnp.greater(jnp.abs(s_new - s), error_tolerance), + ) + + # Figure out how to use step as seed for random. + v_0 = ( + np.random.RandomState(1729).uniform(-1.0, 1.0, matrix_size).astype(matrix.dtype) + ) + + init_state = tuple([0, v_0, jnp.zeros([], dtype=matrix.dtype), v_0, True]) + _, v_out, s_out, _, _ = lax.while_loop(_iter_condition, _iter_body, init_state) + v_out = v_out / jnp.linalg.norm(v_out) + return v_out, s_out def matrix_inverse_pth_root( @@ -246,381 +264,391 @@ def matrix_inverse_pth_root( num_iters=100, ridge_epsilon=1e-6, error_tolerance=1e-6, - precision=lax.Precision.HIGHEST): - """Computes `matrix^(-1/p)`, where `p` is a positive integer. - - This function uses the Coupled newton iterations algorithm for - the computation of a matrix's inverse pth root. - - - References: - [Functions of Matrices, Theory and Computation, - Nicholas J Higham, Pg 184, Eq 7.18]( - https://epubs.siam.org/doi/book/10.1137/1.9780898717778) - - Args: - matrix: the symmetric PSD matrix whose power it to be computed - p: exponent, for p a positive integer. - num_iters: Maximum number of iterations. - ridge_epsilon: Ridge epsilon added to make the matrix positive definite. - error_tolerance: Error indicator, useful for early termination. - precision: precision XLA related flag, the available options are: - a) lax.Precision.DEFAULT (better step time, but not precise) - b) lax.Precision.HIGH (increased precision, slower) - c) lax.Precision.HIGHEST (best possible precision, slowest) - - Returns: - matrix^(-1/p) - """ - - # We use float32 for the matrix inverse pth root. - # Switch to f64 if you have hardware that supports it. - matrix_size = matrix.shape[0] - alpha = jnp.asarray(-1.0 / p, jnp.float32) - identity = jnp.eye(matrix_size, dtype=jnp.float32) - _, max_ev = power_iteration( - matrix=matrix, num_iters=100, - error_tolerance=1e-6, precision=precision) - ridge_epsilon = ridge_epsilon * jnp.maximum(max_ev, 1e-16) - - def _unrolled_mat_pow_1(mat_m): - """Computes mat_m^1.""" - return mat_m - - def _unrolled_mat_pow_2(mat_m): - """Computes mat_m^2.""" - return jnp.matmul(mat_m, mat_m, precision=precision) - - def _unrolled_mat_pow_4(mat_m): - """Computes mat_m^4.""" - mat_pow_2 = _unrolled_mat_pow_2(mat_m) - return jnp.matmul( - mat_pow_2, mat_pow_2, precision=precision) - - def _unrolled_mat_pow_8(mat_m): - """Computes mat_m^4.""" - mat_pow_4 = _unrolled_mat_pow_4(mat_m) - return jnp.matmul( - mat_pow_4, mat_pow_4, precision=precision) - - def mat_power(mat_m, p): - """Computes mat_m^p, for p == 1, 2, 4 or 8. + precision=lax.Precision.HIGHEST, +): + """Computes `matrix^(-1/p)`, where `p` is a positive integer. + + This function uses the Coupled newton iterations algorithm for + the computation of a matrix's inverse pth root. + + + References: + [Functions of Matrices, Theory and Computation, + Nicholas J Higham, Pg 184, Eq 7.18]( + https://epubs.siam.org/doi/book/10.1137/1.9780898717778) Args: - mat_m: a square matrix - p: a positive integer + matrix: the symmetric PSD matrix whose power it to be computed + p: exponent, for p a positive integer. + num_iters: Maximum number of iterations. + ridge_epsilon: Ridge epsilon added to make the matrix positive definite. + error_tolerance: Error indicator, useful for early termination. + precision: precision XLA related flag, the available options are: + a) lax.Precision.DEFAULT (better step time, but not precise) + b) lax.Precision.HIGH (increased precision, slower) + c) lax.Precision.HIGHEST (best possible precision, slowest) Returns: - mat_m^p + matrix^(-1/p) """ - # We unrolled the loop for performance reasons. - exponent = jnp.round(jnp.log2(p)) - return lax.switch( - jnp.asarray(exponent, jnp.int32), [ - _unrolled_mat_pow_1, - _unrolled_mat_pow_2, - _unrolled_mat_pow_4, - _unrolled_mat_pow_8, - ], (mat_m)) - - def _iter_condition(state): - (i, unused_mat_m, unused_mat_h, unused_old_mat_h, error, - run_step) = state - error_above_threshold = jnp.logical_and( - error > error_tolerance, run_step) - return jnp.logical_and(i < num_iters, error_above_threshold) - - def _iter_body(state): - (i, mat_m, mat_h, unused_old_mat_h, error, unused_run_step) = state - mat_m_i = (1 - alpha) * identity + alpha * mat_m - new_mat_m = jnp.matmul(mat_power(mat_m_i, p), mat_m, precision=precision) - new_mat_h = jnp.matmul(mat_h, mat_m_i, precision=precision) - new_error = jnp.max(jnp.abs(new_mat_m - identity)) - # sometimes error increases after an iteration before decreasing and - # converging. 1.2 factor is used to bound the maximal allowed increase. - return (i + 1, new_mat_m, new_mat_h, mat_h, new_error, - new_error < error * 1.2) - - if matrix_size == 1: - resultant_mat_h = (matrix + ridge_epsilon)**alpha - error = 0 - else: - damped_matrix = matrix + ridge_epsilon * identity - - z = (1 + p) / (2 * jnp.linalg.norm(damped_matrix)) - new_mat_m_0 = damped_matrix * z - new_error = jnp.max(jnp.abs(new_mat_m_0 - identity)) - new_mat_h_0 = identity * jnp.power(z, 1.0 / p) - init_state = tuple( - [0, new_mat_m_0, new_mat_h_0, new_mat_h_0, new_error, True]) - _, mat_m, mat_h, old_mat_h, error, convergence = lax.while_loop( - _iter_condition, _iter_body, init_state) - error = jnp.max(jnp.abs(mat_m - identity)) - is_converged = jnp.asarray(convergence, old_mat_h.dtype) - resultant_mat_h = is_converged * mat_h + (1 - is_converged) * old_mat_h - resultant_mat_h = jnp.asarray(resultant_mat_h, matrix.dtype) - return resultant_mat_h, error + + # We use float32 for the matrix inverse pth root. + # Switch to f64 if you have hardware that supports it. + matrix_size = matrix.shape[0] + alpha = jnp.asarray(-1.0 / p, jnp.float32) + identity = jnp.eye(matrix_size, dtype=jnp.float32) + _, max_ev = power_iteration( + matrix=matrix, num_iters=100, error_tolerance=1e-6, precision=precision + ) + ridge_epsilon = ridge_epsilon * jnp.maximum(max_ev, 1e-16) + + def _unrolled_mat_pow_1(mat_m): + """Computes mat_m^1.""" + return mat_m + + def _unrolled_mat_pow_2(mat_m): + """Computes mat_m^2.""" + return jnp.matmul(mat_m, mat_m, precision=precision) + + def _unrolled_mat_pow_4(mat_m): + """Computes mat_m^4.""" + mat_pow_2 = _unrolled_mat_pow_2(mat_m) + return jnp.matmul(mat_pow_2, mat_pow_2, precision=precision) + + def _unrolled_mat_pow_8(mat_m): + """Computes mat_m^4.""" + mat_pow_4 = _unrolled_mat_pow_4(mat_m) + return jnp.matmul(mat_pow_4, mat_pow_4, precision=precision) + + def mat_power(mat_m, p): + """Computes mat_m^p, for p == 1, 2, 4 or 8. + + Args: + mat_m: a square matrix + p: a positive integer + + Returns: + mat_m^p + """ + # We unrolled the loop for performance reasons. + exponent = jnp.round(jnp.log2(p)) + return lax.switch( + jnp.asarray(exponent, jnp.int32), + [ + _unrolled_mat_pow_1, + _unrolled_mat_pow_2, + _unrolled_mat_pow_4, + _unrolled_mat_pow_8, + ], + (mat_m), + ) + + def _iter_condition(state): + (i, unused_mat_m, unused_mat_h, unused_old_mat_h, error, run_step) = state + error_above_threshold = jnp.logical_and(error > error_tolerance, run_step) + return jnp.logical_and(i < num_iters, error_above_threshold) + + def _iter_body(state): + (i, mat_m, mat_h, unused_old_mat_h, error, unused_run_step) = state + mat_m_i = (1 - alpha) * identity + alpha * mat_m + new_mat_m = jnp.matmul(mat_power(mat_m_i, p), mat_m, precision=precision) + new_mat_h = jnp.matmul(mat_h, mat_m_i, precision=precision) + new_error = jnp.max(jnp.abs(new_mat_m - identity)) + # sometimes error increases after an iteration before decreasing and + # converging. 1.2 factor is used to bound the maximal allowed increase. + return (i + 1, new_mat_m, new_mat_h, mat_h, new_error, new_error < error * 1.2) + + if matrix_size == 1: + resultant_mat_h = (matrix + ridge_epsilon) ** alpha + error = 0 + else: + damped_matrix = matrix + ridge_epsilon * identity + + z = (1 + p) / (2 * jnp.linalg.norm(damped_matrix)) + new_mat_m_0 = damped_matrix * z + new_error = jnp.max(jnp.abs(new_mat_m_0 - identity)) + new_mat_h_0 = identity * jnp.power(z, 1.0 / p) + init_state = tuple([0, new_mat_m_0, new_mat_h_0, new_mat_h_0, new_error, True]) + _, mat_m, mat_h, old_mat_h, error, convergence = lax.while_loop( + _iter_condition, _iter_body, init_state + ) + error = jnp.max(jnp.abs(mat_m - identity)) + is_converged = jnp.asarray(convergence, old_mat_h.dtype) + resultant_mat_h = is_converged * mat_h + (1 - is_converged) * old_mat_h + resultant_mat_h = jnp.asarray(resultant_mat_h, matrix.dtype) + return resultant_mat_h, error def merge_small_dims(shape_to_merge, max_dim): - """Merge small dimensions. - - If there are some small dimensions, we collapse them: - e.g. [1, 2, 512, 1, 2048, 1, 3, 4] --> [1024, 2048, 12] if max_dim = 1024 - [1, 2, 768, 1, 2048] --> [2, 768, 2048] - - Args: - shape_to_merge: Shape to merge small dimensions. - max_dim: Maximal dimension of output shape used in merging. - - Returns: - Merged shape. - """ - resulting_shape = [] - product = 1 - for d in shape_to_merge: - if product * d <= max_dim: - product *= d - else: - if product > 1: + """Merge small dimensions. + + If there are some small dimensions, we collapse them: + e.g. [1, 2, 512, 1, 2048, 1, 3, 4] --> [1024, 2048, 12] if max_dim = 1024 + [1, 2, 768, 1, 2048] --> [2, 768, 2048] + + Args: + shape_to_merge: Shape to merge small dimensions. + max_dim: Maximal dimension of output shape used in merging. + + Returns: + Merged shape. + """ + resulting_shape = [] + product = 1 + for d in shape_to_merge: + if product * d <= max_dim: + product *= d + else: + if product > 1: + resulting_shape.append(product) + product = d + if product > 1: resulting_shape.append(product) - product = d - if product > 1: - resulting_shape.append(product) - return resulting_shape + return resulting_shape def pad_matrix(mat, max_size): - """Pad a matrix to a max_size. - - Args: - mat: a matrix to pad. - max_size: matrix size requested. - - Returns: - Given M returns [[M, 0], [0, I]] - """ - size = mat.shape[0] - assert size <= max_size - if size == max_size: + """Pad a matrix to a max_size. + + Args: + mat: a matrix to pad. + max_size: matrix size requested. + + Returns: + Given M returns [[M, 0], [0, I]] + """ + size = mat.shape[0] + assert size <= max_size + if size == max_size: + return mat + pad_size = max_size - size + zs1 = jnp.zeros([size, pad_size], dtype=mat.dtype) + zs2 = jnp.zeros([pad_size, size], dtype=mat.dtype) + eye = jnp.eye(pad_size, dtype=mat.dtype) + mat = jnp.concatenate([mat, zs1], 1) + mat = jnp.concatenate([mat, jnp.concatenate([zs2, eye], 1)], 0) return mat - pad_size = max_size - size - zs1 = jnp.zeros([size, pad_size], dtype=mat.dtype) - zs2 = jnp.zeros([pad_size, size], dtype=mat.dtype) - eye = jnp.eye(pad_size, dtype=mat.dtype) - mat = jnp.concatenate([mat, zs1], 1) - mat = jnp.concatenate([mat, jnp.concatenate([zs2, eye], 1)], 0) - return mat def pad_vector(vec, max_size): - """Pad a vector to a max_size. + """Pad a vector to a max_size. - Args: - vec: a vector to pad. - max_size: matrix size requested. + Args: + vec: a vector to pad. + max_size: matrix size requested. - Returns: - Given V returns [V, 0] - """ - size = vec.shape[0] - assert size <= max_size - if size == max_size: - return vec - pad_size = max_size - size - zs1 = jnp.zeros([pad_size], dtype=vec.dtype) - return jnp.concatenate([vec, zs1], 0) + Returns: + Given V returns [V, 0] + """ + size = vec.shape[0] + assert size <= max_size + if size == max_size: + return vec + pad_size = max_size - size + zs1 = jnp.zeros([pad_size], dtype=vec.dtype) + return jnp.concatenate([vec, zs1], 0) def efficient_cond(predicate, compute_fn, init_state, *args, **kwargs): - """Avoids wasteful buffer allocation with XLA.""" + """Avoids wasteful buffer allocation with XLA.""" - def _iter_body(unused_state): - results = compute_fn(*args, **kwargs) - return tuple([False] + list(results)) + def _iter_body(unused_state): + results = compute_fn(*args, **kwargs) + return tuple([False] + list(results)) - def _iter_condition(state): - return state[0] + def _iter_condition(state): + return state[0] - results = jax.lax.while_loop(_iter_condition, _iter_body, - tuple([predicate] + init_state)) - return tuple(results[1:]) + results = jax.lax.while_loop( + _iter_condition, _iter_body, tuple([predicate] + init_state) + ) + return tuple(results[1:]) class BlockPartitioner: - """Partitions a tensor into smaller tensors.""" - - def __init__(self, param, block_size): - self._shape = param.shape - self._splits = [] - split_sizes = [] - # We split params into smaller blocks. Here we store the metadata to make - # that split. - for i, d in enumerate(param.shape): - if 0 < block_size < d: - # d-1, otherwise split appends a 0-size array. - nsplit = (d - 1) // block_size - indices = (np.arange(nsplit, dtype=np.int32) + 1) * block_size - sizes = np.ones(nsplit + 1, dtype=np.int32) * block_size - sizes[-1] = d - indices[-1] - self._splits.append((i, indices)) - split_sizes.append(sizes) - else: - split_sizes.append(np.array([d], dtype=np.int32)) - self._num_splits = len(split_sizes) - self._preconditioner_shapes = [] - for t in itertools.product(*split_sizes): - self._preconditioner_shapes.extend([[d, d] for d in t]) - - def shapes_for_preconditioners(self): - return self._preconditioner_shapes - - def num_splits(self): - return self._num_splits - - def partition(self, tensor): - """Partition tensor into blocks.""" - - assert tensor.shape == self._shape - tensors = [tensor] - for (i, indices) in self._splits: - tensors_local = [] - for t in tensors: - tensors_local.extend(jnp.split(t, indices_or_sections=indices, axis=i)) - tensors = tensors_local - return tensors - - def merge_partitions(self, partitions): - """Merge partitions back to original shape.""" - - for (i, indices) in reversed(self._splits): - n = len(indices) + 1 - partial_merged_tensors = [] - ind = 0 - while ind < len(partitions): - partial_merged_tensors.append( - jnp.concatenate(partitions[ind:ind + n], axis=i)) - ind += n - partitions = partial_merged_tensors - assert len(partitions) == 1 - return partitions[0] + """Partitions a tensor into smaller tensors.""" + + def __init__(self, param, block_size): + self._shape = param.shape + self._splits = [] + split_sizes = [] + # We split params into smaller blocks. Here we store the metadata to make + # that split. + for i, d in enumerate(param.shape): + if 0 < block_size < d: + # d-1, otherwise split appends a 0-size array. + nsplit = (d - 1) // block_size + indices = (np.arange(nsplit, dtype=np.int32) + 1) * block_size + sizes = np.ones(nsplit + 1, dtype=np.int32) * block_size + sizes[-1] = d - indices[-1] + self._splits.append((i, indices)) + split_sizes.append(sizes) + else: + split_sizes.append(np.array([d], dtype=np.int32)) + self._num_splits = len(split_sizes) + self._preconditioner_shapes = [] + for t in itertools.product(*split_sizes): + self._preconditioner_shapes.extend([[d, d] for d in t]) + + def shapes_for_preconditioners(self): + return self._preconditioner_shapes + + def num_splits(self): + return self._num_splits + + def partition(self, tensor): + """Partition tensor into blocks.""" + + assert tensor.shape == self._shape + tensors = [tensor] + for (i, indices) in self._splits: + tensors_local = [] + for t in tensors: + tensors_local.extend(jnp.split(t, indices_or_sections=indices, axis=i)) + tensors = tensors_local + return tensors + + def merge_partitions(self, partitions): + """Merge partitions back to original shape.""" + + for (i, indices) in reversed(self._splits): + n = len(indices) + 1 + partial_merged_tensors = [] + ind = 0 + while ind < len(partitions): + partial_merged_tensors.append( + jnp.concatenate(partitions[ind : ind + n], axis=i) + ) + ind += n + partitions = partial_merged_tensors + assert len(partitions) == 1 + return partitions[0] class Preconditioner: - """Compute statistics/shape from gradients for preconditioning.""" - - def __init__(self, param, block_size, best_effort_shape_interpretation): - self._original_shape = param.shape - self._transformed_shape = param.shape - if best_effort_shape_interpretation: - self._transformed_shape = merge_small_dims(self._original_shape, - block_size) - reshaped_param = jnp.reshape(param, self._transformed_shape) - self._partitioner = BlockPartitioner(reshaped_param, block_size) - - def statistics_from_grad(self, grad): - """Compute statistics from gradients. - - Args: - grad: Gradient to compute statistics from. - - Returns: - A list of gradient statistics for each partition. - """ - reshaped_grad = jnp.reshape(grad, self._transformed_shape) - partitioned_grads = self._partitioner.partition(reshaped_grad) - stats = [] - for g in partitioned_grads: - g_stats = [] - rank = len(g.shape) - for i in range(rank): - axes = list(range(i)) + list(range(i + 1, rank)) - stat = jnp.tensordot(g, g, axes=(axes, axes)) - g_stats.append(stat) - stats.extend(g_stats) - return stats - - def shapes_for_preconditioners(self): - """Returns shape from statistics.""" - return self._partitioner.shapes_for_preconditioners() - - def exponent_for_preconditioner(self): - """Returns exponent to use for inverse-pth root M^{-1/p}.""" - return 2 * len(self._transformed_shape) - - def preconditioned_grad(self, grad, preconditioners): - """Precondition the gradient. - - Args: - grad: A gradient tensor to precondition. - preconditioners: A list of preconditioners to apply. - - Returns: - A preconditioned gradient. - """ - - reshaped_grad = jnp.reshape(grad, self._transformed_shape) - partitioned_grads = self._partitioner.partition(reshaped_grad) - preconditioned_partitioned_grads = [] - num_splits = self._partitioner.num_splits() - for i, g in enumerate(partitioned_grads): - preconditioners_for_grad = preconditioners[i * num_splits:(i + 1) * - num_splits] - rank = len(g.shape) - precond_g = g - for j in range(rank): - precond_g = jnp.tensordot( - precond_g, preconditioners_for_grad[j], axes=[[0], [0]]) - preconditioned_partitioned_grads.append(precond_g) - merged_grad = self._partitioner.merge_partitions( - preconditioned_partitioned_grads) - return jnp.reshape(merged_grad, self._original_shape) + """Compute statistics/shape from gradients for preconditioning.""" + + def __init__(self, param, block_size, best_effort_shape_interpretation): + self._original_shape = param.shape + self._transformed_shape = param.shape + if best_effort_shape_interpretation: + self._transformed_shape = merge_small_dims(self._original_shape, block_size) + reshaped_param = jnp.reshape(param, self._transformed_shape) + self._partitioner = BlockPartitioner(reshaped_param, block_size) + + def statistics_from_grad(self, grad): + """Compute statistics from gradients. + + Args: + grad: Gradient to compute statistics from. + + Returns: + A list of gradient statistics for each partition. + """ + reshaped_grad = jnp.reshape(grad, self._transformed_shape) + partitioned_grads = self._partitioner.partition(reshaped_grad) + stats = [] + for g in partitioned_grads: + g_stats = [] + rank = len(g.shape) + for i in range(rank): + axes = list(range(i)) + list(range(i + 1, rank)) + stat = jnp.tensordot(g, g, axes=(axes, axes)) + g_stats.append(stat) + stats.extend(g_stats) + return stats + + def shapes_for_preconditioners(self): + """Returns shape from statistics.""" + return self._partitioner.shapes_for_preconditioners() + + def exponent_for_preconditioner(self): + """Returns exponent to use for inverse-pth root M^{-1/p}.""" + return 2 * len(self._transformed_shape) + + def preconditioned_grad(self, grad, preconditioners): + """Precondition the gradient. + + Args: + grad: A gradient tensor to precondition. + preconditioners: A list of preconditioners to apply. + + Returns: + A preconditioned gradient. + """ + + reshaped_grad = jnp.reshape(grad, self._transformed_shape) + partitioned_grads = self._partitioner.partition(reshaped_grad) + preconditioned_partitioned_grads = [] + num_splits = self._partitioner.num_splits() + for i, g in enumerate(partitioned_grads): + preconditioners_for_grad = preconditioners[ + i * num_splits : (i + 1) * num_splits + ] + rank = len(g.shape) + precond_g = g + for j in range(rank): + precond_g = jnp.tensordot( + precond_g, preconditioners_for_grad[j], axes=[[0], [0]] + ) + preconditioned_partitioned_grads.append(precond_g) + merged_grad = self._partitioner.merge_partitions( + preconditioned_partitioned_grads + ) + return jnp.reshape(merged_grad, self._original_shape) def _convert_to_parameter_stats(global_stats, local_stat): - """Creates parameter stats from sharded stats.""" - index_start = int(local_stat.index_start) - index_end = int(len(local_stat.sizes)) + index_start - statistics = global_stats.statistics[index_start:index_end, :, :] - preconditioners = global_stats.preconditioners[index_start:index_end, :, :] - new_statistics = [] - new_preconditioners = [] - for i, size in enumerate(local_stat.sizes): - new_statistics.append(statistics[i][:size, :size]) - new_preconditioners.append(preconditioners[i][:size, :size]) - return ParameterStats(local_stat.diagonal_statistics, new_statistics, - new_preconditioners, local_stat.diagonal_momentum, - local_stat.momentum) + """Creates parameter stats from sharded stats.""" + index_start = int(local_stat.index_start) + index_end = int(len(local_stat.sizes)) + index_start + statistics = global_stats.statistics[index_start:index_end, :, :] + preconditioners = global_stats.preconditioners[index_start:index_end, :, :] + new_statistics = [] + new_preconditioners = [] + for i, size in enumerate(local_stat.sizes): + new_statistics.append(statistics[i][:size, :size]) + new_preconditioners.append(preconditioners[i][:size, :size]) + return ParameterStats( + local_stat.diagonal_statistics, + new_statistics, + new_preconditioners, + local_stat.diagonal_momentum, + local_stat.momentum, + ) def _convert_from_parameter_stats(parameter_stats, local_stats): - """Creates sharded stats from paramter stats.""" - return LocalShardedParameterStats(parameter_stats.diagonal_statistics, - parameter_stats.diagonal_momentum, - parameter_stats.momentum, - local_stats.index_start, local_stats.sizes) + """Creates sharded stats from paramter stats.""" + return LocalShardedParameterStats( + parameter_stats.diagonal_statistics, + parameter_stats.diagonal_momentum, + parameter_stats.momentum, + local_stats.index_start, + local_stats.sizes, + ) def batch(x, num_devices): - """Batch `x` so that so that leading axis is num_devices.""" - n = len(x) - b = int(n / num_devices) - return jnp.stack([jnp.stack(x[idx:idx + b]) for idx in range(0, n, b)]) + """Batch `x` so that so that leading axis is num_devices.""" + n = len(x) + b = int(n / num_devices) + return jnp.stack([jnp.stack(x[idx : idx + b]) for idx in range(0, n, b)]) def unbatch(batched_values): - """Unbatch values across leading axis and return a list of elements.""" - b1, b2 = batched_values.shape[0], batched_values.shape[1] - results = [] - for v_array in jnp.split(batched_values, indices_or_sections=b1, axis=0): - v_array = jnp.squeeze(v_array) - # b2 = batches (number of preconditioner computation) per core. - if b2 > 1: - for v in jnp.split(v_array, indices_or_sections=b2, axis=0): - results.append(jnp.squeeze(v)) - else: - results.append(v_array) - return results + """Unbatch values across leading axis and return a list of elements.""" + b1, b2 = batched_values.shape[0], batched_values.shape[1] + results = [] + for v_array in jnp.split(batched_values, indices_or_sections=b1, axis=0): + v_array = jnp.squeeze(v_array) + # b2 = batches (number of preconditioner computation) per core. + if b2 > 1: + for v in jnp.split(v_array, indices_or_sections=b2, axis=0): + results.append(jnp.squeeze(v)) + else: + results.append(v_array) + return results def distributed_shampoo( @@ -653,959 +681,1146 @@ def distributed_shampoo( moving_average_for_momentum=False, skip_preconditioning_dim_size_gt=4096, clip_by_scaled_gradient_norm=None, - precision=lax.Precision.HIGHEST): - """Distributed Shampoo optimizer. - - Distributed Shampoo is a second-order preconditioned method (concretely, a - variant of full-matrix Adagrad), that provides significant convergence and - wall-clock time improvements compared to conventional first-order methods, - and that has been shown to scale to large state-of-the-art deep learning - models. - - References: - Scalable Second Order Optimization for Deep Learning, - Rohan Anil, Vineet Gupta, Tomer Koren, Kevin Regan, Yoram Singer - - Preprint: https://arxiv.org/abs/2002.09018 - - Args: - learning_rate: the step size used to update the parameters. - block_size: Block size for large layers (if > 0). Preconditioning compute - operation is cubic in the dimension of the tensor. Block size allows us to - chunk the layers into sub-layers of maximal dimension dictated by this - value. Use 128 as default (increase if you have compute budget). - beta1: momentum parameter. - beta2: second moment averaging parameter. - diagonal_epsilon: epsilon for diagonal adagrad (only if layerwise grafting - to AdaGrad is enabled). - matrix_epsilon: epsilon to add to statistics before computing inverse pth - root. If you are running in f32 precision for inverse pth root - (recommended today) this can go upto 1e-6. If you have latest hardware - with native f64 precision, set this upto 1e-12. - weight_decay: Weight decay for regularization. - start_preconditioning_step: When to start Shampoo update before which - diagonal update is used. This is because we dont have enough information - to do stable inverse. - preconditioning_compute_steps: How often to compute preconditioner. - Performance tuning params for controlling memory and compute requirements. - Ideally set this and statistics_compute_steps params to 1. - statistics_compute_steps: How often to compute statistics. - best_effort_shape_interpretation: If there are some small dimensions, - collapse them e.g. [1, 2, 512, 1, 2048, 1, 3, 4] --> [1024, 2048, 12] if - block = 1024, [1, 2, 768, 1, 2048] --> [2, 768, 2048] - graft_type: Grafting is a technique to fix the layerwise scale of Shampoo - optimizer. This allows us to plugin the Shampoo optimizer into settings - where SGD/AdaGrad is already well tuned. Available options are: - GraftingType.SGD and GraftingType.ADAGRAD. - nesterov: Nesterov momentum. - exponent_override: Override the exponent used in matrix inverse. - batch_axis_name: labeled axis over pmap for data-parallel training the - optimizer used for. - mesh_axis_names: Axis names for the mesh (used in pjit). - num_devices_for_pjit: Number of devices to parallelize over when using pjit. - shard_optimizer_states: Shard optimizer states to save memory in model - parallel training. - best_effort_memory_usage_reduction: Best effort memory usage reduction. - diagonal_statistics -> jnp.bfloat16 - momentum buffers (2x) -> jnp.int8 - statistics, preconditioners -> jnp.int16 + diagonals - inverse_failure_threshold: numerics are hard and inverses fail sometimes; we - determine that using this threshold. - moving_average_for_momentum: Whether to use moving average for momentum - instead of exponential moving average. - skip_preconditioning_dim_size_gt: Skip if preconditioning dim size is - greater than this value. - clip_by_scaled_gradient_norm: Clip by scaled gradient norm (only useful - when using RMSProp Grafting). - precision: precision XLA related flag, the available options are: a) - lax.Precision.DEFAULT (better step time, but not precise) b) - lax.Precision.HIGH (increased precision, slower) c) lax.Precision.HIGHEST - (best possible precision, slowest) - - Returns: - a GradientTransformation. - """ - - def quantized_dtype_for_momentum_buffers(): - return jnp.int8 if best_effort_memory_usage_reduction else jnp.float32 - - # TODO(rohananil): Explore int8-16 quantization with non-linear bucket sizes. - def quantized_dtype_for_diagonal_statistics_buffers(): - return jnp.bfloat16 if best_effort_memory_usage_reduction else jnp.float32 - - # Preconditioner and statistics are both stores as int16 in this mode. - # We take out the diagonal to make quantization easier. - def quantized_dtype_for_second_moment_statistics_buffers(): - return jnp.int16 if best_effort_memory_usage_reduction and batch_axis_name else jnp.float32 - - # Preconditioner and statistics are both stores as int16 in this mode. - # We take out the diagonal to make quantization easier. - def quantized_dtype_for_second_moment_preconditioner_buffers(): - return jnp.int16 if best_effort_memory_usage_reduction and batch_axis_name else jnp.float32 - - def _to_float(maybe_quantized): - if isinstance(maybe_quantized, QuantizedValue): - return maybe_quantized.to_float() - else: - return maybe_quantized - - def _maybe_quantize_statistics(statistics_list): - return _maybe_quantize_matrices_with_dtype( - statistics_list, quantized_dtype_for_second_moment_statistics_buffers()) - - def _maybe_quantize_preconditioners(statistics_list): - return _maybe_quantize_matrices_with_dtype( - statistics_list, - quantized_dtype_for_second_moment_preconditioner_buffers()) - - def _maybe_quantize_matrices_with_dtype(statistics_list, quantized_dtype): - if quantized_dtype != jnp.float32: - return ([ - QuantizedValue.from_float_value( - s, quantized_dtype, extract_diagonal=True) - for s in statistics_list - ]) - else: - return statistics_list + precision=lax.Precision.HIGHEST, +): + """Distributed Shampoo optimizer. - def _maybe_dequantize_preconditioners(preconditioner_list): - return _maybe_dequantize_matrices_with_dtype( - preconditioner_list, - quantized_dtype_for_second_moment_preconditioner_buffers()) + Distributed Shampoo is a second-order preconditioned method (concretely, a + variant of full-matrix Adagrad), that provides significant convergence and + wall-clock time improvements compared to conventional first-order methods, + and that has been shown to scale to large state-of-the-art deep learning + models. - def _maybe_dequantize_matrices_with_dtype(statistics_list, quantized_dtype): - if quantized_dtype != jnp.float32: - return [s.to_float() for s in statistics_list] - else: - return statistics_list - - def _quantize_diagonal_statistics(diagonal_statistics): - return QuantizedValue.from_float_value( - diagonal_statistics, quantized_dtype_for_diagonal_statistics_buffers()) - - def _quantize_momentum(momentum_statistics): - return QuantizedValue.from_float_value( - momentum_statistics, quantized_dtype_for_momentum_buffers()) - - def sharded_init_fn(params): - params_flat, treedef = jax.tree_flatten(params) - # Find max size to pad to. - max_size = 0 - for param in params_flat: - preconditioner = Preconditioner(param, block_size, - best_effort_shape_interpretation) - if not _skip_preconditioning(param): - shapes = preconditioner.shapes_for_preconditioners() - sizes = [s[0] for s in shapes] - max_size = max(max(sizes), max_size) - - padded_statistics = [] - padded_preconditioners = [] - local_stats_flat = [] - for param in params_flat: - preconditioner = Preconditioner(param, block_size, - best_effort_shape_interpretation) - shapes = preconditioner.shapes_for_preconditioners() - sizes = [] - - statistics = [] - preconditioners = [] - index_start = len(padded_statistics) - if not _skip_preconditioning(param): - sizes = [s[0] for s in shapes] - shapes = preconditioner.shapes_for_preconditioners() - statistics = [matrix_epsilon * jnp.eye(max_size) for s in shapes] - preconditioners = [jnp.eye(max_size) for s in shapes] - padded_statistics.extend(statistics) - padded_preconditioners.extend(preconditioners) - - diagonal_statistics = [] - if graft_type != GraftingType.SGD: - diagonal_statistics = jnp.zeros_like(param) - local_stats_flat.append( - LocalShardedParameterStats( - _quantize_diagonal_statistics(diagonal_statistics), - _quantize_momentum(jnp.zeros_like(param)), - _quantize_momentum(jnp.zeros_like(param)), index_start, sizes)) - - local_stats = jax.tree_unflatten(treedef, local_stats_flat) - # Pad the statistics and preconditioner matrices to be a multiple of - # num devices. - # TODO(rohananil): Relax to only the size of the mesh axis where the dim - # is split on. - to_pad = -len(padded_statistics) % num_devices_for_pjit - padded_statistics.extend([ - jnp.eye(max_size, dtype=padded_statistics[0].dtype) - for _ in range(to_pad) - ]) - padded_preconditioners.extend([ - jnp.eye(max_size, dtype=padded_statistics[0].dtype) - for _ in range(to_pad) - ]) - global_stats = GlobalShardedParameterStats( - jnp.stack(padded_statistics), jnp.stack(padded_preconditioners)) - return ShampooState( - count=jnp.zeros([], jnp.int32), - stats=ShardedShampooStats(global_stats, local_stats)) - - def sharded_update_fn(grads, state, params): - """Transform the input gradient and update all statistics in sharded mode. - - Args: - grads: the gradient tensors for the parameters. - state: a named tuple containing the state of the optimizer - params: the parameters that should be updated. + References: + Scalable Second Order Optimization for Deep Learning, + Rohan Anil, Vineet Gupta, Tomer Koren, Kevin Regan, Yoram Singer - Returns: - A tuple containing the new parameters and the new optimizer state. - """ - params_flat, treedef = jax.tree_flatten(params) - grads_flat = treedef.flatten_up_to(grads) - - global_stats = state.stats.global_stats - local_stats_flat = treedef.flatten_up_to(state.stats.local_stats) - stats_flat = [ - _convert_to_parameter_stats(global_stats, local_stat) - for local_stat in local_stats_flat - ] - new_stats_flat = jax.tree_multimap( - lambda g, s, p: _compute_stats(g, s, p, state.count), grads_flat, - stats_flat, params_flat) - - exponents = [] - for stat, param in zip(new_stats_flat, params_flat): - num_statistics = len(stat.statistics) - if num_statistics > 0: - preconditioner = Preconditioner(param, block_size, - best_effort_shape_interpretation) - exponent = ( - preconditioner.exponent_for_preconditioner() - if exponent_override == 0 else exponent_override) - exponents.extend([exponent] * num_statistics) - - outputs = jax.tree_multimap( - lambda g, s, p: _transform_grad(g, s, p, state.count), grads_flat, - new_stats_flat, params_flat) - updates_flat, new_stats_flat = list(zip(*outputs)) if outputs else ((), ()) - - updates = jax.tree_unflatten(treedef, updates_flat) - # Create new local_stats - new_local_stats_flat = [ - _convert_from_parameter_stats(new_stat, local_stat) - for new_stat, local_stat in zip(new_stats_flat, local_stats_flat) - ] - new_local_stats = jax.tree_unflatten(treedef, new_local_stats_flat) - - max_size = global_stats.statistics.shape[1] - new_padded_statistics = [] - for stat in new_stats_flat: - new_padded_statistics.extend( - [pad_matrix(stat, max_size) for stat in stat.statistics]) - - # Create global stats - # TODO(rohananil): Preconditioner is not updated every step, so cost of - # stack/pad can be obviated away. - # Pad the statistics and preconditioner matrices to be a multiple of - # num devices. - # TODO(rohananil): Relax to only the size of the mesh axis where the dim - # is split on. - to_pad = -len(new_padded_statistics) % num_devices_for_pjit - new_padded_statistics.extend([ - jnp.eye(max_size, dtype=new_padded_statistics[0].dtype) - for _ in range(to_pad) - ]) - exponents.extend([1 for _ in range(to_pad)]) - new_stacked_padded_statistics = jnp.stack(new_padded_statistics) - new_stacked_exponents = jnp.stack(exponents) - def _matrix_inverse_pth_root_vmap(xs, ps): - mi_pth_root = functools.partial( - matrix_inverse_pth_root, - ridge_epsilon=matrix_epsilon, - precision=precision) - preconditioners, errors = jax.vmap(mi_pth_root)(xs, ps) - return preconditioners, errors - - def _internal_inverse_pth_root_all(): - preconditioners, errors = _matrix_inverse_pth_root_vmap( - new_stacked_padded_statistics, new_stacked_exponents) - return preconditioners, errors - - if preconditioning_compute_steps == 1: - new_preconditioners, errors = _internal_inverse_pth_root_all() - else: - # Passing statistics instead of preconditioners as they are similarly - # shaped tensors. Note statistics will be ignored as we are passing in - # a large init value for error. - preconditioners_init = new_stacked_padded_statistics - errors_init = np.stack([inverse_failure_threshold] * len(exponents)) - init_state = [preconditioners_init, errors_init] - perform_step = state.count % preconditioning_compute_steps == 0 - new_preconditioners, errors = efficient_cond( - perform_step, _internal_inverse_pth_root_all, init_state) - - errors = errors.reshape((-1, 1, 1)) - predicate = jnp.logical_or( - jnp.isnan(errors), - errors >= inverse_failure_threshold).astype(new_preconditioners.dtype) - # TODO(rohananil): Check for numerical instabilities. - new_conditional_preconditioners = ( - predicate * global_stats.preconditioners + - (1.0 - predicate) * new_preconditioners) - new_global_stats = GlobalShardedParameterStats( - new_stacked_padded_statistics, new_conditional_preconditioners) - new_shampoo_state = ShampooState( - count=state.count + 1, - stats=ShardedShampooStats(new_global_stats, new_local_stats)) - return updates, new_shampoo_state - - def init_fn(params): - """Initialise the optimiser's state.""" - - def _init(param): - preconditioner = Preconditioner(param, block_size, - best_effort_shape_interpretation) - statistics = [] - preconditioners = [] - if not _skip_preconditioning(param): - shapes = preconditioner.shapes_for_preconditioners() - statistics = [matrix_epsilon * jnp.eye(s[0]) for s in shapes] - preconditioners = [jnp.eye(s[0]) for s in shapes] - - diagonal_statistics = [] - if graft_type != GraftingType.SGD: - diagonal_statistics = jnp.zeros_like(param) - return ParameterStats( - _quantize_diagonal_statistics(diagonal_statistics), - _maybe_quantize_statistics(statistics), - _maybe_quantize_preconditioners(preconditioners), - _quantize_momentum(jnp.zeros_like(param)), - _quantize_momentum(jnp.zeros_like(param))) - return ShampooState( - count=jnp.zeros([], jnp.int32), stats=jax.tree_map(_init, params)) - - def _skip_preconditioning(param): - return len(param.shape) < 1 or any( - [s > skip_preconditioning_dim_size_gt for s in param.shape]) - - def _compute_stats(grad, state, param, step): - """Compute per-parameter statistics.""" - preconditioner = Preconditioner(param, block_size, - best_effort_shape_interpretation) - new_statistics = [[]] * len(state.statistics) - w1 = beta2 - w2 = beta2 if beta2 == 1.0 else (1.0 - beta2) - if not _skip_preconditioning(param): - - def compute_updated_statistics(): - new_stats = preconditioner.statistics_from_grad(grad) - new_stats_accumulators = [] - for stat, stat_accumulator in zip(new_stats, state.statistics): - new_stats_accumulators.append(w1 * _to_float(stat_accumulator) + - w2 * stat) - return _maybe_quantize_statistics(new_stats_accumulators) - - if statistics_compute_steps > 1: - perform_step = step % statistics_compute_steps == 0 - init_state = state.statistics - new_statistics = list( - efficient_cond(perform_step, compute_updated_statistics, - init_state)) - else: - new_statistics = compute_updated_statistics() - return ParameterStats(state.diagonal_statistics, new_statistics, - state.preconditioners, state.diagonal_momentum, - state.momentum) - - def _matrix_inverse_pth_root_vmap(xs, ps): - mi_pth_root = functools.partial( - matrix_inverse_pth_root, - ridge_epsilon=matrix_epsilon, - precision=precision) - return jax.vmap(mi_pth_root)(xs, ps) - - def _quantized_matrix_inverse_pth_root_vmap(qxs, qds, qbs, ps): - - def _quantized_to_float(qx, qd, qb): - qv = QuantizedValue(qx, qd, qb, qx.dtype, True, list(qx.shape)) - return qv.to_float() - - def matrix_inverse_pth_root_wrapper(qx, qd, qb, p): - v = _quantized_to_float(qx, qd, qb) - preconditioner, error = matrix_inverse_pth_root( - v, p, ridge_epsilon=matrix_epsilon, precision=precision) - qp = QuantizedValue.from_float_value(preconditioner, qx.dtype, True) - return qp.quantized, qp.diagonal, qp.bucket_size, error - - return jax.vmap(matrix_inverse_pth_root_wrapper)(qxs, qds, qbs, ps) - - def _matrix_inverse_pth_root_pjit(xs, ps): - mesh_axis_names_tuple = tuple(mesh_axis_names) - # Partition the concatenated statistics matrix across all cores. - partitioned_xs, partitioned_ps = pjit.pjit( - lambda x, y: (x, y), - in_axis_resources=None, - out_axis_resources=pjit.PartitionSpec(mesh_axis_names_tuple,))(xs, ps) - # Run matrix inverse pth root on each shard. - partitioned_preconditioners, partitioned_errors = _matrix_inverse_pth_root_vmap( - partitioned_xs, partitioned_ps) - # Recombine the outputs at each core. - preconditioners, errors = pjit.pjit( - lambda x, y: (x, y), - in_axis_resources=(pjit.PartitionSpec(mesh_axis_names_tuple,), - pjit.PartitionSpec(mesh_axis_names_tuple,)), - out_axis_resources=(None, None))(partitioned_preconditioners, - partitioned_errors) - return preconditioners, errors - - def _pmap_compute_preconditioners(states, step, statistics, - num_statistics_per_state, original_shapes, - exponents, max_size, prev_preconditioners): - """Computes preconditioners for given statistics in states in PMAP mode. + Preprint: https://arxiv.org/abs/2002.09018 Args: - states: A list of optimizer states. - step: Current step number - statistics: A list of statistics for all variables (for every dim) - num_statistics_per_state: Number of statistis per state to reconstruct - output states. - original_shapes: A list of shapes of the statistics. - exponents: Exponent power to use for inverse-pth roots. - max_size: Maximum dim of the statistics to pad. - prev_preconditioners: Previously available preconditioner. + learning_rate: the step size used to update the parameters. + block_size: Block size for large layers (if > 0). Preconditioning compute + operation is cubic in the dimension of the tensor. Block size allows us to + chunk the layers into sub-layers of maximal dimension dictated by this + value. Use 128 as default (increase if you have compute budget). + beta1: momentum parameter. + beta2: second moment averaging parameter. + diagonal_epsilon: epsilon for diagonal adagrad (only if layerwise grafting + to AdaGrad is enabled). + matrix_epsilon: epsilon to add to statistics before computing inverse pth + root. If you are running in f32 precision for inverse pth root + (recommended today) this can go upto 1e-6. If you have latest hardware + with native f64 precision, set this upto 1e-12. + weight_decay: Weight decay for regularization. + start_preconditioning_step: When to start Shampoo update before which + diagonal update is used. This is because we dont have enough information + to do stable inverse. + preconditioning_compute_steps: How often to compute preconditioner. + Performance tuning params for controlling memory and compute requirements. + Ideally set this and statistics_compute_steps params to 1. + statistics_compute_steps: How often to compute statistics. + best_effort_shape_interpretation: If there are some small dimensions, + collapse them e.g. [1, 2, 512, 1, 2048, 1, 3, 4] --> [1024, 2048, 12] if + block = 1024, [1, 2, 768, 1, 2048] --> [2, 768, 2048] + graft_type: Grafting is a technique to fix the layerwise scale of Shampoo + optimizer. This allows us to plugin the Shampoo optimizer into settings + where SGD/AdaGrad is already well tuned. Available options are: + GraftingType.SGD and GraftingType.ADAGRAD. + nesterov: Nesterov momentum. + exponent_override: Override the exponent used in matrix inverse. + batch_axis_name: labeled axis over pmap for data-parallel training the + optimizer used for. + mesh_axis_names: Axis names for the mesh (used in pjit). + num_devices_for_pjit: Number of devices to parallelize over when using pjit. + shard_optimizer_states: Shard optimizer states to save memory in model + parallel training. + best_effort_memory_usage_reduction: Best effort memory usage reduction. + diagonal_statistics -> jnp.bfloat16 + momentum buffers (2x) -> jnp.int8 + statistics, preconditioners -> jnp.int16 + diagonals + inverse_failure_threshold: numerics are hard and inverses fail sometimes; we + determine that using this threshold. + moving_average_for_momentum: Whether to use moving average for momentum + instead of exponential moving average. + skip_preconditioning_dim_size_gt: Skip if preconditioning dim size is + greater than this value. + clip_by_scaled_gradient_norm: Clip by scaled gradient norm (only useful + when using RMSProp Grafting). + precision: precision XLA related flag, the available options are: a) + lax.Precision.DEFAULT (better step time, but not precise) b) + lax.Precision.HIGH (increased precision, slower) c) lax.Precision.HIGHEST + (best possible precision, slowest) Returns: - New optimizer states after computing the preconditioner. + a GradientTransformation. """ - num_devices = lax.psum(1, batch_axis_name) - num_statistics = len(statistics) - # Pad statistics and exponents to next multiple of num_devices. - packed_statistics = [pad_matrix(stat, max_size) for stat in statistics] - to_pad = -num_statistics % num_devices - packed_statistics.extend([ - jnp.eye(max_size, dtype=packed_statistics[0].dtype) - for _ in range(to_pad) - ]) - exponents.extend([1 for _ in range(to_pad)]) - - if not packed_statistics: - return states - - all_statistics = batch(packed_statistics, num_devices) - all_exponents = batch(exponents, num_devices) - - def _internal_inverse_pth_root_all(): - current_replica = lax.axis_index(batch_axis_name) - preconditioners, errors = _matrix_inverse_pth_root_vmap( - all_statistics[current_replica], all_exponents[current_replica]) - preconditioners = jax.lax.all_gather(preconditioners, batch_axis_name) - errors = jax.lax.all_gather(errors, batch_axis_name) - preconditioners_flat = unbatch(preconditioners) - errors_flat = unbatch(errors) - return preconditioners_flat, errors_flat - - if preconditioning_compute_steps == 1: - preconditioners_flat, errors_flat = _internal_inverse_pth_root_all() - else: - # Passing statistics instead of preconditioners as they are similarly - # shaped tensors. Note statistics will be ignored as we are passing in - # a large init value for error. - preconditioners_init = packed_statistics - errors_init = ([inverse_failure_threshold] * len(packed_statistics)) - init_state = [preconditioners_init, errors_init] - perform_step = step % preconditioning_compute_steps == 0 - preconditioners_flat, errors_flat = efficient_cond( - perform_step, _internal_inverse_pth_root_all, init_state) - - def _skip(error): - condition = jnp.logical_or( - jnp.isnan(error), error >= inverse_failure_threshold) - return condition.astype(error.dtype) - - def _select_preconditioner(error, new_p, old_p): - return lax.cond( - _skip(error), lambda _: old_p, lambda _: new_p, operand=None) - - new_preconditioners_flat = [] - for p, shape, prev_p, error in zip(preconditioners_flat, original_shapes, - prev_preconditioners, errors_flat): - new_preconditioners_flat.append( - _select_preconditioner(error, p[:shape[0], :shape[1]], prev_p)) - - assert len(states) == len(num_statistics_per_state) - assert len(new_preconditioners_flat) == num_statistics - - # Add back empty preconditioners so we that we can set the optimizer state. - preconditioners_for_states = [] - idx = 0 - for num_statistics, state in zip(num_statistics_per_state, states): - if num_statistics == 0: - preconditioners_for_states.append([]) - else: - preconditioners_for_state = new_preconditioners_flat[idx:idx + - num_statistics] - assert len(state.statistics) == len(preconditioners_for_state) - preconditioners_for_states.append(preconditioners_for_state) - idx += num_statistics - new_states = [] - for state, new_preconditioners in zip(states, preconditioners_for_states): - new_states.append( - ParameterStats(state.diagonal_statistics, state.statistics, - new_preconditioners, state.diagonal_momentum, - state.momentum)) - - return new_states - - def _pmap_quantized_compute_preconditioners(states, step, statistics, - num_statistics_per_state, - original_shapes, exponents, - max_size, prev_preconditioners): - """Computes preconditioners for given statistics in states in PMAP mode. - - For quantization, each statistic is represented by three values: - quantized matrix, diagonal, and bucket sizes, we run inverse pth-roots - without ever recreating the original matrix in f32. - - Args: - states: A list of optimizer states. - step: Current step number - statistics: A list of statistics for all variables (for every dim) - num_statistics_per_state: Number of statistis per state to reconstruct - output states. - original_shapes: A list of shapes of the statistics. - exponents: Exponent power to use for inverse-pth roots. - max_size: Maximum dim of the statistics to pad. - prev_preconditioners: Previously available preconditioner. - Returns: - New optimizer states after computing the preconditioner. - """ - num_devices = lax.psum(1, batch_axis_name) - num_statistics = len(statistics) - quantized_dtype = quantized_dtype_for_second_moment_statistics_buffers() - # Complexity here is around: shapes needing be statically shaped, - # our custom quantization type requires a different type of packing. - - # Parallel tensors: - # quantized [dxd] - # diagonals [d] f32 - # bucket_sizes [d] f32 - packed_quantized_statistics = [ - pad_matrix(stat.quantized, max_size) for stat in statistics - ] - packed_quantized_diagonals = [ - pad_vector(stat.diagonal, max_size) for stat in statistics - ] - packed_quantized_bucket_sizes = [ - pad_vector(stat.bucket_size, max_size) for stat in statistics - ] - - to_pad = -num_statistics % num_devices - padded_eye = jnp.eye(max_size, dtype=jnp.float32) - quantized_eye = QuantizedValue.from_float_value(padded_eye, quantized_dtype, - True) - packed_quantized_statistics.extend( - [quantized_eye.quantized for _ in range(to_pad)]) - packed_quantized_diagonals.extend( - [quantized_eye.diagonal for _ in range(to_pad)]) - packed_quantized_bucket_sizes.extend( - [quantized_eye.bucket_size for _ in range(to_pad)]) - exponents.extend([1 for _ in range(to_pad)]) - - if not packed_quantized_statistics: - return states - - all_quantized_statistics = batch(packed_quantized_statistics, num_devices) - all_quantized_diagonals = batch(packed_quantized_diagonals, num_devices) - all_quantized_bucket_sizes = batch(packed_quantized_bucket_sizes, - num_devices) - all_exponents = batch(exponents, num_devices) - - def _internal_inverse_pth_root_all(): - current_replica = lax.axis_index(batch_axis_name) - quantized_preconditioners, quantized_diagonals, quantized_bucket_sizes, errors = ( - _quantized_matrix_inverse_pth_root_vmap( - all_quantized_statistics[current_replica], - all_quantized_diagonals[current_replica], - all_quantized_bucket_sizes[current_replica], - all_exponents[current_replica])) - quantized_preconditioners = jax.lax.all_gather(quantized_preconditioners, - batch_axis_name) - quantized_diagonals = jax.lax.all_gather(quantized_diagonals, - batch_axis_name) - quantized_bucket_sizes = jax.lax.all_gather(quantized_bucket_sizes, - batch_axis_name) - errors = jax.lax.all_gather(errors, batch_axis_name) - quantized_preconditioners_flat = unbatch(quantized_preconditioners) - quantized_diagonals_flat = unbatch(quantized_diagonals) - quantized_bucket_sizes_flat = unbatch(quantized_bucket_sizes) - errors_flat = unbatch(errors) - return (quantized_preconditioners_flat, quantized_diagonals_flat, - quantized_bucket_sizes_flat, errors_flat) - - if preconditioning_compute_steps == 1: - (quantized_preconditioners_flat, quantized_diagonals_flat, - quantized_bucket_sizes_flat, errors_flat) = ( - _internal_inverse_pth_root_all()) - else: - # Passing statistics instead of preconditioners as they are similarly - # shaped tensors. Note statistics will be ignored as we are passing in - # a large init value for error. - quantized_preconditioners_init = packed_quantized_statistics - quantized_diagonals_init = packed_quantized_diagonals - quantized_bucket_sizes_init = packed_quantized_bucket_sizes - errors_init = ([inverse_failure_threshold] * - len(quantized_preconditioners_init)) - init_state = [ - quantized_preconditioners_init, quantized_diagonals_init, - quantized_bucket_sizes_init, errors_init - ] - perform_step = step % preconditioning_compute_steps == 0 - (quantized_preconditioners_flat, quantized_diagonals_flat, - quantized_bucket_sizes_flat, errors_flat) = ( - efficient_cond(perform_step, _internal_inverse_pth_root_all, - init_state)) - - def _skip(error): - condition = jnp.logical_or( - jnp.isnan(error), error >= inverse_failure_threshold) - return condition.astype(error.dtype) - - def _select_preconditioner(error, new_p, old_p): - return lax.cond( - _skip(error), lambda _: old_p, lambda _: new_p, operand=None) - - new_quantized_preconditioners_flat = [] - new_quantized_diagonals_flat = [] - new_quantized_bucket_sizes_flat = [] - for p, d, b, shape, prev_p, error in zip(quantized_preconditioners_flat, - quantized_diagonals_flat, - quantized_bucket_sizes_flat, - original_shapes, - prev_preconditioners, errors_flat): - new_quantized_preconditioners_flat.append( - _select_preconditioner(error, p[:shape[0], :shape[1]], - prev_p.quantized)) - new_quantized_diagonals_flat.append( - _select_preconditioner(error, d[:shape[0]], prev_p.diagonal)) - new_quantized_bucket_sizes_flat.append( - _select_preconditioner(error, b[:shape[0]], prev_p.bucket_size)) - - assert len(states) == len(num_statistics_per_state) - assert len(new_quantized_preconditioners_flat) == num_statistics - assert len(new_quantized_diagonals_flat) == num_statistics - assert len(new_quantized_bucket_sizes_flat) == num_statistics - - # Add back empty preconditioners so we that we can set the optimizer state. - preconditioners_for_states = [] - idx = 0 - for num_statistics, state in zip(num_statistics_per_state, states): - if num_statistics == 0: - preconditioners_for_states.append([]) - else: - quantized_preconditioners_for_state = new_quantized_preconditioners_flat[ - idx:idx + num_statistics] - quantized_diagonals_for_state = new_quantized_diagonals_flat[ - idx:idx + num_statistics] - quantized_bucket_sizes_for_state = new_quantized_bucket_sizes_flat[ - idx:idx + num_statistics] - - assert len(state.statistics) == len(quantized_preconditioners_for_state) - assert len(state.statistics) == len(quantized_diagonals_for_state) - assert len(state.statistics) == len(quantized_bucket_sizes_for_state) - - quantized_preconditioners = [] - for qv, qd, qb in zip(quantized_preconditioners_for_state, - quantized_diagonals_for_state, - quantized_bucket_sizes_for_state): - quantized_preconditioners.append( - QuantizedValue(qv, qd, qb, qv.dtype, True, list(qv.shape))) - preconditioners_for_states.append(quantized_preconditioners) - idx += num_statistics - new_states = [] - for state, new_preconditioners in zip(states, preconditioners_for_states): - new_states.append( - ParameterStats(state.diagonal_statistics, state.statistics, - new_preconditioners, state.diagonal_momentum, - state.momentum)) - - return new_states - - def _pjit_compute_preconditioners(states, step, statistics, - num_statistics_per_state, original_shapes, - exponents, max_size, prev_preconditioners): - """Computes preconditioners for given statistics in states in PJIT mode. - - Args: - states: A list of optimizer states. - step: Current step number - statistics: A list of statistics for all variables (for every dim) - num_statistics_per_state: Number of statistis per state to reconstruct - output states. - original_shapes: A list of shapes of the statistics. - exponents: Exponent power to use for inverse-pth roots. - max_size: Maximum dim of the statistics to pad. - prev_preconditioners: Previously available preconditioner. - - Returns: - New optimizer states after computing the preconditioner. - """ - num_statistics = len(statistics) - to_pad = -num_statistics % num_devices_for_pjit - padded_statistics = [pad_matrix(stat, max_size) for stat in statistics] - padded_statistics.extend([ - jnp.eye(max_size, dtype=padded_statistics[0].dtype) - for _ in range(to_pad) - ]) - exponents.extend([1 for _ in range(to_pad)]) - all_statistics = jnp.stack(padded_statistics) - all_exponents = jnp.stack(exponents) - - def _internal_inverse_pth_root_all(): - preconditioners, errors = _matrix_inverse_pth_root_pjit( - all_statistics, all_exponents) - b1 = preconditioners.shape[0] - - def split(batched_values): - return [ - jnp.squeeze(v) - for v in jnp.split(batched_values, indices_or_sections=b1, axis=0) + def quantized_dtype_for_momentum_buffers(): + return jnp.int8 if best_effort_memory_usage_reduction else jnp.float32 + + # TODO(rohananil): Explore int8-16 quantization with non-linear bucket sizes. + def quantized_dtype_for_diagonal_statistics_buffers(): + return jnp.bfloat16 if best_effort_memory_usage_reduction else jnp.float32 + + # Preconditioner and statistics are both stores as int16 in this mode. + # We take out the diagonal to make quantization easier. + def quantized_dtype_for_second_moment_statistics_buffers(): + return ( + jnp.int16 + if best_effort_memory_usage_reduction and batch_axis_name + else jnp.float32 + ) + + # Preconditioner and statistics are both stores as int16 in this mode. + # We take out the diagonal to make quantization easier. + def quantized_dtype_for_second_moment_preconditioner_buffers(): + return ( + jnp.int16 + if best_effort_memory_usage_reduction and batch_axis_name + else jnp.float32 + ) + + def _to_float(maybe_quantized): + if isinstance(maybe_quantized, QuantizedValue): + return maybe_quantized.to_float() + else: + return maybe_quantized + + def _maybe_quantize_statistics(statistics_list): + return _maybe_quantize_matrices_with_dtype( + statistics_list, quantized_dtype_for_second_moment_statistics_buffers() + ) + + def _maybe_quantize_preconditioners(statistics_list): + return _maybe_quantize_matrices_with_dtype( + statistics_list, quantized_dtype_for_second_moment_preconditioner_buffers() + ) + + def _maybe_quantize_matrices_with_dtype(statistics_list, quantized_dtype): + if quantized_dtype != jnp.float32: + return [ + QuantizedValue.from_float_value( + s, quantized_dtype, extract_diagonal=True + ) + for s in statistics_list + ] + else: + return statistics_list + + def _maybe_dequantize_preconditioners(preconditioner_list): + return _maybe_dequantize_matrices_with_dtype( + preconditioner_list, + quantized_dtype_for_second_moment_preconditioner_buffers(), + ) + + def _maybe_dequantize_matrices_with_dtype(statistics_list, quantized_dtype): + if quantized_dtype != jnp.float32: + return [s.to_float() for s in statistics_list] + else: + return statistics_list + + def _quantize_diagonal_statistics(diagonal_statistics): + return QuantizedValue.from_float_value( + diagonal_statistics, quantized_dtype_for_diagonal_statistics_buffers() + ) + + def _quantize_momentum(momentum_statistics): + return QuantizedValue.from_float_value( + momentum_statistics, quantized_dtype_for_momentum_buffers() + ) + + def sharded_init_fn(params): + params_flat, treedef = jax.tree_flatten(params) + # Find max size to pad to. + max_size = 0 + for param in params_flat: + preconditioner = Preconditioner( + param, block_size, best_effort_shape_interpretation + ) + if not _skip_preconditioning(param): + shapes = preconditioner.shapes_for_preconditioners() + sizes = [s[0] for s in shapes] + max_size = max(max(sizes), max_size) + + padded_statistics = [] + padded_preconditioners = [] + local_stats_flat = [] + for param in params_flat: + preconditioner = Preconditioner( + param, block_size, best_effort_shape_interpretation + ) + shapes = preconditioner.shapes_for_preconditioners() + sizes = [] + + statistics = [] + preconditioners = [] + index_start = len(padded_statistics) + if not _skip_preconditioning(param): + sizes = [s[0] for s in shapes] + shapes = preconditioner.shapes_for_preconditioners() + statistics = [matrix_epsilon * jnp.eye(max_size) for s in shapes] + preconditioners = [jnp.eye(max_size) for s in shapes] + padded_statistics.extend(statistics) + padded_preconditioners.extend(preconditioners) + + diagonal_statistics = [] + if graft_type != GraftingType.SGD: + diagonal_statistics = jnp.zeros_like(param) + local_stats_flat.append( + LocalShardedParameterStats( + _quantize_diagonal_statistics(diagonal_statistics), + _quantize_momentum(jnp.zeros_like(param)), + _quantize_momentum(jnp.zeros_like(param)), + index_start, + sizes, + ) + ) + + local_stats = jax.tree_unflatten(treedef, local_stats_flat) + # Pad the statistics and preconditioner matrices to be a multiple of + # num devices. + # TODO(rohananil): Relax to only the size of the mesh axis where the dim + # is split on. + to_pad = -len(padded_statistics) % num_devices_for_pjit + padded_statistics.extend( + [jnp.eye(max_size, dtype=padded_statistics[0].dtype) for _ in range(to_pad)] + ) + padded_preconditioners.extend( + [jnp.eye(max_size, dtype=padded_statistics[0].dtype) for _ in range(to_pad)] + ) + global_stats = GlobalShardedParameterStats( + jnp.stack(padded_statistics), jnp.stack(padded_preconditioners) + ) + return ShampooState( + count=jnp.zeros([], jnp.int32), + stats=ShardedShampooStats(global_stats, local_stats), + ) + + def sharded_update_fn(grads, state, params): + """Transform the input gradient and update all statistics in sharded mode. + + Args: + grads: the gradient tensors for the parameters. + state: a named tuple containing the state of the optimizer + params: the parameters that should be updated. + + Returns: + A tuple containing the new parameters and the new optimizer state. + """ + params_flat, treedef = jax.tree_flatten(params) + grads_flat = treedef.flatten_up_to(grads) + + global_stats = state.stats.global_stats + local_stats_flat = treedef.flatten_up_to(state.stats.local_stats) + stats_flat = [ + _convert_to_parameter_stats(global_stats, local_stat) + for local_stat in local_stats_flat ] + new_stats_flat = jax.tree_multimap( + lambda g, s, p: _compute_stats(g, s, p, state.count), + grads_flat, + stats_flat, + params_flat, + ) + + exponents = [] + for stat, param in zip(new_stats_flat, params_flat): + num_statistics = len(stat.statistics) + if num_statistics > 0: + preconditioner = Preconditioner( + param, block_size, best_effort_shape_interpretation + ) + exponent = ( + preconditioner.exponent_for_preconditioner() + if exponent_override == 0 + else exponent_override + ) + exponents.extend([exponent] * num_statistics) + + outputs = jax.tree_multimap( + lambda g, s, p: _transform_grad(g, s, p, state.count), + grads_flat, + new_stats_flat, + params_flat, + ) + updates_flat, new_stats_flat = list(zip(*outputs)) if outputs else ((), ()) + + updates = jax.tree_unflatten(treedef, updates_flat) + # Create new local_stats + new_local_stats_flat = [ + _convert_from_parameter_stats(new_stat, local_stat) + for new_stat, local_stat in zip(new_stats_flat, local_stats_flat) + ] + new_local_stats = jax.tree_unflatten(treedef, new_local_stats_flat) + + max_size = global_stats.statistics.shape[1] + new_padded_statistics = [] + for stat in new_stats_flat: + new_padded_statistics.extend( + [pad_matrix(stat, max_size) for stat in stat.statistics] + ) + + # Create global stats + # TODO(rohananil): Preconditioner is not updated every step, so cost of + # stack/pad can be obviated away. + # Pad the statistics and preconditioner matrices to be a multiple of + # num devices. + # TODO(rohananil): Relax to only the size of the mesh axis where the dim + # is split on. + to_pad = -len(new_padded_statistics) % num_devices_for_pjit + new_padded_statistics.extend( + [ + jnp.eye(max_size, dtype=new_padded_statistics[0].dtype) + for _ in range(to_pad) + ] + ) + exponents.extend([1 for _ in range(to_pad)]) + new_stacked_padded_statistics = jnp.stack(new_padded_statistics) + new_stacked_exponents = jnp.stack(exponents) + + def _matrix_inverse_pth_root_vmap(xs, ps): + mi_pth_root = functools.partial( + matrix_inverse_pth_root, + ridge_epsilon=matrix_epsilon, + precision=precision, + ) + preconditioners, errors = jax.vmap(mi_pth_root)(xs, ps) + return preconditioners, errors + + def _internal_inverse_pth_root_all(): + preconditioners, errors = _matrix_inverse_pth_root_vmap( + new_stacked_padded_statistics, new_stacked_exponents + ) + return preconditioners, errors + + if preconditioning_compute_steps == 1: + new_preconditioners, errors = _internal_inverse_pth_root_all() + else: + # Passing statistics instead of preconditioners as they are similarly + # shaped tensors. Note statistics will be ignored as we are passing in + # a large init value for error. + preconditioners_init = new_stacked_padded_statistics + errors_init = np.stack([inverse_failure_threshold] * len(exponents)) + init_state = [preconditioners_init, errors_init] + perform_step = state.count % preconditioning_compute_steps == 0 + new_preconditioners, errors = efficient_cond( + perform_step, _internal_inverse_pth_root_all, init_state + ) + + errors = errors.reshape((-1, 1, 1)) + predicate = jnp.logical_or( + jnp.isnan(errors), errors >= inverse_failure_threshold + ).astype(new_preconditioners.dtype) + # TODO(rohananil): Check for numerical instabilities. + new_conditional_preconditioners = ( + predicate * global_stats.preconditioners + + (1.0 - predicate) * new_preconditioners + ) + new_global_stats = GlobalShardedParameterStats( + new_stacked_padded_statistics, new_conditional_preconditioners + ) + new_shampoo_state = ShampooState( + count=state.count + 1, + stats=ShardedShampooStats(new_global_stats, new_local_stats), + ) + return updates, new_shampoo_state + + def init_fn(params): + """Initialise the optimiser's state.""" + + def _init(param): + preconditioner = Preconditioner( + param, block_size, best_effort_shape_interpretation + ) + statistics = [] + preconditioners = [] + if not _skip_preconditioning(param): + shapes = preconditioner.shapes_for_preconditioners() + statistics = [matrix_epsilon * jnp.eye(s[0]) for s in shapes] + preconditioners = [jnp.eye(s[0]) for s in shapes] + + diagonal_statistics = [] + if graft_type != GraftingType.SGD: + diagonal_statistics = jnp.zeros_like(param) + return ParameterStats( + _quantize_diagonal_statistics(diagonal_statistics), + _maybe_quantize_statistics(statistics), + _maybe_quantize_preconditioners(preconditioners), + _quantize_momentum(jnp.zeros_like(param)), + _quantize_momentum(jnp.zeros_like(param)), + ) + + return ShampooState( + count=jnp.zeros([], jnp.int32), stats=jax.tree_map(_init, params) + ) + + def _skip_preconditioning(param): + return len(param.shape) < 1 or any( + [s > skip_preconditioning_dim_size_gt for s in param.shape] + ) + + def _compute_stats(grad, state, param, step): + """Compute per-parameter statistics.""" + preconditioner = Preconditioner( + param, block_size, best_effort_shape_interpretation + ) + new_statistics = [[]] * len(state.statistics) + w1 = beta2 + w2 = beta2 if beta2 == 1.0 else (1.0 - beta2) + if not _skip_preconditioning(param): + + def compute_updated_statistics(): + new_stats = preconditioner.statistics_from_grad(grad) + new_stats_accumulators = [] + for stat, stat_accumulator in zip(new_stats, state.statistics): + new_stats_accumulators.append( + w1 * _to_float(stat_accumulator) + w2 * stat + ) + return _maybe_quantize_statistics(new_stats_accumulators) + + if statistics_compute_steps > 1: + perform_step = step % statistics_compute_steps == 0 + init_state = state.statistics + new_statistics = list( + efficient_cond(perform_step, compute_updated_statistics, init_state) + ) + else: + new_statistics = compute_updated_statistics() + return ParameterStats( + state.diagonal_statistics, + new_statistics, + state.preconditioners, + state.diagonal_momentum, + state.momentum, + ) - return split(preconditioners), split(errors) - - if preconditioning_compute_steps == 1: - preconditioners_flat, errors_flat = _internal_inverse_pth_root_all() - else: - # Passing statistics instead of preconditioners as they are similarly - # shaped tensors. Note statistics will be ignored as we are passing in - # a large init value for error. - preconditioners_init = padded_statistics - errors_init = [inverse_failure_threshold] * len(padded_statistics) - init_state = [preconditioners_init, errors_init] - perform_step = step % preconditioning_compute_steps == 0 - preconditioners_flat, errors_flat = efficient_cond( - perform_step, _internal_inverse_pth_root_all, init_state) - - def _skip(error): - condition = jnp.logical_or( - jnp.isnan(error), error >= inverse_failure_threshold) - return condition.astype(error.dtype) - - def _select_preconditioner(error, new_p, old_p): - return lax.cond( - _skip(error), lambda _: old_p, lambda _: new_p, operand=None) - - new_preconditioners_flat = [] - for p, shape, prev_p, error in zip(preconditioners_flat, original_shapes, - prev_preconditioners, errors_flat): - new_preconditioners_flat.append( - _select_preconditioner(error, p[:shape[0], :shape[1]], prev_p)) - - assert len(states) == len(num_statistics_per_state) - assert len(new_preconditioners_flat) == num_statistics - - # Add back empty preconditioners so we that we can set the optimizer state. - preconditioners_for_states = [] - idx = 0 - for num_statistics, state in zip(num_statistics_per_state, states): - if num_statistics == 0: - preconditioners_for_states.append([]) - else: - preconditioners_for_state = new_preconditioners_flat[idx:idx + - num_statistics] - assert len(state.statistics) == len(preconditioners_for_state) - preconditioners_for_states.append(preconditioners_for_state) - idx += num_statistics - new_states = [] - for state, new_preconditioners in zip(states, preconditioners_for_states): - new_states.append( - ParameterStats(state.diagonal_statistics, state.statistics, - new_preconditioners, state.diagonal_momentum, - state.momentum)) - - return new_states - - def _compute_preconditioners(states, params, step): - """Computes preconditioners for given statistics in states. - - Args: - states: A list of optimizer states. - params: A list of params. - step: Current step number - - Returns: - New optimizer states after computing the preconditioner. - """ - statistics = [] - num_statistics_per_state = [] - original_shapes = [] - exponents = [] - max_size = 0 - prev_preconditioners = [] - - for state, param in zip(states, params): - num_statistics = len(state.statistics) - num_statistics_per_state.append(num_statistics) - original_shapes_for_state = [] - if num_statistics > 0: - preconditioner = Preconditioner(param, block_size, - best_effort_shape_interpretation) - for statistic in state.statistics: - exponents.append(preconditioner.exponent_for_preconditioner( - ) if exponent_override == 0 else exponent_override) - original_shapes_for_state.append(statistic.shape) - max_size = max(max_size, statistic.shape[0]) - - statistics.extend(state.statistics) - prev_preconditioners.extend(state.preconditioners) - original_shapes.extend(original_shapes_for_state) - - if batch_axis_name: - # Quantization is only enabled if batch_axis_name is not set. - quantized_dtype = quantized_dtype_for_second_moment_statistics_buffers() - - if quantized_dtype == jnp.float32: - return _pmap_compute_preconditioners(states, step, statistics, - num_statistics_per_state, - original_shapes, exponents, - max_size, prev_preconditioners) - else: - return _pmap_quantized_compute_preconditioners( - states, step, statistics, num_statistics_per_state, original_shapes, - exponents, max_size, prev_preconditioners) - - else: - return _pjit_compute_preconditioners(states, step, statistics, - num_statistics_per_state, - original_shapes, exponents, max_size, - prev_preconditioners) - - def _transform_grad(grad, state, param, step): - """Transform per-parameter gradients.""" - preconditioner = Preconditioner(param, block_size, - best_effort_shape_interpretation) - sgd_update = grad - new_diagonal_statistics = state.diagonal_statistics.to_float() - if graft_type == GraftingType.ADAGRAD: - new_diagonal_statistics = state.diagonal_statistics.to_float( - ) + jnp.square(grad) - adagrad_update = grad / ( - jnp.sqrt(new_diagonal_statistics) + diagonal_epsilon) - grafting_update = adagrad_update - elif (graft_type == GraftingType.RMSPROP or - graft_type == GraftingType.RMSPROP_NORMALIZED): - - scaled_grad = grad - if graft_type == GraftingType.RMSPROP_NORMALIZED: - scaled_grad = grad / jnp.linalg.norm(grad) - - w1 = beta2 - w2 = beta2 if beta2 == 1.0 else (1.0 - beta2) - - new_diagonal_statistics = ( - w1 * state.diagonal_statistics.to_float() + - w2 * jnp.square(scaled_grad)) - rmsprop_update = scaled_grad / ( - jnp.sqrt(new_diagonal_statistics) + diagonal_epsilon) - - if clip_by_scaled_gradient_norm: - scaled_grad_norm = jnp.linalg.norm(rmsprop_update) / ( - jnp.sqrt(float(rmsprop_update.size))) - clipping_denom = jnp.maximum( - 1., scaled_grad_norm / clip_by_scaled_gradient_norm) - rmsprop_update /= clipping_denom - - grafting_update = rmsprop_update - else: - grafting_update = sgd_update + def _matrix_inverse_pth_root_vmap(xs, ps): + mi_pth_root = functools.partial( + matrix_inverse_pth_root, ridge_epsilon=matrix_epsilon, precision=precision + ) + return jax.vmap(mi_pth_root)(xs, ps) + + def _quantized_matrix_inverse_pth_root_vmap(qxs, qds, qbs, ps): + def _quantized_to_float(qx, qd, qb): + qv = QuantizedValue(qx, qd, qb, qx.dtype, True, list(qx.shape)) + return qv.to_float() + + def matrix_inverse_pth_root_wrapper(qx, qd, qb, p): + v = _quantized_to_float(qx, qd, qb) + preconditioner, error = matrix_inverse_pth_root( + v, p, ridge_epsilon=matrix_epsilon, precision=precision + ) + qp = QuantizedValue.from_float_value(preconditioner, qx.dtype, True) + return qp.quantized, qp.diagonal, qp.bucket_size, error + + return jax.vmap(matrix_inverse_pth_root_wrapper)(qxs, qds, qbs, ps) + + def _matrix_inverse_pth_root_pjit(xs, ps): + mesh_axis_names_tuple = tuple(mesh_axis_names) + # Partition the concatenated statistics matrix across all cores. + partitioned_xs, partitioned_ps = pjit.pjit( + lambda x, y: (x, y), + in_axis_resources=None, + out_axis_resources=pjit.PartitionSpec( + mesh_axis_names_tuple, + ), + )(xs, ps) + # Run matrix inverse pth root on each shard. + partitioned_preconditioners, partitioned_errors = _matrix_inverse_pth_root_vmap( + partitioned_xs, partitioned_ps + ) + # Recombine the outputs at each core. + preconditioners, errors = pjit.pjit( + lambda x, y: (x, y), + in_axis_resources=( + pjit.PartitionSpec( + mesh_axis_names_tuple, + ), + pjit.PartitionSpec( + mesh_axis_names_tuple, + ), + ), + out_axis_resources=(None, None), + )(partitioned_preconditioners, partitioned_errors) + return preconditioners, errors + + def _pmap_compute_preconditioners( + states, + step, + statistics, + num_statistics_per_state, + original_shapes, + exponents, + max_size, + prev_preconditioners, + ): + """Computes preconditioners for given statistics in states in PMAP mode. + + Args: + states: A list of optimizer states. + step: Current step number + statistics: A list of statistics for all variables (for every dim) + num_statistics_per_state: Number of statistis per state to reconstruct + output states. + original_shapes: A list of shapes of the statistics. + exponents: Exponent power to use for inverse-pth roots. + max_size: Maximum dim of the statistics to pad. + prev_preconditioners: Previously available preconditioner. + + Returns: + New optimizer states after computing the preconditioner. + """ + num_devices = lax.psum(1, batch_axis_name) + num_statistics = len(statistics) + # Pad statistics and exponents to next multiple of num_devices. + packed_statistics = [pad_matrix(stat, max_size) for stat in statistics] + to_pad = -num_statistics % num_devices + packed_statistics.extend( + [jnp.eye(max_size, dtype=packed_statistics[0].dtype) for _ in range(to_pad)] + ) + exponents.extend([1 for _ in range(to_pad)]) + + if not packed_statistics: + return states + + all_statistics = batch(packed_statistics, num_devices) + all_exponents = batch(exponents, num_devices) + + def _internal_inverse_pth_root_all(): + current_replica = lax.axis_index(batch_axis_name) + preconditioners, errors = _matrix_inverse_pth_root_vmap( + all_statistics[current_replica], all_exponents[current_replica] + ) + preconditioners = jax.lax.all_gather(preconditioners, batch_axis_name) + errors = jax.lax.all_gather(errors, batch_axis_name) + preconditioners_flat = unbatch(preconditioners) + errors_flat = unbatch(errors) + return preconditioners_flat, errors_flat + + if preconditioning_compute_steps == 1: + preconditioners_flat, errors_flat = _internal_inverse_pth_root_all() + else: + # Passing statistics instead of preconditioners as they are similarly + # shaped tensors. Note statistics will be ignored as we are passing in + # a large init value for error. + preconditioners_init = packed_statistics + errors_init = [inverse_failure_threshold] * len(packed_statistics) + init_state = [preconditioners_init, errors_init] + perform_step = step % preconditioning_compute_steps == 0 + preconditioners_flat, errors_flat = efficient_cond( + perform_step, _internal_inverse_pth_root_all, init_state + ) + + def _skip(error): + condition = jnp.logical_or( + jnp.isnan(error), error >= inverse_failure_threshold + ) + return condition.astype(error.dtype) + + def _select_preconditioner(error, new_p, old_p): + return lax.cond( + _skip(error), lambda _: old_p, lambda _: new_p, operand=None + ) + + new_preconditioners_flat = [] + for p, shape, prev_p, error in zip( + preconditioners_flat, original_shapes, prev_preconditioners, errors_flat + ): + new_preconditioners_flat.append( + _select_preconditioner(error, p[: shape[0], : shape[1]], prev_p) + ) + + assert len(states) == len(num_statistics_per_state) + assert len(new_preconditioners_flat) == num_statistics + + # Add back empty preconditioners so we that we can set the optimizer state. + preconditioners_for_states = [] + idx = 0 + for num_statistics, state in zip(num_statistics_per_state, states): + if num_statistics == 0: + preconditioners_for_states.append([]) + else: + preconditioners_for_state = new_preconditioners_flat[ + idx : idx + num_statistics + ] + assert len(state.statistics) == len(preconditioners_for_state) + preconditioners_for_states.append(preconditioners_for_state) + idx += num_statistics + new_states = [] + for state, new_preconditioners in zip(states, preconditioners_for_states): + new_states.append( + ParameterStats( + state.diagonal_statistics, + state.statistics, + new_preconditioners, + state.diagonal_momentum, + state.momentum, + ) + ) + + return new_states + + def _pmap_quantized_compute_preconditioners( + states, + step, + statistics, + num_statistics_per_state, + original_shapes, + exponents, + max_size, + prev_preconditioners, + ): + """Computes preconditioners for given statistics in states in PMAP mode. + + For quantization, each statistic is represented by three values: + quantized matrix, diagonal, and bucket sizes, we run inverse pth-roots + without ever recreating the original matrix in f32. + + Args: + states: A list of optimizer states. + step: Current step number + statistics: A list of statistics for all variables (for every dim) + num_statistics_per_state: Number of statistis per state to reconstruct + output states. + original_shapes: A list of shapes of the statistics. + exponents: Exponent power to use for inverse-pth roots. + max_size: Maximum dim of the statistics to pad. + prev_preconditioners: Previously available preconditioner. + + Returns: + New optimizer states after computing the preconditioner. + """ + num_devices = lax.psum(1, batch_axis_name) + num_statistics = len(statistics) + quantized_dtype = quantized_dtype_for_second_moment_statistics_buffers() + # Complexity here is around: shapes needing be statically shaped, + # our custom quantization type requires a different type of packing. + + # Parallel tensors: + # quantized [dxd] + # diagonals [d] f32 + # bucket_sizes [d] f32 + packed_quantized_statistics = [ + pad_matrix(stat.quantized, max_size) for stat in statistics + ] + packed_quantized_diagonals = [ + pad_vector(stat.diagonal, max_size) for stat in statistics + ] + packed_quantized_bucket_sizes = [ + pad_vector(stat.bucket_size, max_size) for stat in statistics + ] - precond_grad = grad - if not _skip_preconditioning(param): - precond_grad = preconditioner.preconditioned_grad( - precond_grad, - _maybe_dequantize_preconditioners(state.preconditioners)) + to_pad = -num_statistics % num_devices + padded_eye = jnp.eye(max_size, dtype=jnp.float32) + quantized_eye = QuantizedValue.from_float_value( + padded_eye, quantized_dtype, True + ) + packed_quantized_statistics.extend( + [quantized_eye.quantized for _ in range(to_pad)] + ) + packed_quantized_diagonals.extend( + [quantized_eye.diagonal for _ in range(to_pad)] + ) + packed_quantized_bucket_sizes.extend( + [quantized_eye.bucket_size for _ in range(to_pad)] + ) + exponents.extend([1 for _ in range(to_pad)]) + + if not packed_quantized_statistics: + return states + + all_quantized_statistics = batch(packed_quantized_statistics, num_devices) + all_quantized_diagonals = batch(packed_quantized_diagonals, num_devices) + all_quantized_bucket_sizes = batch(packed_quantized_bucket_sizes, num_devices) + all_exponents = batch(exponents, num_devices) + + def _internal_inverse_pth_root_all(): + current_replica = lax.axis_index(batch_axis_name) + ( + quantized_preconditioners, + quantized_diagonals, + quantized_bucket_sizes, + errors, + ) = _quantized_matrix_inverse_pth_root_vmap( + all_quantized_statistics[current_replica], + all_quantized_diagonals[current_replica], + all_quantized_bucket_sizes[current_replica], + all_exponents[current_replica], + ) + quantized_preconditioners = jax.lax.all_gather( + quantized_preconditioners, batch_axis_name + ) + quantized_diagonals = jax.lax.all_gather( + quantized_diagonals, batch_axis_name + ) + quantized_bucket_sizes = jax.lax.all_gather( + quantized_bucket_sizes, batch_axis_name + ) + errors = jax.lax.all_gather(errors, batch_axis_name) + quantized_preconditioners_flat = unbatch(quantized_preconditioners) + quantized_diagonals_flat = unbatch(quantized_diagonals) + quantized_bucket_sizes_flat = unbatch(quantized_bucket_sizes) + errors_flat = unbatch(errors) + return ( + quantized_preconditioners_flat, + quantized_diagonals_flat, + quantized_bucket_sizes_flat, + errors_flat, + ) + + if preconditioning_compute_steps == 1: + ( + quantized_preconditioners_flat, + quantized_diagonals_flat, + quantized_bucket_sizes_flat, + errors_flat, + ) = _internal_inverse_pth_root_all() + else: + # Passing statistics instead of preconditioners as they are similarly + # shaped tensors. Note statistics will be ignored as we are passing in + # a large init value for error. + quantized_preconditioners_init = packed_quantized_statistics + quantized_diagonals_init = packed_quantized_diagonals + quantized_bucket_sizes_init = packed_quantized_bucket_sizes + errors_init = [inverse_failure_threshold] * len( + quantized_preconditioners_init + ) + init_state = [ + quantized_preconditioners_init, + quantized_diagonals_init, + quantized_bucket_sizes_init, + errors_init, + ] + perform_step = step % preconditioning_compute_steps == 0 + ( + quantized_preconditioners_flat, + quantized_diagonals_flat, + quantized_bucket_sizes_flat, + errors_flat, + ) = efficient_cond(perform_step, _internal_inverse_pth_root_all, init_state) + + def _skip(error): + condition = jnp.logical_or( + jnp.isnan(error), error >= inverse_failure_threshold + ) + return condition.astype(error.dtype) + + def _select_preconditioner(error, new_p, old_p): + return lax.cond( + _skip(error), lambda _: old_p, lambda _: new_p, operand=None + ) + + new_quantized_preconditioners_flat = [] + new_quantized_diagonals_flat = [] + new_quantized_bucket_sizes_flat = [] + for p, d, b, shape, prev_p, error in zip( + quantized_preconditioners_flat, + quantized_diagonals_flat, + quantized_bucket_sizes_flat, + original_shapes, + prev_preconditioners, + errors_flat, + ): + new_quantized_preconditioners_flat.append( + _select_preconditioner( + error, p[: shape[0], : shape[1]], prev_p.quantized + ) + ) + new_quantized_diagonals_flat.append( + _select_preconditioner(error, d[: shape[0]], prev_p.diagonal) + ) + new_quantized_bucket_sizes_flat.append( + _select_preconditioner(error, b[: shape[0]], prev_p.bucket_size) + ) + + assert len(states) == len(num_statistics_per_state) + assert len(new_quantized_preconditioners_flat) == num_statistics + assert len(new_quantized_diagonals_flat) == num_statistics + assert len(new_quantized_bucket_sizes_flat) == num_statistics + + # Add back empty preconditioners so we that we can set the optimizer state. + preconditioners_for_states = [] + idx = 0 + for num_statistics, state in zip(num_statistics_per_state, states): + if num_statistics == 0: + preconditioners_for_states.append([]) + else: + quantized_preconditioners_for_state = ( + new_quantized_preconditioners_flat[idx : idx + num_statistics] + ) + quantized_diagonals_for_state = new_quantized_diagonals_flat[ + idx : idx + num_statistics + ] + quantized_bucket_sizes_for_state = new_quantized_bucket_sizes_flat[ + idx : idx + num_statistics + ] + + assert len(state.statistics) == len(quantized_preconditioners_for_state) + assert len(state.statistics) == len(quantized_diagonals_for_state) + assert len(state.statistics) == len(quantized_bucket_sizes_for_state) + + quantized_preconditioners = [] + for qv, qd, qb in zip( + quantized_preconditioners_for_state, + quantized_diagonals_for_state, + quantized_bucket_sizes_for_state, + ): + quantized_preconditioners.append( + QuantizedValue(qv, qd, qb, qv.dtype, True, list(qv.shape)) + ) + preconditioners_for_states.append(quantized_preconditioners) + idx += num_statistics + new_states = [] + for state, new_preconditioners in zip(states, preconditioners_for_states): + new_states.append( + ParameterStats( + state.diagonal_statistics, + state.statistics, + new_preconditioners, + state.diagonal_momentum, + state.momentum, + ) + ) + + return new_states + + def _pjit_compute_preconditioners( + states, + step, + statistics, + num_statistics_per_state, + original_shapes, + exponents, + max_size, + prev_preconditioners, + ): + """Computes preconditioners for given statistics in states in PJIT mode. + + Args: + states: A list of optimizer states. + step: Current step number + statistics: A list of statistics for all variables (for every dim) + num_statistics_per_state: Number of statistis per state to reconstruct + output states. + original_shapes: A list of shapes of the statistics. + exponents: Exponent power to use for inverse-pth roots. + max_size: Maximum dim of the statistics to pad. + prev_preconditioners: Previously available preconditioner. + + Returns: + New optimizer states after computing the preconditioner. + """ + num_statistics = len(statistics) + to_pad = -num_statistics % num_devices_for_pjit + padded_statistics = [pad_matrix(stat, max_size) for stat in statistics] + padded_statistics.extend( + [jnp.eye(max_size, dtype=padded_statistics[0].dtype) for _ in range(to_pad)] + ) + exponents.extend([1 for _ in range(to_pad)]) + all_statistics = jnp.stack(padded_statistics) + all_exponents = jnp.stack(exponents) + + def _internal_inverse_pth_root_all(): + preconditioners, errors = _matrix_inverse_pth_root_pjit( + all_statistics, all_exponents + ) + b1 = preconditioners.shape[0] + + def split(batched_values): + return [ + jnp.squeeze(v) + for v in jnp.split(batched_values, indices_or_sections=b1, axis=0) + ] + + return split(preconditioners), split(errors) + + if preconditioning_compute_steps == 1: + preconditioners_flat, errors_flat = _internal_inverse_pth_root_all() + else: + # Passing statistics instead of preconditioners as they are similarly + # shaped tensors. Note statistics will be ignored as we are passing in + # a large init value for error. + preconditioners_init = padded_statistics + errors_init = [inverse_failure_threshold] * len(padded_statistics) + init_state = [preconditioners_init, errors_init] + perform_step = step % preconditioning_compute_steps == 0 + preconditioners_flat, errors_flat = efficient_cond( + perform_step, _internal_inverse_pth_root_all, init_state + ) + + def _skip(error): + condition = jnp.logical_or( + jnp.isnan(error), error >= inverse_failure_threshold + ) + return condition.astype(error.dtype) + + def _select_preconditioner(error, new_p, old_p): + return lax.cond( + _skip(error), lambda _: old_p, lambda _: new_p, operand=None + ) + + new_preconditioners_flat = [] + for p, shape, prev_p, error in zip( + preconditioners_flat, original_shapes, prev_preconditioners, errors_flat + ): + new_preconditioners_flat.append( + _select_preconditioner(error, p[: shape[0], : shape[1]], prev_p) + ) + + assert len(states) == len(num_statistics_per_state) + assert len(new_preconditioners_flat) == num_statistics + + # Add back empty preconditioners so we that we can set the optimizer state. + preconditioners_for_states = [] + idx = 0 + for num_statistics, state in zip(num_statistics_per_state, states): + if num_statistics == 0: + preconditioners_for_states.append([]) + else: + preconditioners_for_state = new_preconditioners_flat[ + idx : idx + num_statistics + ] + assert len(state.statistics) == len(preconditioners_for_state) + preconditioners_for_states.append(preconditioners_for_state) + idx += num_statistics + new_states = [] + for state, new_preconditioners in zip(states, preconditioners_for_states): + new_states.append( + ParameterStats( + state.diagonal_statistics, + state.statistics, + new_preconditioners, + state.diagonal_momentum, + state.momentum, + ) + ) + + return new_states + + def _compute_preconditioners(states, params, step): + """Computes preconditioners for given statistics in states. + + Args: + states: A list of optimizer states. + params: A list of params. + step: Current step number + + Returns: + New optimizer states after computing the preconditioner. + """ + statistics = [] + num_statistics_per_state = [] + original_shapes = [] + exponents = [] + max_size = 0 + prev_preconditioners = [] + + for state, param in zip(states, params): + num_statistics = len(state.statistics) + num_statistics_per_state.append(num_statistics) + original_shapes_for_state = [] + if num_statistics > 0: + preconditioner = Preconditioner( + param, block_size, best_effort_shape_interpretation + ) + for statistic in state.statistics: + exponents.append( + preconditioner.exponent_for_preconditioner() + if exponent_override == 0 + else exponent_override + ) + original_shapes_for_state.append(statistic.shape) + max_size = max(max_size, statistic.shape[0]) + + statistics.extend(state.statistics) + prev_preconditioners.extend(state.preconditioners) + original_shapes.extend(original_shapes_for_state) + + if batch_axis_name: + # Quantization is only enabled if batch_axis_name is not set. + quantized_dtype = quantized_dtype_for_second_moment_statistics_buffers() + + if quantized_dtype == jnp.float32: + return _pmap_compute_preconditioners( + states, + step, + statistics, + num_statistics_per_state, + original_shapes, + exponents, + max_size, + prev_preconditioners, + ) + else: + return _pmap_quantized_compute_preconditioners( + states, + step, + statistics, + num_statistics_per_state, + original_shapes, + exponents, + max_size, + prev_preconditioners, + ) + + else: + return _pjit_compute_preconditioners( + states, + step, + statistics, + num_statistics_per_state, + original_shapes, + exponents, + max_size, + prev_preconditioners, + ) + + def _transform_grad(grad, state, param, step): + """Transform per-parameter gradients.""" + preconditioner = Preconditioner( + param, block_size, best_effort_shape_interpretation + ) + sgd_update = grad + new_diagonal_statistics = state.diagonal_statistics.to_float() + if graft_type == GraftingType.ADAGRAD: + new_diagonal_statistics = state.diagonal_statistics.to_float() + jnp.square( + grad + ) + adagrad_update = grad / ( + jnp.sqrt(new_diagonal_statistics) + diagonal_epsilon + ) + grafting_update = adagrad_update + elif ( + graft_type == GraftingType.RMSPROP + or graft_type == GraftingType.RMSPROP_NORMALIZED + ): + + scaled_grad = grad + if graft_type == GraftingType.RMSPROP_NORMALIZED: + scaled_grad = grad / jnp.linalg.norm(grad) + + w1 = beta2 + w2 = beta2 if beta2 == 1.0 else (1.0 - beta2) + + new_diagonal_statistics = ( + w1 * state.diagonal_statistics.to_float() + w2 * jnp.square(scaled_grad) + ) + rmsprop_update = scaled_grad / ( + jnp.sqrt(new_diagonal_statistics) + diagonal_epsilon + ) + + if clip_by_scaled_gradient_norm: + scaled_grad_norm = jnp.linalg.norm(rmsprop_update) / ( + jnp.sqrt(float(rmsprop_update.size)) + ) + clipping_denom = jnp.maximum( + 1.0, scaled_grad_norm / clip_by_scaled_gradient_norm + ) + rmsprop_update /= clipping_denom + + grafting_update = rmsprop_update + else: + grafting_update = sgd_update + + precond_grad = grad + if not _skip_preconditioning(param): + precond_grad = preconditioner.preconditioned_grad( + precond_grad, _maybe_dequantize_preconditioners(state.preconditioners) + ) + else: + precond_grad = grafting_update + + grafting_update_norm = jnp.linalg.norm(grafting_update) + precond_grad_norm = jnp.linalg.norm(precond_grad) + + multiplier = grafting_update_norm / (precond_grad_norm + 1e-16) + shampoo_update = precond_grad * multiplier + + shampoo_update_with_wd = shampoo_update + grafting_update_with_wd = grafting_update + if weight_decay != 0: + shampoo_update_with_wd = shampoo_update + weight_decay * param + grafting_update_with_wd = grafting_update + weight_decay * param + + w = (1.0 - beta1) if moving_average_for_momentum else 1.0 + shampoo_update_with_wd_momentum = ( + state.momentum.to_float() * beta1 + w * shampoo_update_with_wd + ) + grafting_update_with_wd_momentum = ( + state.diagonal_momentum.to_float() * beta1 + w * grafting_update_with_wd + ) + + run_shampoo = (step >= start_preconditioning_step).astype( + grafting_update_with_wd_momentum.dtype + ) + + momentum_update = ( + run_shampoo * shampoo_update_with_wd_momentum + + (1.0 - run_shampoo) * grafting_update_with_wd_momentum + ) + + wd_update = ( + run_shampoo * shampoo_update_with_wd + + (1.0 - run_shampoo) * grafting_update_with_wd + ) + + if nesterov: + momentum_update = w * wd_update + beta1 * momentum_update + + lr = learning_rate + if callable(learning_rate): + lr = learning_rate(step) + transformed_update = -1.0 * lr * momentum_update + + param_stats = ParameterStats( + _quantize_diagonal_statistics(new_diagonal_statistics), + state.statistics, + state.preconditioners, + _quantize_momentum(grafting_update_with_wd_momentum), + _quantize_momentum(shampoo_update_with_wd_momentum), + ) + return transformed_update, param_stats + + def update_fn(grads, state, params): + """Transform the input gradient and update all statistics. + + Args: + grads: the gradient tensors for the parameters. + state: a named tuple containing the state of the optimizer + params: the parameters that should be updated. + + Returns: + A tuple containing the new parameters and the new optimizer state. + """ + params_flat, treedef = jax.tree_flatten(params) + stats_flat = treedef.flatten_up_to(state.stats) + grads_flat = treedef.flatten_up_to(grads) + + new_stats_flat = jax.tree_multimap( + lambda g, s, p: _compute_stats(g, s, p, state.count), + grads_flat, + stats_flat, + params_flat, + ) + new_stats_flat = _compute_preconditioners( + new_stats_flat, params_flat, state.count + ) + + outputs = jax.tree_multimap( + lambda g, s, p: _transform_grad(g, s, p, state.count), + grads_flat, + new_stats_flat, + params_flat, + ) + updates_flat, new_stats_flat = list(zip(*outputs)) if outputs else ((), ()) + + updates = jax.tree_unflatten(treedef, updates_flat) + new_stats = jax.tree_unflatten(treedef, new_stats_flat) + + new_state = ShampooState(count=state.count + 1, stats=new_stats) + return updates, new_state + + if shard_optimizer_states: + return optax.GradientTransformation(sharded_init_fn, sharded_update_fn) else: - precond_grad = grafting_update - - grafting_update_norm = jnp.linalg.norm(grafting_update) - precond_grad_norm = jnp.linalg.norm(precond_grad) - - multiplier = (grafting_update_norm / (precond_grad_norm + 1e-16)) - shampoo_update = precond_grad * multiplier - - shampoo_update_with_wd = shampoo_update - grafting_update_with_wd = grafting_update - if weight_decay != 0: - shampoo_update_with_wd = shampoo_update + weight_decay * param - grafting_update_with_wd = grafting_update + weight_decay * param - - w = (1.0 - beta1) if moving_average_for_momentum else 1.0 - shampoo_update_with_wd_momentum = ( - state.momentum.to_float() * beta1 + w * shampoo_update_with_wd) - grafting_update_with_wd_momentum = ( - state.diagonal_momentum.to_float() * beta1 + - w * grafting_update_with_wd) - - run_shampoo = (step >= start_preconditioning_step).astype( - grafting_update_with_wd_momentum.dtype) - - momentum_update = ( - run_shampoo * shampoo_update_with_wd_momentum + - (1.0 - run_shampoo) * grafting_update_with_wd_momentum) - - wd_update = ( - run_shampoo * shampoo_update_with_wd + - (1.0 - run_shampoo) * grafting_update_with_wd) - - if nesterov: - momentum_update = w * wd_update + beta1 * momentum_update - - lr = learning_rate - if callable(learning_rate): - lr = learning_rate(step) - transformed_update = -1.0 * lr * momentum_update - - param_stats = ParameterStats( - _quantize_diagonal_statistics(new_diagonal_statistics), - state.statistics, state.preconditioners, - _quantize_momentum(grafting_update_with_wd_momentum), - _quantize_momentum(shampoo_update_with_wd_momentum)) - return transformed_update, param_stats - - def update_fn(grads, state, params): - """Transform the input gradient and update all statistics. - - Args: - grads: the gradient tensors for the parameters. - state: a named tuple containing the state of the optimizer - params: the parameters that should be updated. - - Returns: - A tuple containing the new parameters and the new optimizer state. - """ - params_flat, treedef = jax.tree_flatten(params) - stats_flat = treedef.flatten_up_to(state.stats) - grads_flat = treedef.flatten_up_to(grads) - - new_stats_flat = jax.tree_multimap( - lambda g, s, p: _compute_stats(g, s, p, state.count), grads_flat, - stats_flat, params_flat) - new_stats_flat = _compute_preconditioners(new_stats_flat, params_flat, - state.count) - - outputs = jax.tree_multimap( - lambda g, s, p: _transform_grad(g, s, p, state.count), grads_flat, - new_stats_flat, params_flat) - updates_flat, new_stats_flat = list(zip(*outputs)) if outputs else ((), ()) - - updates = jax.tree_unflatten(treedef, updates_flat) - new_stats = jax.tree_unflatten(treedef, new_stats_flat) - - new_state = ShampooState( - count=state.count+1, stats=new_stats) - return updates, new_state - - if shard_optimizer_states: - return optax.GradientTransformation(sharded_init_fn, sharded_update_fn) - else: - return optax.GradientTransformation(init_fn, update_fn) + return optax.GradientTransformation(init_fn, update_fn)