Spaces:
Running
Running
# coding=utf-8 | |
# Copyright 2022 The Google Research Authors. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""JAX Ops for symmetric matrices used by the Shampoo optimizer.""" | |
import functools | |
from typing import Any, List, Sequence, Union | |
import jax | |
import jax.numpy as jnp | |
import numpy as np | |
from flax import struct | |
from jax import lax | |
class SlicedSymmetricMatrix: | |
"""A symmetric matrix represented by lower-triangular block row slices. | |
For example, the symmetric matrix M = [[a, b^T], [b, c]] would be represented | |
by the block rows a and [b, c]. | |
The matrix may be batched, in which case each entry of block_rows may have | |
dimension greater than 2. The last two dimensions represent the rows and cols. | |
""" | |
block_rows: List[jnp.ndarray] | |
def product_with_transpose( | |
mat1, | |
mat2, | |
axes, | |
precision=lax.Precision.DEFAULT, | |
): | |
"""Returns mat1 * mat2^T for two matrices (possibly batched). | |
The rows and columns are the last two dimensions for each matrix. | |
Args: | |
mat1: First matrix. | |
mat2: Second matrix. | |
axes: The axes over which to apply the product. | |
precision: JAX precision to use for the multiplication. | |
""" | |
return jnp.tensordot(a=mat1, b=mat2, axes=axes, precision=precision) | |
def sliced_transposed_product( | |
mat, | |
block_size, | |
axes=(-1,), | |
precision=lax.Precision.DEFAULT, | |
): | |
"""Returns the blocked slices representing a symmetric contraction. | |
Specifically, the output is a contraction of the input mat with itself, in the | |
specified axes. | |
Args: | |
mat: The matrix for which we will compute a contraction with itself. | |
block_size: The size of row blocks to compute. | |
axes: Axes to use for the contraction. | |
precision: The precision to use in each computation. | |
Raises: | |
ValueError: Raised when the specified block size does not evenly divide | |
the number of rows of the input mat. | |
""" | |
rank = len(mat.shape) | |
def _make_axis_positive(ax): | |
assert -rank <= ax < rank | |
return ax + rank if ax < 0 else ax | |
positive_axes = [_make_axis_positive(ax) for ax in axes] | |
assert len(positive_axes) == len(axes) | |
remaining_axes = set(range(rank)) - set(positive_axes) | |
assert len(remaining_axes) == 1 | |
remaining_ax = remaining_axes.pop() | |
num_rows = mat.shape[remaining_ax] | |
if num_rows % block_size != 0: | |
raise ValueError( | |
"The row dimension must be divisible by block_size. " | |
f"Instead got row dimension={num_rows} and block_size={block_size}." | |
) | |
block_rows = [] | |
for i in range(num_rows // block_size): | |
start_indices = [0] * rank | |
start_indices[remaining_ax] = i * block_size | |
slice_sizes = list(mat.shape) | |
slice_sizes[remaining_ax] = block_size | |
slice_sizes_full = list(mat.shape) | |
slice_sizes_full[remaining_ax] = (i + 1) * block_size | |
block_rows.append( | |
product_with_transpose( | |
lax.dynamic_slice( | |
mat, start_indices=start_indices, slice_sizes=slice_sizes | |
), | |
lax.dynamic_slice( | |
mat, start_indices=[0] * rank, slice_sizes=slice_sizes_full | |
), | |
axes=(axes, axes), | |
precision=precision, | |
) | |
) | |
return SlicedSymmetricMatrix(block_rows=block_rows) | |
def sliced_transposed_product_concat( | |
mat, | |
block_size, | |
axes=(-1,), | |
precision=lax.Precision.DEFAULT, | |
): | |
"""Returns the concatenated slices representing mat*mat^T. | |
Args: | |
mat: The matrix for which we will compute mat*mat^T. It does not need to be | |
square, and may be batched. | |
block_size: The size of row blocks to compute. | |
axes: Axes to use for the contraction. | |
precision: The precision to use in each computation. | |
Raises: | |
ValueError: Raised when the specified block size does not evenly divide | |
the number of rows of the input mat. | |
""" | |
sliced_symmetric_matrix = sliced_transposed_product( | |
mat=mat, block_size=block_size, axes=axes, precision=precision | |
) | |
return jnp.concatenate(sliced_symmetric_matrix.block_rows, axis=-1) | |
def materialize_matrix(symmetric_matrix): | |
"""Returns a materialized symmetric matrix. | |
Args: | |
symmetric_matrix: the matrix represented by lower-triangular block slices. | |
""" | |
block_rows = symmetric_matrix.block_rows | |
block_size = block_rows[0].shape[-2] | |
num_blocks = len(block_rows) | |
# Slice the lower-triangular and diagonal blocks into blocks. | |
blocks = [ | |
[ | |
block_row[Ellipsis, i * block_size : (i + 1) * block_size] | |
for i in range(k + 1) | |
] | |
for k, block_row in enumerate(block_rows) | |
] | |
# Generate the (off-diagonal) upper-triangular blocks. | |
off_diags = [[] for _ in range(num_blocks - 1)] | |
for k, block_row in enumerate(block_rows[1:]): | |
for i in range(k + 1): | |
off_diags[i].append( | |
jnp.swapaxes( | |
a=block_row[Ellipsis, i * block_size : (i + 1) * block_size], | |
axis1=-1, | |
axis2=-2, | |
) | |
) | |
return jnp.block( | |
[row + row_t for row, row_t in zip(blocks[:-1], off_diags)] + [blocks[-1]] | |
) | |
def materialize_matrix_from_concat( | |
block_rows_concat, | |
num_blocks, | |
): | |
"""Returns a materialized symmetric matrix from concatenated slices. | |
Args: | |
block_rows_concat: The matrix represented as the concatenated | |
lower-triangular blocks. | |
num_blocks: The number of block-rows used to represent the symmetric matrix. | |
""" | |
block_size = block_rows_concat.shape[-2] | |
block_rows = [ | |
block_rows_concat[ | |
Ellipsis, | |
(k * (k + 1)) | |
// 2 | |
* block_size : (((k + 1) * (k + 2)) // 2 + 1) | |
* block_size, | |
] | |
for k in range(num_blocks) | |
] | |
return materialize_matrix(SlicedSymmetricMatrix(block_rows=block_rows)) | |
def update_sliced_rows( | |
symmetric_matrix, | |
mat, | |
alpha, | |
beta, | |
axes=(-1,), | |
): | |
"""Implements the blocked equivalent of SYRK. | |
Specifically, the symmetric matrix (represented using lower-triangular block | |
rows) is updated using the sliced product of mat. | |
Args: | |
symmetric_matrix: The symmetric matrix to update. | |
mat: The matrix to use for the update = mat * mat^T. The number of rows | |
should match that of symmetric_matrix. | |
alpha: The weight for the update. | |
beta: The weight for the original symmetric matrix. | |
axes: Axes to use for the contraction of the update. | |
Returns: | |
The updated rows of alpha * mat * mat^T + beta * symmetric_matrix. | |
""" | |
block_size = symmetric_matrix.block_rows[0].shape[-2] | |
sym_prod = sliced_transposed_product(mat=mat, block_size=block_size, axes=axes) | |
return SlicedSymmetricMatrix( | |
block_rows=[ | |
update * alpha + row * beta | |
for update, row in zip(sym_prod.block_rows, symmetric_matrix.block_rows) | |
] | |
) | |
def find_num_blocks(block_rows_concat): | |
"""Returns the number of (row) blocks representing the concatenated matrix. | |
For example, an input with dimensions [256, 2560] represents 10 square blocks, | |
which matches 4 lower-triangular block rows (1+2+3+4). So this function will | |
return 4. | |
Use ordinary numpy functions here so that the returned value is static. | |
Args: | |
block_rows_concat: The concatenated block array. | |
Raises: | |
ValueError: When the dimensions of the matrix do not correspond to a lower | |
triangular block representation. | |
""" | |
# Compute the number of square blocks used to represent the matrix. | |
total_blocks = block_rows_concat.shape[-1] / block_rows_concat.shape[-2] | |
# Determine the number of block rows by inverting y = x*(x+1)/2. | |
num_blocks = np.round((np.sqrt(8 * total_blocks + 1) - 1) / 2).astype(np.int32) | |
if num_blocks * (num_blocks + 1) / 2 != total_blocks: | |
raise ValueError( | |
"Could not determine an appropriate number of blocks for " | |
"the concatenated matrix." | |
) | |
else: | |
return num_blocks | |