|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
from argparse import ArgumentParser |
|
|
|
import torch |
|
from lightning.pytorch.loggers import TensorBoardLogger |
|
from megatron.core.dist_checkpointing.validation import StrictHandling |
|
from megatron.core.distributed import DistributedDataParallelConfig |
|
from megatron.core.optimizer import OptimizerConfig |
|
|
|
from nemo import lightning as nl |
|
from nemo.collections import llm |
|
from nemo.collections.llm.gpt.data import ChatDataModule, MockDataModule |
|
from nemo.collections.nlp.modules.common.tokenizer_utils import get_tokenizer |
|
from nemo.lightning.pytorch.callbacks import ModelCheckpoint |
|
from nemo.lightning.pytorch.optim import CosineAnnealingScheduler |
|
from nemo.utils import logging |
|
|
|
|
|
os.environ["TOKENIZERS_PARALLELISM"] = "false" |
|
|
|
|
|
def get_args(): |
|
"""Parse the command-line arguments.""" |
|
parser = ArgumentParser( |
|
description=""" |
|
Script for training GPT models. Supports 4 modes, with different arguments needed in addition to the required arguments: |
|
1. Pretrain: no additional arguments required |
|
2. SFT: --use-chat-data required |
|
3. Distillation: --teacher_path required |
|
4. SFT Distillation: --use-chat-data and --teacher_path required |
|
""" |
|
) |
|
parser.add_argument("--name", type=str, required=True, help="Experiment name") |
|
parser.add_argument( |
|
"--model_path", |
|
type=str, |
|
required=True, |
|
help="Path to NeMo 2 checkpoint. If only model_path is provided, the model will be trained (pretrain or SFT). If teacher_path is also provided, the model will be distilled.", |
|
) |
|
parser.add_argument( |
|
"--teacher_path", |
|
type=str, |
|
required=False, |
|
help="Path to NeMo 2 checkpoint to use as a distillation teacher. Will trigger distillation mode if provided.", |
|
) |
|
parser.add_argument("--kd_config", type=str, help="""Path to Knowledge-Distillation config file""") |
|
parser.add_argument("--tp_size", type=int, default=1, help="Tensor parallel size") |
|
parser.add_argument("--cp_size", type=int, default=1, help="Context parallel size") |
|
parser.add_argument("--pp_size", type=int, default=1, help="Pipeline parallel size") |
|
parser.add_argument("--ep_size", type=int, default=1, help="Expert parallel size") |
|
parser.add_argument("--precision", type=str, default="bf16-mixed", help="Datatype for models and optimizer") |
|
parser.add_argument("--devices", type=int, default=1, help="Number of GPUs to use per node") |
|
parser.add_argument("--num_nodes", type=int, default=1, help="Number of nodes to use") |
|
parser.add_argument("--log_dir", type=str, required=True, help="Folder for logging and checkpoint saving") |
|
parser.add_argument("--max_steps", type=int, required=True, help="Number of global batches to process") |
|
parser.add_argument("--gbs", type=int, required=True, help="Global Batch Size") |
|
parser.add_argument("--mbs", type=int, required=True, help="Micro-batch Size") |
|
parser.add_argument( |
|
"--data_paths", |
|
nargs="+", |
|
help="List of tokenized data paths to load from. If using chat data, provide a single path.", |
|
) |
|
parser.add_argument("--split", type=str, default="99,1,0", help="Train,Val,Test ratios to split data") |
|
parser.add_argument("--index_mapping_dir", type=str, help="Folder to write cached data indices") |
|
parser.add_argument("--use-chat-data", action="store_true", help="Use chat data for fine-tuning.") |
|
parser.add_argument( |
|
"--chat-template-path", |
|
type=str, |
|
help="Path to Chat template .txt file to use for chat data. Only provide if overriding default chat template in HuggingFace tokenizer.", |
|
) |
|
parser.add_argument( |
|
"--use_mock_data", action="store_true", help="Use mock data instead of custom data in --data_paths" |
|
) |
|
parser.add_argument("--seq_length", type=int, required=True, help="Number of tokens per input sample") |
|
parser.add_argument( |
|
"--tokenizer", |
|
type=str, |
|
help="Name of tokenizer model to override default. Required if using chat data (--use-chat-data).", |
|
) |
|
parser.add_argument("--lr", type=float, default=1e-4, help="Base LR for Cosine-Annealing scheduler") |
|
parser.add_argument("--min_lr", type=float, default=1e-5, help="Minimum LR for Cosine-Annealing scheduler") |
|
parser.add_argument("--warmup_steps", type=int, default=50, help="Number of scheduler warmup steps") |
|
parser.add_argument("--val_check_interval", type=int, default=100, help="Validate + checkpoint every _ steps") |
|
parser.add_argument("--limit_val_batches", type=int, default=32, help="Number of batches per validation stage") |
|
parser.add_argument("--log_interval", type=int, default=10, help="Write to log every _ steps") |
|
parser.add_argument("--legacy_ckpt", action="store_true", help="Load ckpt saved with TE < 1.14") |
|
return parser.parse_args() |
|
|
|
|
|
def _read_chat_template(template_path: str): |
|
|
|
if not template_path: |
|
return None |
|
with open(template_path, 'r') as f: |
|
return f.read().strip() |
|
|
|
|
|
if __name__ == "__main__": |
|
args = get_args() |
|
|
|
|
|
strategy = nl.MegatronStrategy( |
|
tensor_model_parallel_size=args.tp_size, |
|
pipeline_model_parallel_size=args.pp_size, |
|
context_parallel_size=args.cp_size, |
|
expert_model_parallel_size=args.ep_size, |
|
sequence_parallel=(args.tp_size > 1), |
|
ddp=DistributedDataParallelConfig( |
|
grad_reduce_in_fp32=True, |
|
overlap_grad_reduce=True, |
|
overlap_param_gather=True, |
|
check_for_nan_in_grad=True, |
|
average_in_collective=True, |
|
), |
|
ckpt_load_strictness=StrictHandling.LOG_ALL if args.legacy_ckpt else None, |
|
) |
|
trainer = nl.Trainer( |
|
devices=args.devices, |
|
num_nodes=args.num_nodes, |
|
max_steps=args.max_steps, |
|
log_every_n_steps=args.log_interval, |
|
val_check_interval=args.val_check_interval, |
|
limit_val_batches=args.limit_val_batches, |
|
strategy=strategy, |
|
accelerator="gpu", |
|
plugins=nl.MegatronMixedPrecision( |
|
precision=args.precision, |
|
params_dtype=torch.bfloat16 if "bf16" in args.precision else torch.float32, |
|
autocast_enabled=False, |
|
grad_reduce_in_fp32=True, |
|
), |
|
) |
|
|
|
|
|
if not args.use_mock_data and not args.data_paths: |
|
raise ValueError("Must provide either custom dataset(s) in --data_paths or set --use_mock_data.") |
|
|
|
if args.use_mock_data: |
|
logging.warning("Using Mock Data for training!") |
|
data = MockDataModule(seq_length=args.seq_length, global_batch_size=args.gbs, micro_batch_size=args.mbs) |
|
elif args.use_chat_data: |
|
assert len(args.data_paths) == 1, "If using chat data, provide a single path." |
|
assert args.tokenizer is not None, "Tokenizer is required if using chat data." |
|
|
|
chat_template = _read_chat_template(args.chat_template_path) |
|
tokenizer = get_tokenizer(args.tokenizer, chat_template=chat_template) |
|
if '{% generation %}' not in tokenizer.tokenizer.chat_template: |
|
if not args.chat_template_path: |
|
raise ValueError( |
|
"Tokenizer does not contain the '{% generation %}' keyword. Please provide a chat template path using --chat-template-path." |
|
) |
|
raise ValueError( |
|
"Please ensure the chat template includes a '{% generation %}' keyword for proper assistant mask during training. See https://github.com/huggingface/transformers/pull/30650" |
|
) |
|
data = ChatDataModule( |
|
dataset_root=args.data_paths[0], |
|
seq_length=args.seq_length, |
|
tokenizer=tokenizer, |
|
global_batch_size=args.gbs, |
|
micro_batch_size=args.mbs, |
|
use_hf_tokenizer_chat_template=True, |
|
) |
|
else: |
|
data = llm.PreTrainingDataModule( |
|
paths=args.data_paths, |
|
seq_length=args.seq_length, |
|
global_batch_size=args.gbs, |
|
micro_batch_size=args.mbs, |
|
split=args.split, |
|
index_mapping_dir=args.index_mapping_dir, |
|
) |
|
|
|
|
|
optim_config = OptimizerConfig( |
|
optimizer="adam", |
|
lr=args.lr, |
|
bf16=("bf16" in args.precision), |
|
use_distributed_optimizer=True, |
|
) |
|
sched = CosineAnnealingScheduler( |
|
max_steps=args.max_steps, |
|
warmup_steps=args.warmup_steps, |
|
constant_steps=0, |
|
min_lr=args.min_lr, |
|
) |
|
optim = nl.MegatronOptimizerModule(optim_config, sched) |
|
|
|
|
|
checkpoint_callback = ModelCheckpoint( |
|
monitor="val_loss", |
|
save_top_k=1, |
|
every_n_train_steps=args.val_check_interval, |
|
) |
|
logger = nl.NeMoLogger( |
|
name=args.name, |
|
log_dir=args.log_dir, |
|
ckpt=checkpoint_callback, |
|
tensorboard=TensorBoardLogger(os.path.join(args.log_dir, args.name)), |
|
update_logger_directory=False, |
|
) |
|
|
|
|
|
resume = nl.AutoResume( |
|
resume_if_exists=True, |
|
resume_ignore_no_checkpoint=True, |
|
restore_config=nl.RestoreConfig(path=args.model_path), |
|
) |
|
|
|
if args.teacher_path: |
|
llm.distill( |
|
student_model_path=args.model_path, |
|
teacher_model_path=args.teacher_path, |
|
distillation_config_path=args.kd_config, |
|
data=data, |
|
trainer=trainer, |
|
log=logger, |
|
resume=resume, |
|
optim=optim, |
|
tokenizer=get_tokenizer(args.tokenizer) if args.tokenizer else None, |
|
) |
|
else: |
|
llm.train( |
|
model=args.model_path, |
|
data=data, |
|
trainer=trainer, |
|
optim=optim, |
|
log=logger, |
|
resume=resume, |
|
tokenizer="data" if args.use_chat_data else "model", |
|
) |
|
|