weiquan commited on
Commit
336f823
·
1 Parent(s): 98d8a15
ckpts/resnet/cifar10/Dense_SA_best.path.tar ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7420c8d9158ebb948094e39a17cb22ec1bd269957a41930e7171b036baedd26c
3
+ size 89489613
ckpts/resnet/cifar10/FF/fisher_newcheckpoint.pth.tar ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fa0010529135002e07c81491bf1985212fa28470c6e77f0c90c1f6d659037731
3
+ size 44775689
ckpts/resnet/cifar10/FF/fisher_neweval_result.pth.tar ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5ce00ed2a06fbd1a20e8ece24df65f336af83b4dca24b7f305dff9437dc3d5ee
3
+ size 1199
ckpts/resnet/cifar10/FT/FTcheckpoint.pth.tar ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5bc1f381a0c2c0bdf35109b5b5195b90aa3826d20966df94f6cbf340f881f782
3
+ size 44775689
ckpts/resnet/cifar10/FT/FTeval_result.pth.tar ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:331b95208b29cc2202a4a92170cec0a8e3f52a338ee79c8f1437326056808c4f
3
+ size 1199
ckpts/resnet/cifar10/GA/GAcheckpoint.pth.tar ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:922ef6c25c17124489e6c0754c78c9d824179d34977d6206ac1387a75e7b5e6a
3
+ size 44775689
ckpts/resnet/cifar10/GA/GAeval_result.pth.tar ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:94266550c31fa9507cc09496cfc0f81f4017401827039a0b98150dbe7c249be1
3
+ size 1199
ckpts/resnet/cifar10/IU/wfishercheckpoint.pth.tar ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4aee7cbb80a5961ce44ea68a6534643835edf57c7518fe35faaabba8a33147df
3
+ size 44775689
ckpts/resnet/cifar10/IU/wfishereval_result.pth.tar ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:64b013e64a4db5a48085967e498b0e21f8d5ebd563804057a1fe3a1f11e9f597
3
+ size 1199
ckpts/resnet/cifar10/l1_sparse/FT_prunecheckpoint.pth.tar ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:55f9adfd0450ac743a8ab08558463d32ad32b624a21657cd0bae1e9f10cbc845
3
+ size 44775689
ckpts/resnet/cifar10/l1_sparse/FT_pruneeval_result.pth.tar ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0117231e0a39b0adc48cd5348a9de7a3deaadc1efbec87ae1aa076e1d6ff7721
3
+ size 1199
ckpts/resnet/cifar10/retrain/retraincheckpoint.pth.tar ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:41a05992d8ee665c9466f0468f3c447691b5e5cde015189e7e86b6f21359ca00
3
+ size 44775689
ckpts/resnet/cifar10/retrain/retraineval_result.pth.tar ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:db34432b11e493aee52501966d3e0bd0f4d4502e154231f463430bb930805912
3
+ size 1199
models/ResNet.py ADDED
@@ -0,0 +1,460 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ # from torchvision.models.utils import load_state_dict_from_url
5
+
6
+
7
+ class NormalizeByChannelMeanStd(torch.nn.Module):
8
+ def __init__(self, mean, std):
9
+ super(NormalizeByChannelMeanStd, self).__init__()
10
+ if not isinstance(mean, torch.Tensor):
11
+ mean = torch.tensor(mean)
12
+ if not isinstance(std, torch.Tensor):
13
+ std = torch.tensor(std)
14
+ self.register_buffer("mean", mean)
15
+ self.register_buffer("std", std)
16
+
17
+ def forward(self, tensor):
18
+ return self.normalize_fn(tensor, self.mean, self.std)
19
+
20
+ def extra_repr(self):
21
+ return "mean={}, std={}".format(self.mean, self.std)
22
+
23
+ def normalize_fn(self, tensor, mean, std):
24
+ """Differentiable version of torchvision.functional.normalize"""
25
+ # here we assume the color channel is in at dim=1
26
+ mean = mean[None, :, None, None]
27
+ std = std[None, :, None, None]
28
+ return tensor.sub(mean).div(std)
29
+
30
+
31
+ __all__ = [
32
+ "ResNet",
33
+ "resnet18",
34
+ "resnet34",
35
+ "resnet50",
36
+ "resnet101",
37
+ "resnet152",
38
+ "resnext50_32x4d",
39
+ "resnext101_32x8d",
40
+ "wide_resnet50_2",
41
+ "wide_resnet101_2",
42
+ ]
43
+
44
+
45
+ model_urls = {
46
+ "resnet18": "https://download.pytorch.org/models/resnet18-5c106cde.pth",
47
+ "resnet34": "https://download.pytorch.org/models/resnet34-333f7ec4.pth",
48
+ "resnet50": "https://download.pytorch.org/models/resnet50-19c8e357.pth",
49
+ "resnet101": "https://download.pytorch.org/models/resnet101-5d3b4d8f.pth",
50
+ "resnet152": "https://download.pytorch.org/models/resnet152-b121ed2d.pth",
51
+ "resnext50_32x4d": "https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth",
52
+ "resnext101_32x8d": "https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth",
53
+ "wide_resnet50_2": "https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth",
54
+ "wide_resnet101_2": "https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth",
55
+ }
56
+
57
+
58
+ def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
59
+ """3x3 convolution with padding"""
60
+ return nn.Conv2d(
61
+ in_planes,
62
+ out_planes,
63
+ kernel_size=3,
64
+ stride=stride,
65
+ padding=dilation,
66
+ groups=groups,
67
+ bias=False,
68
+ dilation=dilation,
69
+ )
70
+
71
+
72
+ def conv1x1(in_planes, out_planes, stride=1):
73
+ """1x1 convolution"""
74
+ return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
75
+
76
+
77
+ class BasicBlock(nn.Module):
78
+ expansion = 1
79
+ __constants__ = ["downsample"]
80
+
81
+ def __init__(
82
+ self,
83
+ inplanes,
84
+ planes,
85
+ stride=1,
86
+ downsample=None,
87
+ groups=1,
88
+ base_width=64,
89
+ dilation=1,
90
+ norm_layer=None,
91
+ ):
92
+ super(BasicBlock, self).__init__()
93
+ if norm_layer is None:
94
+ norm_layer = nn.BatchNorm2d
95
+ if groups != 1 or base_width != 64:
96
+ raise ValueError("BasicBlock only supports groups=1 and base_width=64")
97
+ if dilation > 1:
98
+ raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
99
+ # Both self.conv1 and self.downsample layers downsample the input when stride != 1
100
+ self.conv1 = conv3x3(inplanes, planes, stride)
101
+ self.bn1 = norm_layer(planes)
102
+ self.relu = nn.ReLU(inplace=True)
103
+ self.conv2 = conv3x3(planes, planes)
104
+ self.bn2 = norm_layer(planes)
105
+ self.downsample = downsample
106
+ self.stride = stride
107
+
108
+ def forward(self, x):
109
+ identity = x
110
+
111
+ out = self.conv1(x)
112
+ out = self.bn1(out)
113
+ out = self.relu(out)
114
+
115
+ out = self.conv2(out)
116
+ out = self.bn2(out)
117
+
118
+ if self.downsample is not None:
119
+ identity = self.downsample(x)
120
+
121
+ out += identity
122
+ out = self.relu(out)
123
+
124
+ return out
125
+
126
+
127
+ class Bottleneck(nn.Module):
128
+ expansion = 4
129
+ __constants__ = ["downsample"]
130
+
131
+ def __init__(
132
+ self,
133
+ inplanes,
134
+ planes,
135
+ stride=1,
136
+ downsample=None,
137
+ groups=1,
138
+ base_width=64,
139
+ dilation=1,
140
+ norm_layer=None,
141
+ ):
142
+ super(Bottleneck, self).__init__()
143
+ if norm_layer is None:
144
+ norm_layer = nn.BatchNorm2d
145
+ width = int(planes * (base_width / 64.0)) * groups
146
+ # Both self.conv2 and self.downsample layers downsample the input when stride != 1
147
+ self.conv1 = conv1x1(inplanes, width)
148
+ self.bn1 = norm_layer(width)
149
+ self.conv2 = conv3x3(width, width, stride, groups, dilation)
150
+ self.bn2 = norm_layer(width)
151
+ self.conv3 = conv1x1(width, planes * self.expansion)
152
+ self.bn3 = norm_layer(planes * self.expansion)
153
+ self.relu = nn.ReLU(inplace=True)
154
+ self.downsample = downsample
155
+ self.stride = stride
156
+
157
+ def forward(self, x):
158
+ identity = x
159
+
160
+ out = self.conv1(x)
161
+ out = self.bn1(out)
162
+ out = self.relu(out)
163
+
164
+ out = self.conv2(out)
165
+ out = self.bn2(out)
166
+ out = self.relu(out)
167
+
168
+ out = self.conv3(out)
169
+ out = self.bn3(out)
170
+
171
+ if self.downsample is not None:
172
+ identity = self.downsample(x)
173
+
174
+ out += identity
175
+ out = self.relu(out)
176
+
177
+ return out
178
+
179
+
180
+ class ResNet(nn.Module):
181
+ def __init__(
182
+ self,
183
+ block,
184
+ layers,
185
+ num_classes=1000,
186
+ zero_init_residual=False,
187
+ groups=1,
188
+ width_per_group=64,
189
+ replace_stride_with_dilation=None,
190
+ norm_layer=None,
191
+ imagenet=False,
192
+ ):
193
+ super(ResNet, self).__init__()
194
+ if norm_layer is None:
195
+ norm_layer = nn.BatchNorm2d
196
+ self._norm_layer = norm_layer
197
+
198
+ self.inplanes = 64
199
+ self.dilation = 1
200
+ if replace_stride_with_dilation is None:
201
+ # each element in the tuple indicates if we should replace
202
+ # the 2x2 stride with a dilated convolution instead
203
+ replace_stride_with_dilation = [False, False, False]
204
+ if len(replace_stride_with_dilation) != 3:
205
+ raise ValueError(
206
+ "replace_stride_with_dilation should be None "
207
+ "or a 3-element tuple, got {}".format(replace_stride_with_dilation)
208
+ )
209
+ self.groups = groups
210
+ self.base_width = width_per_group
211
+
212
+ print("The normalize layer is contained in the network")
213
+ self.normalize = NormalizeByChannelMeanStd(
214
+ mean=[0.4914, 0.4822, 0.4465], std=[0.2470, 0.2435, 0.2616]
215
+ )
216
+
217
+ if not imagenet:
218
+ self.conv1 = nn.Conv2d(
219
+ 3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False
220
+ )
221
+ self.bn1 = norm_layer(self.inplanes)
222
+ self.relu = nn.ReLU(inplace=True)
223
+ self.maxpool = nn.Identity()
224
+ else:
225
+ self.conv1 = nn.Conv2d(
226
+ 3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False
227
+ )
228
+ self.bn1 = nn.BatchNorm2d(self.inplanes)
229
+ self.relu = nn.ReLU(inplace=True)
230
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
231
+
232
+ self.layer1 = self._make_layer(block, 64, layers[0])
233
+ self.layer2 = self._make_layer(
234
+ block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0]
235
+ )
236
+ self.layer3 = self._make_layer(
237
+ block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1]
238
+ )
239
+ self.layer4 = self._make_layer(
240
+ block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2]
241
+ )
242
+ self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
243
+ self.fc = nn.Linear(512 * block.expansion, num_classes)
244
+
245
+ for m in self.modules():
246
+ if isinstance(m, nn.Conv2d):
247
+ nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
248
+ elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
249
+ nn.init.constant_(m.weight, 1)
250
+ nn.init.constant_(m.bias, 0)
251
+
252
+ # Zero-initialize the last BN in each residual branch,
253
+ # so that the residual branch starts with zeros, and each residual block behaves like an identity.
254
+ # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
255
+ if zero_init_residual:
256
+ for m in self.modules():
257
+ if isinstance(m, Bottleneck):
258
+ nn.init.constant_(m.bn3.weight, 0)
259
+ elif isinstance(m, BasicBlock):
260
+ nn.init.constant_(m.bn2.weight, 0)
261
+
262
+ def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
263
+ norm_layer = self._norm_layer
264
+ downsample = None
265
+ previous_dilation = self.dilation
266
+ if dilate:
267
+ self.dilation *= stride
268
+ stride = 1
269
+ if stride != 1 or self.inplanes != planes * block.expansion:
270
+ downsample = nn.Sequential(
271
+ conv1x1(self.inplanes, planes * block.expansion, stride),
272
+ norm_layer(planes * block.expansion),
273
+ )
274
+
275
+ layers = []
276
+ layers.append(
277
+ block(
278
+ self.inplanes,
279
+ planes,
280
+ stride,
281
+ downsample,
282
+ self.groups,
283
+ self.base_width,
284
+ previous_dilation,
285
+ norm_layer,
286
+ )
287
+ )
288
+ self.inplanes = planes * block.expansion
289
+ for _ in range(1, blocks):
290
+ layers.append(
291
+ block(
292
+ self.inplanes,
293
+ planes,
294
+ groups=self.groups,
295
+ base_width=self.base_width,
296
+ dilation=self.dilation,
297
+ norm_layer=norm_layer,
298
+ )
299
+ )
300
+
301
+ return nn.Sequential(*layers)
302
+
303
+ def _forward_impl(self, x):
304
+ # See note [TorchScript super()]
305
+ x = self.normalize(x)
306
+
307
+ x = self.conv1(x)
308
+ x = self.bn1(x)
309
+ x = self.relu(x)
310
+ x = self.maxpool(x)
311
+
312
+ x = self.layer1(x)
313
+ x = self.layer2(x)
314
+ x = self.layer3(x)
315
+ x = self.layer4(x)
316
+
317
+ x = self.avgpool(x)
318
+ x = torch.flatten(x, 1)
319
+ # print(x.shape)
320
+ x = self.fc(x)
321
+
322
+ return x
323
+
324
+ def forward(self, x):
325
+ return self._forward_impl(x)
326
+
327
+
328
+ def _resnet(arch, block, layers, pretrained, progress, **kwargs):
329
+ model = ResNet(block, layers, **kwargs)
330
+ if pretrained:
331
+ state_dict = load_state_dict_from_url(model_urls[arch], progress=progress)
332
+ model.load_state_dict(state_dict)
333
+ return model
334
+
335
+
336
+ def resnet18(pretrained=False, progress=True, **kwargs):
337
+ r"""ResNet-18 model from
338
+ `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
339
+
340
+ Args:
341
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
342
+ progress (bool): If True, displays a progress bar of the download to stderr
343
+ """
344
+ return _resnet("resnet18", BasicBlock, [2, 2, 2, 2], pretrained, progress, **kwargs)
345
+
346
+
347
+ def resnet34(pretrained=False, progress=True, **kwargs):
348
+ r"""ResNet-34 model from
349
+ `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
350
+
351
+ Args:
352
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
353
+ progress (bool): If True, displays a progress bar of the download to stderr
354
+ """
355
+ return _resnet("resnet34", BasicBlock, [3, 4, 6, 3], pretrained, progress, **kwargs)
356
+
357
+
358
+ def resnet50(pretrained=False, progress=True, **kwargs):
359
+ r"""ResNet-50 model from
360
+ `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
361
+
362
+ Args:
363
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
364
+ progress (bool): If True, displays a progress bar of the download to stderr
365
+ """
366
+ return _resnet("resnet50", Bottleneck, [3, 4, 6, 3], pretrained, progress, **kwargs)
367
+
368
+
369
+ def resnet101(pretrained=False, progress=True, **kwargs):
370
+ r"""ResNet-101 model from
371
+ `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
372
+
373
+ Args:
374
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
375
+ progress (bool): If True, displays a progress bar of the download to stderr
376
+ """
377
+ return _resnet(
378
+ "resnet101", Bottleneck, [3, 4, 23, 3], pretrained, progress, **kwargs
379
+ )
380
+
381
+
382
+ def resnet152(pretrained=False, progress=True, **kwargs):
383
+ r"""ResNet-152 model from
384
+ `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
385
+
386
+ Args:
387
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
388
+ progress (bool): If True, displays a progress bar of the download to stderr
389
+ """
390
+ return _resnet(
391
+ "resnet152", Bottleneck, [3, 8, 36, 3], pretrained, progress, **kwargs
392
+ )
393
+
394
+
395
+ def resnext50_32x4d(pretrained=False, progress=True, **kwargs):
396
+ r"""ResNeXt-50 32x4d model from
397
+ `"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_
398
+
399
+ Args:
400
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
401
+ progress (bool): If True, displays a progress bar of the download to stderr
402
+ """
403
+ kwargs["groups"] = 32
404
+ kwargs["width_per_group"] = 4
405
+ return _resnet(
406
+ "resnext50_32x4d", Bottleneck, [3, 4, 6, 3], pretrained, progress, **kwargs
407
+ )
408
+
409
+
410
+ def resnext101_32x8d(pretrained=False, progress=True, **kwargs):
411
+ r"""ResNeXt-101 32x8d model from
412
+ `"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_
413
+
414
+ Args:
415
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
416
+ progress (bool): If True, displays a progress bar of the download to stderr
417
+ """
418
+ kwargs["groups"] = 32
419
+ kwargs["width_per_group"] = 8
420
+ return _resnet(
421
+ "resnext101_32x8d", Bottleneck, [3, 4, 23, 3], pretrained, progress, **kwargs
422
+ )
423
+
424
+
425
+ def wide_resnet50_2(pretrained=False, progress=True, **kwargs):
426
+ r"""Wide ResNet-50-2 model from
427
+ `"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_
428
+
429
+ The model is the same as ResNet except for the bottleneck number of channels
430
+ which is twice larger in every block. The number of channels in outer 1x1
431
+ convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
432
+ channels, and in Wide ResNet-50-2 has 2048-1024-2048.
433
+
434
+ Args:
435
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
436
+ progress (bool): If True, displays a progress bar of the download to stderr
437
+ """
438
+ kwargs["width_per_group"] = 64 * 2
439
+ return _resnet(
440
+ "wide_resnet50_2", Bottleneck, [3, 4, 6, 3], pretrained, progress, **kwargs
441
+ )
442
+
443
+
444
+ def wide_resnet101_2(pretrained=False, progress=True, **kwargs):
445
+ r"""Wide ResNet-101-2 model from
446
+ `"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_
447
+
448
+ The model is the same as ResNet except for the bottleneck number of channels
449
+ which is twice larger in every block. The number of channels in outer 1x1
450
+ convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
451
+ channels, and in Wide ResNet-50-2 has 2048-1024-2048.
452
+
453
+ Args:
454
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
455
+ progress (bool): If True, displays a progress bar of the download to stderr
456
+ """
457
+ kwargs["width_per_group"] = 64 * 2
458
+ return _resnet(
459
+ "wide_resnet101_2", Bottleneck, [3, 4, 23, 3], pretrained, progress, **kwargs
460
+ )