youtube-music-transcribe / t5x /checkpoint_importer.py
juancopi81's picture
Add t5x and mt3 models
b100e1c
# Copyright 2022 The T5X Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""T5 Checkpoint Importer."""
import asyncio
from concurrent.futures import thread
import re
from typing import Any, Callable, Mapping, MutableMapping, Optional, Union
from flax import traverse_util
import jax
from jax import numpy as jnp
import numpy as np
import orbax.checkpoint
import tensorflow as tf
import tensorstore as ts
# TODO(b/233659813): Cleanup clients depending on t5x.checkpoint_importer for
# LazyArray. Reconcile divergence in subclass implementation when possible.
LazyArray = orbax.checkpoint.lazy_array.LazyArray
# TODO(brianlester): The choice between using a `LazyTreadPoolArray` or a
# `LazyAwaitableArray` is dependent on if the user provided `get_fn` is blocking
# or async respectively, if we can detect which it is, we can automatically
# proxy to the correct subclass. We cannot detect of `get_fn` is a lambda that
# wraps an async call so this isn't possible yet. Add this dispatch once we are
# able to detect that, python3.8+ can detect async for partial'ed functions but
# not lambdas.
class LazyThreadPoolArray(LazyArray):
"""Lazily and asynchronously loads an array when the `get_fn` blocks."""
# Uses a global threadpool to enable asynchronous loading.
executor = thread.ThreadPoolExecutor()
def get_async(self) -> asyncio.Future:
return asyncio.wrap_future(self.executor.submit(self.get))
def get(self) -> np.ndarray:
arr = self._get_fn()
if arr.dtype != self.dtype:
arr = arr.astype(self.dtype)
return arr
class LazyAwaitableArray(LazyArray):
"""Lazily and asynchronously loads an array when the `get_fn` is async.
Note:
The synchronous load method `.get` requires the asyncio event loop and
calling `.run_until_complete`. This is not supported when the event loop is
already running (for example, from inside another async function).
Note:
Currently, this class has a few helper methods for creating a
LazyAwaitableArray when the input could be either an array, or a TensorStore
spec. Most people use async code when dealing with TensorStore so the
classmethods have been placed here. When someone eventually uses a blocking
function to read from TensorStore they can be moved to the LazyArray base
class.
"""
def get_async(self) -> asyncio.Future:
async def _get_and_cast():
# Pytype has a false positive here, where it treats our _get_fn (_read_ts
# in this case) as having a return type of `np.ndarray` instead of
# wrapping it in an Awaitable. Related to this bug
# https://github.com/google/pytype/issues/527
arr = await self._get_fn() # pytype: disable=bad-return-type
if arr.dtype != self.dtype:
arr = arr.astype(self.dtype)
return arr
return asyncio.ensure_future(_get_and_cast())
def get(self) -> np.ndarray:
loop = asyncio.get_event_loop()
return loop.run_until_complete(self.get_async())
@classmethod
def from_tensor_store_spec(
cls,
ts_spec: ts.Spec,
get_fn: Callable[[], np.ndarray],
dtype: Optional[jnp.dtype] = None) -> 'LazyAwaitableArray':
"""Create a LazyAwaitableArray based on a tensorstore.Spec."""
ts_spec = ts_spec.to_json()
shape = ts_spec['metadata']['shape']
if dtype is None:
dtype = jnp.dtype(ts_spec['dtype'])
else:
dtype = jnp.dtype(dtype)
# v2 T5X checkpoints use uint16 as the TensorStore datatype and then store
# the bfloat16 bytes as in in the 16 bytes uint16 has (no actual cast). When
# When reading the dtype from the TensorStore, if we keep the dtype of these
# v2 checkpoints as np.uint16 then the _get_fn (which has a possible cast to
# support the `restore_dtype` parameter for the checkpointer) will actually
# cast the bfloat16 values to uint16, generally resulting in an array of all
# zeros. This check avoid the actual cast to uint16 by replacing the dtype.
if dtype == np.uint16:
dtype = jnp.bfloat16
return cls(shape, dtype, get_fn)
@classmethod
def from_array(cls,
array: np.ndarray,
get_fn: Callable[[], np.ndarray],
dtype: Optional[jnp.dtype] = None) -> 'LazyAwaitableArray':
"""Create a LazyAwaitableArray based on an array or python number."""
if dtype is None:
dtype = array.dtype
else:
dtype = jnp.dtype(dtype)
return cls(array.shape, dtype, get_fn)
@classmethod
def from_tensor_store_spec_or_array(
cls,
maybe_ts_spec: Union[ts.Spec, np.ndarray],
get_fn: Callable[[], np.ndarray],
dtype: Optional[jnp.dtype] = None) -> 'LazyAwaitableArray':
"""Create a LazyAwaitableArray based on an array or a tensorstore.Spec."""
if isinstance(maybe_ts_spec, ts.Spec):
return cls.from_tensor_store_spec(maybe_ts_spec, get_fn, dtype=dtype)
return cls.from_array(maybe_ts_spec, get_fn, dtype=dtype)
class CheckpointTranslator:
"""Utility class for defining mapping rules from one flatdict to another.
We assume a checkpoint is loaded as a dictionary with flattened keys of the
form: 'name0/name1/name2/.../nameN'
A rule is added with the 'add' decorator, which takes a regex matching rule
and wraps a conversion function, feeding it (opts, key, val, **regex_groups)
where opts is a dict containing apply-time keyword options for use by the
conversion functions.
"""
def __init__(self):
self.rules = []
def add(self, pattern):
"""Adds a new keyval conversion rule.
Args:
pattern: regex with capture groups for matching given sets of model
variables. We terminate all regexes with '$' to force complete matches.
Returns:
Translation function decorator for associating with the provided
pattern.
"""
def register_translation_fn_decorator(fn):
# We force a complete match by adding end-of-string match.
self.rules.append((re.compile(pattern + '$'), fn))
return fn
return register_translation_fn_decorator
def apply(self, flatdict, **opts):
"""Applies rules to a flattened dictionary.
Args:
flatdict: flat-key dictionary of variables.
**opts: additional config options for translation rules supplied at
application time.
Returns:
Checkpoint data with translated key/values in flat-key dict format.
"""
new_dict = {}
unmatched = {}
for k, v in flatdict.items():
matched = False
for rule_pat, rule_fn in self.rules:
if rule_pat.match(k):
groups = rule_pat.match(k).groups()
new_k, new_v = rule_fn(opts, k, v, *groups)
if new_k is not None:
new_dict[new_k] = new_v
matched = True
break
if not matched:
unmatched[k] = v
# We force every key-value pair in checkpoint to have a rule associated with
# it.
if unmatched:
raise ValueError('Unmapped tensor keys exist: %s' % unmatched)
return new_dict
# Create a translation rule set for importing T5 & T5.1.1 model checkpoints.
# -----------------------------------------------------------------------------
t5_importer = CheckpointTranslator()
# Name mappings.
SLOT_MAP = {'_slot_vc': 'v_col', '_slot_vr': 'v_row', '_slot_v': 'v'}
TOWER_MAP = {'transformer': 'decoder'}
@t5_importer.add(r'global_step')
def global_step(opts, key, val):
del opts, key
return 'state/step', val.astype(np.int32).get() if isinstance(
val, LazyArray) else val
@t5_importer.add(r'shared/embedding(\w*)')
def shared_embeddings(opts, key, val, slot):
del opts, key
prefix = 'state/param_states' if slot else 'target'
suffix = '/' + SLOT_MAP[slot] if slot else ''
newkey = f'{prefix}/token_embedder/embedding{suffix}'
return newkey, val
@t5_importer.add(r'(encoder|decoder|transformer)/embedding(\w*)')
def separate_embeddings(opts, key, val, encdec, slot):
del opts, key
prefix = 'state/param_states' if slot else 'target'
suffix = '/' + SLOT_MAP[slot] if slot else ''
encdec = TOWER_MAP.get(encdec, encdec)
newkey = f'{prefix}/{encdec}/token_embedder/embedding{suffix}'
return newkey, val
# In the Mesh TensorFlow T5 code, relative_attention_bias always occurs in layer
# 0 because SelfAttention precedes other sublayers within the same block.
@t5_importer.add(
r'(encoder|decoder|transformer)/block_(\d+)/layer_000/SelfAttention/relative_attention_bias(\w*)'
)
def rel_embeddings(opts, key, val, encdec, blocknum, slot):
"""Process relpos bias assuming that they are not shared across layers."""
del opts, key
prefix = 'state/param_states' if slot else 'target'
suffix = '/' + SLOT_MAP[slot] if slot else ''
blocknum = int(blocknum)
encdec = TOWER_MAP.get(encdec, encdec)
# At this point, we can't determine whether the relpos bias was shared across
# layers or not. We first assume that it was not shared. During post
# processing, we remove the layers_0 scope if it was shared.
newkey = f'{prefix}/{encdec}/layers_{blocknum}/relpos_bias/rel_embedding{suffix}'
return newkey, val
@t5_importer.add(
r'(encoder|decoder|transformer)/block_(\d+)/layer_\d+/(SelfAttention|EncDecAttention)/(q|k|v|o)(\w*)'
)
def attention_layers(opts, key, val, encdec, blocknum, attntype, qkvo, slot):
"""Process attention layers."""
del opts, key
prefix = 'state/param_states' if slot else 'target'
suffix = '/' + SLOT_MAP[slot] if slot else ''
blocknum = int(blocknum)
encdec = TOWER_MAP.get(encdec, encdec)
matrix = {'q': 'query', 'k': 'key', 'v': 'value', 'o': 'out'}[qkvo]
if encdec == 'encoder':
attntype = 'attention'
else:
attntype = {
'SelfAttention': 'self_attention',
'EncDecAttention': 'encoder_decoder_attention'
}[attntype]
newkey = f'{prefix}/{encdec}/layers_{blocknum}/{attntype}/{matrix}/kernel{suffix}'
return newkey, val
@t5_importer.add(
r'(encoder|decoder|transformer)/block_(\d+)/layer_\d+/DenseReluDense/(wi|wo)(?:_(\d+))?/kernel(\w*)'
)
def mlpblock(opts, key, val, encdec, blocknum, io_name, io_num, slot):
"""Process MLP blocks."""
del opts, key
prefix = 'state/param_states' if slot else 'target'
suffix = '/' + SLOT_MAP[slot] if slot else ''
blocknum = int(blocknum)
encdec = TOWER_MAP.get(encdec, encdec)
io_num = f'_{io_num}' if io_num else ''
newkey = f'{prefix}/{encdec}/layers_{blocknum}/mlp/{io_name}{io_num}/kernel{suffix}'
return newkey, val
@t5_importer.add(
r'(encoder|decoder|transformer)/block_(\d+)/layer_(\d+)/(?:layer|rms)_norm/scale(\w*)'
)
def layernorms(opts, key, val, encdec, blocknum, lyrnum, slot):
"""Process layer norms assuming that they are pre-layernorms."""
del opts, key
prefix = 'state/param_states' if slot else 'target'
suffix = '/' + SLOT_MAP[slot] if slot else ''
lyrnum = int(lyrnum)
if encdec == 'transformer':
layernorm_type = ['pre_self_attention_layer_norm',
'pre_mlp_layer_norm'][lyrnum]
elif encdec == 'encoder':
layernorm_type = ['pre_attention_layer_norm', 'pre_mlp_layer_norm'][lyrnum]
else: # decoder
layernorm_type = [
'pre_self_attention_layer_norm', 'pre_cross_attention_layer_norm',
'pre_mlp_layer_norm'
][lyrnum]
encdec = TOWER_MAP.get(encdec, encdec)
newkey = f'{prefix}/{encdec}/layers_{int(blocknum)}/{layernorm_type}/scale{suffix}'
return newkey, val
@t5_importer.add(
r'(encoder|decoder|transformer)/(?:final_layer|rms)_norm/scale(\w*)')
def final_layernorms(opts, key, val, encdec, slot):
"""Process final layer norms."""
del opts, key
prefix = 'state/param_states' if slot else 'target'
suffix = '/' + SLOT_MAP[slot] if slot else ''
norm = {
'encoder': 'encoder_norm',
'decoder': 'decoder_norm',
'transformer': 'decoder_norm'
}[encdec]
encdec = TOWER_MAP.get(encdec, encdec)
newkey = f'{prefix}/{encdec}/{norm}/scale{suffix}'
return newkey, val
@t5_importer.add(r'(?:decoder|transformer)/logits/kernel(\w*)')
def final_logits(opts, key, val, slot):
del opts, key
prefix = 'state/param_states' if slot else 'target'
suffix = '/' + SLOT_MAP[slot] if slot else ''
newkey = f'{prefix}/decoder/logits_dense/kernel{suffix}'
return newkey, val
def _add_missing_param_states(t5_data):
"""Add dummy slots that Flax Adafactor requires but TF does not."""
updates = {}
for k in t5_data:
if k.startswith('target'):
state_leaf = 'state/param_states' + k[len('target'):]
updates[state_leaf + '/m'] = np.zeros((1,), np.float32)
if state_leaf + '/v' in t5_data:
updates[state_leaf + '/v_row'] = np.zeros((1,), np.float32)
updates[state_leaf + '/v_col'] = np.zeros((1,), np.float32)
elif state_leaf + '/v_row' in t5_data:
updates[state_leaf + '/v'] = np.zeros((1,), np.float32)
t5_data.update(**updates)
return t5_data
def _maybe_correct_relpos_bias(t5_data):
"""Correct the relpos_bias format if it is shared across layers."""
max_layer_ind = 0
for k, v in t5_data.items():
match = re.search(r'layers_(\d+)/relpos_bias', k)
if match:
layer_ind = int(match.groups()[0])
max_layer_ind = max(max_layer_ind, layer_ind)
modified_dict = {}
if max_layer_ind == 0:
# Relative position biases are shared across layers
for k, v in t5_data.items():
new_k = re.sub(r'layers_\d+/relpos_bias', 'relpos_bias', k)
modified_dict[new_k] = v
else:
# Relative position biases are unique in each layer. No more processing is
# necessary.
modified_dict = t5_data
return modified_dict
# Load checkpoint, translate, and update flax optimizer and model.
# -----------------------------------------------------------------------------
def load_tf_ckpt(path):
"""Load a TF checkpoint as a flat dictionary of numpy arrays."""
ckpt_reader = tf.train.load_checkpoint(path)
ckpt_shape_map = ckpt_reader.get_variable_to_shape_map()
ckpt_dtype_map = ckpt_reader.get_variable_to_dtype_map()
datamap = { # pylint: disable=g-complex-comprehension
k: LazyThreadPoolArray(
s,
jnp.dtype(ckpt_dtype_map[k].as_numpy_dtype),
lambda x=k: ckpt_reader.get_tensor(x))
for k, s in ckpt_shape_map.items()
}
return datamap
def _update_state_dict(state_dict: Mapping[str, Any],
t5_data: MutableMapping[str, LazyArray],
strict: bool = True) -> Mapping[str, Any]:
"""Update flax optimizer for T5 model.
Args:
state_dict: Optimizer to update with T5 parameters.
t5_data: T5 model parameters, typically loaded from a checkpoint.
strict: If True requires that optimizer and t5_data mappings contain the
same set of names (variables). If False, updating will succeed even if
t5_data contains variables not in the optimizer. If the optimizer has
variables not in t5_data, this function will still fail.
Returns:
Updated optimizer.
"""
flat_state_dict = traverse_util.flatten_dict(state_dict, sep='/')
# Remove parameters from the checkpoint not found in the optimizer (this
# allows us to load checkpoints that contain more parameters than our current
# model).
if not strict:
for k in list(t5_data):
if k not in flat_state_dict:
t5_data.pop(k)
# Shape check.
for k, v in t5_data.items():
if flat_state_dict[k].shape != v.shape:
raise ValueError(
f'Variable {k} has shape {v.shape} != {flat_state_dict[k].shape}')
flat_state_dict = t5_data
state_dict = traverse_util.unflatten_dict(
{tuple(k.split('/')): v for k, v in flat_state_dict.items()})
return state_dict
def restore_from_t5_checkpoint(
state_dict: Mapping[str, Any],
path: str,
lazy_parameters: bool = False,
strict: bool = True,
translator: Optional[CheckpointTranslator] = None) -> Mapping[str, Any]:
"""Load T5 checkpoint and update Adafactor optimizer and T5 model from it.
We require that the final translated checkpoint structure exactly matches
that of the Flax Adafactor + Transformer data, up to shape agreement of
the leaves.
Args:
state_dict: Flax Adafactor Optimizer for T5 transformer encoder-decoder.
path: a path to checkpoint file or directory.
lazy_parameters: whether to leave the parameters as LazyArrays to preserve
memory.
strict: If True requires that optimizer and t5_data mappings contain the
same set of names (variables). If False, updating will succeed even if
t5_data contains variables not in the optimizer. If the optimizer has
variables not in t5_data, this function will still fail.
translator: The mapping rules for conversion. If None, then default T5
conversion rules will be used.
Returns:
Adafactor optimizer updated with parameters and optimizer state from
T5 checkpoint.
"""
if translator is None:
translator = t5_importer
ckpt_data = load_tf_ckpt(path)
t5_data = translator.apply(ckpt_data)
t5_data = _add_missing_param_states(t5_data)
t5_data = _maybe_correct_relpos_bias(t5_data)
state_dict = _update_state_dict(state_dict, t5_data, strict=strict)
if not lazy_parameters:
state_dict = jax.tree_map(
lambda x: x.get() if isinstance(x, LazyArray) else x, state_dict)
return state_dict