Spaces:
Running
Running
import torch.nn as nn | |
import torch.nn.functional as F | |
from .repvgg import create_RepVGG | |
class RepVGG_8_1_align(nn.Module): | |
""" | |
RepVGG backbone, output resolution are 1/8 and 1. | |
Each block has 2 layers. | |
""" | |
def __init__(self, config): | |
super().__init__() | |
backbone = create_RepVGG(False) | |
self.layer0, self.layer1, self.layer2, self.layer3 = backbone.stage0, backbone.stage1, backbone.stage2, backbone.stage3 | |
for layer in [self.layer0, self.layer1, self.layer2, self.layer3]: | |
for m in layer.modules(): | |
if isinstance(m, nn.Conv2d): | |
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') | |
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): | |
nn.init.constant_(m.weight, 1) | |
nn.init.constant_(m.bias, 0) | |
def forward(self, x): | |
out = self.layer0(x) # 1/2 | |
for module in self.layer1: | |
out = module(out) # 1/2 | |
x1 = out | |
for module in self.layer2: | |
out = module(out) # 1/4 | |
x2 = out | |
for module in self.layer3: | |
out = module(out) # 1/8 | |
x3 = out | |
return {'feats_c': x3, 'feats_f': None, 'feats_x2': x2, 'feats_x1': x1} | |