|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Input/output checkpointing.""" |
|
|
|
import os |
|
import random |
|
import sys |
|
import numpy as np |
|
|
|
import torch |
|
|
|
from megatron import (mpu, |
|
update_num_microbatches) |
|
from .global_vars import get_args |
|
from .utils import (unwrap_model, |
|
print_rank_0) |
|
|
|
|
|
_CHECKPOINT_VERSION = None |
|
|
|
def set_checkpoint_version(value): |
|
global _CHECKPOINT_VERSION |
|
if _CHECKPOINT_VERSION is not None: |
|
assert _CHECKPOINT_VERSION == value, \ |
|
"checkpoint versions do not match" |
|
_CHECKPOINT_VERSION = value |
|
|
|
def get_checkpoint_version(): |
|
global _CHECKPOINT_VERSION |
|
return _CHECKPOINT_VERSION |
|
|
|
def check_checkpoint_args(checkpoint_args): |
|
"""Ensure fixed arguments for a model are the same for the input |
|
arguments and the one retrieved from checkpoint.""" |
|
args = get_args() |
|
|
|
def _compare(arg_name, old_arg_name=None): |
|
if old_arg_name is not None: |
|
checkpoint_value = getattr(checkpoint_args, old_arg_name) |
|
else: |
|
checkpoint_value = getattr(checkpoint_args, arg_name) |
|
args_value = getattr(args, arg_name) |
|
error_message = '{} value from checkpoint ({}) is not equal to the ' \ |
|
'input argument value ({}).'.format( |
|
arg_name, checkpoint_value, args_value) |
|
assert checkpoint_value == args_value, error_message |
|
|
|
_compare('num_layers') |
|
_compare('hidden_size') |
|
_compare('num_attention_heads') |
|
if args.vocab_file: |
|
_compare('max_position_embeddings') |
|
_compare('make_vocab_size_divisible_by') |
|
_compare('padded_vocab_size') |
|
_compare('tokenizer_type') |
|
if args.data_parallel_random_init: |
|
_compare('data_parallel_random_init') |
|
if get_checkpoint_version() < 3.0: |
|
_compare('tensor_model_parallel_size', |
|
old_arg_name='model_parallel_size') |
|
if get_checkpoint_version() >= 3.0: |
|
_compare('tensor_model_parallel_size') |
|
_compare('pipeline_model_parallel_size') |
|
|
|
def ensure_directory_exists(filename): |
|
"""Build filename's path if it does not already exists.""" |
|
dirname = os.path.dirname(filename) |
|
if not os.path.exists(dirname): |
|
os.makedirs(dirname) |
|
|
|
|
|
def get_checkpoint_names(checkpoints_path, iteration, use_distributed_optimizer, release=False, |
|
pipeline_parallel=None, tensor_rank=None, pipeline_rank=None): |
|
"""Determine the directory name for this rank's checkpoint.""" |
|
if release: |
|
directory = 'release' |
|
else: |
|
directory = 'iter_{:07d}'.format(iteration) |
|
|
|
|
|
if pipeline_parallel is None: |
|
pipeline_parallel = (mpu.get_pipeline_model_parallel_world_size() > 1) |
|
if tensor_rank is None: |
|
tensor_rank = mpu.get_tensor_model_parallel_rank() |
|
if pipeline_rank is None: |
|
pipeline_rank = mpu.get_pipeline_model_parallel_rank() |
|
|
|
|
|
|
|
|
|
if not pipeline_parallel: |
|
common_path = os.path.join(checkpoints_path, directory, |
|
f'mp_rank_{tensor_rank:02d}') |
|
else: |
|
common_path = os.path.join(checkpoints_path, directory, |
|
f'mp_rank_{tensor_rank:02d}_{pipeline_rank:03d}') |
|
|
|
if use_distributed_optimizer: |
|
model_name = os.path.join(common_path, "model_rng.pt") |
|
optim_name = os.path.join( |
|
common_path + "_%03d" % mpu.get_data_parallel_rank(), |
|
"optim.pt") |
|
else: |
|
model_name = optim_name = os.path.join(common_path, "model_optim_rng.pt") |
|
return model_name, optim_name |
|
|
|
def find_checkpoint_rank_0(checkpoints_path, iteration, use_distributed_optimizer, release=False): |
|
"""Finds the checkpoint for rank 0 without knowing if we are using |
|
pipeline parallelism or not. |
|
|
|
Since the checkpoint naming scheme changes if pipeline parallelism |
|
is present, we need to look for both naming schemes if we don't |
|
know if the checkpoint has pipeline parallelism. |
|
|
|
""" |
|
|
|
|
|
filenames = get_checkpoint_names(checkpoints_path, iteration, use_distributed_optimizer, release, |
|
pipeline_parallel=False, |
|
tensor_rank=0, pipeline_rank=0) |
|
if os.path.isfile(filenames[0]): |
|
return filenames |
|
|
|
|
|
filenames = get_checkpoint_names(checkpoints_path, iteration, use_distributed_optimizer, release, |
|
pipeline_parallel=True, |
|
tensor_rank=0, pipeline_rank=0) |
|
if os.path.isfile(filenames[0]): |
|
return filenames |
|
|
|
return None, None |
|
|
|
def get_checkpoint_tracker_filename(checkpoints_path): |
|
|
|
"""Tracker file rescords the latest chckpoint during |
|
training to restart from.""" |
|
return os.path.join(checkpoints_path, 'latest_checkpointed_iteration.txt') |
|
|
|
|
|
def read_metadata(tracker_filename): |
|
|
|
|
|
iteration = 0 |
|
release = False |
|
with open(tracker_filename, 'r') as f: |
|
metastring = f.read().strip() |
|
try: |
|
iteration = int(metastring) |
|
except ValueError: |
|
release = metastring == 'release' |
|
if not release: |
|
print_rank_0('ERROR: Invalid metadata file {}. Exiting'.format( |
|
tracker_filename)) |
|
sys.exit() |
|
assert iteration > 0 or release, 'error parsing metadata file {}'.format( |
|
tracker_filename) |
|
|
|
|
|
if torch.distributed.is_initialized(): |
|
iters_cuda = torch.cuda.LongTensor([iteration]) |
|
torch.distributed.all_reduce(iters_cuda, op=torch.distributed.ReduceOp.MAX) |
|
max_iter = iters_cuda[0].item() |
|
|
|
|
|
|
|
|
|
if iteration != max_iter: |
|
print('WARNING: on rank {} found iteration {} in the ' |
|
'metadata while max iteration across the ranks ' |
|
'is {}, replacing it with max iteration.'.format( |
|
rank, iteration, max_iter), flush=True) |
|
else: |
|
|
|
|
|
|
|
max_iter = iteration |
|
return max_iter, release |
|
|
|
|
|
def get_rng_state(): |
|
""" collect rng state across data parallel ranks """ |
|
args = get_args() |
|
rng_state = { |
|
'random_rng_state': random.getstate(), |
|
'np_rng_state': np.random.get_state(), |
|
'torch_rng_state': torch.get_rng_state(), |
|
'cuda_rng_state': torch.cuda.get_rng_state(), |
|
'rng_tracker_states': mpu.get_cuda_rng_tracker().get_states()} |
|
|
|
rng_state_list = None |
|
if torch.distributed.is_initialized() and \ |
|
mpu.get_data_parallel_world_size() > 1 and \ |
|
args.data_parallel_random_init: |
|
rng_state_list = \ |
|
[None for i in range(mpu.get_data_parallel_world_size())] |
|
torch.distributed.all_gather_object( |
|
rng_state_list, |
|
rng_state, |
|
group=mpu.get_data_parallel_group()) |
|
else: |
|
rng_state_list = [rng_state] |
|
|
|
return rng_state_list |
|
|
|
|
|
def save_checkpoint(iteration, model, optimizer, opt_param_scheduler): |
|
"""Save a model checkpoint.""" |
|
args = get_args() |
|
|
|
|
|
model = unwrap_model(model) |
|
|
|
print_rank_0('saving checkpoint at iteration {:7d} to {}'.format( |
|
iteration, args.save)) |
|
|
|
|
|
rng_state = get_rng_state() |
|
|
|
|
|
model_checkpoint_name, optim_checkpoint_name = \ |
|
get_checkpoint_names(args.save, iteration, args.use_distributed_optimizer) |
|
|
|
|
|
model_state_dict = {} |
|
if not torch.distributed.is_initialized() \ |
|
or mpu.get_data_parallel_rank() == 0: |
|
|
|
|
|
model_state_dict['args'] = args |
|
model_state_dict['checkpoint_version'] = 3.0 |
|
model_state_dict['iteration'] = iteration |
|
if len(model) == 1: |
|
model_state_dict['model'] = model[0].state_dict_for_save_checkpoint() |
|
else: |
|
for i in range(len(model)): |
|
mpu.set_virtual_pipeline_model_parallel_rank(i) |
|
model_state_dict['model%d' % i] = \ |
|
model[i].state_dict_for_save_checkpoint() |
|
|
|
|
|
if not args.no_save_rng: |
|
model_state_dict["rng_state"] = rng_state |
|
|
|
|
|
|
|
optim_state_dict = {} |
|
if not args.no_save_optim \ |
|
and (not torch.distributed.is_initialized() |
|
or mpu.get_data_parallel_rank() == 0 |
|
or args.use_distributed_optimizer): |
|
|
|
|
|
if optimizer is not None: |
|
optim_state_dict['optimizer'] = optimizer.state_dict() |
|
if opt_param_scheduler is not None: |
|
optim_state_dict['opt_param_scheduler'] = \ |
|
opt_param_scheduler.state_dict() |
|
|
|
|
|
if args.use_distributed_optimizer: |
|
|
|
if model_state_dict: |
|
ensure_directory_exists(model_checkpoint_name) |
|
torch.save(model_state_dict, model_checkpoint_name) |
|
if optim_state_dict: |
|
ensure_directory_exists(optim_checkpoint_name) |
|
torch.save(optim_state_dict, optim_checkpoint_name) |
|
else: |
|
|
|
state_dict = {**model_state_dict, **optim_state_dict} |
|
if state_dict: |
|
ensure_directory_exists(model_checkpoint_name) |
|
torch.save(state_dict, model_checkpoint_name) |
|
|
|
|
|
if torch.distributed.is_initialized(): |
|
torch.distributed.barrier() |
|
|
|
print_rank_0(' successfully saved checkpoint at iteration {:7d} to {}'.format( |
|
iteration, args.save)) |
|
|
|
|
|
if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0: |
|
tracker_filename = get_checkpoint_tracker_filename(args.save) |
|
with open(tracker_filename, 'w') as f: |
|
f.write(str(iteration)) |
|
|
|
|
|
if torch.distributed.is_initialized(): |
|
torch.distributed.barrier() |
|
|
|
def _transpose_first_dim(t, num_splits, num_splits_first, model): |
|
input_shape = t.size() |
|
|
|
|
|
while hasattr(model, 'module'): |
|
model = model.module |
|
attention_module = model.language_model.encoder.layers[0].self_attention |
|
hidden_size_per_attention_head = attention_module.hidden_size_per_attention_head |
|
num_attention_heads_per_partition = attention_module.num_attention_heads_per_partition |
|
if num_splits_first: |
|
"""[num_splits * np * hn, h] |
|
-->(view) [num_splits, np, hn, h] |
|
-->(tranpose) [np, num_splits, hn, h] |
|
-->(view) [np * num_splits * hn, h] """ |
|
|
|
intermediate_shape = \ |
|
(num_splits, num_attention_heads_per_partition, |
|
hidden_size_per_attention_head) + input_shape[1:] |
|
|
|
t = t.view(*intermediate_shape) |
|
t = t.transpose(0, 1).contiguous() |
|
else: |
|
"""[np * hn * num_splits, h] |
|
-->(view) [np, hn, num_splits, h] |
|
-->(tranpose) [np, num_splits, hn, h] |
|
-->(view) [np * num_splits * hn, h] """ |
|
|
|
intermediate_shape = \ |
|
(num_attention_heads_per_partition, |
|
hidden_size_per_attention_head, num_splits) +\ |
|
input_shape[1:] |
|
|
|
t = t.view(*intermediate_shape) |
|
t = t.transpose(1, 2).contiguous() |
|
t = t.view(*input_shape) |
|
|
|
return t |
|
|
|
def fix_query_key_value_ordering(model, checkpoint_version): |
|
"""Fix up query/key/value matrix ordering if checkpoint |
|
version is smaller than 2.0 |
|
""" |
|
if checkpoint_version < 2.0: |
|
if isinstance(model, list): |
|
assert len(model)==1 |
|
model = model[0] |
|
for name, param in model.named_parameters(): |
|
if name.endswith(('.query_key_value.weight', '.query_key_value.bias')): |
|
if checkpoint_version == 0: |
|
fixed_param = _transpose_first_dim(param.data, 3, True, model) |
|
elif checkpoint_version == 1.0: |
|
fixed_param = _transpose_first_dim(param.data, 3, False, model) |
|
else: |
|
print_rank_0(f"Invalid checkpoint version {checkpoint_version}.") |
|
sys.exit() |
|
param.data.copy_(fixed_param) |
|
if name.endswith(('.key_value.weight', '.key_value.bias')): |
|
if checkpoint_version == 0: |
|
fixed_param = _transpose_first_dim(param.data, 2, True, model) |
|
elif checkpoint_version == 1.0: |
|
fixed_param = _transpose_first_dim(param.data, 2, False, model) |
|
else: |
|
print_rank_0(f"Invalid checkpoint version {checkpoint_version}.") |
|
sys.exit() |
|
param.data.copy_(fixed_param) |
|
print_rank_0(" succesfully fixed query-key-values ordering for" |
|
" checkpoint version {}".format(checkpoint_version)) |
|
|
|
def _load_base_checkpoint(load_dir, use_distributed_optimizer, rank0=False): |
|
""" Load the base state_dict from the given directory |
|
|
|
If rank0 is true, just loads rank 0 checkpoint, ignoring arguments. |
|
""" |
|
|
|
|
|
|
|
tracker_filename = get_checkpoint_tracker_filename(load_dir) |
|
|
|
|
|
if not os.path.isfile(tracker_filename): |
|
if not rank0: |
|
print_rank_0('WARNING: could not find the metadata file {} '.format( |
|
tracker_filename)) |
|
print_rank_0(' will not load any checkpoints and will start from ' |
|
'random') |
|
return None, None, False |
|
|
|
|
|
|
|
iteration, release = read_metadata(tracker_filename) |
|
|
|
|
|
if rank0: |
|
checkpoint_names = find_checkpoint_rank_0(load_dir, iteration, use_distributed_optimizer, |
|
release) |
|
else: |
|
checkpoint_names = get_checkpoint_names(load_dir, iteration, use_distributed_optimizer, |
|
release) |
|
if release: |
|
print_rank_0(f' loading release checkpoint from {load_dir}') |
|
else: |
|
print_rank_0(f' loading checkpoint from {load_dir} at iteration {iteration}') |
|
|
|
model_checkpoint_name, optim_checkpoint_name = checkpoint_names |
|
|
|
|
|
try: |
|
model_state_dict = torch.load(model_checkpoint_name, map_location='cpu') |
|
if use_distributed_optimizer: |
|
optim_state_dict = torch.load(optim_checkpoint_name, map_location='cpu') |
|
else: |
|
optim_state_dict = model_state_dict |
|
except ModuleNotFoundError: |
|
from megatron.fp16_deprecated import loss_scaler |
|
|
|
if not rank0: |
|
print_rank_0(' > deserializing using the old code structure ...') |
|
sys.modules['fp16.loss_scaler'] = sys.modules[ |
|
'megatron.fp16_deprecated.loss_scaler'] |
|
sys.modules['megatron.fp16.loss_scaler'] = sys.modules[ |
|
'megatron.fp16_deprecated.loss_scaler'] |
|
model_state_dict = torch.load(model_checkpoint_name, map_location='cpu') |
|
optim_state_dict = torch.load(optim_checkpoint_name, map_location='cpu') |
|
sys.modules.pop('fp16.loss_scaler', None) |
|
sys.modules.pop('megatron.fp16.loss_scaler', None) |
|
except BaseException as e: |
|
print_rank_0('could not load the checkpoint') |
|
print_rank_0(e) |
|
sys.exit() |
|
|
|
return model_state_dict, optim_state_dict, release |
|
|
|
def load_args_from_checkpoint(args, load_arg='load'): |
|
"""Set required arguments from the checkpoint specified in the |
|
arguments. |
|
|
|
Will overwrite arguments that have a non-None default value, but |
|
will leave any arguments that default to None as set. |
|
|
|
Returns the same args NameSpace with the new values added/updated. |
|
|
|
If no checkpoint is specified in args, or if the checkpoint is |
|
there but invalid, the arguments will not be modified |
|
|
|
""" |
|
load_dir = getattr(args, load_arg) |
|
|
|
if load_dir is None: |
|
print_rank_0('No load directory specified, using provided arguments.') |
|
return args |
|
|
|
model_state_dict, optim_state_dict, release = \ |
|
_load_base_checkpoint(load_dir, |
|
use_distributed_optimizer=args.use_distributed_optimizer, |
|
rank0=True) |
|
|
|
|
|
state_dict = model_state_dict |
|
|
|
if not state_dict: |
|
print_rank_0('Checkpoint not found to provide arguments, using provided arguments.') |
|
return args |
|
|
|
if 'args' not in state_dict: |
|
print_rank_0('Checkpoint provided does not have arguments saved, using provided arguments.') |
|
return args |
|
|
|
checkpoint_args = state_dict['args'] |
|
checkpoint_version = state_dict.get('checkpoint_version', 0) |
|
args.iteration = state_dict['iteration'] |
|
|
|
def _set_arg(arg_name, old_arg_name=None, force=False): |
|
if not force and getattr(args, arg_name, None) is not None: |
|
return |
|
|
|
if old_arg_name is not None: |
|
checkpoint_value = getattr(checkpoint_args, old_arg_name, None) |
|
else: |
|
checkpoint_value = getattr(checkpoint_args, arg_name, None) |
|
|
|
if checkpoint_value is not None: |
|
print_rank_0(f"Setting {arg_name} to {checkpoint_value} from checkpoint") |
|
setattr(args, arg_name, checkpoint_value) |
|
|
|
_set_arg('num_layers') |
|
_set_arg('hidden_size') |
|
_set_arg('ffn_hidden_size') |
|
_set_arg('seq_length') |
|
_set_arg('num_attention_heads') |
|
_set_arg('kv_channels') |
|
_set_arg('max_position_embeddings') |
|
_set_arg('tokenizer_type') |
|
_set_arg('padded_vocab_size') |
|
if checkpoint_version < 3.0: |
|
_set_arg('tensor_model_parallel_size', |
|
'model_parallel_size') |
|
else: |
|
_set_arg('tensor_model_parallel_size', force=True) |
|
_set_arg('pipeline_model_parallel_size', force=True) |
|
_set_arg('num_layers_per_virtual_pipeline_stage') |
|
return args |
|
|
|
|
|
def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', strict=True): |
|
"""Load a model checkpoint and return the iteration. |
|
strict (bool): whether to strictly enforce that the keys in |
|
:attr:`state_dict` of the checkpoint match the names of |
|
parameters and buffers in model. |
|
""" |
|
args = get_args() |
|
load_dir = getattr(args, load_arg) |
|
|
|
model = unwrap_model(model) |
|
|
|
model_state_dict, optim_state_dict, release = \ |
|
_load_base_checkpoint(load_dir, |
|
use_distributed_optimizer=args.use_distributed_optimizer, |
|
rank0=False) |
|
|
|
if model_state_dict is None: |
|
return 0 |
|
|
|
|
|
set_checkpoint_version(model_state_dict.get('checkpoint_version', 0)) |
|
|
|
|
|
if args.finetune or release: |
|
iteration = 0 |
|
else: |
|
try: |
|
iteration = model_state_dict['iteration'] |
|
except KeyError: |
|
try: |
|
iteration = model_state_dict['total_iters'] |
|
except KeyError: |
|
print_rank_0('A metadata file exists but unable to load ' |
|
'iteration from checkpoint {}, exiting'.format( |
|
checkpoint_name)) |
|
sys.exit() |
|
|
|
|
|
assert args.consumed_train_samples == 0 |
|
assert args.consumed_valid_samples == 0 |
|
if 'args' in model_state_dict: |
|
checkpoint_args = model_state_dict['args'] |
|
check_checkpoint_args(checkpoint_args) |
|
args.consumed_train_samples = getattr(checkpoint_args, |
|
'consumed_train_samples', 0) |
|
update_num_microbatches(consumed_samples=args.consumed_train_samples) |
|
args.consumed_valid_samples = getattr(checkpoint_args, |
|
'consumed_valid_samples', 0) |
|
else: |
|
print_rank_0('could not find arguments in the checkpoint ...') |
|
|
|
|
|
if len(model) == 1: |
|
model[0].load_state_dict(model_state_dict['model'], strict=strict) |
|
else: |
|
for i in range(len(model)): |
|
mpu.set_virtual_pipeline_model_parallel_rank(i) |
|
model[i].load_state_dict(model_state_dict['model%d' % i], strict=strict) |
|
|
|
|
|
checkpoint_version = get_checkpoint_version() |
|
print_rank_0(f' checkpoint version {checkpoint_version}') |
|
fix_query_key_value_ordering(model, checkpoint_version) |
|
|
|
|
|
if not release and not args.finetune and not args.no_load_optim: |
|
try: |
|
if optimizer is not None: |
|
optimizer.load_state_dict(optim_state_dict['optimizer']) |
|
if opt_param_scheduler is not None: |
|
if 'lr_scheduler' in optim_state_dict: |
|
opt_param_scheduler.load_state_dict(optim_state_dict['lr_scheduler']) |
|
else: |
|
opt_param_scheduler.load_state_dict(optim_state_dict['opt_param_scheduler']) |
|
except KeyError: |
|
print_rank_0('Unable to load optimizer from checkpoint {}. ' |
|
'Specify --no-load-optim or --finetune to prevent ' |
|
'attempting to load the optimizer state, ' |
|
'exiting ...'.format(checkpoint_name)) |
|
sys.exit() |
|
|
|
|
|
if not release and not args.finetune and not args.no_load_rng: |
|
try: |
|
if 'rng_state' in model_state_dict: |
|
|
|
if args.data_parallel_random_init: |
|
|
|
rng_state = model_state_dict['rng_state'][mpu.get_data_parallel_rank()] |
|
else: |
|
rng_state = model_state_dict['rng_state'][0] |
|
random.setstate(rng_state['random_rng_state']) |
|
np.random.set_state(rng_state['np_rng_state']) |
|
torch.set_rng_state(rng_state['torch_rng_state']) |
|
torch.cuda.set_rng_state(rng_state['cuda_rng_state']) |
|
|
|
if not rng_state['rng_tracker_states']: |
|
raise KeyError |
|
mpu.get_cuda_rng_tracker().set_states( |
|
rng_state['rng_tracker_states']) |
|
else: |
|
random.setstate(model_state_dict['random_rng_state']) |
|
np.random.set_state(model_state_dict['np_rng_state']) |
|
torch.set_rng_state(model_state_dict['torch_rng_state']) |
|
torch.cuda.set_rng_state(model_state_dict['cuda_rng_state']) |
|
|
|
if not model_state_dict['rng_tracker_states']: |
|
raise KeyError |
|
mpu.get_cuda_rng_tracker().set_states( |
|
model_state_dict['rng_tracker_states']) |
|
except KeyError: |
|
print_rank_0('Unable to load rng state from checkpoint {}. ' |
|
'Specify --no-load-rng or --finetune to prevent ' |
|
'attempting to load the rng state, ' |
|
'exiting ...'.format(checkpoint_name)) |
|
sys.exit() |
|
|
|
|
|
if torch.distributed.is_initialized(): |
|
torch.distributed.barrier() |
|
|
|
print_rank_0(f' successfully loaded checkpoint from {args.load} ' |
|
f'at iteration {iteration}') |
|
|
|
return iteration |
|
|
|
|
|
def load_biencoder_checkpoint(model, only_query_model=False, |
|
only_context_model=False, custom_load_path=None): |
|
""" |
|
selectively load retrieval models for indexing/retrieving |
|
from saved checkpoints |
|
""" |
|
|
|
args = get_args() |
|
|
|
model = unwrap_model(model) |
|
|
|
load_path = custom_load_path if custom_load_path is not None else args.load |
|
|
|
tracker_filename = get_checkpoint_tracker_filename(load_path) |
|
with open(tracker_filename, 'r') as f: |
|
iteration = int(f.read().strip()) |
|
|
|
checkpoint_name, _ = get_checkpoint_names(load_path, iteration, |
|
args.use_distributed_optimizer, |
|
release=False) |
|
|
|
if mpu.get_data_parallel_rank() == 0: |
|
print('global rank {} is loading checkpoint {}'.format( |
|
torch.distributed.get_rank(), checkpoint_name)) |
|
|
|
state_dict = torch.load(model_checkpoint_name, map_location='cpu') |
|
ret_state_dict = state_dict['model'] |
|
|
|
if only_query_model: |
|
ret_state_dict.pop('context_model') |
|
if only_context_model: |
|
ret_state_dict.pop('query_model') |
|
|
|
assert len(model) == 1 |
|
model[0].load_state_dict(ret_state_dict) |
|
torch.distributed.barrier() |
|
|
|
if mpu.get_data_parallel_rank() == 0: |
|
print(' successfully loaded {}'.format(checkpoint_name)) |
|
|
|
return model |
|
|