Spaces:
Build error
Build error
# Copyright 2022 The T5X Authors. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
r"""Tool to convert a T5/MeshTF checkpoint to T5X. | |
While T5X can be load these checkpoints on-the-fly, the process can be slow | |
for very large checkpoints. For frequently used checkpoints, it's recommended to | |
convert them once to a T5X checkpoint. | |
Example usage: | |
CUDA_VISIBLE_DEVICES="" | |
python -m t5x.scripts.convert_tf_checkpoint \ | |
--gin_file=t5x/examples/t5/t5_1_0/small.gin\ | |
--gin.convert_checkpoint.model=%MODEL\ | |
--gin.convert_checkpoint.tf_checkpoint_path=\ | |
\"gs://t5-data/pretrained_models/small/model.ckpt-1000000\"\ | |
--gin.convert_checkpoint.output_dir=\"/tmp/t5x_checkpoints/t5_small\"\ | |
--logtostderr | |
""" | |
import jax | |
import jax.numpy as jnp | |
from t5x import checkpoints | |
from t5x import models | |
from t5x import partitioning | |
from t5x import train_state as train_state_lib | |
def convert_checkpoint(model: models.BaseModel, | |
tf_checkpoint_path: str, | |
output_dir: str, | |
save_dtype: jnp.dtype = jnp.float32, | |
concurrent_gb: int = 16): | |
"""Converts a TensorFlow checkpoint to a P5X checkpoint. | |
Args: | |
model: | |
tf_checkpoint_path: Path to a TensorFlow checkpoint to convert. | |
output_dir: Path to a directory to write the converted checkpoint. | |
save_dtype: What dtype to store the target parameters as. | |
concurrent_gb: Number of gigabtes of parameters to convert in parallel. | |
Actual RAM usage may be 4X this number. | |
""" | |
def initialize_train_state(rng): | |
initial_variables = model.get_initial_variables( | |
rng=rng, | |
input_shapes={ | |
'encoder_input_tokens': (1, 1), | |
'decoder_input_tokens': (1, 1) | |
}) | |
return train_state_lib.FlaxOptimTrainState.create(model.optimizer_def, | |
initial_variables) | |
train_state = jax.eval_shape(initialize_train_state, jax.random.PRNGKey(0)) | |
partitioner = partitioning.PjitPartitioner(1) | |
checkpointer = checkpoints.Checkpointer( | |
train_state, partitioner, output_dir, save_dtype=jnp.dtype(save_dtype)) | |
checkpointer.convert_from_tf_checkpoint( | |
tf_checkpoint_path, concurrent_gb=concurrent_gb) | |
if __name__ == '__main__': | |
# pylint:disable=g-import-not-at-top | |
from absl import flags | |
import gin | |
from t5x import gin_utils | |
# pylint:disable=g-import-not-at-top | |
FLAGS = flags.FLAGS | |
jax.config.parse_flags_with_absl() | |
flags.DEFINE_multi_string( | |
'gin_file', | |
default=None, | |
help='Path to gin configuration file. Multiple paths may be passed and ' | |
'will be imported in the given order, with later configurations ' | |
'overriding earlier ones.') | |
flags.DEFINE_multi_string( | |
'gin_bindings', default=[], help='Individual gin bindings') | |
flags.DEFINE_list( | |
'gin_search_paths', | |
default=['t5x/configs'], | |
help='Comma-separated list of gin config path prefixes to be prepended ' | |
'to suffixes given via `--gin_file`. If a file appears in. Only the ' | |
'first prefix that produces a valid path for each suffix will be ' | |
'used.') | |
def main(_): | |
"""True main function.""" | |
convert_checkpoint_using_gin = gin.configurable(convert_checkpoint) | |
gin_utils.parse_gin_flags(FLAGS.gin_search_paths, FLAGS.gin_file, | |
FLAGS.gin_bindings) | |
# Get gin-configured version of `convert_checkpoint`. | |
convert_checkpoint_using_gin() | |
gin_utils.run(main) | |