boris commited on
Commit
9ecdd3f
1 Parent(s): 2c583b3

feat: update shampoo

Browse files
tools/train/scalable_shampoo/README.md CHANGED
@@ -4,4 +4,4 @@ Files copied from [google-research/scalable_shampoo/optax](https://github.com/go
4
 
5
  Imports have been modified to be relative.
6
 
7
- This will be replaced with `optax-shampoo` package eventually.
 
4
 
5
  Imports have been modified to be relative.
6
 
7
+ This will eventually be replaced with `optax-shampoo` package.
tools/train/scalable_shampoo/distributed_shampoo.py CHANGED
@@ -25,13 +25,12 @@
25
  # Authors: Rohan Anil (rohananil at google dot com)
26
  # & Vineet Gupta (vineet at google dot com)
27
  #
28
-
29
  """Distributed Shampoo Implementation."""
30
 
31
  import enum
32
  import functools
33
  import itertools
34
- from typing import Any, List, NamedTuple
35
 
36
  import chex
37
  import jax
@@ -43,6 +42,7 @@ from flax import struct
43
  from jax import lax
44
 
45
  from .quantization_utils import QuantizedValue
 
46
 
47
  # Dtype for inverse-pth root routine
48
  # Switch to f64 if you have hardware that supports it. Enable the jax flag
@@ -141,7 +141,10 @@ class GraftingType(enum.IntEnum):
141
 
142
 
143
  def power_iteration(
144
- matrix, num_iters=100, error_tolerance=1e-6, precision=lax.Precision.HIGHEST
 
 
 
145
  ):
146
  r"""Power iteration algorithm.
147
 
@@ -156,10 +159,10 @@ def power_iteration(
156
  matrix: the symmetric PSD matrix.
157
  num_iters: Number of iterations.
158
  error_tolerance: Iterative exit condition.
159
- precision: precision XLA related flag, the available options are:
160
- a) lax.Precision.DEFAULT (better step time, but not precise)
161
- b) lax.Precision.HIGH (increased precision, slower)
162
- c) lax.Precision.HIGHEST (best possible precision, slowest)
163
 
164
  Returns:
165
  eigen vector, eigen value
@@ -196,7 +199,11 @@ def power_iteration(
196
  return v_out, s_out
197
 
198
 
199
- def mat_power(mat_m, p, precision=lax.Precision.HIGHEST):
 
 
 
 
200
  """A simple matrix power method. M^p where p can be TracedValue."""
201
  power = jnp.eye(mat_m.shape[0], dtype=_MAT_INV_PTH_ROOT_DTYPE)
202
 
@@ -245,15 +252,19 @@ def matrix_inverse_pth_root(
245
  num_iters: Maximum number of iterations.
246
  ridge_epsilon: Ridge epsilon added to make the matrix positive definite.
247
  error_tolerance: Error indicator, useful for early termination.
248
- precision: precision XLA related flag, the available options are:
249
- a) lax.Precision.DEFAULT (better step time, but not precise)
250
- b) lax.Precision.HIGH (increased precision, slower)
251
- c) lax.Precision.HIGHEST (best possible precision, slowest)
252
 
253
  Returns:
254
  matrix^(-1/p)
255
  """
256
 
 
 
 
 
257
  assert matrix.shape[0] == matrix.shape[1]
258
 
259
  # We use _MAT_INV_PTH_ROOT_DTYPE for the matrix inverse pth root.
@@ -336,8 +347,8 @@ def merge_small_dims(shape_to_merge, max_dim):
336
  return resulting_shape
337
 
338
 
339
- def pad_matrix(mat, max_size):
340
- """Pad a matrix to a max_size.
341
 
342
  Args:
343
  mat: a matrix to pad.
@@ -346,19 +357,132 @@ def pad_matrix(mat, max_size):
346
  Returns:
347
  Given M returns [[M, 0], [0, I]]
348
  """
349
- size = mat.shape[0]
350
- assert size <= max_size
351
- if size == max_size:
 
 
 
 
 
 
 
 
352
  return mat
353
- pad_size = max_size - size
354
- zs1 = jnp.zeros([size, pad_size], dtype=mat.dtype)
355
- zs2 = jnp.zeros([pad_size, size], dtype=mat.dtype)
 
356
  eye = jnp.eye(pad_size, dtype=mat.dtype)
357
  mat = jnp.concatenate([mat, zs1], 1)
358
  mat = jnp.concatenate([mat, jnp.concatenate([zs2, eye], 1)], 0)
359
  return mat
360
 
361
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
362
  def pad_vector(vec, max_size):
363
  """Pad a vector to a max_size.
364
 
@@ -694,18 +818,17 @@ def distributed_shampoo(
694
  num_devices_for_pjit: Number of devices to parallelize over when using pjit.
695
  shard_optimizer_states: Shard optimizer states to save memory in model
696
  parallel training.
697
- best_effort_memory_usage_reduction: Best effort memory usage reduction.
698
- diagonal_statistics -> jnp.bfloat16
699
- momentum buffers (2x) -> jnp.int8
700
  statistics, preconditioners -> jnp.int16 + diagonals
701
  inverse_failure_threshold: numerics are hard and inverses fail sometimes; we
702
  determine that using this threshold.
703
  moving_average_for_momentum: Whether to use moving average for momentum
704
  instead of exponential moving average.
705
  skip_preconditioning_dim_size_gt: Skip if preconditioning dim size is
706
- greater than this value.
707
- clip_by_scaled_gradient_norm: Clip by scaled gradient norm (only useful
708
- when using RMSProp Grafting).
709
  precision: precision XLA related flag, the available options are: a)
710
  lax.Precision.DEFAULT (better step time, but not precise) b)
711
  lax.Precision.HIGH (increased precision, slower) c) lax.Precision.HIGHEST
@@ -1167,7 +1290,7 @@ def distributed_shampoo(
1167
  new_padded_statistics = []
1168
  for stat in new_stats_flat:
1169
  new_padded_statistics.extend(
1170
- [pad_matrix(stat, max_size) for stat in stat.statistics]
1171
  )
1172
 
1173
  # Create global stats
@@ -1388,7 +1511,7 @@ def distributed_shampoo(
1388
  num_devices = lax.psum(1, batch_axis_name)
1389
  num_statistics = len(statistics)
1390
  # Pad statistics and exponents to next multiple of num_devices.
1391
- packed_statistics = [pad_matrix(stat, max_size) for stat in statistics]
1392
  to_pad = -num_statistics % num_devices
1393
  packed_statistics.extend(
1394
  [jnp.eye(max_size, dtype=packed_statistics[0].dtype) for _ in range(to_pad)]
@@ -1540,7 +1663,7 @@ def distributed_shampoo(
1540
  # diagonals [d] f32
1541
  # bucket_sizes [d] f32
1542
  packed_quantized_statistics = [
1543
- pad_matrix(stat.quantized, max_size) for stat in statistics
1544
  ]
1545
  packed_quantized_diagonals = [
1546
  pad_vector(stat.diagonal, max_size) for stat in statistics
@@ -1772,7 +1895,7 @@ def distributed_shampoo(
1772
  """
1773
  num_statistics = len(statistics)
1774
  to_pad = -num_statistics % num_devices_for_pjit
1775
- padded_statistics = [pad_matrix(stat, max_size) for stat in statistics]
1776
  padded_statistics.extend(
1777
  [jnp.eye(max_size, dtype=padded_statistics[0].dtype) for _ in range(to_pad)]
1778
  )
 
25
  # Authors: Rohan Anil (rohananil at google dot com)
26
  # & Vineet Gupta (vineet at google dot com)
27
  #
 
28
  """Distributed Shampoo Implementation."""
29
 
30
  import enum
31
  import functools
32
  import itertools
33
+ from typing import Any, List, NamedTuple, Tuple
34
 
35
  import chex
36
  import jax
 
42
  from jax import lax
43
 
44
  from .quantization_utils import QuantizedValue
45
+ from .symmetric_matrices import symmetric_matrices
46
 
47
  # Dtype for inverse-pth root routine
48
  # Switch to f64 if you have hardware that supports it. Enable the jax flag
 
141
 
142
 
143
  def power_iteration(
144
+ matrix,
145
+ num_iters=100,
146
+ error_tolerance=1e-6,
147
+ precision=lax.Precision.HIGHEST,
148
  ):
149
  r"""Power iteration algorithm.
150
 
 
159
  matrix: the symmetric PSD matrix.
160
  num_iters: Number of iterations.
161
  error_tolerance: Iterative exit condition.
162
+ precision: precision XLA related flag, the available options are: a)
163
+ lax.Precision.DEFAULT (better step time, but not precise) b)
164
+ lax.Precision.HIGH (increased precision, slower) c) lax.Precision.HIGHEST
165
+ (best possible precision, slowest)
166
 
167
  Returns:
168
  eigen vector, eigen value
 
199
  return v_out, s_out
200
 
201
 
202
+ def mat_power(
203
+ mat_m,
204
+ p,
205
+ precision=lax.Precision.HIGHEST,
206
+ ):
207
  """A simple matrix power method. M^p where p can be TracedValue."""
208
  power = jnp.eye(mat_m.shape[0], dtype=_MAT_INV_PTH_ROOT_DTYPE)
209
 
 
252
  num_iters: Maximum number of iterations.
253
  ridge_epsilon: Ridge epsilon added to make the matrix positive definite.
254
  error_tolerance: Error indicator, useful for early termination.
255
+ precision: precision XLA related flag, the available options are: a)
256
+ lax.Precision.DEFAULT (better step time, but not precise) b)
257
+ lax.Precision.HIGH (increased precision, slower) c) lax.Precision.HIGHEST
258
+ (best possible precision, slowest)
259
 
260
  Returns:
261
  matrix^(-1/p)
262
  """
263
 
264
+ # If the input is not square, materialize it from the concatenated form.
265
+ if matrix.shape[0] != matrix.shape[1]:
266
+ matrix = symmetric_matrices.materialize_matrix_from_concat(matrix)
267
+
268
  assert matrix.shape[0] == matrix.shape[1]
269
 
270
  # We use _MAT_INV_PTH_ROOT_DTYPE for the matrix inverse pth root.
 
347
  return resulting_shape
348
 
349
 
350
+ def pad_square_matrix(mat, max_size):
351
+ """Pad a square matrix up to max_size.
352
 
353
  Args:
354
  mat: a matrix to pad.
 
357
  Returns:
358
  Given M returns [[M, 0], [0, I]]
359
  """
360
+ rows, cols = mat.shape
361
+ if rows != cols:
362
+ raise ValueError(
363
+ "Must have rows == cols, instead got " f"rows={rows}, cols={cols}"
364
+ )
365
+ if cols > max_size:
366
+ raise ValueError(
367
+ "Must have cols <= max_size. Instead got "
368
+ f"cols={cols}, max_size={max_size}."
369
+ )
370
+ if rows == max_size:
371
  return mat
372
+ pad_size = max_size - rows
373
+
374
+ zs1 = jnp.zeros([rows, pad_size], dtype=mat.dtype)
375
+ zs2 = jnp.zeros([pad_size, rows], dtype=mat.dtype)
376
  eye = jnp.eye(pad_size, dtype=mat.dtype)
377
  mat = jnp.concatenate([mat, zs1], 1)
378
  mat = jnp.concatenate([mat, jnp.concatenate([zs2, eye], 1)], 0)
379
  return mat
380
 
381
 
382
+ def make_sliced_padding(
383
+ symmetric_block_size,
384
+ num_blocks,
385
+ starting_block,
386
+ dtype,
387
+ ):
388
+ """Returns padding for symmetric block matrix.
389
+
390
+ Specifically, the padding is given concatenated rectangular matrices
391
+ representing the lower-triangular rows below the starting block. For example,
392
+ if we want to pad the symmetric matrix
393
+
394
+ M = [[A, B^T]
395
+ [B, C]],
396
+
397
+ the desired output (in terms of the full matrix) with num_blocks = 4 is
398
+
399
+ M_padded = [[A, B^T, 0, 0]
400
+ [B, C, 0, 0]
401
+ [0, 0, I, 0]
402
+ 0, 0, 0, I].
403
+
404
+ We would represent M as the block matrix mat = [A, B, C]. In this form, the
405
+ additional padding to provide has form [0, 0, I, 0, 0, 0, I] (only the lower
406
+ triangular parts in the third and fourth rows).
407
+
408
+ Args:
409
+ symmetric_block_size: The size of each block.
410
+ num_blocks: The total number of blocks.
411
+ starting_block: The block where to start the padding.
412
+ dtype: The type to use for the blocks.
413
+ """
414
+ if starting_block == num_blocks:
415
+ return jnp.zeros(shape=(symmetric_block_size, 0), dtype=dtype)
416
+
417
+ blocks = []
418
+ for i in range(starting_block, num_blocks):
419
+ blocks.append(
420
+ jnp.zeros(
421
+ shape=(symmetric_block_size, symmetric_block_size * i), dtype=dtype
422
+ )
423
+ )
424
+ blocks.append(jnp.eye(symmetric_block_size, dtype=dtype))
425
+ return jnp.concatenate(blocks, axis=-1)
426
+
427
+
428
+ def pad_block_symmetric_matrix(
429
+ mat,
430
+ symmetric_block_size,
431
+ max_num_blocks,
432
+ ):
433
+ """Returns the padded blocked symmetric matrix.
434
+
435
+ The size of the padded matrix will be:
436
+ [symmetric_block_size, symmetric_block_size * max_num_blocks]
437
+
438
+ The input matrix can either:
439
+ - Be square with size less or equal to symmetric_block_size. In this case,
440
+ mat will first be padded to a square matrix of size symmetric_block_size,
441
+ and then be padded again up to the full size of the blocked matrix.
442
+ - Be a rectangle with number of rows equal to block size.
443
+ In this case, number of columns must be a multiple of number of rows, and
444
+ the ratio must correspond to a block representation of a symmetric matrix.
445
+ That is, the ratio must have form x * (x + 1) / 2. Here, x represents the
446
+ number of block rows represented by the matrix.
447
+
448
+ Args:
449
+ mat: The input block matrix.
450
+ symmetric_block_size: The size of blocks.
451
+ max_num_blocks: The largest number of blocks to pad to.
452
+ """
453
+ rows, cols = mat.shape
454
+ if rows > symmetric_block_size:
455
+ raise ValueError(
456
+ "Must have rows <= symmetric_block_size. Instead got "
457
+ f"rows={rows}, symmetric_block_size={symmetric_block_size}."
458
+ )
459
+ if rows > cols:
460
+ raise ValueError(
461
+ "Must have rows <= cols, instead got " f"rows={rows}, cols={cols}."
462
+ )
463
+ if cols > symmetric_block_size * max_num_blocks:
464
+ raise ValueError(
465
+ "Must have cols <= symmetric_block_size * max_num_blocks "
466
+ f"Instead got cols={cols}, "
467
+ f"symmetric_block_size={symmetric_block_size}, "
468
+ f"max_num_blocks={max_num_blocks}."
469
+ )
470
+ if rows < symmetric_block_size:
471
+ mat = pad_square_matrix(mat, max_size=symmetric_block_size)
472
+ # Update rows and cols after possibly padding in pad_square_matrix.
473
+ rows, cols = mat.shape
474
+ assert rows == symmetric_block_size
475
+ assert cols % rows == 0
476
+ filled_blocks = cols // rows
477
+ padding_blocks = make_sliced_padding(
478
+ symmetric_block_size=symmetric_block_size,
479
+ num_blocks=symmetric_matrices.num_blocks_from_total_blocks(max_num_blocks),
480
+ starting_block=symmetric_matrices.num_blocks_from_total_blocks(filled_blocks),
481
+ dtype=mat.dtype,
482
+ )
483
+ return jnp.concatenate([mat, padding_blocks], axis=-1)
484
+
485
+
486
  def pad_vector(vec, max_size):
487
  """Pad a vector to a max_size.
488
 
 
818
  num_devices_for_pjit: Number of devices to parallelize over when using pjit.
819
  shard_optimizer_states: Shard optimizer states to save memory in model
820
  parallel training.
821
+ best_effort_memory_usage_reduction: Best effort memory usage reduction. -
822
+ diagonal_statistics -> jnp.bfloat16 - momentum buffers (2x) -> jnp.int8 -
 
823
  statistics, preconditioners -> jnp.int16 + diagonals
824
  inverse_failure_threshold: numerics are hard and inverses fail sometimes; we
825
  determine that using this threshold.
826
  moving_average_for_momentum: Whether to use moving average for momentum
827
  instead of exponential moving average.
828
  skip_preconditioning_dim_size_gt: Skip if preconditioning dim size is
829
+ greater than this value.
830
+ clip_by_scaled_gradient_norm: Clip by scaled gradient norm (only useful when
831
+ using RMSProp Grafting).
832
  precision: precision XLA related flag, the available options are: a)
833
  lax.Precision.DEFAULT (better step time, but not precise) b)
834
  lax.Precision.HIGH (increased precision, slower) c) lax.Precision.HIGHEST
 
1290
  new_padded_statistics = []
1291
  for stat in new_stats_flat:
1292
  new_padded_statistics.extend(
1293
+ [pad_square_matrix(stat, max_size) for stat in stat.statistics]
1294
  )
1295
 
1296
  # Create global stats
 
1511
  num_devices = lax.psum(1, batch_axis_name)
1512
  num_statistics = len(statistics)
1513
  # Pad statistics and exponents to next multiple of num_devices.
1514
+ packed_statistics = [pad_square_matrix(stat, max_size) for stat in statistics]
1515
  to_pad = -num_statistics % num_devices
1516
  packed_statistics.extend(
1517
  [jnp.eye(max_size, dtype=packed_statistics[0].dtype) for _ in range(to_pad)]
 
1663
  # diagonals [d] f32
1664
  # bucket_sizes [d] f32
1665
  packed_quantized_statistics = [
1666
+ pad_square_matrix(stat.quantized, max_size) for stat in statistics
1667
  ]
1668
  packed_quantized_diagonals = [
1669
  pad_vector(stat.diagonal, max_size) for stat in statistics
 
1895
  """
1896
  num_statistics = len(statistics)
1897
  to_pad = -num_statistics % num_devices_for_pjit
1898
+ padded_statistics = [pad_square_matrix(stat, max_size) for stat in statistics]
1899
  padded_statistics.extend(
1900
  [jnp.eye(max_size, dtype=padded_statistics[0].dtype) for _ in range(to_pad)]
1901
  )
tools/train/scalable_shampoo/symmetric_matrices/symmetric_matrices.py CHANGED
@@ -16,7 +16,7 @@
16
  """JAX Ops for symmetric matrices used by the Shampoo optimizer."""
17
 
18
  import functools
19
- from typing import Any, List, Sequence, Union
20
 
21
  import jax
22
  import jax.numpy as jnp
@@ -192,7 +192,7 @@ def materialize_matrix(symmetric_matrix):
192
  @functools.partial(jax.jit, static_argnames=("num_blocks"))
193
  def materialize_matrix_from_concat(
194
  block_rows_concat,
195
- num_blocks,
196
  ):
197
  """Returns a materialized symmetric matrix from concatenated slices.
198
 
@@ -200,7 +200,11 @@ def materialize_matrix_from_concat(
200
  block_rows_concat: The matrix represented as the concatenated
201
  lower-triangular blocks.
202
  num_blocks: The number of block-rows used to represent the symmetric matrix.
 
203
  """
 
 
 
204
  block_size = block_rows_concat.shape[-2]
205
 
206
  block_rows = [
@@ -251,6 +255,28 @@ def update_sliced_rows(
251
  )
252
 
253
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
254
  def find_num_blocks(block_rows_concat):
255
  """Returns the number of (row) blocks representing the concatenated matrix.
256
 
@@ -270,11 +296,147 @@ def find_num_blocks(block_rows_concat):
270
  # Compute the number of square blocks used to represent the matrix.
271
  total_blocks = block_rows_concat.shape[-1] / block_rows_concat.shape[-2]
272
  # Determine the number of block rows by inverting y = x*(x+1)/2.
273
- num_blocks = np.round((np.sqrt(8 * total_blocks + 1) - 1) / 2).astype(np.int32)
274
- if num_blocks * (num_blocks + 1) / 2 != total_blocks:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
275
  raise ValueError(
276
- "Could not determine an appropriate number of blocks for "
277
- "the concatenated matrix."
278
  )
279
- else:
280
- return num_blocks
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  """JAX Ops for symmetric matrices used by the Shampoo optimizer."""
17
 
18
  import functools
19
+ from typing import Any, List, Optional, Sequence, Union
20
 
21
  import jax
22
  import jax.numpy as jnp
 
192
  @functools.partial(jax.jit, static_argnames=("num_blocks"))
193
  def materialize_matrix_from_concat(
194
  block_rows_concat,
195
+ num_blocks=None,
196
  ):
197
  """Returns a materialized symmetric matrix from concatenated slices.
198
 
 
200
  block_rows_concat: The matrix represented as the concatenated
201
  lower-triangular blocks.
202
  num_blocks: The number of block-rows used to represent the symmetric matrix.
203
+ If not specified, it is inferred from the shape of block_rows_concat.
204
  """
205
+ if num_blocks is None:
206
+ num_blocks = find_num_blocks(block_rows_concat)
207
+
208
  block_size = block_rows_concat.shape[-2]
209
 
210
  block_rows = [
 
255
  )
256
 
257
 
258
+ def num_blocks_from_total_blocks(total_blocks):
259
+ """Returns the number of blocks (i.e.
260
+
261
+ block rows) from the total blocks.
262
+
263
+ This is the inverse of the function x -> x*(x+1)/2.
264
+
265
+ For example, the matrix M = [[A, B^T], [B, C]] may be represented using a
266
+ total of 3 blocks ([A, B, C]). The number of corresponding block rows is 2.
267
+
268
+ Args:
269
+ total_blocks: The total blocks used to represent the matrix.
270
+ """
271
+ num_blocks = np.round((np.sqrt(8 * total_blocks + 1) - 1) / 2).astype(np.int32)
272
+ if (num_blocks * (num_blocks + 1)) / 2 != total_blocks:
273
+ raise ValueError(
274
+ f"total_blocks={total_blocks} does not correspond to "
275
+ "a symmetric matrix. It must have the form total_blocks = x*(x+1)/2."
276
+ )
277
+ return num_blocks
278
+
279
+
280
  def find_num_blocks(block_rows_concat):
281
  """Returns the number of (row) blocks representing the concatenated matrix.
282
 
 
296
  # Compute the number of square blocks used to represent the matrix.
297
  total_blocks = block_rows_concat.shape[-1] / block_rows_concat.shape[-2]
298
  # Determine the number of block rows by inverting y = x*(x+1)/2.
299
+ return num_blocks_from_total_blocks(total_blocks)
300
+
301
+
302
+ @functools.partial(jax.jit, static_argnames=("block_size"))
303
+ def slice_symmetric_matrix(
304
+ mat,
305
+ block_size,
306
+ ):
307
+ """Returns sliced row blocks.
308
+
309
+ Args:
310
+ mat: A symmetric matrix.
311
+ block_size: The size of the row slices.
312
+ """
313
+ num_rows = mat.shape[-2]
314
+ num_cols = mat.shape[-1]
315
+ if num_rows != num_cols:
316
+ raise ValueError("mat is not square.")
317
+ if num_rows % block_size != 0:
318
  raise ValueError(
319
+ "block size does not evenly divide rows. "
320
+ f"num_rows={num_rows}, block_size={block_size}"
321
  )
322
+ return SlicedSymmetricMatrix(
323
+ block_rows=[
324
+ mat[
325
+ Ellipsis,
326
+ i * block_size : (i + 1) * block_size,
327
+ 0 : (i + 1) * block_size,
328
+ ]
329
+ for i in range(num_rows // block_size)
330
+ ]
331
+ )
332
+
333
+
334
+ @functools.partial(jax.jit, static_argnames=("block_size"))
335
+ def slice_symmetric_matrix_concat(
336
+ mat,
337
+ block_size,
338
+ ):
339
+ """Returns the concatenated sliced row blocks.
340
+
341
+ Args:
342
+ mat: A symmetric matrix.
343
+ block_size: The size of the row slices.
344
+ """
345
+ sliced_symmetric_matrix = slice_symmetric_matrix(mat=mat, block_size=block_size)
346
+ return jnp.concatenate(sliced_symmetric_matrix.block_rows, axis=-1)
347
+
348
+
349
+ def sliced_matrix_diag(mat):
350
+ """Returns the diagonal of the symmetric matrix.
351
+
352
+ Args:
353
+ mat: The symmetric matrix represented in concatenated block form.
354
+ """
355
+ rows, cols = mat.shape
356
+ total_blocks = cols // rows
357
+ num_blocks = num_blocks_from_total_blocks(total_blocks)
358
+ diags = []
359
+ for i in range(num_blocks):
360
+ last_index = rows * ((i + 2) * (i + 1)) // 2
361
+ first_index = last_index - rows
362
+ diags.append(jnp.diag(mat[Ellipsis, first_index:last_index]))
363
+ return jnp.concatenate(diags, axis=-1)
364
+
365
+
366
+ def diag_as_concat(diag, block_size):
367
+ """Returns the representation of a diagonal matrix in symmetric block form.
368
+
369
+ Args:
370
+ diag: The 1D array for the diagonals.
371
+ block_size: The size of blocks to use. Must divide the length of diag.
372
+ """
373
+ assert len(diag.shape) == 1 # diag must be 1D.
374
+ assert len(diag) % block_size == 0
375
+ num_diag_blocks = len(diag) // block_size
376
+ blocks = []
377
+ for i in range(num_diag_blocks):
378
+ blocks.append(jnp.zeros(shape=(block_size, block_size * i), dtype=diag.dtype))
379
+ blocks.append(jnp.diag(diag[i * block_size : (i + 1) * block_size]))
380
+ return jnp.concatenate(blocks, axis=-1)
381
+
382
+
383
+ def row_abs_maxes(mat):
384
+ """Returns the max of the absolute values of the rows of the full matrix.
385
+
386
+ For example the symmetric matrix M = [[1, 6], [6, 2]] is represented using
387
+ mat = [1, 6, 2] with block_size = 1. In this case the function returns the
388
+ aboslute row maxes of the original symmetric matrix, [6, 6].
389
+
390
+ Args:
391
+ mat: The symmetric matrix represented as the concatenated blocks.
392
+ """
393
+ rows, cols = mat.shape
394
+
395
+ # Find col and row max for each block.
396
+ col_maxes = []
397
+ row_maxes = []
398
+ for i in range(cols // rows):
399
+ block = jnp.abs(mat[Ellipsis, i * rows : (i + 1) * rows])
400
+ col_maxes.append(jnp.max(block, axis=1))
401
+ row_maxes.append(jnp.max(block, axis=0))
402
+
403
+ # global row max from block maxes.
404
+ num_blocks = num_blocks_from_total_blocks(cols // rows)
405
+ maxes = []
406
+ for i in range(num_blocks):
407
+ maxes.append(
408
+ jnp.concatenate(
409
+ row_maxes[(i * (i + 1) // 2) : ((i + 2) * (i + 1) // 2)]
410
+ + [
411
+ col_maxes[((j + 1) * (j + 2)) // 2 - (j - i + 1)]
412
+ for j in range(i + 1, num_blocks)
413
+ ],
414
+ axis=-1,
415
+ )
416
+ )
417
+
418
+ return jnp.max(jnp.stack(maxes), axis=0)
419
+
420
+
421
+ def times_vector(mat, vec):
422
+ """Returns the symmetric block-concatenated matrix multiplied by a vector.
423
+
424
+ Specifically, each value in the vector is multiplied by a row of the full
425
+ matrix. That is, the vector is broadcast and multiplied element-wise. Note
426
+ this would be the transpose of full_mat * vec if full_mat represented the full
427
+ symmetric matrix.
428
+
429
+ Args:
430
+ mat: The symmetric matrix represented as the concatenated blocks.
431
+ vec: The vector, having the same dimension as the materialized matrix.
432
+ """
433
+ rows, cols = mat.shape
434
+ num_blocks = num_blocks_from_total_blocks(cols // rows)
435
+ multiplied = []
436
+ for i in range(num_blocks):
437
+ mat_block = mat[
438
+ Ellipsis, rows * ((i + 1) * i) // 2 : rows * ((i + 1) * (i + 2)) // 2
439
+ ]
440
+ vec_block = vec[Ellipsis, rows * i : rows * (i + 1)]
441
+ multiplied.append(jnp.einsum("...ij,...i->ij", mat_block, vec_block))
442
+ return jnp.concatenate(multiplied, axis=-1)