import warnings import torch.nn as nn from roma.models.matcher import * from roma.models.transformer import Block, TransformerDecoder, MemEffAttention from roma.models.encoders import * def roma_model( resolution, upsample_preds, device=None, weights=None, dinov2_weights=None, **kwargs ): # roma weights and dinov2 weights are loaded seperately, as dinov2 weights are not parameters torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn warnings.filterwarnings( "ignore", category=UserWarning, message="TypedStorage is deprecated" ) gp_dim = 512 feat_dim = 512 decoder_dim = gp_dim + feat_dim cls_to_coord_res = 64 coordinate_decoder = TransformerDecoder( nn.Sequential( *[Block(decoder_dim, 8, attn_class=MemEffAttention) for _ in range(5)] ), decoder_dim, cls_to_coord_res**2 + 1, is_classifier=True, amp=True, pos_enc=False, ) dw = True hidden_blocks = 8 kernel_size = 5 displacement_emb = "linear" disable_local_corr_grad = True conv_refiner = nn.ModuleDict( { "16": ConvRefiner( 2 * 512 + 128 + (2 * 7 + 1) ** 2, 2 * 512 + 128 + (2 * 7 + 1) ** 2, 2 + 1, kernel_size=kernel_size, dw=dw, hidden_blocks=hidden_blocks, displacement_emb=displacement_emb, displacement_emb_dim=128, local_corr_radius=7, corr_in_other=True, amp=True, disable_local_corr_grad=disable_local_corr_grad, bn_momentum=0.01, ), "8": ConvRefiner( 2 * 512 + 64 + (2 * 3 + 1) ** 2, 2 * 512 + 64 + (2 * 3 + 1) ** 2, 2 + 1, kernel_size=kernel_size, dw=dw, hidden_blocks=hidden_blocks, displacement_emb=displacement_emb, displacement_emb_dim=64, local_corr_radius=3, corr_in_other=True, amp=True, disable_local_corr_grad=disable_local_corr_grad, bn_momentum=0.01, ), "4": ConvRefiner( 2 * 256 + 32 + (2 * 2 + 1) ** 2, 2 * 256 + 32 + (2 * 2 + 1) ** 2, 2 + 1, kernel_size=kernel_size, dw=dw, hidden_blocks=hidden_blocks, displacement_emb=displacement_emb, displacement_emb_dim=32, local_corr_radius=2, corr_in_other=True, amp=True, disable_local_corr_grad=disable_local_corr_grad, bn_momentum=0.01, ), "2": ConvRefiner( 2 * 64 + 16, 128 + 16, 2 + 1, kernel_size=kernel_size, dw=dw, hidden_blocks=hidden_blocks, displacement_emb=displacement_emb, displacement_emb_dim=16, amp=True, disable_local_corr_grad=disable_local_corr_grad, bn_momentum=0.01, ), "1": ConvRefiner( 2 * 9 + 6, 24, 2 + 1, kernel_size=kernel_size, dw=dw, hidden_blocks=hidden_blocks, displacement_emb=displacement_emb, displacement_emb_dim=6, amp=True, disable_local_corr_grad=disable_local_corr_grad, bn_momentum=0.01, ), } ) kernel_temperature = 0.2 learn_temperature = False no_cov = True kernel = CosKernel only_attention = False basis = "fourier" gp16 = GP( kernel, T=kernel_temperature, learn_temperature=learn_temperature, only_attention=only_attention, gp_dim=gp_dim, basis=basis, no_cov=no_cov, ) gps = nn.ModuleDict({"16": gp16}) proj16 = nn.Sequential(nn.Conv2d(1024, 512, 1, 1), nn.BatchNorm2d(512)) proj8 = nn.Sequential(nn.Conv2d(512, 512, 1, 1), nn.BatchNorm2d(512)) proj4 = nn.Sequential(nn.Conv2d(256, 256, 1, 1), nn.BatchNorm2d(256)) proj2 = nn.Sequential(nn.Conv2d(128, 64, 1, 1), nn.BatchNorm2d(64)) proj1 = nn.Sequential(nn.Conv2d(64, 9, 1, 1), nn.BatchNorm2d(9)) proj = nn.ModuleDict( { "16": proj16, "8": proj8, "4": proj4, "2": proj2, "1": proj1, } ) displacement_dropout_p = 0.0 gm_warp_dropout_p = 0.0 decoder = Decoder( coordinate_decoder, gps, proj, conv_refiner, detach=True, scales=["16", "8", "4", "2", "1"], displacement_dropout_p=displacement_dropout_p, gm_warp_dropout_p=gm_warp_dropout_p, ) encoder = CNNandDinov2( cnn_kwargs=dict(pretrained=False, amp=True), amp=True, use_vgg=True, dinov2_weights=dinov2_weights, ) h, w = resolution symmetric = True attenuate_cert = True matcher = RegressionMatcher( encoder, decoder, h=h, w=w, upsample_preds=upsample_preds, symmetric=symmetric, attenuate_cert=attenuate_cert, **kwargs ).to(device) matcher.load_state_dict(weights) return matcher