import torch import torch.nn as nn import torch.nn.functional as F import torchvision.models as tvm class ResNet18(nn.Module): def __init__(self, pretrained=False) -> None: super().__init__() self.net = tvm.resnet18(pretrained=pretrained) def forward(self, x): self = self.net x1 = x x = self.conv1(x1) x = self.bn1(x) x2 = self.relu(x) x = self.maxpool(x2) x4 = self.layer1(x) x8 = self.layer2(x4) x16 = self.layer3(x8) x32 = self.layer4(x16) return {32: x32, 16: x16, 8: x8, 4: x4, 2: x2, 1: x1} def train(self, mode=True): super().train(mode) for m in self.modules(): if isinstance(m, nn.BatchNorm2d): m.eval() pass class ResNet50(nn.Module): def __init__( self, pretrained=False, high_res=False, weights=None, dilation=None, freeze_bn=True, anti_aliased=False, ) -> None: super().__init__() if dilation is None: dilation = [False, False, False] if anti_aliased: pass else: if weights is not None: self.net = tvm.resnet50( weights=weights, replace_stride_with_dilation=dilation ) else: self.net = tvm.resnet50( pretrained=pretrained, replace_stride_with_dilation=dilation ) self.high_res = high_res self.freeze_bn = freeze_bn def forward(self, x): net = self.net feats = {1: x} x = net.conv1(x) x = net.bn1(x) x = net.relu(x) feats[2] = x x = net.maxpool(x) x = net.layer1(x) feats[4] = x x = net.layer2(x) feats[8] = x x = net.layer3(x) feats[16] = x x = net.layer4(x) feats[32] = x return feats def train(self, mode=True): super().train(mode) if self.freeze_bn: for m in self.modules(): if isinstance(m, nn.BatchNorm2d): m.eval() pass class ResNet101(nn.Module): def __init__(self, pretrained=False, high_res=False, weights=None) -> None: super().__init__() if weights is not None: self.net = tvm.resnet101(weights=weights) else: self.net = tvm.resnet101(pretrained=pretrained) self.high_res = high_res self.scale_factor = 1 if not high_res else 1.5 def forward(self, x): net = self.net feats = {1: x} sf = self.scale_factor if self.high_res: x = F.interpolate(x, scale_factor=sf, align_corners=False, mode="bicubic") x = net.conv1(x) x = net.bn1(x) x = net.relu(x) feats[2] = ( x if not self.high_res else F.interpolate( x, scale_factor=1 / sf, align_corners=False, mode="bilinear" ) ) x = net.maxpool(x) x = net.layer1(x) feats[4] = ( x if not self.high_res else F.interpolate( x, scale_factor=1 / sf, align_corners=False, mode="bilinear" ) ) x = net.layer2(x) feats[8] = ( x if not self.high_res else F.interpolate( x, scale_factor=1 / sf, align_corners=False, mode="bilinear" ) ) x = net.layer3(x) feats[16] = ( x if not self.high_res else F.interpolate( x, scale_factor=1 / sf, align_corners=False, mode="bilinear" ) ) x = net.layer4(x) feats[32] = ( x if not self.high_res else F.interpolate( x, scale_factor=1 / sf, align_corners=False, mode="bilinear" ) ) return feats def train(self, mode=True): super().train(mode) for m in self.modules(): if isinstance(m, nn.BatchNorm2d): m.eval() pass class WideResNet50(nn.Module): def __init__(self, pretrained=False, high_res=False, weights=None) -> None: super().__init__() if weights is not None: self.net = tvm.wide_resnet50_2(weights=weights) else: self.net = tvm.wide_resnet50_2(pretrained=pretrained) self.high_res = high_res self.scale_factor = 1 if not high_res else 1.5 def forward(self, x): net = self.net feats = {1: x} sf = self.scale_factor if self.high_res: x = F.interpolate(x, scale_factor=sf, align_corners=False, mode="bicubic") x = net.conv1(x) x = net.bn1(x) x = net.relu(x) feats[2] = ( x if not self.high_res else F.interpolate( x, scale_factor=1 / sf, align_corners=False, mode="bilinear" ) ) x = net.maxpool(x) x = net.layer1(x) feats[4] = ( x if not self.high_res else F.interpolate( x, scale_factor=1 / sf, align_corners=False, mode="bilinear" ) ) x = net.layer2(x) feats[8] = ( x if not self.high_res else F.interpolate( x, scale_factor=1 / sf, align_corners=False, mode="bilinear" ) ) x = net.layer3(x) feats[16] = ( x if not self.high_res else F.interpolate( x, scale_factor=1 / sf, align_corners=False, mode="bilinear" ) ) x = net.layer4(x) feats[32] = ( x if not self.high_res else F.interpolate( x, scale_factor=1 / sf, align_corners=False, mode="bilinear" ) ) return feats def train(self, mode=True): super().train(mode) for m in self.modules(): if isinstance(m, nn.BatchNorm2d): m.eval() pass