eP-ALM / TimeSformer /timesformer /utils /c2_model_loading.py
mshukor
init
3eb682b
raw
history blame
No virus
4.98 kB
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
"""Caffe2 to PyTorch checkpoint name converting utility."""
import re
def get_name_convert_func():
"""
Get the function to convert Caffe2 layer names to PyTorch layer names.
Returns:
(func): function to convert parameter name from Caffe2 format to PyTorch
format.
"""
pairs = [
# ------------------------------------------------------------
# 'nonlocal_conv3_1_theta_w' -> 's3.pathway0_nonlocal3.conv_g.weight'
[
r"^nonlocal_conv([0-9]+)_([0-9]+)_(.*)",
r"s\1.pathway0_nonlocal\2_\3",
],
# 'theta' -> 'conv_theta'
[r"^(.*)_nonlocal([0-9]+)_(theta)(.*)", r"\1_nonlocal\2.conv_\3\4"],
# 'g' -> 'conv_g'
[r"^(.*)_nonlocal([0-9]+)_(g)(.*)", r"\1_nonlocal\2.conv_\3\4"],
# 'phi' -> 'conv_phi'
[r"^(.*)_nonlocal([0-9]+)_(phi)(.*)", r"\1_nonlocal\2.conv_\3\4"],
# 'out' -> 'conv_out'
[r"^(.*)_nonlocal([0-9]+)_(out)(.*)", r"\1_nonlocal\2.conv_\3\4"],
# 'nonlocal_conv4_5_bn_s' -> 's4.pathway0_nonlocal3.bn.weight'
[r"^(.*)_nonlocal([0-9]+)_(bn)_(.*)", r"\1_nonlocal\2.\3.\4"],
# ------------------------------------------------------------
# 't_pool1_subsample_bn' -> 's1_fuse.conv_f2s.bn.running_mean'
[r"^t_pool1_subsample_bn_(.*)", r"s1_fuse.bn.\1"],
# 't_pool1_subsample' -> 's1_fuse.conv_f2s'
[r"^t_pool1_subsample_(.*)", r"s1_fuse.conv_f2s.\1"],
# 't_res4_5_branch2c_bn_subsample_bn_rm' -> 's4_fuse.conv_f2s.bias'
[
r"^t_res([0-9]+)_([0-9]+)_branch2c_bn_subsample_bn_(.*)",
r"s\1_fuse.bn.\3",
],
# 't_pool1_subsample' -> 's1_fuse.conv_f2s'
[
r"^t_res([0-9]+)_([0-9]+)_branch2c_bn_subsample_(.*)",
r"s\1_fuse.conv_f2s.\3",
],
# ------------------------------------------------------------
# 'res4_4_branch_2c_bn_b' -> 's4.pathway0_res4.branch2.c_bn_b'
[
r"^res([0-9]+)_([0-9]+)_branch([0-9]+)([a-z])_(.*)",
r"s\1.pathway0_res\2.branch\3.\4_\5",
],
# 'res_conv1_bn_' -> 's1.pathway0_stem.bn.'
[r"^res_conv1_bn_(.*)", r"s1.pathway0_stem.bn.\1"],
# 'conv1_xy_w_momentum' -> 's1.pathway0_stem.conv_xy.'
[r"^conv1_xy(.*)", r"s1.pathway0_stem.conv_xy\1"],
# 'conv1_w_momentum' -> 's1.pathway0_stem.conv.'
[r"^conv1_(.*)", r"s1.pathway0_stem.conv.\1"],
# 'res4_0_branch1_w' -> 'S4.pathway0_res0.branch1.weight'
[
r"^res([0-9]+)_([0-9]+)_branch([0-9]+)_(.*)",
r"s\1.pathway0_res\2.branch\3_\4",
],
# 'res_conv1_' -> 's1.pathway0_stem.conv.'
[r"^res_conv1_(.*)", r"s1.pathway0_stem.conv.\1"],
# ------------------------------------------------------------
# 'res4_4_branch_2c_bn_b' -> 's4.pathway0_res4.branch2.c_bn_b'
[
r"^t_res([0-9]+)_([0-9]+)_branch([0-9]+)([a-z])_(.*)",
r"s\1.pathway1_res\2.branch\3.\4_\5",
],
# 'res_conv1_bn_' -> 's1.pathway0_stem.bn.'
[r"^t_res_conv1_bn_(.*)", r"s1.pathway1_stem.bn.\1"],
# 'conv1_w_momentum' -> 's1.pathway0_stem.conv.'
[r"^t_conv1_(.*)", r"s1.pathway1_stem.conv.\1"],
# 'res4_0_branch1_w' -> 'S4.pathway0_res0.branch1.weight'
[
r"^t_res([0-9]+)_([0-9]+)_branch([0-9]+)_(.*)",
r"s\1.pathway1_res\2.branch\3_\4",
],
# 'res_conv1_' -> 's1.pathway0_stem.conv.'
[r"^t_res_conv1_(.*)", r"s1.pathway1_stem.conv.\1"],
# ------------------------------------------------------------
# pred_ -> head.projection.
[r"pred_(.*)", r"head.projection.\1"],
# '.b_bn_fc' -> '.se.fc'
[r"(.*)b_bn_fc(.*)", r"\1se.fc\2"],
# conv_5 -> head.conv_5.
[r"conv_5(.*)", r"head.conv_5\1"],
# conv_5 -> head.conv_5.
[r"lin_5(.*)", r"head.lin_5\1"],
# '.bn_b' -> '.weight'
[r"(.*)bn.b\Z", r"\1bn.bias"],
# '.bn_s' -> '.weight'
[r"(.*)bn.s\Z", r"\1bn.weight"],
# '_bn_rm' -> '.running_mean'
[r"(.*)bn.rm\Z", r"\1bn.running_mean"],
# '_bn_riv' -> '.running_var'
[r"(.*)bn.riv\Z", r"\1bn.running_var"],
# '_b' -> '.bias'
[r"(.*)[\._]b\Z", r"\1.bias"],
# '_w' -> '.weight'
[r"(.*)[\._]w\Z", r"\1.weight"],
]
def convert_caffe2_name_to_pytorch(caffe2_layer_name):
"""
Convert the caffe2_layer_name to pytorch format by apply the list of
regular expressions.
Args:
caffe2_layer_name (str): caffe2 layer name.
Returns:
(str): pytorch layer name.
"""
for source, dest in pairs:
caffe2_layer_name = re.sub(source, dest, caffe2_layer_name)
return caffe2_layer_name
return convert_caffe2_name_to_pytorch