# Copyright 2022 The T5X Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. r"""Precompile and generates HLO from TPU metadata backend. TPU Metadata backend is a TPU backend without real TPU devices while supporting any TPU topologies, to allow work that doesn't require real TPUs to run as if it is, e.g., compiling/lowering a HLO graph with the backend. Ideally, the precompile defaults to cpu backend for default device array placement since metadata backend does not have memory allocation. The pjit function is pinned to use available TPU Metadata backend, for getting a proper lowering under TPU mesh. """ import os from typing import Iterator, Optional import jax from jax import random import numpy as np import t5.data.mixtures # pylint:disable=unused-import from t5x import models from t5x import partitioning from t5x import trainer as trainer_lib from t5x import utils import tensorflow as tf def precompile(*, model: models.BaseTransformerModel, train_dataset_cfg: utils.DatasetConfig, partitioner: partitioning.BasePartitioner, model_dir: str, random_seed: Optional[int], get_dataset_fn: utils.GetDatasetCallable = utils.get_dataset): """Compiles and dump the HLO to model dir, with HLO text dumps.""" rng = random.PRNGKey(random_seed or 42) _, trainer_rng = random.split(rng, 2) # TODO(hthu): Find a better way of getting dataset shapes instead of actually # reading database and iterate on it. data_layout = partitioner.get_data_layout(train_dataset_cfg.batch_size) ds_shard_id = data_layout.shard_id num_ds_shards = data_layout.num_shards def _verify_matching_vocabs(cfg: utils.DatasetConfig): ds_vocabs = utils.get_vocabulary(cfg) if (ds_vocabs[0] != model.input_vocabulary or ds_vocabs[1] != model.output_vocabulary): raise ValueError(f'Model and Task vocabularies do not match:\n' f' task={cfg.mixture_or_task_name}\n' f' ds_vocabs=({ds_vocabs[0]}, {ds_vocabs[1]})\n' f' model.input_vocabulary={model.input_vocabulary}\n' f' model.output_vocabulary={model.output_vocabulary}\n') _verify_matching_vocabs(train_dataset_cfg) train_ds = get_dataset_fn(train_dataset_cfg, ds_shard_id, num_ds_shards, model.FEATURE_CONVERTER_CLS) # Need to use full batch size. input_shapes = { k: (data_layout.batch_size, *v.shape[1:]) for k, v in train_ds.element_spec.items() } input_types = { k: v.dtype.as_numpy_dtype() for k, v in train_ds.element_spec.items() } checkpointable_train_iter = iter(train_ds) train_iter: Iterator[trainer_lib.BatchType] = map( lambda x: jax.tree_map(np.array, x), checkpointable_train_iter) batch = next(train_iter) # Compiling does not care about loading real weights. train_state_initializer = utils.TrainStateInitializer( optimizer_def=model.optimizer_def, init_fn=model.get_initial_variables, input_shapes=input_shapes, input_types=input_types, partitioner=partitioner) train_state_shape = train_state_initializer.global_train_state_shape train_state_axes = train_state_initializer.train_state_axes def train_step(train_state, batch): return trainer_lib.train_with_lr( train_state, batch, learning_rate=1e-3, dropout_rng=trainer_rng, model=model, num_microbatches=None, weight_metrics_computer=None) partitioned_step = partitioner.partition( train_step, in_axis_resources=(train_state_axes, partitioning.PartitionSpec('data',)), out_axis_resources=(train_state_axes, None), donate_argnums=(0,)) # PartitionedTrainCallable has lower() defined but isn't exposed in pytype. # TODO(hthu): Explicitly expose the lower() interface. # pytype: disable=attribute-error lowered = partitioned_step.lower(train_state_shape, batch) # pytype: enable=attribute-error # TODO(hthu): Make this a proper library without writing files by default. tf.io.gfile.makedirs(model_dir) with tf.io.gfile.GFile( os.path.join(model_dir, 'lowered_hlo_pre_optimization'), 'w') as f: f.write(lowered.compiler_ir(dialect='hlo').as_serialized_hlo_module_proto()) compiled = lowered.compile() output_path = os.path.join(model_dir, 'lowered_hlo_post_optimization') with tf.io.gfile.GFile(output_path, 'w') as f: f.write(compiled.compiler_ir()[0].as_serialized_hlo_module_proto()) with tf.io.gfile.GFile(os.path.join(model_dir, 'assignment'), 'wb') as f: np.save(f, partitioner.mesh.device_ids)