Vincentqyw
add: roma
62c7319
raw
history blame
4.6 kB
import torch
from torch import nn
from ..dkm import *
from ..encoders import *
def DKMv3(weights, h, w, symmetric = True, sample_mode= "threshold_balanced", device = None, **kwargs):
if device is None:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
gp_dim = 256
dfn_dim = 384
feat_dim = 256
coordinate_decoder = DFN(
internal_dim=dfn_dim,
feat_input_modules=nn.ModuleDict(
{
"32": nn.Conv2d(512, feat_dim, 1, 1),
"16": nn.Conv2d(512, feat_dim, 1, 1),
}
),
pred_input_modules=nn.ModuleDict(
{
"32": nn.Identity(),
"16": nn.Identity(),
}
),
rrb_d_dict=nn.ModuleDict(
{
"32": RRB(gp_dim + feat_dim, dfn_dim),
"16": RRB(gp_dim + feat_dim, dfn_dim),
}
),
cab_dict=nn.ModuleDict(
{
"32": CAB(2 * dfn_dim, dfn_dim),
"16": CAB(2 * dfn_dim, dfn_dim),
}
),
rrb_u_dict=nn.ModuleDict(
{
"32": RRB(dfn_dim, dfn_dim),
"16": RRB(dfn_dim, dfn_dim),
}
),
terminal_module=nn.ModuleDict(
{
"32": nn.Conv2d(dfn_dim, 3, 1, 1, 0),
"16": nn.Conv2d(dfn_dim, 3, 1, 1, 0),
}
),
)
dw = True
hidden_blocks = 8
kernel_size = 5
displacement_emb = "linear"
conv_refiner = nn.ModuleDict(
{
"16": ConvRefiner(
2 * 512+128+(2*7+1)**2,
2 * 512+128+(2*7+1)**2,
3,
kernel_size=kernel_size,
dw=dw,
hidden_blocks=hidden_blocks,
displacement_emb=displacement_emb,
displacement_emb_dim=128,
local_corr_radius = 7,
corr_in_other = True,
),
"8": ConvRefiner(
2 * 512+64+(2*3+1)**2,
2 * 512+64+(2*3+1)**2,
3,
kernel_size=kernel_size,
dw=dw,
hidden_blocks=hidden_blocks,
displacement_emb=displacement_emb,
displacement_emb_dim=64,
local_corr_radius = 3,
corr_in_other = True,
),
"4": ConvRefiner(
2 * 256+32+(2*2+1)**2,
2 * 256+32+(2*2+1)**2,
3,
kernel_size=kernel_size,
dw=dw,
hidden_blocks=hidden_blocks,
displacement_emb=displacement_emb,
displacement_emb_dim=32,
local_corr_radius = 2,
corr_in_other = True,
),
"2": ConvRefiner(
2 * 64+16,
128+16,
3,
kernel_size=kernel_size,
dw=dw,
hidden_blocks=hidden_blocks,
displacement_emb=displacement_emb,
displacement_emb_dim=16,
),
"1": ConvRefiner(
2 * 3+6,
24,
3,
kernel_size=kernel_size,
dw=dw,
hidden_blocks=hidden_blocks,
displacement_emb=displacement_emb,
displacement_emb_dim=6,
),
}
)
kernel_temperature = 0.2
learn_temperature = False
no_cov = True
kernel = CosKernel
only_attention = False
basis = "fourier"
gp32 = GP(
kernel,
T=kernel_temperature,
learn_temperature=learn_temperature,
only_attention=only_attention,
gp_dim=gp_dim,
basis=basis,
no_cov=no_cov,
)
gp16 = GP(
kernel,
T=kernel_temperature,
learn_temperature=learn_temperature,
only_attention=only_attention,
gp_dim=gp_dim,
basis=basis,
no_cov=no_cov,
)
gps = nn.ModuleDict({"32": gp32, "16": gp16})
proj = nn.ModuleDict(
{"16": nn.Conv2d(1024, 512, 1, 1), "32": nn.Conv2d(2048, 512, 1, 1)}
)
decoder = Decoder(coordinate_decoder, gps, proj, conv_refiner, detach=True)
encoder = ResNet50(pretrained = False, high_res = False, freeze_bn=False)
matcher = RegressionMatcher(encoder, decoder, h=h, w=w, name = "DKMv3", sample_mode=sample_mode, symmetric = symmetric, **kwargs).to(device)
res = matcher.load_state_dict(weights)
return matcher