Spaces:
Build error
Build error
File size: 54,299 Bytes
b100e1c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186 1187 1188 1189 1190 1191 1192 1193 1194 1195 1196 1197 1198 1199 1200 1201 1202 1203 1204 1205 1206 1207 1208 1209 1210 1211 1212 1213 1214 1215 1216 1217 1218 1219 1220 1221 1222 1223 1224 1225 1226 1227 1228 1229 1230 1231 1232 1233 1234 1235 1236 1237 1238 1239 1240 1241 1242 1243 1244 1245 1246 1247 1248 1249 1250 1251 1252 1253 1254 1255 1256 1257 1258 1259 1260 1261 1262 1263 1264 1265 1266 1267 1268 1269 1270 1271 1272 1273 1274 1275 1276 1277 1278 1279 1280 1281 1282 1283 1284 1285 1286 1287 1288 1289 1290 1291 1292 1293 1294 1295 1296 1297 1298 1299 1300 1301 1302 1303 1304 1305 1306 1307 1308 1309 1310 1311 1312 1313 1314 1315 1316 1317 1318 1319 1320 1321 1322 1323 1324 1325 1326 1327 1328 1329 1330 1331 1332 1333 1334 1335 1336 1337 1338 1339 1340 1341 1342 1343 1344 1345 1346 1347 1348 1349 1350 1351 1352 1353 1354 1355 1356 1357 1358 1359 1360 1361 1362 1363 1364 1365 1366 1367 1368 1369 1370 1371 1372 1373 1374 1375 1376 1377 1378 1379 1380 1381 |
# Copyright 2022 The T5X Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""General utility functions for t5x."""
import collections.abc
from concurrent.futures import thread
import contextlib
import dataclasses
import functools
import importlib
import inspect
import os
import re
import time
import typing
from typing import Any, Callable, Iterable, Mapping, Optional, Sequence, Tuple, Type, Union
import warnings
from absl import logging
import clu.data
from flax import traverse_util
import flax.core
from flax.core import scope as flax_scope
from flax.linen import partitioning as flax_partitioning
import jax
from jax import prng
from jax import pxla
from jax.experimental import multihost_utils
from jax.experimental.global_device_array import GlobalDeviceArray
import jax.numpy as jnp
import numpy as np
import orbax.checkpoint
import seqio
from t5x import checkpoints
from t5x import optimizers
from t5x import partitioning
from t5x import state_utils
from t5x import train_state as train_state_lib
import tensorflow as tf
from tensorflow.io import gfile
import typing_extensions
Array = Union[np.ndarray, jnp.ndarray, jax.pxla.ShardedDeviceArray, tf.Tensor]
PyTreeDef = type(jax.tree_structure(None))
PartitionSpec = partitioning.PartitionSpec
DType = Union[np.dtype, type(jnp.bfloat16)]
Shape = Tuple[int, ...]
# TODO(adarob): Remove namespace mapping after client gin files are updated.
TensorBoardLogger = seqio.TensorBoardLogger
# -----------------------------------------------------------------------------
# Configurations
# -----------------------------------------------------------------------------
@dataclasses.dataclass
class SaveCheckpointConfig:
"""Configuration for saving model checkpoints."""
# The dtype to save ('float32' or 'bfloat16').
dtype: str = 'float32'
# Number of steps between writing checkpoints.
period: Optional[int] = None
# Number of most recent checkpoints to keep, or None to keep them all.
keep: Optional[int] = None
# Number of dataset checkpoints to keep, or None to keep them all.
# Note: Dataset checkpoints are also affected by `keep`.
keep_dataset_checkpoints: Optional[int] = None
# Whether to save dataset checkpoints.
save_dataset: bool = False
# The checkpointer class to use.
checkpointer_cls: checkpoints.CheckpointerConstructor = checkpoints.Checkpointer
# Transformations to apply, in order, to the state before writing.
state_transformation_fns: Sequence[checkpoints.SaveStateTransformationFn] = (
dataclasses.field(default_factory=list))
def __post_init__(self):
if self.dtype not in ('float32', 'bfloat16'):
raise ValueError(
"`SaveCheckpointConfig.dtype` must be one of 'float32' or "
f"'bfloat16'. Got {self.dtype}.")
@dataclasses.dataclass
class RestoreCheckpointConfig:
"""Configuration for restoring model from checkpoint."""
# Path(s) to checkpoint to restore from or directory (depending on `mode`).
path: Union[str, Sequence[str]]
# One of 'specific', 'latest', or 'all'.
# specific: load the checkpoint specified by `path`.
# latest: load most recent checkpoint in the directory specified by `path`.
# all: sequentially load all of checkpoints in the directory `path`.
mode: str = 'latest'
# An optional sequence of (pattern, replacement) regex pairs. The pattern
# matches parameters in the model and the replacement matches the checkpoint
# (after substitutions). The replacement may be None, in which case the
# parameter can be dropped. Use `fallback_to_scratch` to fill them in with
# newly initialized values.
assignment_map: Optional[Sequence[Tuple[str, Optional[str]]]] = None
# Whether to restore all optimizer parameters from the checkpoint.
strict: bool = True
# Whether to initialize parameters that are in the model being restored but
# are missing from the checkpoint (after `assignment_map` is applied).
fallback_to_scratch: bool = False
# The dtype to restore ('float32' or 'bfloat16'), or None to load as saved.
dtype: Optional[str] = None
# Whether to restore the dataset checkpoint. Fails if checkpoint not present.
restore_dataset: bool = False
# The checkpointer class to use.
checkpointer_cls: checkpoints.CheckpointerConstructor = checkpoints.Checkpointer
# Transformations to apply, in order, to the state after reading. These will
# be applied after the `assignment_map` transformations.
state_transformation_fns: Sequence[
checkpoints.RestoreStateTransformationFn] = ()
def __post_init__(self):
if self.mode not in ('specific', 'latest', 'all'):
raise ValueError(
"`RestoreCheckpointConfig.mode` must be one of 'specific', 'latest', "
f"or 'all'. Got {self.mode}.")
if self.dtype not in (None, 'float32', 'bfloat16'):
raise ValueError(
"`RestoreCheckpointConfig.dtype` must be one of `None`, 'float32', "
f"or 'bfloat16'. Got {self.dtype}.")
if self.assignment_map is not None:
# Turns `assignment_map` into a transformation function.
assignment_map_fn = functools.partial(
state_utils.apply_assignment_map, assignment_map=self.assignment_map)
# Prepends the `assignment_map` transformation to the front of the list.
self.state_transformation_fns = (assignment_map_fn,
*self.state_transformation_fns)
@dataclasses.dataclass
class CheckpointConfig:
"""Configuration for checkpointing of model and dataset."""
save: Optional[SaveCheckpointConfig] = None
restore: Optional[RestoreCheckpointConfig] = None
class LegacyCheckpointer(orbax.checkpoint.Checkpointer):
"""Implementation of Checkpointer interface for T5X.
Relies on underlying save_checkpointer and restore_checkpointer, which are
t5x.checkpoints.Checkpointer objects.
"""
def __init__(self,
save_checkpointer: checkpoints.Checkpointer,
restore_checkpointer: checkpoints.Checkpointer,
*,
strict: Optional[bool] = False):
self._save_checkpointer = save_checkpointer
self._restore_checkpointer = restore_checkpointer
self._strict = strict
async def async_save(self, path: str, item: Any):
raise NotImplementedError
async def async_restore(self, path: str, item: Optional[Any] = None) -> Any:
raise NotImplementedError
def save(self,
path: str,
item: train_state_lib.TrainState,
state_transformation_fns: Sequence[
checkpoints.SaveStateTransformationFn] = (),
*,
concurrent_gb: int = 128):
"""Performs save operation using save_checkpointer.
Args:
path: path to save item to.
item: a TrainState PyTree to save.
state_transformation_fns: Transformations to apply, in order, to the state
before writing.
concurrent_gb: the approximate number of gigabytes of partitionable
parameters to process in parallel. Useful to preserve RAM.
"""
train_state = item
del path # stored in save_checkpointer
# dataset_iterator is also saved, but is provided in checkpointer init
self._save_checkpointer.save(
train_state, state_transformation_fns, concurrent_gb=concurrent_gb)
def restore(self,
path: str,
item: Optional[train_state_lib.TrainState],
state_transformation_fns: Sequence[
checkpoints.RestoreStateTransformationFn] = (),
fallback_state: Optional[Mapping[str, Any]] = None,
lazy_parameters: bool = False) -> train_state_lib.TrainState:
"""Performs restore operation using restore_checkpointer.
Determines whether the indicated path is a Tensorflow checkpoint.
Args:
path: the string path to restore from.
item: a TrainState PyTree to restore. Unused.
state_transformation_fns: Transformations to apply, in order, to the state
before writing.
fallback_state: a state dict of an optimizer to fall back to for loading
params that do not exist in the checkpoint (after applying all
`state_transformation_fns`), but do exist in `Checkpointer.optimizer`.
The union of `fallback_state` and state loaded from the checkpoint must
match `Checkpointer.optimizer`.
lazy_parameters: whether to load the parameters as LazyArrays to preserve
memory.
Returns:
The restored train state.
"""
del item # not needed for restore in T5X
from_tensorflow = gfile.exists(path + '.index')
if from_tensorflow and state_transformation_fns:
raise ValueError('Cannot initialize from a TensorFlow checkpoint using '
'`state_transformation_fns`.')
if from_tensorflow:
logging.info('Initializing parameters from TensorFlow checkpoint %s',
path)
return self._restore_checkpointer.restore_from_tf_checkpoint(
path, strict=self._strict)
return self._restore_checkpointer.restore(
path=path,
state_transformation_fns=state_transformation_fns,
fallback_state=fallback_state,
lazy_parameters=lazy_parameters)
class LegacyCheckpointManager(orbax.checkpoint.CheckpointManager):
"""Implementation of CheckpointManager interface for T5X.
Uses underlying LegacyCheckpointer to handle save/restore for Dataset and
TrainState.
"""
def __init__(self,
save_cfg: SaveCheckpointConfig,
restore_cfg: RestoreCheckpointConfig,
train_state_shape: train_state_lib.TrainState,
partitioner: partitioning.BasePartitioner,
ds_iter: Optional[tf.data.Iterator] = None,
model_dir: Optional[str] = None,
use_gda: Optional[bool] = False):
if save_cfg.save_dataset:
assert ds_iter is not None
save_checkpointer = save_cfg.checkpointer_cls(
train_state=train_state_shape,
partitioner=partitioner,
checkpoints_dir=model_dir,
dataset_iterator=ds_iter if save_cfg.save_dataset else None,
save_dtype=save_cfg.dtype,
keep=save_cfg.keep,
use_gda=use_gda,
keep_dataset_checkpoints=save_cfg.keep_dataset_checkpoints)
if restore_cfg:
restore_checkpointer = restore_cfg.checkpointer_cls(
train_state=train_state_shape,
partitioner=partitioner,
checkpoints_dir='', # unused for restore
dataset_iterator=ds_iter if restore_cfg.restore_dataset else None,
restore_dtype=jnp.dtype(restore_cfg.dtype)
if restore_cfg.dtype else None)
strict = restore_cfg.strict
else:
restore_checkpointer = None
strict = False
self._checkpointer = LegacyCheckpointer(
save_checkpointer, restore_checkpointer, strict=strict)
def save(self,
train_state: train_state_lib.TrainState,
state_transformation_fns: Sequence[
checkpoints.SaveStateTransformationFn] = ()):
"""Performs save operation.
Args:
train_state: a TrainState PyTree to save.
state_transformation_fns: Transformations to apply, in order, to the state
before writing.
"""
self._checkpointer.save(
path='', # not used
item=train_state,
state_transformation_fns=state_transformation_fns)
def restore(
self,
paths: Sequence[str],
restore_cfg: RestoreCheckpointConfig,
fallback_state: Optional[Mapping[str, Any]] = None
) -> Union[train_state_lib.TrainState, Sequence[train_state_lib.TrainState]]:
"""Performs restore operation using restore_checkpointer.
Determines whether the indicated path is a Tensorflow checkpoint.
Args:
paths: A sequence of paths to restore from.
restore_cfg: RestoreCheckpointConfig specifying restoration information.
fallback_state: a state dict of an optimizer to fall back to for loading
params that do not exist in the checkpoint (after applying all
`state_transformation_fns`), but do exist in `Checkpointer.optimizer`.
The union of `fallback_state` and state loaded from the checkpoint must
match `Checkpointer.optimizer`.
Returns:
The restored TrainState if only one TrainState can be restored from the
given paths, otherwise a sequence of TrainStates.
"""
if restore_cfg is None or paths is None:
return None
restored = []
for path in paths:
logging.info('Initializing parameters from specific T5X checkpoint %s',
path)
restored.append(
self._checkpointer.restore(
path=path,
item=None, # not used
state_transformation_fns=restore_cfg.state_transformation_fns,
fallback_state=fallback_state))
if len(restored) == 1:
restored = restored[0]
return restored
@dataclasses.dataclass
class DatasetConfig:
"""Configuration for loading a dataset from a SeqIO Task or Mixture."""
mixture_or_task_name: str
task_feature_lengths: Mapping[str, int]
split: str
batch_size: int
shuffle: bool
seed: Optional[int]
# Whether to use a precomputed version of the dataset from a cache dir.
use_cached: bool = False
pack: bool = False
# Whether to use tensor2tensor custom ops for more efficient packing.
use_custom_packing_ops: bool = False
# An optional module to import for registering the referenced Mixture or Task.
# DEPRECATED.
module: Optional[str] = None
# Whether to cache the dataset in memory (only applies to evaluation data).
use_memory_cache: bool = True
#------------------------------------------------------------------------------
# Fast *nondeterministic* hardware RNG for faster Dropout
#------------------------------------------------------------------------------
def _hardware_uniform(
rng_key: Array,
shape: Shape,
dtype: jnp.dtype = np.float32,
minval: Array = np.float32(0),
maxval: Array = np.float32(1)
) -> Array:
"""Random uniform method that uses non-deterministic accelerator hardware."""
del rng_key # non-deterministic prng.
minval = jax.lax.convert_element_type(minval, dtype)
maxval = jax.lax.convert_element_type(maxval, dtype)
return jax.lax.rng_uniform(minval, maxval, shape)
# For dropout-only hardware rng.
def _hardware_bernoulli(
rng_key: Array, p: np.ndarray = np.float32(0.5),
shape: Shape = ()) -> Array:
del rng_key # non-deterministic prng.
return jax.lax.rng_uniform(0.0, 1.0, shape) < p
def set_hardware_rng_ops():
"""Enable JAX Custom PRNG extension."""
jax.config.update('jax_enable_custom_prng', True)
# Use only fast TPU hardware PRNG with iterated-hash "split" substitute.
# Expected to be deterministic for a fixed partitioning.
# Monkey-patch JAX PRNGKey to use unsafe_rbg_prng_impl
# TODO(levskaya): replace with jax global config option once we debug it.
rbg_prng_key = functools.partial(prng.seed_with_impl,
prng.unsafe_rbg_prng_impl)
jax.random.PRNGKey = rbg_prng_key
jax._src.random.PRNGKey = rbg_prng_key # pylint: disable=protected-access
# -----------------------------------------------------------------------------
# Training utility functions.
# -----------------------------------------------------------------------------
def get_zeros_batch_like_spec(
batch_spec: Mapping[str,
jax.ShapeDtypeStruct]) -> Mapping[str, jnp.ndarray]:
return {k: jnp.zeros(t.shape, t.dtype) for k, t in batch_spec.items()}
def get_zeros_batch_like_dataset(dataset: tf.data.Dataset,
batch_size=None) -> Mapping[str, jnp.ndarray]:
reshape = lambda s: (batch_size,) + s[1:] if batch_size else tuple(s)
batch_spec = {
k: jax.ShapeDtypeStruct(reshape(t.shape), t.dtype.as_numpy_dtype)
for k, t in dataset.element_spec.items()
}
return get_zeros_batch_like_spec(batch_spec)
class InitFnCallable(typing_extensions.Protocol):
"""A callable that initializes model variables."""
def __call__(
self, rng: Array, input_shapes: Mapping[str, Array],
input_types: Optional[Mapping[str,
DType]]) -> flax_scope.FrozenVariableDict:
...
class LearningRateCallable(typing_extensions.Protocol):
def __call__(self, step: jnp.ndarray) -> jnp.ndarray:
...
def create_learning_rate_scheduler(
factors: str = 'constant * linear_warmup * rsqrt_decay',
base_learning_rate: float = 0.5,
warmup_steps: int = 1000,
decay_factor: float = 0.5,
steps_per_decay: int = 20000,
steps_per_cycle: int = 100000,
step_offset: int = 0,
min_learning_rate: float = 1e-8) -> LearningRateCallable:
"""Creates learning rate schedule.
Interprets factors in the factors string which can consist of:
* constant: interpreted as the constant value,
* linear_warmup: interpreted as linear warmup until warmup_steps,
* linear_decay: linear decay from warmup_steps with decay_factor slope. Note
this option implies 'constant * linear_warmup', and should not be used in
in conjunction with `constant` or `linear_warmup` factors.
* rsqrt_decay: divide by square root of max(step, warmup_steps)
* rsqrt_normalized_decay: divide by square root of max(step/warmup_steps, 1)
* decay_every: Every k steps decay the learning rate by decay_factor.
* cosine_decay: Cyclic cosine decay, uses steps_per_cycle parameter.
Args:
factors: string, factors separated by '*' that defines the schedule.
base_learning_rate: float, the starting constant for the lr schedule.
warmup_steps: int, how many steps to warm up for in the warmup schedule.
decay_factor: float, the amount to decay the learning rate by.
steps_per_decay: int, how often to decay the learning rate.
steps_per_cycle: int, steps per cycle when using cosine decay.
step_offset: int, an offset that the step parameters to this function are
relative to.
min_learning_rate: float, minimum learning rate to output. Useful for cases
when a decay function is (mis)configured to decay to non-positive values.
Returns:
a function learning_rate(step): float -> {'learning_rate': float}, the
step-dependent lr.
"""
factors = [n.strip() for n in factors.split('*')]
def step_fn(step: jnp.ndarray) -> jnp.ndarray:
"""Step to learning rate function."""
step = jnp.maximum(0, step - step_offset)
ret = 1.0
for name in factors:
if name == 'constant':
ret *= base_learning_rate
elif name == 'linear_warmup':
ret *= jnp.minimum(1.0, step / warmup_steps)
elif name == 'linear_decay':
ret *= base_learning_rate * jnp.minimum(
step / warmup_steps, 1.0 + decay_factor * (warmup_steps - step))
elif name == 'rsqrt_decay':
ret /= jnp.sqrt(jnp.maximum(step, warmup_steps))
elif name == 'rsqrt_normalized_decay':
ret *= jnp.sqrt(warmup_steps)
ret /= jnp.sqrt(jnp.maximum(step, warmup_steps))
elif name == 'decay_every':
ret *= (decay_factor**(step // steps_per_decay))
elif name == 'cosine_decay':
progress = jnp.maximum(0.0,
(step - warmup_steps) / float(steps_per_cycle))
ret *= jnp.maximum(0.0,
0.5 * (1.0 + jnp.cos(jnp.pi * (progress % 1.0))))
else:
raise ValueError('Unknown factor %s.' % name)
ret = jnp.maximum(ret, min_learning_rate)
return jnp.asarray(ret, dtype=jnp.float32)
return step_fn
def get_first_valid_restore_config_and_paths(
restore_cfgs: Sequence[RestoreCheckpointConfig]
) -> Tuple[Optional[RestoreCheckpointConfig], Sequence[str]]:
"""Returns first valid restore_cfg and the paths to restore.
Args:
restore_cfgs: a sequence of RestoreCheckpointConfig objects, which should be
filtered to determine the first valid object.
Returns:
Tuple of valid RestoreCheckpointConfig and a sequence of paths.
If the first config encountered has mode 'specfic', it is immediately
returned, along with its specified paths.
If the mode is 'all' or 'latest', checks to ensure that there are valid
checkpoints at each of the provided paths and filters the returned paths
accordingly.
"""
for restore_cfg in restore_cfgs:
paths = ([restore_cfg.path]
if isinstance(restore_cfg.path, str) else restore_cfg.path)
if restore_cfg.mode == 'specific':
return restore_cfg, paths
elif restore_cfg.mode in ('all', 'latest'):
for ckpt_dir in paths:
if not gfile.isdir(ckpt_dir):
raise ValueError(
'Checkpoint path(s) must be valid directories when using '
"restore mode 'all' or 'latest'.")
# Check if this is a TensorFlow checkpoint dir.
tf_ckpt_state = tf.train.get_checkpoint_state(ckpt_dir)
if tf_ckpt_state:
ckpt_paths = tf_ckpt_state.all_model_checkpoint_paths
else:
ckpt_paths = [
os.path.join(ckpt_dir, f'checkpoint_{step}')
for step in checkpoints.all_steps(ckpt_dir)
]
if not ckpt_paths:
logging.info('No checkpoints found in specified directory: %s',
ckpt_dir)
continue
if restore_cfg.mode == 'latest':
logging.info('Using latest T5X checkpoint.')
ckpt_paths = ckpt_paths[-1:]
return restore_cfg, ckpt_paths
else:
logging.error('Unsupported checkpoint restore mode: %s', restore_cfg.mode)
return None, []
def get_fallback_state(restore_cfg: RestoreCheckpointConfig,
init_fn: Callable[[jnp.ndarray], Mapping[str, Any]],
init_rng: jnp.ndarray) -> Optional[Mapping[str, Any]]:
"""Returns the fallback_state that can be used in restore()."""
if restore_cfg is None:
return
if restore_cfg.fallback_to_scratch:
if not restore_cfg.state_transformation_fns:
raise ValueError('`state_transformation_fns` must be provided with '
'`fallback_to_scratch`')
if init_rng is None:
raise ValueError('An `init_rng` must be provided with '
'`fallback_to_scratch`')
fallback_state = init_fn(init_rng)
else:
fallback_state = None
return fallback_state
class TrainStateInitializer:
"""Helper for initializing partitioned TrainState from checkpoints or scratch.
Common use cases:
* To restore from a single checkpoint, use `from_checkpoint`.
* To iterate over multiple checkpoints without recompiling the model,
use `from_checkpoints`.
* To initialize from scratch, use `from_scratch`.
* To restore from a checkpoint with a fallback to initializing from scratch,
use `from_checkpoint_or_scratch`.
Attributes:
global_train_state_shape: a TrainState containing the global (unpartitioned)
shape (in `jax.ShapeDtypeStruct`) of each parameter instead of its value.
train_state_axes: a TrainState object containing a PartitionSpec (or None)
for each parameter, in place of the parameter itself.
"""
# TODO(adarob): Replace input_shapes and input_types with sample batch.
def __init__(self,
optimizer_def: Optional[optimizers.OptimizerDefType],
init_fn: InitFnCallable,
input_shapes: Mapping[str, Array],
partitioner: partitioning.BasePartitioner,
input_types: Optional[Mapping[str, DType]] = None):
"""TrainStateInitializer constructor.
Args:
optimizer_def: Optimizer def to be initialized, or None to create a
`InferenceState` without an optimizer.
init_fn: callable that initializes model variables from a PRNGKey and the
input shapes.
input_shapes: a mapping from key to array shape for each feature in the
global (unsharded) input batch.
partitioner: the partitioner to use.
input_types: a mapping from key to array type for each feature in the
global (unshared) input batch. If not provided, the type is assumed to
be `jnp.float32`.
"""
def initialize_train_state(rng: Array):
initial_variables = init_fn(
rng=rng, input_shapes=input_shapes, input_types=input_types)
if optimizer_def:
return train_state_lib.FlaxOptimTrainState.create(
optimizer_def, initial_variables)
return train_state_lib.InferenceState.create(initial_variables)
self._partitioner = partitioner
self.global_train_state_shape = jax.eval_shape(
initialize_train_state, rng=jax.random.PRNGKey(0))
self.train_state_axes = partitioner.get_mesh_axes(
self.global_train_state_shape)
self._initialize_train_state = initialize_train_state
# Currently scanned layers require passing annotations through to the
# point of the scan transformation to resolve an XLA SPMD issue.
# init_fn is always(?) equal to model.get_initial_variables, fetch the model
# instance from the bound method.
model = init_fn.__self__ # pytype: disable=attribute-error
if (hasattr(model, 'module') and hasattr(model.module, 'scan_layers') and
model.module.scan_layers):
if hasattr(model.module, 'spmd_annotations'):
# update top-level module with spmd annotations.
model.module = model.module.clone(
parent=None, spmd_annotations=self.train_state_axes.params)
def from_scratch(self, init_rng: Array) -> train_state_lib.TrainState:
"""Initializes the partitioned Optimizer from scratch."""
logging.info('Initializing parameters from scratch.')
# If pretraining and no checkpoint imported, we jit the (sharded-) init
# function to minimize fragmentation. We use the same partition
# setup as the training step/loop to initialize everything "in-place" and
# avoid communication or OOM.
p_initialize_train_state_fn = self._partitioner.partition(
self._initialize_train_state,
in_axis_resources=None,
out_axis_resources=self.train_state_axes)
return p_initialize_train_state_fn(init_rng)
# TODO(b/216650048) deprecate this function and use orbax.
def from_checkpoints(
self,
restore_cfgs: Sequence[RestoreCheckpointConfig],
ds_iter: Optional[tf.data.Iterator] = None,
init_rng: Optional[jnp.ndarray] = None,
) -> Iterable[train_state_lib.TrainState]:
"""Yields 0 or more restored partitioned Optimizers, and maybe datasets.
The manner in which parameters are initialized depends on `restore_cfgs` and
`restore_cfgs` is iterated over and the first config that matches one or
more existing checkpoints is used to generate restored optimizers from the
checkpoint(s). Any remaining configs are ignored.
Args:
restore_cfgs: ordered sequence of configurations specifying checkpoint(s)
to restore from. The first config to match a checkpoint will be used.
ds_iter: a tf.data.Iterator for the input data, or None. If provided, the
referenced iterator's state may be silently restored (depending on the
config's `restore_dataset` value) along with the optimizer.
init_rng: for initializing parameters from scratch when they are not
available in the checkpoint and `fallback_to_scratch` is True
Yields:
TrainState with initialized optimizer, with parameters copied to devices.
"""
def _restore_path(path, cfg):
restore_checkpointer = cfg.checkpointer_cls(
train_state=self.global_train_state_shape,
partitioner=self._partitioner,
checkpoints_dir='', # unused for restore
dataset_iterator=ds_iter if cfg.restore_dataset else None,
restore_dtype=jnp.dtype(cfg.dtype) if cfg.dtype else None)
from_tensorflow = gfile.exists(path + '.index')
if from_tensorflow and cfg.state_transformation_fns:
raise ValueError('Cannot initialize from a TensorFlow checkpoint using '
'`state_transformation_fns`.')
if from_tensorflow:
logging.info('Initializing parameters from TensorFlow checkpoint %s',
path)
return restore_checkpointer.restore_from_tf_checkpoint(
path, strict=cfg.strict)
else:
fallback_state = get_fallback_state(
cfg, lambda rng: self.from_scratch(rng).state_dict(), init_rng)
logging.info('Initializing parameters from specific T5X checkpoint %s',
path)
return restore_checkpointer.restore(
path=path,
state_transformation_fns=cfg.state_transformation_fns,
fallback_state=fallback_state)
restore_cfg, paths = get_first_valid_restore_config_and_paths(restore_cfgs)
for path in paths:
yield _restore_path(path, restore_cfg)
def from_checkpoint(
self,
ckpt_cfgs: Sequence[RestoreCheckpointConfig],
*,
ds_iter: Optional[tf.data.Iterator] = None,
init_rng: Optional[jnp.ndarray] = None
) -> Optional[train_state_lib.TrainState]:
"""Restores (at most) 1 checkpoint using `from_checkpoints`, or dies."""
train_states = list(
self.from_checkpoints(ckpt_cfgs, ds_iter=ds_iter, init_rng=init_rng))
if len(train_states) > 1:
raise ValueError(
f'Expected at most 1 checkpoint but got {len(train_states)} for '
f'config(s): {ckpt_cfgs}')
return (train_states[0]) if train_states else None
def from_checkpoint_or_scratch(
self,
ckpt_cfgs: Sequence[RestoreCheckpointConfig],
*,
init_rng: Array,
ds_iter: Optional[tf.data.Iterator] = None) -> train_state_lib.TrainState:
"""Initializes from checkpoint, if found, or from scratch."""
return (self.from_checkpoint(ckpt_cfgs, ds_iter=ds_iter, init_rng=init_rng)
or self.from_scratch(init_rng))
# -----------------------------------------------------------------------------
# Logging utility functions
# -----------------------------------------------------------------------------
def log_model_info(log_file: Optional[str],
full_train_state: train_state_lib.TrainState,
partitioner: partitioning.BasePartitioner):
"""Log the variable shapes information and optionally write it to a file."""
# Only write logs on host 0.
if jax.process_index() != 0:
return
state_dict = full_train_state.state_dict()
total_num_params = jax.tree_util.tree_reduce(
np.add, jax.tree_map(np.size, state_dict['target']))
logical_axes = partitioner.get_logical_axes(full_train_state).state_dict()
mesh_axes = jax.tree_map(
lambda x: tuple(x) if x is not None else None,
partitioner.get_mesh_axes(full_train_state).state_dict())
def _log_info_and_write_to_file(writer, format_str, *args):
logging.info(format_str, *args)
if writer is not None:
writer.write(format_str % args + '\n')
with contextlib.ExitStack() as stack:
writer = stack.enter_context(gfile.GFile(
log_file, 'w')) if log_file is not None else None
# Log params
def _log_variable(name: str, arr: Optional[np.ndarray],
logical_axes: Optional[partitioning.AxisNames],
mesh_axes: Optional[partitioning.PartitionSpec]):
# Log nothing on empty dict leaves, which occur with optax EmptyState().
if isinstance(arr, dict) and not arr:
return
if arr is None:
_log_info_and_write_to_file(writer, 'Variable %-80s None', name)
return
if logical_axes is None or len(logical_axes) != len(arr.shape):
shape_str = str(arr.shape)
else:
shape_str = '({})'.format(', '.join(
f'{name}={dimension}'
for name, dimension in zip(logical_axes, arr.shape)))
_log_info_and_write_to_file(
writer, 'Variable %-80s size %-12s shape %-40s partition spec %s',
name, arr.size, shape_str, mesh_axes)
jax.tree_map(
_log_variable,
state_utils.get_name_tree(state_dict['target'], keep_empty_nodes=True),
state_dict['target'], logical_axes['target'], mesh_axes['target'])
_log_info_and_write_to_file(writer, 'Total number of parameters: %d',
total_num_params)
# Add a blank line between params and states.
_log_info_and_write_to_file(writer, '')
jax.tree_map(
_log_variable,
state_utils.get_name_tree(state_dict['state'], keep_empty_nodes=True),
state_dict['state'], logical_axes['state'], mesh_axes['state'])
# -----------------------------------------------------------------------------
# Utility functions for prediction and evaluation.
# -----------------------------------------------------------------------------
class InferStepWithRngCallable(typing_extensions.Protocol):
def __call__(self,
params: Mapping[str, Any],
batch: Mapping[str, jnp.ndarray],
rng: jnp.ndarray = None) -> PyTreeDef:
"""Runs an inference step returning a prediction or score."""
...
class InferStepWithoutRngCallable(typing_extensions.Protocol):
def __call__(self, params: Mapping[str, Any],
batch: Mapping[str, jnp.ndarray]) -> PyTreeDef:
"""Runs an inference step returning a prediction or score."""
...
InferStepCallable = Union[InferStepWithRngCallable, InferStepWithoutRngCallable]
# NOTE: We're not more prescriptive than PyTreeDef because that's what
# InferStepCallable expects.
_InferFnResult = Sequence[Tuple[int, PyTreeDef]]
_InferFnWithAuxResult = Tuple[_InferFnResult, Mapping[str, Sequence[Any]]]
class InferFnCallable(typing_extensions.Protocol):
def __call__(
self,
ds: tf.data.Dataset,
train_state: train_state_lib.TrainState,
rng: Optional[jnp.ndarray] = None
) -> Union[_InferFnResult, _InferFnWithAuxResult]:
"""Runs inference on the dataset."""
...
def _remove_padding(all_inferences, all_indices):
"""Remove padded examples.
Args:
all_inferences: PyTree[total_examples + padding_count, ...].
all_indices: [total_examples + padding_count].
Returns:
all_inferences in shape PyTree[total_examples, ...].
all_indices in shape [total_exmamples].
"""
non_pad_idxs = np.where(all_indices >= 0)
all_indices = all_indices[non_pad_idxs]
all_inferences = jax.tree_map(lambda x: x[non_pad_idxs], all_inferences)
return all_inferences, all_indices
def get_infer_fn(infer_step: InferStepCallable, batch_size: int,
train_state_axes: train_state_lib.TrainState,
partitioner: partitioning.BasePartitioner) -> InferFnCallable:
"""Get prediction function for the SeqIO evaluator.
The returned prediction function should take in an enumerated dataset, make
predictions and return in an enumerated form with the original indices and
examples zipped together. This ensures that the predictions are compared to
the targets in a correct order even if the dataset is sharded across
multiple hosts and gathered in a nondeterministic way.
jax.process_index == 0 is used as a "main host", i.e., it gathers all
inference results and returns.
Shape notation:
Per replica set num replicas: R
Per replica set batch size: B
Number of replica sets: H
Length: L
Some transformations have shape transformation annotation, e.g.,
[B, L] -> [R, B/R, L].
Args:
infer_step: a callable that executes one prediction step. Should not yet be
partitioned or pmapped.
batch_size: the global infer batch size.
train_state_axes: Partitioning info for the train state object.
partitioner: partitioner to use.
Returns:
predict_fn: a callable which takes in the enumerated infer dataset and an
optimizer and runs the prediction.
"""
def infer_step_with_indices(params, batch, rng, indices):
if 'rng' in inspect.signature(infer_step).parameters:
res = typing.cast(InferStepWithRngCallable, infer_step)(params, batch,
rng)
else:
res = typing.cast(InferStepWithoutRngCallable, infer_step)(params, batch)
return indices, res
partitioned_infer_step = partitioner.partition(
infer_step_with_indices,
in_axis_resources=(train_state_axes.params,
partitioner.data_partition_spec, None,
partitioner.data_partition_spec),
out_axis_resources=(None, None))
data_layout = partitioner.get_data_layout(batch_size)
shard_id = data_layout.shard_id
num_shards = data_layout.num_shards
per_shard_batch_size = batch_size // num_shards
def infer_fn(ds: tf.data.Dataset,
train_state: train_state_lib.TrainState,
rng: Optional[jnp.ndarray] = None):
ds_shapes = jax.tree_map(lambda x: jnp.array(x.shape), ds.element_spec)
multihost_utils.assert_equal(
ds_shapes, 'Dataset element shapes do not agree across hosts. '
'This could be an indication that the dataset is nondeterministic.')
try:
original_ds_length = len(ds)
dataset_remainder = original_ds_length % batch_size # pytype:disable=wrong-arg-types
logging.info('length of dataset = %s', len(ds))
except TypeError as e:
if str(e) == 'dataset length is unknown.':
logging.warning(
'The following error is likely due to the use of TensorFlow v1 in '
'your dataset pipeline. Verify you are not importing from '
'`tf.compat.v1` as part of your pipeline.')
raise e
if dataset_remainder:
dataset_pad_amt = batch_size - dataset_remainder
logging.info(
'Padding infer dataset with %d examples for even per-replica shards.',
dataset_pad_amt)
# Pad with the first example using an index of -1 so seqio will ignore.
pad_ds = ds.take(1).map(lambda i, x: (np.int64(-1), x)).repeat(
dataset_pad_amt)
ds = ds.concatenate(pad_ds)
# Shard the infer dataset across replica sets.
sharded_ds = ds.shard(num_shards, shard_id).batch(
per_shard_batch_size, drop_remainder=True)
multihost_utils.assert_equal(
jnp.array(len(sharded_ds)),
'Dataset lengths do not agree across hosts.')
logging.info(
'The infer dataset is sharded into %d shards with per-shard '
'batch size of %d', num_shards, per_shard_batch_size)
# Run inference for each replica set.
batched_results, all_indices = [], []
for index, infer_batch in sharded_ds.as_numpy_iterator():
if rng is None:
step_rng = None
else:
step_rng, rng = jax.random.split(rng)
# Run fast inference on batch.
# [B, ...] -> [B * shard_count, ...]
# partitioned_infer_step executes infer_step on sharded batched data, and
# returns de-sharded batched indices and result replicated on all hosts.
batch_indices, batch_result = partitioned_infer_step(
train_state.params, infer_batch, step_rng, index)
logging.info('Inference of batch %s done.', index)
# Issue asynchronous copy request which serves as prefetching to the host.
def _copy_to_host_async(x):
if isinstance(x, GlobalDeviceArray):
x.local_data(0).copy_to_host_async() # GDA is fully replicated
return x.local_data(0)
else:
x.copy_to_host_async()
return x
try:
batch_result = jax.tree_map(_copy_to_host_async, batch_result)
batch_indices = jax.tree_map(_copy_to_host_async, batch_indices)
except AttributeError:
# Similar to jax.device_get, we skip transfers for non DeviceArrays.
pass
batched_results.append(batch_result)
all_indices.append(batch_indices)
logging.info('Inference of all batches done.')
all_inferences = batched_results
# List[B * shard_count, ...] -> [B * shard_count * batch_count, ...]
all_inferences = jax.tree_multimap(lambda *args: np.concatenate(args),
*all_inferences)
all_indices = np.concatenate(all_indices)
all_inferences, all_indices = _remove_padding(all_inferences, all_indices)
# Results are returned from infer_step out of order due to shard operation.
# Note: remove padding first, as -1 indices would mess up this operation.
# Note: all_inferences may be a PyTree, not just an array, e.g. if
# `infer_step` is `model.predict_batch_with_aux`.
all_inferences = jax.tree_map(lambda x: x[all_indices], all_inferences)
all_indices = all_indices[all_indices]
# aux_values is supposed to be a dictionary that maps strings to a set of
# auxiliary values.
#
# We don't want to flatten/unflatten the aux values. We want to preserve the
# unflattened values with the type List[Mapping[str, Sequence[Any]]]. We do
# this as a memory optimization to avoid lots of redundant keys if we'd
# instead had List[Mapping[str, Any]].
#
# It has shape Mapping[str, [B * shard_count * batch_count, ...]]. That is,
# the first dimension of each of the values in aux_values is equal to
# len(all_inferences).
aux_values = None
if (isinstance(all_inferences, tuple) and len(all_inferences) == 2 and
isinstance(all_inferences[1], Mapping)):
all_inferences, aux_values = all_inferences
# Translate to List[...] by flattening inferences making sure to
# preserve structure of individual elements (inferences are not assumed to
# be simple np.array). Finally, zip inferences with corresponding indices
# and convert leaf np.arrays into lists.
all_inferences, struct = jax.tree_flatten(all_inferences)
all_inferences = map(
functools.partial(jax.tree_unflatten, struct), zip(*all_inferences))
indices_and_outputs = list(zip(all_indices, all_inferences))
indices_and_outputs = jax.tree_map(lambda x: np.array(x).tolist(),
indices_and_outputs)
if len(indices_and_outputs) != original_ds_length:
raise ValueError(
'Size of indices_and_outputs does not match length of original '
'dataset: %d versus %d' %
(len(indices_and_outputs), original_ds_length))
if aux_values is None:
return indices_and_outputs
else:
aux_values = jax.tree_map(lambda x: np.array(x).tolist(), aux_values)
return indices_and_outputs, aux_values
return infer_fn
# -----------------------------------------------------------------------------
# SeqIO utility functions.
# -----------------------------------------------------------------------------
def import_module(module: str):
"""Imports the given module at runtime."""
logging.info('Importing %s.', module)
try:
importlib.import_module(module)
except RuntimeError as e:
if (str(e) ==
'Attempted to add a new configurable after the config was locked.'):
raise RuntimeError(
'Your Task/Mixture module contains gin configurables that must be '
'loaded before gin flag parsing. One fix is to add '
f"'import {module}' in your gin file.")
raise e
def get_vocabulary(
cfg: DatasetConfig) -> Tuple[seqio.Vocabulary, seqio.Vocabulary]:
"""Returns `seqio.Vocabulary` objects associated with the `Mixture`/`Task`.
Args:
cfg: the DatasetConfig specifying which mixture or task to get the
vocabularies for.
Returns:
A tuple of seqio.Vocabulary for inputs and targets.
Raises:
ValueError: if inputs and targets are not both present and vocabularies
are different.
"""
if cfg.module:
warnings.warn(
'The use of `DatasetConfig.module` and `MIXTURE_OR_TASK_MODULE` is '
'deprecated in favor of importing the module directly or via gin.',
DeprecationWarning)
import_module(cfg.module)
provider = seqio.get_mixture_or_task(cfg.mixture_or_task_name)
features = provider.output_features
if 'inputs' in features and 'targets' in features:
return (features['inputs'].vocabulary, features['targets'].vocabulary)
# If a mix of PassThroughVocabularies and other Vocabularies are specified,
# use the non-PassThroughVocabularies.
# TODO(b/185912004): Remove this once a more general solution is implemented.
vocabularies = list(
f.vocabulary
for f in features.values()
if not isinstance(f.vocabulary, seqio.PassThroughVocabulary))
# Otherwise, if all of the vocabs are PassThroughVocabularies, use those.
if not vocabularies:
vocabularies = list(f.vocabulary for f in features.values())
# If there still aren't any vocabularies, raise an error.
if not vocabularies:
raise ValueError('"inputs" and "targets" are not both present, and '
'no vocabularies were set for any features.')
first_vocab = vocabularies[0]
for vocab in vocabularies[1:]:
if vocab != first_vocab:
raise ValueError('"inputs" and "targets" are not both present, and '
'vocabularies are different.')
return (first_vocab, first_vocab)
def get_dataset(cfg: DatasetConfig,
shard_id: int,
num_shards: int,
feature_converter_cls: Type[seqio.FeatureConverter],
num_epochs: Optional[int] = None,
continue_from_last_checkpoint: bool = False) -> tf.data.Dataset:
"""Returns a dataset from SeqIO based on a `DatasetConfig`."""
if continue_from_last_checkpoint:
raise ValueError(
'`continue_from_last_checkpoint` must be set to False as this is not '
'supported by this dataset fn.')
del continue_from_last_checkpoint
if cfg.module:
import_module(cfg.module)
if cfg.batch_size % num_shards:
raise ValueError(
f'Batch size ({cfg.batch_size}) must be divisible by number of '
f'shards ({num_shards}).')
shard_info = seqio.ShardInfo(index=shard_id, num_shards=num_shards)
if cfg.seed is None:
# Use a shared timestamp across devices as the seed.
seed = multihost_utils.broadcast_one_to_all(np.int32(time.time()))
else:
seed = cfg.seed
return get_dataset_inner(cfg, shard_info, feature_converter_cls, seed,
num_epochs)
def get_dataset_inner(cfg: DatasetConfig,
shard_info: seqio.ShardInfo,
feature_converter_cls: Type[seqio.FeatureConverter],
seed: Optional[int] = None,
num_epochs: Optional[int] = None):
"""Internal fn to load a dataset from SeqIO based on a `DatasetConfig`."""
batch_size = cfg.batch_size // shard_info.num_shards
if seed is not None:
multihost_utils.assert_equal(
np.array(seed),
f'`seed` is not same across hosts; {jax.process_index} has a seed of '
f'{seed}')
logging.info(
"Initializing dataset for task '%s' with a replica batch size of %d and "
'a seed of %d', cfg.mixture_or_task_name, batch_size, seed)
ds = seqio.get_dataset(
mixture_or_task_name=cfg.mixture_or_task_name,
task_feature_lengths=cfg.task_feature_lengths,
dataset_split=cfg.split,
shuffle=cfg.shuffle,
num_epochs=num_epochs,
feature_converter=feature_converter_cls(
pack=cfg.pack, use_custom_packing_ops=cfg.use_custom_packing_ops), # pytype: disable=not-instantiable
shard_info=shard_info,
use_cached=cfg.use_cached,
seed=seed)
ds = ds.batch(batch_size, drop_remainder=True)
return ds
class GetDatasetCallable(typing_extensions.Protocol):
"""Interface for a function returning a dataset (iterator)."""
def __call__(
self,
cfg: DatasetConfig,
shard_id: int,
num_shards: int,
feature_converter_cls: Callable[..., seqio.FeatureConverter],
num_epochs: Optional[int] = None,
continue_from_last_checkpoint: bool = True
) -> Union[clu.data.DatasetIterator, tf.data.Dataset]:
...
def get_training_eval_datasets(
cfg: DatasetConfig,
shard_id: int,
num_shards: int,
eval_steps: int,
feature_converter_cls: Callable[..., seqio.FeatureConverter],
get_dataset_fn: GetDatasetCallable = get_dataset,
) -> Mapping[str, tf.data.Dataset]:
"""Returns a mapping from eval task name to its dataset."""
mixture_or_task = seqio.get_mixture_or_task(cfg.mixture_or_task_name)
datasets = {}
if cfg.batch_size % num_shards:
raise ValueError(
f'Batch size ({cfg.batch_size}) must be divisible by number of '
f'shards ({num_shards}).')
def _repeat_shard_batch_take_cache(ds: tf.data.Dataset):
# We shard and batch the full, repeated dataset to avoid issues with uneven
# file shards.
if not isinstance(ds, tf.data.Dataset):
raise ValueError('Only tf.data.Dataset objects supported.')
return ds.unbatch().repeat().shard(num_shards, shard_id).batch(
cfg.batch_size // num_shards,
drop_remainder=True).take(eval_steps).cache()
for task in seqio.get_subtasks(mixture_or_task):
if cfg.split not in task.splits:
logging.info("Task %s has no '%s' split; skipping training evaluation.",
task.name, cfg.split)
continue
logging.info('Loading task %s for training evaluation.', task.name)
task_cfg = dataclasses.replace(
cfg, mixture_or_task_name=task.name, batch_size=1)
# We set `num_epochs` to be finite to avoid infinite loops on shards that
# have input examples that are all filtered.
datasets[task.name] = _repeat_shard_batch_take_cache(
get_dataset_fn(
task_cfg,
shard_id=0,
num_shards=1,
feature_converter_cls=feature_converter_cls,
num_epochs=eval_steps * cfg.batch_size,
continue_from_last_checkpoint=False))
if isinstance(mixture_or_task, seqio.Mixture):
datasets[mixture_or_task.name] = _repeat_shard_batch_take_cache(
get_dataset_fn(
dataclasses.replace(cfg, batch_size=1),
shard_id=0,
num_shards=1,
feature_converter_cls=feature_converter_cls,
num_epochs=eval_steps * cfg.batch_size,
continue_from_last_checkpoint=False))
return datasets
def round_vocab_size_to_multiple(vocabulary: seqio.Vocabulary,
divisor: int = 128):
"""Round up vocabulary size for improved TPU performance."""
size = vocabulary.vocab_size
return size + -size % divisor
def flatten_dict_string_keys(x):
"""Flattens a nested dictionary to have string keys and '/' separators."""
return traverse_util.flatten_dict(flax.core.unfreeze(x), sep='/')
class _RegexMap(collections.abc.Mapping):
"""Ordered mapping from regexes to values requiring a full match."""
def __init__(self, kvs: Sequence[Tuple[str, Any]]):
self._kvs = [(re.compile(k), v) for k, v in kvs]
def __getitem__(self, key: str) -> Any:
for pattern, v in self._kvs:
if pattern.fullmatch(key):
return v
raise KeyError(f'No pattern matching key: {key}')
def __len__(self) -> int:
return len(self._kvs)
def __iter__(self) -> Iterable[Tuple[re.Pattern, Any]]:
return iter(self._kvs)
def override_params_axes_names(
model_variables: flax_scope.FrozenVariableDict,
params_axes_names_override: Sequence[Tuple[str, Tuple[str, ...]]] = ()
) -> flax_scope.FrozenVariableDict:
"""Applies parameter axis names overrides to axes variables.
Args:
model_variables: the original model variables containing the 'params_axes'
collection.
params_axes_names_override: a priority-ordered mapping from regex patterns
(fully matching parameter names) to tuples containing string logical axis
names to replace model-derived names.
Returns:
an updated set of model variables with the overrides applied to the
'params_axes' collection.
"""
params_axes_names_override_map = _RegexMap(params_axes_names_override)
if 'params_axes' not in model_variables:
raise ValueError(
"Model variables do not contain a 'params_axes' collection to apply an "
'override to.')
model_variables = model_variables.unfreeze()
flat_params = traverse_util.flatten_dict(model_variables['params'])
flat_params_axes = traverse_util.flatten_dict(model_variables['params_axes'])
for key, param in flat_params.items():
param_name = '/'.join(key)
override = params_axes_names_override_map.get(param_name)
if override is None:
continue
param_axes_key = key[:-1] + (f'{key[-1]}_axes',)
curr_metadata = flat_params_axes.get(param_axes_key)
if curr_metadata is None:
logging.info('Adding axis names for %s: %s', param_name, override)
else:
assert isinstance(curr_metadata, flax_partitioning.AxisMetadata)
logging.info('Replacing axis names for %s (%s) with %s.', param_name,
curr_metadata.names, override)
if param.ndim != len(override):
raise ValueError(
f'Provided axis name override for {param_name} does not match '
f'param rank ({param.ndim}): {override}')
flat_params_axes[param_axes_key] = flax_partitioning.AxisMetadata(
names=override)
model_variables['params_axes'] = traverse_util.unflatten_dict(
flat_params_axes)
return flax.core.freeze(model_variables)
def get_local_data(x):
if isinstance(x, GlobalDeviceArray):
return x.local_data(0)
elif isinstance(x, pxla.ShardedDeviceArray):
val = x.device_buffers[0]
if val.aval is None:
val.aval = jax.ShapedArray(val.shape, val.dtype)
return val
else:
return x
|