Vincentqyw
fix: roma
c74a070
raw
history blame contribute delete
No virus
22.8 kB
import torch
import torch.nn as nn
from dkm import *
from .local_corr import LocalCorr
from .corr_channels import NormedCorr
from torchvision.models import resnet as tv_resnet
dkm_pretrained_urls = {
"DKM": {
"mega_synthetic": "https://github.com/Parskatt/storage/releases/download/dkm_mega_synthetic/dkm_mega_synthetic.pth",
"mega": "https://github.com/Parskatt/storage/releases/download/dkm_mega/dkm_mega.pth",
},
"DKMv2": {
"outdoor": "https://github.com/Parskatt/storage/releases/download/dkmv2/dkm_v2_outdoor.pth",
"indoor": "https://github.com/Parskatt/storage/releases/download/dkmv2/dkm_v2_indoor.pth",
},
}
def DKM(pretrained=True, version="mega_synthetic", device=None):
if device is None:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
gp_dim = 256
dfn_dim = 384
feat_dim = 256
coordinate_decoder = DFN(
internal_dim=dfn_dim,
feat_input_modules=nn.ModuleDict(
{
"32": nn.Conv2d(512, feat_dim, 1, 1),
"16": nn.Conv2d(512, feat_dim, 1, 1),
}
),
pred_input_modules=nn.ModuleDict(
{
"32": nn.Identity(),
"16": nn.Identity(),
}
),
rrb_d_dict=nn.ModuleDict(
{
"32": RRB(gp_dim + feat_dim, dfn_dim),
"16": RRB(gp_dim + feat_dim, dfn_dim),
}
),
cab_dict=nn.ModuleDict(
{
"32": CAB(2 * dfn_dim, dfn_dim),
"16": CAB(2 * dfn_dim, dfn_dim),
}
),
rrb_u_dict=nn.ModuleDict(
{
"32": RRB(dfn_dim, dfn_dim),
"16": RRB(dfn_dim, dfn_dim),
}
),
terminal_module=nn.ModuleDict(
{
"32": nn.Conv2d(dfn_dim, 3, 1, 1, 0),
"16": nn.Conv2d(dfn_dim, 3, 1, 1, 0),
}
),
)
dw = True
hidden_blocks = 8
kernel_size = 5
conv_refiner = nn.ModuleDict(
{
"16": ConvRefiner(
2 * 512,
1024,
3,
kernel_size=kernel_size,
dw=dw,
hidden_blocks=hidden_blocks,
),
"8": ConvRefiner(
2 * 512,
1024,
3,
kernel_size=kernel_size,
dw=dw,
hidden_blocks=hidden_blocks,
),
"4": ConvRefiner(
2 * 256,
512,
3,
kernel_size=kernel_size,
dw=dw,
hidden_blocks=hidden_blocks,
),
"2": ConvRefiner(
2 * 64,
128,
3,
kernel_size=kernel_size,
dw=dw,
hidden_blocks=hidden_blocks,
),
"1": ConvRefiner(
2 * 3,
24,
3,
kernel_size=kernel_size,
dw=dw,
hidden_blocks=hidden_blocks,
),
}
)
kernel_temperature = 0.2
learn_temperature = False
no_cov = True
kernel = CosKernel
only_attention = False
basis = "fourier"
gp32 = GP(
kernel,
T=kernel_temperature,
learn_temperature=learn_temperature,
only_attention=only_attention,
gp_dim=gp_dim,
basis=basis,
no_cov=no_cov,
)
gp16 = GP(
kernel,
T=kernel_temperature,
learn_temperature=learn_temperature,
only_attention=only_attention,
gp_dim=gp_dim,
basis=basis,
no_cov=no_cov,
)
gps = nn.ModuleDict({"32": gp32, "16": gp16})
proj = nn.ModuleDict(
{"16": nn.Conv2d(1024, 512, 1, 1), "32": nn.Conv2d(2048, 512, 1, 1)}
)
decoder = Decoder(coordinate_decoder, gps, proj, conv_refiner, detach=True)
h, w = 384, 512
encoder = Encoder(
tv_resnet.resnet50(pretrained=not pretrained),
) # only load pretrained weights if not loading a pretrained matcher ;)
matcher = RegressionMatcher(encoder, decoder, h=h, w=w).to(device)
if pretrained:
weights = torch.hub.load_state_dict_from_url(
dkm_pretrained_urls["DKM"][version]
)
matcher.load_state_dict(weights)
return matcher
def DKMv2(pretrained=True, version="outdoor", resolution="low", **kwargs):
gp_dim = 256
dfn_dim = 384
feat_dim = 256
coordinate_decoder = DFN(
internal_dim=dfn_dim,
feat_input_modules=nn.ModuleDict(
{
"32": nn.Conv2d(512, feat_dim, 1, 1),
"16": nn.Conv2d(512, feat_dim, 1, 1),
}
),
pred_input_modules=nn.ModuleDict(
{
"32": nn.Identity(),
"16": nn.Identity(),
}
),
rrb_d_dict=nn.ModuleDict(
{
"32": RRB(gp_dim + feat_dim, dfn_dim),
"16": RRB(gp_dim + feat_dim, dfn_dim),
}
),
cab_dict=nn.ModuleDict(
{
"32": CAB(2 * dfn_dim, dfn_dim),
"16": CAB(2 * dfn_dim, dfn_dim),
}
),
rrb_u_dict=nn.ModuleDict(
{
"32": RRB(dfn_dim, dfn_dim),
"16": RRB(dfn_dim, dfn_dim),
}
),
terminal_module=nn.ModuleDict(
{
"32": nn.Conv2d(dfn_dim, 3, 1, 1, 0),
"16": nn.Conv2d(dfn_dim, 3, 1, 1, 0),
}
),
)
dw = True
hidden_blocks = 8
kernel_size = 5
displacement_emb = "linear"
conv_refiner = nn.ModuleDict(
{
"16": ConvRefiner(
2 * 512 + 128,
1024 + 128,
3,
kernel_size=kernel_size,
dw=dw,
hidden_blocks=hidden_blocks,
displacement_emb=displacement_emb,
displacement_emb_dim=128,
),
"8": ConvRefiner(
2 * 512 + 64,
1024 + 64,
3,
kernel_size=kernel_size,
dw=dw,
hidden_blocks=hidden_blocks,
displacement_emb=displacement_emb,
displacement_emb_dim=64,
),
"4": ConvRefiner(
2 * 256 + 32,
512 + 32,
3,
kernel_size=kernel_size,
dw=dw,
hidden_blocks=hidden_blocks,
displacement_emb=displacement_emb,
displacement_emb_dim=32,
),
"2": ConvRefiner(
2 * 64 + 16,
128 + 16,
3,
kernel_size=kernel_size,
dw=dw,
hidden_blocks=hidden_blocks,
displacement_emb=displacement_emb,
displacement_emb_dim=16,
),
"1": ConvRefiner(
2 * 3 + 6,
24,
3,
kernel_size=kernel_size,
dw=dw,
hidden_blocks=hidden_blocks,
displacement_emb=displacement_emb,
displacement_emb_dim=6,
),
}
)
kernel_temperature = 0.2
learn_temperature = False
no_cov = True
kernel = CosKernel
only_attention = False
basis = "fourier"
gp32 = GP(
kernel,
T=kernel_temperature,
learn_temperature=learn_temperature,
only_attention=only_attention,
gp_dim=gp_dim,
basis=basis,
no_cov=no_cov,
)
gp16 = GP(
kernel,
T=kernel_temperature,
learn_temperature=learn_temperature,
only_attention=only_attention,
gp_dim=gp_dim,
basis=basis,
no_cov=no_cov,
)
gps = nn.ModuleDict({"32": gp32, "16": gp16})
proj = nn.ModuleDict(
{"16": nn.Conv2d(1024, 512, 1, 1), "32": nn.Conv2d(2048, 512, 1, 1)}
)
decoder = Decoder(coordinate_decoder, gps, proj, conv_refiner, detach=True)
if resolution == "low":
h, w = 384, 512
elif resolution == "high":
h, w = 480, 640
encoder = Encoder(
tv_resnet.resnet50(pretrained=not pretrained),
) # only load pretrained weights if not loading a pretrained matcher ;)
matcher = RegressionMatcher(encoder, decoder, h=h, w=w, **kwargs).to(device)
if pretrained:
try:
weights = torch.hub.load_state_dict_from_url(
dkm_pretrained_urls["DKMv2"][version]
)
except:
weights = torch.load(dkm_pretrained_urls["DKMv2"][version])
matcher.load_state_dict(weights)
return matcher
def local_corr(pretrained=True, version="mega_synthetic"):
gp_dim = 256
dfn_dim = 384
feat_dim = 256
coordinate_decoder = DFN(
internal_dim=dfn_dim,
feat_input_modules=nn.ModuleDict(
{
"32": nn.Conv2d(512, feat_dim, 1, 1),
"16": nn.Conv2d(512, feat_dim, 1, 1),
}
),
pred_input_modules=nn.ModuleDict(
{
"32": nn.Identity(),
"16": nn.Identity(),
}
),
rrb_d_dict=nn.ModuleDict(
{
"32": RRB(gp_dim + feat_dim, dfn_dim),
"16": RRB(gp_dim + feat_dim, dfn_dim),
}
),
cab_dict=nn.ModuleDict(
{
"32": CAB(2 * dfn_dim, dfn_dim),
"16": CAB(2 * dfn_dim, dfn_dim),
}
),
rrb_u_dict=nn.ModuleDict(
{
"32": RRB(dfn_dim, dfn_dim),
"16": RRB(dfn_dim, dfn_dim),
}
),
terminal_module=nn.ModuleDict(
{
"32": nn.Conv2d(dfn_dim, 3, 1, 1, 0),
"16": nn.Conv2d(dfn_dim, 3, 1, 1, 0),
}
),
)
dw = True
hidden_blocks = 8
kernel_size = 5
conv_refiner = nn.ModuleDict(
{
"16": LocalCorr(
81,
81 * 12,
3,
kernel_size=kernel_size,
dw=dw,
hidden_blocks=hidden_blocks,
),
"8": LocalCorr(
81,
81 * 12,
3,
kernel_size=kernel_size,
dw=dw,
hidden_blocks=hidden_blocks,
),
"4": LocalCorr(
81,
81 * 6,
3,
kernel_size=kernel_size,
dw=dw,
hidden_blocks=hidden_blocks,
),
"2": LocalCorr(
81,
81,
3,
kernel_size=kernel_size,
dw=dw,
hidden_blocks=hidden_blocks,
),
"1": ConvRefiner(
2 * 3,
24,
3,
kernel_size=kernel_size,
dw=dw,
hidden_blocks=hidden_blocks,
),
}
)
kernel_temperature = 0.2
learn_temperature = False
no_cov = True
kernel = CosKernel
only_attention = False
basis = "fourier"
gp32 = GP(
kernel,
T=kernel_temperature,
learn_temperature=learn_temperature,
only_attention=only_attention,
gp_dim=gp_dim,
basis=basis,
no_cov=no_cov,
)
gp16 = GP(
kernel,
T=kernel_temperature,
learn_temperature=learn_temperature,
only_attention=only_attention,
gp_dim=gp_dim,
basis=basis,
no_cov=no_cov,
)
gps = nn.ModuleDict({"32": gp32, "16": gp16})
proj = nn.ModuleDict(
{"16": nn.Conv2d(1024, 512, 1, 1), "32": nn.Conv2d(2048, 512, 1, 1)}
)
decoder = Decoder(coordinate_decoder, gps, proj, conv_refiner, detach=True)
h, w = 384, 512
encoder = Encoder(
tv_resnet.resnet50(pretrained=not pretrained)
) # only load pretrained weights if not loading a pretrained matcher ;)
matcher = RegressionMatcher(encoder, decoder, h=h, w=w).to(device)
if pretrained:
weights = torch.hub.load_state_dict_from_url(
dkm_pretrained_urls["local_corr"][version]
)
matcher.load_state_dict(weights)
return matcher
def corr_channels(pretrained=True, version="mega_synthetic"):
h, w = 384, 512
gp_dim = (h // 32) * (w // 32), (h // 16) * (w // 16)
dfn_dim = 384
feat_dim = 256
coordinate_decoder = DFN(
internal_dim=dfn_dim,
feat_input_modules=nn.ModuleDict(
{
"32": nn.Conv2d(512, feat_dim, 1, 1),
"16": nn.Conv2d(512, feat_dim, 1, 1),
}
),
pred_input_modules=nn.ModuleDict(
{
"32": nn.Identity(),
"16": nn.Identity(),
}
),
rrb_d_dict=nn.ModuleDict(
{
"32": RRB(gp_dim[0] + feat_dim, dfn_dim),
"16": RRB(gp_dim[1] + feat_dim, dfn_dim),
}
),
cab_dict=nn.ModuleDict(
{
"32": CAB(2 * dfn_dim, dfn_dim),
"16": CAB(2 * dfn_dim, dfn_dim),
}
),
rrb_u_dict=nn.ModuleDict(
{
"32": RRB(dfn_dim, dfn_dim),
"16": RRB(dfn_dim, dfn_dim),
}
),
terminal_module=nn.ModuleDict(
{
"32": nn.Conv2d(dfn_dim, 3, 1, 1, 0),
"16": nn.Conv2d(dfn_dim, 3, 1, 1, 0),
}
),
)
dw = True
hidden_blocks = 8
kernel_size = 5
conv_refiner = nn.ModuleDict(
{
"16": ConvRefiner(
2 * 512,
1024,
3,
kernel_size=kernel_size,
dw=dw,
hidden_blocks=hidden_blocks,
),
"8": ConvRefiner(
2 * 512,
1024,
3,
kernel_size=kernel_size,
dw=dw,
hidden_blocks=hidden_blocks,
),
"4": ConvRefiner(
2 * 256,
512,
3,
kernel_size=kernel_size,
dw=dw,
hidden_blocks=hidden_blocks,
),
"2": ConvRefiner(
2 * 64,
128,
3,
kernel_size=kernel_size,
dw=dw,
hidden_blocks=hidden_blocks,
),
"1": ConvRefiner(
2 * 3,
24,
3,
kernel_size=kernel_size,
dw=dw,
hidden_blocks=hidden_blocks,
),
}
)
gp32 = NormedCorr()
gp16 = NormedCorr()
gps = nn.ModuleDict({"32": gp32, "16": gp16})
proj = nn.ModuleDict(
{"16": nn.Conv2d(1024, 512, 1, 1), "32": nn.Conv2d(2048, 512, 1, 1)}
)
decoder = Decoder(coordinate_decoder, gps, proj, conv_refiner, detach=True)
h, w = 384, 512
encoder = Encoder(
tv_resnet.resnet50(pretrained=not pretrained)
) # only load pretrained weights if not loading a pretrained matcher ;)
matcher = RegressionMatcher(encoder, decoder, h=h, w=w).to(device)
if pretrained:
weights = torch.hub.load_state_dict_from_url(
dkm_pretrained_urls["corr_channels"][version]
)
matcher.load_state_dict(weights)
return matcher
def baseline(pretrained=True, version="mega_synthetic"):
h, w = 384, 512
gp_dim = (h // 32) * (w // 32), (h // 16) * (w // 16)
dfn_dim = 384
feat_dim = 256
coordinate_decoder = DFN(
internal_dim=dfn_dim,
feat_input_modules=nn.ModuleDict(
{
"32": nn.Conv2d(512, feat_dim, 1, 1),
"16": nn.Conv2d(512, feat_dim, 1, 1),
}
),
pred_input_modules=nn.ModuleDict(
{
"32": nn.Identity(),
"16": nn.Identity(),
}
),
rrb_d_dict=nn.ModuleDict(
{
"32": RRB(gp_dim[0] + feat_dim, dfn_dim),
"16": RRB(gp_dim[1] + feat_dim, dfn_dim),
}
),
cab_dict=nn.ModuleDict(
{
"32": CAB(2 * dfn_dim, dfn_dim),
"16": CAB(2 * dfn_dim, dfn_dim),
}
),
rrb_u_dict=nn.ModuleDict(
{
"32": RRB(dfn_dim, dfn_dim),
"16": RRB(dfn_dim, dfn_dim),
}
),
terminal_module=nn.ModuleDict(
{
"32": nn.Conv2d(dfn_dim, 3, 1, 1, 0),
"16": nn.Conv2d(dfn_dim, 3, 1, 1, 0),
}
),
)
dw = True
hidden_blocks = 8
kernel_size = 5
conv_refiner = nn.ModuleDict(
{
"16": LocalCorr(
81,
81 * 12,
3,
kernel_size=kernel_size,
dw=dw,
hidden_blocks=hidden_blocks,
),
"8": LocalCorr(
81,
81 * 12,
3,
kernel_size=kernel_size,
dw=dw,
hidden_blocks=hidden_blocks,
),
"4": LocalCorr(
81,
81 * 6,
3,
kernel_size=kernel_size,
dw=dw,
hidden_blocks=hidden_blocks,
),
"2": LocalCorr(
81,
81,
3,
kernel_size=kernel_size,
dw=dw,
hidden_blocks=hidden_blocks,
),
"1": ConvRefiner(
2 * 3,
24,
3,
kernel_size=kernel_size,
dw=dw,
hidden_blocks=hidden_blocks,
),
}
)
gp32 = NormedCorr()
gp16 = NormedCorr()
gps = nn.ModuleDict({"32": gp32, "16": gp16})
proj = nn.ModuleDict(
{"16": nn.Conv2d(1024, 512, 1, 1), "32": nn.Conv2d(2048, 512, 1, 1)}
)
decoder = Decoder(coordinate_decoder, gps, proj, conv_refiner, detach=True)
h, w = 384, 512
encoder = Encoder(
tv_resnet.resnet50(pretrained=not pretrained)
) # only load pretrained weights if not loading a pretrained matcher ;)
matcher = RegressionMatcher(encoder, decoder, h=h, w=w).to(device)
if pretrained:
weights = torch.hub.load_state_dict_from_url(
dkm_pretrained_urls["baseline"][version]
)
matcher.load_state_dict(weights)
return matcher
def linear(pretrained=True, version="mega_synthetic"):
gp_dim = 256
dfn_dim = 384
feat_dim = 256
coordinate_decoder = DFN(
internal_dim=dfn_dim,
feat_input_modules=nn.ModuleDict(
{
"32": nn.Conv2d(512, feat_dim, 1, 1),
"16": nn.Conv2d(512, feat_dim, 1, 1),
}
),
pred_input_modules=nn.ModuleDict(
{
"32": nn.Identity(),
"16": nn.Identity(),
}
),
rrb_d_dict=nn.ModuleDict(
{
"32": RRB(gp_dim + feat_dim, dfn_dim),
"16": RRB(gp_dim + feat_dim, dfn_dim),
}
),
cab_dict=nn.ModuleDict(
{
"32": CAB(2 * dfn_dim, dfn_dim),
"16": CAB(2 * dfn_dim, dfn_dim),
}
),
rrb_u_dict=nn.ModuleDict(
{
"32": RRB(dfn_dim, dfn_dim),
"16": RRB(dfn_dim, dfn_dim),
}
),
terminal_module=nn.ModuleDict(
{
"32": nn.Conv2d(dfn_dim, 3, 1, 1, 0),
"16": nn.Conv2d(dfn_dim, 3, 1, 1, 0),
}
),
)
dw = True
hidden_blocks = 8
kernel_size = 5
conv_refiner = nn.ModuleDict(
{
"16": ConvRefiner(
2 * 512,
1024,
3,
kernel_size=kernel_size,
dw=dw,
hidden_blocks=hidden_blocks,
),
"8": ConvRefiner(
2 * 512,
1024,
3,
kernel_size=kernel_size,
dw=dw,
hidden_blocks=hidden_blocks,
),
"4": ConvRefiner(
2 * 256,
512,
3,
kernel_size=kernel_size,
dw=dw,
hidden_blocks=hidden_blocks,
),
"2": ConvRefiner(
2 * 64,
128,
3,
kernel_size=kernel_size,
dw=dw,
hidden_blocks=hidden_blocks,
),
"1": ConvRefiner(
2 * 3,
24,
3,
kernel_size=kernel_size,
dw=dw,
hidden_blocks=hidden_blocks,
),
}
)
kernel_temperature = 0.2
learn_temperature = False
no_cov = True
kernel = CosKernel
only_attention = False
basis = "linear"
gp32 = GP(
kernel,
T=kernel_temperature,
learn_temperature=learn_temperature,
only_attention=only_attention,
gp_dim=gp_dim,
basis=basis,
no_cov=no_cov,
)
gp16 = GP(
kernel,
T=kernel_temperature,
learn_temperature=learn_temperature,
only_attention=only_attention,
gp_dim=gp_dim,
basis=basis,
no_cov=no_cov,
)
gps = nn.ModuleDict({"32": gp32, "16": gp16})
proj = nn.ModuleDict(
{"16": nn.Conv2d(1024, 512, 1, 1), "32": nn.Conv2d(2048, 512, 1, 1)}
)
decoder = Decoder(coordinate_decoder, gps, proj, conv_refiner, detach=True)
h, w = 384, 512
encoder = Encoder(
tv_resnet.resnet50(pretrained=not pretrained)
) # only load pretrained weights if not loading a pretrained matcher ;)
matcher = RegressionMatcher(encoder, decoder, h=h, w=w).to(device)
if pretrained:
weights = torch.hub.load_state_dict_from_url(
dkm_pretrained_urls["linear"][version]
)
matcher.load_state_dict(weights)
return matcher