Spaces:
Sleeping
Sleeping
File size: 6,311 Bytes
cb80c28 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 |
'''
Reference:
https://github.com/khurramjaved96/incremental-learning/blob/autoencoders/model/resnet32.py
https://github.com/hshustc/CVPR19_Incremental_Learning/blob/master/cifar100-class-incremental/modified_resnet_cifar.py
'''
import torch
import torch.nn as nn
import torch.nn.functional as F
# from convs.modified_linear import CosineLinear
class DownsampleA(nn.Module):
def __init__(self, nIn, nOut, stride):
super(DownsampleA, self).__init__()
assert stride == 2
self.avg = nn.AvgPool2d(kernel_size=1, stride=stride)
def forward(self, x):
x = self.avg(x)
return torch.cat((x, x.mul(0)), 1)
class DownsampleB(nn.Module):
def __init__(self, nIn, nOut, stride):
super(DownsampleB, self).__init__()
self.conv = nn.Conv2d(nIn, nOut, kernel_size=1, stride=stride, padding=0, bias=False)
self.bn = nn.BatchNorm2d(nOut)
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
return x
class DownsampleC(nn.Module):
def __init__(self, nIn, nOut, stride):
super(DownsampleC, self).__init__()
assert stride != 1 or nIn != nOut
self.conv = nn.Conv2d(nIn, nOut, kernel_size=1, stride=stride, padding=0, bias=False)
def forward(self, x):
x = self.conv(x)
return x
class DownsampleD(nn.Module):
def __init__(self, nIn, nOut, stride):
super(DownsampleD, self).__init__()
assert stride == 2
self.conv = nn.Conv2d(nIn, nOut, kernel_size=2, stride=stride, padding=0, bias=False)
self.bn = nn.BatchNorm2d(nOut)
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
return x
class ResNetBasicblock(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, stride=1, downsample=None, last=False):
super(ResNetBasicblock, self).__init__()
self.conv_a = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
self.bn_a = nn.BatchNorm2d(planes)
self.conv_b = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
self.bn_b = nn.BatchNorm2d(planes)
self.downsample = downsample
self.last = last
def forward(self, x):
residual = x
basicblock = self.conv_a(x)
basicblock = self.bn_a(basicblock)
basicblock = F.relu(basicblock, inplace=True)
basicblock = self.conv_b(basicblock)
basicblock = self.bn_b(basicblock)
if self.downsample is not None:
residual = self.downsample(x)
out = residual + basicblock
if not self.last:
out = F.relu(out, inplace=True)
return out
class CifarResNet(nn.Module):
"""
ResNet optimized for the Cifar Dataset, as specified in
https://arxiv.org/abs/1512.03385.pdf
"""
def __init__(self, block, depth, channels=3):
super(CifarResNet, self).__init__()
# Model type specifies number of layers for CIFAR-10 and CIFAR-100 model
assert (depth - 2) % 6 == 0, 'depth should be one of 20, 32, 44, 56, 110'
layer_blocks = (depth - 2) // 6
self.conv_1_3x3 = nn.Conv2d(channels, 16, kernel_size=3, stride=1, padding=1, bias=False)
self.bn_1 = nn.BatchNorm2d(16)
self.inplanes = 16
self.stage_1 = self._make_layer(block, 16, layer_blocks, 1)
self.stage_2 = self._make_layer(block, 32, layer_blocks, 2)
self.stage_3 = self._make_layer(block, 64, layer_blocks, 2, last_phase=True)
self.avgpool = nn.AvgPool2d(8)
self.out_dim = 64 * block.expansion
# self.fc = CosineLinear(64*block.expansion, 10)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def _make_layer(self, block, planes, blocks, stride=1, last_phase=False):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = DownsampleB(self.inplanes, planes * block.expansion, stride) # DownsampleA => DownsampleB
layers = []
layers.append(block(self.inplanes, planes, stride, downsample))
self.inplanes = planes * block.expansion
if last_phase:
for i in range(1, blocks-1):
layers.append(block(self.inplanes, planes))
layers.append(block(self.inplanes, planes, last=True))
else:
for i in range(1, blocks):
layers.append(block(self.inplanes, planes))
return nn.Sequential(*layers)
def forward(self, x):
x = self.conv_1_3x3(x) # [bs, 16, 32, 32]
x = F.relu(self.bn_1(x), inplace=True)
x_1 = self.stage_1(x) # [bs, 16, 32, 32]
x_2 = self.stage_2(x_1) # [bs, 32, 16, 16]
x_3 = self.stage_3(x_2) # [bs, 64, 8, 8]
pooled = self.avgpool(x_3) # [bs, 64, 1, 1]
features = pooled.view(pooled.size(0), -1) # [bs, 64]
# out = self.fc(vector)
return {
'fmaps': [x_1, x_2, x_3],
'features': features
}
@property
def last_conv(self):
return self.stage_3[-1].conv_b
def resnet20mnist():
"""Constructs a ResNet-20 model for MNIST."""
model = CifarResNet(ResNetBasicblock, 20, 1)
return model
def resnet32mnist():
"""Constructs a ResNet-32 model for MNIST."""
model = CifarResNet(ResNetBasicblock, 32, 1)
return model
def resnet20():
"""Constructs a ResNet-20 model for CIFAR-10."""
model = CifarResNet(ResNetBasicblock, 20)
return model
def resnet32():
"""Constructs a ResNet-32 model for CIFAR-10."""
model = CifarResNet(ResNetBasicblock, 32)
return model
def resnet44():
"""Constructs a ResNet-44 model for CIFAR-10."""
model = CifarResNet(ResNetBasicblock, 44)
return model
def resnet56():
"""Constructs a ResNet-56 model for CIFAR-10."""
model = CifarResNet(ResNetBasicblock, 56)
return model
def resnet110():
"""Constructs a ResNet-110 model for CIFAR-10."""
model = CifarResNet(ResNetBasicblock, 110)
return model
|