youtube-music-transcribe / t5x /scripts /convert_tf_checkpoint.py
juancopi81's picture
Add t5x and mt3 models
b100e1c
raw
history blame
4 kB
# Copyright 2022 The T5X Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r"""Tool to convert a T5/MeshTF checkpoint to T5X.
While T5X can be load these checkpoints on-the-fly, the process can be slow
for very large checkpoints. For frequently used checkpoints, it's recommended to
convert them once to a T5X checkpoint.
Example usage:
CUDA_VISIBLE_DEVICES=""
python -m t5x.scripts.convert_tf_checkpoint \
--gin_file=t5x/examples/t5/t5_1_0/small.gin\
--gin.convert_checkpoint.model=%MODEL\
--gin.convert_checkpoint.tf_checkpoint_path=\
\"gs://t5-data/pretrained_models/small/model.ckpt-1000000\"\
--gin.convert_checkpoint.output_dir=\"/tmp/t5x_checkpoints/t5_small\"\
--logtostderr
"""
import jax
import jax.numpy as jnp
from t5x import checkpoints
from t5x import models
from t5x import partitioning
from t5x import train_state as train_state_lib
def convert_checkpoint(model: models.BaseModel,
tf_checkpoint_path: str,
output_dir: str,
save_dtype: jnp.dtype = jnp.float32,
concurrent_gb: int = 16):
"""Converts a TensorFlow checkpoint to a P5X checkpoint.
Args:
model:
tf_checkpoint_path: Path to a TensorFlow checkpoint to convert.
output_dir: Path to a directory to write the converted checkpoint.
save_dtype: What dtype to store the target parameters as.
concurrent_gb: Number of gigabtes of parameters to convert in parallel.
Actual RAM usage may be 4X this number.
"""
def initialize_train_state(rng):
initial_variables = model.get_initial_variables(
rng=rng,
input_shapes={
'encoder_input_tokens': (1, 1),
'decoder_input_tokens': (1, 1)
})
return train_state_lib.FlaxOptimTrainState.create(model.optimizer_def,
initial_variables)
train_state = jax.eval_shape(initialize_train_state, jax.random.PRNGKey(0))
partitioner = partitioning.PjitPartitioner(1)
checkpointer = checkpoints.Checkpointer(
train_state, partitioner, output_dir, save_dtype=jnp.dtype(save_dtype))
checkpointer.convert_from_tf_checkpoint(
tf_checkpoint_path, concurrent_gb=concurrent_gb)
if __name__ == '__main__':
# pylint:disable=g-import-not-at-top
from absl import flags
import gin
from t5x import gin_utils
# pylint:disable=g-import-not-at-top
FLAGS = flags.FLAGS
jax.config.parse_flags_with_absl()
flags.DEFINE_multi_string(
'gin_file',
default=None,
help='Path to gin configuration file. Multiple paths may be passed and '
'will be imported in the given order, with later configurations '
'overriding earlier ones.')
flags.DEFINE_multi_string(
'gin_bindings', default=[], help='Individual gin bindings')
flags.DEFINE_list(
'gin_search_paths',
default=['t5x/configs'],
help='Comma-separated list of gin config path prefixes to be prepended '
'to suffixes given via `--gin_file`. If a file appears in. Only the '
'first prefix that produces a valid path for each suffix will be '
'used.')
def main(_):
"""True main function."""
convert_checkpoint_using_gin = gin.configurable(convert_checkpoint)
gin_utils.parse_gin_flags(FLAGS.gin_search_paths, FLAGS.gin_file,
FLAGS.gin_bindings)
# Get gin-configured version of `convert_checkpoint`.
convert_checkpoint_using_gin()
gin_utils.run(main)