bensonsantos commited on
Commit
dfc786f
1 Parent(s): f5b5f26

Upload model.py

Browse files
Files changed (1) hide show
  1. model.py +81 -0
model.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch
3
+ from torch.nn import functional as F
4
+ from torchvision import models
5
+
6
+ class ContextualModule(nn.Module):
7
+ def __init__(self, features, out_features=512, sizes=(1, 2, 3, 6)):
8
+ super(ContextualModule, self).__init__()
9
+ self.scales = []
10
+ self.scales = nn.ModuleList([self._make_scale(features, size) for size in sizes])
11
+ self.bottleneck = nn.Conv2d(features * 2, out_features, kernel_size=1)
12
+ self.relu = nn.ReLU()
13
+ self.weight_net = nn.Conv2d(features,features,kernel_size=1)
14
+
15
+ def __make_weight(self,feature,scale_feature):
16
+ weight_feature = feature - scale_feature
17
+ return F.sigmoid(self.weight_net(weight_feature))
18
+
19
+ def _make_scale(self, features, size):
20
+ prior = nn.AdaptiveAvgPool2d(output_size=(size, size))
21
+ conv = nn.Conv2d(features, features, kernel_size=1, bias=False)
22
+ return nn.Sequential(prior, conv)
23
+
24
+ def forward(self, feats):
25
+ h, w = feats.size(2), feats.size(3)
26
+ multi_scales = [F.upsample(input=stage(feats), size=(h, w), mode='bilinear') for stage in self.scales]
27
+ weights = [self.__make_weight(feats,scale_feature) for scale_feature in multi_scales]
28
+ overall_features = [(multi_scales[0]*weights[0]+multi_scales[1]*weights[1]+multi_scales[2]*weights[2]+multi_scales[3]*weights[3])/(weights[0]+weights[1]+weights[2]+weights[3])]+ [feats]
29
+ bottle = self.bottleneck(torch.cat(overall_features, 1))
30
+ return self.relu(bottle)
31
+
32
+ class CANNet(nn.Module):
33
+ def __init__(self, load_weights=False):
34
+ super(CANNet, self).__init__()
35
+ self.seen = 0
36
+ self.context = ContextualModule(512, 512)
37
+ self.frontend_feat = [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512]
38
+ self.backend_feat = [512, 512, 512,256,128,64]
39
+ self.frontend = make_layers(self.frontend_feat)
40
+ self.backend = make_layers(self.backend_feat,in_channels = 512,batch_norm=True, dilation = True)
41
+ self.output_layer = nn.Conv2d(64, 1, kernel_size=1)
42
+ if not load_weights:
43
+ mod = models.vgg16(pretrained = True)
44
+ self._initialize_weights()
45
+ for i in range(len(self.frontend.state_dict().items())):
46
+ list(self.frontend.state_dict().items())[i][1].data[:] = list(mod.state_dict().items())[i][1].data[:]
47
+
48
+ def forward(self,x):
49
+ x = self.frontend(x)
50
+ x = self.context(x)
51
+ x = self.backend(x)
52
+ x = self.output_layer(x)
53
+ return x
54
+
55
+ def _initialize_weights(self):
56
+ for m in self.modules():
57
+ if isinstance(m, nn.Conv2d):
58
+ nn.init.normal_(m.weight, std=0.01)
59
+ if m.bias is not None:
60
+ nn.init.constant_(m.bias, 0)
61
+ elif isinstance(m, nn.BatchNorm2d):
62
+ nn.init.constant_(m.weight, 1)
63
+ nn.init.constant_(m.bias, 0)
64
+
65
+ def make_layers(cfg, in_channels = 3,batch_norm=False,dilation = False):
66
+ if dilation:
67
+ d_rate = 2
68
+ else:
69
+ d_rate = 1
70
+ layers = []
71
+ for v in cfg:
72
+ if v == 'M':
73
+ layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
74
+ else:
75
+ conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=d_rate,dilation = d_rate)
76
+ if batch_norm:
77
+ layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
78
+ else:
79
+ layers += [conv2d, nn.ReLU(inplace=True)]
80
+ in_channels = v
81
+ return nn.Sequential(*layers)