# Copyright 2022 The T5X Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """T5 Checkpoint Importer.""" import asyncio from concurrent.futures import thread import re from typing import Any, Callable, Mapping, MutableMapping, Optional, Union from flax import traverse_util import jax from jax import numpy as jnp import numpy as np import orbax.checkpoint import tensorflow as tf import tensorstore as ts # TODO(b/233659813): Cleanup clients depending on t5x.checkpoint_importer for # LazyArray. Reconcile divergence in subclass implementation when possible. LazyArray = orbax.checkpoint.lazy_array.LazyArray # TODO(brianlester): The choice between using a `LazyTreadPoolArray` or a # `LazyAwaitableArray` is dependent on if the user provided `get_fn` is blocking # or async respectively, if we can detect which it is, we can automatically # proxy to the correct subclass. We cannot detect of `get_fn` is a lambda that # wraps an async call so this isn't possible yet. Add this dispatch once we are # able to detect that, python3.8+ can detect async for partial'ed functions but # not lambdas. class LazyThreadPoolArray(LazyArray): """Lazily and asynchronously loads an array when the `get_fn` blocks.""" # Uses a global threadpool to enable asynchronous loading. executor = thread.ThreadPoolExecutor() def get_async(self) -> asyncio.Future: return asyncio.wrap_future(self.executor.submit(self.get)) def get(self) -> np.ndarray: arr = self._get_fn() if arr.dtype != self.dtype: arr = arr.astype(self.dtype) return arr class LazyAwaitableArray(LazyArray): """Lazily and asynchronously loads an array when the `get_fn` is async. Note: The synchronous load method `.get` requires the asyncio event loop and calling `.run_until_complete`. This is not supported when the event loop is already running (for example, from inside another async function). Note: Currently, this class has a few helper methods for creating a LazyAwaitableArray when the input could be either an array, or a TensorStore spec. Most people use async code when dealing with TensorStore so the classmethods have been placed here. When someone eventually uses a blocking function to read from TensorStore they can be moved to the LazyArray base class. """ def get_async(self) -> asyncio.Future: async def _get_and_cast(): # Pytype has a false positive here, where it treats our _get_fn (_read_ts # in this case) as having a return type of `np.ndarray` instead of # wrapping it in an Awaitable. Related to this bug # https://github.com/google/pytype/issues/527 arr = await self._get_fn() # pytype: disable=bad-return-type if arr.dtype != self.dtype: arr = arr.astype(self.dtype) return arr return asyncio.ensure_future(_get_and_cast()) def get(self) -> np.ndarray: loop = asyncio.get_event_loop() return loop.run_until_complete(self.get_async()) @classmethod def from_tensor_store_spec( cls, ts_spec: ts.Spec, get_fn: Callable[[], np.ndarray], dtype: Optional[jnp.dtype] = None) -> 'LazyAwaitableArray': """Create a LazyAwaitableArray based on a tensorstore.Spec.""" ts_spec = ts_spec.to_json() shape = ts_spec['metadata']['shape'] if dtype is None: dtype = jnp.dtype(ts_spec['dtype']) else: dtype = jnp.dtype(dtype) # v2 T5X checkpoints use uint16 as the TensorStore datatype and then store # the bfloat16 bytes as in in the 16 bytes uint16 has (no actual cast). When # When reading the dtype from the TensorStore, if we keep the dtype of these # v2 checkpoints as np.uint16 then the _get_fn (which has a possible cast to # support the `restore_dtype` parameter for the checkpointer) will actually # cast the bfloat16 values to uint16, generally resulting in an array of all # zeros. This check avoid the actual cast to uint16 by replacing the dtype. if dtype == np.uint16: dtype = jnp.bfloat16 return cls(shape, dtype, get_fn) @classmethod def from_array(cls, array: np.ndarray, get_fn: Callable[[], np.ndarray], dtype: Optional[jnp.dtype] = None) -> 'LazyAwaitableArray': """Create a LazyAwaitableArray based on an array or python number.""" if dtype is None: dtype = array.dtype else: dtype = jnp.dtype(dtype) return cls(array.shape, dtype, get_fn) @classmethod def from_tensor_store_spec_or_array( cls, maybe_ts_spec: Union[ts.Spec, np.ndarray], get_fn: Callable[[], np.ndarray], dtype: Optional[jnp.dtype] = None) -> 'LazyAwaitableArray': """Create a LazyAwaitableArray based on an array or a tensorstore.Spec.""" if isinstance(maybe_ts_spec, ts.Spec): return cls.from_tensor_store_spec(maybe_ts_spec, get_fn, dtype=dtype) return cls.from_array(maybe_ts_spec, get_fn, dtype=dtype) class CheckpointTranslator: """Utility class for defining mapping rules from one flatdict to another. We assume a checkpoint is loaded as a dictionary with flattened keys of the form: 'name0/name1/name2/.../nameN' A rule is added with the 'add' decorator, which takes a regex matching rule and wraps a conversion function, feeding it (opts, key, val, **regex_groups) where opts is a dict containing apply-time keyword options for use by the conversion functions. """ def __init__(self): self.rules = [] def add(self, pattern): """Adds a new keyval conversion rule. Args: pattern: regex with capture groups for matching given sets of model variables. We terminate all regexes with '$' to force complete matches. Returns: Translation function decorator for associating with the provided pattern. """ def register_translation_fn_decorator(fn): # We force a complete match by adding end-of-string match. self.rules.append((re.compile(pattern + '$'), fn)) return fn return register_translation_fn_decorator def apply(self, flatdict, **opts): """Applies rules to a flattened dictionary. Args: flatdict: flat-key dictionary of variables. **opts: additional config options for translation rules supplied at application time. Returns: Checkpoint data with translated key/values in flat-key dict format. """ new_dict = {} unmatched = {} for k, v in flatdict.items(): matched = False for rule_pat, rule_fn in self.rules: if rule_pat.match(k): groups = rule_pat.match(k).groups() new_k, new_v = rule_fn(opts, k, v, *groups) if new_k is not None: new_dict[new_k] = new_v matched = True break if not matched: unmatched[k] = v # We force every key-value pair in checkpoint to have a rule associated with # it. if unmatched: raise ValueError('Unmapped tensor keys exist: %s' % unmatched) return new_dict # Create a translation rule set for importing T5 & T5.1.1 model checkpoints. # ----------------------------------------------------------------------------- t5_importer = CheckpointTranslator() # Name mappings. SLOT_MAP = {'_slot_vc': 'v_col', '_slot_vr': 'v_row', '_slot_v': 'v'} TOWER_MAP = {'transformer': 'decoder'} @t5_importer.add(r'global_step') def global_step(opts, key, val): del opts, key return 'state/step', val.astype(np.int32).get() if isinstance( val, LazyArray) else val @t5_importer.add(r'shared/embedding(\w*)') def shared_embeddings(opts, key, val, slot): del opts, key prefix = 'state/param_states' if slot else 'target' suffix = '/' + SLOT_MAP[slot] if slot else '' newkey = f'{prefix}/token_embedder/embedding{suffix}' return newkey, val @t5_importer.add(r'(encoder|decoder|transformer)/embedding(\w*)') def separate_embeddings(opts, key, val, encdec, slot): del opts, key prefix = 'state/param_states' if slot else 'target' suffix = '/' + SLOT_MAP[slot] if slot else '' encdec = TOWER_MAP.get(encdec, encdec) newkey = f'{prefix}/{encdec}/token_embedder/embedding{suffix}' return newkey, val # In the Mesh TensorFlow T5 code, relative_attention_bias always occurs in layer # 0 because SelfAttention precedes other sublayers within the same block. @t5_importer.add( r'(encoder|decoder|transformer)/block_(\d+)/layer_000/SelfAttention/relative_attention_bias(\w*)' ) def rel_embeddings(opts, key, val, encdec, blocknum, slot): """Process relpos bias assuming that they are not shared across layers.""" del opts, key prefix = 'state/param_states' if slot else 'target' suffix = '/' + SLOT_MAP[slot] if slot else '' blocknum = int(blocknum) encdec = TOWER_MAP.get(encdec, encdec) # At this point, we can't determine whether the relpos bias was shared across # layers or not. We first assume that it was not shared. During post # processing, we remove the layers_0 scope if it was shared. newkey = f'{prefix}/{encdec}/layers_{blocknum}/relpos_bias/rel_embedding{suffix}' return newkey, val @t5_importer.add( r'(encoder|decoder|transformer)/block_(\d+)/layer_\d+/(SelfAttention|EncDecAttention)/(q|k|v|o)(\w*)' ) def attention_layers(opts, key, val, encdec, blocknum, attntype, qkvo, slot): """Process attention layers.""" del opts, key prefix = 'state/param_states' if slot else 'target' suffix = '/' + SLOT_MAP[slot] if slot else '' blocknum = int(blocknum) encdec = TOWER_MAP.get(encdec, encdec) matrix = {'q': 'query', 'k': 'key', 'v': 'value', 'o': 'out'}[qkvo] if encdec == 'encoder': attntype = 'attention' else: attntype = { 'SelfAttention': 'self_attention', 'EncDecAttention': 'encoder_decoder_attention' }[attntype] newkey = f'{prefix}/{encdec}/layers_{blocknum}/{attntype}/{matrix}/kernel{suffix}' return newkey, val @t5_importer.add( r'(encoder|decoder|transformer)/block_(\d+)/layer_\d+/DenseReluDense/(wi|wo)(?:_(\d+))?/kernel(\w*)' ) def mlpblock(opts, key, val, encdec, blocknum, io_name, io_num, slot): """Process MLP blocks.""" del opts, key prefix = 'state/param_states' if slot else 'target' suffix = '/' + SLOT_MAP[slot] if slot else '' blocknum = int(blocknum) encdec = TOWER_MAP.get(encdec, encdec) io_num = f'_{io_num}' if io_num else '' newkey = f'{prefix}/{encdec}/layers_{blocknum}/mlp/{io_name}{io_num}/kernel{suffix}' return newkey, val @t5_importer.add( r'(encoder|decoder|transformer)/block_(\d+)/layer_(\d+)/(?:layer|rms)_norm/scale(\w*)' ) def layernorms(opts, key, val, encdec, blocknum, lyrnum, slot): """Process layer norms assuming that they are pre-layernorms.""" del opts, key prefix = 'state/param_states' if slot else 'target' suffix = '/' + SLOT_MAP[slot] if slot else '' lyrnum = int(lyrnum) if encdec == 'transformer': layernorm_type = ['pre_self_attention_layer_norm', 'pre_mlp_layer_norm'][lyrnum] elif encdec == 'encoder': layernorm_type = ['pre_attention_layer_norm', 'pre_mlp_layer_norm'][lyrnum] else: # decoder layernorm_type = [ 'pre_self_attention_layer_norm', 'pre_cross_attention_layer_norm', 'pre_mlp_layer_norm' ][lyrnum] encdec = TOWER_MAP.get(encdec, encdec) newkey = f'{prefix}/{encdec}/layers_{int(blocknum)}/{layernorm_type}/scale{suffix}' return newkey, val @t5_importer.add( r'(encoder|decoder|transformer)/(?:final_layer|rms)_norm/scale(\w*)') def final_layernorms(opts, key, val, encdec, slot): """Process final layer norms.""" del opts, key prefix = 'state/param_states' if slot else 'target' suffix = '/' + SLOT_MAP[slot] if slot else '' norm = { 'encoder': 'encoder_norm', 'decoder': 'decoder_norm', 'transformer': 'decoder_norm' }[encdec] encdec = TOWER_MAP.get(encdec, encdec) newkey = f'{prefix}/{encdec}/{norm}/scale{suffix}' return newkey, val @t5_importer.add(r'(?:decoder|transformer)/logits/kernel(\w*)') def final_logits(opts, key, val, slot): del opts, key prefix = 'state/param_states' if slot else 'target' suffix = '/' + SLOT_MAP[slot] if slot else '' newkey = f'{prefix}/decoder/logits_dense/kernel{suffix}' return newkey, val def _add_missing_param_states(t5_data): """Add dummy slots that Flax Adafactor requires but TF does not.""" updates = {} for k in t5_data: if k.startswith('target'): state_leaf = 'state/param_states' + k[len('target'):] updates[state_leaf + '/m'] = np.zeros((1,), np.float32) if state_leaf + '/v' in t5_data: updates[state_leaf + '/v_row'] = np.zeros((1,), np.float32) updates[state_leaf + '/v_col'] = np.zeros((1,), np.float32) elif state_leaf + '/v_row' in t5_data: updates[state_leaf + '/v'] = np.zeros((1,), np.float32) t5_data.update(**updates) return t5_data def _maybe_correct_relpos_bias(t5_data): """Correct the relpos_bias format if it is shared across layers.""" max_layer_ind = 0 for k, v in t5_data.items(): match = re.search(r'layers_(\d+)/relpos_bias', k) if match: layer_ind = int(match.groups()[0]) max_layer_ind = max(max_layer_ind, layer_ind) modified_dict = {} if max_layer_ind == 0: # Relative position biases are shared across layers for k, v in t5_data.items(): new_k = re.sub(r'layers_\d+/relpos_bias', 'relpos_bias', k) modified_dict[new_k] = v else: # Relative position biases are unique in each layer. No more processing is # necessary. modified_dict = t5_data return modified_dict # Load checkpoint, translate, and update flax optimizer and model. # ----------------------------------------------------------------------------- def load_tf_ckpt(path): """Load a TF checkpoint as a flat dictionary of numpy arrays.""" ckpt_reader = tf.train.load_checkpoint(path) ckpt_shape_map = ckpt_reader.get_variable_to_shape_map() ckpt_dtype_map = ckpt_reader.get_variable_to_dtype_map() datamap = { # pylint: disable=g-complex-comprehension k: LazyThreadPoolArray( s, jnp.dtype(ckpt_dtype_map[k].as_numpy_dtype), lambda x=k: ckpt_reader.get_tensor(x)) for k, s in ckpt_shape_map.items() } return datamap def _update_state_dict(state_dict: Mapping[str, Any], t5_data: MutableMapping[str, LazyArray], strict: bool = True) -> Mapping[str, Any]: """Update flax optimizer for T5 model. Args: state_dict: Optimizer to update with T5 parameters. t5_data: T5 model parameters, typically loaded from a checkpoint. strict: If True requires that optimizer and t5_data mappings contain the same set of names (variables). If False, updating will succeed even if t5_data contains variables not in the optimizer. If the optimizer has variables not in t5_data, this function will still fail. Returns: Updated optimizer. """ flat_state_dict = traverse_util.flatten_dict(state_dict, sep='/') # Remove parameters from the checkpoint not found in the optimizer (this # allows us to load checkpoints that contain more parameters than our current # model). if not strict: for k in list(t5_data): if k not in flat_state_dict: t5_data.pop(k) # Shape check. for k, v in t5_data.items(): if flat_state_dict[k].shape != v.shape: raise ValueError( f'Variable {k} has shape {v.shape} != {flat_state_dict[k].shape}') flat_state_dict = t5_data state_dict = traverse_util.unflatten_dict( {tuple(k.split('/')): v for k, v in flat_state_dict.items()}) return state_dict def restore_from_t5_checkpoint( state_dict: Mapping[str, Any], path: str, lazy_parameters: bool = False, strict: bool = True, translator: Optional[CheckpointTranslator] = None) -> Mapping[str, Any]: """Load T5 checkpoint and update Adafactor optimizer and T5 model from it. We require that the final translated checkpoint structure exactly matches that of the Flax Adafactor + Transformer data, up to shape agreement of the leaves. Args: state_dict: Flax Adafactor Optimizer for T5 transformer encoder-decoder. path: a path to checkpoint file or directory. lazy_parameters: whether to leave the parameters as LazyArrays to preserve memory. strict: If True requires that optimizer and t5_data mappings contain the same set of names (variables). If False, updating will succeed even if t5_data contains variables not in the optimizer. If the optimizer has variables not in t5_data, this function will still fail. translator: The mapping rules for conversion. If None, then default T5 conversion rules will be used. Returns: Adafactor optimizer updated with parameters and optimizer state from T5 checkpoint. """ if translator is None: translator = t5_importer ckpt_data = load_tf_ckpt(path) t5_data = translator.apply(ckpt_data) t5_data = _add_missing_param_states(t5_data) t5_data = _maybe_correct_relpos_bias(t5_data) state_dict = _update_state_dict(state_dict, t5_data, strict=strict) if not lazy_parameters: state_dict = jax.tree_map( lambda x: x.get() if isinstance(x, LazyArray) else x, state_dict) return state_dict