BrainFM / utils /checkpoint.py
peirong26's picture
Upload 187 files
2571f24 verified
#!/usr/bin/env python3
"""Functions that handle saving and loading of checkpoints."""
import os
import torch
import torch.nn as nn
import utils.distributed as du
import utils.logging as logging
from utils.env import checkpoint_pathmgr as pathmgr
from tabulate import tabulate
logger = logging.get_logger(__name__)
import copy
import logging
import re
from typing import Dict, List
import torch
from tabulate import tabulate
def convert_basic_c2_names(original_keys):
"""
Apply some basic name conversion to names in C2 weights.
It only deals with typical backbone models.
Args:
original_keys (list[str]):
Returns:
list[str]: The same number of strings matching those in original_keys.
"""
layer_keys = copy.deepcopy(original_keys)
layer_keys = [
{"pred_b": "linear_b", "pred_w": "linear_w"}.get(k, k) for k in layer_keys
] # some hard-coded mappings
layer_keys = [k.replace("_", ".") for k in layer_keys]
layer_keys = [re.sub("\\.b$", ".bias", k) for k in layer_keys]
layer_keys = [re.sub("\\.w$", ".weight", k) for k in layer_keys]
# Uniform both bn and gn names to "norm"
layer_keys = [re.sub("bn\\.s$", "norm.weight", k) for k in layer_keys]
layer_keys = [re.sub("bn\\.bias$", "norm.bias", k) for k in layer_keys]
layer_keys = [re.sub("bn\\.rm", "norm.running_mean", k) for k in layer_keys]
layer_keys = [re.sub("bn\\.running.mean$", "norm.running_mean", k) for k in layer_keys]
layer_keys = [re.sub("bn\\.riv$", "norm.running_var", k) for k in layer_keys]
layer_keys = [re.sub("bn\\.running.var$", "norm.running_var", k) for k in layer_keys]
layer_keys = [re.sub("bn\\.gamma$", "norm.weight", k) for k in layer_keys]
layer_keys = [re.sub("bn\\.beta$", "norm.bias", k) for k in layer_keys]
layer_keys = [re.sub("gn\\.s$", "norm.weight", k) for k in layer_keys]
layer_keys = [re.sub("gn\\.bias$", "norm.bias", k) for k in layer_keys]
# stem
layer_keys = [re.sub("^res\\.conv1\\.norm\\.", "conv1.norm.", k) for k in layer_keys]
# to avoid mis-matching with "conv1" in other components (e.g. detection head)
layer_keys = [re.sub("^conv1\\.", "stem.conv1.", k) for k in layer_keys]
# layer1-4 is used by torchvision, however we follow the C2 naming strategy (res2-5)
# layer_keys = [re.sub("^res2.", "layer1.", k) for k in layer_keys]
# layer_keys = [re.sub("^res3.", "layer2.", k) for k in layer_keys]
# layer_keys = [re.sub("^res4.", "layer3.", k) for k in layer_keys]
# layer_keys = [re.sub("^res5.", "layer4.", k) for k in layer_keys]
# blocks
layer_keys = [k.replace(".branch1.", ".shortcut.") for k in layer_keys]
layer_keys = [k.replace(".branch2a.", ".conv1.") for k in layer_keys]
layer_keys = [k.replace(".branch2b.", ".conv2.") for k in layer_keys]
layer_keys = [k.replace(".branch2c.", ".conv3.") for k in layer_keys]
# DensePose substitutions
layer_keys = [re.sub("^body.conv.fcn", "body_conv_fcn", k) for k in layer_keys]
layer_keys = [k.replace("AnnIndex.lowres", "ann_index_lowres") for k in layer_keys]
layer_keys = [k.replace("Index.UV.lowres", "index_uv_lowres") for k in layer_keys]
layer_keys = [k.replace("U.lowres", "u_lowres") for k in layer_keys]
layer_keys = [k.replace("V.lowres", "v_lowres") for k in layer_keys]
return layer_keys
def convert_c2_detectron_names(weights):
"""
Map Caffe2 Detectron weight names to Detectron2 names.
Args:
weights (dict): name -> tensor
Returns:
dict: detectron2 names -> tensor
dict: detectron2 names -> C2 names
"""
logger = logging.getLogger(__name__)
logger.info("Renaming Caffe2 weights ......")
original_keys = sorted(weights.keys())
layer_keys = copy.deepcopy(original_keys)
layer_keys = convert_basic_c2_names(layer_keys)
# --------------------------------------------------------------------------
# RPN hidden representation conv
# --------------------------------------------------------------------------
# FPN case
# In the C2 model, the RPN hidden layer conv is defined for FPN level 2 and then
# shared for all other levels, hence the appearance of "fpn2"
layer_keys = [
k.replace("conv.rpn.fpn2", "proposal_generator.rpn_head.conv") for k in layer_keys
]
# Non-FPN case
layer_keys = [k.replace("conv.rpn", "proposal_generator.rpn_head.conv") for k in layer_keys]
# --------------------------------------------------------------------------
# RPN box transformation conv
# --------------------------------------------------------------------------
# FPN case (see note above about "fpn2")
layer_keys = [
k.replace("rpn.bbox.pred.fpn2", "proposal_generator.rpn_head.anchor_deltas")
for k in layer_keys
]
layer_keys = [
k.replace("rpn.cls.logits.fpn2", "proposal_generator.rpn_head.objectness_logits")
for k in layer_keys
]
# Non-FPN case
layer_keys = [
k.replace("rpn.bbox.pred", "proposal_generator.rpn_head.anchor_deltas") for k in layer_keys
]
layer_keys = [
k.replace("rpn.cls.logits", "proposal_generator.rpn_head.objectness_logits")
for k in layer_keys
]
# --------------------------------------------------------------------------
# Fast R-CNN box head
# --------------------------------------------------------------------------
layer_keys = [re.sub("^bbox\\.pred", "bbox_pred", k) for k in layer_keys]
layer_keys = [re.sub("^cls\\.score", "cls_score", k) for k in layer_keys]
layer_keys = [re.sub("^fc6\\.", "box_head.fc1.", k) for k in layer_keys]
layer_keys = [re.sub("^fc7\\.", "box_head.fc2.", k) for k in layer_keys]
# 4conv1fc head tensor names: head_conv1_w, head_conv1_gn_s
layer_keys = [re.sub("^head\\.conv", "box_head.conv", k) for k in layer_keys]
# --------------------------------------------------------------------------
# FPN lateral and output convolutions
# --------------------------------------------------------------------------
def fpn_map(name):
"""
Look for keys with the following patterns:
1) Starts with "fpn.inner."
Example: "fpn.inner.res2.2.sum.lateral.weight"
Meaning: These are lateral pathway convolutions
2) Starts with "fpn.res"
Example: "fpn.res2.2.sum.weight"
Meaning: These are FPN output convolutions
"""
splits = name.split(".")
norm = ".norm" if "norm" in splits else ""
if name.startswith("fpn.inner."):
# splits example: ['fpn', 'inner', 'res2', '2', 'sum', 'lateral', 'weight']
stage = int(splits[2][len("res") :])
return "fpn_lateral{}{}.{}".format(stage, norm, splits[-1])
elif name.startswith("fpn.res"):
# splits example: ['fpn', 'res2', '2', 'sum', 'weight']
stage = int(splits[1][len("res") :])
return "fpn_output{}{}.{}".format(stage, norm, splits[-1])
return name
layer_keys = [fpn_map(k) for k in layer_keys]
# --------------------------------------------------------------------------
# Mask R-CNN mask head
# --------------------------------------------------------------------------
# roi_heads.StandardROIHeads case
layer_keys = [k.replace(".[mask].fcn", "mask_head.mask_fcn") for k in layer_keys]
layer_keys = [re.sub("^\\.mask\\.fcn", "mask_head.mask_fcn", k) for k in layer_keys]
layer_keys = [k.replace("mask.fcn.logits", "mask_head.predictor") for k in layer_keys]
# roi_heads.Res5ROIHeads case
layer_keys = [k.replace("conv5.mask", "mask_head.deconv") for k in layer_keys]
# --------------------------------------------------------------------------
# Keypoint R-CNN head
# --------------------------------------------------------------------------
# interestingly, the keypoint head convs have blob names that are simply "conv_fcnX"
layer_keys = [k.replace("conv.fcn", "roi_heads.keypoint_head.conv_fcn") for k in layer_keys]
layer_keys = [
k.replace("kps.score.lowres", "roi_heads.keypoint_head.score_lowres") for k in layer_keys
]
layer_keys = [k.replace("kps.score.", "roi_heads.keypoint_head.score.") for k in layer_keys]
# --------------------------------------------------------------------------
# Done with replacements
# --------------------------------------------------------------------------
assert len(set(layer_keys)) == len(layer_keys)
assert len(original_keys) == len(layer_keys)
new_weights = {}
new_keys_to_original_keys = {}
for orig, renamed in zip(original_keys, layer_keys):
new_keys_to_original_keys[renamed] = orig
if renamed.startswith("bbox_pred.") or renamed.startswith("mask_head.predictor."):
# remove the meaningless prediction weight for background class
new_start_idx = 4 if renamed.startswith("bbox_pred.") else 1
new_weights[renamed] = weights[orig][new_start_idx:]
logger.info(
"Remove prediction weight for background class in {}. The shape changes from "
"{} to {}.".format(
renamed, tuple(weights[orig].shape), tuple(new_weights[renamed].shape)
)
)
elif renamed.startswith("cls_score."):
# move weights of bg class from original index 0 to last index
logger.info(
"Move classification weights for background class in {} from index 0 to "
"index {}.".format(renamed, weights[orig].shape[0] - 1)
)
new_weights[renamed] = torch.cat([weights[orig][1:], weights[orig][:1]])
else:
new_weights[renamed] = weights[orig]
return new_weights, new_keys_to_original_keys
def _group_keys_by_module(keys: List[str], original_names: Dict[str, str]):
"""
Params in the same submodule are grouped together.
Args:
keys: names of all parameters
original_names: mapping from parameter name to their name in the checkpoint
Returns:
dict[name -> all other names in the same group]
"""
def _submodule_name(key):
pos = key.rfind(".")
if pos < 0:
return None
prefix = key[: pos + 1]
return prefix
all_submodules = [_submodule_name(k) for k in keys]
all_submodules = [x for x in all_submodules if x]
all_submodules = sorted(all_submodules, key=len)
ret = {}
for prefix in all_submodules:
group = [k for k in keys if k.startswith(prefix)]
if len(group) <= 1:
continue
original_name_lcp = _longest_common_prefix_str([original_names[k] for k in group])
if len(original_name_lcp) == 0:
# don't group weights if original names don't share prefix
continue
for k in group:
if k in ret:
continue
ret[k] = group
return ret
def _longest_common_prefix(names):
"""
["abc.zfg", "abc.zef"] -> "abc."
"""
names = [n.split(".") for n in names]
m1, m2 = min(names), max(names)
ret = [a for a, b in zip(m1, m2) if a == b]
ret = ".".join(ret) + "." if len(ret) else ""
return ret
def _longest_common_prefix_str(names):
m1, m2 = min(names), max(names)
lcp = []
for a, b in zip(m1, m2):
if a == b:
lcp.append(a)
else:
break
lcp = "".join(lcp)
return lcp
def _group_str(names):
"""
Turn "common1", "common2", "common3" into "common{1,2,3}"
"""
lcp = _longest_common_prefix_str(names)
rest = [x[len(lcp) :] for x in names]
rest = "{" + ",".join(rest) + "}"
ret = lcp + rest
# add some simplification for BN specifically
ret = ret.replace("bn_{beta,running_mean,running_var,gamma}", "bn_*")
ret = ret.replace("bn_beta,bn_running_mean,bn_running_var,bn_gamma", "bn_*")
return ret
def make_checkpoint_dir(path_to_job):
"""
Creates the checkpoint directory (if not present already).
Args:
path_to_job (string): the path to the folder of the current job.
"""
checkpoint_dir = os.path.join(path_to_job, "checkpoints")
# Create the checkpoint dir from the master process
if du.is_master_proc() and not pathmgr.exists(checkpoint_dir):
try:
pathmgr.mkdirs(checkpoint_dir)
except Exception:
pass
return checkpoint_dir
def get_checkpoint_dir(path_to_job):
"""
Get path for storing checkpoints.
Args:
path_to_job (string): the path to the folder of the current job.
"""
return os.path.join(path_to_job, "checkpoints")
def get_path_to_checkpoint(path_to_job, epoch):
"""
Get the full path to a checkpoint file.
Args:
path_to_job (string): the path to the folder of the current job.
epoch (int): the number of epoch for the checkpoint.
"""
name = "checkpoint_epoch_{:05d}.pyth".format(epoch)
return os.path.join(get_checkpoint_dir(path_to_job), name)
def get_last_checkpoint(path_to_job):
"""
Get the last checkpoint from the checkpointing folder.
Args:
path_to_job (string): the path to the folder of the current job.
"""
name = "checkpoint_latest.pyth"
return os.path.join(get_checkpoint_dir(path_to_job), name)
def has_checkpoint(path_to_job):
"""
Determines if the given directory contains a checkpoint.
Args:
path_to_job (string): the path to the folder of the current job.
"""
d = get_checkpoint_dir(path_to_job)
files = pathmgr.ls(d) if pathmgr.exists(d) else []
return any("checkpoint" in f for f in files)
def is_checkpoint_epoch(cfg, cur_iter):
"""
Determine if a checkpoint should be saved on current epoch.
Args:
cfg (CfgNode): configs to save.
cur_epoch (int): current number of epoch of the model.
"""
if cur_iter + 1 == cfg.SOLVER.MAX_EPOCH:
return True
return (cur_iter + 1) % cfg.TRAIN.CHECKPOINT_PERIOD == 0
def save_checkpoint(path_to_job, model, optimizer, iter, cfg, scaler=None):
"""
Save a checkpoint.
Args:
model (model): model to save the weight to the checkpoint.
optimizer (optim): optimizer to save the historical state.
epoch (int): current number of epoch of the model.
cfg (CfgNode): configs to save.
scaler (GradScaler): the mixed precision scale.
"""
# Save checkpoints only from the master process.
if not du.is_master_proc(cfg.NUM_GPUS * cfg.NUM_SHARDS):
return
# Ensure that the checkpoint dir exists.
pathmgr.mkdirs(get_checkpoint_dir(path_to_job))
# Omit the DDP wrapper in the multi-gpu setting.
sd = model.module.state_dict() if cfg.NUM_GPUS > 1 else model.state_dict()
# Record the state.
checkpoint = {
"epoch": iter,
"model_state": sd,
"optimizer_state": optimizer.state_dict(),
"cfg": cfg.dump(),
}
if scaler is not None:
checkpoint["scaler_state"] = scaler.state_dict()
# Write the current epoch checkpoint & update the latest epoch checkpoint
path_to_checkpoint = get_path_to_checkpoint(path_to_job, iter + 1)
with pathmgr.open(path_to_checkpoint, "wb") as f:
torch.save(checkpoint, f)
path_to_latest_checkpoint = get_last_checkpoint(path_to_job)
with pathmgr.open(path_to_latest_checkpoint, "wb") as f:
torch.save(checkpoint, f)
return path_to_checkpoint
def load_checkpoint(
path_to_checkpoint,
models,
optimizer = None,
model_keys = ['model'],
exclude_key = None,
to_match = {},
to_print = True,
):
"""
Load the checkpoint from the given file.
"""
assert pathmgr.exists(path_to_checkpoint), "Checkpoint '{}' not found".format(
path_to_checkpoint
)
if to_print:
logger.info("Loading network weights from {}.".format(path_to_checkpoint))
# Load the checkpoint on CPU to avoid GPU mem spike.
def find_model_key(keys, model_key):
for k in keys:
if model_key in k:
return k
for k in keys:
if 'model' in k:
if to_print:
logger.info('Have not found model state_dict according to the given key, but using the "model" as key instead!')
return k
with pathmgr.open(path_to_checkpoint, "rb") as f:
checkpoint = torch.load(f, map_location="cpu")
for i, model in enumerate(models):
ms = model
#ms = model.module if data_parallel else model # Account for the DDP wrapper in the multi-gpu setting.
model_dict = ms.state_dict()
k = find_model_key(checkpoint.keys(), model_keys[i])
pre_train_dict = checkpoint[k]
ms.load_state_dict(align_and_update_state_dicts(model_dict, pre_train_dict, exclude_key = exclude_key, to_print = to_print, to_match = to_match), strict=False)
if optimizer and 'optimizaer' in checkpoint:
optimizer.load_state_dict(checkpoint['optimizer'])
best_val_stats = checkpoint['best_val_stats'] if 'best_val_stats' in checkpoint else None
return checkpoint['epoch'], best_val_stats
def load_test_checkpoint(cfg, model):
"""
Loading checkpoint logic for testing.
"""
# Load a checkpoint to test if applicable.
if cfg.TEST.CHECKPOINT_FILE_PATH != "":
load_checkpoint(
cfg.TEST.CHECKPOINT_FILE_PATH,
model,
cfg.NUM_GPUS > 1,
None,
squeeze_temporal=cfg.TEST.CHECKPOINT_SQUEEZE_TEMPORAL,
)
elif has_checkpoint(cfg.OUTPUT_DIR):
last_checkpoint = get_last_checkpoint(cfg.OUTPUT_DIR)
load_checkpoint(last_checkpoint, model, cfg.NUM_GPUS > 1)
elif cfg.TRAIN.CHECKPOINT_FILE_PATH != "":
# If no checkpoint found in TEST.CHECKPOINT_FILE_PATH or in the current
# checkpoint folder, try to load checkpoint from
# TRAIN.CHECKPOINT_FILE_PATH and test it.
load_checkpoint(
cfg.TRAIN.CHECKPOINT_FILE_PATH,
model,
cfg.NUM_GPUS > 1,
None,
)
else:
logger.info(
"Unknown way of loading checkpoint. Using random initialization, only for debugging."
)
def load_train_checkpoint(cfg, model, optimizer, scaler=None):
"""
Loading checkpoint logic for training.
"""
if cfg.TRAIN.AUTO_RESUME and has_checkpoint(cfg.OUTPUT_DIR):
last_checkpoint = get_last_checkpoint(cfg.OUTPUT_DIR)
logger.info("Load from last checkpoint, {}.".format(last_checkpoint))
checkpoint_epoch = load_checkpoint(
last_checkpoint, model, cfg.NUM_GPUS > 1, optimizer, scaler=scaler
)
start_epoch = checkpoint_epoch + 1
elif cfg.TRAIN.CHECKPOINT_FILE_PATH != "" and cfg.TRAIN.FINETUNE:
logger.info("Finetune from given checkpoint file.")
checkpoint_epoch = load_checkpoint(
cfg.TRAIN.CHECKPOINT_FILE_PATH,
model,
cfg.NUM_GPUS > 1,
optimizer,
scaler=scaler,
epoch_reset=cfg.TRAIN.CHECKPOINT_EPOCH_RESET,
freeze_pretrain=cfg.TRAIN.FREEZE_PRETRAIN,
)
start_epoch = checkpoint_epoch + 1 if cfg.TRAIN.FINETUNE_START_EPOCH == 0 else cfg.TRAIN.FINETUNE_START_EPOCH
elif cfg.TRAIN.CHECKPOINT_FILE_PATH != "":
logger.info("Load from given checkpoint file.")
checkpoint_epoch = load_checkpoint(
cfg.TRAIN.CHECKPOINT_FILE_PATH,
model,
cfg.NUM_GPUS > 1,
optimizer,
scaler=scaler,
epoch_reset=cfg.TRAIN.CHECKPOINT_EPOCH_RESET,
)
start_epoch = checkpoint_epoch + 1
else:
start_epoch = 0
return start_epoch
# Note the current matching is not symmetric.
# it assumes model_state_dict will have longer names.
def align_and_update_state_dicts(model_state_dict, ckpt_state_dict, exclude_key = None, to_print = True, to_match = {}):
"""
Match names between the two state-dict, and returns a new chkpt_state_dict with names
converted to match model_state_dict with heuristics. The returned dict can be later
loaded with fvcore checkpointer.
"""
if exclude_key is not None:
model_keys = sorted([k for k in model_state_dict.keys() if exclude_key not in k])
else:
model_keys = sorted(model_state_dict.keys())
original_keys = {x: x for x in ckpt_state_dict.keys()}
ckpt_keys = sorted(ckpt_state_dict.keys())
def in_to_match(a, b):
for k in to_match.keys():
c = b.replace(k, to_match[k])
if a == c or a.endswith("." + c):
return True
return False
def match(a, b):
if (a == b or a.endswith("." + b) or in_to_match(a, b)) and to_print:
print('matched')
print(a, '--', b)
return a == b or a.endswith("." + b) or in_to_match(a, b)
# get a matrix of string matches, where each (i, j) entry correspond to the size of the
# ckpt_key string, if it matches
match_matrix = [len(j) if match(i, j) else 0 for i in model_keys for j in ckpt_keys]
match_matrix = torch.as_tensor(match_matrix).view(len(model_keys), len(ckpt_keys))
# use the matched one with longest size in case of multiple matches
max_match_size, idxs = match_matrix.max(1)
# remove indices that correspond to no-match
idxs[max_match_size == 0] = -1
#logger = logging.getLogger(__name__)
# matched_pairs (matched checkpoint key --> matched model key)
matched_keys = {}
result_state_dict = {}
for idx_model, idx_ckpt in enumerate(idxs.tolist()):
if idx_ckpt == -1:
continue
key_model = model_keys[idx_model]
key_ckpt = ckpt_keys[idx_ckpt]
value_ckpt = ckpt_state_dict[key_ckpt]
shape_in_model = model_state_dict[key_model].shape
if shape_in_model != value_ckpt.shape:
logger.warning(
"Shape of {} in checkpoint is {}, while shape of {} in model is {}.".format(
key_ckpt, value_ckpt.shape, key_model, shape_in_model
)
)
if shape_in_model[0] != value_ckpt.shape[0] and len(shape_in_model) == len(value_ckpt.shape): # different embed_dim setup
logger.warning(
"{} will not be loaded. Please double check and see if this is desired.".format(
key_ckpt
)
)
logger.warning('--- shape_in_model: {}'.format(shape_in_model))
logger.warning('--- ckpt shape: {}'.format(value_ckpt.shape))
else:
logger.warning(
"{} will be loaded for the center frame with the weights from the 2D conv layers in pre-trained models and\
initialize other weights as zero. Please double check and see if this is desired.".format(
key_ckpt
)
)
assert key_model not in result_state_dict
logger.warning('--- shape_in_model: {}'.format(shape_in_model))
logger.warning('--- ckpt shape: {}'.format(value_ckpt.shape))
# load pre-trained 2D weights on the parameters' center termporal frame while others as 0. (B, C, (T,) H, W)
nn.init.constant_(model_state_dict[key_model], 0.0)
model_state_dict[key_model][:, :, int(shape_in_model[2] / 2)] = value_ckpt
result_state_dict[key_model] = model_state_dict[key_model]
logger.warning('--- loaded to T: {}'.format(int(shape_in_model[2] / 2)))
logger.warning('--- reshaped ckpt: {}'.format(result_state_dict[key_model].shape))
matched_keys[key_ckpt] = key_model
else:
assert key_model not in result_state_dict
result_state_dict[key_model] = value_ckpt
if key_ckpt in matched_keys: # already added to matched_keys
logger.error(
"Ambiguity found for {} in checkpoint!"
"It matches at least two keys in the model ({} and {}).".format(
key_ckpt, key_model, matched_keys[key_ckpt]
)
)
raise ValueError("Cannot match one checkpoint key to multiple keys in the model.")
if to_print:
logger.info('Matching {} to {}'.format(key_ckpt, key_model))
matched_keys[key_ckpt] = key_model
# logging:
matched_model_keys = sorted(matched_keys.values())
if len(matched_model_keys) == 0:
logger.warning("No weights in checkpoint matched with model.")
return ckpt_state_dict
common_prefix = _longest_common_prefix(matched_model_keys)
rev_matched_keys = {v: k for k, v in matched_keys.items()}
original_keys = {k: original_keys[rev_matched_keys[k]] for k in matched_model_keys}
model_key_groups = _group_keys_by_module(matched_model_keys, original_keys)
table = []
memo = set()
for key_model in matched_model_keys:
if to_print:
print(' matched:', key_model)
if key_model in memo:
continue
if key_model in model_key_groups:
group = model_key_groups[key_model]
memo |= set(group)
shapes = [tuple(model_state_dict[k].shape) for k in group]
table.append(
(
_longest_common_prefix([k[len(common_prefix) :] for k in group]) + "*",
_group_str([original_keys[k] for k in group]),
" ".join([str(x).replace(" ", "") for x in shapes]),
)
)
else:
key_checkpoint = original_keys[key_model]
shape = str(tuple(model_state_dict[key_model].shape))
table.append((key_model[len(common_prefix) :], key_checkpoint, shape))
table_str = tabulate(
table, tablefmt="pipe", headers=["Names in Model", "Names in Checkpoint", "Shapes"]
)
if to_print:
logger.info(
"Following weights matched with "
+ (f"submodule {common_prefix[:-1]}" if common_prefix else "model")
+ ":\n"
+ table_str
)
unmatched_ckpt_keys = [k for k in ckpt_keys if k not in set(matched_keys.keys())]
unmatched_model_keys = [k for k in model_keys if k not in set(matched_keys.values())]
#for k in unmatched_ckpt_keys:
#result_state_dict[k] = ckpt_state_dict[k]
#result_state_dict[k] = model_state_dict[k]
#logger.info('unmatched:', k)
for k in unmatched_model_keys:
#logger.info('unmatched:', k)
result_state_dict[k] = model_state_dict[k]
return result_state_dict