|
|
|
|
|
"""Functions that handle saving and loading of checkpoints.""" |
|
|
|
import copy |
|
import numpy as np |
|
import os |
|
import pickle |
|
from collections import OrderedDict |
|
import torch |
|
from fvcore.common.file_io import PathManager |
|
|
|
import timesformer.utils.distributed as du |
|
import timesformer.utils.logging as logging |
|
from timesformer.utils.c2_model_loading import get_name_convert_func |
|
import torch.nn.functional as F |
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
def make_checkpoint_dir(path_to_job): |
|
""" |
|
Creates the checkpoint directory (if not present already). |
|
Args: |
|
path_to_job (string): the path to the folder of the current job. |
|
""" |
|
checkpoint_dir = os.path.join(path_to_job, "checkpoints") |
|
|
|
if du.is_master_proc() and not PathManager.exists(checkpoint_dir): |
|
try: |
|
PathManager.mkdirs(checkpoint_dir) |
|
except Exception: |
|
pass |
|
return checkpoint_dir |
|
|
|
|
|
def get_checkpoint_dir(path_to_job): |
|
""" |
|
Get path for storing checkpoints. |
|
Args: |
|
path_to_job (string): the path to the folder of the current job. |
|
""" |
|
return os.path.join(path_to_job, "checkpoints") |
|
|
|
|
|
def get_path_to_checkpoint(path_to_job, epoch): |
|
""" |
|
Get the full path to a checkpoint file. |
|
Args: |
|
path_to_job (string): the path to the folder of the current job. |
|
epoch (int): the number of epoch for the checkpoint. |
|
""" |
|
name = "checkpoint_epoch_{:05d}.pyth".format(epoch) |
|
return os.path.join(get_checkpoint_dir(path_to_job), name) |
|
|
|
|
|
def get_last_checkpoint(path_to_job): |
|
""" |
|
Get the last checkpoint from the checkpointing folder. |
|
Args: |
|
path_to_job (string): the path to the folder of the current job. |
|
""" |
|
|
|
d = get_checkpoint_dir(path_to_job) |
|
names = PathManager.ls(d) if PathManager.exists(d) else [] |
|
names = [f for f in names if "checkpoint" in f] |
|
assert len(names), "No checkpoints found in '{}'.".format(d) |
|
|
|
name = sorted(names)[-1] |
|
return os.path.join(d, name) |
|
|
|
|
|
def has_checkpoint(path_to_job): |
|
""" |
|
Determines if the given directory contains a checkpoint. |
|
Args: |
|
path_to_job (string): the path to the folder of the current job. |
|
""" |
|
d = get_checkpoint_dir(path_to_job) |
|
files = PathManager.ls(d) if PathManager.exists(d) else [] |
|
return any("checkpoint" in f for f in files) |
|
|
|
|
|
def is_checkpoint_epoch(cfg, cur_epoch, multigrid_schedule=None): |
|
""" |
|
Determine if a checkpoint should be saved on current epoch. |
|
Args: |
|
cfg (CfgNode): configs to save. |
|
cur_epoch (int): current number of epoch of the model. |
|
multigrid_schedule (List): schedule for multigrid training. |
|
""" |
|
if cur_epoch + 1 == cfg.SOLVER.MAX_EPOCH: |
|
return True |
|
if multigrid_schedule is not None: |
|
prev_epoch = 0 |
|
for s in multigrid_schedule: |
|
if cur_epoch < s[-1]: |
|
period = max( |
|
(s[-1] - prev_epoch) // cfg.MULTIGRID.EVAL_FREQ + 1, 1 |
|
) |
|
return (s[-1] - 1 - cur_epoch) % period == 0 |
|
prev_epoch = s[-1] |
|
|
|
return (cur_epoch + 1) % cfg.TRAIN.CHECKPOINT_PERIOD == 0 |
|
|
|
|
|
def save_checkpoint(path_to_job, model, optimizer, epoch, cfg): |
|
""" |
|
Save a checkpoint. |
|
Args: |
|
model (model): model to save the weight to the checkpoint. |
|
optimizer (optim): optimizer to save the historical state. |
|
epoch (int): current number of epoch of the model. |
|
cfg (CfgNode): configs to save. |
|
""" |
|
|
|
if not du.is_master_proc(cfg.NUM_GPUS * cfg.NUM_SHARDS): |
|
return |
|
|
|
PathManager.mkdirs(get_checkpoint_dir(path_to_job)) |
|
|
|
sd = model.module.state_dict() if cfg.NUM_GPUS > 1 else model.state_dict() |
|
normalized_sd = sub_to_normal_bn(sd) |
|
|
|
|
|
checkpoint = { |
|
"epoch": epoch, |
|
"model_state": normalized_sd, |
|
"optimizer_state": optimizer.state_dict(), |
|
"cfg": cfg.dump(), |
|
} |
|
|
|
path_to_checkpoint = get_path_to_checkpoint(path_to_job, epoch + 1) |
|
with PathManager.open(path_to_checkpoint, "wb") as f: |
|
torch.save(checkpoint, f) |
|
return path_to_checkpoint |
|
|
|
|
|
def inflate_weight(state_dict_2d, state_dict_3d): |
|
""" |
|
Inflate 2D model weights in state_dict_2d to the 3D model weights in |
|
state_dict_3d. The details can be found in: |
|
Joao Carreira, and Andrew Zisserman. |
|
"Quo vadis, action recognition? a new model and the kinetics dataset." |
|
Args: |
|
state_dict_2d (OrderedDict): a dict of parameters from a 2D model. |
|
state_dict_3d (OrderedDict): a dict of parameters from a 3D model. |
|
Returns: |
|
state_dict_inflated (OrderedDict): a dict of inflated parameters. |
|
""" |
|
state_dict_inflated = OrderedDict() |
|
|
|
|
|
|
|
for k, v2d in state_dict_2d.items(): |
|
assert k in state_dict_3d.keys() |
|
v3d = state_dict_3d[k] |
|
|
|
if len(v2d.shape) == 4 and len(v3d.shape) == 5: |
|
logger.info( |
|
"Inflate {}: {} -> {}: {}".format(k, v2d.shape, k, v3d.shape) |
|
) |
|
|
|
try: |
|
assert v2d.shape[-2:] == v3d.shape[-2:] |
|
assert v2d.shape[:2] == v3d.shape[:2] |
|
v3d = ( |
|
v2d.unsqueeze(2).repeat(1, 1, v3d.shape[2], 1, 1) / v3d.shape[2] |
|
) |
|
except: |
|
temp = ( |
|
v2d.unsqueeze(2).repeat(1, 1, v3d.shape[2], 1, 1) / v3d.shape[2] |
|
) |
|
v3d = torch.zeros(v3d.shape) |
|
v3d[:,:v2d.shape[1],:,:,:] = temp |
|
|
|
|
|
elif v2d.shape == v3d.shape: |
|
v3d = v2d |
|
else: |
|
logger.info( |
|
"Unexpected {}: {} -|> {}: {}".format( |
|
k, v2d.shape, k, v3d.shape |
|
) |
|
) |
|
state_dict_inflated[k] = v3d.clone() |
|
return state_dict_inflated |
|
|
|
|
|
def load_checkpoint( |
|
path_to_checkpoint, |
|
model, |
|
data_parallel=True, |
|
optimizer=None, |
|
inflation=False, |
|
convert_from_caffe2=False, |
|
epoch_reset=False, |
|
clear_name_pattern=(), |
|
): |
|
""" |
|
Load the checkpoint from the given file. If inflation is True, inflate the |
|
2D Conv weights from the checkpoint to 3D Conv. |
|
Args: |
|
path_to_checkpoint (string): path to the checkpoint to load. |
|
model (model): model to load the weights from the checkpoint. |
|
data_parallel (bool): if true, model is wrapped by |
|
torch.nn.parallel.DistributedDataParallel. |
|
optimizer (optim): optimizer to load the historical state. |
|
inflation (bool): if True, inflate the weights from the checkpoint. |
|
convert_from_caffe2 (bool): if True, load the model from caffe2 and |
|
convert it to pytorch. |
|
epoch_reset (bool): if True, reset #train iterations from the checkpoint. |
|
clear_name_pattern (string): if given, this (sub)string will be cleared |
|
from a layer name if it can be matched. |
|
Returns: |
|
(int): the number of training epoch of the checkpoint. |
|
""" |
|
assert PathManager.exists( |
|
path_to_checkpoint |
|
), "Checkpoint '{}' not found".format(path_to_checkpoint) |
|
logger.info("Loading network weights from {}.".format(path_to_checkpoint)) |
|
|
|
|
|
try: |
|
ms = model.module if data_parallel else model |
|
except: |
|
ms = model |
|
|
|
if convert_from_caffe2: |
|
with PathManager.open(path_to_checkpoint, "rb") as f: |
|
caffe2_checkpoint = pickle.load(f, encoding="latin1") |
|
state_dict = OrderedDict() |
|
name_convert_func = get_name_convert_func() |
|
for key in caffe2_checkpoint["blobs"].keys(): |
|
converted_key = name_convert_func(key) |
|
converted_key = c2_normal_to_sub_bn(converted_key, ms.state_dict()) |
|
if converted_key in ms.state_dict(): |
|
c2_blob_shape = caffe2_checkpoint["blobs"][key].shape |
|
model_blob_shape = ms.state_dict()[converted_key].shape |
|
|
|
|
|
if len(c2_blob_shape) < len(model_blob_shape): |
|
c2_blob_shape += (1,) * ( |
|
len(model_blob_shape) - len(c2_blob_shape) |
|
) |
|
caffe2_checkpoint["blobs"][key] = np.reshape( |
|
caffe2_checkpoint["blobs"][key], c2_blob_shape |
|
) |
|
|
|
if ( |
|
len(model_blob_shape) == 1 |
|
and len(c2_blob_shape) == 1 |
|
and model_blob_shape[0] > c2_blob_shape[0] |
|
and model_blob_shape[0] % c2_blob_shape[0] == 0 |
|
): |
|
caffe2_checkpoint["blobs"][key] = np.concatenate( |
|
[caffe2_checkpoint["blobs"][key]] |
|
* (model_blob_shape[0] // c2_blob_shape[0]) |
|
) |
|
c2_blob_shape = caffe2_checkpoint["blobs"][key].shape |
|
|
|
if c2_blob_shape == tuple(model_blob_shape): |
|
state_dict[converted_key] = torch.tensor( |
|
caffe2_checkpoint["blobs"][key] |
|
).clone() |
|
logger.info( |
|
"{}: {} => {}: {}".format( |
|
key, |
|
c2_blob_shape, |
|
converted_key, |
|
tuple(model_blob_shape), |
|
) |
|
) |
|
else: |
|
logger.warn( |
|
"!! {}: {} does not match {}: {}".format( |
|
key, |
|
c2_blob_shape, |
|
converted_key, |
|
tuple(model_blob_shape), |
|
) |
|
) |
|
else: |
|
if not any( |
|
prefix in key for prefix in ["momentum", "lr", "model_iter"] |
|
): |
|
logger.warn( |
|
"!! {}: can not be converted, got {}".format( |
|
key, converted_key |
|
) |
|
) |
|
diff = set(ms.state_dict()) - set(state_dict) |
|
diff = {d for d in diff if "num_batches_tracked" not in d} |
|
if len(diff) > 0: |
|
logger.warn("Not loaded {}".format(diff)) |
|
ms.load_state_dict(state_dict, strict=False) |
|
epoch = -1 |
|
else: |
|
|
|
with PathManager.open(path_to_checkpoint, "rb") as f: |
|
checkpoint = torch.load(f, map_location="cpu") |
|
try: |
|
|
|
model_state_dict_3d = ( |
|
model.module.state_dict() if data_parallel else model.state_dict() |
|
) |
|
checkpoint["model_state"] = normal_to_sub_bn( |
|
checkpoint["model_state"], model_state_dict_3d |
|
) |
|
except: |
|
|
|
model_state_dict_3d = model.state_dict() |
|
checkpoint["model_state"] = normal_to_sub_bn( |
|
checkpoint["model_state"], model_state_dict_3d |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if inflation: |
|
|
|
inflated_model_dict = inflate_weight( |
|
checkpoint["model_state"], model_state_dict_3d |
|
) |
|
ms.load_state_dict(inflated_model_dict, strict=False) |
|
else: |
|
if clear_name_pattern: |
|
for item in clear_name_pattern: |
|
model_state_dict_new = OrderedDict() |
|
for k in checkpoint["model_state"]: |
|
if item in k: |
|
k_re = k.replace(item, "") |
|
model_state_dict_new[k_re] = checkpoint[ |
|
"model_state" |
|
][k] |
|
logger.info("renaming: {} -> {}".format(k, k_re)) |
|
else: |
|
model_state_dict_new[k] = checkpoint["model_state"][ |
|
k |
|
] |
|
checkpoint["model_state"] = model_state_dict_new |
|
|
|
pre_train_dict = checkpoint["model_state"] |
|
model_dict = ms.state_dict() |
|
|
|
if 'model.time_embed' in pre_train_dict: |
|
k = 'model.time_embed' |
|
v = pre_train_dict[k] |
|
v = v[0,:,:].unsqueeze(0).transpose(1,2) |
|
new_v = F.interpolate(v, size=(model_dict[k].size(1)), mode='nearest') |
|
pre_train_dict[k] = new_v.transpose(1,2) |
|
|
|
|
|
|
|
pre_train_dict_match = { |
|
k: v |
|
for k, v in pre_train_dict.items() |
|
if k in model_dict and v.size() == model_dict[k].size() |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
not_load_layers = [ |
|
k |
|
for k in model_dict.keys() |
|
if k not in pre_train_dict_match.keys() |
|
] |
|
|
|
if not_load_layers: |
|
for k in not_load_layers: |
|
logger.info("Network weights {} not loaded.".format(k)) |
|
|
|
ms.load_state_dict(pre_train_dict_match, strict=False) |
|
epoch = -1 |
|
|
|
|
|
if "epoch" in checkpoint.keys() and not epoch_reset: |
|
epoch = checkpoint["epoch"] |
|
if optimizer: |
|
optimizer.load_state_dict(checkpoint["optimizer_state"]) |
|
else: |
|
epoch = -1 |
|
return epoch |
|
|
|
|
|
def sub_to_normal_bn(sd): |
|
""" |
|
Convert the Sub-BN paprameters to normal BN parameters in a state dict. |
|
There are two copies of BN layers in a Sub-BN implementation: `bn.bn` and |
|
`bn.split_bn`. `bn.split_bn` is used during training and |
|
"compute_precise_bn". Before saving or evaluation, its stats are copied to |
|
`bn.bn`. We rename `bn.bn` to `bn` and store it to be consistent with normal |
|
BN layers. |
|
Args: |
|
sd (OrderedDict): a dict of parameters whitch might contain Sub-BN |
|
parameters. |
|
Returns: |
|
new_sd (OrderedDict): a dict with Sub-BN parameters reshaped to |
|
normal parameters. |
|
""" |
|
new_sd = copy.deepcopy(sd) |
|
modifications = [ |
|
("bn.bn.running_mean", "bn.running_mean"), |
|
("bn.bn.running_var", "bn.running_var"), |
|
("bn.split_bn.num_batches_tracked", "bn.num_batches_tracked"), |
|
] |
|
to_remove = ["bn.bn.", ".split_bn."] |
|
for key in sd: |
|
for before, after in modifications: |
|
if key.endswith(before): |
|
new_key = key.split(before)[0] + after |
|
new_sd[new_key] = new_sd.pop(key) |
|
|
|
for rm in to_remove: |
|
if rm in key and key in new_sd: |
|
del new_sd[key] |
|
|
|
for key in new_sd: |
|
if key.endswith("bn.weight") or key.endswith("bn.bias"): |
|
if len(new_sd[key].size()) == 4: |
|
assert all(d == 1 for d in new_sd[key].size()[1:]) |
|
new_sd[key] = new_sd[key][:, 0, 0, 0] |
|
|
|
return new_sd |
|
|
|
|
|
def c2_normal_to_sub_bn(key, model_keys): |
|
""" |
|
Convert BN parameters to Sub-BN parameters if model contains Sub-BNs. |
|
Args: |
|
key (OrderedDict): source dict of parameters. |
|
mdoel_key (OrderedDict): target dict of parameters. |
|
Returns: |
|
new_sd (OrderedDict): converted dict of parameters. |
|
""" |
|
if "bn.running_" in key: |
|
if key in model_keys: |
|
return key |
|
|
|
new_key = key.replace("bn.running_", "bn.split_bn.running_") |
|
if new_key in model_keys: |
|
return new_key |
|
else: |
|
return key |
|
|
|
|
|
def normal_to_sub_bn(checkpoint_sd, model_sd): |
|
""" |
|
Convert BN parameters to Sub-BN parameters if model contains Sub-BNs. |
|
Args: |
|
checkpoint_sd (OrderedDict): source dict of parameters. |
|
model_sd (OrderedDict): target dict of parameters. |
|
Returns: |
|
new_sd (OrderedDict): converted dict of parameters. |
|
""" |
|
for key in model_sd: |
|
if key not in checkpoint_sd: |
|
if "bn.split_bn." in key: |
|
load_key = key.replace("bn.split_bn.", "bn.") |
|
bn_key = key.replace("bn.split_bn.", "bn.bn.") |
|
checkpoint_sd[key] = checkpoint_sd.pop(load_key) |
|
checkpoint_sd[bn_key] = checkpoint_sd[key] |
|
|
|
for key in model_sd: |
|
if key in checkpoint_sd: |
|
model_blob_shape = model_sd[key].shape |
|
c2_blob_shape = checkpoint_sd[key].shape |
|
|
|
if ( |
|
len(model_blob_shape) == 1 |
|
and len(c2_blob_shape) == 1 |
|
and model_blob_shape[0] > c2_blob_shape[0] |
|
and model_blob_shape[0] % c2_blob_shape[0] == 0 |
|
): |
|
before_shape = checkpoint_sd[key].shape |
|
checkpoint_sd[key] = torch.cat( |
|
[checkpoint_sd[key]] |
|
* (model_blob_shape[0] // c2_blob_shape[0]) |
|
) |
|
logger.info( |
|
"{} {} -> {}".format( |
|
key, before_shape, checkpoint_sd[key].shape |
|
) |
|
) |
|
return checkpoint_sd |
|
|
|
|
|
def load_test_checkpoint(cfg, model): |
|
""" |
|
Loading checkpoint logic for testing. |
|
""" |
|
|
|
if cfg.TEST.CHECKPOINT_FILE_PATH != "": |
|
|
|
|
|
|
|
load_checkpoint( |
|
cfg.TEST.CHECKPOINT_FILE_PATH, |
|
model, |
|
cfg.NUM_GPUS > 1, |
|
None, |
|
inflation=False, |
|
convert_from_caffe2=cfg.TEST.CHECKPOINT_TYPE == "caffe2", |
|
) |
|
elif has_checkpoint(cfg.OUTPUT_DIR): |
|
last_checkpoint = get_last_checkpoint(cfg.OUTPUT_DIR) |
|
load_checkpoint(last_checkpoint, model, cfg.NUM_GPUS > 1) |
|
elif cfg.TRAIN.CHECKPOINT_FILE_PATH != "": |
|
|
|
|
|
|
|
load_checkpoint( |
|
cfg.TRAIN.CHECKPOINT_FILE_PATH, |
|
model, |
|
cfg.NUM_GPUS > 1, |
|
None, |
|
inflation=False, |
|
convert_from_caffe2=cfg.TRAIN.CHECKPOINT_TYPE == "caffe2", |
|
) |
|
else: |
|
logger.info( |
|
"Unknown way of loading checkpoint. Using with random initialization, only for debugging." |
|
) |
|
|
|
|
|
def load_train_checkpoint(cfg, model, optimizer): |
|
""" |
|
Loading checkpoint logic for training. |
|
""" |
|
if cfg.TRAIN.AUTO_RESUME and has_checkpoint(cfg.OUTPUT_DIR): |
|
last_checkpoint = get_last_checkpoint(cfg.OUTPUT_DIR) |
|
logger.info("Load from last checkpoint, {}.".format(last_checkpoint)) |
|
checkpoint_epoch = load_checkpoint( |
|
last_checkpoint, model, cfg.NUM_GPUS > 1, optimizer |
|
) |
|
start_epoch = checkpoint_epoch + 1 |
|
elif cfg.TRAIN.CHECKPOINT_FILE_PATH != "": |
|
logger.info("Load from given checkpoint file.") |
|
checkpoint_epoch = load_checkpoint( |
|
cfg.TRAIN.CHECKPOINT_FILE_PATH, |
|
model, |
|
cfg.NUM_GPUS > 1, |
|
optimizer, |
|
inflation=cfg.TRAIN.CHECKPOINT_INFLATE, |
|
convert_from_caffe2=cfg.TRAIN.CHECKPOINT_TYPE == "caffe2", |
|
epoch_reset=cfg.TRAIN.CHECKPOINT_EPOCH_RESET, |
|
clear_name_pattern=cfg.TRAIN.CHECKPOINT_CLEAR_NAME_PATTERN, |
|
) |
|
start_epoch = checkpoint_epoch + 1 |
|
else: |
|
start_epoch = 0 |
|
|
|
return start_epoch |
|
|