Spaces:
Build error
Build error
File size: 4,891 Bytes
b100e1c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
# Copyright 2022 The T5X Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Extensions to Jax/Flax core functions for Mixture of Experts training.
"""
import dataclasses
import re
from typing import Any, Callable, Iterable, Optional, Sequence, Tuple, Union
import flax
import jax
import numpy as np
from t5x import train_state
# Type Stubs
ParamTree = Any
PyTreeDef = Any
Gradients = Union[flax.core.FrozenDict, train_state.TrainState]
def match_fn(prefix: Optional[str]) -> Callable[[str], bool]:
"""Creates a function returning true iff a string matches the prefix.
Args:
prefix: Regex prefix to match. If none, then return match function will not
match any strings.
Returns:
Prefix match function.
"""
if not prefix:
return lambda name: False
params_regex = re.compile(f'^{prefix}')
return lambda name: params_regex.match(name) is not None
def scale_sharded_grads(grads: Gradients,
sharded_match_fn: Optional[Callable[[str], bool]],
scale_factor: float) -> Gradients:
"""Scales sharded grads, identified by sharded_match_fn, by scale_factor.
Args:
grads: Parameter gradients.
sharded_match_fn: Filter function for distinguishing sharded parameters from
replicated parameters.
scale_factor: Amount by which to scale sharded parameter gradients.
Returns:
Gradients matching input, expect with sharded parameter gradients rescaled.
"""
if sharded_match_fn:
names_and_grads, tree_def = _tree_flatten_with_names(grads)
scaled_grads = [
grad * scale_factor if sharded_match_fn(name) else grad
for name, grad in names_and_grads
]
return tree_def.unflatten(scaled_grads)
else:
return grads
def tree_map_with_names(f, param_tree, match_name_fn=lambda name: True):
"""Like jax.tree_map but with a filter on the leaf path name.
Args:
f: The function to be applied to each parameter in `param_tree`.
param_tree: The tree of parameters `f` should be applied to.
match_name_fn: This function is called with each tree leave's path name,
which has a path-like format ('a/b/c'), and decides whether `f` should be
applied to that leaf or the leaf should be kept as-is.
Returns:
A tree identical in structure to `param_tree` but with the leaves the
result of calling `f` on them in the cases where `match_name_fn` returns
True for that leaf's path name.
"""
names_and_vals, tree_def = _tree_flatten_with_names(param_tree)
vals = [f(v) if match_name_fn(name) else v for name, v in names_and_vals]
return tree_def.unflatten(vals)
def _tree_flatten_with_names(
tree: ParamTree) -> Tuple[Sequence[Tuple[str, Any]], PyTreeDef]:
"""Like jax.tree_flatten but also fetches leaf names.
Specialized to parameter trees of the form {'key0': {'subkey0': Any}, ...}.
Args:
tree: Tree of parameters to flatten.
Returns:
- A list of leaf name and value pairs: [(name, value), ...].
- A tree definition object representing the structure of the flattened tree.
"""
# PyTrees don't treat None values as leaves, so we explicitly declare them as
# such.
vals, tree_def = jax.tree_flatten(tree, is_leaf=lambda x: x is None)
# 'Fake' token tree that is use to track jax internal tree traversal and
# adjust our custom tree traversal to be compatible with it.
tokens = range(len(vals))
token_tree = tree_def.unflatten(tokens)
val_names, perm = zip(*_traverse_with_names(token_tree))
inv_perm = np.argsort(perm)
# Custom traversal should visit the same number of leaves.
if len(val_names) != len(vals):
raise ValueError(f'Pytree traversal detected {len(val_names)} names, '
f'but {len(vals)} leafs.\nTreeDef is:\n{tree_def}')
return [(val_names[i], v) for i, v in zip(inv_perm, vals)], tree_def
def _traverse_with_names(
param_tree: ParamTree) -> Iterable[Tuple[str, ParamTree]]:
"""Traverses nested dicts/dataclasses and emits (leaf_name, leaf_val)."""
if dataclasses.is_dataclass(param_tree):
param_tree = flax.serialization.to_state_dict(param_tree)
if isinstance(param_tree, (dict, flax.core.FrozenDict)):
keys = sorted(param_tree.keys())
for key in keys:
for path, v in _traverse_with_names(param_tree[key]):
yield (key + '/' + path).rstrip('/'), v
else:
yield '', param_tree
|