|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from argparse import ArgumentParser |
|
|
|
import torch |
|
from pytorch_lightning import Trainer |
|
|
|
from nemo.collections.nlp.parts.nlp_overrides import NLPDDPStrategy, NLPSaveRestoreConnector |
|
from nemo.utils import logging, model_utils |
|
from nemo.utils.app_state import AppState |
|
|
|
|
|
""" |
|
Usage: |
|
python megatron_change_num_partitions.py \ |
|
--model_file=PATH_TO_SRC_FILE \ |
|
--target_file=PATH_TO_TGT_FILE \ |
|
--tensor_model_parallel_size=2 \ |
|
--target_tensor_model_parallel_size=1 |
|
""" |
|
|
|
|
|
def merge_partition(model, partitions, write_path=None): |
|
idx = 0 |
|
for name, param in model.named_parameters(): |
|
if param.shape == partitions[0][idx].shape: |
|
concated = partitions[0][idx].data |
|
elif param.shape[0] == partitions[0][idx].shape[0]: |
|
concated = torch.cat([partitions[i][idx].data for i in range(len(partitions))], dim=-1) |
|
else: |
|
concated = torch.cat([partitions[i][idx].data for i in range(len(partitions))], dim=0) |
|
if concated.shape != param.shape: |
|
logging.info( |
|
f"Warning: Shape mismatch for parameter {name} required shape: {param.shape}, merged shape: {concated.shape}. Narrowing to match required size." |
|
) |
|
if concated.shape[1:] == param.shape[1:]: |
|
concated = torch.narrow(concated, 0, 0, param.shape[0]) |
|
elif concated.shape[:-1] == param.shape[:-1]: |
|
concated = torch.narrow(concated, -1, 0, param.shape[-1]) |
|
else: |
|
raise RuntimeError( |
|
f"Can not handle parameter {name}, required shape: {param.shape}, merged shape: {concated.shape}." |
|
) |
|
param.data = concated |
|
idx += 1 |
|
|
|
if write_path is not None: |
|
model.save_to(write_path) |
|
|
|
|
|
def split_partition(model, partitions, tp_size, write_path=None, megatron_legacy=False): |
|
if len(partitions) != 1: |
|
raise ValueError( |
|
"Can only split partitions of model with TP=1. For partitions of models with TP>1, merge first." |
|
) |
|
|
|
if tp_size < 1: |
|
raise ValueError("TP size must to be >= 1.") |
|
|
|
app_state = AppState() |
|
app_state.data_parallel_rank = 0 |
|
app_state.pipeline_model_parallel_size = 1 |
|
app_state.tensor_model_parallel_size = tp_size |
|
app_state.model_parallel_size = app_state.pipeline_model_parallel_size * app_state.tensor_model_parallel_size |
|
|
|
app_state.tensor_model_parallel_rank = tp_size - 1 |
|
|
|
idx = 0 |
|
splits = [] |
|
for param_name, param in model.named_parameters(): |
|
if param.shape == partitions[0][idx].shape: |
|
split = [partitions[0][idx].data] * tp_size |
|
elif param.shape[0] == partitions[0][idx].shape[0]: |
|
split = torch.split(partitions[0][idx].data, param.shape[-1], dim=-1) |
|
else: |
|
|
|
if 'query_key_value.weight' in param_name and megatron_legacy: |
|
split_dim = partitions[0][idx].data.shape[0] |
|
if split_dim % (tp_size * 3) != 0: |
|
raise ValueError( |
|
f"Can not split Q,K,V parameter {param_name} with shape {param.shape} into tensor parallel size {tp_size}. Not divisible by {tp_size * 3}." |
|
) |
|
tp_qkv_splits = torch.chunk(partitions[0][idx].data, tp_size * 3, dim=0) |
|
split = [] |
|
for i in range(tp_size): |
|
tp_qkv = torch.cat([tp_qkv_splits[item] for item in range(i, tp_size * 3, tp_size)]) |
|
split.append(tp_qkv) |
|
elif 'key_value.weight' in param_name and megatron_legacy: |
|
split_dim = partitions[0][idx].data.shape[0] |
|
if split_dim % (tp_size * 2) != 0: |
|
raise ValueError( |
|
f"Can not split K,V parameter {param_name} with shape {param.shape} into tensor parallel size {tp_size}. Not divisible by {tp_size * 2}." |
|
) |
|
tp_qkv_splits = torch.chunk(partitions[0][idx].data, tp_size * 2, dim=0) |
|
split = [] |
|
for i in range(tp_size): |
|
tp_qkv = torch.cat([tp_qkv_splits[item] for item in range(i, tp_size * 2, tp_size)]) |
|
split.append(tp_qkv) |
|
|
|
else: |
|
split = torch.split(partitions[0][idx].data, param.shape[0], dim=0) |
|
splits.append(split) |
|
idx += 1 |
|
|
|
for i in range(tp_size - 1, -1, -1): |
|
app_state.tensor_model_parallel_rank = i |
|
|
|
idx = 0 |
|
for name, param in model.named_parameters(): |
|
split_val = splits[idx][i].clone() |
|
|
|
if param.shape != split_val.shape: |
|
logging.info( |
|
f"Warning: Shape mismatch for parameter {name} required shape: {param.shape}, split shape: {split_val.shape}. Padding to match required size." |
|
) |
|
|
|
if split_val.shape[1:] == param.shape[1:]: |
|
pad = [0, 0] * len(split_val.shape) |
|
pad[-1] = param.shape[0] - split_val.shape[0] |
|
split_val = torch.nn.functional.pad(split_val, pad, 'constant') |
|
elif split_val.shape[:-1] == param.shape[:-1]: |
|
pad = [0, param.shape[-1] - split_val.shape[-1]] |
|
split_val = torch.nn.functional.pad(split_val, pad, 'constant') |
|
else: |
|
raise RuntimeError( |
|
f"Can not handle parameter {name}, required shape: {param.shape}, split shape: {split_val.shape}." |
|
) |
|
|
|
param.data = split_val |
|
idx += 1 |
|
|
|
if write_path is not None: |
|
model.save_to(write_path) |
|
|
|
|
|
def main(): |
|
parser = ArgumentParser() |
|
parser.add_argument("--model_file", type=str, required=True, help="Path to source .nemo file") |
|
parser.add_argument("--target_file", type=str, required=True, help="Path to write target .nemo file") |
|
parser.add_argument("--tensor_model_parallel_size", type=int, required=True, help="TP size of source model") |
|
parser.add_argument("--target_tensor_model_parallel_size", type=int, required=True, help="TP size of target model") |
|
parser.add_argument( |
|
"--model_class", |
|
type=str, |
|
default="nemo.collections.nlp.models.language_modeling.megatron_gpt_model.MegatronGPTModel", |
|
help="NeMo model class. This script should support all NeMo megatron models that use Tensor Parallel", |
|
) |
|
parser.add_argument("--precision", default=16, help="PyTorch Lightning Trainer precision flag") |
|
parser.add_argument( |
|
"--megatron_legacy", |
|
action="store_true", |
|
help="Converter for legacy megatron modles that have different q,k,v weight splits", |
|
) |
|
parser.add_argument( |
|
"--tokenizer_model_path", |
|
type=str, |
|
required=False, |
|
default=None, |
|
help="Path to the tokenizer model path if your model uses a tokenizer model as an artifact. This is needed if your model uses a sentencepiece tokenizer.", |
|
) |
|
|
|
args = parser.parse_args() |
|
|
|
precision = args.precision |
|
if args.precision in ["32", "16"]: |
|
precision = int(float(args.precision)) |
|
tp_size = args.tensor_model_parallel_size |
|
tgt_tp_size = args.target_tensor_model_parallel_size |
|
cls = model_utils.import_class_by_path(args.model_class) |
|
|
|
trainer = Trainer(devices=1, strategy=NLPDDPStrategy(), accelerator="cpu", precision=precision) |
|
app_state = AppState() |
|
app_state.data_parallel_rank = 0 |
|
app_state.pipeline_model_parallel_size = 1 |
|
app_state.tensor_model_parallel_size = tp_size |
|
app_state.model_parallel_size = app_state.pipeline_model_parallel_size * app_state.tensor_model_parallel_size |
|
|
|
if tp_size > 1: |
|
partitions = [] |
|
for i in range(tp_size): |
|
app_state.tensor_model_parallel_rank = i |
|
model = cls.restore_from(restore_path=args.model_file, trainer=trainer, map_location=torch.device("cpu")) |
|
params = [p for _, p in model.named_parameters()] |
|
partitions.append(params) |
|
|
|
app_state.data_parallel_rank = 0 |
|
app_state.pipeline_model_parallel_size = 1 |
|
app_state.tensor_model_parallel_size = tp_size |
|
app_state.model_parallel_size = ( |
|
app_state.pipeline_model_parallel_size * app_state.tensor_model_parallel_size |
|
) |
|
|
|
model.cfg.tensor_model_parallel_size = 1 |
|
app_state.model_parallel_size = 1 |
|
trainer = Trainer(devices=1, strategy=NLPDDPStrategy(), accelerator="cpu", precision=precision) |
|
if args.tokenizer_model_path is not None: |
|
model.cfg.tokenizer.model = args.tokenizer_model_path |
|
model = cls(model.cfg, trainer).to('cpu') |
|
model._save_restore_connector = NLPSaveRestoreConnector() |
|
|
|
if tgt_tp_size > 1: |
|
merge_partition(model, partitions) |
|
else: |
|
merge_partition(model, partitions, args.target_file) |
|
else: |
|
app_state.model_parallel_size = 1 |
|
model = cls.restore_from(restore_path=args.model_file, trainer=trainer, map_location=torch.device("cpu")) |
|
|
|
if tgt_tp_size > 1: |
|
partitions = [] |
|
params = [p for _, p in model.named_parameters()] |
|
partitions.append(params) |
|
|
|
model.cfg.tensor_model_parallel_size = tgt_tp_size |
|
app_state.model_parallel_size = tgt_tp_size |
|
trainer = Trainer(devices=1, strategy=NLPDDPStrategy(), accelerator="cpu", precision=precision) |
|
if args.tokenizer_model_path is not None: |
|
model.cfg.tokenizer.model = args.tokenizer_model_path |
|
model = cls(model.cfg, trainer).to('cpu') |
|
model._save_restore_connector = NLPSaveRestoreConnector() |
|
split_partition(model, partitions, tgt_tp_size, args.target_file, args.megatron_legacy) |
|
|
|
logging.info("Successfully finished changing partitions!") |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |
|
|