MykolaL commited on
Commit
93cb4db
1 Parent(s): d716966

Upload mask_predictor.py

Browse files
Files changed (1) hide show
  1. refer/lib/mask_predictor.py +72 -0
refer/lib/mask_predictor.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from torch.nn import functional as F
4
+ from collections import OrderedDict
5
+
6
+
7
+ class SimpleDecoding(nn.Module):
8
+ def __init__(self, dims, factor=2):
9
+ super(SimpleDecoding, self).__init__()
10
+
11
+ hidden_size = dims[-1]//factor
12
+ c4_size = dims[-1]
13
+ c3_size = dims[-2]
14
+ c2_size = dims[-3]
15
+ c1_size = dims[-4]
16
+
17
+ self.conv1_4 = nn.Conv2d(c4_size+c3_size, hidden_size, 3, padding=1, bias=False)
18
+ self.bn1_4 = nn.BatchNorm2d(hidden_size)
19
+ self.relu1_4 = nn.ReLU()
20
+ self.conv2_4 = nn.Conv2d(hidden_size, hidden_size, 3, padding=1, bias=False)
21
+ self.bn2_4 = nn.BatchNorm2d(hidden_size)
22
+ self.relu2_4 = nn.ReLU()
23
+
24
+ self.conv1_3 = nn.Conv2d(hidden_size + c2_size, hidden_size, 3, padding=1, bias=False)
25
+ self.bn1_3 = nn.BatchNorm2d(hidden_size)
26
+ self.relu1_3 = nn.ReLU()
27
+ self.conv2_3 = nn.Conv2d(hidden_size, hidden_size, 3, padding=1, bias=False)
28
+ self.bn2_3 = nn.BatchNorm2d(hidden_size)
29
+ self.relu2_3 = nn.ReLU()
30
+
31
+ self.conv1_2 = nn.Conv2d(hidden_size + c1_size, hidden_size, 3, padding=1, bias=False)
32
+ self.bn1_2 = nn.BatchNorm2d(hidden_size)
33
+ self.relu1_2 = nn.ReLU()
34
+ self.conv2_2 = nn.Conv2d(hidden_size, hidden_size, 3, padding=1, bias=False)
35
+ self.bn2_2 = nn.BatchNorm2d(hidden_size)
36
+ self.relu2_2 = nn.ReLU()
37
+
38
+ self.conv1_1 = nn.Conv2d(hidden_size, 2, 1)
39
+
40
+ def forward(self, x_c4, x_c3, x_c2, x_c1):
41
+ # fuse Y4 and Y3
42
+ if x_c4.size(-2) < x_c3.size(-2) or x_c4.size(-1) < x_c3.size(-1):
43
+ x_c4 = F.interpolate(input=x_c4, size=(x_c3.size(-2), x_c3.size(-1)), mode='bilinear', align_corners=True)
44
+ x = torch.cat([x_c4, x_c3], dim=1)
45
+ x = self.conv1_4(x)
46
+ x = self.bn1_4(x)
47
+ x = self.relu1_4(x)
48
+ x = self.conv2_4(x)
49
+ x = self.bn2_4(x)
50
+ x = self.relu2_4(x)
51
+ # fuse top-down features and Y2 features
52
+ if x.size(-2) < x_c2.size(-2) or x.size(-1) < x_c2.size(-1):
53
+ x = F.interpolate(input=x, size=(x_c2.size(-2), x_c2.size(-1)), mode='bilinear', align_corners=True)
54
+ x = torch.cat([x, x_c2], dim=1)
55
+ x = self.conv1_3(x)
56
+ x = self.bn1_3(x)
57
+ x = self.relu1_3(x)
58
+ x = self.conv2_3(x)
59
+ x = self.bn2_3(x)
60
+ x = self.relu2_3(x)
61
+ # fuse top-down features and Y1 features
62
+ if x.size(-2) < x_c1.size(-2) or x.size(-1) < x_c1.size(-1):
63
+ x = F.interpolate(input=x, size=(x_c1.size(-2), x_c1.size(-1)), mode='bilinear', align_corners=True)
64
+ x = torch.cat([x, x_c1], dim=1)
65
+ x = self.conv1_2(x)
66
+ x = self.bn1_2(x)
67
+ x = self.relu1_2(x)
68
+ x = self.conv2_2(x)
69
+ x = self.bn2_2(x)
70
+ x = self.relu2_2(x)
71
+
72
+ return self.conv1_1(x)