josedolot commited on
Commit
3e89704
·
1 Parent(s): 2eb9939

Upload encoders/timm_sknet.py

Browse files
Files changed (1) hide show
  1. encoders/timm_sknet.py +103 -0
encoders/timm_sknet.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ._base import EncoderMixin
2
+ from timm.models.resnet import ResNet
3
+ from timm.models.sknet import SelectiveKernelBottleneck, SelectiveKernelBasic
4
+ import torch.nn as nn
5
+
6
+
7
+ class SkNetEncoder(ResNet, EncoderMixin):
8
+ def __init__(self, out_channels, depth=5, **kwargs):
9
+ super().__init__(**kwargs)
10
+ self._depth = depth
11
+ self._out_channels = out_channels
12
+ self._in_channels = 3
13
+
14
+ del self.fc
15
+ del self.global_pool
16
+
17
+ def get_stages(self):
18
+ return [
19
+ nn.Identity(),
20
+ nn.Sequential(self.conv1, self.bn1, self.act1),
21
+ nn.Sequential(self.maxpool, self.layer1),
22
+ self.layer2,
23
+ self.layer3,
24
+ self.layer4,
25
+ ]
26
+
27
+ def forward(self, x):
28
+ stages = self.get_stages()
29
+
30
+ features = []
31
+ for i in range(self._depth + 1):
32
+ x = stages[i](x)
33
+ features.append(x)
34
+
35
+ return features
36
+
37
+ def load_state_dict(self, state_dict, **kwargs):
38
+ state_dict.pop("fc.bias", None)
39
+ state_dict.pop("fc.weight", None)
40
+ super().load_state_dict(state_dict, **kwargs)
41
+
42
+
43
+ sknet_weights = {
44
+ 'timm-skresnet18': {
45
+ 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/skresnet18_ra-4eec2804.pth'
46
+ },
47
+ 'timm-skresnet34': {
48
+ 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/skresnet34_ra-bdc0ccde.pth'
49
+ },
50
+ 'timm-skresnext50_32x4d': {
51
+ 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/skresnext50_ra-f40e40bf.pth',
52
+ }
53
+ }
54
+
55
+ pretrained_settings = {}
56
+ for model_name, sources in sknet_weights.items():
57
+ pretrained_settings[model_name] = {}
58
+ for source_name, source_url in sources.items():
59
+ pretrained_settings[model_name][source_name] = {
60
+ "url": source_url,
61
+ 'input_size': [3, 224, 224],
62
+ 'input_range': [0, 1],
63
+ 'mean': [0.485, 0.456, 0.406],
64
+ 'std': [0.229, 0.224, 0.225],
65
+ 'num_classes': 1000
66
+ }
67
+
68
+ timm_sknet_encoders = {
69
+ 'timm-skresnet18': {
70
+ 'encoder': SkNetEncoder,
71
+ "pretrained_settings": pretrained_settings["timm-skresnet18"],
72
+ 'params': {
73
+ 'out_channels': (3, 64, 64, 128, 256, 512),
74
+ 'block': SelectiveKernelBasic,
75
+ 'layers': [2, 2, 2, 2],
76
+ 'zero_init_last_bn': False,
77
+ 'block_args': {'sk_kwargs': {'rd_ratio': 1/8, 'split_input': True}}
78
+ }
79
+ },
80
+ 'timm-skresnet34': {
81
+ 'encoder': SkNetEncoder,
82
+ "pretrained_settings": pretrained_settings["timm-skresnet34"],
83
+ 'params': {
84
+ 'out_channels': (3, 64, 64, 128, 256, 512),
85
+ 'block': SelectiveKernelBasic,
86
+ 'layers': [3, 4, 6, 3],
87
+ 'zero_init_last_bn': False,
88
+ 'block_args': {'sk_kwargs': {'rd_ratio': 1/8, 'split_input': True}}
89
+ }
90
+ },
91
+ 'timm-skresnext50_32x4d': {
92
+ 'encoder': SkNetEncoder,
93
+ "pretrained_settings": pretrained_settings["timm-skresnext50_32x4d"],
94
+ 'params': {
95
+ 'out_channels': (3, 64, 256, 512, 1024, 2048),
96
+ 'block': SelectiveKernelBottleneck,
97
+ 'layers': [3, 4, 6, 3],
98
+ 'zero_init_last_bn': False,
99
+ 'cardinality': 32,
100
+ 'base_width': 4
101
+ }
102
+ }
103
+ }