leonelhs commited on
Commit
def3395
1 Parent(s): aa7afdd
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ .idea/
2
+ __pycache__/
app.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://github.com/MarcoForte/FBA_Matting
2
+ import cv2
3
+ import gradio as gr
4
+ import numpy as np
5
+ import torch
6
+ from huggingface_hub import hf_hub_download
7
+
8
+ from networks.models import build_model
9
+ from networks.transforms import trimap_transform, normalise_image
10
+
11
+ REPO_ID = "leonelhs/FBA-Matting"
12
+
13
+ weights = hf_hub_download(repo_id=REPO_ID, filename="FBA.pth")
14
+ model = build_model(weights)
15
+ model.eval().cpu()
16
+
17
+
18
+ def np_to_torch(x, permute=True):
19
+ if permute:
20
+ return torch.from_numpy(x).permute(2, 0, 1)[None, :, :, :].float().cpu()
21
+ else:
22
+ return torch.from_numpy(x)[None, :, :, :].float().cpu()
23
+
24
+
25
+ def scale_input(x: np.ndarray, scale: float, scale_type) -> np.ndarray:
26
+ ''' Scales inputs to multiple of 8. '''
27
+ h, w = x.shape[:2]
28
+ h1 = int(np.ceil(scale * h / 8) * 8)
29
+ w1 = int(np.ceil(scale * w / 8) * 8)
30
+ x_scale = cv2.resize(x, (w1, h1), interpolation=scale_type)
31
+ return x_scale
32
+
33
+
34
+ def inference(image_np: np.ndarray, trimap_np: np.ndarray) -> [np.ndarray]:
35
+ ''' Predict alpha, foreground and background.
36
+ Parameters:
37
+ image_np -- the image in rgb format between 0 and 1. Dimensions: (h, w, 3)
38
+ trimap_np -- two channel trimap, first background then foreground. Dimensions: (h, w, 2)
39
+ Returns:
40
+ fg: foreground image in rgb format between 0 and 1. Dimensions: (h, w, 3)
41
+ bg: background image in rgb format between 0 and 1. Dimensions: (h, w, 3)
42
+ alpha: alpha matte image between 0 and 1. Dimensions: (h, w)
43
+ '''
44
+ h, w = trimap_np.shape[:2]
45
+ image_scale_np = scale_input(image_np, 1.0, cv2.INTER_LANCZOS4)
46
+ trimap_scale_np = scale_input(trimap_np, 1.0, cv2.INTER_LANCZOS4)
47
+
48
+ with torch.no_grad():
49
+ image_torch = np_to_torch(image_scale_np)
50
+ trimap_torch = np_to_torch(trimap_scale_np)
51
+
52
+ trimap_transformed_torch = np_to_torch(
53
+ trimap_transform(trimap_scale_np), permute=False)
54
+ image_transformed_torch = normalise_image(
55
+ image_torch.clone())
56
+
57
+ output = model(
58
+ image_torch,
59
+ trimap_torch,
60
+ image_transformed_torch,
61
+ trimap_transformed_torch)
62
+ output = cv2.resize(
63
+ output[0].cpu().numpy().transpose(
64
+ (1, 2, 0)), (w, h), cv2.INTER_LANCZOS4)
65
+
66
+ alpha = output[:, :, 0]
67
+ fg = output[:, :, 1:4]
68
+ bg = output[:, :, 4:7]
69
+
70
+ alpha[trimap_np[:, :, 0] == 1] = 0
71
+ alpha[trimap_np[:, :, 1] == 1] = 1
72
+ fg[alpha == 1] = image_np[alpha == 1]
73
+ bg[alpha == 0] = image_np[alpha == 0]
74
+
75
+ return fg, bg, alpha
76
+
77
+
78
+ def read_image(name):
79
+ return (cv2.imread(name) / 255.0)[:, :, ::-1]
80
+
81
+
82
+ def read_trimap(name):
83
+ trimap_im = cv2.imread(name, 0) / 255.0
84
+ h, w = trimap_im.shape
85
+ trimap_np = np.zeros((h, w, 2))
86
+ trimap_np[trimap_im == 1, 1] = 1
87
+ trimap_np[trimap_im == 0, 0] = 1
88
+ return trimap_np
89
+
90
+
91
+ def predict(image, trimap):
92
+ image_np = read_image(image)
93
+ trimap_np = read_trimap(trimap)
94
+ return inference(image_np, trimap_np)
95
+
96
+
97
+ footer = r"""
98
+ <center>
99
+ <b>
100
+ Demo for <a href='https://github.com/MarcoForte/FBA_Matting'>FBA Matting</a>
101
+ </b>
102
+ </center>
103
+ """
104
+
105
+ with gr.Blocks(title="FBA Matting") as app:
106
+ gr.HTML("<center><h1>FBA Matting</h1></center>")
107
+ gr.HTML("<center><h3>Foreground, Background, Alpha Matting Generator.</h3></center>")
108
+ with gr.Row().style(equal_height=False):
109
+ with gr.Column():
110
+ input_img = gr.Image(type="filepath", label="Input image")
111
+ input_trimap = gr.Image(type="filepath", label="Trimap image")
112
+ run_btn = gr.Button(variant="primary")
113
+ with gr.Column():
114
+ fg = gr.Image(type="numpy", label="Foreground")
115
+ bg = gr.Image(type="numpy", label="Background")
116
+ alpha = gr.Image(type="numpy", label="Alpha")
117
+
118
+ run_btn.click(predict, [input_img, input_trimap], [fg, bg, alpha])
119
+
120
+ with gr.Row():
121
+ examples_data = [[f"examples/{x:02d}.jpg"] for x in range(1, 4)]
122
+ examples = gr.Dataset(components=[input_img], samples=examples_data)
123
+ examples.click(lambda x: x[0], [examples], [input_img])
124
+
125
+ with gr.Row():
126
+ gr.HTML(footer)
127
+
128
+ app.launch(share=False, debug=True, enable_queue=True, show_error=True)
networks/__init__.py ADDED
File without changes
networks/layers_WS.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ from torch.nn import functional as F
3
+
4
+
5
+ class Conv2d(nn.Conv2d):
6
+
7
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1,
8
+ padding=0, dilation=1, groups=1, bias=True, eps=1e-5):
9
+ super(Conv2d, self).__init__(in_channels, out_channels, kernel_size, stride,
10
+ padding, dilation, groups, bias)
11
+ self.out_channels = out_channels
12
+ self.eps = eps
13
+
14
+ def normalize_weight(self):
15
+ weight = F.batch_norm(
16
+ self.weight.view(1, self.out_channels, -1), None, None,
17
+ training=True, momentum=0., eps=self.eps).reshape_as(self.weight)
18
+ self.weight.data = weight
19
+
20
+ def forward(self, x):
21
+ if self.training:
22
+ self.normalize_weight()
23
+ return F.conv2d(x, self.weight, self.bias, self.stride,
24
+ self.padding, self.dilation, self.groups)
25
+
26
+ def train(self, mode: bool = True):
27
+ super().train(mode=mode)
28
+ self.normalize_weight()
29
+
30
+
31
+ def norm(dim):
32
+ return nn.GroupNorm(32, dim)
networks/models.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from networks.resnet_GN_WS import ResNet
4
+ import networks.layers_WS as L
5
+
6
+
7
+ def build_model(weights):
8
+ net_encoder = fba_encoder()
9
+
10
+ net_decoder = fba_decoder()
11
+
12
+ model = MattingModule(net_encoder, net_decoder)
13
+
14
+ if weights != 'default':
15
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
16
+ sd = torch.load(weights, map_location=device)
17
+ model.load_state_dict(sd, strict=True)
18
+
19
+ return model
20
+
21
+
22
+ class MattingModule(nn.Module):
23
+ def __init__(self, net_enc, net_dec):
24
+ super(MattingModule, self).__init__()
25
+ self.encoder = net_enc
26
+ self.decoder = net_dec
27
+
28
+ def forward(self, image, two_chan_trimap, image_n, trimap_transformed):
29
+ resnet_input = torch.cat((image_n, trimap_transformed, two_chan_trimap), 1)
30
+ conv_out, indices = self.encoder(resnet_input, return_feature_maps=True)
31
+ return self.decoder(conv_out, image, indices, two_chan_trimap)
32
+
33
+
34
+ def fba_encoder():
35
+ orig_resnet = ResNet()
36
+ net_encoder = ResnetDilated(orig_resnet, dilate_scale=8)
37
+
38
+ num_channels = 3 + 6 + 2
39
+
40
+ print(f'modifying input layer to accept {num_channels} channels')
41
+ net_encoder_sd = net_encoder.state_dict()
42
+ conv1_weights = net_encoder_sd['conv1.weight']
43
+
44
+ c_out, c_in, h, w = conv1_weights.size()
45
+ conv1_mod = torch.zeros(c_out, num_channels, h, w)
46
+ conv1_mod[:, :3, :, :] = conv1_weights
47
+
48
+ conv1 = net_encoder.conv1
49
+ conv1.in_channels = num_channels
50
+ conv1.weight = torch.nn.Parameter(conv1_mod)
51
+
52
+ net_encoder.conv1 = conv1
53
+
54
+ net_encoder_sd['conv1.weight'] = conv1_mod
55
+
56
+ net_encoder.load_state_dict(net_encoder_sd)
57
+ return net_encoder
58
+
59
+
60
+ class ResnetDilated(nn.Module):
61
+ def __init__(self, orig_resnet, dilate_scale=8):
62
+ super(ResnetDilated, self).__init__()
63
+ from functools import partial
64
+
65
+ if dilate_scale == 8:
66
+ orig_resnet.layer3.apply(
67
+ partial(self._nostride_dilate, dilate=2))
68
+ orig_resnet.layer4.apply(
69
+ partial(self._nostride_dilate, dilate=4))
70
+ elif dilate_scale == 16:
71
+ orig_resnet.layer4.apply(
72
+ partial(self._nostride_dilate, dilate=2))
73
+
74
+ # take pretrained resnet, except AvgPool and FC
75
+ self.conv1 = orig_resnet.conv1
76
+ self.bn1 = orig_resnet.bn1
77
+ self.relu = orig_resnet.relu
78
+ self.maxpool = orig_resnet.maxpool
79
+ self.layer1 = orig_resnet.layer1
80
+ self.layer2 = orig_resnet.layer2
81
+ self.layer3 = orig_resnet.layer3
82
+ self.layer4 = orig_resnet.layer4
83
+
84
+ def _nostride_dilate(self, m, dilate):
85
+ classname = m.__class__.__name__
86
+ if classname.find('Conv') != -1:
87
+ # the convolution with stride
88
+ if m.stride == (2, 2):
89
+ m.stride = (1, 1)
90
+ if m.kernel_size == (3, 3):
91
+ m.dilation = (dilate // 2, dilate // 2)
92
+ m.padding = (dilate // 2, dilate // 2)
93
+ # other convoluions
94
+ else:
95
+ if m.kernel_size == (3, 3):
96
+ m.dilation = (dilate, dilate)
97
+ m.padding = (dilate, dilate)
98
+
99
+ def forward(self, x, return_feature_maps=False):
100
+ conv_out = [x]
101
+ x = self.relu(self.bn1(self.conv1(x)))
102
+ conv_out.append(x)
103
+ x, indices = self.maxpool(x)
104
+ x = self.layer1(x)
105
+ conv_out.append(x)
106
+ x = self.layer2(x)
107
+ conv_out.append(x)
108
+ x = self.layer3(x)
109
+ conv_out.append(x)
110
+ x = self.layer4(x)
111
+ conv_out.append(x)
112
+
113
+ if return_feature_maps:
114
+ return conv_out, indices
115
+ return [x]
116
+
117
+
118
+ def fba_fusion(alpha, img, F, B):
119
+ F = (alpha * img + (1 - alpha ** 2) * F - alpha * (1 - alpha) * B)
120
+ B = ((1 - alpha) * img + (2 * alpha - alpha ** 2) * B - alpha * (1 - alpha) * F)
121
+
122
+ F = torch.clamp(F, 0, 1)
123
+ B = torch.clamp(B, 0, 1)
124
+ la = 0.1
125
+ alpha = (alpha * la + torch.sum((img - B) * (F - B), 1, keepdim=True)) / (
126
+ torch.sum((F - B) * (F - B), 1, keepdim=True) + la)
127
+ alpha = torch.clamp(alpha, 0, 1)
128
+ return alpha, F, B
129
+
130
+
131
+ class fba_decoder(nn.Module):
132
+ def __init__(self):
133
+ super(fba_decoder, self).__init__()
134
+ pool_scales = (1, 2, 3, 6)
135
+
136
+ self.ppm = []
137
+
138
+ for scale in pool_scales:
139
+ self.ppm.append(nn.Sequential(
140
+ nn.AdaptiveAvgPool2d(scale),
141
+ L.Conv2d(2048, 256, kernel_size=1, bias=True),
142
+ L.norm(256),
143
+ nn.LeakyReLU()
144
+ ))
145
+ self.ppm = nn.ModuleList(self.ppm)
146
+
147
+ self.conv_up1 = nn.Sequential(
148
+ L.Conv2d(2048 + len(pool_scales) * 256, 256,
149
+ kernel_size=3, padding=1, bias=True),
150
+
151
+ L.norm(256),
152
+ nn.LeakyReLU(),
153
+ L.Conv2d(256, 256, kernel_size=3, padding=1),
154
+ L.norm(256),
155
+ nn.LeakyReLU()
156
+ )
157
+
158
+ self.conv_up2 = nn.Sequential(
159
+ L.Conv2d(256 + 256, 256,
160
+ kernel_size=3, padding=1, bias=True),
161
+ L.norm(256),
162
+ nn.LeakyReLU()
163
+ )
164
+ self.conv_up3 = nn.Sequential(
165
+ L.Conv2d(256 + 64, 64,
166
+ kernel_size=3, padding=1, bias=True),
167
+ L.norm(64),
168
+ nn.LeakyReLU()
169
+ )
170
+
171
+ self.unpool = nn.MaxUnpool2d(2, stride=2)
172
+
173
+ self.conv_up4 = nn.Sequential(
174
+ nn.Conv2d(64 + 3 + 3 + 2, 32,
175
+ kernel_size=3, padding=1, bias=True),
176
+ nn.LeakyReLU(),
177
+ nn.Conv2d(32, 16,
178
+ kernel_size=3, padding=1, bias=True),
179
+
180
+ nn.LeakyReLU(),
181
+ nn.Conv2d(16, 7, kernel_size=1, padding=0, bias=True)
182
+ )
183
+
184
+ def forward(self, conv_out, img, indices, two_chan_trimap):
185
+ conv5 = conv_out[-1]
186
+
187
+ input_size = conv5.size()
188
+ ppm_out = [conv5]
189
+ for pool_scale in self.ppm:
190
+ ppm_out.append(nn.functional.interpolate(
191
+ pool_scale(conv5),
192
+ (input_size[2], input_size[3]),
193
+ mode='bilinear', align_corners=False))
194
+ ppm_out = torch.cat(ppm_out, 1)
195
+ x = self.conv_up1(ppm_out)
196
+
197
+ x = torch.nn.functional.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
198
+
199
+ x = torch.cat((x, conv_out[-4]), 1)
200
+
201
+ x = self.conv_up2(x)
202
+ x = torch.nn.functional.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
203
+
204
+ x = torch.cat((x, conv_out[-5]), 1)
205
+ x = self.conv_up3(x)
206
+
207
+ x = torch.nn.functional.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
208
+ x = torch.cat((x, conv_out[-6][:, :3], img, two_chan_trimap), 1)
209
+
210
+ output = self.conv_up4(x)
211
+
212
+ alpha = torch.clamp(output[:, 0][:, None], 0, 1)
213
+ F = torch.sigmoid(output[:, 1:4])
214
+ B = torch.sigmoid(output[:, 4:7])
215
+
216
+ # FBA Fusion
217
+ alpha, F, B = fba_fusion(alpha, img, F, B)
218
+
219
+ output = torch.cat((alpha, F, B), 1)
220
+
221
+ return output
networks/resnet_GN_WS.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import networks.layers_WS as L
3
+
4
+ __all__ = ['ResNet', 'l_resnet50']
5
+
6
+
7
+ def conv3x3(in_planes, out_planes, stride=1):
8
+ """3x3 convolution with padding"""
9
+ return L.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
10
+ padding=1, bias=False)
11
+
12
+
13
+ def conv1x1(in_planes, out_planes, stride=1):
14
+ """1x1 convolution"""
15
+ return L.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
16
+
17
+
18
+ class Bottleneck(nn.Module):
19
+ expansion = 4
20
+
21
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
22
+ super(Bottleneck, self).__init__()
23
+ self.conv1 = conv1x1(inplanes, planes)
24
+ self.bn1 = L.norm(planes)
25
+ self.conv2 = conv3x3(planes, planes, stride)
26
+ self.bn2 = L.norm(planes)
27
+ self.conv3 = conv1x1(planes, planes * self.expansion)
28
+ self.bn3 = L.norm(planes * self.expansion)
29
+ self.relu = nn.ReLU(inplace=True)
30
+ self.downsample = downsample
31
+ self.stride = stride
32
+
33
+ def forward(self, x):
34
+ identity = x
35
+
36
+ out = self.conv1(x)
37
+ out = self.bn1(out)
38
+ out = self.relu(out)
39
+
40
+ out = self.conv2(out)
41
+ out = self.bn2(out)
42
+ out = self.relu(out)
43
+
44
+ out = self.conv3(out)
45
+ out = self.bn3(out)
46
+
47
+ if self.downsample is not None:
48
+ identity = self.downsample(x)
49
+
50
+ out += identity
51
+ out = self.relu(out)
52
+
53
+ return out
54
+
55
+
56
+ class ResNet(nn.Module):
57
+
58
+ def __init__(self, block=Bottleneck, layers=[3, 4, 6, 3], num_classes=1000):
59
+ super(ResNet, self).__init__()
60
+ self.inplanes = 64
61
+ self.conv1 = L.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
62
+ bias=False)
63
+ self.bn1 = L.norm(64)
64
+ self.relu = nn.ReLU(inplace=True)
65
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1, return_indices=True)
66
+ self.layer1 = self._make_layer(block, 64, layers[0])
67
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
68
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
69
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
70
+ self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
71
+ self.fc = nn.Linear(512 * block.expansion, num_classes)
72
+
73
+ def _make_layer(self, block, planes, blocks, stride=1):
74
+ downsample = None
75
+ if stride != 1 or self.inplanes != planes * block.expansion:
76
+ downsample = nn.Sequential(
77
+ conv1x1(self.inplanes, planes * block.expansion, stride),
78
+ L.norm(planes * block.expansion),
79
+ )
80
+
81
+ layers = []
82
+ layers.append(block(self.inplanes, planes, stride, downsample))
83
+ self.inplanes = planes * block.expansion
84
+ for _ in range(1, blocks):
85
+ layers.append(block(self.inplanes, planes))
86
+
87
+ return nn.Sequential(*layers)
88
+
89
+ def forward(self, x):
90
+ x = self.conv1(x)
91
+ x = self.bn1(x)
92
+ x = self.relu(x)
93
+ x = self.maxpool(x)
94
+
95
+ x = self.layer1(x)
96
+ x = self.layer2(x)
97
+ x = self.layer3(x)
98
+ x = self.layer4(x)
99
+
100
+ x = self.avgpool(x)
101
+ x = x.view(x.size(0), -1)
102
+ x = self.fc(x)
103
+
104
+ return x
networks/resnet_bn.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import math
3
+ from torch.nn import BatchNorm2d
4
+
5
+ __all__ = ['ResNet']
6
+
7
+
8
+ def conv3x3(in_planes, out_planes, stride=1):
9
+ "3x3 convolution with padding"
10
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
11
+ padding=1, bias=False)
12
+
13
+
14
+ class BasicBlock(nn.Module):
15
+ expansion = 1
16
+
17
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
18
+ super(BasicBlock, self).__init__()
19
+ self.conv1 = conv3x3(inplanes, planes, stride)
20
+ self.bn1 = BatchNorm2d(planes)
21
+ self.relu = nn.ReLU(inplace=True)
22
+ self.conv2 = conv3x3(planes, planes)
23
+ self.bn2 = BatchNorm2d(planes)
24
+ self.downsample = downsample
25
+ self.stride = stride
26
+
27
+ def forward(self, x):
28
+ residual = x
29
+
30
+ out = self.conv1(x)
31
+ out = self.bn1(out)
32
+ out = self.relu(out)
33
+
34
+ out = self.conv2(out)
35
+ out = self.bn2(out)
36
+
37
+ if self.downsample is not None:
38
+ residual = self.downsample(x)
39
+
40
+ out += residual
41
+ out = self.relu(out)
42
+
43
+ return out
44
+
45
+
46
+ class Bottleneck(nn.Module):
47
+ expansion = 4
48
+
49
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
50
+ super(Bottleneck, self).__init__()
51
+ self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
52
+ self.bn1 = BatchNorm2d(planes)
53
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
54
+ padding=1, bias=False)
55
+ self.bn2 = BatchNorm2d(planes, momentum=0.01)
56
+ self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
57
+ self.bn3 = BatchNorm2d(planes * 4)
58
+ self.relu = nn.ReLU(inplace=True)
59
+ self.downsample = downsample
60
+ self.stride = stride
61
+
62
+ def forward(self, x):
63
+ residual = x
64
+
65
+ out = self.conv1(x)
66
+ out = self.bn1(out)
67
+ out = self.relu(out)
68
+
69
+ out = self.conv2(out)
70
+ out = self.bn2(out)
71
+ out = self.relu(out)
72
+
73
+ out = self.conv3(out)
74
+ out = self.bn3(out)
75
+
76
+ if self.downsample is not None:
77
+ residual = self.downsample(x)
78
+
79
+ out += residual
80
+ out = self.relu(out)
81
+
82
+ return out
83
+
84
+
85
+ class ResNet(nn.Module):
86
+
87
+ def __init__(self, block, layers, num_classes=1000):
88
+ self.inplanes = 128
89
+ super(ResNet, self).__init__()
90
+ self.conv1 = conv3x3(3, 64, stride=2)
91
+ self.bn1 = BatchNorm2d(64)
92
+ self.relu1 = nn.ReLU(inplace=True)
93
+ self.conv2 = conv3x3(64, 64)
94
+ self.bn2 = BatchNorm2d(64)
95
+ self.relu2 = nn.ReLU(inplace=True)
96
+ self.conv3 = conv3x3(64, 128)
97
+ self.bn3 = BatchNorm2d(128)
98
+ self.relu3 = nn.ReLU(inplace=True)
99
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1, return_indices=True)
100
+
101
+ self.layer1 = self._make_layer(block, 64, layers[0])
102
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
103
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
104
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
105
+ self.avgpool = nn.AvgPool2d(7, stride=1)
106
+ self.fc = nn.Linear(512 * block.expansion, num_classes)
107
+
108
+ for m in self.modules():
109
+ if isinstance(m, nn.Conv2d):
110
+ n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
111
+ m.weight.data.normal_(0, math.sqrt(2. / n))
112
+ elif isinstance(m, BatchNorm2d):
113
+ m.weight.data.fill_(1)
114
+ m.bias.data.zero_()
115
+
116
+ def _make_layer(self, block, planes, blocks, stride=1):
117
+ downsample = None
118
+ if stride != 1 or self.inplanes != planes * block.expansion:
119
+ downsample = nn.Sequential(
120
+ nn.Conv2d(self.inplanes, planes * block.expansion,
121
+ kernel_size=1, stride=stride, bias=False),
122
+ BatchNorm2d(planes * block.expansion),
123
+ )
124
+
125
+ layers = []
126
+ layers.append(block(self.inplanes, planes, stride, downsample))
127
+ self.inplanes = planes * block.expansion
128
+ for i in range(1, blocks):
129
+ layers.append(block(self.inplanes, planes))
130
+
131
+ return nn.Sequential(*layers)
132
+
133
+ def forward(self, x):
134
+ x = self.relu1(self.bn1(self.conv1(x)))
135
+ x = self.relu2(self.bn2(self.conv2(x)))
136
+ x = self.relu3(self.bn3(self.conv3(x)))
137
+ x, indices = self.maxpool(x)
138
+
139
+ x = self.layer1(x)
140
+ x = self.layer2(x)
141
+ x = self.layer3(x)
142
+ x = self.layer4(x)
143
+
144
+ x = self.avgpool(x)
145
+ x = x.view(x.size(0), -1)
146
+ x = self.fc(x)
147
+ return x
148
+
149
+
150
+ def l_resnet50():
151
+ """Constructs a ResNet-50 model.
152
+ Args:
153
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
154
+ """
155
+ model = ResNet(Bottleneck, [3, 4, 6, 3])
156
+ return model
networks/transforms.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import cv2
4
+
5
+
6
+ def dt(a):
7
+ return cv2.distanceTransform((a * 255).astype(np.uint8), cv2.DIST_L2, 0)
8
+
9
+
10
+ def trimap_transform(trimap, L=320):
11
+ clicks = []
12
+ for k in range(2):
13
+ dt_mask = -dt(1 - trimap[:, :, k]) ** 2
14
+ clicks.append(np.exp(dt_mask / (2 * ((0.02 * L) ** 2))))
15
+ clicks.append(np.exp(dt_mask / (2 * ((0.08 * L) ** 2))))
16
+ clicks.append(np.exp(dt_mask / (2 * ((0.16 * L) ** 2))))
17
+ clicks = np.array(clicks)
18
+ return clicks
19
+
20
+
21
+ # For RGB !
22
+ imagenet_norm_std = torch.from_numpy(np.array([0.229, 0.224, 0.225])).float().cpu()[None, :, None, None]
23
+ imagenet_norm_mean = torch.from_numpy(np.array([0.485, 0.456, 0.406])).float().cpu()[None, :, None, None]
24
+
25
+
26
+ def normalise_image(image, mean=imagenet_norm_mean, std=imagenet_norm_std):
27
+ return (image - mean) / std
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ torch>=1.4.0
2
+ numpy
3
+ opencv-python