# Copyright 2022 The T5X Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. r"""Tool to convert a T5/MeshTF checkpoint to T5X. While T5X can be load these checkpoints on-the-fly, the process can be slow for very large checkpoints. For frequently used checkpoints, it's recommended to convert them once to a T5X checkpoint. Example usage: CUDA_VISIBLE_DEVICES="" python -m t5x.scripts.convert_tf_checkpoint \ --gin_file=t5x/examples/t5/t5_1_0/small.gin\ --gin.convert_checkpoint.model=%MODEL\ --gin.convert_checkpoint.tf_checkpoint_path=\ \"gs://t5-data/pretrained_models/small/model.ckpt-1000000\"\ --gin.convert_checkpoint.output_dir=\"/tmp/t5x_checkpoints/t5_small\"\ --logtostderr """ import jax import jax.numpy as jnp from t5x import checkpoints from t5x import models from t5x import partitioning from t5x import train_state as train_state_lib def convert_checkpoint(model: models.BaseModel, tf_checkpoint_path: str, output_dir: str, save_dtype: jnp.dtype = jnp.float32, concurrent_gb: int = 16): """Converts a TensorFlow checkpoint to a P5X checkpoint. Args: model: tf_checkpoint_path: Path to a TensorFlow checkpoint to convert. output_dir: Path to a directory to write the converted checkpoint. save_dtype: What dtype to store the target parameters as. concurrent_gb: Number of gigabtes of parameters to convert in parallel. Actual RAM usage may be 4X this number. """ def initialize_train_state(rng): initial_variables = model.get_initial_variables( rng=rng, input_shapes={ 'encoder_input_tokens': (1, 1), 'decoder_input_tokens': (1, 1) }) return train_state_lib.FlaxOptimTrainState.create(model.optimizer_def, initial_variables) train_state = jax.eval_shape(initialize_train_state, jax.random.PRNGKey(0)) partitioner = partitioning.PjitPartitioner(1) checkpointer = checkpoints.Checkpointer( train_state, partitioner, output_dir, save_dtype=jnp.dtype(save_dtype)) checkpointer.convert_from_tf_checkpoint( tf_checkpoint_path, concurrent_gb=concurrent_gb) if __name__ == '__main__': # pylint:disable=g-import-not-at-top from absl import flags import gin from t5x import gin_utils # pylint:disable=g-import-not-at-top FLAGS = flags.FLAGS jax.config.parse_flags_with_absl() flags.DEFINE_multi_string( 'gin_file', default=None, help='Path to gin configuration file. Multiple paths may be passed and ' 'will be imported in the given order, with later configurations ' 'overriding earlier ones.') flags.DEFINE_multi_string( 'gin_bindings', default=[], help='Individual gin bindings') flags.DEFINE_list( 'gin_search_paths', default=['t5x/configs'], help='Comma-separated list of gin config path prefixes to be prepended ' 'to suffixes given via `--gin_file`. If a file appears in. Only the ' 'first prefix that produces a valid path for each suffix will be ' 'used.') def main(_): """True main function.""" convert_checkpoint_using_gin = gin.configurable(convert_checkpoint) gin_utils.parse_gin_flags(FLAGS.gin_search_paths, FLAGS.gin_file, FLAGS.gin_bindings) # Get gin-configured version of `convert_checkpoint`. convert_checkpoint_using_gin() gin_utils.run(main)