Realcat
add: efficientloftr
e02ffe6
raw
history blame
1.29 kB
import torch.nn as nn
import torch.nn.functional as F
from .repvgg import create_RepVGG
class RepVGG_8_1_align(nn.Module):
"""
RepVGG backbone, output resolution are 1/8 and 1.
Each block has 2 layers.
"""
def __init__(self, config):
super().__init__()
backbone = create_RepVGG(False)
self.layer0, self.layer1, self.layer2, self.layer3 = backbone.stage0, backbone.stage1, backbone.stage2, backbone.stage3
for layer in [self.layer0, self.layer1, self.layer2, self.layer3]:
for m in layer.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def forward(self, x):
out = self.layer0(x) # 1/2
for module in self.layer1:
out = module(out) # 1/2
x1 = out
for module in self.layer2:
out = module(out) # 1/4
x2 = out
for module in self.layer3:
out = module(out) # 1/8
x3 = out
return {'feats_c': x3, 'feats_f': None, 'feats_x2': x2, 'feats_x1': x1}