NeMo / examples /nlp /language_modeling /megatron_t5_seq2seq_finetune.py
camenduru's picture
thanks to NVIDIA ❤
7934b29
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import tempfile
import torch.multiprocessing as mp
from omegaconf.omegaconf import OmegaConf, open_dict
from pytorch_lightning import Trainer
from pytorch_lightning.plugins.environments import TorchElasticEnvironment
from pytorch_lightning.trainer.connectors.checkpoint_connector import CheckpointConnector
from nemo.collections.nlp.models.language_modeling.megatron_finetune_model import MegatronT5FinetuneModel
from nemo.collections.nlp.models.language_modeling.megatron_glue_model import MegatronT5GLUEModel
from nemo.collections.nlp.models.language_modeling.megatron_t0_model import MegatronT0Model
from nemo.collections.nlp.modules.common.megatron.megatron_init import fake_initialize_model_parallel
from nemo.collections.nlp.parts.nlp_overrides import (
GradScaler,
MegatronHalfPrecisionPlugin,
NLPDDPStrategy,
NLPSaveRestoreConnector,
PipelineMixedPrecisionPlugin,
)
from nemo.core.config import hydra_runner
from nemo.utils import AppState, logging
from nemo.utils.exp_manager import exp_manager
from nemo.utils.model_utils import inject_model_parallel_rank
mp.set_start_method("spawn", force=True)
def _modify_config(t5_cfg, cfg, add_cfg_to_tree=False):
"""
This function modifies the original t5 pre-training config (t5_cfg) with attributes from the finetuning config (cfg).
The `add_cfg_to_tree` arg adds `cfg` to the top of the yaml tree which is needed for all `hparams.yaml` files when passed as an arg to `load_from_checkpoint()`.
"""
OmegaConf.set_struct(t5_cfg, True)
with open_dict(t5_cfg):
t5_cfg.megatron_amp_O2 = cfg.model.get('megatron_amp_O2', False)
if hasattr(t5_cfg, 'encoder') and hasattr(t5_cfg, 'decoder'):
t5_cfg.encoder.masked_softmax_fusion = False
t5_cfg.decoder.masked_softmax_fusion = False
t5_cfg.encoder.hidden_dropout = cfg.model.get('hidden_dropout', 0.1)
t5_cfg.decoder.hidden_dropout = cfg.model.get('hidden_dropout', 0.1)
if hasattr(t5_cfg.encoder, 'ffn_dropout'):
t5_cfg.encoder.ffn_dropout = cfg.model.get('ffn_dropout', 0.1)
if hasattr(t5_cfg.decoder, 'ffn_dropout'):
t5_cfg.decoder.ffn_dropout = cfg.model.get('ffn_dropout', 0.1)
else:
t5_cfg.hidden_dropout = cfg.model.get('hidden_dropout', 0.1)
t5_cfg.attention_dropout = cfg.model.get('attention_dropout', 0.1)
t5_cfg.masked_softmax_fusion = False
t5_cfg.data = cfg.model.data
t5_cfg.precision = cfg.trainer.precision
t5_cfg.optim = cfg.model.optim
t5_cfg.micro_batch_size = cfg.model.data.train_ds.micro_batch_size
t5_cfg.global_batch_size = cfg.model.data.train_ds.global_batch_size
# XNLI has eval languages in the yaml config.
if hasattr(cfg.model, 'eval_languages'):
t5_cfg.eval_languages = cfg.model.eval_languages
# This is needed when modifying a hparam file directly to load `.ckpt` files.
# This is not needed to modify the cfg in `.nemo` files.
if add_cfg_to_tree:
OmegaConf.resolve(t5_cfg)
t5_cfg.cfg = t5_cfg
return t5_cfg
def load_from_nemo(cls, cfg, trainer, t5_cfg, modify_confg_fn):
t5_cfg = modify_confg_fn(t5_cfg, cfg, add_cfg_to_tree=False)
model = cls.restore_from(
restore_path=cfg.model.restore_from_path,
trainer=trainer,
override_config_path=t5_cfg,
save_restore_connector=NLPSaveRestoreConnector(),
)
return model
def load_from_checkpoint_dir(cls, cfg, trainer, modify_confg_fn):
app_state = AppState()
if cfg.model.tensor_model_parallel_size > 1 or cfg.model.pipeline_model_parallel_size > 1:
app_state.model_parallel_size = cfg.model.tensor_model_parallel_size * cfg.model.pipeline_model_parallel_size
app_state.tensor_model_parallel_size = cfg.model.tensor_model_parallel_size
app_state.pipeline_model_parallel_size = cfg.model.pipeline_model_parallel_size
(
app_state.tensor_model_parallel_rank,
app_state.pipeline_model_parallel_rank,
app_state.model_parallel_size,
app_state.data_parallel_size,
app_state.pipeline_model_parallel_split_rank,
app_state.virtual_pipeline_model_parallel_rank,
) = fake_initialize_model_parallel(
world_size=app_state.model_parallel_size,
rank=trainer.global_rank,
tensor_model_parallel_size_=cfg.model.tensor_model_parallel_size,
pipeline_model_parallel_size_=cfg.model.pipeline_model_parallel_size,
pipeline_model_parallel_split_rank_=cfg.model.pipeline_model_parallel_split_rank,
)
checkpoint_path = inject_model_parallel_rank(
os.path.join(cfg.model.pretrained_checkpoint.checkpoint_dir, cfg.model.pretrained_checkpoint.checkpoint_name)
)
hparams_file = OmegaConf.load(cfg.model.pretrained_checkpoint.hparams_file)
t5_cfg = modify_confg_fn(hparams_file.cfg, cfg, add_cfg_to_tree=True)
with tempfile.NamedTemporaryFile(suffix='.yaml') as f:
OmegaConf.save(config=t5_cfg, f=f.name)
model = cls.load_from_checkpoint(checkpoint_path=checkpoint_path, trainer=trainer, hparams_file=f.name,)
return model
def validate_checkpoint_loading_args(cfg):
if cfg.checkpoint_dir is None or not os.path.isdir(cfg.checkpoint_dir):
raise ValueError(f'Checkpoint directory {cfg.checkpoint_dir} does not exist or is not a directory.')
if cfg.checkpoint_name is None:
raise ValueError(f'Checkpoint name {cfg.checkpoint_name} is not valid.')
if cfg.hparams_file is None or not os.path.isfile(cfg.hparams_file):
raise ValueError(f'Hparams file {cfg.hparams_file} does not exist or is not a file.')
@hydra_runner(config_path="conf", config_name="megatron_t5_config_finetune_glue_mnli")
def main(cfg) -> None:
logging.info("\n\n************** Experiment configuration ***********")
logging.info(f'\n{OmegaConf.to_yaml(cfg)}')
megatron_amp_o2 = cfg.model.get('megatron_amp_O2', False)
plugins = []
strategy = NLPDDPStrategy(
no_ddp_communication_hook=True,
gradient_as_bucket_view=cfg.model.gradient_as_bucket_view,
find_unused_parameters=False,
)
if cfg.trainer.precision in [16, 'bf16']:
scaler = None
if cfg.trainer.precision == 16:
scaler = GradScaler(
init_scale=cfg.model.get('native_amp_init_scale', 2 ** 32),
growth_interval=cfg.model.get('native_amp_growth_interval', 1000),
hysteresis=cfg.model.get('hysteresis', 2),
)
if megatron_amp_o2:
plugins.append(MegatronHalfPrecisionPlugin(precision=cfg.trainer.precision, device='cuda', scaler=scaler))
else:
plugins.append(PipelineMixedPrecisionPlugin(precision=cfg.trainer.precision, device='cuda', scaler=scaler))
if cfg.get('cluster_type', None) == 'BCP':
plugins.append(TorchElasticEnvironment())
trainer = Trainer(plugins=plugins, strategy=strategy, **cfg.trainer)
exp_manager(trainer, cfg.exp_manager)
# update resume from checkpoint found by exp_manager
if cfg.model.resume_from_checkpoint is not None:
resume_from_checkpoint = cfg.model.resume_from_checkpoint
else:
resume_from_checkpoint = trainer._checkpoint_connector.resume_from_checkpoint_fit_path
logging.info(f'Resuming training from checkpoint: {resume_from_checkpoint}')
trainer._checkpoint_connector = CheckpointConnector(trainer, resume_from_checkpoint=resume_from_checkpoint)
if hasattr(cfg.model.data.train_ds, 'task_name'):
if cfg.model.restore_from_path:
t5_cfg = MegatronT5GLUEModel.restore_from(
restore_path=cfg.model.restore_from_path, trainer=trainer, return_config=True
)
model = load_from_nemo(MegatronT5GLUEModel, cfg, trainer, t5_cfg, modify_confg_fn=_modify_config)
else:
validate_checkpoint_loading_args(cfg.model.pretrained_checkpoint)
model = load_from_checkpoint_dir(MegatronT5GLUEModel, cfg, trainer, modify_confg_fn=_modify_config)
elif hasattr(cfg.model.data.train_ds, 'file_names'):
if cfg.model.restore_from_path:
t5_cfg = MegatronT0Model.restore_from(
restore_path=cfg.model.restore_from_path, trainer=trainer, return_config=True
)
model = load_from_nemo(MegatronT0Model, cfg, trainer, t5_cfg, modify_confg_fn=_modify_config)
else:
validate_checkpoint_loading_args(cfg.model.pretrained_checkpoint)
model = load_from_checkpoint_dir(MegatronT0Model, cfg, trainer, modify_confg_fn=_modify_config)
else:
if cfg.model.restore_from_path:
t5_cfg = MegatronT5FinetuneModel.restore_from(
restore_path=cfg.model.restore_from_path, trainer=trainer, return_config=True
)
model = load_from_nemo(MegatronT5FinetuneModel, cfg, trainer, t5_cfg, modify_confg_fn=_modify_config)
else:
validate_checkpoint_loading_args(cfg.model.pretrained_checkpoint)
model = load_from_checkpoint_dir(MegatronT5FinetuneModel, cfg, trainer, modify_confg_fn=_modify_config)
trainer.fit(model)
trainer.validate(model)
if hasattr(cfg.model.data, 'test_ds'):
trainer.test(model)
if __name__ == '__main__':
main()