File size: 72,426 Bytes
cd57f41 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186 1187 1188 1189 1190 1191 1192 1193 1194 1195 1196 1197 1198 1199 1200 1201 1202 1203 1204 1205 1206 1207 1208 1209 1210 1211 1212 1213 1214 1215 1216 1217 1218 1219 1220 1221 1222 1223 1224 1225 1226 1227 1228 1229 1230 1231 1232 1233 1234 1235 1236 1237 1238 1239 1240 1241 1242 1243 1244 1245 1246 1247 1248 1249 1250 1251 1252 1253 1254 1255 1256 1257 1258 1259 1260 1261 1262 1263 1264 1265 1266 1267 1268 1269 1270 1271 1272 1273 1274 1275 1276 1277 1278 1279 1280 1281 1282 1283 1284 1285 1286 1287 1288 1289 1290 1291 1292 1293 1294 1295 1296 1297 1298 1299 1300 1301 1302 1303 1304 1305 1306 1307 1308 1309 1310 1311 1312 1313 1314 1315 1316 1317 1318 1319 1320 1321 1322 1323 1324 1325 1326 1327 1328 1329 1330 1331 1332 1333 1334 1335 1336 1337 1338 1339 1340 1341 1342 1343 1344 1345 1346 1347 1348 1349 1350 1351 1352 1353 1354 1355 1356 1357 1358 1359 1360 1361 1362 1363 1364 1365 1366 1367 1368 1369 1370 1371 1372 1373 1374 1375 1376 1377 1378 1379 1380 1381 1382 1383 1384 1385 1386 1387 1388 1389 1390 1391 1392 1393 1394 1395 1396 1397 1398 1399 1400 1401 1402 1403 1404 1405 1406 1407 1408 1409 1410 1411 1412 1413 1414 1415 1416 1417 1418 1419 1420 1421 1422 1423 1424 1425 1426 1427 1428 1429 1430 1431 1432 1433 1434 1435 1436 1437 1438 1439 1440 1441 1442 1443 1444 1445 1446 1447 1448 1449 1450 1451 1452 1453 1454 1455 1456 1457 1458 1459 1460 1461 1462 1463 1464 1465 1466 1467 1468 1469 1470 1471 1472 1473 1474 1475 1476 1477 1478 1479 1480 1481 1482 1483 1484 1485 1486 1487 1488 1489 1490 1491 1492 1493 1494 1495 1496 1497 1498 1499 1500 1501 1502 1503 1504 1505 1506 1507 1508 1509 1510 1511 1512 1513 1514 1515 1516 1517 1518 1519 1520 1521 1522 1523 1524 1525 1526 1527 1528 1529 1530 1531 1532 1533 1534 1535 1536 1537 1538 1539 1540 1541 1542 1543 1544 1545 1546 1547 1548 1549 1550 1551 1552 1553 1554 1555 1556 1557 1558 1559 1560 1561 1562 1563 1564 1565 1566 1567 1568 1569 1570 1571 1572 1573 1574 1575 1576 1577 1578 1579 1580 1581 1582 1583 1584 1585 1586 1587 1588 1589 1590 1591 1592 1593 1594 1595 1596 1597 1598 1599 1600 1601 1602 1603 1604 1605 1606 1607 1608 1609 1610 1611 1612 1613 1614 1615 1616 1617 1618 1619 1620 1621 1622 1623 1624 1625 1626 1627 1628 1629 1630 1631 1632 1633 1634 1635 1636 1637 1638 1639 1640 1641 1642 1643 1644 1645 1646 1647 1648 1649 1650 1651 1652 1653 1654 1655 1656 1657 1658 1659 1660 1661 1662 1663 1664 1665 1666 1667 1668 1669 1670 1671 1672 1673 1674 1675 1676 1677 1678 1679 1680 1681 1682 1683 1684 1685 1686 1687 1688 1689 1690 1691 1692 1693 1694 1695 1696 1697 1698 1699 1700 1701 1702 1703 1704 1705 1706 1707 1708 1709 1710 1711 1712 1713 1714 1715 1716 1717 1718 1719 1720 1721 1722 1723 1724 1725 1726 1727 1728 1729 1730 1731 1732 1733 1734 1735 1736 1737 1738 1739 1740 1741 1742 1743 1744 1745 1746 1747 1748 1749 1750 1751 1752 1753 1754 1755 1756 1757 1758 1759 1760 1761 1762 1763 1764 1765 1766 1767 1768 1769 1770 1771 1772 1773 1774 1775 1776 1777 1778 1779 1780 1781 1782 1783 1784 1785 1786 1787 1788 1789 1790 1791 1792 1793 1794 1795 1796 1797 1798 1799 1800 1801 |
#from https://github.com/google-research/google-research/blob/master/scalable_shampoo/optax/distributed_shampoo.py
# coding=utf-8
# Copyright 2022 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# An implementation of distributed Shampoo optimizer from:
#
# Scalable Second Order Optimization for Deep Learning
# Rohan Anil, Vineet Gupta, Tomer Koren, Kevin Regan, Yoram Singer
# Preprint Paper: https://arxiv.org/abs/2002.09018
#
# This implementation moves computation of inverse pth root back to the
# accelerator (if higher precision is available).
#
# Authors: Rohan Anil (rohananil at google dot com)
# & Vineet Gupta (vineet at google dot com)
#
"""Distributed Shampoo Implementation."""
import enum
import functools
import itertools
from typing import Any, List, NamedTuple
import chex
from flax import struct
import jax
from jax import lax
import jax.experimental.pjit as pjit
import jax.numpy as jnp
import numpy as np
import optax
PartitionSpec = pjit.PartitionSpec
# pylint:disable=no-value-for-parameter
@struct.dataclass
class QuantizedValue:
"""State associated with quantized value."""
quantized: chex.Array
diagonal: chex.Array # Diagonal (if extract_diagonal is set)
bucket_size: chex.Array
quantized_dtype: jnp.dtype = struct.field(
pytree_node=False) # Dtype for the quantized value.
extract_diagonal: bool = struct.field(
pytree_node=False) # In case its centered.
shape: Any = struct.field(pytree_node=False) # Shape of the tensor.
@classmethod
def from_float_value(cls, fvalue, quantized_dtype, extract_diagonal=False):
if isinstance(fvalue, list) and not fvalue:
return QuantizedValue([], [], [], quantized_dtype, extract_diagonal, [])
quantized, diagonal_fvalue, bucket_size = QuantizedValue.quantize(
fvalue, quantized_dtype, extract_diagonal)
return QuantizedValue(quantized, diagonal_fvalue, bucket_size,
quantized_dtype, extract_diagonal,
list(quantized.shape))
# Quantization is from Lingvo JAX optimizers.
# We extend it for int16 quantization of PSD matrices.
@classmethod
def quantize(cls, fvalue, quantized_dtype, extract_diagonal=False):
"""Returns quantized value and the bucket."""
if quantized_dtype == jnp.float32:
return fvalue, [], []
elif quantized_dtype == jnp.bfloat16:
return fvalue.astype(jnp.bfloat16), [], []
float_dtype = fvalue.dtype
if quantized_dtype == jnp.int8:
# value -128 is not used.
num_buckets = jnp.array(127.0, dtype=float_dtype)
elif quantized_dtype == jnp.int16:
# value -32768 is not used.
num_buckets = jnp.array(32767.0, dtype=float_dtype)
else:
raise ValueError(f'Quantized dtype {quantized_dtype} not supported.')
# max value is mapped to num_buckets
if extract_diagonal and fvalue.ndim != 2:
raise ValueError(
f'Input array {fvalue} must be 2D to work with extract_diagonal.')
diagonal_fvalue = []
if extract_diagonal:
diagonal_fvalue = jnp.diag(fvalue)
# Remove the diagonal entries.
fvalue = fvalue - jnp.diag(diagonal_fvalue)
# TODO(rohananil): Extend this by making use of information about the blocks
# SM3 style which will be useful for diagonal statistics
# We first decide the scale.
if fvalue.ndim < 1:
raise ValueError(
f'Input array {fvalue} must have a strictly positive number of '
'dimensions.')
max_abs = jnp.max(jnp.abs(fvalue), axis=0)
bucket_size = max_abs / num_buckets
bs_expanded = bucket_size[jnp.newaxis, Ellipsis]
# To avoid divide by 0.0
bs_nonzero = jnp.where(bs_expanded > 0.0, bs_expanded,
jnp.ones_like(bs_expanded))
ratio = fvalue / bs_nonzero
# We use rounding to remove bias.
quantized = jnp.round(ratio)
return quantized.astype(quantized_dtype), diagonal_fvalue, bucket_size
def to_float(self):
"""Returns the float value."""
if isinstance(self.quantized, list) and not self.quantized:
return self.quantized
if self.quantized_dtype == jnp.float32:
return self.quantized
if self.quantized_dtype == jnp.bfloat16:
return self.quantized.astype(jnp.float32)
float_dtype = self.bucket_size.dtype
bucket_size = self.bucket_size[jnp.newaxis, Ellipsis]
val = self.quantized.astype(float_dtype) * bucket_size
if self.extract_diagonal:
val += jnp.diag(self.diagonal)
return val
# Per parameter optimizer state used in data-parallel training.
class ParameterStats(NamedTuple):
"""State associated to each parameter of the model being trained."""
diagonal_statistics: QuantizedValue # Accumulator for diagonal preconditioner
statistics: List[Any] # Statistics (QuantizedValue, chex.Array)
preconditioners: List[Any] # Preconditioners (QuantizedValue, chex.Array)
diagonal_momentum: QuantizedValue # Momentum for the diagonal preconditioner
momentum: QuantizedValue # Momentum for the shampoo preconditioner
# For training extremely large model; We keep a global state with a concatenated
# statistics and preconditioner states for all vars. This is so that we can
# annotate the leading axis to be sharded to save memory at the cost of
# communication.
@struct.dataclass
class GlobalShardedParameterStats:
statistics: chex.Array # Statistics
preconditioners: chex.Array # Preconditioners
exponents: chex.Array # exponents
# These are per-parameter local states; All statistics here mirror the parameter
# Thus the sharding is copied over from the param specification.
@struct.dataclass
class LocalShardedParameterStats:
"""State associated to each parameter of the model being trained."""
diagonal_statistics: QuantizedValue # Accumulator for diagonal preconditioner
diagonal_momentum: QuantizedValue # Momentum for the diagonal preconditioner
momentum: QuantizedValue # Momentum for the shampoo preconditioner
index_start: np.int32 = struct.field(
pytree_node=False) # Index into global statistics array
sizes: Any = struct.field(pytree_node=False) # Sizes of the statistics.
class ShardedShampooStats(NamedTuple):
"""Shampoo state in sharded mode."""
global_stats: Any
local_stats: Any
class ShampooState(NamedTuple):
count: chex.Array
stats: Any
class InitFnState(NamedTuple):
init_fn: Any
pspec_fn: Any
shape_and_dtype_fn: Any
class GraftingType(enum.IntEnum):
SGD = 1
ADAGRAD = 2
RMSPROP = 3
RMSPROP_NORMALIZED = 4
def power_iteration(
matrix,
num_iters=100,
error_tolerance=1e-6,
precision=lax.Precision.HIGHEST):
r"""Power iteration algorithm.
The power iteration algorithm takes a symmetric PSD matrix `A`, and produces
a scalar `\lambda` , which is the greatest (in absolute value) eigenvalue
of `A`, and a vector v, which is the corresponding eigenvector of `A`.
References:
[Wikipedia, 2021](https://en.wikipedia.org/wiki/Power_iteration)
Args:
matrix: the symmetric PSD matrix.
num_iters: Number of iterations.
error_tolerance: Iterative exit condition.
precision: precision XLA related flag, the available options are:
a) lax.Precision.DEFAULT (better step time, but not precise)
b) lax.Precision.HIGH (increased precision, slower)
c) lax.Precision.HIGHEST (best possible precision, slowest)
Returns:
eigen vector, eigen value
"""
matrix_size = matrix.shape[-1]
def _iter_condition(state):
i, unused_v, unused_s, unused_s_v, run_step = state
return jnp.logical_and(i < num_iters, run_step)
def _iter_body(state):
"""One step of power iteration."""
i, new_v, s, s_v, unused_run_step = state
new_v = new_v / jnp.linalg.norm(new_v)
s_v = jnp.einsum('ij,j->i', matrix, new_v, precision=precision)
s_new = jnp.einsum('i,i->', new_v, s_v, precision=precision)
return (i + 1, s_v, s_new, s_v,
jnp.greater(jnp.abs(s_new - s), error_tolerance))
# Figure out how to use step as seed for random.
v_0 = np.random.RandomState(1729).uniform(-1.0, 1.0,
matrix_size).astype(matrix.dtype)
init_state = tuple([0, v_0, jnp.zeros([], dtype=matrix.dtype), v_0, True])
_, v_out, s_out, _, _ = lax.while_loop(
_iter_condition, _iter_body, init_state)
v_out = v_out / jnp.linalg.norm(v_out)
return v_out, s_out
def matrix_inverse_pth_root(
matrix,
p,
num_iters=100,
ridge_epsilon=1e-6,
error_tolerance=1e-6,
precision=lax.Precision.HIGHEST):
"""Computes `matrix^(-1/p)`, where `p` is a positive integer.
This function uses the Coupled newton iterations algorithm for
the computation of a matrix's inverse pth root.
References:
[Functions of Matrices, Theory and Computation,
Nicholas J Higham, Pg 184, Eq 7.18](
https://epubs.siam.org/doi/book/10.1137/1.9780898717778)
Args:
matrix: the symmetric PSD matrix whose power it to be computed
p: exponent, for p a positive integer.
num_iters: Maximum number of iterations.
ridge_epsilon: Ridge epsilon added to make the matrix positive definite.
error_tolerance: Error indicator, useful for early termination.
precision: precision XLA related flag, the available options are:
a) lax.Precision.DEFAULT (better step time, but not precise)
b) lax.Precision.HIGH (increased precision, slower)
c) lax.Precision.HIGHEST (best possible precision, slowest)
Returns:
matrix^(-1/p)
"""
assert matrix.shape[0] == matrix.shape[1]
# We use float32 for the matrix inverse pth root.
# Switch to f64 if you have hardware that supports it.
matrix_size = matrix.shape[0]
alpha = jnp.asarray(-1.0 / p, jnp.float32)
identity = jnp.eye(matrix_size, dtype=jnp.float32)
_, max_ev = power_iteration(
matrix=matrix, num_iters=100,
error_tolerance=1e-6, precision=precision)
ridge_epsilon = ridge_epsilon * jnp.maximum(max_ev, 1e-16)
def _unrolled_mat_pow_1(mat_m):
"""Computes mat_m^1."""
return mat_m
def _unrolled_mat_pow_2(mat_m):
"""Computes mat_m^2."""
return jnp.matmul(mat_m, mat_m, precision=precision)
def _unrolled_mat_pow_4(mat_m):
"""Computes mat_m^4."""
mat_pow_2 = _unrolled_mat_pow_2(mat_m)
return jnp.matmul(
mat_pow_2, mat_pow_2, precision=precision)
def _unrolled_mat_pow_8(mat_m):
"""Computes mat_m^4."""
mat_pow_4 = _unrolled_mat_pow_4(mat_m)
return jnp.matmul(
mat_pow_4, mat_pow_4, precision=precision)
def mat_power(mat_m, p):
"""Computes mat_m^p, for p == 1, 2, 4 or 8.
Args:
mat_m: a square matrix
p: a positive integer
Returns:
mat_m^p
"""
# We unrolled the loop for performance reasons.
exponent = jnp.round(jnp.log2(p))
return lax.switch(
jnp.asarray(exponent, jnp.int32), [
_unrolled_mat_pow_1,
_unrolled_mat_pow_2,
_unrolled_mat_pow_4,
_unrolled_mat_pow_8,
], (mat_m))
def _iter_condition(state):
(i, unused_mat_m, unused_mat_h, unused_old_mat_h, error,
run_step) = state
error_above_threshold = jnp.logical_and(
error > error_tolerance, run_step)
return jnp.logical_and(i < num_iters, error_above_threshold)
def _iter_body(state):
(i, mat_m, mat_h, unused_old_mat_h, error, unused_run_step) = state
mat_m_i = (1 - alpha) * identity + alpha * mat_m
new_mat_m = jnp.matmul(mat_power(mat_m_i, p), mat_m, precision=precision)
new_mat_h = jnp.matmul(mat_h, mat_m_i, precision=precision)
new_error = jnp.max(jnp.abs(new_mat_m - identity))
# sometimes error increases after an iteration before decreasing and
# converging. 1.2 factor is used to bound the maximal allowed increase.
return (i + 1, new_mat_m, new_mat_h, mat_h, new_error,
new_error < error * 1.2)
if matrix_size == 1:
resultant_mat_h = (matrix + ridge_epsilon)**alpha
error = 0
else:
damped_matrix = matrix + ridge_epsilon * identity
z = (1 + p) / (2 * jnp.linalg.norm(damped_matrix))
new_mat_m_0 = damped_matrix * z
new_error = jnp.max(jnp.abs(new_mat_m_0 - identity))
new_mat_h_0 = identity * jnp.power(z, 1.0 / p)
init_state = tuple(
[0, new_mat_m_0, new_mat_h_0, new_mat_h_0, new_error, True])
_, mat_m, mat_h, old_mat_h, error, convergence = lax.while_loop(
_iter_condition, _iter_body, init_state)
error = jnp.max(jnp.abs(mat_m - identity))
is_converged = jnp.asarray(convergence, old_mat_h.dtype)
resultant_mat_h = is_converged * mat_h + (1 - is_converged) * old_mat_h
resultant_mat_h = jnp.asarray(resultant_mat_h, matrix.dtype)
return resultant_mat_h, error
def merge_small_dims(shape_to_merge, max_dim):
"""Merge small dimensions.
If there are some small dimensions, we collapse them:
e.g. [1, 2, 512, 1, 2048, 1, 3, 4] --> [1024, 2048, 12] if max_dim = 1024
[1, 2, 768, 1, 2048] --> [2, 768, 2048]
Args:
shape_to_merge: Shape to merge small dimensions.
max_dim: Maximal dimension of output shape used in merging.
Returns:
Merged shape.
"""
resulting_shape = []
product = 1
for d in shape_to_merge:
if product * d <= max_dim:
product *= d
else:
if product > 1:
resulting_shape.append(product)
product = d
if product > 1:
resulting_shape.append(product)
return resulting_shape
def pad_matrix(mat, max_size):
"""Pad a matrix to a max_size.
Args:
mat: a matrix to pad.
max_size: matrix size requested.
Returns:
Given M returns [[M, 0], [0, I]]
"""
size = mat.shape[0]
assert size <= max_size
if size == max_size:
return mat
pad_size = max_size - size
zs1 = jnp.zeros([size, pad_size], dtype=mat.dtype)
zs2 = jnp.zeros([pad_size, size], dtype=mat.dtype)
eye = jnp.eye(pad_size, dtype=mat.dtype)
mat = jnp.concatenate([mat, zs1], 1)
mat = jnp.concatenate([mat, jnp.concatenate([zs2, eye], 1)], 0)
return mat
def pad_vector(vec, max_size):
"""Pad a vector to a max_size.
Args:
vec: a vector to pad.
max_size: matrix size requested.
Returns:
Given V returns [V, 0]
"""
size = vec.shape[0]
assert size <= max_size
if size == max_size:
return vec
pad_size = max_size - size
zs1 = jnp.zeros([pad_size], dtype=vec.dtype)
return jnp.concatenate([vec, zs1], 0)
def efficient_cond(predicate, compute_fn, init_state, *args, **kwargs):
"""Avoids wasteful buffer allocation with XLA."""
def _iter_body(unused_state):
results = compute_fn(*args, **kwargs)
return tuple([False] + list(results))
def _iter_condition(state):
return state[0]
results = jax.lax.while_loop(_iter_condition, _iter_body,
tuple([predicate] + init_state))
return tuple(results[1:])
class BlockPartitioner:
"""Partitions a tensor into smaller tensors."""
def __init__(self, param, block_size):
self._shape = param.shape
self._splits = []
split_sizes = []
# We split params into smaller blocks. Here we store the metadata to make
# that split.
for i, d in enumerate(param.shape):
if 0 < block_size < d:
# d-1, otherwise split appends a 0-size array.
nsplit = (d - 1) // block_size
indices = (np.arange(nsplit, dtype=np.int32) + 1) * block_size
sizes = np.ones(nsplit + 1, dtype=np.int32) * block_size
sizes[-1] = d - indices[-1]
self._splits.append((i, indices))
split_sizes.append(sizes)
else:
split_sizes.append(np.array([d], dtype=np.int32))
self._num_splits = len(split_sizes)
self._preconditioner_shapes = []
for t in itertools.product(*split_sizes):
self._preconditioner_shapes.extend([[d, d] for d in t])
def shapes_for_preconditioners(self):
return self._preconditioner_shapes
def num_splits(self):
return self._num_splits
def partition(self, tensor):
"""Partition tensor into blocks."""
assert tensor.shape == self._shape
tensors = [tensor]
for (i, indices) in self._splits:
tensors_local = []
for t in tensors:
tensors_local.extend(jnp.split(t, indices_or_sections=indices, axis=i))
tensors = tensors_local
return tensors
def merge_partitions(self, partitions):
"""Merge partitions back to original shape."""
for (i, indices) in reversed(self._splits):
n = len(indices) + 1
partial_merged_tensors = []
ind = 0
while ind < len(partitions):
partial_merged_tensors.append(
jnp.concatenate(partitions[ind:ind + n], axis=i))
ind += n
partitions = partial_merged_tensors
assert len(partitions) == 1
return partitions[0]
class Preconditioner:
"""Compute statistics/shape from gradients for preconditioning."""
def __init__(self, param, block_size, best_effort_shape_interpretation):
self._original_shape = param.shape
self._transformed_shape = param.shape
if best_effort_shape_interpretation:
self._transformed_shape = merge_small_dims(self._original_shape,
block_size)
reshaped_param = jnp.reshape(param, self._transformed_shape)
self._partitioner = BlockPartitioner(reshaped_param, block_size)
def statistics_from_grad(self, grad):
"""Compute statistics from gradients.
Args:
grad: Gradient to compute statistics from.
Returns:
A list of gradient statistics for each partition.
"""
reshaped_grad = jnp.reshape(grad, self._transformed_shape)
partitioned_grads = self._partitioner.partition(reshaped_grad)
stats = []
for g in partitioned_grads:
g_stats = []
rank = len(g.shape)
for i in range(rank):
axes = list(range(i)) + list(range(i + 1, rank))
stat = jnp.tensordot(g, g, axes=(axes, axes))
g_stats.append(stat)
stats.extend(g_stats)
return stats
def shapes_for_preconditioners(self):
"""Returns shape from statistics."""
return self._partitioner.shapes_for_preconditioners()
def exponent_for_preconditioner(self):
"""Returns exponent to use for inverse-pth root M^{-1/p}."""
return 2 * len(self._transformed_shape)
def preconditioned_grad(self, grad, preconditioners):
"""Precondition the gradient.
Args:
grad: A gradient tensor to precondition.
preconditioners: A list of preconditioners to apply.
Returns:
A preconditioned gradient.
"""
reshaped_grad = jnp.reshape(grad, self._transformed_shape)
partitioned_grads = self._partitioner.partition(reshaped_grad)
preconditioned_partitioned_grads = []
num_splits = self._partitioner.num_splits()
for i, g in enumerate(partitioned_grads):
preconditioners_for_grad = preconditioners[i * num_splits:(i + 1) *
num_splits]
rank = len(g.shape)
precond_g = g
for j in range(rank):
precond_g = jnp.tensordot(
precond_g, preconditioners_for_grad[j], axes=[[0], [0]])
preconditioned_partitioned_grads.append(precond_g)
merged_grad = self._partitioner.merge_partitions(
preconditioned_partitioned_grads)
return jnp.reshape(merged_grad, self._original_shape)
def _convert_to_parameter_stats(global_stats, local_stat):
"""Creates parameter stats from sharded stats."""
index_start = int(local_stat.index_start)
index_end = int(len(local_stat.sizes)) + index_start
statistics = global_stats.statistics[index_start:index_end, :, :]
preconditioners = global_stats.preconditioners[index_start:index_end, :, :]
new_statistics = []
new_preconditioners = []
for i, size in enumerate(local_stat.sizes):
new_statistics.append(statistics[i][:size, :size])
new_preconditioners.append(preconditioners[i][:size, :size])
return ParameterStats(local_stat.diagonal_statistics, new_statistics,
new_preconditioners, local_stat.diagonal_momentum,
local_stat.momentum)
def _convert_from_parameter_stats(parameter_stats, local_stats):
"""Creates sharded stats from paramter stats."""
return LocalShardedParameterStats(parameter_stats.diagonal_statistics,
parameter_stats.diagonal_momentum,
parameter_stats.momentum,
local_stats.index_start, local_stats.sizes)
def batch(x, num_devices):
"""Batch `x` so that so that leading axis is num_devices."""
n = len(x)
b = int(n / num_devices)
return jnp.stack([jnp.stack(x[idx:idx + b]) for idx in range(0, n, b)])
def unbatch(batched_values):
"""Unbatch values across leading axis and return a list of elements."""
b1, b2 = batched_values.shape[0], batched_values.shape[1]
results = []
for v_array in jnp.split(batched_values, indices_or_sections=b1, axis=0):
v_array = jnp.squeeze(v_array)
# b2 = batches (number of preconditioner computation) per core.
if b2 > 1:
for v in jnp.split(v_array, indices_or_sections=b2, axis=0):
results.append(jnp.squeeze(v))
else:
results.append(v_array)
return results
def distributed_shampoo(
learning_rate,
block_size,
beta1=0.9,
beta2=0.999,
diagonal_epsilon=1e-10,
matrix_epsilon=1e-6,
weight_decay=0.0,
start_preconditioning_step=5,
preconditioning_compute_steps=1,
statistics_compute_steps=1,
best_effort_shape_interpretation=True,
graft_type=GraftingType.SGD,
nesterov=True,
exponent_override=0,
# Pass pmap 'batch axis name' in pmap mode.
batch_axis_name=None,
### Only set following 3 params in pjit/spmd mode.
### WARNING: Experimental
statistics_partition_spec=None,
preconditioner_partition_spec=None,
num_devices_for_pjit=None,
shard_optimizer_states=False,
###
### Experimental memory reduction mode
best_effort_memory_usage_reduction=False,
###
inverse_failure_threshold=0.1,
moving_average_for_momentum=False,
skip_preconditioning_dim_size_gt=4096,
clip_by_scaled_gradient_norm=None,
precision=lax.Precision.HIGHEST):
"""Distributed Shampoo optimizer.
Distributed Shampoo is a second-order preconditioned method (concretely, a
variant of full-matrix Adagrad), that provides significant convergence and
wall-clock time improvements compared to conventional first-order methods,
and that has been shown to scale to large state-of-the-art deep learning
models.
References:
Scalable Second Order Optimization for Deep Learning,
Rohan Anil, Vineet Gupta, Tomer Koren, Kevin Regan, Yoram Singer
Preprint: https://arxiv.org/abs/2002.09018
Args:
learning_rate: the step size used to update the parameters.
block_size: Block size for large layers (if > 0). Preconditioning compute
operation is cubic in the dimension of the tensor. Block size allows us to
chunk the layers into sub-layers of maximal dimension dictated by this
value. Use 128 as default (increase if you have compute budget).
beta1: momentum parameter.
beta2: second moment averaging parameter.
diagonal_epsilon: epsilon for diagonal adagrad (only if layerwise grafting
to AdaGrad is enabled).
matrix_epsilon: epsilon to add to statistics before computing inverse pth
root. If you are running in f32 precision for inverse pth root
(recommended today) this can go upto 1e-6. If you have latest hardware
with native f64 precision, set this upto 1e-12.
weight_decay: Weight decay for regularization.
start_preconditioning_step: When to start Shampoo update before which
diagonal update is used. This is because we dont have enough information
to do stable inverse.
preconditioning_compute_steps: How often to compute preconditioner.
Performance tuning params for controlling memory and compute requirements.
Ideally set this and statistics_compute_steps params to 1.
statistics_compute_steps: How often to compute statistics.
best_effort_shape_interpretation: If there are some small dimensions,
collapse them e.g. [1, 2, 512, 1, 2048, 1, 3, 4] --> [1024, 2048, 12] if
block = 1024, [1, 2, 768, 1, 2048] --> [2, 768, 2048]
graft_type: Grafting is a technique to fix the layerwise scale of Shampoo
optimizer. This allows us to plugin the Shampoo optimizer into settings
where SGD/AdaGrad is already well tuned. Available options are:
GraftingType.SGD and GraftingType.ADAGRAD.
nesterov: Nesterov momentum.
exponent_override: Override the exponent used in matrix inverse.
batch_axis_name: labeled axis over pmap for data-parallel training the
optimizer used for.
statistics_partition_spec: PartitionSpec to be used in sharded mode.
preconditioner_partition_spec: PartitionSpec to be used in sharded mode.
num_devices_for_pjit: Number of devices to parallelize over when using pjit.
shard_optimizer_states: Shard optimizer states to save memory in model
parallel training.
best_effort_memory_usage_reduction: Best effort memory usage reduction.
diagonal_statistics -> jnp.bfloat16
momentum buffers (2x) -> jnp.int8
statistics, preconditioners -> jnp.int16 + diagonals
inverse_failure_threshold: numerics are hard and inverses fail sometimes; we
determine that using this threshold.
moving_average_for_momentum: Whether to use moving average for momentum
instead of exponential moving average.
skip_preconditioning_dim_size_gt: Skip if preconditioning dim size is
greater than this value.
clip_by_scaled_gradient_norm: Clip by scaled gradient norm (only useful
when using RMSProp Grafting).
precision: precision XLA related flag, the available options are: a)
lax.Precision.DEFAULT (better step time, but not precise) b)
lax.Precision.HIGH (increased precision, slower) c) lax.Precision.HIGHEST
(best possible precision, slowest)
Returns:
a GradientTransformation.
"""
def quantized_dtype_for_momentum_buffers():
return jnp.int8 if best_effort_memory_usage_reduction else jnp.float32
# TODO(rohananil): Explore int8-16 quantization with non-linear bucket sizes.
def quantized_dtype_for_diagonal_statistics_buffers():
return jnp.bfloat16 if best_effort_memory_usage_reduction else jnp.float32
# Preconditioner and statistics are both stores as int16 in this mode.
# We take out the diagonal to make quantization easier.
def quantized_dtype_for_second_moment_statistics_buffers():
return jnp.int16 if best_effort_memory_usage_reduction and batch_axis_name else jnp.float32
# Preconditioner and statistics are both stores as int16 in this mode.
# We take out the diagonal to make quantization easier.
def quantized_dtype_for_second_moment_preconditioner_buffers():
return jnp.int16 if best_effort_memory_usage_reduction and batch_axis_name else jnp.float32
def _to_float(maybe_quantized):
if isinstance(maybe_quantized, QuantizedValue):
return maybe_quantized.to_float()
else:
return maybe_quantized
def _maybe_quantize_statistics(statistics_list):
return _maybe_quantize_matrices_with_dtype(
statistics_list, quantized_dtype_for_second_moment_statistics_buffers())
def _maybe_quantize_preconditioners(statistics_list):
return _maybe_quantize_matrices_with_dtype(
statistics_list,
quantized_dtype_for_second_moment_preconditioner_buffers())
def _maybe_quantize_matrices_with_dtype(statistics_list, quantized_dtype):
if quantized_dtype != jnp.float32:
return ([
QuantizedValue.from_float_value(
s, quantized_dtype, extract_diagonal=True)
for s in statistics_list
])
else:
return statistics_list
def _maybe_dequantize_preconditioners(preconditioner_list):
return _maybe_dequantize_matrices_with_dtype(
preconditioner_list,
quantized_dtype_for_second_moment_preconditioner_buffers())
def _maybe_dequantize_matrices_with_dtype(statistics_list, quantized_dtype):
if quantized_dtype != jnp.float32:
return [s.to_float() for s in statistics_list]
else:
return statistics_list
def _quantize_diagonal_statistics(diagonal_statistics):
return QuantizedValue.from_float_value(
diagonal_statistics, quantized_dtype_for_diagonal_statistics_buffers())
def _quantize_momentum(momentum_statistics):
return QuantizedValue.from_float_value(
momentum_statistics, quantized_dtype_for_momentum_buffers())
def sharded_init_fn(params):
"""Returns optimizer state (for PJIT mode).
Args:
params: the parameters that should be updated.
"""
params_flat, treedef = jax.tree_flatten(params)
# Find max size to pad to.
max_size = 0
for param in params_flat:
preconditioner = Preconditioner(param, block_size,
best_effort_shape_interpretation)
if not _skip_preconditioning(param):
shapes = preconditioner.shapes_for_preconditioners()
sizes = [s[0] for s in shapes]
max_size = max(max(sizes), max_size)
padded_statistics = []
padded_preconditioners = []
local_stats_flat = []
exponents = []
for param in params_flat:
preconditioner = Preconditioner(param, block_size,
best_effort_shape_interpretation)
shapes = preconditioner.shapes_for_preconditioners()
sizes = []
statistics = []
preconditioners = []
index_start = len(padded_statistics)
if not _skip_preconditioning(param):
sizes = [s[0] for s in shapes]
shapes = preconditioner.shapes_for_preconditioners()
statistics = [matrix_epsilon * jnp.eye(max_size) for s in shapes]
preconditioners = [jnp.eye(max_size) for s in shapes]
padded_statistics.extend(statistics)
padded_preconditioners.extend(preconditioners)
exponent = (
preconditioner.exponent_for_preconditioner()
if exponent_override == 0 else exponent_override)
exponents.extend([exponent] * len(shapes))
diagonal_statistics = []
if graft_type != GraftingType.SGD:
diagonal_statistics = jnp.zeros_like(param)
local_stats_flat.append(
LocalShardedParameterStats(
_quantize_diagonal_statistics(diagonal_statistics),
_quantize_momentum(jnp.zeros_like(param)),
_quantize_momentum(jnp.zeros_like(param)), index_start, sizes))
local_stats = jax.tree_unflatten(treedef, local_stats_flat)
# Pad the statistics and preconditioner matrices to be a multiple of
# num devices.
# TODO(rohananil): Relax to only the size of the mesh axis where the dim
# is split on.
to_pad = -len(padded_statistics) % num_devices_for_pjit
padded_statistics.extend([
jnp.eye(max_size, dtype=padded_statistics[0].dtype)
for _ in range(to_pad)
])
padded_preconditioners.extend([
jnp.eye(max_size, dtype=padded_statistics[0].dtype)
for _ in range(to_pad)
])
exponents.extend([1 for _ in range(to_pad)])
global_stats = GlobalShardedParameterStats(
jnp.stack(padded_statistics), jnp.stack(padded_preconditioners),
jnp.stack(exponents))
return ShampooState(
count=jnp.zeros([], jnp.int32),
stats=ShardedShampooStats(global_stats, local_stats))
def _max_statistics_size_from_params(params):
max_size = 0
for param in params:
param_clone = jnp.zeros(param.shape, dtype=param.dtype)
preconditioner = Preconditioner(param_clone, block_size,
best_effort_shape_interpretation)
if not _skip_preconditioning(param):
shapes = preconditioner.shapes_for_preconditioners()
sizes = [s[0] for s in shapes]
max_size = max(max(sizes), max_size)
return max_size
def _remove_leading_sharding_annotation(pspec):
"""Mapping from N-d to (N-1)-d, used for quantization, factoring etc."""
# None and PSpec(None) are valid PSpecs.
if pspec and len(pspec) > 1:
return PartitionSpec(*pspec[1:])
else:
return pspec
def sharded_init_partition_spec_fn(params, params_partition_spec,
partition_spec_for_statistics):
"""Returns a parallel state tree with PartitionSpec associated with state.
Args:
params: A pytree with params.
params_partition_spec: A pytree with PartitionSpec for params.
partition_spec_for_statistics: PartitionSpec for the statistics.
"""
# Parallel lists of spec, and params.
param_pspec_flat, _ = jax.tree_flatten(params_partition_spec)
params_flat, treedef = jax.tree_flatten(params)
assert param_pspec_flat
assert params_flat
# Step is replicated across cores.
# None means cores.
local_stats_flat = []
num_statistics = 0
for param, param_pspec in zip(params_flat, param_pspec_flat):
param_clone = jnp.zeros(param.shape, dtype=param.dtype)
preconditioner = Preconditioner(param_clone, block_size,
best_effort_shape_interpretation)
shapes = preconditioner.shapes_for_preconditioners()
sizes = []
index_start = num_statistics
if not _skip_preconditioning(param):
sizes = [s[0] for s in shapes]
shapes = preconditioner.shapes_for_preconditioners()
num_statistics += len(shapes)
diagonal_statistics_pspec = []
diagonal_statistics_scale_pspec = []
if graft_type != GraftingType.SGD:
# Identically shaped param.
diagonal_statistics_pspec = param_pspec
if quantized_dtype_for_diagonal_statistics_buffers() != jnp.float32:
diagonal_statistics_scale_pspec = _remove_leading_sharding_annotation(
param_pspec)
m1_pspec = param_pspec
m2_pspec = param_pspec
m1_scale_pspec = []
m2_scale_pspec = []
if quantized_dtype_for_momentum_buffers() != jnp.float32:
m1_scale_pspec = _remove_leading_sharding_annotation(m1_pspec)
m2_scale_pspec = _remove_leading_sharding_annotation(m2_pspec)
local_stats_flat.append(
LocalShardedParameterStats(
QuantizedValue(diagonal_statistics_pspec, [],
diagonal_statistics_scale_pspec,
quantized_dtype_for_diagonal_statistics_buffers(),
False, list(param.shape)),
QuantizedValue(m1_pspec, [], m1_scale_pspec,
quantized_dtype_for_momentum_buffers(), False,
list(param.shape)),
QuantizedValue(m2_pspec, [], m2_scale_pspec,
quantized_dtype_for_momentum_buffers(), False,
list(param.shape)), index_start, sizes))
local_stats = jax.tree_unflatten(treedef, local_stats_flat)
global_stats = GlobalShardedParameterStats(partition_spec_for_statistics,
partition_spec_for_statistics,
PartitionSpec())
count_pspec = PartitionSpec()
return ShampooState(
count=count_pspec, stats=ShardedShampooStats(global_stats, local_stats))
def sharded_init_shape_and_dtype_fn(params):
"""Returns a parallel state tree with shape, dtype associated with state.
Args:
params: A pytree with params.
"""
# Parallel lists of spec, and params.
params_flat, treedef = jax.tree_flatten(params)
assert params_flat
# Step is replicated across cores.
# None means cores.
local_stats_flat = []
num_statistics = 0
for param in params_flat:
param_clone = jnp.zeros(param.shape, dtype=param.dtype)
preconditioner = Preconditioner(param_clone, block_size,
best_effort_shape_interpretation)
shapes = preconditioner.shapes_for_preconditioners()
sizes = []
index_start = num_statistics
if not _skip_preconditioning(param):
sizes = [s[0] for s in shapes]
shapes = preconditioner.shapes_for_preconditioners()
num_statistics += len(shapes)
diagonal_statistics_shape_and_dtype = []
diagonal_statistics_scale_shape_and_dtype = []
if graft_type != GraftingType.SGD:
diagonal_statistics_shape_and_dtype = [list(param.shape), param.dtype]
qdtype = quantized_dtype_for_diagonal_statistics_buffers()
if qdtype != jnp.float32:
diagonal_statistics_shape_and_dtype = [list(param.shape), qdtype]
diagonal_statistics_scale_shape_and_dtype = [
list(param.shape)[1:], param.dtype
]
m1_shape_and_dtype = [list(param.shape), param.dtype]
m2_shape_and_dtype = [list(param.shape), param.dtype]
m1_scale_shape_and_dtype = []
m2_scale_shape_and_dtype = []
qdtype = quantized_dtype_for_momentum_buffers()
if qdtype != jnp.float32:
m1_shape_and_dtype = [list(param.shape), qdtype]
m2_shape_and_dtype = [list(param.shape), qdtype]
m1_scale_shape_and_dtype = [list(param.shape)[1:], qdtype]
m2_scale_shape_and_dtype = [list(param.shape)[1:], qdtype]
local_stats_flat.append(
LocalShardedParameterStats(
QuantizedValue(diagonal_statistics_shape_and_dtype, [],
diagonal_statistics_scale_shape_and_dtype,
quantized_dtype_for_diagonal_statistics_buffers(),
False, list(param.shape)),
QuantizedValue(m1_shape_and_dtype, [], m1_scale_shape_and_dtype,
quantized_dtype_for_momentum_buffers(), False,
list(param.shape)),
QuantizedValue(m2_shape_and_dtype, [], m2_scale_shape_and_dtype,
quantized_dtype_for_momentum_buffers(), False,
list(param.shape)), index_start, sizes))
local_stats = jax.tree_unflatten(treedef, local_stats_flat)
max_statistics_size = _max_statistics_size_from_params(params_flat)
to_pad = -num_statistics % num_devices_for_pjit
num_statistics += to_pad
statistics_shape = [
num_statistics, max_statistics_size, max_statistics_size
]
global_stats = GlobalShardedParameterStats([statistics_shape, jnp.float32],
[statistics_shape, jnp.float32],
[[num_statistics], jnp.int32])
return ShampooState(
count=[[], jnp.float32],
stats=ShardedShampooStats(global_stats, local_stats))
def sharded_update_fn(grads, state, params):
"""Transform the input gradient and update all statistics in sharded mode.
Args:
grads: the gradient tensors for the parameters.
state: a named tuple containing the state of the optimizer
params: the parameters that should be updated.
Returns:
A tuple containing the new parameters and the new optimizer state.
"""
params_flat, treedef = jax.tree_flatten(params)
grads_flat = treedef.flatten_up_to(grads)
global_stats = state.stats.global_stats
local_stats_flat = treedef.flatten_up_to(state.stats.local_stats)
stats_flat = [
_convert_to_parameter_stats(global_stats, local_stat)
for local_stat in local_stats_flat
]
new_stats_flat = jax.tree_multimap(
lambda g, s, p: _compute_stats(g, s, p, state.count), grads_flat,
stats_flat, params_flat)
outputs = jax.tree_multimap(
lambda g, s, p: _transform_grad(g, s, p, state.count), grads_flat,
new_stats_flat, params_flat)
updates_flat, new_stats_flat = list(zip(*outputs)) if outputs else ((), ())
updates = jax.tree_unflatten(treedef, updates_flat)
# Create new local_stats
new_local_stats_flat = [
_convert_from_parameter_stats(new_stat, local_stat)
for new_stat, local_stat in zip(new_stats_flat, local_stats_flat)
]
new_local_stats = jax.tree_unflatten(treedef, new_local_stats_flat)
max_size = global_stats.statistics.shape[1]
new_padded_statistics = []
for stat in new_stats_flat:
new_padded_statistics.extend(
[pad_matrix(stat, max_size) for stat in stat.statistics])
# Create global stats
# TODO(rohananil): Preconditioner is not updated every step, so cost of
# stack/pad can be obviated away.
# Pad the statistics and preconditioner matrices to be a multiple of
# num devices.
# TODO(rohananil): Relax to only the size of the mesh axis where the dim
# is split on.
to_pad = -len(new_padded_statistics) % num_devices_for_pjit
new_padded_statistics.extend([
jnp.eye(max_size, dtype=new_padded_statistics[0].dtype)
for _ in range(to_pad)
])
new_stacked_padded_statistics = jnp.stack(new_padded_statistics)
new_stacked_padded_statistics = pjit.with_sharding_constraint(
new_stacked_padded_statistics, statistics_partition_spec)
def _internal_inverse_pth_root_all():
preconditioners, errors = _matrix_inverse_pth_root_pjit(
new_stacked_padded_statistics, global_stats.exponents,
statistics_partition_spec)
return preconditioners, errors
if preconditioning_compute_steps == 1:
new_preconditioners, errors = _internal_inverse_pth_root_all()
else:
# Passing statistics instead of preconditioners as they are similarly
# shaped tensors. Note statistics will be ignored as we are passing in
# a large init value for error.
preconditioners_init = new_stacked_padded_statistics
n = new_stacked_padded_statistics.shape[0]
errors_init = jnp.ones([n], jnp.float32) * inverse_failure_threshold
init_state = [preconditioners_init, errors_init]
perform_step = state.count % preconditioning_compute_steps == 0
new_preconditioners, errors = efficient_cond(
perform_step, _internal_inverse_pth_root_all, init_state)
errors = errors.reshape((-1, 1, 1))
predicate = jnp.logical_or(
jnp.isnan(errors),
errors >= inverse_failure_threshold).astype(new_preconditioners.dtype)
# TODO(rohananil): Check for numerical instabilities.
new_conditional_preconditioners = (
predicate * global_stats.preconditioners +
(1.0 - predicate) * new_preconditioners)
new_global_stats = GlobalShardedParameterStats(
new_stacked_padded_statistics, new_conditional_preconditioners,
global_stats.exponents)
new_shampoo_state = ShampooState(
count=state.count + 1,
stats=ShardedShampooStats(new_global_stats, new_local_stats))
return updates, new_shampoo_state
def init_fn(params):
"""Initialise the optimiser's state."""
def _init(param):
preconditioner = Preconditioner(param, block_size,
best_effort_shape_interpretation)
statistics = []
preconditioners = []
if not _skip_preconditioning(param):
shapes = preconditioner.shapes_for_preconditioners()
statistics = [matrix_epsilon * jnp.eye(s[0]) for s in shapes]
preconditioners = [jnp.eye(s[0]) for s in shapes]
diagonal_statistics = []
if graft_type != GraftingType.SGD:
diagonal_statistics = jnp.zeros_like(param)
return ParameterStats(
_quantize_diagonal_statistics(diagonal_statistics),
_maybe_quantize_statistics(statistics),
_maybe_quantize_preconditioners(preconditioners),
_quantize_momentum(jnp.zeros_like(param)),
_quantize_momentum(jnp.zeros_like(param)))
return ShampooState(
count=jnp.zeros([], jnp.int32), stats=jax.tree_map(_init, params))
def _skip_preconditioning(param):
return len(param.shape) < 1 or any(
[s > skip_preconditioning_dim_size_gt for s in param.shape])
def _compute_stats(grad, state, param, step):
"""Compute per-parameter statistics."""
preconditioner = Preconditioner(param, block_size,
best_effort_shape_interpretation)
new_statistics = [[]] * len(state.statistics)
w1 = beta2
w2 = beta2 if beta2 == 1.0 else (1.0 - beta2)
if not _skip_preconditioning(param):
def compute_updated_statistics():
new_stats = preconditioner.statistics_from_grad(grad)
new_stats_accumulators = []
for stat, stat_accumulator in zip(new_stats, state.statistics):
new_stats_accumulators.append(w1 * _to_float(stat_accumulator) +
w2 * stat)
return _maybe_quantize_statistics(new_stats_accumulators)
if statistics_compute_steps > 1:
perform_step = step % statistics_compute_steps == 0
init_state = state.statistics
new_statistics = list(
efficient_cond(perform_step, compute_updated_statistics,
init_state))
else:
new_statistics = compute_updated_statistics()
return ParameterStats(state.diagonal_statistics, new_statistics,
state.preconditioners, state.diagonal_momentum,
state.momentum)
def _matrix_inverse_pth_root_vmap(xs, ps):
mi_pth_root = functools.partial(
matrix_inverse_pth_root,
ridge_epsilon=matrix_epsilon,
precision=precision)
return jax.vmap(mi_pth_root)(xs, ps)
def _quantized_matrix_inverse_pth_root_vmap(qxs, qds, qbs, ps):
def _quantized_to_float(qx, qd, qb):
qv = QuantizedValue(qx, qd, qb, qx.dtype, True, list(qx.shape))
return qv.to_float()
def matrix_inverse_pth_root_wrapper(qx, qd, qb, p):
v = _quantized_to_float(qx, qd, qb)
preconditioner, error = matrix_inverse_pth_root(
v, p, ridge_epsilon=matrix_epsilon, precision=precision)
qp = QuantizedValue.from_float_value(preconditioner, qx.dtype, True)
return qp.quantized, qp.diagonal, qp.bucket_size, error
return jax.vmap(matrix_inverse_pth_root_wrapper)(qxs, qds, qbs, ps)
def _matrix_inverse_pth_root_pjit(xs, ps, statistics_partition_spec=None):
# Partition the concatenated statistics matrix across all cores.
pspec_for_partition = preconditioner_partition_spec
partitioned_xs = pjit.with_sharding_constraint(xs, pspec_for_partition)
partitioned_ps = pjit.with_sharding_constraint(
ps, pjit.PartitionSpec(preconditioner_partition_spec[0]))
# Run matrix inverse pth root on each shard.
partitioned_preconditioners, partitioned_errors = (
_matrix_inverse_pth_root_vmap(partitioned_xs, partitioned_ps))
# Reshard output to have the same PSpec as input. This is required to avoid
# vmap seeing the full set of statistics.
partitioned_preconditioners = pjit.with_sharding_constraint(
partitioned_preconditioners, pspec_for_partition)
# Recombine the outputs at each core.
preconditioners = pjit.with_sharding_constraint(partitioned_preconditioners,
statistics_partition_spec)
errors = pjit.with_sharding_constraint(partitioned_errors,
pjit.PartitionSpec())
return preconditioners, errors
def _pmap_compute_preconditioners(states, step, statistics,
num_statistics_per_state, original_shapes,
exponents, max_size, prev_preconditioners):
"""Computes preconditioners for given statistics in states in PMAP mode.
Args:
states: A list of optimizer states.
step: Current step number
statistics: A list of statistics for all variables (for every dim)
num_statistics_per_state: Number of statistis per state to reconstruct
output states.
original_shapes: A list of shapes of the statistics.
exponents: Exponent power to use for inverse-pth roots.
max_size: Maximum dim of the statistics to pad.
prev_preconditioners: Previously available preconditioner.
Returns:
New optimizer states after computing the preconditioner.
"""
num_devices = lax.psum(1, batch_axis_name)
num_statistics = len(statistics)
# Pad statistics and exponents to next multiple of num_devices.
packed_statistics = [pad_matrix(stat, max_size) for stat in statistics]
to_pad = -num_statistics % num_devices
packed_statistics.extend([
jnp.eye(max_size, dtype=packed_statistics[0].dtype)
for _ in range(to_pad)
])
exponents.extend([1 for _ in range(to_pad)])
if not packed_statistics:
return states
all_statistics = batch(packed_statistics, num_devices)
all_exponents = batch(exponents, num_devices)
def _internal_inverse_pth_root_all():
current_replica = lax.axis_index(batch_axis_name)
preconditioners, errors = _matrix_inverse_pth_root_vmap(
all_statistics[current_replica], all_exponents[current_replica])
preconditioners = jax.lax.all_gather(preconditioners, batch_axis_name)
errors = jax.lax.all_gather(errors, batch_axis_name)
preconditioners_flat = unbatch(preconditioners)
errors_flat = unbatch(errors)
return preconditioners_flat, errors_flat
if preconditioning_compute_steps == 1:
preconditioners_flat, errors_flat = _internal_inverse_pth_root_all()
else:
# Passing statistics instead of preconditioners as they are similarly
# shaped tensors. Note statistics will be ignored as we are passing in
# a large init value for error.
preconditioners_init = packed_statistics
errors_init = ([inverse_failure_threshold] * len(packed_statistics))
init_state = [preconditioners_init, errors_init]
perform_step = step % preconditioning_compute_steps == 0
preconditioners_flat, errors_flat = efficient_cond(
perform_step, _internal_inverse_pth_root_all, init_state)
def _skip(error):
condition = jnp.logical_or(
jnp.isnan(error), error >= inverse_failure_threshold)
return condition.astype(error.dtype)
def _select_preconditioner(error, new_p, old_p):
return lax.cond(
_skip(error), lambda _: old_p, lambda _: new_p, operand=None)
new_preconditioners_flat = []
for p, shape, prev_p, error in zip(preconditioners_flat, original_shapes,
prev_preconditioners, errors_flat):
new_preconditioners_flat.append(
_select_preconditioner(error, p[:shape[0], :shape[1]], prev_p))
assert len(states) == len(num_statistics_per_state)
assert len(new_preconditioners_flat) == num_statistics
# Add back empty preconditioners so we that we can set the optimizer state.
preconditioners_for_states = []
idx = 0
for num_statistics, state in zip(num_statistics_per_state, states):
if num_statistics == 0:
preconditioners_for_states.append([])
else:
preconditioners_for_state = new_preconditioners_flat[idx:idx +
num_statistics]
assert len(state.statistics) == len(preconditioners_for_state)
preconditioners_for_states.append(preconditioners_for_state)
idx += num_statistics
new_states = []
for state, new_preconditioners in zip(states, preconditioners_for_states):
new_states.append(
ParameterStats(state.diagonal_statistics, state.statistics,
new_preconditioners, state.diagonal_momentum,
state.momentum))
return new_states
def _pmap_quantized_compute_preconditioners(states, step, statistics,
num_statistics_per_state,
original_shapes, exponents,
max_size, prev_preconditioners):
"""Computes preconditioners for given statistics in states in PMAP mode.
For quantization, each statistic is represented by three values:
quantized matrix, diagonal, and bucket sizes, we run inverse pth-roots
without ever recreating the original matrix in f32.
Args:
states: A list of optimizer states.
step: Current step number
statistics: A list of statistics for all variables (for every dim)
num_statistics_per_state: Number of statistis per state to reconstruct
output states.
original_shapes: A list of shapes of the statistics.
exponents: Exponent power to use for inverse-pth roots.
max_size: Maximum dim of the statistics to pad.
prev_preconditioners: Previously available preconditioner.
Returns:
New optimizer states after computing the preconditioner.
"""
num_devices = lax.psum(1, batch_axis_name)
num_statistics = len(statistics)
quantized_dtype = quantized_dtype_for_second_moment_statistics_buffers()
# Complexity here is around: shapes needing be statically shaped,
# our custom quantization type requires a different type of packing.
# Parallel tensors:
# quantized [dxd]
# diagonals [d] f32
# bucket_sizes [d] f32
packed_quantized_statistics = [
pad_matrix(stat.quantized, max_size) for stat in statistics
]
packed_quantized_diagonals = [
pad_vector(stat.diagonal, max_size) for stat in statistics
]
packed_quantized_bucket_sizes = [
pad_vector(stat.bucket_size, max_size) for stat in statistics
]
to_pad = -num_statistics % num_devices
padded_eye = jnp.eye(max_size, dtype=jnp.float32)
quantized_eye = QuantizedValue.from_float_value(padded_eye, quantized_dtype,
True)
packed_quantized_statistics.extend(
[quantized_eye.quantized for _ in range(to_pad)])
packed_quantized_diagonals.extend(
[quantized_eye.diagonal for _ in range(to_pad)])
packed_quantized_bucket_sizes.extend(
[quantized_eye.bucket_size for _ in range(to_pad)])
exponents.extend([1 for _ in range(to_pad)])
if not packed_quantized_statistics:
return states
all_quantized_statistics = batch(packed_quantized_statistics, num_devices)
all_quantized_diagonals = batch(packed_quantized_diagonals, num_devices)
all_quantized_bucket_sizes = batch(packed_quantized_bucket_sizes,
num_devices)
all_exponents = batch(exponents, num_devices)
def _internal_inverse_pth_root_all():
current_replica = lax.axis_index(batch_axis_name)
(quantized_preconditioners, quantized_diagonals, quantized_bucket_sizes,
errors) = (
_quantized_matrix_inverse_pth_root_vmap(
all_quantized_statistics[current_replica],
all_quantized_diagonals[current_replica],
all_quantized_bucket_sizes[current_replica],
all_exponents[current_replica]))
quantized_preconditioners = jax.lax.all_gather(quantized_preconditioners,
batch_axis_name)
quantized_diagonals = jax.lax.all_gather(quantized_diagonals,
batch_axis_name)
quantized_bucket_sizes = jax.lax.all_gather(quantized_bucket_sizes,
batch_axis_name)
errors = jax.lax.all_gather(errors, batch_axis_name)
quantized_preconditioners_flat = unbatch(quantized_preconditioners)
quantized_diagonals_flat = unbatch(quantized_diagonals)
quantized_bucket_sizes_flat = unbatch(quantized_bucket_sizes)
errors_flat = unbatch(errors)
return (quantized_preconditioners_flat, quantized_diagonals_flat,
quantized_bucket_sizes_flat, errors_flat)
if preconditioning_compute_steps == 1:
(quantized_preconditioners_flat, quantized_diagonals_flat,
quantized_bucket_sizes_flat, errors_flat) = (
_internal_inverse_pth_root_all())
else:
# Passing statistics instead of preconditioners as they are similarly
# shaped tensors. Note statistics will be ignored as we are passing in
# a large init value for error.
quantized_preconditioners_init = packed_quantized_statistics
quantized_diagonals_init = packed_quantized_diagonals
quantized_bucket_sizes_init = packed_quantized_bucket_sizes
errors_init = ([inverse_failure_threshold] *
len(quantized_preconditioners_init))
init_state = [
quantized_preconditioners_init, quantized_diagonals_init,
quantized_bucket_sizes_init, errors_init
]
perform_step = step % preconditioning_compute_steps == 0
(quantized_preconditioners_flat, quantized_diagonals_flat,
quantized_bucket_sizes_flat, errors_flat) = (
efficient_cond(perform_step, _internal_inverse_pth_root_all,
init_state))
def _skip(error):
condition = jnp.logical_or(
jnp.isnan(error), error >= inverse_failure_threshold)
return condition.astype(error.dtype)
def _select_preconditioner(error, new_p, old_p):
return lax.cond(
_skip(error), lambda _: old_p, lambda _: new_p, operand=None)
new_quantized_preconditioners_flat = []
new_quantized_diagonals_flat = []
new_quantized_bucket_sizes_flat = []
for p, d, b, shape, prev_p, error in zip(quantized_preconditioners_flat,
quantized_diagonals_flat,
quantized_bucket_sizes_flat,
original_shapes,
prev_preconditioners, errors_flat):
new_quantized_preconditioners_flat.append(
_select_preconditioner(error, p[:shape[0], :shape[1]],
prev_p.quantized))
new_quantized_diagonals_flat.append(
_select_preconditioner(error, d[:shape[0]], prev_p.diagonal))
new_quantized_bucket_sizes_flat.append(
_select_preconditioner(error, b[:shape[0]], prev_p.bucket_size))
assert len(states) == len(num_statistics_per_state)
assert len(new_quantized_preconditioners_flat) == num_statistics
assert len(new_quantized_diagonals_flat) == num_statistics
assert len(new_quantized_bucket_sizes_flat) == num_statistics
# Add back empty preconditioners so we that we can set the optimizer state.
preconditioners_for_states = []
idx = 0
for num_statistics, state in zip(num_statistics_per_state, states):
if num_statistics == 0:
preconditioners_for_states.append([])
else:
quantized_preconditioners_for_state = new_quantized_preconditioners_flat[
idx:idx + num_statistics]
quantized_diagonals_for_state = new_quantized_diagonals_flat[
idx:idx + num_statistics]
quantized_bucket_sizes_for_state = new_quantized_bucket_sizes_flat[
idx:idx + num_statistics]
assert len(state.statistics) == len(quantized_preconditioners_for_state)
assert len(state.statistics) == len(quantized_diagonals_for_state)
assert len(state.statistics) == len(quantized_bucket_sizes_for_state)
quantized_preconditioners = []
for qv, qd, qb in zip(quantized_preconditioners_for_state,
quantized_diagonals_for_state,
quantized_bucket_sizes_for_state):
quantized_preconditioners.append(
QuantizedValue(qv, qd, qb, qv.dtype, True, list(qv.shape)))
preconditioners_for_states.append(quantized_preconditioners)
idx += num_statistics
new_states = []
for state, new_preconditioners in zip(states, preconditioners_for_states):
new_states.append(
ParameterStats(state.diagonal_statistics, state.statistics,
new_preconditioners, state.diagonal_momentum,
state.momentum))
return new_states
def _pjit_compute_preconditioners(states, step, statistics,
num_statistics_per_state, original_shapes,
exponents, max_size, prev_preconditioners):
"""Computes preconditioners for given statistics in states in PJIT mode.
Args:
states: A list of optimizer states.
step: Current step number
statistics: A list of statistics for all variables (for every dim)
num_statistics_per_state: Number of statistis per state to reconstruct
output states.
original_shapes: A list of shapes of the statistics.
exponents: Exponent power to use for inverse-pth roots.
max_size: Maximum dim of the statistics to pad.
prev_preconditioners: Previously available preconditioner.
Returns:
New optimizer states after computing the preconditioner.
"""
num_statistics = len(statistics)
to_pad = -num_statistics % num_devices_for_pjit
padded_statistics = [pad_matrix(stat, max_size) for stat in statistics]
padded_statistics.extend([
jnp.eye(max_size, dtype=padded_statistics[0].dtype)
for _ in range(to_pad)
])
exponents.extend([1 for _ in range(to_pad)])
all_statistics = jnp.stack(padded_statistics)
all_exponents = jnp.stack(exponents)
def _internal_inverse_pth_root_all():
preconditioners, errors = _matrix_inverse_pth_root_pjit(
all_statistics, all_exponents)
b1 = preconditioners.shape[0]
def split(batched_values):
return [
jnp.squeeze(v)
for v in jnp.split(batched_values, indices_or_sections=b1, axis=0)
]
return split(preconditioners), split(errors)
if preconditioning_compute_steps == 1:
preconditioners_flat, errors_flat = _internal_inverse_pth_root_all()
else:
# Passing statistics instead of preconditioners as they are similarly
# shaped tensors. Note statistics will be ignored as we are passing in
# a large init value for error.
preconditioners_init = padded_statistics
errors_init = [inverse_failure_threshold] * len(padded_statistics)
init_state = [preconditioners_init, errors_init]
perform_step = step % preconditioning_compute_steps == 0
preconditioners_flat, errors_flat = efficient_cond(
perform_step, _internal_inverse_pth_root_all, init_state)
def _skip(error):
condition = jnp.logical_or(
jnp.isnan(error), error >= inverse_failure_threshold)
return condition.astype(error.dtype)
def _select_preconditioner(error, new_p, old_p):
return lax.cond(
_skip(error), lambda _: old_p, lambda _: new_p, operand=None)
new_preconditioners_flat = []
for p, shape, prev_p, error in zip(preconditioners_flat, original_shapes,
prev_preconditioners, errors_flat):
new_preconditioners_flat.append(
_select_preconditioner(error, p[:shape[0], :shape[1]], prev_p))
assert len(states) == len(num_statistics_per_state)
assert len(new_preconditioners_flat) == num_statistics
# Add back empty preconditioners so we that we can set the optimizer state.
preconditioners_for_states = []
idx = 0
for num_statistics, state in zip(num_statistics_per_state, states):
if num_statistics == 0:
preconditioners_for_states.append([])
else:
preconditioners_for_state = new_preconditioners_flat[idx:idx +
num_statistics]
assert len(state.statistics) == len(preconditioners_for_state)
preconditioners_for_states.append(preconditioners_for_state)
idx += num_statistics
new_states = []
for state, new_preconditioners in zip(states, preconditioners_for_states):
new_states.append(
ParameterStats(state.diagonal_statistics, state.statistics,
new_preconditioners, state.diagonal_momentum,
state.momentum))
return new_states
def _compute_preconditioners(states, params, step):
"""Computes preconditioners for given statistics in states.
Args:
states: A list of optimizer states.
params: A list of params.
step: Current step number
Returns:
New optimizer states after computing the preconditioner.
"""
statistics = []
num_statistics_per_state = []
original_shapes = []
exponents = []
max_size = 0
prev_preconditioners = []
for state, param in zip(states, params):
num_statistics = len(state.statistics)
num_statistics_per_state.append(num_statistics)
original_shapes_for_state = []
if num_statistics > 0:
preconditioner = Preconditioner(param, block_size,
best_effort_shape_interpretation)
for statistic in state.statistics:
exponents.append(preconditioner.exponent_for_preconditioner(
) if exponent_override == 0 else exponent_override)
original_shapes_for_state.append(statistic.shape)
max_size = max(max_size, statistic.shape[0])
statistics.extend(state.statistics)
prev_preconditioners.extend(state.preconditioners)
original_shapes.extend(original_shapes_for_state)
if batch_axis_name:
# Quantization is only enabled if batch_axis_name is not set.
quantized_dtype = quantized_dtype_for_second_moment_statistics_buffers()
if quantized_dtype == jnp.float32:
return _pmap_compute_preconditioners(states, step, statistics,
num_statistics_per_state,
original_shapes, exponents,
max_size, prev_preconditioners)
else:
return _pmap_quantized_compute_preconditioners(
states, step, statistics, num_statistics_per_state, original_shapes,
exponents, max_size, prev_preconditioners)
else:
return _pjit_compute_preconditioners(states, step, statistics,
num_statistics_per_state,
original_shapes, exponents, max_size,
prev_preconditioners)
def _transform_grad(grad, state, param, step):
"""Transform per-parameter gradients."""
preconditioner = Preconditioner(param, block_size,
best_effort_shape_interpretation)
sgd_update = grad
new_diagonal_statistics = state.diagonal_statistics.to_float()
if graft_type == GraftingType.ADAGRAD:
new_diagonal_statistics = state.diagonal_statistics.to_float(
) + jnp.square(grad)
adagrad_update = grad / (
jnp.sqrt(new_diagonal_statistics) + diagonal_epsilon)
grafting_update = adagrad_update
elif (graft_type == GraftingType.RMSPROP or
graft_type == GraftingType.RMSPROP_NORMALIZED):
scaled_grad = grad
if graft_type == GraftingType.RMSPROP_NORMALIZED:
scaled_grad = grad / jnp.linalg.norm(grad)
w1 = beta2
w2 = beta2 if beta2 == 1.0 else (1.0 - beta2)
new_diagonal_statistics = (
w1 * state.diagonal_statistics.to_float() +
w2 * jnp.square(scaled_grad))
rmsprop_update = scaled_grad / (
jnp.sqrt(new_diagonal_statistics) + diagonal_epsilon)
if clip_by_scaled_gradient_norm:
scaled_grad_norm = jnp.linalg.norm(rmsprop_update) / (
jnp.sqrt(float(rmsprop_update.size)))
clipping_denom = jnp.maximum(
1., scaled_grad_norm / clip_by_scaled_gradient_norm)
rmsprop_update /= clipping_denom
grafting_update = rmsprop_update
else:
grafting_update = sgd_update
precond_grad = grad
if not _skip_preconditioning(param):
precond_grad = preconditioner.preconditioned_grad(
precond_grad,
_maybe_dequantize_preconditioners(state.preconditioners))
else:
precond_grad = grafting_update
grafting_update_norm = jnp.linalg.norm(grafting_update)
precond_grad_norm = jnp.linalg.norm(precond_grad)
multiplier = (grafting_update_norm / (precond_grad_norm + 1e-16))
shampoo_update = precond_grad * multiplier
shampoo_update_with_wd = shampoo_update
grafting_update_with_wd = grafting_update
if weight_decay != 0:
shampoo_update_with_wd = shampoo_update + weight_decay * param
grafting_update_with_wd = grafting_update + weight_decay * param
w = (1.0 - beta1) if moving_average_for_momentum else 1.0
shampoo_update_with_wd_momentum = (
state.momentum.to_float() * beta1 + w * shampoo_update_with_wd)
grafting_update_with_wd_momentum = (
state.diagonal_momentum.to_float() * beta1 +
w * grafting_update_with_wd)
run_shampoo = (step >= start_preconditioning_step).astype(
grafting_update_with_wd_momentum.dtype)
momentum_update = (
run_shampoo * shampoo_update_with_wd_momentum +
(1.0 - run_shampoo) * grafting_update_with_wd_momentum)
wd_update = (
run_shampoo * shampoo_update_with_wd +
(1.0 - run_shampoo) * grafting_update_with_wd)
if nesterov:
momentum_update = w * wd_update + beta1 * momentum_update
lr = learning_rate
if callable(learning_rate):
lr = learning_rate(step)
transformed_update = -1.0 * lr * momentum_update
param_stats = ParameterStats(
_quantize_diagonal_statistics(new_diagonal_statistics),
state.statistics, state.preconditioners,
_quantize_momentum(grafting_update_with_wd_momentum),
_quantize_momentum(shampoo_update_with_wd_momentum))
return transformed_update, param_stats
def update_fn(grads, state, params):
"""Transform the input gradient and update all statistics.
Args:
grads: the gradient tensors for the parameters.
state: a named tuple containing the state of the optimizer
params: the parameters that should be updated.
Returns:
A tuple containing the new parameters and the new optimizer state.
"""
params_flat, treedef = jax.tree_flatten(params)
stats_flat = treedef.flatten_up_to(state.stats)
grads_flat = treedef.flatten_up_to(grads)
new_stats_flat = jax.tree_multimap(
lambda g, s, p: _compute_stats(g, s, p, state.count), grads_flat,
stats_flat, params_flat)
new_stats_flat = _compute_preconditioners(new_stats_flat, params_flat,
state.count)
outputs = jax.tree_multimap(
lambda g, s, p: _transform_grad(g, s, p, state.count), grads_flat,
new_stats_flat, params_flat)
updates_flat, new_stats_flat = list(zip(*outputs)) if outputs else ((), ())
updates = jax.tree_unflatten(treedef, updates_flat)
new_stats = jax.tree_unflatten(treedef, new_stats_flat)
new_state = ShampooState(
count=state.count+1, stats=new_stats)
return updates, new_state
if shard_optimizer_states:
# Hijacks the init_fn signature so we can return an OptState with
# appropriate init_fns.
def _init_fns(unused_params):
return InitFnState(
init_fn=sharded_init_fn,
pspec_fn=sharded_init_partition_spec_fn,
shape_and_dtype_fn=sharded_init_shape_and_dtype_fn)
return optax.GradientTransformation(_init_fns, sharded_update_fn)
else:
return optax.GradientTransformation(init_fn, update_fn) |