Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
""" | |
from https://github.com/Separius/SimCLRv2-Pytorch | |
""" | |
BATCH_NORM_EPSILON = 1e-5 | |
BATCH_NORM_DECAY = 0.9 # == pytorch's default value as well | |
class BatchNormRelu(nn.Sequential): | |
def __init__(self, num_channels, relu=True): | |
super().__init__(nn.BatchNorm2d(num_channels, eps=BATCH_NORM_EPSILON), | |
nn.ReLU() if relu else nn.Identity()) | |
def conv(in_channels, out_channels, kernel_size=3, stride=1, bias=False): | |
return nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, | |
stride=stride, padding=(kernel_size - 1) // 2, bias=bias) | |
class SelectiveKernel(nn.Module): | |
def __init__(self, in_channels, out_channels, stride, sk_ratio, min_dim=32): | |
super().__init__() | |
assert sk_ratio > 0.0 | |
self.main_conv = nn.Sequential(conv(in_channels, 2 * out_channels, stride=stride), | |
BatchNormRelu(2 * out_channels)) | |
mid_dim = max(int(out_channels * sk_ratio), min_dim) | |
self.mixing_conv = nn.Sequential(conv(out_channels, mid_dim, kernel_size=1), | |
BatchNormRelu(mid_dim), | |
conv(mid_dim, 2 * out_channels, kernel_size=1)) | |
def forward(self, x): | |
x = self.main_conv(x) | |
x = torch.stack(torch.chunk(x, 2, dim=1), dim=0) # 2, B, C, H, W | |
g = x.sum(dim=0).mean(dim=[2, 3], keepdim=True) | |
m = self.mixing_conv(g) | |
m = torch.stack(torch.chunk(m, 2, dim=1), dim=0) # 2, B, C, 1, 1 | |
return (x * F.softmax(m, dim=0)).sum(dim=0) | |
class Projection(nn.Module): | |
def __init__(self, in_channels, out_channels, stride, sk_ratio=0): | |
super().__init__() | |
if sk_ratio > 0: | |
self.shortcut = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), | |
nn.AvgPool2d(kernel_size=2, stride=stride, padding=0), | |
conv(in_channels, out_channels, kernel_size=1)) | |
else: | |
self.shortcut = conv(in_channels, out_channels, kernel_size=1, stride=stride) | |
self.bn = BatchNormRelu(out_channels, relu=False) | |
def forward(self, x): | |
return self.bn(self.shortcut(x)) | |
class BottleneckBlock(nn.Module): | |
expansion = 4 | |
def __init__(self, in_channels, out_channels, stride, sk_ratio=0, use_projection=False): | |
super().__init__() | |
if use_projection: | |
self.projection = Projection(in_channels, out_channels * 4, stride, sk_ratio) | |
else: | |
self.projection = nn.Identity() | |
ops = [conv(in_channels, out_channels, kernel_size=1), BatchNormRelu(out_channels)] | |
if sk_ratio > 0: | |
ops.append(SelectiveKernel(out_channels, out_channels, stride, sk_ratio)) | |
else: | |
ops.append(conv(out_channels, out_channels, stride=stride)) | |
ops.append(BatchNormRelu(out_channels)) | |
ops.append(conv(out_channels, out_channels * 4, kernel_size=1)) | |
ops.append(BatchNormRelu(out_channels * 4, relu=False)) | |
self.net = nn.Sequential(*ops) | |
def forward(self, x): | |
shortcut = self.projection(x) | |
return F.relu(shortcut + self.net(x)) | |
class Blocks(nn.Module): | |
def __init__(self, num_blocks, in_channels, out_channels, stride, sk_ratio=0): | |
super().__init__() | |
self.blocks = nn.ModuleList([BottleneckBlock(in_channels, out_channels, stride, sk_ratio, True)]) | |
self.channels_out = out_channels * BottleneckBlock.expansion | |
for _ in range(num_blocks - 1): | |
self.blocks.append(BottleneckBlock(self.channels_out, out_channels, 1, sk_ratio)) | |
def forward(self, x): | |
for b in self.blocks: | |
x = b(x) | |
return x | |
class Stem(nn.Sequential): | |
def __init__(self, sk_ratio, width_multiplier): | |
ops = [] | |
channels = 64 * width_multiplier // 2 | |
if sk_ratio > 0: | |
ops.append(conv(3, channels, stride=2)) | |
ops.append(BatchNormRelu(channels)) | |
ops.append(conv(channels, channels)) | |
ops.append(BatchNormRelu(channels)) | |
ops.append(conv(channels, channels * 2)) | |
else: | |
ops.append(conv(3, channels * 2, kernel_size=7, stride=2)) | |
ops.append(BatchNormRelu(channels * 2)) | |
ops.append(nn.MaxPool2d(kernel_size=3, stride=2, padding=1)) | |
super().__init__(*ops) | |
class ResNet(nn.Module): | |
def __init__(self, layers, width_multiplier, sk_ratio): | |
super().__init__() | |
ops = [Stem(sk_ratio, width_multiplier)] | |
channels_in = 64 * width_multiplier | |
ops.append(Blocks(layers[0], channels_in, 64 * width_multiplier, 1, sk_ratio)) | |
channels_in = ops[-1].channels_out | |
ops.append(Blocks(layers[1], channels_in, 128 * width_multiplier, 2, sk_ratio)) | |
channels_in = ops[-1].channels_out | |
ops.append(Blocks(layers[2], channels_in, 256 * width_multiplier, 2, sk_ratio)) | |
channels_in = ops[-1].channels_out | |
ops.append(Blocks(layers[3], channels_in, 512 * width_multiplier, 2, sk_ratio)) | |
channels_in = ops[-1].channels_out | |
self.channels_out = channels_in | |
self.net = nn.Sequential(*ops) | |
self.fc = nn.Linear(channels_in, 1000) | |
def forward(self, x, apply_fc=False): | |
h = self.net(x).mean(dim=[2, 3]) | |
if apply_fc: | |
h = self.fc(h) | |
return h | |
class ContrastiveHead(nn.Module): | |
def __init__(self, channels_in, out_dim=128, num_layers=3): | |
super().__init__() | |
self.layers = nn.ModuleList() | |
for i in range(num_layers): | |
if i != num_layers - 1: | |
dim, relu = channels_in, True | |
else: | |
dim, relu = out_dim, False | |
self.layers.append(nn.Linear(channels_in, dim, bias=False)) | |
bn = nn.BatchNorm1d(dim, eps=BATCH_NORM_EPSILON, affine=True) | |
if i == num_layers - 1: | |
nn.init.zeros_(bn.bias) | |
self.layers.append(bn) | |
if relu: | |
self.layers.append(nn.ReLU()) | |
def forward(self, x): | |
for b in self.layers: | |
x = b(x) | |
return x | |
def get_resnet(depth=50, width_multiplier=1, sk_ratio=0): # sk_ratio=0.0625 is recommended | |
layers = {50: [3, 4, 6, 3], 101: [3, 4, 23, 3], 152: [3, 8, 36, 3], 200: [3, 24, 36, 3]}[depth] | |
resnet = ResNet(layers, width_multiplier, sk_ratio) | |
return resnet, ContrastiveHead(resnet.channels_out) | |
def name_to_params(checkpoint): | |
sk_ratio = 0.0625 if '_sk1' in checkpoint else 0 | |
if 'r50_' in checkpoint: | |
depth = 50 | |
elif 'r101_' in checkpoint: | |
depth = 101 | |
elif 'r152_' in checkpoint: | |
depth = 152 | |
else: | |
raise NotImplementedError | |
if '_1x_' in checkpoint: | |
width = 1 | |
elif '_2x_' in checkpoint: | |
width = 2 | |
elif '_3x_' in checkpoint: | |
width = 3 | |
else: | |
raise NotImplementedError | |
return depth, width, sk_ratio | |
class SimCLRv2(nn.Module): | |
def __init__(self, model, head): | |
super(SimCLRv2, self).__init__() | |
self.encoder = model | |
self.contrastive_head = head | |
def forward(self, x): | |
x = self.encoder(x) | |
x = self.contrastive_head(x) | |
return x | |
def get_simclr2_model(ckpt_path): | |
depth, width, sk_ratio = name_to_params(ckpt_path) | |
model, head = get_resnet(depth, width, sk_ratio) | |
checkpoint = torch.load('pretrained_models/simclr2_models/' + ckpt_path) | |
model.load_state_dict(checkpoint['resnet']) | |
head.load_state_dict(checkpoint['head']) | |
del model.fc | |
simclr2 = SimCLRv2(model, head) | |
return simclr2.to(device) |