import torch import torch.nn as nn import torch.nn.functional as F device = torch.device("cuda" if torch.cuda.is_available() else "cpu") """ from https://github.com/Separius/SimCLRv2-Pytorch """ BATCH_NORM_EPSILON = 1e-5 BATCH_NORM_DECAY = 0.9 # == pytorch's default value as well class BatchNormRelu(nn.Sequential): def __init__(self, num_channels, relu=True): super().__init__(nn.BatchNorm2d(num_channels, eps=BATCH_NORM_EPSILON), nn.ReLU() if relu else nn.Identity()) def conv(in_channels, out_channels, kernel_size=3, stride=1, bias=False): return nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=(kernel_size - 1) // 2, bias=bias) class SelectiveKernel(nn.Module): def __init__(self, in_channels, out_channels, stride, sk_ratio, min_dim=32): super().__init__() assert sk_ratio > 0.0 self.main_conv = nn.Sequential(conv(in_channels, 2 * out_channels, stride=stride), BatchNormRelu(2 * out_channels)) mid_dim = max(int(out_channels * sk_ratio), min_dim) self.mixing_conv = nn.Sequential(conv(out_channels, mid_dim, kernel_size=1), BatchNormRelu(mid_dim), conv(mid_dim, 2 * out_channels, kernel_size=1)) def forward(self, x): x = self.main_conv(x) x = torch.stack(torch.chunk(x, 2, dim=1), dim=0) # 2, B, C, H, W g = x.sum(dim=0).mean(dim=[2, 3], keepdim=True) m = self.mixing_conv(g) m = torch.stack(torch.chunk(m, 2, dim=1), dim=0) # 2, B, C, 1, 1 return (x * F.softmax(m, dim=0)).sum(dim=0) class Projection(nn.Module): def __init__(self, in_channels, out_channels, stride, sk_ratio=0): super().__init__() if sk_ratio > 0: self.shortcut = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.AvgPool2d(kernel_size=2, stride=stride, padding=0), conv(in_channels, out_channels, kernel_size=1)) else: self.shortcut = conv(in_channels, out_channels, kernel_size=1, stride=stride) self.bn = BatchNormRelu(out_channels, relu=False) def forward(self, x): return self.bn(self.shortcut(x)) class BottleneckBlock(nn.Module): expansion = 4 def __init__(self, in_channels, out_channels, stride, sk_ratio=0, use_projection=False): super().__init__() if use_projection: self.projection = Projection(in_channels, out_channels * 4, stride, sk_ratio) else: self.projection = nn.Identity() ops = [conv(in_channels, out_channels, kernel_size=1), BatchNormRelu(out_channels)] if sk_ratio > 0: ops.append(SelectiveKernel(out_channels, out_channels, stride, sk_ratio)) else: ops.append(conv(out_channels, out_channels, stride=stride)) ops.append(BatchNormRelu(out_channels)) ops.append(conv(out_channels, out_channels * 4, kernel_size=1)) ops.append(BatchNormRelu(out_channels * 4, relu=False)) self.net = nn.Sequential(*ops) def forward(self, x): shortcut = self.projection(x) return F.relu(shortcut + self.net(x)) class Blocks(nn.Module): def __init__(self, num_blocks, in_channels, out_channels, stride, sk_ratio=0): super().__init__() self.blocks = nn.ModuleList([BottleneckBlock(in_channels, out_channels, stride, sk_ratio, True)]) self.channels_out = out_channels * BottleneckBlock.expansion for _ in range(num_blocks - 1): self.blocks.append(BottleneckBlock(self.channels_out, out_channels, 1, sk_ratio)) def forward(self, x): for b in self.blocks: x = b(x) return x class Stem(nn.Sequential): def __init__(self, sk_ratio, width_multiplier): ops = [] channels = 64 * width_multiplier // 2 if sk_ratio > 0: ops.append(conv(3, channels, stride=2)) ops.append(BatchNormRelu(channels)) ops.append(conv(channels, channels)) ops.append(BatchNormRelu(channels)) ops.append(conv(channels, channels * 2)) else: ops.append(conv(3, channels * 2, kernel_size=7, stride=2)) ops.append(BatchNormRelu(channels * 2)) ops.append(nn.MaxPool2d(kernel_size=3, stride=2, padding=1)) super().__init__(*ops) class ResNet(nn.Module): def __init__(self, layers, width_multiplier, sk_ratio): super().__init__() ops = [Stem(sk_ratio, width_multiplier)] channels_in = 64 * width_multiplier ops.append(Blocks(layers[0], channels_in, 64 * width_multiplier, 1, sk_ratio)) channels_in = ops[-1].channels_out ops.append(Blocks(layers[1], channels_in, 128 * width_multiplier, 2, sk_ratio)) channels_in = ops[-1].channels_out ops.append(Blocks(layers[2], channels_in, 256 * width_multiplier, 2, sk_ratio)) channels_in = ops[-1].channels_out ops.append(Blocks(layers[3], channels_in, 512 * width_multiplier, 2, sk_ratio)) channels_in = ops[-1].channels_out self.channels_out = channels_in self.net = nn.Sequential(*ops) self.fc = nn.Linear(channels_in, 1000) def forward(self, x, apply_fc=False): h = self.net(x).mean(dim=[2, 3]) if apply_fc: h = self.fc(h) return h class ContrastiveHead(nn.Module): def __init__(self, channels_in, out_dim=128, num_layers=3): super().__init__() self.layers = nn.ModuleList() for i in range(num_layers): if i != num_layers - 1: dim, relu = channels_in, True else: dim, relu = out_dim, False self.layers.append(nn.Linear(channels_in, dim, bias=False)) bn = nn.BatchNorm1d(dim, eps=BATCH_NORM_EPSILON, affine=True) if i == num_layers - 1: nn.init.zeros_(bn.bias) self.layers.append(bn) if relu: self.layers.append(nn.ReLU()) def forward(self, x): for b in self.layers: x = b(x) return x def get_resnet(depth=50, width_multiplier=1, sk_ratio=0): # sk_ratio=0.0625 is recommended layers = {50: [3, 4, 6, 3], 101: [3, 4, 23, 3], 152: [3, 8, 36, 3], 200: [3, 24, 36, 3]}[depth] resnet = ResNet(layers, width_multiplier, sk_ratio) return resnet, ContrastiveHead(resnet.channels_out) def name_to_params(checkpoint): sk_ratio = 0.0625 if '_sk1' in checkpoint else 0 if 'r50_' in checkpoint: depth = 50 elif 'r101_' in checkpoint: depth = 101 elif 'r152_' in checkpoint: depth = 152 else: raise NotImplementedError if '_1x_' in checkpoint: width = 1 elif '_2x_' in checkpoint: width = 2 elif '_3x_' in checkpoint: width = 3 else: raise NotImplementedError return depth, width, sk_ratio class SimCLRv2(nn.Module): def __init__(self, model, head): super(SimCLRv2, self).__init__() self.encoder = model self.contrastive_head = head def forward(self, x): x = self.encoder(x) x = self.contrastive_head(x) return x def get_simclr2_model(ckpt_path): depth, width, sk_ratio = name_to_params(ckpt_path) model, head = get_resnet(depth, width, sk_ratio) checkpoint = torch.load('pretrained_models/simclr2_models/' + ckpt_path) model.load_state_dict(checkpoint['resnet']) head.load_state_dict(checkpoint['head']) del model.fc simclr2 = SimCLRv2(model, head) return simclr2.to(device)