Spaces:
Running
Running
init app
Browse files- .gitignore +2 -0
- app.py +128 -0
- networks/__init__.py +0 -0
- networks/layers_WS.py +32 -0
- networks/models.py +221 -0
- networks/resnet_GN_WS.py +104 -0
- networks/resnet_bn.py +156 -0
- networks/transforms.py +27 -0
- requirements.txt +3 -0
.gitignore
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
.idea/
|
2 |
+
__pycache__/
|
app.py
ADDED
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# https://github.com/MarcoForte/FBA_Matting
|
2 |
+
import cv2
|
3 |
+
import gradio as gr
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
from huggingface_hub import hf_hub_download
|
7 |
+
|
8 |
+
from networks.models import build_model
|
9 |
+
from networks.transforms import trimap_transform, normalise_image
|
10 |
+
|
11 |
+
REPO_ID = "leonelhs/FBA-Matting"
|
12 |
+
|
13 |
+
weights = hf_hub_download(repo_id=REPO_ID, filename="FBA.pth")
|
14 |
+
model = build_model(weights)
|
15 |
+
model.eval().cpu()
|
16 |
+
|
17 |
+
|
18 |
+
def np_to_torch(x, permute=True):
|
19 |
+
if permute:
|
20 |
+
return torch.from_numpy(x).permute(2, 0, 1)[None, :, :, :].float().cpu()
|
21 |
+
else:
|
22 |
+
return torch.from_numpy(x)[None, :, :, :].float().cpu()
|
23 |
+
|
24 |
+
|
25 |
+
def scale_input(x: np.ndarray, scale: float, scale_type) -> np.ndarray:
|
26 |
+
''' Scales inputs to multiple of 8. '''
|
27 |
+
h, w = x.shape[:2]
|
28 |
+
h1 = int(np.ceil(scale * h / 8) * 8)
|
29 |
+
w1 = int(np.ceil(scale * w / 8) * 8)
|
30 |
+
x_scale = cv2.resize(x, (w1, h1), interpolation=scale_type)
|
31 |
+
return x_scale
|
32 |
+
|
33 |
+
|
34 |
+
def inference(image_np: np.ndarray, trimap_np: np.ndarray) -> [np.ndarray]:
|
35 |
+
''' Predict alpha, foreground and background.
|
36 |
+
Parameters:
|
37 |
+
image_np -- the image in rgb format between 0 and 1. Dimensions: (h, w, 3)
|
38 |
+
trimap_np -- two channel trimap, first background then foreground. Dimensions: (h, w, 2)
|
39 |
+
Returns:
|
40 |
+
fg: foreground image in rgb format between 0 and 1. Dimensions: (h, w, 3)
|
41 |
+
bg: background image in rgb format between 0 and 1. Dimensions: (h, w, 3)
|
42 |
+
alpha: alpha matte image between 0 and 1. Dimensions: (h, w)
|
43 |
+
'''
|
44 |
+
h, w = trimap_np.shape[:2]
|
45 |
+
image_scale_np = scale_input(image_np, 1.0, cv2.INTER_LANCZOS4)
|
46 |
+
trimap_scale_np = scale_input(trimap_np, 1.0, cv2.INTER_LANCZOS4)
|
47 |
+
|
48 |
+
with torch.no_grad():
|
49 |
+
image_torch = np_to_torch(image_scale_np)
|
50 |
+
trimap_torch = np_to_torch(trimap_scale_np)
|
51 |
+
|
52 |
+
trimap_transformed_torch = np_to_torch(
|
53 |
+
trimap_transform(trimap_scale_np), permute=False)
|
54 |
+
image_transformed_torch = normalise_image(
|
55 |
+
image_torch.clone())
|
56 |
+
|
57 |
+
output = model(
|
58 |
+
image_torch,
|
59 |
+
trimap_torch,
|
60 |
+
image_transformed_torch,
|
61 |
+
trimap_transformed_torch)
|
62 |
+
output = cv2.resize(
|
63 |
+
output[0].cpu().numpy().transpose(
|
64 |
+
(1, 2, 0)), (w, h), cv2.INTER_LANCZOS4)
|
65 |
+
|
66 |
+
alpha = output[:, :, 0]
|
67 |
+
fg = output[:, :, 1:4]
|
68 |
+
bg = output[:, :, 4:7]
|
69 |
+
|
70 |
+
alpha[trimap_np[:, :, 0] == 1] = 0
|
71 |
+
alpha[trimap_np[:, :, 1] == 1] = 1
|
72 |
+
fg[alpha == 1] = image_np[alpha == 1]
|
73 |
+
bg[alpha == 0] = image_np[alpha == 0]
|
74 |
+
|
75 |
+
return fg, bg, alpha
|
76 |
+
|
77 |
+
|
78 |
+
def read_image(name):
|
79 |
+
return (cv2.imread(name) / 255.0)[:, :, ::-1]
|
80 |
+
|
81 |
+
|
82 |
+
def read_trimap(name):
|
83 |
+
trimap_im = cv2.imread(name, 0) / 255.0
|
84 |
+
h, w = trimap_im.shape
|
85 |
+
trimap_np = np.zeros((h, w, 2))
|
86 |
+
trimap_np[trimap_im == 1, 1] = 1
|
87 |
+
trimap_np[trimap_im == 0, 0] = 1
|
88 |
+
return trimap_np
|
89 |
+
|
90 |
+
|
91 |
+
def predict(image, trimap):
|
92 |
+
image_np = read_image(image)
|
93 |
+
trimap_np = read_trimap(trimap)
|
94 |
+
return inference(image_np, trimap_np)
|
95 |
+
|
96 |
+
|
97 |
+
footer = r"""
|
98 |
+
<center>
|
99 |
+
<b>
|
100 |
+
Demo for <a href='https://github.com/MarcoForte/FBA_Matting'>FBA Matting</a>
|
101 |
+
</b>
|
102 |
+
</center>
|
103 |
+
"""
|
104 |
+
|
105 |
+
with gr.Blocks(title="FBA Matting") as app:
|
106 |
+
gr.HTML("<center><h1>FBA Matting</h1></center>")
|
107 |
+
gr.HTML("<center><h3>Foreground, Background, Alpha Matting Generator.</h3></center>")
|
108 |
+
with gr.Row().style(equal_height=False):
|
109 |
+
with gr.Column():
|
110 |
+
input_img = gr.Image(type="filepath", label="Input image")
|
111 |
+
input_trimap = gr.Image(type="filepath", label="Trimap image")
|
112 |
+
run_btn = gr.Button(variant="primary")
|
113 |
+
with gr.Column():
|
114 |
+
fg = gr.Image(type="numpy", label="Foreground")
|
115 |
+
bg = gr.Image(type="numpy", label="Background")
|
116 |
+
alpha = gr.Image(type="numpy", label="Alpha")
|
117 |
+
|
118 |
+
run_btn.click(predict, [input_img, input_trimap], [fg, bg, alpha])
|
119 |
+
|
120 |
+
with gr.Row():
|
121 |
+
examples_data = [[f"examples/{x:02d}.jpg"] for x in range(1, 4)]
|
122 |
+
examples = gr.Dataset(components=[input_img], samples=examples_data)
|
123 |
+
examples.click(lambda x: x[0], [examples], [input_img])
|
124 |
+
|
125 |
+
with gr.Row():
|
126 |
+
gr.HTML(footer)
|
127 |
+
|
128 |
+
app.launch(share=False, debug=True, enable_queue=True, show_error=True)
|
networks/__init__.py
ADDED
File without changes
|
networks/layers_WS.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
from torch.nn import functional as F
|
3 |
+
|
4 |
+
|
5 |
+
class Conv2d(nn.Conv2d):
|
6 |
+
|
7 |
+
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
|
8 |
+
padding=0, dilation=1, groups=1, bias=True, eps=1e-5):
|
9 |
+
super(Conv2d, self).__init__(in_channels, out_channels, kernel_size, stride,
|
10 |
+
padding, dilation, groups, bias)
|
11 |
+
self.out_channels = out_channels
|
12 |
+
self.eps = eps
|
13 |
+
|
14 |
+
def normalize_weight(self):
|
15 |
+
weight = F.batch_norm(
|
16 |
+
self.weight.view(1, self.out_channels, -1), None, None,
|
17 |
+
training=True, momentum=0., eps=self.eps).reshape_as(self.weight)
|
18 |
+
self.weight.data = weight
|
19 |
+
|
20 |
+
def forward(self, x):
|
21 |
+
if self.training:
|
22 |
+
self.normalize_weight()
|
23 |
+
return F.conv2d(x, self.weight, self.bias, self.stride,
|
24 |
+
self.padding, self.dilation, self.groups)
|
25 |
+
|
26 |
+
def train(self, mode: bool = True):
|
27 |
+
super().train(mode=mode)
|
28 |
+
self.normalize_weight()
|
29 |
+
|
30 |
+
|
31 |
+
def norm(dim):
|
32 |
+
return nn.GroupNorm(32, dim)
|
networks/models.py
ADDED
@@ -0,0 +1,221 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from networks.resnet_GN_WS import ResNet
|
4 |
+
import networks.layers_WS as L
|
5 |
+
|
6 |
+
|
7 |
+
def build_model(weights):
|
8 |
+
net_encoder = fba_encoder()
|
9 |
+
|
10 |
+
net_decoder = fba_decoder()
|
11 |
+
|
12 |
+
model = MattingModule(net_encoder, net_decoder)
|
13 |
+
|
14 |
+
if weights != 'default':
|
15 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
16 |
+
sd = torch.load(weights, map_location=device)
|
17 |
+
model.load_state_dict(sd, strict=True)
|
18 |
+
|
19 |
+
return model
|
20 |
+
|
21 |
+
|
22 |
+
class MattingModule(nn.Module):
|
23 |
+
def __init__(self, net_enc, net_dec):
|
24 |
+
super(MattingModule, self).__init__()
|
25 |
+
self.encoder = net_enc
|
26 |
+
self.decoder = net_dec
|
27 |
+
|
28 |
+
def forward(self, image, two_chan_trimap, image_n, trimap_transformed):
|
29 |
+
resnet_input = torch.cat((image_n, trimap_transformed, two_chan_trimap), 1)
|
30 |
+
conv_out, indices = self.encoder(resnet_input, return_feature_maps=True)
|
31 |
+
return self.decoder(conv_out, image, indices, two_chan_trimap)
|
32 |
+
|
33 |
+
|
34 |
+
def fba_encoder():
|
35 |
+
orig_resnet = ResNet()
|
36 |
+
net_encoder = ResnetDilated(orig_resnet, dilate_scale=8)
|
37 |
+
|
38 |
+
num_channels = 3 + 6 + 2
|
39 |
+
|
40 |
+
print(f'modifying input layer to accept {num_channels} channels')
|
41 |
+
net_encoder_sd = net_encoder.state_dict()
|
42 |
+
conv1_weights = net_encoder_sd['conv1.weight']
|
43 |
+
|
44 |
+
c_out, c_in, h, w = conv1_weights.size()
|
45 |
+
conv1_mod = torch.zeros(c_out, num_channels, h, w)
|
46 |
+
conv1_mod[:, :3, :, :] = conv1_weights
|
47 |
+
|
48 |
+
conv1 = net_encoder.conv1
|
49 |
+
conv1.in_channels = num_channels
|
50 |
+
conv1.weight = torch.nn.Parameter(conv1_mod)
|
51 |
+
|
52 |
+
net_encoder.conv1 = conv1
|
53 |
+
|
54 |
+
net_encoder_sd['conv1.weight'] = conv1_mod
|
55 |
+
|
56 |
+
net_encoder.load_state_dict(net_encoder_sd)
|
57 |
+
return net_encoder
|
58 |
+
|
59 |
+
|
60 |
+
class ResnetDilated(nn.Module):
|
61 |
+
def __init__(self, orig_resnet, dilate_scale=8):
|
62 |
+
super(ResnetDilated, self).__init__()
|
63 |
+
from functools import partial
|
64 |
+
|
65 |
+
if dilate_scale == 8:
|
66 |
+
orig_resnet.layer3.apply(
|
67 |
+
partial(self._nostride_dilate, dilate=2))
|
68 |
+
orig_resnet.layer4.apply(
|
69 |
+
partial(self._nostride_dilate, dilate=4))
|
70 |
+
elif dilate_scale == 16:
|
71 |
+
orig_resnet.layer4.apply(
|
72 |
+
partial(self._nostride_dilate, dilate=2))
|
73 |
+
|
74 |
+
# take pretrained resnet, except AvgPool and FC
|
75 |
+
self.conv1 = orig_resnet.conv1
|
76 |
+
self.bn1 = orig_resnet.bn1
|
77 |
+
self.relu = orig_resnet.relu
|
78 |
+
self.maxpool = orig_resnet.maxpool
|
79 |
+
self.layer1 = orig_resnet.layer1
|
80 |
+
self.layer2 = orig_resnet.layer2
|
81 |
+
self.layer3 = orig_resnet.layer3
|
82 |
+
self.layer4 = orig_resnet.layer4
|
83 |
+
|
84 |
+
def _nostride_dilate(self, m, dilate):
|
85 |
+
classname = m.__class__.__name__
|
86 |
+
if classname.find('Conv') != -1:
|
87 |
+
# the convolution with stride
|
88 |
+
if m.stride == (2, 2):
|
89 |
+
m.stride = (1, 1)
|
90 |
+
if m.kernel_size == (3, 3):
|
91 |
+
m.dilation = (dilate // 2, dilate // 2)
|
92 |
+
m.padding = (dilate // 2, dilate // 2)
|
93 |
+
# other convoluions
|
94 |
+
else:
|
95 |
+
if m.kernel_size == (3, 3):
|
96 |
+
m.dilation = (dilate, dilate)
|
97 |
+
m.padding = (dilate, dilate)
|
98 |
+
|
99 |
+
def forward(self, x, return_feature_maps=False):
|
100 |
+
conv_out = [x]
|
101 |
+
x = self.relu(self.bn1(self.conv1(x)))
|
102 |
+
conv_out.append(x)
|
103 |
+
x, indices = self.maxpool(x)
|
104 |
+
x = self.layer1(x)
|
105 |
+
conv_out.append(x)
|
106 |
+
x = self.layer2(x)
|
107 |
+
conv_out.append(x)
|
108 |
+
x = self.layer3(x)
|
109 |
+
conv_out.append(x)
|
110 |
+
x = self.layer4(x)
|
111 |
+
conv_out.append(x)
|
112 |
+
|
113 |
+
if return_feature_maps:
|
114 |
+
return conv_out, indices
|
115 |
+
return [x]
|
116 |
+
|
117 |
+
|
118 |
+
def fba_fusion(alpha, img, F, B):
|
119 |
+
F = (alpha * img + (1 - alpha ** 2) * F - alpha * (1 - alpha) * B)
|
120 |
+
B = ((1 - alpha) * img + (2 * alpha - alpha ** 2) * B - alpha * (1 - alpha) * F)
|
121 |
+
|
122 |
+
F = torch.clamp(F, 0, 1)
|
123 |
+
B = torch.clamp(B, 0, 1)
|
124 |
+
la = 0.1
|
125 |
+
alpha = (alpha * la + torch.sum((img - B) * (F - B), 1, keepdim=True)) / (
|
126 |
+
torch.sum((F - B) * (F - B), 1, keepdim=True) + la)
|
127 |
+
alpha = torch.clamp(alpha, 0, 1)
|
128 |
+
return alpha, F, B
|
129 |
+
|
130 |
+
|
131 |
+
class fba_decoder(nn.Module):
|
132 |
+
def __init__(self):
|
133 |
+
super(fba_decoder, self).__init__()
|
134 |
+
pool_scales = (1, 2, 3, 6)
|
135 |
+
|
136 |
+
self.ppm = []
|
137 |
+
|
138 |
+
for scale in pool_scales:
|
139 |
+
self.ppm.append(nn.Sequential(
|
140 |
+
nn.AdaptiveAvgPool2d(scale),
|
141 |
+
L.Conv2d(2048, 256, kernel_size=1, bias=True),
|
142 |
+
L.norm(256),
|
143 |
+
nn.LeakyReLU()
|
144 |
+
))
|
145 |
+
self.ppm = nn.ModuleList(self.ppm)
|
146 |
+
|
147 |
+
self.conv_up1 = nn.Sequential(
|
148 |
+
L.Conv2d(2048 + len(pool_scales) * 256, 256,
|
149 |
+
kernel_size=3, padding=1, bias=True),
|
150 |
+
|
151 |
+
L.norm(256),
|
152 |
+
nn.LeakyReLU(),
|
153 |
+
L.Conv2d(256, 256, kernel_size=3, padding=1),
|
154 |
+
L.norm(256),
|
155 |
+
nn.LeakyReLU()
|
156 |
+
)
|
157 |
+
|
158 |
+
self.conv_up2 = nn.Sequential(
|
159 |
+
L.Conv2d(256 + 256, 256,
|
160 |
+
kernel_size=3, padding=1, bias=True),
|
161 |
+
L.norm(256),
|
162 |
+
nn.LeakyReLU()
|
163 |
+
)
|
164 |
+
self.conv_up3 = nn.Sequential(
|
165 |
+
L.Conv2d(256 + 64, 64,
|
166 |
+
kernel_size=3, padding=1, bias=True),
|
167 |
+
L.norm(64),
|
168 |
+
nn.LeakyReLU()
|
169 |
+
)
|
170 |
+
|
171 |
+
self.unpool = nn.MaxUnpool2d(2, stride=2)
|
172 |
+
|
173 |
+
self.conv_up4 = nn.Sequential(
|
174 |
+
nn.Conv2d(64 + 3 + 3 + 2, 32,
|
175 |
+
kernel_size=3, padding=1, bias=True),
|
176 |
+
nn.LeakyReLU(),
|
177 |
+
nn.Conv2d(32, 16,
|
178 |
+
kernel_size=3, padding=1, bias=True),
|
179 |
+
|
180 |
+
nn.LeakyReLU(),
|
181 |
+
nn.Conv2d(16, 7, kernel_size=1, padding=0, bias=True)
|
182 |
+
)
|
183 |
+
|
184 |
+
def forward(self, conv_out, img, indices, two_chan_trimap):
|
185 |
+
conv5 = conv_out[-1]
|
186 |
+
|
187 |
+
input_size = conv5.size()
|
188 |
+
ppm_out = [conv5]
|
189 |
+
for pool_scale in self.ppm:
|
190 |
+
ppm_out.append(nn.functional.interpolate(
|
191 |
+
pool_scale(conv5),
|
192 |
+
(input_size[2], input_size[3]),
|
193 |
+
mode='bilinear', align_corners=False))
|
194 |
+
ppm_out = torch.cat(ppm_out, 1)
|
195 |
+
x = self.conv_up1(ppm_out)
|
196 |
+
|
197 |
+
x = torch.nn.functional.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
|
198 |
+
|
199 |
+
x = torch.cat((x, conv_out[-4]), 1)
|
200 |
+
|
201 |
+
x = self.conv_up2(x)
|
202 |
+
x = torch.nn.functional.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
|
203 |
+
|
204 |
+
x = torch.cat((x, conv_out[-5]), 1)
|
205 |
+
x = self.conv_up3(x)
|
206 |
+
|
207 |
+
x = torch.nn.functional.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
|
208 |
+
x = torch.cat((x, conv_out[-6][:, :3], img, two_chan_trimap), 1)
|
209 |
+
|
210 |
+
output = self.conv_up4(x)
|
211 |
+
|
212 |
+
alpha = torch.clamp(output[:, 0][:, None], 0, 1)
|
213 |
+
F = torch.sigmoid(output[:, 1:4])
|
214 |
+
B = torch.sigmoid(output[:, 4:7])
|
215 |
+
|
216 |
+
# FBA Fusion
|
217 |
+
alpha, F, B = fba_fusion(alpha, img, F, B)
|
218 |
+
|
219 |
+
output = torch.cat((alpha, F, B), 1)
|
220 |
+
|
221 |
+
return output
|
networks/resnet_GN_WS.py
ADDED
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
import networks.layers_WS as L
|
3 |
+
|
4 |
+
__all__ = ['ResNet', 'l_resnet50']
|
5 |
+
|
6 |
+
|
7 |
+
def conv3x3(in_planes, out_planes, stride=1):
|
8 |
+
"""3x3 convolution with padding"""
|
9 |
+
return L.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
|
10 |
+
padding=1, bias=False)
|
11 |
+
|
12 |
+
|
13 |
+
def conv1x1(in_planes, out_planes, stride=1):
|
14 |
+
"""1x1 convolution"""
|
15 |
+
return L.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
|
16 |
+
|
17 |
+
|
18 |
+
class Bottleneck(nn.Module):
|
19 |
+
expansion = 4
|
20 |
+
|
21 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
22 |
+
super(Bottleneck, self).__init__()
|
23 |
+
self.conv1 = conv1x1(inplanes, planes)
|
24 |
+
self.bn1 = L.norm(planes)
|
25 |
+
self.conv2 = conv3x3(planes, planes, stride)
|
26 |
+
self.bn2 = L.norm(planes)
|
27 |
+
self.conv3 = conv1x1(planes, planes * self.expansion)
|
28 |
+
self.bn3 = L.norm(planes * self.expansion)
|
29 |
+
self.relu = nn.ReLU(inplace=True)
|
30 |
+
self.downsample = downsample
|
31 |
+
self.stride = stride
|
32 |
+
|
33 |
+
def forward(self, x):
|
34 |
+
identity = x
|
35 |
+
|
36 |
+
out = self.conv1(x)
|
37 |
+
out = self.bn1(out)
|
38 |
+
out = self.relu(out)
|
39 |
+
|
40 |
+
out = self.conv2(out)
|
41 |
+
out = self.bn2(out)
|
42 |
+
out = self.relu(out)
|
43 |
+
|
44 |
+
out = self.conv3(out)
|
45 |
+
out = self.bn3(out)
|
46 |
+
|
47 |
+
if self.downsample is not None:
|
48 |
+
identity = self.downsample(x)
|
49 |
+
|
50 |
+
out += identity
|
51 |
+
out = self.relu(out)
|
52 |
+
|
53 |
+
return out
|
54 |
+
|
55 |
+
|
56 |
+
class ResNet(nn.Module):
|
57 |
+
|
58 |
+
def __init__(self, block=Bottleneck, layers=[3, 4, 6, 3], num_classes=1000):
|
59 |
+
super(ResNet, self).__init__()
|
60 |
+
self.inplanes = 64
|
61 |
+
self.conv1 = L.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
|
62 |
+
bias=False)
|
63 |
+
self.bn1 = L.norm(64)
|
64 |
+
self.relu = nn.ReLU(inplace=True)
|
65 |
+
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1, return_indices=True)
|
66 |
+
self.layer1 = self._make_layer(block, 64, layers[0])
|
67 |
+
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
|
68 |
+
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
|
69 |
+
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
|
70 |
+
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
71 |
+
self.fc = nn.Linear(512 * block.expansion, num_classes)
|
72 |
+
|
73 |
+
def _make_layer(self, block, planes, blocks, stride=1):
|
74 |
+
downsample = None
|
75 |
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
76 |
+
downsample = nn.Sequential(
|
77 |
+
conv1x1(self.inplanes, planes * block.expansion, stride),
|
78 |
+
L.norm(planes * block.expansion),
|
79 |
+
)
|
80 |
+
|
81 |
+
layers = []
|
82 |
+
layers.append(block(self.inplanes, planes, stride, downsample))
|
83 |
+
self.inplanes = planes * block.expansion
|
84 |
+
for _ in range(1, blocks):
|
85 |
+
layers.append(block(self.inplanes, planes))
|
86 |
+
|
87 |
+
return nn.Sequential(*layers)
|
88 |
+
|
89 |
+
def forward(self, x):
|
90 |
+
x = self.conv1(x)
|
91 |
+
x = self.bn1(x)
|
92 |
+
x = self.relu(x)
|
93 |
+
x = self.maxpool(x)
|
94 |
+
|
95 |
+
x = self.layer1(x)
|
96 |
+
x = self.layer2(x)
|
97 |
+
x = self.layer3(x)
|
98 |
+
x = self.layer4(x)
|
99 |
+
|
100 |
+
x = self.avgpool(x)
|
101 |
+
x = x.view(x.size(0), -1)
|
102 |
+
x = self.fc(x)
|
103 |
+
|
104 |
+
return x
|
networks/resnet_bn.py
ADDED
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
import math
|
3 |
+
from torch.nn import BatchNorm2d
|
4 |
+
|
5 |
+
__all__ = ['ResNet']
|
6 |
+
|
7 |
+
|
8 |
+
def conv3x3(in_planes, out_planes, stride=1):
|
9 |
+
"3x3 convolution with padding"
|
10 |
+
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
|
11 |
+
padding=1, bias=False)
|
12 |
+
|
13 |
+
|
14 |
+
class BasicBlock(nn.Module):
|
15 |
+
expansion = 1
|
16 |
+
|
17 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
18 |
+
super(BasicBlock, self).__init__()
|
19 |
+
self.conv1 = conv3x3(inplanes, planes, stride)
|
20 |
+
self.bn1 = BatchNorm2d(planes)
|
21 |
+
self.relu = nn.ReLU(inplace=True)
|
22 |
+
self.conv2 = conv3x3(planes, planes)
|
23 |
+
self.bn2 = BatchNorm2d(planes)
|
24 |
+
self.downsample = downsample
|
25 |
+
self.stride = stride
|
26 |
+
|
27 |
+
def forward(self, x):
|
28 |
+
residual = x
|
29 |
+
|
30 |
+
out = self.conv1(x)
|
31 |
+
out = self.bn1(out)
|
32 |
+
out = self.relu(out)
|
33 |
+
|
34 |
+
out = self.conv2(out)
|
35 |
+
out = self.bn2(out)
|
36 |
+
|
37 |
+
if self.downsample is not None:
|
38 |
+
residual = self.downsample(x)
|
39 |
+
|
40 |
+
out += residual
|
41 |
+
out = self.relu(out)
|
42 |
+
|
43 |
+
return out
|
44 |
+
|
45 |
+
|
46 |
+
class Bottleneck(nn.Module):
|
47 |
+
expansion = 4
|
48 |
+
|
49 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
50 |
+
super(Bottleneck, self).__init__()
|
51 |
+
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
|
52 |
+
self.bn1 = BatchNorm2d(planes)
|
53 |
+
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
|
54 |
+
padding=1, bias=False)
|
55 |
+
self.bn2 = BatchNorm2d(planes, momentum=0.01)
|
56 |
+
self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
|
57 |
+
self.bn3 = BatchNorm2d(planes * 4)
|
58 |
+
self.relu = nn.ReLU(inplace=True)
|
59 |
+
self.downsample = downsample
|
60 |
+
self.stride = stride
|
61 |
+
|
62 |
+
def forward(self, x):
|
63 |
+
residual = x
|
64 |
+
|
65 |
+
out = self.conv1(x)
|
66 |
+
out = self.bn1(out)
|
67 |
+
out = self.relu(out)
|
68 |
+
|
69 |
+
out = self.conv2(out)
|
70 |
+
out = self.bn2(out)
|
71 |
+
out = self.relu(out)
|
72 |
+
|
73 |
+
out = self.conv3(out)
|
74 |
+
out = self.bn3(out)
|
75 |
+
|
76 |
+
if self.downsample is not None:
|
77 |
+
residual = self.downsample(x)
|
78 |
+
|
79 |
+
out += residual
|
80 |
+
out = self.relu(out)
|
81 |
+
|
82 |
+
return out
|
83 |
+
|
84 |
+
|
85 |
+
class ResNet(nn.Module):
|
86 |
+
|
87 |
+
def __init__(self, block, layers, num_classes=1000):
|
88 |
+
self.inplanes = 128
|
89 |
+
super(ResNet, self).__init__()
|
90 |
+
self.conv1 = conv3x3(3, 64, stride=2)
|
91 |
+
self.bn1 = BatchNorm2d(64)
|
92 |
+
self.relu1 = nn.ReLU(inplace=True)
|
93 |
+
self.conv2 = conv3x3(64, 64)
|
94 |
+
self.bn2 = BatchNorm2d(64)
|
95 |
+
self.relu2 = nn.ReLU(inplace=True)
|
96 |
+
self.conv3 = conv3x3(64, 128)
|
97 |
+
self.bn3 = BatchNorm2d(128)
|
98 |
+
self.relu3 = nn.ReLU(inplace=True)
|
99 |
+
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1, return_indices=True)
|
100 |
+
|
101 |
+
self.layer1 = self._make_layer(block, 64, layers[0])
|
102 |
+
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
|
103 |
+
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
|
104 |
+
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
|
105 |
+
self.avgpool = nn.AvgPool2d(7, stride=1)
|
106 |
+
self.fc = nn.Linear(512 * block.expansion, num_classes)
|
107 |
+
|
108 |
+
for m in self.modules():
|
109 |
+
if isinstance(m, nn.Conv2d):
|
110 |
+
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
111 |
+
m.weight.data.normal_(0, math.sqrt(2. / n))
|
112 |
+
elif isinstance(m, BatchNorm2d):
|
113 |
+
m.weight.data.fill_(1)
|
114 |
+
m.bias.data.zero_()
|
115 |
+
|
116 |
+
def _make_layer(self, block, planes, blocks, stride=1):
|
117 |
+
downsample = None
|
118 |
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
119 |
+
downsample = nn.Sequential(
|
120 |
+
nn.Conv2d(self.inplanes, planes * block.expansion,
|
121 |
+
kernel_size=1, stride=stride, bias=False),
|
122 |
+
BatchNorm2d(planes * block.expansion),
|
123 |
+
)
|
124 |
+
|
125 |
+
layers = []
|
126 |
+
layers.append(block(self.inplanes, planes, stride, downsample))
|
127 |
+
self.inplanes = planes * block.expansion
|
128 |
+
for i in range(1, blocks):
|
129 |
+
layers.append(block(self.inplanes, planes))
|
130 |
+
|
131 |
+
return nn.Sequential(*layers)
|
132 |
+
|
133 |
+
def forward(self, x):
|
134 |
+
x = self.relu1(self.bn1(self.conv1(x)))
|
135 |
+
x = self.relu2(self.bn2(self.conv2(x)))
|
136 |
+
x = self.relu3(self.bn3(self.conv3(x)))
|
137 |
+
x, indices = self.maxpool(x)
|
138 |
+
|
139 |
+
x = self.layer1(x)
|
140 |
+
x = self.layer2(x)
|
141 |
+
x = self.layer3(x)
|
142 |
+
x = self.layer4(x)
|
143 |
+
|
144 |
+
x = self.avgpool(x)
|
145 |
+
x = x.view(x.size(0), -1)
|
146 |
+
x = self.fc(x)
|
147 |
+
return x
|
148 |
+
|
149 |
+
|
150 |
+
def l_resnet50():
|
151 |
+
"""Constructs a ResNet-50 model.
|
152 |
+
Args:
|
153 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
154 |
+
"""
|
155 |
+
model = ResNet(Bottleneck, [3, 4, 6, 3])
|
156 |
+
return model
|
networks/transforms.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import cv2
|
4 |
+
|
5 |
+
|
6 |
+
def dt(a):
|
7 |
+
return cv2.distanceTransform((a * 255).astype(np.uint8), cv2.DIST_L2, 0)
|
8 |
+
|
9 |
+
|
10 |
+
def trimap_transform(trimap, L=320):
|
11 |
+
clicks = []
|
12 |
+
for k in range(2):
|
13 |
+
dt_mask = -dt(1 - trimap[:, :, k]) ** 2
|
14 |
+
clicks.append(np.exp(dt_mask / (2 * ((0.02 * L) ** 2))))
|
15 |
+
clicks.append(np.exp(dt_mask / (2 * ((0.08 * L) ** 2))))
|
16 |
+
clicks.append(np.exp(dt_mask / (2 * ((0.16 * L) ** 2))))
|
17 |
+
clicks = np.array(clicks)
|
18 |
+
return clicks
|
19 |
+
|
20 |
+
|
21 |
+
# For RGB !
|
22 |
+
imagenet_norm_std = torch.from_numpy(np.array([0.229, 0.224, 0.225])).float().cpu()[None, :, None, None]
|
23 |
+
imagenet_norm_mean = torch.from_numpy(np.array([0.485, 0.456, 0.406])).float().cpu()[None, :, None, None]
|
24 |
+
|
25 |
+
|
26 |
+
def normalise_image(image, mean=imagenet_norm_mean, std=imagenet_norm_std):
|
27 |
+
return (image - mean) / std
|
requirements.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
torch>=1.4.0
|
2 |
+
numpy
|
3 |
+
opencv-python
|