Spaces:
Sleeping
Sleeping
import torch | |
from torch import nn | |
from ..dkm import * | |
from ..encoders import * | |
def DKMv3( | |
weights, | |
h, | |
w, | |
symmetric=True, | |
sample_mode="threshold_balanced", | |
device=None, | |
**kwargs | |
): | |
if device is None: | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
gp_dim = 256 | |
dfn_dim = 384 | |
feat_dim = 256 | |
coordinate_decoder = DFN( | |
internal_dim=dfn_dim, | |
feat_input_modules=nn.ModuleDict( | |
{ | |
"32": nn.Conv2d(512, feat_dim, 1, 1), | |
"16": nn.Conv2d(512, feat_dim, 1, 1), | |
} | |
), | |
pred_input_modules=nn.ModuleDict( | |
{ | |
"32": nn.Identity(), | |
"16": nn.Identity(), | |
} | |
), | |
rrb_d_dict=nn.ModuleDict( | |
{ | |
"32": RRB(gp_dim + feat_dim, dfn_dim), | |
"16": RRB(gp_dim + feat_dim, dfn_dim), | |
} | |
), | |
cab_dict=nn.ModuleDict( | |
{ | |
"32": CAB(2 * dfn_dim, dfn_dim), | |
"16": CAB(2 * dfn_dim, dfn_dim), | |
} | |
), | |
rrb_u_dict=nn.ModuleDict( | |
{ | |
"32": RRB(dfn_dim, dfn_dim), | |
"16": RRB(dfn_dim, dfn_dim), | |
} | |
), | |
terminal_module=nn.ModuleDict( | |
{ | |
"32": nn.Conv2d(dfn_dim, 3, 1, 1, 0), | |
"16": nn.Conv2d(dfn_dim, 3, 1, 1, 0), | |
} | |
), | |
) | |
dw = True | |
hidden_blocks = 8 | |
kernel_size = 5 | |
displacement_emb = "linear" | |
conv_refiner = nn.ModuleDict( | |
{ | |
"16": ConvRefiner( | |
2 * 512 + 128 + (2 * 7 + 1) ** 2, | |
2 * 512 + 128 + (2 * 7 + 1) ** 2, | |
3, | |
kernel_size=kernel_size, | |
dw=dw, | |
hidden_blocks=hidden_blocks, | |
displacement_emb=displacement_emb, | |
displacement_emb_dim=128, | |
local_corr_radius=7, | |
corr_in_other=True, | |
), | |
"8": ConvRefiner( | |
2 * 512 + 64 + (2 * 3 + 1) ** 2, | |
2 * 512 + 64 + (2 * 3 + 1) ** 2, | |
3, | |
kernel_size=kernel_size, | |
dw=dw, | |
hidden_blocks=hidden_blocks, | |
displacement_emb=displacement_emb, | |
displacement_emb_dim=64, | |
local_corr_radius=3, | |
corr_in_other=True, | |
), | |
"4": ConvRefiner( | |
2 * 256 + 32 + (2 * 2 + 1) ** 2, | |
2 * 256 + 32 + (2 * 2 + 1) ** 2, | |
3, | |
kernel_size=kernel_size, | |
dw=dw, | |
hidden_blocks=hidden_blocks, | |
displacement_emb=displacement_emb, | |
displacement_emb_dim=32, | |
local_corr_radius=2, | |
corr_in_other=True, | |
), | |
"2": ConvRefiner( | |
2 * 64 + 16, | |
128 + 16, | |
3, | |
kernel_size=kernel_size, | |
dw=dw, | |
hidden_blocks=hidden_blocks, | |
displacement_emb=displacement_emb, | |
displacement_emb_dim=16, | |
), | |
"1": ConvRefiner( | |
2 * 3 + 6, | |
24, | |
3, | |
kernel_size=kernel_size, | |
dw=dw, | |
hidden_blocks=hidden_blocks, | |
displacement_emb=displacement_emb, | |
displacement_emb_dim=6, | |
), | |
} | |
) | |
kernel_temperature = 0.2 | |
learn_temperature = False | |
no_cov = True | |
kernel = CosKernel | |
only_attention = False | |
basis = "fourier" | |
gp32 = GP( | |
kernel, | |
T=kernel_temperature, | |
learn_temperature=learn_temperature, | |
only_attention=only_attention, | |
gp_dim=gp_dim, | |
basis=basis, | |
no_cov=no_cov, | |
) | |
gp16 = GP( | |
kernel, | |
T=kernel_temperature, | |
learn_temperature=learn_temperature, | |
only_attention=only_attention, | |
gp_dim=gp_dim, | |
basis=basis, | |
no_cov=no_cov, | |
) | |
gps = nn.ModuleDict({"32": gp32, "16": gp16}) | |
proj = nn.ModuleDict( | |
{"16": nn.Conv2d(1024, 512, 1, 1), "32": nn.Conv2d(2048, 512, 1, 1)} | |
) | |
decoder = Decoder(coordinate_decoder, gps, proj, conv_refiner, detach=True) | |
encoder = ResNet50(pretrained=False, high_res=False, freeze_bn=False) | |
matcher = RegressionMatcher( | |
encoder, | |
decoder, | |
h=h, | |
w=w, | |
name="DKMv3", | |
sample_mode=sample_mode, | |
symmetric=symmetric, | |
**kwargs | |
).to(device) | |
res = matcher.load_state_dict(weights) | |
return matcher | |