marta-0 commited on
Commit
6da6215
1 Parent(s): 6789c0e
Files changed (38) hide show
  1. app.py +53 -0
  2. configs/modnet/modnet_hrnet_w18.yml +5 -0
  3. configs/modnet/modnet_mobilenetv2.yml +49 -0
  4. configs/modnet/modnet_resnet50_vd.yml +5 -0
  5. images/armchair.jpg +0 -0
  6. images/cat.jpg +0 -0
  7. images/plant.jpg +0 -0
  8. matting/__pycache__/transforms.cpython-37.pyc +0 -0
  9. matting/__pycache__/utils.cpython-37.pyc +0 -0
  10. matting/core/__init__.py +1 -0
  11. matting/core/__pycache__/__init__.cpython-37.pyc +0 -0
  12. matting/core/__pycache__/predict.cpython-37.pyc +0 -0
  13. matting/core/predict.py +163 -0
  14. matting/dataset/__init__.py +15 -0
  15. matting/dataset/__pycache__/__init__.cpython-37.pyc +0 -0
  16. matting/dataset/__pycache__/matting_dataset.cpython-37.pyc +0 -0
  17. matting/dataset/matting_dataset.py +229 -0
  18. matting/model/__init__.py +21 -0
  19. matting/model/__pycache__/__init__.cpython-37.pyc +0 -0
  20. matting/model/__pycache__/dim.cpython-37.pyc +0 -0
  21. matting/model/__pycache__/hrnet.cpython-37.pyc +0 -0
  22. matting/model/__pycache__/loss.cpython-37.pyc +0 -0
  23. matting/model/__pycache__/mobilenet_v2.cpython-37.pyc +0 -0
  24. matting/model/__pycache__/modnet.cpython-37.pyc +0 -0
  25. matting/model/__pycache__/resnet_vd.cpython-37.pyc +0 -0
  26. matting/model/__pycache__/vgg.cpython-37.pyc +0 -0
  27. matting/model/dim.py +203 -0
  28. matting/model/hrnet.py +835 -0
  29. matting/model/loss.py +51 -0
  30. matting/model/mobilenet_v2.py +241 -0
  31. matting/model/modnet.py +481 -0
  32. matting/model/resnet_vd.py +368 -0
  33. matting/model/vgg.py +166 -0
  34. matting/transforms.py +530 -0
  35. matting/utils.py +70 -0
  36. requirements.txt +2 -0
  37. train.txt +0 -0
  38. val.txt +0 -0
app.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ import gradio as gr
3
+
4
+ import paddle
5
+ from paddleseg.cvlibs import Config
6
+
7
+ from matting.core import predict
8
+ from matting.model import *
9
+ from matting.dataset import MattingDataset
10
+
11
+
12
+ def download_file(http_address, file_name):
13
+ r = requests.get(http_address, allow_redirects=True)
14
+ open(file_name, 'wb').write(r.content)
15
+
16
+
17
+ cfgs = ['configs/modnet/modnet_mobilenetv2.yml', 'configs/modnet/modnet_resnet50_vd.yml', 'configs/modnet/modnet_hrnet_w18.yml']
18
+
19
+ download_file('https://paddleseg.bj.bcebos.com/matting/models/modnet-mobilenetv2.pdparams', 'modnet-mobilenetv2.pdparams')
20
+ download_file('https://paddleseg.bj.bcebos.com/matting/models/modnet-resnet50_vd.pdparams', 'modnet-resnet50_vd.pdparams')
21
+ download_file('https://paddleseg.bj.bcebos.com/matting/models/modnet-hrnet_w18.pdparams', 'modnet-hrnet_w18.pdparams')
22
+ models_paths = ['modnet-mobilenetv2.pdparams', 'modnet-resnet50_vd.pdparams', 'modnet-hrnet_w18.pdparams']
23
+
24
+
25
+ def inference(image, chosen_model):
26
+ paddle.set_device('cpu')
27
+ cfg = Config(cfgs[chosen_model])
28
+
29
+ val_dataset = cfg.val_dataset
30
+ model = cfg.model
31
+ img_transforms = val_dataset.transforms
32
+
33
+ alpha_pred = predict(model,
34
+ model_path=models_paths[chosen_model],
35
+ transforms=img_transforms,
36
+ image_list=[image])
37
+
38
+ return alpha_pred
39
+
40
+
41
+ inputs = [gr.inputs.Image(label='Input Image'),
42
+ gr.inputs.Radio(['MobileNetV2', 'ResNet50_vd', 'HRNet_W18'], label='Model', type='index')]
43
+
44
+
45
+ gr.Interface(
46
+ inference,
47
+ inputs,
48
+ gr.outputs.Image(label='Output'),
49
+ title='PaddleSeg - Matting',
50
+ examples=[['images/armchair.jpg', 'MobileNetV2'],
51
+ ['images/cat.jpg', 'ResNet50_vd'],
52
+ ['images/plant.jpg', 'HRNet_W18']]
53
+ ).launch()
configs/modnet/modnet_hrnet_w18.yml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
1
+ _base_: modnet_mobilenetv2.yml
2
+ model:
3
+ backbone:
4
+ type: HRNet_W18
5
+ pretrained: https://bj.bcebos.com/paddleseg/dygraph/hrnet_w18_ssld.tar.gz
configs/modnet/modnet_mobilenetv2.yml ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ batch_size: 16
2
+ iters: 100000
3
+
4
+ train_dataset:
5
+ type: MattingDataset
6
+ dataset_root: .
7
+ train_file: train.txt
8
+ transforms:
9
+ # - type: LoadImages
10
+ - type: ResizeByShort
11
+ short_size: 512
12
+ - type: RandomCrop
13
+ crop_size: [512, 512]
14
+ - type: RandomDistort
15
+ - type: RandomBlur
16
+ - type: RandomHorizontalFlip
17
+ - type: Normalize
18
+ mode: train
19
+
20
+ val_dataset:
21
+ type: MattingDataset
22
+ dataset_root: .
23
+ val_file: val.txt
24
+ transforms:
25
+ # - type: LoadImages
26
+ - type: ResizeByShort
27
+ short_size: 512
28
+ - type: ResizeToIntMult
29
+ mult_int: 32
30
+ - type: Normalize
31
+ mode: val
32
+ get_trimap: False
33
+
34
+ model:
35
+ type: MODNet
36
+ backbone:
37
+ type: MobileNetV2
38
+ pretrained: https://paddleseg.bj.bcebos.com/matting/models/MobileNetV2_pretrained/model.pdparams
39
+ pretrained: Null
40
+
41
+ optimizer:
42
+ type: sgd
43
+ momentum: 0.9
44
+ weight_decay: 4.0e-5
45
+
46
+ lr_scheduler:
47
+ type: PiecewiseDecay
48
+ boundaries: [40000, 80000]
49
+ values: [0.02, 0.002, 0.0002]
configs/modnet/modnet_resnet50_vd.yml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
1
+ _base_: modnet_mobilenetv2.yml
2
+ model:
3
+ backbone:
4
+ type: ResNet50_vd
5
+ pretrained: https://bj.bcebos.com/paddleseg/dygraph/resnet50_vd_ssld_v2.tar.gz
images/armchair.jpg ADDED
images/cat.jpg ADDED
images/plant.jpg ADDED
matting/__pycache__/transforms.cpython-37.pyc ADDED
Binary file (15.8 kB). View file
matting/__pycache__/utils.cpython-37.pyc ADDED
Binary file (1.66 kB). View file
matting/core/__init__.py ADDED
@@ -0,0 +1 @@
 
1
+ from .predict import predict
matting/core/__pycache__/__init__.cpython-37.pyc ADDED
Binary file (196 Bytes). View file
matting/core/__pycache__/predict.cpython-37.pyc ADDED
Binary file (4.07 kB). View file
matting/core/predict.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+ import math
17
+ import time
18
+
19
+ import cv2
20
+ import numpy as np
21
+ import paddle
22
+ import paddle.nn.functional as F
23
+ from paddleseg import utils
24
+ from paddleseg.core import infer
25
+ from paddleseg.utils import logger, progbar, TimeAverager
26
+
27
+ from matting.utils import mkdir
28
+
29
+
30
+ def partition_list(arr, m):
31
+ """split the list 'arr' into m pieces"""
32
+ n = int(math.ceil(len(arr) / float(m)))
33
+ return [arr[i:i + n] for i in range(0, len(arr), n)]
34
+
35
+
36
+ def save_alpha_pred(alpha, path, trimap=None):
37
+ """
38
+ The value of alpha is range [0, 1], shape should be [h,w]
39
+ """
40
+ dirname = os.path.dirname(path)
41
+ if not os.path.exists(dirname):
42
+ os.makedirs(dirname)
43
+
44
+ trimap = cv2.imread(trimap, 0)
45
+ alpha[trimap == 0] = 0
46
+ alpha[trimap == 255] = 255
47
+ alpha = (alpha).astype('uint8')
48
+ cv2.imwrite(path, alpha)
49
+
50
+
51
+ def reverse_transform(alpha, trans_info):
52
+ """recover pred to origin shape"""
53
+ for item in trans_info[::-1]:
54
+ if item[0] == 'resize':
55
+ h, w = item[1][0], item[1][1]
56
+ alpha = F.interpolate(alpha, [h, w], mode='bilinear')
57
+ elif item[0] == 'padding':
58
+ h, w = item[1][0], item[1][1]
59
+ alpha = alpha[:, :, 0:h, 0:w]
60
+ else:
61
+ raise Exception("Unexpected info '{}' in im_info".format(item[0]))
62
+ return alpha
63
+
64
+
65
+ def preprocess(img, transforms, trimap=None):
66
+ data = {}
67
+ data['img'] = img
68
+ if trimap is not None:
69
+ data['trimap'] = trimap
70
+ data['gt_fields'] = ['trimap']
71
+ data['trans_info'] = []
72
+ data = transforms(data)
73
+ data['img'] = paddle.to_tensor(data['img'])
74
+ data['img'] = data['img'].unsqueeze(0)
75
+ if trimap is not None:
76
+ data['trimap'] = paddle.to_tensor(data['trimap'])
77
+ data['trimap'] = data['trimap'].unsqueeze((0, 1))
78
+
79
+ return data
80
+
81
+
82
+ def predict(model,
83
+ model_path,
84
+ transforms,
85
+ image_list,
86
+ image_dir=None,
87
+ trimap_list=None,
88
+ save_dir='output'):
89
+ """
90
+ predict and visualize the image_list.
91
+
92
+ Args:
93
+ model (nn.Layer): Used to predict for input image.
94
+ model_path (str): The path of pretrained model.
95
+ transforms (transforms.Compose): Preprocess for input image.
96
+ image_list (list): A list of image path to be predicted.
97
+ image_dir (str, optional): The root directory of the images predicted. Default: None.
98
+ trimap_list (list, optional): A list of trimap of image_list. Default: None.
99
+ save_dir (str, optional): The directory to save the visualized results. Default: 'output'.
100
+ """
101
+ utils.utils.load_entire_model(model, model_path)
102
+ model.eval()
103
+ nranks = paddle.distributed.get_world_size()
104
+ local_rank = paddle.distributed.get_rank()
105
+ if nranks > 1:
106
+ img_lists = partition_list(image_list, nranks)
107
+ trimap_lists = partition_list(
108
+ trimap_list, nranks) if trimap_list is not None else None
109
+ else:
110
+ img_lists = [image_list]
111
+ trimap_lists = [trimap_list] if trimap_list is not None else None
112
+
113
+ logger.info("Start to predict...")
114
+ progbar_pred = progbar.Progbar(target=len(img_lists[0]), verbose=1)
115
+ preprocess_cost_averager = TimeAverager()
116
+ infer_cost_averager = TimeAverager()
117
+ postprocess_cost_averager = TimeAverager()
118
+ batch_start = time.time()
119
+ with paddle.no_grad():
120
+ for i, im_path in enumerate(img_lists[local_rank]):
121
+ preprocess_start = time.time()
122
+ trimap = trimap_lists[local_rank][
123
+ i] if trimap_list is not None else None
124
+ data = preprocess(img=im_path, transforms=transforms, trimap=trimap)
125
+ preprocess_cost_averager.record(time.time() - preprocess_start)
126
+
127
+ infer_start = time.time()
128
+ alpha_pred = model(data)
129
+ infer_cost_averager.record(time.time() - infer_start)
130
+
131
+ postprocess_start = time.time()
132
+ alpha_pred = reverse_transform(alpha_pred, data['trans_info'])
133
+ alpha_pred = (alpha_pred.numpy()).squeeze()
134
+ alpha_pred = (alpha_pred * 255).astype('uint8')
135
+
136
+ # get the saved name
137
+ # if image_dir is not None:
138
+ # im_file = im_path.replace(image_dir, '')
139
+ # else:
140
+ # im_file = os.path.basename(im_path)
141
+ # if im_file[0] == '/' or im_file[0] == '\\':
142
+ # im_file = im_file[1:]
143
+
144
+ # save_path = os.path.join(save_dir, im_file)
145
+ # mkdir(save_path)
146
+ # save_alpha_pred(alpha_pred, save_path, trimap=trimap)
147
+
148
+ postprocess_cost_averager.record(time.time() - postprocess_start)
149
+
150
+ preprocess_cost = preprocess_cost_averager.get_average()
151
+ infer_cost = infer_cost_averager.get_average()
152
+ postprocess_cost = postprocess_cost_averager.get_average()
153
+ if local_rank == 0:
154
+ progbar_pred.update(i + 1,
155
+ [('preprocess_cost', preprocess_cost),
156
+ ('infer_cost cost', infer_cost),
157
+ ('postprocess_cost', postprocess_cost)])
158
+
159
+ preprocess_cost_averager.reset()
160
+ infer_cost_averager.reset()
161
+ postprocess_cost_averager.reset()
162
+
163
+ return alpha_pred
matting/dataset/__init__.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .matting_dataset import MattingDataset
matting/dataset/__pycache__/__init__.cpython-37.pyc ADDED
Binary file (235 Bytes). View file
matting/dataset/__pycache__/matting_dataset.cpython-37.pyc ADDED
Binary file (5.67 kB). View file
matting/dataset/matting_dataset.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+ import math
17
+
18
+ import cv2
19
+ import numpy as np
20
+ import random
21
+ import paddle
22
+ from paddleseg.cvlibs import manager
23
+
24
+ import matting.transforms as T
25
+
26
+
27
+ @manager.DATASETS.add_component
28
+ class MattingDataset(paddle.io.Dataset):
29
+ """
30
+ Pass in a dataset that conforms to the format.
31
+ matting_dataset/
32
+ |--bg/
33
+ |
34
+ |--train/
35
+ | |--fg/
36
+ | |--alpha/
37
+ |
38
+ |--val/
39
+ | |--fg/
40
+ | |--alpha/
41
+ | |--trimap/ (if existing)
42
+ |
43
+ |--train.txt
44
+ |
45
+ |--val.txt
46
+ See README.md for more information of dataset.
47
+
48
+ Args:
49
+ dataset_root(str): The root path of dataset.
50
+ transforms(list): Transforms for image.
51
+ mode (str, optional): which part of dataset to use. it is one of ('train', 'val', 'trainval'). Default: 'train'.
52
+ train_file (str|list, optional): File list is used to train. It should be `foreground_image.png background_image.png`
53
+ or `foreground_image.png`. It shold be provided if mode equal to 'train'. Default: None.
54
+ val_file (str|list, optional): File list is used to evaluation. It should be `foreground_image.png background_image.png`
55
+ or `foreground_image.png` or ``foreground_image.png background_image.png trimap_image.png`.
56
+ It shold be provided if mode equal to 'val'. Default: None.
57
+ get_trimap (bool, optional): Whether to get triamp. Default: True.
58
+ separator (str, optional): The separator of train_file or val_file. If file name contains ' ', '|' may be perfect. Default: ' '.
59
+ """
60
+
61
+ def __init__(self,
62
+ dataset_root,
63
+ transforms,
64
+ mode='train',
65
+ train_file=None,
66
+ val_file=None,
67
+ get_trimap=True,
68
+ separator=' '):
69
+ super().__init__()
70
+ self.dataset_root = dataset_root
71
+ self.transforms = T.Compose(transforms)
72
+ self.mode = mode
73
+ self.get_trimap = get_trimap
74
+ self.separator = separator
75
+
76
+ # check file
77
+ if mode == 'train' or mode == 'trainval':
78
+ if train_file is None:
79
+ raise ValueError(
80
+ "When `mode` is 'train' or 'trainval', `train_file must be provided!"
81
+ )
82
+ if isinstance(train_file, str):
83
+ train_file = [train_file]
84
+ file_list = train_file
85
+
86
+ if mode == 'val' or mode == 'trainval':
87
+ if val_file is None:
88
+ raise ValueError(
89
+ "When `mode` is 'val' or 'trainval', `val_file must be provided!"
90
+ )
91
+ if isinstance(val_file, str):
92
+ val_file = [val_file]
93
+ file_list = val_file
94
+
95
+ if mode == 'trainval':
96
+ file_list = train_file + val_file
97
+
98
+ # read file
99
+ self.fg_bg_list = []
100
+ for file in file_list:
101
+ file = os.path.join(dataset_root, file)
102
+ with open(file, 'r') as f:
103
+ lines = f.readlines()
104
+ for line in lines:
105
+ line = line.strip()
106
+ self.fg_bg_list.append(line)
107
+
108
+ def __getitem__(self, idx):
109
+ data = {}
110
+ fg_bg_file = self.fg_bg_list[idx]
111
+ fg_bg_file = fg_bg_file.split(self.separator)
112
+ data['img_name'] = fg_bg_file[0] # using in save prediction results
113
+ fg_file = os.path.join(self.dataset_root, fg_bg_file[0])
114
+ alpha_file = fg_file.replace('/fg', '/alpha')
115
+ fg = cv2.imread(fg_file)
116
+ alpha = cv2.imread(alpha_file, 0)
117
+ data['alpha'] = alpha
118
+ data['gt_fields'] = []
119
+
120
+ # line is: fg [bg] [trimap]
121
+ if len(fg_bg_file) >= 2:
122
+ bg_file = os.path.join(self.dataset_root, fg_bg_file[1])
123
+ bg = cv2.imread(bg_file)
124
+ data['img'], data['bg'] = self.composite(fg, alpha, bg)
125
+ data['fg'] = fg
126
+ if self.mode in ['train', 'trainval']:
127
+ data['gt_fields'].append('fg')
128
+ data['gt_fields'].append('bg')
129
+ data['gt_fields'].append('alpha')
130
+ if len(fg_bg_file) == 3 and self.get_trimap:
131
+ if self.mode == 'val':
132
+ trimap_path = os.path.join(self.dataset_root, fg_bg_file[2])
133
+ if os.path.exists(trimap_path):
134
+ data['trimap'] = trimap_path
135
+ data['gt_fields'].append('trimap')
136
+ data['ori_trimap'] = cv2.imread(trimap_path, 0)
137
+ else:
138
+ raise FileNotFoundError(
139
+ 'trimap is not Found: {}'.format(fg_bg_file[2]))
140
+ else:
141
+ data['img'] = fg
142
+ if self.mode in ['train', 'trainval']:
143
+ data['fg'] = fg.copy()
144
+ data['bg'] = fg.copy()
145
+ data['gt_fields'].append('fg')
146
+ data['gt_fields'].append('bg')
147
+ data['gt_fields'].append('alpha')
148
+
149
+ data['trans_info'] = [] # Record shape change information
150
+
151
+ # Generate trimap from alpha if no trimap file provided
152
+ if self.get_trimap:
153
+ if 'trimap' not in data:
154
+ data['trimap'] = self.gen_trimap(
155
+ data['alpha'], mode=self.mode).astype('float32')
156
+ data['gt_fields'].append('trimap')
157
+ if self.mode == 'val':
158
+ data['ori_trimap'] = data['trimap'].copy()
159
+
160
+ data = self.transforms(data)
161
+
162
+ # When evaluation, gt should not be transforms.
163
+ if self.mode == 'val':
164
+ data['gt_fields'].append('alpha')
165
+
166
+ data['img'] = data['img'].astype('float32')
167
+ for key in data.get('gt_fields', []):
168
+ data[key] = data[key].astype('float32')
169
+
170
+ if 'trimap' in data:
171
+ data['trimap'] = data['trimap'][np.newaxis, :, :]
172
+ if 'ori_trimap' in data:
173
+ data['ori_trimap'] = data['ori_trimap'][np.newaxis, :, :]
174
+
175
+ data['alpha'] = data['alpha'][np.newaxis, :, :] / 255.
176
+
177
+ return data
178
+
179
+ def __len__(self):
180
+ return len(self.fg_bg_list)
181
+
182
+ def composite(self, fg, alpha, ori_bg):
183
+ fg_h, fg_w = fg.shape[:2]
184
+ ori_bg_h, ori_bg_w = ori_bg.shape[:2]
185
+
186
+ wratio = fg_w / ori_bg_w
187
+ hratio = fg_h / ori_bg_h
188
+ ratio = wratio if wratio > hratio else hratio
189
+
190
+ # Resize ori_bg if it is smaller than fg.
191
+ if ratio > 1:
192
+ resize_h = math.ceil(ori_bg_h * ratio)
193
+ resize_w = math.ceil(ori_bg_w * ratio)
194
+ bg = cv2.resize(
195
+ ori_bg, (resize_w, resize_h), interpolation=cv2.INTER_LINEAR)
196
+ else:
197
+ bg = ori_bg
198
+
199
+ bg = bg[0:fg_h, 0:fg_w, :]
200
+ alpha = alpha / 255
201
+ alpha = np.expand_dims(alpha, axis=2)
202
+ image = alpha * fg + (1 - alpha) * bg
203
+ image = image.astype(np.uint8)
204
+ return image, bg
205
+
206
+ @staticmethod
207
+ def gen_trimap(alpha, mode='train', eval_kernel=7):
208
+ if mode == 'train':
209
+ k_size = random.choice(range(2, 5))
210
+ iterations = np.random.randint(5, 15)
211
+ kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE,
212
+ (k_size, k_size))
213
+ dilated = cv2.dilate(alpha, kernel, iterations=iterations)
214
+ eroded = cv2.erode(alpha, kernel, iterations=iterations)
215
+ trimap = np.zeros(alpha.shape)
216
+ trimap.fill(128)
217
+ trimap[eroded > 254.5] = 255
218
+ trimap[dilated < 0.5] = 0
219
+ else:
220
+ k_size = eval_kernel
221
+ kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE,
222
+ (k_size, k_size))
223
+ dilated = cv2.dilate(alpha, kernel)
224
+ trimap = np.zeros(alpha.shape)
225
+ trimap.fill(128)
226
+ trimap[alpha >= 250] = 255
227
+ trimap[dilated <= 5] = 0
228
+
229
+ return trimap
matting/model/__init__.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .vgg import *
16
+ from .resnet_vd import *
17
+ from .mobilenet_v2 import *
18
+ from .hrnet import *
19
+ from .dim import DIM
20
+ from .loss import MRSD
21
+ from .modnet import MODNet
matting/model/__pycache__/__init__.cpython-37.pyc ADDED
Binary file (362 Bytes). View file
matting/model/__pycache__/dim.cpython-37.pyc ADDED
Binary file (5.55 kB). View file
matting/model/__pycache__/hrnet.cpython-37.pyc ADDED
Binary file (17.7 kB). View file
matting/model/__pycache__/loss.cpython-37.pyc ADDED
Binary file (1.45 kB). View file
matting/model/__pycache__/mobilenet_v2.cpython-37.pyc ADDED
Binary file (6.21 kB). View file
matting/model/__pycache__/modnet.cpython-37.pyc ADDED
Binary file (11.9 kB). View file
matting/model/__pycache__/resnet_vd.cpython-37.pyc ADDED
Binary file (7.74 kB). View file
matting/model/__pycache__/vgg.cpython-37.pyc ADDED
Binary file (4.08 kB). View file
matting/model/dim.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from collections import defaultdict
16
+ import paddle
17
+ import paddle.nn as nn
18
+ import paddle.nn.functional as F
19
+ from paddleseg.models import layers
20
+ from paddleseg import utils
21
+ from paddleseg.cvlibs import manager
22
+
23
+ from .loss import MRSD
24
+
25
+
26
+ @manager.MODELS.add_component
27
+ class DIM(nn.Layer):
28
+ """
29
+ The DIM implementation based on PaddlePaddle.
30
+
31
+ The original article refers to
32
+ Ning Xu, et, al. "Deep Image Matting"
33
+ (https://arxiv.org/pdf/1908.07919.pdf).
34
+
35
+ Args:
36
+ backbone: backbone model.
37
+ stage (int, optional): The stage of model. Defautl: 3.
38
+ decoder_input_channels(int, optional): The channel of decoder input. Default: 512.
39
+ pretrained(str, optional): The path of pretrianed model. Defautl: None.
40
+
41
+ """
42
+
43
+ def __init__(self,
44
+ backbone,
45
+ stage=3,
46
+ decoder_input_channels=512,
47
+ pretrained=None):
48
+ super().__init__()
49
+ self.backbone = backbone
50
+ self.pretrained = pretrained
51
+ self.stage = stage
52
+
53
+ decoder_output_channels = [64, 128, 256, 512]
54
+ self.decoder = Decoder(
55
+ input_channels=decoder_input_channels,
56
+ output_channels=decoder_output_channels)
57
+ if self.stage == 2:
58
+ for param in self.backbone.parameters():
59
+ param.stop_gradient = True
60
+ for param in self.decoder.parameters():
61
+ param.stop_gradient = True
62
+ if self.stage >= 2:
63
+ self.refine = Refine()
64
+ self.init_weight()
65
+
66
+ def forward(self, inputs):
67
+ input_shape = paddle.shape(inputs['img'])[-2:]
68
+ x = paddle.concat([inputs['img'], inputs['trimap'] / 255], axis=1)
69
+ fea_list = self.backbone(x)
70
+
71
+ # decoder stage
72
+ up_shape = []
73
+ for i in range(5):
74
+ up_shape.append(paddle.shape(fea_list[i])[-2:])
75
+ alpha_raw = self.decoder(fea_list, up_shape)
76
+ alpha_raw = F.interpolate(
77
+ alpha_raw, input_shape, mode='bilinear', align_corners=False)
78
+ logit_dict = {'alpha_raw': alpha_raw}
79
+ if self.stage < 2:
80
+ return logit_dict
81
+
82
+ if self.stage >= 2:
83
+ # refine stage
84
+ refine_input = paddle.concat([inputs['img'], alpha_raw], axis=1)
85
+ alpha_refine = self.refine(refine_input)
86
+
87
+ # finally alpha
88
+ alpha_pred = alpha_refine + alpha_raw
89
+ alpha_pred = F.interpolate(
90
+ alpha_pred, input_shape, mode='bilinear', align_corners=False)
91
+ if not self.training:
92
+ alpha_pred = paddle.clip(alpha_pred, min=0, max=1)
93
+ logit_dict['alpha_pred'] = alpha_pred
94
+ if self.training:
95
+ return logit_dict
96
+ else:
97
+ return alpha_pred
98
+
99
+ def loss(self, logit_dict, label_dict, loss_func_dict=None):
100
+ if loss_func_dict is None:
101
+ loss_func_dict = defaultdict(list)
102
+ loss_func_dict['alpha_raw'].append(MRSD())
103
+ loss_func_dict['comp'].append(MRSD())
104
+ loss_func_dict['alpha_pred'].append(MRSD())
105
+
106
+ loss = {}
107
+ mask = label_dict['trimap'] == 128
108
+ loss['all'] = 0
109
+
110
+ if self.stage != 2:
111
+ loss['alpha_raw'] = loss_func_dict['alpha_raw'][0](
112
+ logit_dict['alpha_raw'], label_dict['alpha'], mask)
113
+ loss['alpha_raw'] = 0.5 * loss['alpha_raw']
114
+ loss['all'] = loss['all'] + loss['alpha_raw']
115
+
116
+ if self.stage == 1 or self.stage == 3:
117
+ comp_pred = logit_dict['alpha_raw'] * label_dict['fg'] + \
118
+ (1 - logit_dict['alpha_raw']) * label_dict['bg']
119
+ loss['comp'] = loss_func_dict['comp'][0](comp_pred,
120
+ label_dict['img'], mask)
121
+ loss['comp'] = 0.5 * loss['comp']
122
+ loss['all'] = loss['all'] + loss['comp']
123
+
124
+ if self.stage == 2 or self.stage == 3:
125
+ loss['alpha_pred'] = loss_func_dict['alpha_pred'][0](
126
+ logit_dict['alpha_pred'], label_dict['alpha'], mask)
127
+ loss['all'] = loss['all'] + loss['alpha_pred']
128
+
129
+ return loss
130
+
131
+ def init_weight(self):
132
+ if self.pretrained is not None:
133
+ utils.load_entire_model(self, self.pretrained)
134
+
135
+
136
+ # bilinear interpolate skip connect
137
+ class Up(nn.Layer):
138
+ def __init__(self, input_channels, output_channels):
139
+ super().__init__()
140
+ self.conv = layers.ConvBNReLU(
141
+ input_channels,
142
+ output_channels,
143
+ kernel_size=5,
144
+ padding=2,
145
+ bias_attr=False)
146
+
147
+ def forward(self, x, skip, output_shape):
148
+ x = F.interpolate(
149
+ x, size=output_shape, mode='bilinear', align_corners=False)
150
+ x = x + skip
151
+ x = self.conv(x)
152
+ x = F.relu(x)
153
+
154
+ return x
155
+
156
+
157
+ class Decoder(nn.Layer):
158
+ def __init__(self, input_channels, output_channels=(64, 128, 256, 512)):
159
+ super().__init__()
160
+ self.deconv6 = nn.Conv2D(
161
+ input_channels, input_channels, kernel_size=1, bias_attr=False)
162
+ self.deconv5 = Up(input_channels, output_channels[-1])
163
+ self.deconv4 = Up(output_channels[-1], output_channels[-2])
164
+ self.deconv3 = Up(output_channels[-2], output_channels[-3])
165
+ self.deconv2 = Up(output_channels[-3], output_channels[-4])
166
+ self.deconv1 = Up(output_channels[-4], 64)
167
+
168
+ self.alpha_conv = nn.Conv2D(
169
+ 64, 1, kernel_size=5, padding=2, bias_attr=False)
170
+
171
+ def forward(self, fea_list, shape_list):
172
+ x = fea_list[-1]
173
+ x = self.deconv6(x)
174
+ x = self.deconv5(x, fea_list[4], shape_list[4])
175
+ x = self.deconv4(x, fea_list[3], shape_list[3])
176
+ x = self.deconv3(x, fea_list[2], shape_list[2])
177
+ x = self.deconv2(x, fea_list[1], shape_list[1])
178
+ x = self.deconv1(x, fea_list[0], shape_list[0])
179
+ alpha = self.alpha_conv(x)
180
+ alpha = F.sigmoid(alpha)
181
+
182
+ return alpha
183
+
184
+
185
+ class Refine(nn.Layer):
186
+ def __init__(self):
187
+ super().__init__()
188
+ self.conv1 = layers.ConvBNReLU(
189
+ 4, 64, kernel_size=3, padding=1, bias_attr=False)
190
+ self.conv2 = layers.ConvBNReLU(
191
+ 64, 64, kernel_size=3, padding=1, bias_attr=False)
192
+ self.conv3 = layers.ConvBNReLU(
193
+ 64, 64, kernel_size=3, padding=1, bias_attr=False)
194
+ self.alpha_pred = layers.ConvBNReLU(
195
+ 64, 1, kernel_size=3, padding=1, bias_attr=False)
196
+
197
+ def forward(self, x):
198
+ x = self.conv1(x)
199
+ x = self.conv2(x)
200
+ x = self.conv3(x)
201
+ alpha = self.alpha_pred(x)
202
+
203
+ return alpha
matting/model/hrnet.py ADDED
@@ -0,0 +1,835 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import math
16
+
17
+ import paddle
18
+ import paddle.nn as nn
19
+ import paddle.nn.functional as F
20
+
21
+ from paddleseg.cvlibs import manager, param_init
22
+ from paddleseg.models import layers
23
+ from paddleseg.utils import utils
24
+
25
+ __all__ = [
26
+ "HRNet_W18_Small_V1", "HRNet_W18_Small_V2", "HRNet_W18", "HRNet_W30",
27
+ "HRNet_W32", "HRNet_W40", "HRNet_W44", "HRNet_W48", "HRNet_W60", "HRNet_W64"
28
+ ]
29
+
30
+
31
+ class HRNet(nn.Layer):
32
+ """
33
+ The HRNet implementation based on PaddlePaddle.
34
+
35
+ The original article refers to
36
+ Jingdong Wang, et, al. "HRNet:Deep High-Resolution Representation Learning for Visual Recognition"
37
+ (https://arxiv.org/pdf/1908.07919.pdf).
38
+
39
+ Args:
40
+ pretrained (str, optional): The path of pretrained model.
41
+ stage1_num_modules (int, optional): Number of modules for stage1. Default 1.
42
+ stage1_num_blocks (list, optional): Number of blocks per module for stage1. Default (4).
43
+ stage1_num_channels (list, optional): Number of channels per branch for stage1. Default (64).
44
+ stage2_num_modules (int, optional): Number of modules for stage2. Default 1.
45
+ stage2_num_blocks (list, optional): Number of blocks per module for stage2. Default (4, 4).
46
+ stage2_num_channels (list, optional): Number of channels per branch for stage2. Default (18, 36).
47
+ stage3_num_modules (int, optional): Number of modules for stage3. Default 4.
48
+ stage3_num_blocks (list, optional): Number of blocks per module for stage3. Default (4, 4, 4).
49
+ stage3_num_channels (list, optional): Number of channels per branch for stage3. Default [18, 36, 72).
50
+ stage4_num_modules (int, optional): Number of modules for stage4. Default 3.
51
+ stage4_num_blocks (list, optional): Number of blocks per module for stage4. Default (4, 4, 4, 4).
52
+ stage4_num_channels (list, optional): Number of channels per branch for stage4. Default (18, 36, 72. 144).
53
+ has_se (bool, optional): Whether to use Squeeze-and-Excitation module. Default False.
54
+ align_corners (bool, optional): An argument of F.interpolate. It should be set to False when the feature size is even,
55
+ e.g. 1024x512, otherwise it is True, e.g. 769x769. Default: False.
56
+ """
57
+
58
+ def __init__(self,
59
+ input_channels=3,
60
+ pretrained=None,
61
+ stage1_num_modules=1,
62
+ stage1_num_blocks=(4, ),
63
+ stage1_num_channels=(64, ),
64
+ stage2_num_modules=1,
65
+ stage2_num_blocks=(4, 4),
66
+ stage2_num_channels=(18, 36),
67
+ stage3_num_modules=4,
68
+ stage3_num_blocks=(4, 4, 4),
69
+ stage3_num_channels=(18, 36, 72),
70
+ stage4_num_modules=3,
71
+ stage4_num_blocks=(4, 4, 4, 4),
72
+ stage4_num_channels=(18, 36, 72, 144),
73
+ has_se=False,
74
+ align_corners=False,
75
+ padding_same=True):
76
+ super(HRNet, self).__init__()
77
+ self.pretrained = pretrained
78
+ self.stage1_num_modules = stage1_num_modules
79
+ self.stage1_num_blocks = stage1_num_blocks
80
+ self.stage1_num_channels = stage1_num_channels
81
+ self.stage2_num_modules = stage2_num_modules
82
+ self.stage2_num_blocks = stage2_num_blocks
83
+ self.stage2_num_channels = stage2_num_channels
84
+ self.stage3_num_modules = stage3_num_modules
85
+ self.stage3_num_blocks = stage3_num_blocks
86
+ self.stage3_num_channels = stage3_num_channels
87
+ self.stage4_num_modules = stage4_num_modules
88
+ self.stage4_num_blocks = stage4_num_blocks
89
+ self.stage4_num_channels = stage4_num_channels
90
+ self.has_se = has_se
91
+ self.align_corners = align_corners
92
+
93
+ self.feat_channels = [i for i in stage4_num_channels]
94
+ self.feat_channels = [64] + self.feat_channels
95
+
96
+ self.conv_layer1_1 = layers.ConvBNReLU(
97
+ in_channels=input_channels,
98
+ out_channels=64,
99
+ kernel_size=3,
100
+ stride=2,
101
+ padding=1 if not padding_same else 'same',
102
+ bias_attr=False)
103
+
104
+ self.conv_layer1_2 = layers.ConvBNReLU(
105
+ in_channels=64,
106
+ out_channels=64,
107
+ kernel_size=3,
108
+ stride=2,
109
+ padding=1 if not padding_same else 'same',
110
+ bias_attr=False)
111
+
112
+ self.la1 = Layer1(
113
+ num_channels=64,
114
+ num_blocks=self.stage1_num_blocks[0],
115
+ num_filters=self.stage1_num_channels[0],
116
+ has_se=has_se,
117
+ name="layer2",
118
+ padding_same=padding_same)
119
+
120
+ self.tr1 = TransitionLayer(
121
+ in_channels=[self.stage1_num_channels[0] * 4],
122
+ out_channels=self.stage2_num_channels,
123
+ name="tr1",
124
+ padding_same=padding_same)
125
+
126
+ self.st2 = Stage(
127
+ num_channels=self.stage2_num_channels,
128
+ num_modules=self.stage2_num_modules,
129
+ num_blocks=self.stage2_num_blocks,
130
+ num_filters=self.stage2_num_channels,
131
+ has_se=self.has_se,
132
+ name="st2",
133
+ align_corners=align_corners,
134
+ padding_same=padding_same)
135
+
136
+ self.tr2 = TransitionLayer(
137
+ in_channels=self.stage2_num_channels,
138
+ out_channels=self.stage3_num_channels,
139
+ name="tr2",
140
+ padding_same=padding_same)
141
+ self.st3 = Stage(
142
+ num_channels=self.stage3_num_channels,
143
+ num_modules=self.stage3_num_modules,
144
+ num_blocks=self.stage3_num_blocks,
145
+ num_filters=self.stage3_num_channels,
146
+ has_se=self.has_se,
147
+ name="st3",
148
+ align_corners=align_corners,
149
+ padding_same=padding_same)
150
+
151
+ self.tr3 = TransitionLayer(
152
+ in_channels=self.stage3_num_channels,
153
+ out_channels=self.stage4_num_channels,
154
+ name="tr3",
155
+ padding_same=padding_same)
156
+ self.st4 = Stage(
157
+ num_channels=self.stage4_num_channels,
158
+ num_modules=self.stage4_num_modules,
159
+ num_blocks=self.stage4_num_blocks,
160
+ num_filters=self.stage4_num_channels,
161
+ has_se=self.has_se,
162
+ name="st4",
163
+ align_corners=align_corners,
164
+ padding_same=padding_same)
165
+
166
+ self.init_weight()
167
+
168
+ def forward(self, x):
169
+ feat_list = []
170
+ conv1 = self.conv_layer1_1(x)
171
+ feat_list.append(conv1)
172
+ conv2 = self.conv_layer1_2(conv1)
173
+
174
+ la1 = self.la1(conv2)
175
+
176
+ tr1 = self.tr1([la1])
177
+ st2 = self.st2(tr1)
178
+
179
+ tr2 = self.tr2(st2)
180
+ st3 = self.st3(tr2)
181
+
182
+ tr3 = self.tr3(st3)
183
+ st4 = self.st4(tr3)
184
+
185
+ feat_list = feat_list + st4
186
+
187
+ return feat_list
188
+
189
+ def init_weight(self):
190
+ for layer in self.sublayers():
191
+ if isinstance(layer, nn.Conv2D):
192
+ param_init.normal_init(layer.weight, std=0.001)
193
+ elif isinstance(layer, (nn.BatchNorm, nn.SyncBatchNorm)):
194
+ param_init.constant_init(layer.weight, value=1.0)
195
+ param_init.constant_init(layer.bias, value=0.0)
196
+ if self.pretrained is not None:
197
+ utils.load_pretrained_model(self, self.pretrained)
198
+
199
+
200
+ class Layer1(nn.Layer):
201
+ def __init__(self,
202
+ num_channels,
203
+ num_filters,
204
+ num_blocks,
205
+ has_se=False,
206
+ name=None,
207
+ padding_same=True):
208
+ super(Layer1, self).__init__()
209
+
210
+ self.bottleneck_block_list = []
211
+
212
+ for i in range(num_blocks):
213
+ bottleneck_block = self.add_sublayer(
214
+ "bb_{}_{}".format(name, i + 1),
215
+ BottleneckBlock(
216
+ num_channels=num_channels if i == 0 else num_filters * 4,
217
+ num_filters=num_filters,
218
+ has_se=has_se,
219
+ stride=1,
220
+ downsample=True if i == 0 else False,
221
+ name=name + '_' + str(i + 1),
222
+ padding_same=padding_same))
223
+ self.bottleneck_block_list.append(bottleneck_block)
224
+
225
+ def forward(self, x):
226
+ conv = x
227
+ for block_func in self.bottleneck_block_list:
228
+ conv = block_func(conv)
229
+ return conv
230
+
231
+
232
+ class TransitionLayer(nn.Layer):
233
+ def __init__(self, in_channels, out_channels, name=None, padding_same=True):
234
+ super(TransitionLayer, self).__init__()
235
+
236
+ num_in = len(in_channels)
237
+ num_out = len(out_channels)
238
+ self.conv_bn_func_list = []
239
+ for i in range(num_out):
240
+ residual = None
241
+ if i < num_in:
242
+ if in_channels[i] != out_channels[i]:
243
+ residual = self.add_sublayer(
244
+ "transition_{}_layer_{}".format(name, i + 1),
245
+ layers.ConvBNReLU(
246
+ in_channels=in_channels[i],
247
+ out_channels=out_channels[i],
248
+ kernel_size=3,
249
+ padding=1 if not padding_same else 'same',
250
+ bias_attr=False))
251
+ else:
252
+ residual = self.add_sublayer(
253
+ "transition_{}_layer_{}".format(name, i + 1),
254
+ layers.ConvBNReLU(
255
+ in_channels=in_channels[-1],
256
+ out_channels=out_channels[i],
257
+ kernel_size=3,
258
+ stride=2,
259
+ padding=1 if not padding_same else 'same',
260
+ bias_attr=False))
261
+ self.conv_bn_func_list.append(residual)
262
+
263
+ def forward(self, x):
264
+ outs = []
265
+ for idx, conv_bn_func in enumerate(self.conv_bn_func_list):
266
+ if conv_bn_func is None:
267
+ outs.append(x[idx])
268
+ else:
269
+ if idx < len(x):
270
+ outs.append(conv_bn_func(x[idx]))
271
+ else:
272
+ outs.append(conv_bn_func(x[-1]))
273
+ return outs
274
+
275
+
276
+ class Branches(nn.Layer):
277
+ def __init__(self,
278
+ num_blocks,
279
+ in_channels,
280
+ out_channels,
281
+ has_se=False,
282
+ name=None,
283
+ padding_same=True):
284
+ super(Branches, self).__init__()
285
+
286
+ self.basic_block_list = []
287
+
288
+ for i in range(len(out_channels)):
289
+ self.basic_block_list.append([])
290
+ for j in range(num_blocks[i]):
291
+ in_ch = in_channels[i] if j == 0 else out_channels[i]
292
+ basic_block_func = self.add_sublayer(
293
+ "bb_{}_branch_layer_{}_{}".format(name, i + 1, j + 1),
294
+ BasicBlock(
295
+ num_channels=in_ch,
296
+ num_filters=out_channels[i],
297
+ has_se=has_se,
298
+ name=name + '_branch_layer_' + str(i + 1) + '_' +
299
+ str(j + 1),
300
+ padding_same=padding_same))
301
+ self.basic_block_list[i].append(basic_block_func)
302
+
303
+ def forward(self, x):
304
+ outs = []
305
+ for idx, input in enumerate(x):
306
+ conv = input
307
+ for basic_block_func in self.basic_block_list[idx]:
308
+ conv = basic_block_func(conv)
309
+ outs.append(conv)
310
+ return outs
311
+
312
+
313
+ class BottleneckBlock(nn.Layer):
314
+ def __init__(self,
315
+ num_channels,
316
+ num_filters,
317
+ has_se,
318
+ stride=1,
319
+ downsample=False,
320
+ name=None,
321
+ padding_same=True):
322
+ super(BottleneckBlock, self).__init__()
323
+
324
+ self.has_se = has_se
325
+ self.downsample = downsample
326
+
327
+ self.conv1 = layers.ConvBNReLU(
328
+ in_channels=num_channels,
329
+ out_channels=num_filters,
330
+ kernel_size=1,
331
+ bias_attr=False)
332
+
333
+ self.conv2 = layers.ConvBNReLU(
334
+ in_channels=num_filters,
335
+ out_channels=num_filters,
336
+ kernel_size=3,
337
+ stride=stride,
338
+ padding=1 if not padding_same else 'same',
339
+ bias_attr=False)
340
+
341
+ self.conv3 = layers.ConvBN(
342
+ in_channels=num_filters,
343
+ out_channels=num_filters * 4,
344
+ kernel_size=1,
345
+ bias_attr=False)
346
+
347
+ if self.downsample:
348
+ self.conv_down = layers.ConvBN(
349
+ in_channels=num_channels,
350
+ out_channels=num_filters * 4,
351
+ kernel_size=1,
352
+ bias_attr=False)
353
+
354
+ if self.has_se:
355
+ self.se = SELayer(
356
+ num_channels=num_filters * 4,
357
+ num_filters=num_filters * 4,
358
+ reduction_ratio=16,
359
+ name=name + '_fc')
360
+
361
+ self.add = layers.Add()
362
+ self.relu = layers.Activation("relu")
363
+
364
+ def forward(self, x):
365
+ residual = x
366
+ conv1 = self.conv1(x)
367
+ conv2 = self.conv2(conv1)
368
+ conv3 = self.conv3(conv2)
369
+
370
+ if self.downsample:
371
+ residual = self.conv_down(x)
372
+
373
+ if self.has_se:
374
+ conv3 = self.se(conv3)
375
+
376
+ y = self.add(conv3, residual)
377
+ y = self.relu(y)
378
+ return y
379
+
380
+
381
+ class BasicBlock(nn.Layer):
382
+ def __init__(self,
383
+ num_channels,
384
+ num_filters,
385
+ stride=1,
386
+ has_se=False,
387
+ downsample=False,
388
+ name=None,
389
+ padding_same=True):
390
+ super(BasicBlock, self).__init__()
391
+
392
+ self.has_se = has_se
393
+ self.downsample = downsample
394
+
395
+ self.conv1 = layers.ConvBNReLU(
396
+ in_channels=num_channels,
397
+ out_channels=num_filters,
398
+ kernel_size=3,
399
+ stride=stride,
400
+ padding=1 if not padding_same else 'same',
401
+ bias_attr=False)
402
+ self.conv2 = layers.ConvBN(
403
+ in_channels=num_filters,
404
+ out_channels=num_filters,
405
+ kernel_size=3,
406
+ padding=1 if not padding_same else 'same',
407
+ bias_attr=False)
408
+
409
+ if self.downsample:
410
+ self.conv_down = layers.ConvBNReLU(
411
+ in_channels=num_channels,
412
+ out_channels=num_filters,
413
+ kernel_size=1,
414
+ bias_attr=False)
415
+
416
+ if self.has_se:
417
+ self.se = SELayer(
418
+ num_channels=num_filters,
419
+ num_filters=num_filters,
420
+ reduction_ratio=16,
421
+ name=name + '_fc')
422
+
423
+ self.add = layers.Add()
424
+ self.relu = layers.Activation("relu")
425
+
426
+ def forward(self, x):
427
+ residual = x
428
+ conv1 = self.conv1(x)
429
+ conv2 = self.conv2(conv1)
430
+
431
+ if self.downsample:
432
+ residual = self.conv_down(x)
433
+
434
+ if self.has_se:
435
+ conv2 = self.se(conv2)
436
+
437
+ y = self.add(conv2, residual)
438
+ y = self.relu(y)
439
+ return y
440
+
441
+
442
+ class SELayer(nn.Layer):
443
+ def __init__(self, num_channels, num_filters, reduction_ratio, name=None):
444
+ super(SELayer, self).__init__()
445
+
446
+ self.pool2d_gap = nn.AdaptiveAvgPool2D(1)
447
+
448
+ self._num_channels = num_channels
449
+
450
+ med_ch = int(num_channels / reduction_ratio)
451
+ stdv = 1.0 / math.sqrt(num_channels * 1.0)
452
+ self.squeeze = nn.Linear(
453
+ num_channels,
454
+ med_ch,
455
+ weight_attr=paddle.ParamAttr(
456
+ initializer=nn.initializer.Uniform(-stdv, stdv)))
457
+
458
+ stdv = 1.0 / math.sqrt(med_ch * 1.0)
459
+ self.excitation = nn.Linear(
460
+ med_ch,
461
+ num_filters,
462
+ weight_attr=paddle.ParamAttr(
463
+ initializer=nn.initializer.Uniform(-stdv, stdv)))
464
+
465
+ def forward(self, x):
466
+ pool = self.pool2d_gap(x)
467
+ pool = paddle.reshape(pool, shape=[-1, self._num_channels])
468
+ squeeze = self.squeeze(pool)
469
+ squeeze = F.relu(squeeze)
470
+ excitation = self.excitation(squeeze)
471
+ excitation = F.sigmoid(excitation)
472
+ excitation = paddle.reshape(
473
+ excitation, shape=[-1, self._num_channels, 1, 1])
474
+ out = x * excitation
475
+ return out
476
+
477
+
478
+ class Stage(nn.Layer):
479
+ def __init__(self,
480
+ num_channels,
481
+ num_modules,
482
+ num_blocks,
483
+ num_filters,
484
+ has_se=False,
485
+ multi_scale_output=True,
486
+ name=None,
487
+ align_corners=False,
488
+ padding_same=True):
489
+ super(Stage, self).__init__()
490
+
491
+ self._num_modules = num_modules
492
+
493
+ self.stage_func_list = []
494
+ for i in range(num_modules):
495
+ if i == num_modules - 1 and not multi_scale_output:
496
+ stage_func = self.add_sublayer(
497
+ "stage_{}_{}".format(name, i + 1),
498
+ HighResolutionModule(
499
+ num_channels=num_channels,
500
+ num_blocks=num_blocks,
501
+ num_filters=num_filters,
502
+ has_se=has_se,
503
+ multi_scale_output=False,
504
+ name=name + '_' + str(i + 1),
505
+ align_corners=align_corners,
506
+ padding_same=padding_same))
507
+ else:
508
+ stage_func = self.add_sublayer(
509
+ "stage_{}_{}".format(name, i + 1),
510
+ HighResolutionModule(
511
+ num_channels=num_channels,
512
+ num_blocks=num_blocks,
513
+ num_filters=num_filters,
514
+ has_se=has_se,
515
+ name=name + '_' + str(i + 1),
516
+ align_corners=align_corners,
517
+ padding_same=padding_same))
518
+
519
+ self.stage_func_list.append(stage_func)
520
+
521
+ def forward(self, x):
522
+ out = x
523
+ for idx in range(self._num_modules):
524
+ out = self.stage_func_list[idx](out)
525
+ return out
526
+
527
+
528
+ class HighResolutionModule(nn.Layer):
529
+ def __init__(self,
530
+ num_channels,
531
+ num_blocks,
532
+ num_filters,
533
+ has_se=False,
534
+ multi_scale_output=True,
535
+ name=None,
536
+ align_corners=False,
537
+ padding_same=True):
538
+ super(HighResolutionModule, self).__init__()
539
+
540
+ self.branches_func = Branches(
541
+ num_blocks=num_blocks,
542
+ in_channels=num_channels,
543
+ out_channels=num_filters,
544
+ has_se=has_se,
545
+ name=name,
546
+ padding_same=padding_same)
547
+
548
+ self.fuse_func = FuseLayers(
549
+ in_channels=num_filters,
550
+ out_channels=num_filters,
551
+ multi_scale_output=multi_scale_output,
552
+ name=name,
553
+ align_corners=align_corners,
554
+ padding_same=padding_same)
555
+
556
+ def forward(self, x):
557
+ out = self.branches_func(x)
558
+ out = self.fuse_func(out)
559
+ return out
560
+
561
+
562
+ class FuseLayers(nn.Layer):
563
+ def __init__(self,
564
+ in_channels,
565
+ out_channels,
566
+ multi_scale_output=True,
567
+ name=None,
568
+ align_corners=False,
569
+ padding_same=True):
570
+ super(FuseLayers, self).__init__()
571
+
572
+ self._actual_ch = len(in_channels) if multi_scale_output else 1
573
+ self._in_channels = in_channels
574
+ self.align_corners = align_corners
575
+
576
+ self.residual_func_list = []
577
+ for i in range(self._actual_ch):
578
+ for j in range(len(in_channels)):
579
+ if j > i:
580
+ residual_func = self.add_sublayer(
581
+ "residual_{}_layer_{}_{}".format(name, i + 1, j + 1),
582
+ layers.ConvBN(
583
+ in_channels=in_channels[j],
584
+ out_channels=out_channels[i],
585
+ kernel_size=1,
586
+ bias_attr=False))
587
+ self.residual_func_list.append(residual_func)
588
+ elif j < i:
589
+ pre_num_filters = in_channels[j]
590
+ for k in range(i - j):
591
+ if k == i - j - 1:
592
+ residual_func = self.add_sublayer(
593
+ "residual_{}_layer_{}_{}_{}".format(
594
+ name, i + 1, j + 1, k + 1),
595
+ layers.ConvBN(
596
+ in_channels=pre_num_filters,
597
+ out_channels=out_channels[i],
598
+ kernel_size=3,
599
+ stride=2,
600
+ padding=1 if not padding_same else 'same',
601
+ bias_attr=False))
602
+ pre_num_filters = out_channels[i]
603
+ else:
604
+ residual_func = self.add_sublayer(
605
+ "residual_{}_layer_{}_{}_{}".format(
606
+ name, i + 1, j + 1, k + 1),
607
+ layers.ConvBNReLU(
608
+ in_channels=pre_num_filters,
609
+ out_channels=out_channels[j],
610
+ kernel_size=3,
611
+ stride=2,
612
+ padding=1 if not padding_same else 'same',
613
+ bias_attr=False))
614
+ pre_num_filters = out_channels[j]
615
+ self.residual_func_list.append(residual_func)
616
+
617
+ def forward(self, x):
618
+ outs = []
619
+ residual_func_idx = 0
620
+ for i in range(self._actual_ch):
621
+ residual = x[i]
622
+ residual_shape = paddle.shape(residual)[-2:]
623
+ for j in range(len(self._in_channels)):
624
+ if j > i:
625
+ y = self.residual_func_list[residual_func_idx](x[j])
626
+ residual_func_idx += 1
627
+
628
+ y = F.interpolate(
629
+ y,
630
+ residual_shape,
631
+ mode='bilinear',
632
+ align_corners=self.align_corners)
633
+ residual = residual + y
634
+ elif j < i:
635
+ y = x[j]
636
+ for k in range(i - j):
637
+ y = self.residual_func_list[residual_func_idx](y)
638
+ residual_func_idx += 1
639
+
640
+ residual = residual + y
641
+
642
+ residual = F.relu(residual)
643
+ outs.append(residual)
644
+
645
+ return outs
646
+
647
+
648
+ @manager.BACKBONES.add_component
649
+ def HRNet_W18_Small_V1(**kwargs):
650
+ model = HRNet(
651
+ stage1_num_modules=1,
652
+ stage1_num_blocks=[1],
653
+ stage1_num_channels=[32],
654
+ stage2_num_modules=1,
655
+ stage2_num_blocks=[2, 2],
656
+ stage2_num_channels=[16, 32],
657
+ stage3_num_modules=1,
658
+ stage3_num_blocks=[2, 2, 2],
659
+ stage3_num_channels=[16, 32, 64],
660
+ stage4_num_modules=1,
661
+ stage4_num_blocks=[2, 2, 2, 2],
662
+ stage4_num_channels=[16, 32, 64, 128],
663
+ **kwargs)
664
+ return model
665
+
666
+
667
+ @manager.BACKBONES.add_component
668
+ def HRNet_W18_Small_V2(**kwargs):
669
+ model = HRNet(
670
+ stage1_num_modules=1,
671
+ stage1_num_blocks=[2],
672
+ stage1_num_channels=[64],
673
+ stage2_num_modules=1,
674
+ stage2_num_blocks=[2, 2],
675
+ stage2_num_channels=[18, 36],
676
+ stage3_num_modules=3,
677
+ stage3_num_blocks=[2, 2, 2],
678
+ stage3_num_channels=[18, 36, 72],
679
+ stage4_num_modules=2,
680
+ stage4_num_blocks=[2, 2, 2, 2],
681
+ stage4_num_channels=[18, 36, 72, 144],
682
+ **kwargs)
683
+ return model
684
+
685
+
686
+ @manager.BACKBONES.add_component
687
+ def HRNet_W18(**kwargs):
688
+ model = HRNet(
689
+ stage1_num_modules=1,
690
+ stage1_num_blocks=[4],
691
+ stage1_num_channels=[64],
692
+ stage2_num_modules=1,
693
+ stage2_num_blocks=[4, 4],
694
+ stage2_num_channels=[18, 36],
695
+ stage3_num_modules=4,
696
+ stage3_num_blocks=[4, 4, 4],
697
+ stage3_num_channels=[18, 36, 72],
698
+ stage4_num_modules=3,
699
+ stage4_num_blocks=[4, 4, 4, 4],
700
+ stage4_num_channels=[18, 36, 72, 144],
701
+ **kwargs)
702
+ return model
703
+
704
+
705
+ @manager.BACKBONES.add_component
706
+ def HRNet_W30(**kwargs):
707
+ model = HRNet(
708
+ stage1_num_modules=1,
709
+ stage1_num_blocks=[4],
710
+ stage1_num_channels=[64],
711
+ stage2_num_modules=1,
712
+ stage2_num_blocks=[4, 4],
713
+ stage2_num_channels=[30, 60],
714
+ stage3_num_modules=4,
715
+ stage3_num_blocks=[4, 4, 4],
716
+ stage3_num_channels=[30, 60, 120],
717
+ stage4_num_modules=3,
718
+ stage4_num_blocks=[4, 4, 4, 4],
719
+ stage4_num_channels=[30, 60, 120, 240],
720
+ **kwargs)
721
+ return model
722
+
723
+
724
+ @manager.BACKBONES.add_component
725
+ def HRNet_W32(**kwargs):
726
+ model = HRNet(
727
+ stage1_num_modules=1,
728
+ stage1_num_blocks=[4],
729
+ stage1_num_channels=[64],
730
+ stage2_num_modules=1,
731
+ stage2_num_blocks=[4, 4],
732
+ stage2_num_channels=[32, 64],
733
+ stage3_num_modules=4,
734
+ stage3_num_blocks=[4, 4, 4],
735
+ stage3_num_channels=[32, 64, 128],
736
+ stage4_num_modules=3,
737
+ stage4_num_blocks=[4, 4, 4, 4],
738
+ stage4_num_channels=[32, 64, 128, 256],
739
+ **kwargs)
740
+ return model
741
+
742
+
743
+ @manager.BACKBONES.add_component
744
+ def HRNet_W40(**kwargs):
745
+ model = HRNet(
746
+ stage1_num_modules=1,
747
+ stage1_num_blocks=[4],
748
+ stage1_num_channels=[64],
749
+ stage2_num_modules=1,
750
+ stage2_num_blocks=[4, 4],
751
+ stage2_num_channels=[40, 80],
752
+ stage3_num_modules=4,
753
+ stage3_num_blocks=[4, 4, 4],
754
+ stage3_num_channels=[40, 80, 160],
755
+ stage4_num_modules=3,
756
+ stage4_num_blocks=[4, 4, 4, 4],
757
+ stage4_num_channels=[40, 80, 160, 320],
758
+ **kwargs)
759
+ return model
760
+
761
+
762
+ @manager.BACKBONES.add_component
763
+ def HRNet_W44(**kwargs):
764
+ model = HRNet(
765
+ stage1_num_modules=1,
766
+ stage1_num_blocks=[4],
767
+ stage1_num_channels=[64],
768
+ stage2_num_modules=1,
769
+ stage2_num_blocks=[4, 4],
770
+ stage2_num_channels=[44, 88],
771
+ stage3_num_modules=4,
772
+ stage3_num_blocks=[4, 4, 4],
773
+ stage3_num_channels=[44, 88, 176],
774
+ stage4_num_modules=3,
775
+ stage4_num_blocks=[4, 4, 4, 4],
776
+ stage4_num_channels=[44, 88, 176, 352],
777
+ **kwargs)
778
+ return model
779
+
780
+
781
+ @manager.BACKBONES.add_component
782
+ def HRNet_W48(**kwargs):
783
+ model = HRNet(
784
+ stage1_num_modules=1,
785
+ stage1_num_blocks=[4],
786
+ stage1_num_channels=[64],
787
+ stage2_num_modules=1,
788
+ stage2_num_blocks=[4, 4],
789
+ stage2_num_channels=[48, 96],
790
+ stage3_num_modules=4,
791
+ stage3_num_blocks=[4, 4, 4],
792
+ stage3_num_channels=[48, 96, 192],
793
+ stage4_num_modules=3,
794
+ stage4_num_blocks=[4, 4, 4, 4],
795
+ stage4_num_channels=[48, 96, 192, 384],
796
+ **kwargs)
797
+ return model
798
+
799
+
800
+ @manager.BACKBONES.add_component
801
+ def HRNet_W60(**kwargs):
802
+ model = HRNet(
803
+ stage1_num_modules=1,
804
+ stage1_num_blocks=[4],
805
+ stage1_num_channels=[64],
806
+ stage2_num_modules=1,
807
+ stage2_num_blocks=[4, 4],
808
+ stage2_num_channels=[60, 120],
809
+ stage3_num_modules=4,
810
+ stage3_num_blocks=[4, 4, 4],
811
+ stage3_num_channels=[60, 120, 240],
812
+ stage4_num_modules=3,
813
+ stage4_num_blocks=[4, 4, 4, 4],
814
+ stage4_num_channels=[60, 120, 240, 480],
815
+ **kwargs)
816
+ return model
817
+
818
+
819
+ @manager.BACKBONES.add_component
820
+ def HRNet_W64(**kwargs):
821
+ model = HRNet(
822
+ stage1_num_modules=1,
823
+ stage1_num_blocks=[4],
824
+ stage1_num_channels=[64],
825
+ stage2_num_modules=1,
826
+ stage2_num_blocks=[4, 4],
827
+ stage2_num_channels=[64, 128],
828
+ stage3_num_modules=4,
829
+ stage3_num_blocks=[4, 4, 4],
830
+ stage3_num_channels=[64, 128, 256],
831
+ stage4_num_modules=3,
832
+ stage4_num_blocks=[4, 4, 4, 4],
833
+ stage4_num_channels=[64, 128, 256, 512],
834
+ **kwargs)
835
+ return model
matting/model/loss.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import paddle
16
+ import paddle.nn as nn
17
+ import paddle.nn.functional as F
18
+
19
+ from paddleseg.cvlibs import manager
20
+
21
+
22
+ @manager.LOSSES.add_component
23
+ class MRSD(nn.Layer):
24
+ def __init__(self, eps=1e-6):
25
+ super().__init__()
26
+ self.eps = eps
27
+
28
+ def forward(self, logit, label, mask=None):
29
+ """
30
+ Forward computation.
31
+
32
+ Args:
33
+ logit (Tensor): Logit tensor, the data type is float32, float64.
34
+ label (Tensor): Label tensor, the data type is float32, float64. The shape should equal to logit.
35
+ mask (Tensor, optional): The mask where the loss valid. Default: None.
36
+ """
37
+ if len(label.shape) == 3:
38
+ label = label.unsqueeze(1)
39
+ sd = paddle.square(logit - label)
40
+ loss = paddle.sqrt(sd + self.eps)
41
+ if mask is not None:
42
+ mask = mask.astype('float32')
43
+ if len(mask.shape) == 3:
44
+ mask = mask.unsqueeze(1)
45
+ loss = loss * mask
46
+ loss = loss.sum() / (mask.sum() + self.eps)
47
+ mask.stop_gradient = True
48
+ else:
49
+ loss = loss.mean()
50
+
51
+ return loss
matting/model/mobilenet_v2.py ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import math
16
+
17
+ import numpy as np
18
+ import paddle
19
+ from paddle import ParamAttr
20
+ import paddle.nn as nn
21
+ import paddle.nn.functional as F
22
+ from paddle.nn import Conv2D, BatchNorm, Linear, Dropout
23
+ from paddle.nn import AdaptiveAvgPool2D, MaxPool2D, AvgPool2D
24
+
25
+ from paddleseg import utils
26
+ from paddleseg.cvlibs import manager
27
+
28
+ MODEL_URLS = {
29
+ "MobileNetV2_x0_25":
30
+ "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/MobileNetV2_x0_25_pretrained.pdparams",
31
+ "MobileNetV2_x0_5":
32
+ "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/MobileNetV2_x0_5_pretrained.pdparams",
33
+ "MobileNetV2_x0_75":
34
+ "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/MobileNetV2_x0_75_pretrained.pdparams",
35
+ "MobileNetV2":
36
+ "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/MobileNetV2_pretrained.pdparams",
37
+ "MobileNetV2_x1_5":
38
+ "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/MobileNetV2_x1_5_pretrained.pdparams",
39
+ "MobileNetV2_x2_0":
40
+ "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/MobileNetV2_x2_0_pretrained.pdparams"
41
+ }
42
+
43
+ __all__ = ["MobileNetV2"]
44
+
45
+
46
+ class ConvBNLayer(nn.Layer):
47
+ def __init__(self,
48
+ num_channels,
49
+ filter_size,
50
+ num_filters,
51
+ stride,
52
+ padding,
53
+ channels=None,
54
+ num_groups=1,
55
+ name=None,
56
+ use_cudnn=True):
57
+ super(ConvBNLayer, self).__init__()
58
+
59
+ self._conv = Conv2D(
60
+ in_channels=num_channels,
61
+ out_channels=num_filters,
62
+ kernel_size=filter_size,
63
+ stride=stride,
64
+ padding=padding,
65
+ groups=num_groups,
66
+ weight_attr=ParamAttr(name=name + "_weights"),
67
+ bias_attr=False)
68
+
69
+ self._batch_norm = BatchNorm(
70
+ num_filters,
71
+ param_attr=ParamAttr(name=name + "_bn_scale"),
72
+ bias_attr=ParamAttr(name=name + "_bn_offset"),
73
+ moving_mean_name=name + "_bn_mean",
74
+ moving_variance_name=name + "_bn_variance")
75
+
76
+ def forward(self, inputs, if_act=True):
77
+ y = self._conv(inputs)
78
+ y = self._batch_norm(y)
79
+ if if_act:
80
+ y = F.relu6(y)
81
+ return y
82
+
83
+
84
+ class InvertedResidualUnit(nn.Layer):
85
+ def __init__(self, num_channels, num_in_filter, num_filters, stride,
86
+ filter_size, padding, expansion_factor, name):
87
+ super(InvertedResidualUnit, self).__init__()
88
+ num_expfilter = int(round(num_in_filter * expansion_factor))
89
+ self._expand_conv = ConvBNLayer(
90
+ num_channels=num_channels,
91
+ num_filters=num_expfilter,
92
+ filter_size=1,
93
+ stride=1,
94
+ padding=0,
95
+ num_groups=1,
96
+ name=name + "_expand")
97
+
98
+ self._bottleneck_conv = ConvBNLayer(
99
+ num_channels=num_expfilter,
100
+ num_filters=num_expfilter,
101
+ filter_size=filter_size,
102
+ stride=stride,
103
+ padding=padding,
104
+ num_groups=num_expfilter,
105
+ use_cudnn=False,
106
+ name=name + "_dwise")
107
+
108
+ self._linear_conv = ConvBNLayer(
109
+ num_channels=num_expfilter,
110
+ num_filters=num_filters,
111
+ filter_size=1,
112
+ stride=1,
113
+ padding=0,
114
+ num_groups=1,
115
+ name=name + "_linear")
116
+
117
+ def forward(self, inputs, ifshortcut):
118
+ y = self._expand_conv(inputs, if_act=True)
119
+ y = self._bottleneck_conv(y, if_act=True)
120
+ y = self._linear_conv(y, if_act=False)
121
+ if ifshortcut:
122
+ y = paddle.add(inputs, y)
123
+ return y
124
+
125
+
126
+ class InvresiBlocks(nn.Layer):
127
+ def __init__(self, in_c, t, c, n, s, name):
128
+ super(InvresiBlocks, self).__init__()
129
+
130
+ self._first_block = InvertedResidualUnit(
131
+ num_channels=in_c,
132
+ num_in_filter=in_c,
133
+ num_filters=c,
134
+ stride=s,
135
+ filter_size=3,
136
+ padding=1,
137
+ expansion_factor=t,
138
+ name=name + "_1")
139
+
140
+ self._block_list = []
141
+ for i in range(1, n):
142
+ block = self.add_sublayer(
143
+ name + "_" + str(i + 1),
144
+ sublayer=InvertedResidualUnit(
145
+ num_channels=c,
146
+ num_in_filter=c,
147
+ num_filters=c,
148
+ stride=1,
149
+ filter_size=3,
150
+ padding=1,
151
+ expansion_factor=t,
152
+ name=name + "_" + str(i + 1)))
153
+ self._block_list.append(block)
154
+
155
+ def forward(self, inputs):
156
+ y = self._first_block(inputs, ifshortcut=False)
157
+ for block in self._block_list:
158
+ y = block(y, ifshortcut=True)
159
+ return y
160
+
161
+
162
+ class MobileNet(nn.Layer):
163
+ def __init__(self,
164
+ input_channels=3,
165
+ scale=1.0,
166
+ pretrained=None,
167
+ prefix_name=""):
168
+ super(MobileNet, self).__init__()
169
+ self.scale = scale
170
+
171
+ bottleneck_params_list = [
172
+ (1, 16, 1, 1),
173
+ (6, 24, 2, 2),
174
+ (6, 32, 3, 2),
175
+ (6, 64, 4, 2),
176
+ (6, 96, 3, 1),
177
+ (6, 160, 3, 2),
178
+ (6, 320, 1, 1),
179
+ ]
180
+
181
+ self.conv1 = ConvBNLayer(
182
+ num_channels=input_channels,
183
+ num_filters=int(32 * scale),
184
+ filter_size=3,
185
+ stride=2,
186
+ padding=1,
187
+ name=prefix_name + "conv1_1")
188
+
189
+ self.block_list = []
190
+ i = 1
191
+ in_c = int(32 * scale)
192
+ for layer_setting in bottleneck_params_list:
193
+ t, c, n, s = layer_setting
194
+ i += 1
195
+ block = self.add_sublayer(
196
+ prefix_name + "conv" + str(i),
197
+ sublayer=InvresiBlocks(
198
+ in_c=in_c,
199
+ t=t,
200
+ c=int(c * scale),
201
+ n=n,
202
+ s=s,
203
+ name=prefix_name + "conv" + str(i)))
204
+ self.block_list.append(block)
205
+ in_c = int(c * scale)
206
+
207
+ self.out_c = int(1280 * scale) if scale > 1.0 else 1280
208
+ self.conv9 = ConvBNLayer(
209
+ num_channels=in_c,
210
+ num_filters=self.out_c,
211
+ filter_size=1,
212
+ stride=1,
213
+ padding=0,
214
+ name=prefix_name + "conv9")
215
+
216
+ self.feat_channels = [int(i * scale) for i in [16, 24, 32, 96, 1280]]
217
+ self.pretrained = pretrained
218
+ self.init_weight()
219
+
220
+ def forward(self, inputs):
221
+ feat_list = []
222
+ y = self.conv1(inputs, if_act=True)
223
+
224
+ block_index = 0
225
+ for block in self.block_list:
226
+ y = block(y)
227
+ if block_index in [0, 1, 2, 4]:
228
+ feat_list.append(y)
229
+ block_index += 1
230
+ y = self.conv9(y, if_act=True)
231
+ feat_list.append(y)
232
+ return feat_list
233
+
234
+ def init_weight(self):
235
+ utils.load_pretrained_model(self, self.pretrained)
236
+
237
+
238
+ @manager.BACKBONES.add_component
239
+ def MobileNetV2(**kwargs):
240
+ model = MobileNet(scale=1.0, **kwargs)
241
+ return model
matting/model/modnet.py ADDED
@@ -0,0 +1,481 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # you may not use this file except in compliance with the License.
2
+ # You may obtain a copy of the License at
3
+ #
4
+ # http://www.apache.org/licenses/LICENSE-2.0
5
+ #
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
11
+
12
+ from collections import defaultdict
13
+
14
+ import paddle
15
+ import paddle.nn as nn
16
+ import paddle.nn.functional as F
17
+ import paddleseg
18
+ from paddleseg.models import layers, losses
19
+ from paddleseg import utils
20
+ from paddleseg.cvlibs import manager, param_init
21
+ import numpy as np
22
+ import scipy
23
+
24
+
25
+ @manager.MODELS.add_component
26
+ class MODNet(nn.Layer):
27
+ """
28
+ The MODNet implementation based on PaddlePaddle.
29
+
30
+ The original article refers to
31
+ Zhanghan Ke, et, al. "Is a Green Screen Really Necessary for Real-Time Portrait Matting?"
32
+ (https://arxiv.org/pdf/2011.11961.pdf).
33
+
34
+ Args:
35
+ backbone: backbone model.
36
+ hr(int, optional): The channels of high resolutions branch. Defautl: None.
37
+ pretrained(str, optional): The path of pretrianed model. Defautl: None.
38
+
39
+ """
40
+
41
+ def __init__(self, backbone, hr_channels=32, pretrained=None):
42
+ super().__init__()
43
+ self.backbone = backbone
44
+ self.pretrained = pretrained
45
+
46
+ self.head = MODNetHead(
47
+ hr_channels=hr_channels, backbone_channels=backbone.feat_channels)
48
+ self.init_weight()
49
+ self.blurer = GaussianBlurLayer(1, 3)
50
+
51
+ def forward(self, inputs):
52
+ """
53
+ If training, return a dict.
54
+ If evaluation, return the final alpha prediction.
55
+ """
56
+ x = inputs['img']
57
+ feat_list = self.backbone(x)
58
+ y = self.head(inputs=inputs, feat_list=feat_list)
59
+
60
+ return y
61
+
62
+ def loss(self, logit_dict, label_dict, loss_func_dict=None):
63
+ if loss_func_dict is None:
64
+ loss_func_dict = defaultdict(list)
65
+ loss_func_dict['semantic'].append(paddleseg.models.MSELoss())
66
+ loss_func_dict['detail'].append(paddleseg.models.L1Loss())
67
+ loss_func_dict['fusion'].append(paddleseg.models.L1Loss())
68
+ loss_func_dict['fusion'].append(paddleseg.models.L1Loss())
69
+
70
+ loss = {}
71
+ # semantic loss
72
+ semantic_gt = F.interpolate(
73
+ label_dict['alpha'],
74
+ scale_factor=1 / 16,
75
+ mode='bilinear',
76
+ align_corners=False)
77
+ semantic_gt = self.blurer(semantic_gt)
78
+ # semantic_gt.stop_gradient=True
79
+ loss['semantic'] = loss_func_dict['semantic'][0](logit_dict['semantic'],
80
+ semantic_gt)
81
+
82
+ # detail loss
83
+ trimap = label_dict['trimap']
84
+ mask = (trimap == 128).astype('float32')
85
+ logit_detail = logit_dict['detail'] * mask
86
+ label_detail = label_dict['alpha'] * mask
87
+ loss_detail = loss_func_dict['detail'][0](logit_detail, label_detail)
88
+ loss_detail = loss_detail / (mask.mean() + 1e-6)
89
+ loss['detail'] = 10 * loss_detail
90
+
91
+ # fusion loss
92
+ matte = logit_dict['matte']
93
+ alpha = label_dict['alpha']
94
+ transition_mask = label_dict['trimap'] == 128
95
+ matte_boundary = paddle.where(transition_mask, matte, alpha)
96
+ # l1 loss
97
+ loss_fusion_l1 = loss_func_dict['fusion'][0](
98
+ matte,
99
+ alpha) + 4 * loss_func_dict['fusion'][0](matte_boundary, alpha)
100
+ # composition loss
101
+ loss_fusion_comp = loss_func_dict['fusion'][1](
102
+ matte * label_dict['img'],
103
+ alpha * label_dict['img']) + 4 * loss_func_dict['fusion'][1](
104
+ matte_boundary * label_dict['img'], alpha * label_dict['img'])
105
+ # consisten loss with semantic
106
+ transition_mask = F.interpolate(
107
+ label_dict['trimap'],
108
+ scale_factor=1 / 16,
109
+ mode='nearest',
110
+ align_corners=False)
111
+ transition_mask = transition_mask == 128
112
+ matte_con_sem = F.interpolate(
113
+ matte, scale_factor=1 / 16, mode='bilinear', align_corners=False)
114
+ matte_con_sem = self.blurer(matte_con_sem)
115
+ logit_semantic = logit_dict['semantic'].clone()
116
+ logit_semantic.stop_gradient = True
117
+ matte_con_sem = paddle.where(transition_mask, logit_semantic,
118
+ matte_con_sem)
119
+ if False:
120
+ import cv2
121
+ matte_con_sem_num = matte_con_sem.numpy()
122
+ matte_con_sem_num = matte_con_sem_num[0].squeeze()
123
+ matte_con_sem_num = (matte_con_sem_num * 255).astype('uint8')
124
+ semantic = logit_dict['semantic'].numpy()
125
+ semantic = semantic[0].squeeze()
126
+ semantic = (semantic * 255).astype('uint8')
127
+ transition_mask = transition_mask.astype('uint8')
128
+ transition_mask = transition_mask.numpy()
129
+ transition_mask = (transition_mask[0].squeeze()) * 255
130
+ cv2.imwrite('matte_con.png', matte_con_sem_num)
131
+ cv2.imwrite('semantic.png', semantic)
132
+ cv2.imwrite('transition.png', transition_mask)
133
+ mse_loss = paddleseg.models.MSELoss()
134
+ loss_fusion_con_sem = mse_loss(matte_con_sem, logit_dict['semantic'])
135
+ loss_fusion = loss_fusion_l1 + loss_fusion_comp + loss_fusion_con_sem
136
+ loss['fusion'] = loss_fusion
137
+ loss['fusion_l1'] = loss_fusion_l1
138
+ loss['fusion_comp'] = loss_fusion_comp
139
+ loss['fusion_con_sem'] = loss_fusion_con_sem
140
+
141
+ loss['all'] = loss['semantic'] + loss['detail'] + loss['fusion']
142
+
143
+ return loss
144
+
145
+ def init_weight(self):
146
+ if self.pretrained is not None:
147
+ utils.load_entire_model(self, self.pretrained)
148
+
149
+
150
+ class MODNetHead(nn.Layer):
151
+ def __init__(self, hr_channels, backbone_channels):
152
+ super().__init__()
153
+
154
+ self.lr_branch = LRBranch(backbone_channels)
155
+ self.hr_branch = HRBranch(hr_channels, backbone_channels)
156
+ self.f_branch = FusionBranch(hr_channels, backbone_channels)
157
+ self.init_weight()
158
+
159
+ def forward(self, inputs, feat_list):
160
+ pred_semantic, lr8x, [enc2x, enc4x] = self.lr_branch(feat_list)
161
+ pred_detail, hr2x = self.hr_branch(inputs['img'], enc2x, enc4x, lr8x)
162
+ pred_matte = self.f_branch(inputs['img'], lr8x, hr2x)
163
+
164
+ if self.training:
165
+ logit_dict = {
166
+ 'semantic': pred_semantic,
167
+ 'detail': pred_detail,
168
+ 'matte': pred_matte
169
+ }
170
+ return logit_dict
171
+ else:
172
+ return pred_matte
173
+
174
+ def init_weight(self):
175
+ for layer in self.sublayers():
176
+ if isinstance(layer, nn.Conv2D):
177
+ param_init.kaiming_uniform(layer.weight)
178
+
179
+
180
+ class FusionBranch(nn.Layer):
181
+ def __init__(self, hr_channels, enc_channels):
182
+ super().__init__()
183
+ self.conv_lr4x = Conv2dIBNormRelu(
184
+ enc_channels[2], hr_channels, 5, stride=1, padding=2)
185
+
186
+ self.conv_f2x = Conv2dIBNormRelu(
187
+ 2 * hr_channels, hr_channels, 3, stride=1, padding=1)
188
+ self.conv_f = nn.Sequential(
189
+ Conv2dIBNormRelu(
190
+ hr_channels + 3, int(hr_channels / 2), 3, stride=1, padding=1),
191
+ Conv2dIBNormRelu(
192
+ int(hr_channels / 2),
193
+ 1,
194
+ 1,
195
+ stride=1,
196
+ padding=0,
197
+ with_ibn=False,
198
+ with_relu=False))
199
+
200
+ def forward(self, img, lr8x, hr2x):
201
+ lr4x = F.interpolate(
202
+ lr8x, scale_factor=2, mode='bilinear', align_corners=False)
203
+ lr4x = self.conv_lr4x(lr4x)
204
+ lr2x = F.interpolate(
205
+ lr4x, scale_factor=2, mode='bilinear', align_corners=False)
206
+
207
+ f2x = self.conv_f2x(paddle.concat((lr2x, hr2x), axis=1))
208
+ f = F.interpolate(
209
+ f2x, scale_factor=2, mode='bilinear', align_corners=False)
210
+ f = self.conv_f(paddle.concat((f, img), axis=1))
211
+ pred_matte = F.sigmoid(f)
212
+
213
+ return pred_matte
214
+
215
+
216
+ class HRBranch(nn.Layer):
217
+ """
218
+ High Resolution Branch of MODNet
219
+ """
220
+
221
+ def __init__(self, hr_channels, enc_channels):
222
+ super().__init__()
223
+
224
+ self.tohr_enc2x = Conv2dIBNormRelu(
225
+ enc_channels[0], hr_channels, 1, stride=1, padding=0)
226
+ self.conv_enc2x = Conv2dIBNormRelu(
227
+ hr_channels + 3, hr_channels, 3, stride=2, padding=1)
228
+
229
+ self.tohr_enc4x = Conv2dIBNormRelu(
230
+ enc_channels[1], hr_channels, 1, stride=1, padding=0)
231
+ self.conv_enc4x = Conv2dIBNormRelu(
232
+ 2 * hr_channels, 2 * hr_channels, 3, stride=1, padding=1)
233
+
234
+ self.conv_hr4x = nn.Sequential(
235
+ Conv2dIBNormRelu(
236
+ 2 * hr_channels + enc_channels[2] + 3,
237
+ 2 * hr_channels,
238
+ 3,
239
+ stride=1,
240
+ padding=1),
241
+ Conv2dIBNormRelu(
242
+ 2 * hr_channels, 2 * hr_channels, 3, stride=1, padding=1),
243
+ Conv2dIBNormRelu(
244
+ 2 * hr_channels, hr_channels, 3, stride=1, padding=1))
245
+
246
+ self.conv_hr2x = nn.Sequential(
247
+ Conv2dIBNormRelu(
248
+ 2 * hr_channels, 2 * hr_channels, 3, stride=1, padding=1),
249
+ Conv2dIBNormRelu(
250
+ 2 * hr_channels, hr_channels, 3, stride=1, padding=1),
251
+ Conv2dIBNormRelu(hr_channels, hr_channels, 3, stride=1, padding=1),
252
+ Conv2dIBNormRelu(hr_channels, hr_channels, 3, stride=1, padding=1))
253
+
254
+ self.conv_hr = nn.Sequential(
255
+ Conv2dIBNormRelu(
256
+ hr_channels + 3, hr_channels, 3, stride=1, padding=1),
257
+ Conv2dIBNormRelu(
258
+ hr_channels,
259
+ 1,
260
+ 1,
261
+ stride=1,
262
+ padding=0,
263
+ with_ibn=False,
264
+ with_relu=False))
265
+
266
+ def forward(self, img, enc2x, enc4x, lr8x):
267
+ img2x = F.interpolate(
268
+ img, scale_factor=1 / 2, mode='bilinear', align_corners=False)
269
+ img4x = F.interpolate(
270
+ img, scale_factor=1 / 4, mode='bilinear', align_corners=False)
271
+
272
+ enc2x = self.tohr_enc2x(enc2x)
273
+ hr4x = self.conv_enc2x(paddle.concat((img2x, enc2x), axis=1))
274
+
275
+ enc4x = self.tohr_enc4x(enc4x)
276
+ hr4x = self.conv_enc4x(paddle.concat((hr4x, enc4x), axis=1))
277
+
278
+ lr4x = F.interpolate(
279
+ lr8x, scale_factor=2, mode='bilinear', align_corners=False)
280
+ hr4x = self.conv_hr4x(paddle.concat((hr4x, lr4x, img4x), axis=1))
281
+
282
+ hr2x = F.interpolate(
283
+ hr4x, scale_factor=2, mode='bilinear', align_corners=False)
284
+ hr2x = self.conv_hr2x(paddle.concat((hr2x, enc2x), axis=1))
285
+
286
+ pred_detail = None
287
+ if self.training:
288
+ hr = F.interpolate(
289
+ hr2x, scale_factor=2, mode='bilinear', align_corners=False)
290
+ hr = self.conv_hr(paddle.concat((hr, img), axis=1))
291
+ pred_detail = F.sigmoid(hr)
292
+
293
+ return pred_detail, hr2x
294
+
295
+
296
+ class LRBranch(nn.Layer):
297
+ def __init__(self, backbone_channels):
298
+ super().__init__()
299
+ self.se_block = SEBlock(backbone_channels[4], reduction=4)
300
+ self.conv_lr16x = Conv2dIBNormRelu(
301
+ backbone_channels[4], backbone_channels[3], 5, stride=1, padding=2)
302
+ self.conv_lr8x = Conv2dIBNormRelu(
303
+ backbone_channels[3], backbone_channels[2], 5, stride=1, padding=2)
304
+ self.conv_lr = Conv2dIBNormRelu(
305
+ backbone_channels[2],
306
+ 1,
307
+ 3,
308
+ stride=2,
309
+ padding=1,
310
+ with_ibn=False,
311
+ with_relu=False)
312
+
313
+ def forward(self, feat_list):
314
+ enc2x, enc4x, enc32x = feat_list[0], feat_list[1], feat_list[4]
315
+
316
+ enc32x = self.se_block(enc32x)
317
+ lr16x = F.interpolate(
318
+ enc32x, scale_factor=2, mode='bilinear', align_corners=False)
319
+ lr16x = self.conv_lr16x(lr16x)
320
+ lr8x = F.interpolate(
321
+ lr16x, scale_factor=2, mode='bilinear', align_corners=False)
322
+ lr8x = self.conv_lr8x(lr8x)
323
+
324
+ pred_semantic = None
325
+ if self.training:
326
+ lr = self.conv_lr(lr8x)
327
+ pred_semantic = F.sigmoid(lr)
328
+
329
+ return pred_semantic, lr8x, [enc2x, enc4x]
330
+
331
+
332
+ class IBNorm(nn.Layer):
333
+ """
334
+ Combine Instance Norm and Batch Norm into One Layer
335
+ """
336
+
337
+ def __init__(self, in_channels):
338
+ super().__init__()
339
+ self.bnorm_channels = in_channels // 2
340
+ self.inorm_channels = in_channels - self.bnorm_channels
341
+
342
+ self.bnorm = nn.BatchNorm2D(self.bnorm_channels)
343
+ self.inorm = nn.InstanceNorm2D(self.inorm_channels)
344
+
345
+ def forward(self, x):
346
+ bn_x = self.bnorm(x[:, :self.bnorm_channels, :, :])
347
+ in_x = self.inorm(x[:, self.bnorm_channels:, :, :])
348
+
349
+ return paddle.concat((bn_x, in_x), 1)
350
+
351
+
352
+ class Conv2dIBNormRelu(nn.Layer):
353
+ """
354
+ Convolution + IBNorm + Relu
355
+ """
356
+
357
+ def __init__(self,
358
+ in_channels,
359
+ out_channels,
360
+ kernel_size,
361
+ stride=1,
362
+ padding=0,
363
+ dilation=1,
364
+ groups=1,
365
+ bias_attr=None,
366
+ with_ibn=True,
367
+ with_relu=True):
368
+
369
+ super().__init__()
370
+
371
+ layers = [
372
+ nn.Conv2D(
373
+ in_channels,
374
+ out_channels,
375
+ kernel_size,
376
+ stride=stride,
377
+ padding=padding,
378
+ dilation=dilation,
379
+ groups=groups,
380
+ bias_attr=bias_attr)
381
+ ]
382
+
383
+ if with_ibn:
384
+ layers.append(IBNorm(out_channels))
385
+
386
+ if with_relu:
387
+ layers.append(nn.ReLU())
388
+
389
+ self.layers = nn.Sequential(*layers)
390
+
391
+ def forward(self, x):
392
+ return self.layers(x)
393
+
394
+
395
+ class SEBlock(nn.Layer):
396
+ """
397
+ SE Block Proposed in https://arxiv.org/pdf/1709.01507.pdf
398
+ """
399
+
400
+ def __init__(self, num_channels, reduction=1):
401
+ super().__init__()
402
+ self.pool = nn.AdaptiveAvgPool2D(1)
403
+ self.conv = nn.Sequential(
404
+ nn.Conv2D(
405
+ num_channels,
406
+ int(num_channels // reduction),
407
+ 1,
408
+ bias_attr=False), nn.ReLU(),
409
+ nn.Conv2D(
410
+ int(num_channels // reduction),
411
+ num_channels,
412
+ 1,
413
+ bias_attr=False), nn.Sigmoid())
414
+
415
+ def forward(self, x):
416
+ w = self.pool(x)
417
+ w = self.conv(w)
418
+ return w * x
419
+
420
+
421
+ class GaussianBlurLayer(nn.Layer):
422
+ """ Add Gaussian Blur to a 4D tensors
423
+ This layer takes a 4D tensor of {N, C, H, W} as input.
424
+ The Gaussian blur will be performed in given channel number (C) splitly.
425
+ """
426
+
427
+ def __init__(self, channels, kernel_size):
428
+ """
429
+ Args:
430
+ channels (int): Channel for input tensor
431
+ kernel_size (int): Size of the kernel used in blurring
432
+ """
433
+
434
+ super(GaussianBlurLayer, self).__init__()
435
+ self.channels = channels
436
+ self.kernel_size = kernel_size
437
+ assert self.kernel_size % 2 != 0
438
+
439
+ self.op = nn.Sequential(
440
+ nn.Pad2D(int(self.kernel_size / 2), mode='reflect'),
441
+ nn.Conv2D(
442
+ channels,
443
+ channels,
444
+ self.kernel_size,
445
+ stride=1,
446
+ padding=0,
447
+ bias_attr=False,
448
+ groups=channels))
449
+
450
+ self._init_kernel()
451
+ self.op[1].weight.stop_gradient = True
452
+
453
+ def forward(self, x):
454
+ """
455
+ Args:
456
+ x (paddle.Tensor): input 4D tensor
457
+ Returns:
458
+ paddle.Tensor: Blurred version of the input
459
+ """
460
+
461
+ if not len(list(x.shape)) == 4:
462
+ print('\'GaussianBlurLayer\' requires a 4D tensor as input\n')
463
+ exit()
464
+ elif not x.shape[1] == self.channels:
465
+ print('In \'GaussianBlurLayer\', the required channel ({0}) is'
466
+ 'not the same as input ({1})\n'.format(
467
+ self.channels, x.shape[1]))
468
+ exit()
469
+
470
+ return self.op(x)
471
+
472
+ def _init_kernel(self):
473
+ sigma = 0.3 * ((self.kernel_size - 1) * 0.5 - 1) + 0.8
474
+
475
+ n = np.zeros((self.kernel_size, self.kernel_size))
476
+ i = int(self.kernel_size / 2)
477
+ n[i, i] = 1
478
+ kernel = scipy.ndimage.gaussian_filter(n, sigma)
479
+ kernel = kernel.astype('float32')
480
+ kernel = kernel[np.newaxis, np.newaxis, :, :]
481
+ paddle.assign(kernel, self.op[1].weight)
matting/model/resnet_vd.py ADDED
@@ -0,0 +1,368 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import paddle
16
+ import paddle.nn as nn
17
+ import paddle.nn.functional as F
18
+
19
+ from paddleseg.cvlibs import manager
20
+ from paddleseg.models import layers
21
+ from paddleseg.utils import utils
22
+
23
+ __all__ = [
24
+ "ResNet18_vd", "ResNet34_vd", "ResNet50_vd", "ResNet101_vd", "ResNet152_vd"
25
+ ]
26
+
27
+
28
+ class ConvBNLayer(nn.Layer):
29
+ def __init__(
30
+ self,
31
+ in_channels,
32
+ out_channels,
33
+ kernel_size,
34
+ stride=1,
35
+ dilation=1,
36
+ groups=1,
37
+ is_vd_mode=False,
38
+ act=None,
39
+ ):
40
+ super(ConvBNLayer, self).__init__()
41
+
42
+ self.is_vd_mode = is_vd_mode
43
+ self._pool2d_avg = nn.AvgPool2D(
44
+ kernel_size=2, stride=2, padding=0, ceil_mode=True)
45
+ self._conv = nn.Conv2D(
46
+ in_channels=in_channels,
47
+ out_channels=out_channels,
48
+ kernel_size=kernel_size,
49
+ stride=stride,
50
+ padding=(kernel_size - 1) // 2 if dilation == 1 else 0,
51
+ dilation=dilation,
52
+ groups=groups,
53
+ bias_attr=False)
54
+
55
+ self._batch_norm = layers.SyncBatchNorm(out_channels)
56
+ self._act_op = layers.Activation(act=act)
57
+
58
+ def forward(self, inputs):
59
+ if self.is_vd_mode:
60
+ inputs = self._pool2d_avg(inputs)
61
+ y = self._conv(inputs)
62
+ y = self._batch_norm(y)
63
+ y = self._act_op(y)
64
+
65
+ return y
66
+
67
+
68
+ class BottleneckBlock(nn.Layer):
69
+ def __init__(self,
70
+ in_channels,
71
+ out_channels,
72
+ stride,
73
+ shortcut=True,
74
+ if_first=False,
75
+ dilation=1):
76
+ super(BottleneckBlock, self).__init__()
77
+
78
+ self.conv0 = ConvBNLayer(
79
+ in_channels=in_channels,
80
+ out_channels=out_channels,
81
+ kernel_size=1,
82
+ act='relu')
83
+
84
+ self.dilation = dilation
85
+
86
+ self.conv1 = ConvBNLayer(
87
+ in_channels=out_channels,
88
+ out_channels=out_channels,
89
+ kernel_size=3,
90
+ stride=stride,
91
+ act='relu',
92
+ dilation=dilation)
93
+ self.conv2 = ConvBNLayer(
94
+ in_channels=out_channels,
95
+ out_channels=out_channels * 4,
96
+ kernel_size=1,
97
+ act=None)
98
+
99
+ if not shortcut:
100
+ self.short = ConvBNLayer(
101
+ in_channels=in_channels,
102
+ out_channels=out_channels * 4,
103
+ kernel_size=1,
104
+ stride=1,
105
+ is_vd_mode=False if if_first or stride == 1 else True)
106
+
107
+ self.shortcut = shortcut
108
+
109
+ def forward(self, inputs):
110
+ y = self.conv0(inputs)
111
+
112
+ ####################################################################
113
+ # If given dilation rate > 1, using corresponding padding.
114
+ # The performance drops down without the follow padding.
115
+ if self.dilation > 1:
116
+ padding = self.dilation
117
+ y = F.pad(y, [padding, padding, padding, padding])
118
+ #####################################################################
119
+
120
+ conv1 = self.conv1(y)
121
+ conv2 = self.conv2(conv1)
122
+
123
+ if self.shortcut:
124
+ short = inputs
125
+ else:
126
+ short = self.short(inputs)
127
+
128
+ y = paddle.add(x=short, y=conv2)
129
+ y = F.relu(y)
130
+ return y
131
+
132
+
133
+ class BasicBlock(nn.Layer):
134
+ def __init__(self,
135
+ in_channels,
136
+ out_channels,
137
+ stride,
138
+ shortcut=True,
139
+ if_first=False):
140
+ super(BasicBlock, self).__init__()
141
+ self.stride = stride
142
+ self.conv0 = ConvBNLayer(
143
+ in_channels=in_channels,
144
+ out_channels=out_channels,
145
+ kernel_size=3,
146
+ stride=stride,
147
+ act='relu')
148
+ self.conv1 = ConvBNLayer(
149
+ in_channels=out_channels,
150
+ out_channels=out_channels,
151
+ kernel_size=3,
152
+ act=None)
153
+
154
+ if not shortcut:
155
+ self.short = ConvBNLayer(
156
+ in_channels=in_channels,
157
+ out_channels=out_channels,
158
+ kernel_size=1,
159
+ stride=1,
160
+ is_vd_mode=False if if_first else True)
161
+
162
+ self.shortcut = shortcut
163
+
164
+ def forward(self, inputs):
165
+ y = self.conv0(inputs)
166
+ conv1 = self.conv1(y)
167
+
168
+ if self.shortcut:
169
+ short = inputs
170
+ else:
171
+ short = self.short(inputs)
172
+ y = paddle.add(x=short, y=conv1)
173
+ y = F.relu(y)
174
+
175
+ return y
176
+
177
+
178
+ class ResNet_vd(nn.Layer):
179
+ """
180
+ The ResNet_vd implementation based on PaddlePaddle.
181
+
182
+ The original article refers to Jingdong
183
+ Tong He, et, al. "Bag of Tricks for Image Classification with Convolutional Neural Networks"
184
+ (https://arxiv.org/pdf/1812.01187.pdf).
185
+
186
+ Args:
187
+ layers (int, optional): The layers of ResNet_vd. The supported layers are (18, 34, 50, 101, 152, 200). Default: 50.
188
+ output_stride (int, optional): The stride of output features compared to input images. It is 8 or 16. Default: 8.
189
+ multi_grid (tuple|list, optional): The grid of stage4. Defult: (1, 1, 1).
190
+ pretrained (str, optional): The path of pretrained model.
191
+
192
+ """
193
+
194
+ def __init__(self,
195
+ input_channels=3,
196
+ layers=50,
197
+ output_stride=32,
198
+ multi_grid=(1, 1, 1),
199
+ pretrained=None):
200
+ super(ResNet_vd, self).__init__()
201
+
202
+ self.conv1_logit = None # for gscnn shape stream
203
+ self.layers = layers
204
+ supported_layers = [18, 34, 50, 101, 152, 200]
205
+ assert layers in supported_layers, \
206
+ "supported layers are {} but input layer is {}".format(
207
+ supported_layers, layers)
208
+
209
+ if layers == 18:
210
+ depth = [2, 2, 2, 2]
211
+ elif layers == 34 or layers == 50:
212
+ depth = [3, 4, 6, 3]
213
+ elif layers == 101:
214
+ depth = [3, 4, 23, 3]
215
+ elif layers == 152:
216
+ depth = [3, 8, 36, 3]
217
+ elif layers == 200:
218
+ depth = [3, 12, 48, 3]
219
+ num_channels = [64, 256, 512, 1024
220
+ ] if layers >= 50 else [64, 64, 128, 256]
221
+ num_filters = [64, 128, 256, 512]
222
+
223
+ # for channels of four returned stages
224
+ self.feat_channels = [c * 4 for c in num_filters
225
+ ] if layers >= 50 else num_filters
226
+ self.feat_channels = [64] + self.feat_channels
227
+
228
+ dilation_dict = None
229
+ if output_stride == 8:
230
+ dilation_dict = {2: 2, 3: 4}
231
+ elif output_stride == 16:
232
+ dilation_dict = {3: 2}
233
+
234
+ self.conv1_1 = ConvBNLayer(
235
+ in_channels=input_channels,
236
+ out_channels=32,
237
+ kernel_size=3,
238
+ stride=2,
239
+ act='relu')
240
+ self.conv1_2 = ConvBNLayer(
241
+ in_channels=32,
242
+ out_channels=32,
243
+ kernel_size=3,
244
+ stride=1,
245
+ act='relu')
246
+ self.conv1_3 = ConvBNLayer(
247
+ in_channels=32,
248
+ out_channels=64,
249
+ kernel_size=3,
250
+ stride=1,
251
+ act='relu')
252
+ self.pool2d_max = nn.MaxPool2D(kernel_size=3, stride=2, padding=1)
253
+
254
+ # self.block_list = []
255
+ self.stage_list = []
256
+ if layers >= 50:
257
+ for block in range(len(depth)):
258
+ shortcut = False
259
+ block_list = []
260
+ for i in range(depth[block]):
261
+ if layers in [101, 152] and block == 2:
262
+ if i == 0:
263
+ conv_name = "res" + str(block + 2) + "a"
264
+ else:
265
+ conv_name = "res" + str(block + 2) + "b" + str(i)
266
+ else:
267
+ conv_name = "res" + str(block + 2) + chr(97 + i)
268
+
269
+ ###############################################################################
270
+ # Add dilation rate for some segmentation tasks, if dilation_dict is not None.
271
+ dilation_rate = dilation_dict[
272
+ block] if dilation_dict and block in dilation_dict else 1
273
+
274
+ # Actually block here is 'stage', and i is 'block' in 'stage'
275
+ # At the stage 4, expand the the dilation_rate if given multi_grid
276
+ if block == 3:
277
+ dilation_rate = dilation_rate * multi_grid[i]
278
+ ###############################################################################
279
+
280
+ bottleneck_block = self.add_sublayer(
281
+ 'bb_%d_%d' % (block, i),
282
+ BottleneckBlock(
283
+ in_channels=num_channels[block]
284
+ if i == 0 else num_filters[block] * 4,
285
+ out_channels=num_filters[block],
286
+ stride=2 if i == 0 and block != 0
287
+ and dilation_rate == 1 else 1,
288
+ shortcut=shortcut,
289
+ if_first=block == i == 0,
290
+ dilation=dilation_rate))
291
+
292
+ block_list.append(bottleneck_block)
293
+ shortcut = True
294
+ self.stage_list.append(block_list)
295
+ else:
296
+ for block in range(len(depth)):
297
+ shortcut = False
298
+ block_list = []
299
+ for i in range(depth[block]):
300
+ conv_name = "res" + str(block + 2) + chr(97 + i)
301
+ basic_block = self.add_sublayer(
302
+ 'bb_%d_%d' % (block, i),
303
+ BasicBlock(
304
+ in_channels=num_channels[block]
305
+ if i == 0 else num_filters[block],
306
+ out_channels=num_filters[block],
307
+ stride=2 if i == 0 and block != 0 else 1,
308
+ shortcut=shortcut,
309
+ if_first=block == i == 0))
310
+ block_list.append(basic_block)
311
+ shortcut = True
312
+ self.stage_list.append(block_list)
313
+
314
+ self.pretrained = pretrained
315
+ self.init_weight()
316
+
317
+ def forward(self, inputs):
318
+ feat_list = []
319
+ y = self.conv1_1(inputs)
320
+ y = self.conv1_2(y)
321
+ y = self.conv1_3(y)
322
+ feat_list.append(y)
323
+
324
+ y = self.pool2d_max(y)
325
+
326
+ # A feature list saves the output feature map of each stage.
327
+ for stage in self.stage_list:
328
+ for block in stage:
329
+ y = block(y)
330
+ feat_list.append(y)
331
+
332
+ return feat_list
333
+
334
+ def init_weight(self):
335
+ utils.load_pretrained_model(self, self.pretrained)
336
+
337
+
338
+ @manager.BACKBONES.add_component
339
+ def ResNet18_vd(**args):
340
+ model = ResNet_vd(layers=18, **args)
341
+ return model
342
+
343
+
344
+ def ResNet34_vd(**args):
345
+ model = ResNet_vd(layers=34, **args)
346
+ return model
347
+
348
+
349
+ @manager.BACKBONES.add_component
350
+ def ResNet50_vd(**args):
351
+ model = ResNet_vd(layers=50, **args)
352
+ return model
353
+
354
+
355
+ @manager.BACKBONES.add_component
356
+ def ResNet101_vd(**args):
357
+ model = ResNet_vd(layers=101, **args)
358
+ return model
359
+
360
+
361
+ def ResNet152_vd(**args):
362
+ model = ResNet_vd(layers=152, **args)
363
+ return model
364
+
365
+
366
+ def ResNet200_vd(**args):
367
+ model = ResNet_vd(layers=200, **args)
368
+ return model
matting/model/vgg.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import paddle
16
+ from paddle import ParamAttr
17
+ import paddle.nn as nn
18
+ import paddle.nn.functional as F
19
+ from paddle.nn import Conv2D, BatchNorm, Linear, Dropout
20
+ from paddle.nn import AdaptiveAvgPool2D, MaxPool2D, AvgPool2D
21
+
22
+ from paddleseg.cvlibs import manager
23
+ from paddleseg.utils import utils
24
+
25
+
26
+ class ConvBlock(nn.Layer):
27
+ def __init__(self, input_channels, output_channels, groups, name=None):
28
+ super(ConvBlock, self).__init__()
29
+
30
+ self.groups = groups
31
+ self._conv_1 = Conv2D(
32
+ in_channels=input_channels,
33
+ out_channels=output_channels,
34
+ kernel_size=3,
35
+ stride=1,
36
+ padding=1,
37
+ weight_attr=ParamAttr(name=name + "1_weights"),
38
+ bias_attr=False)
39
+ if groups == 2 or groups == 3 or groups == 4:
40
+ self._conv_2 = Conv2D(
41
+ in_channels=output_channels,
42
+ out_channels=output_channels,
43
+ kernel_size=3,
44
+ stride=1,
45
+ padding=1,
46
+ weight_attr=ParamAttr(name=name + "2_weights"),
47
+ bias_attr=False)
48
+ if groups == 3 or groups == 4:
49
+ self._conv_3 = Conv2D(
50
+ in_channels=output_channels,
51
+ out_channels=output_channels,
52
+ kernel_size=3,
53
+ stride=1,
54
+ padding=1,
55
+ weight_attr=ParamAttr(name=name + "3_weights"),
56
+ bias_attr=False)
57
+ if groups == 4:
58
+ self._conv_4 = Conv2D(
59
+ in_channels=output_channels,
60
+ out_channels=output_channels,
61
+ kernel_size=3,
62
+ stride=1,
63
+ padding=1,
64
+ weight_attr=ParamAttr(name=name + "4_weights"),
65
+ bias_attr=False)
66
+
67
+ self._pool = MaxPool2D(
68
+ kernel_size=2, stride=2, padding=0, return_mask=True)
69
+
70
+ def forward(self, inputs):
71
+ x = self._conv_1(inputs)
72
+ x = F.relu(x)
73
+ if self.groups == 2 or self.groups == 3 or self.groups == 4:
74
+ x = self._conv_2(x)
75
+ x = F.relu(x)
76
+ if self.groups == 3 or self.groups == 4:
77
+ x = self._conv_3(x)
78
+ x = F.relu(x)
79
+ if self.groups == 4:
80
+ x = self._conv_4(x)
81
+ x = F.relu(x)
82
+ skip = x
83
+ x, max_indices = self._pool(x)
84
+ return x, max_indices, skip
85
+
86
+
87
+ class VGGNet(nn.Layer):
88
+ def __init__(self, input_channels=3, layers=11, pretrained=None):
89
+ super(VGGNet, self).__init__()
90
+ self.pretrained = pretrained
91
+
92
+ self.layers = layers
93
+ self.vgg_configure = {
94
+ 11: [1, 1, 2, 2, 2],
95
+ 13: [2, 2, 2, 2, 2],
96
+ 16: [2, 2, 3, 3, 3],
97
+ 19: [2, 2, 4, 4, 4]
98
+ }
99
+ assert self.layers in self.vgg_configure.keys(), \
100
+ "supported layers are {} but input layer is {}".format(
101
+ self.vgg_configure.keys(), layers)
102
+ self.groups = self.vgg_configure[self.layers]
103
+
104
+ # matting的第一层卷积输入为4通道,初始化是直接初始化为0
105
+ self._conv_block_1 = ConvBlock(
106
+ input_channels, 64, self.groups[0], name="conv1_")
107
+ self._conv_block_2 = ConvBlock(64, 128, self.groups[1], name="conv2_")
108
+ self._conv_block_3 = ConvBlock(128, 256, self.groups[2], name="conv3_")
109
+ self._conv_block_4 = ConvBlock(256, 512, self.groups[3], name="conv4_")
110
+ self._conv_block_5 = ConvBlock(512, 512, self.groups[4], name="conv5_")
111
+
112
+ # 这一层的初始化需要利用vgg fc6的参数转换后进行初始化,可以暂时不考虑初始化
113
+ self._conv_6 = Conv2D(
114
+ 512, 512, kernel_size=3, padding=1, bias_attr=False)
115
+
116
+ self.init_weight()
117
+
118
+ def forward(self, inputs):
119
+ fea_list = []
120
+ ids_list = []
121
+ x, ids, skip = self._conv_block_1(inputs)
122
+ fea_list.append(skip)
123
+ ids_list.append(ids)
124
+ x, ids, skip = self._conv_block_2(x)
125
+ fea_list.append(skip)
126
+ ids_list.append(ids)
127
+ x, ids, skip = self._conv_block_3(x)
128
+ fea_list.append(skip)
129
+ ids_list.append(ids)
130
+ x, ids, skip = self._conv_block_4(x)
131
+ fea_list.append(skip)
132
+ ids_list.append(ids)
133
+ x, ids, skip = self._conv_block_5(x)
134
+ fea_list.append(skip)
135
+ ids_list.append(ids)
136
+ x = F.relu(self._conv_6(x))
137
+ fea_list.append(x)
138
+ return fea_list
139
+
140
+ def init_weight(self):
141
+ if self.pretrained is not None:
142
+ utils.load_pretrained_model(self, self.pretrained)
143
+
144
+
145
+ @manager.BACKBONES.add_component
146
+ def VGG11(**args):
147
+ model = VGGNet(layers=11, **args)
148
+ return model
149
+
150
+
151
+ @manager.BACKBONES.add_component
152
+ def VGG13(**args):
153
+ model = VGGNet(layers=13, **args)
154
+ return model
155
+
156
+
157
+ @manager.BACKBONES.add_component
158
+ def VGG16(**args):
159
+ model = VGGNet(layers=16, **args)
160
+ return model
161
+
162
+
163
+ @manager.BACKBONES.add_component
164
+ def VGG19(**args):
165
+ model = VGGNet(layers=19, **args)
166
+ return model
matting/transforms.py ADDED
@@ -0,0 +1,530 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import random
16
+
17
+ import cv2
18
+ import numpy as np
19
+ from paddleseg.transforms import functional
20
+ from paddleseg.cvlibs import manager
21
+ from PIL import Image
22
+
23
+
24
+ @manager.TRANSFORMS.add_component
25
+ class Compose:
26
+ """
27
+ Do transformation on input data with corresponding pre-processing and augmentation operations.
28
+ The shape of input data to all operations is [height, width, channels].
29
+ """
30
+
31
+ def __init__(self, transforms, to_rgb=True):
32
+ if not isinstance(transforms, list):
33
+ raise TypeError('The transforms must be a list!')
34
+ self.transforms = transforms
35
+ self.to_rgb = to_rgb
36
+
37
+ def __call__(self, data):
38
+ """
39
+ Args:
40
+ data (dict): The data to transform.
41
+
42
+ Returns:
43
+ dict: Data after transformation
44
+ """
45
+ if 'trans_info' not in data:
46
+ data['trans_info'] = []
47
+ for op in self.transforms:
48
+ data = op(data)
49
+ if data is None:
50
+ return None
51
+
52
+ data['img'] = np.transpose(data['img'], (2, 0, 1))
53
+ for key in data.get('gt_fields', []):
54
+ if len(data[key].shape) == 2:
55
+ continue
56
+ data[key] = np.transpose(data[key], (2, 0, 1))
57
+
58
+ return data
59
+
60
+
61
+ @manager.TRANSFORMS.add_component
62
+ class LoadImages:
63
+ def __init__(self, to_rgb=True):
64
+ self.to_rgb = to_rgb
65
+
66
+ def __call__(self, data):
67
+ if isinstance(data['img'], str):
68
+ data['img'] = cv2.imread(data['img'])
69
+ for key in data.get('gt_fields', []):
70
+ if isinstance(data[key], str):
71
+ data[key] = cv2.imread(data[key], cv2.IMREAD_UNCHANGED)
72
+ # if alpha and trimap has 3 channels, extract one.
73
+ if key in ['alpha', 'trimap']:
74
+ if len(data[key].shape) > 2:
75
+ data[key] = data[key][:, :, 0]
76
+
77
+ if self.to_rgb:
78
+ data['img'] = cv2.cvtColor(data['img'], cv2.COLOR_BGR2RGB)
79
+ for key in data.get('gt_fields', []):
80
+ if len(data[key].shape) == 2:
81
+ continue
82
+ data[key] = cv2.cvtColor(data[key], cv2.COLOR_BGR2RGB)
83
+
84
+ return data
85
+
86
+
87
+ @manager.TRANSFORMS.add_component
88
+ class Resize:
89
+ def __init__(self, target_size=(512, 512)):
90
+ if isinstance(target_size, list) or isinstance(target_size, tuple):
91
+ if len(target_size) != 2:
92
+ raise ValueError(
93
+ '`target_size` should include 2 elements, but it is {}'.
94
+ format(target_size))
95
+ else:
96
+ raise TypeError(
97
+ "Type of `target_size` is invalid. It should be list or tuple, but it is {}"
98
+ .format(type(target_size)))
99
+
100
+ self.target_size = target_size
101
+
102
+ def __call__(self, data):
103
+ data['trans_info'].append(('resize', data['img'].shape[0:2]))
104
+ data['img'] = functional.resize(data['img'], self.target_size)
105
+ for key in data.get('gt_fields', []):
106
+ data[key] = functional.resize(data[key], self.target_size)
107
+ return data
108
+
109
+
110
+ @manager.TRANSFORMS.add_component
111
+ class ResizeByLong:
112
+ """
113
+ Resize the long side of an image to given size, and then scale the other side proportionally.
114
+
115
+ Args:
116
+ long_size (int): The target size of long side.
117
+ """
118
+
119
+ def __init__(self, long_size):
120
+ self.long_size = long_size
121
+
122
+ def __call__(self, data):
123
+ data['trans_info'].append(('resize', data['img'].shape[0:2]))
124
+ data['img'] = functional.resize_long(data['img'], self.long_size)
125
+ for key in data.get('gt_fields', []):
126
+ data[key] = functional.resize_long(data[key], self.long_size)
127
+ return data
128
+
129
+
130
+ @manager.TRANSFORMS.add_component
131
+ class ResizeByShort:
132
+ """
133
+ Resize the short side of an image to given size, and then scale the other side proportionally.
134
+
135
+ Args:
136
+ short_size (int): The target size of short side.
137
+ """
138
+
139
+ def __init__(self, short_size):
140
+ self.short_size = short_size
141
+
142
+ def __call__(self, data):
143
+ data['trans_info'].append(('resize', data['img'].shape[0:2]))
144
+ data['img'] = functional.resize_short(data['img'], self.short_size)
145
+ for key in data.get('gt_fields', []):
146
+ data[key] = functional.resize_short(data[key], self.short_size)
147
+ return data
148
+
149
+
150
+ @manager.TRANSFORMS.add_component
151
+ class ResizeToIntMult:
152
+ """
153
+ Resize to some int muitple, d.g. 32.
154
+ """
155
+
156
+ def __init__(self, mult_int=32):
157
+ self.mult_int = mult_int
158
+
159
+ def __call__(self, data):
160
+ data['trans_info'].append(('resize', data['img'].shape[0:2]))
161
+
162
+ h, w = data['img'].shape[0:2]
163
+ rw = w - w % 32
164
+ rh = h - h % 32
165
+ data['img'] = functional.resize(data['img'], (rw, rh))
166
+ for key in data.get('gt_fields', []):
167
+ data[key] = functional.resize(data[key], (rw, rh))
168
+
169
+ return data
170
+
171
+
172
+ @manager.TRANSFORMS.add_component
173
+ class Normalize:
174
+ """
175
+ Normalize an image.
176
+
177
+ Args:
178
+ mean (list, optional): The mean value of a data set. Default: [0.5, 0.5, 0.5].
179
+ std (list, optional): The standard deviation of a data set. Default: [0.5, 0.5, 0.5].
180
+
181
+ Raises:
182
+ ValueError: When mean/std is not list or any value in std is 0.
183
+ """
184
+
185
+ def __init__(self, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)):
186
+ self.mean = mean
187
+ self.std = std
188
+ if not (isinstance(self.mean, (list, tuple))
189
+ and isinstance(self.std, (list, tuple))):
190
+ raise ValueError(
191
+ "{}: input type is invalid. It should be list or tuple".format(
192
+ self))
193
+ from functools import reduce
194
+ if reduce(lambda x, y: x * y, self.std) == 0:
195
+ raise ValueError('{}: std is invalid!'.format(self))
196
+
197
+ def __call__(self, data):
198
+ mean = np.array(self.mean)[np.newaxis, np.newaxis, :]
199
+ std = np.array(self.std)[np.newaxis, np.newaxis, :]
200
+ data['img'] = functional.normalize(data['img'], mean, std)
201
+ if 'fg' in data.get('gt_fields', []):
202
+ data['fg'] = functional.normalize(data['fg'], mean, std)
203
+ if 'bg' in data.get('gt_fields', []):
204
+ data['bg'] = functional.normalize(data['bg'], mean, std)
205
+
206
+ return data
207
+
208
+
209
+ @manager.TRANSFORMS.add_component
210
+ class RandomCropByAlpha:
211
+ """
212
+ Randomly crop while centered on uncertain area by a certain probability.
213
+
214
+ Args:
215
+ crop_size (tuple|list): The size you want to crop from image.
216
+ p (float): The probability centered on uncertain area.
217
+
218
+ """
219
+
220
+ def __init__(self, crop_size=((320, 320), (480, 480), (640, 640)),
221
+ prob=0.5):
222
+ self.crop_size = crop_size
223
+ self.prob = prob
224
+
225
+ def __call__(self, data):
226
+ idex = np.random.randint(low=0, high=len(self.crop_size))
227
+ crop_w, crop_h = self.crop_size[idex]
228
+
229
+ img_h = data['img'].shape[0]
230
+ img_w = data['img'].shape[1]
231
+ if np.random.rand() < self.prob:
232
+ crop_center = np.where((data['alpha'] > 0) & (data['alpha'] < 255))
233
+ center_h_array, center_w_array = crop_center
234
+ if len(center_h_array) == 0:
235
+ return data
236
+ rand_ind = np.random.randint(len(center_h_array))
237
+ center_h = center_h_array[rand_ind]
238
+ center_w = center_w_array[rand_ind]
239
+ delta_h = crop_h // 2
240
+ delta_w = crop_w // 2
241
+ start_h = max(0, center_h - delta_h)
242
+ start_w = max(0, center_w - delta_w)
243
+ else:
244
+ start_h = 0
245
+ start_w = 0
246
+ if img_h > crop_h:
247
+ start_h = np.random.randint(img_h - crop_h + 1)
248
+ if img_w > crop_w:
249
+ start_w = np.random.randint(img_w - crop_w + 1)
250
+
251
+ end_h = min(img_h, start_h + crop_h)
252
+ end_w = min(img_w, start_w + crop_w)
253
+
254
+ data['img'] = data['img'][start_h:end_h, start_w:end_w]
255
+ for key in data.get('gt_fields', []):
256
+ data[key] = data[key][start_h:end_h, start_w:end_w]
257
+
258
+ return data
259
+
260
+
261
+ @manager.TRANSFORMS.add_component
262
+ class RandomCrop:
263
+ """
264
+ Randomly crop
265
+
266
+ Args:
267
+ crop_size (tuple|list): The size you want to crop from image.
268
+ """
269
+
270
+ def __init__(self, crop_size=((320, 320), (480, 480), (640, 640))):
271
+ if not isinstance(crop_size[0], (list, tuple)):
272
+ crop_size = [crop_size]
273
+ self.crop_size = crop_size
274
+
275
+ def __call__(self, data):
276
+ idex = np.random.randint(low=0, high=len(self.crop_size))
277
+ crop_w, crop_h = self.crop_size[idex]
278
+ img_h, img_w = data['img'].shape[0:2]
279
+
280
+ start_h = 0
281
+ start_w = 0
282
+ if img_h > crop_h:
283
+ start_h = np.random.randint(img_h - crop_h + 1)
284
+ if img_w > crop_w:
285
+ start_w = np.random.randint(img_w - crop_w + 1)
286
+
287
+ end_h = min(img_h, start_h + crop_h)
288
+ end_w = min(img_w, start_w + crop_w)
289
+
290
+ data['img'] = data['img'][start_h:end_h, start_w:end_w]
291
+ for key in data.get('gt_fields', []):
292
+ data[key] = data[key][start_h:end_h, start_w:end_w]
293
+
294
+ return data
295
+
296
+
297
+ @manager.TRANSFORMS.add_component
298
+ class LimitLong:
299
+ """
300
+ Limit the long edge of image.
301
+
302
+ If the long edge is larger than max_long, resize the long edge
303
+ to max_long, while scale the short edge proportionally.
304
+
305
+ If the long edge is smaller than min_long, resize the long edge
306
+ to min_long, while scale the short edge proportionally.
307
+
308
+ Args:
309
+ max_long (int, optional): If the long edge of image is larger than max_long,
310
+ it will be resize to max_long. Default: None.
311
+ min_long (int, optional): If the long edge of image is smaller than min_long,
312
+ it will be resize to min_long. Default: None.
313
+ """
314
+
315
+ def __init__(self, max_long=None, min_long=None):
316
+ if max_long is not None:
317
+ if not isinstance(max_long, int):
318
+ raise TypeError(
319
+ "Type of `max_long` is invalid. It should be int, but it is {}"
320
+ .format(type(max_long)))
321
+ if min_long is not None:
322
+ if not isinstance(min_long, int):
323
+ raise TypeError(
324
+ "Type of `min_long` is invalid. It should be int, but it is {}"
325
+ .format(type(min_long)))
326
+ if (max_long is not None) and (min_long is not None):
327
+ if min_long > max_long:
328
+ raise ValueError(
329
+ '`max_long should not smaller than min_long, but they are {} and {}'
330
+ .format(max_long, min_long))
331
+ self.max_long = max_long
332
+ self.min_long = min_long
333
+
334
+ def __call__(self, data):
335
+ h, w = data['img'].shape[:2]
336
+ long_edge = max(h, w)
337
+ target = long_edge
338
+ if (self.max_long is not None) and (long_edge > self.max_long):
339
+ target = self.max_long
340
+ elif (self.min_long is not None) and (long_edge < self.min_long):
341
+ target = self.min_long
342
+
343
+ if target != long_edge:
344
+ data['trans_info'].append(('resize', data['img'].shape[0:2]))
345
+ data['img'] = functional.resize_long(data['img'], target)
346
+ for key in data.get('gt_fields', []):
347
+ data[key] = functional.resize_long(data[key], target)
348
+
349
+ return data
350
+
351
+
352
+ @manager.TRANSFORMS.add_component
353
+ class RandomHorizontalFlip:
354
+ """
355
+ Flip an image horizontally with a certain probability.
356
+
357
+ Args:
358
+ prob (float, optional): A probability of horizontally flipping. Default: 0.5.
359
+ """
360
+
361
+ def __init__(self, prob=0.5):
362
+ self.prob = prob
363
+
364
+ def __call__(self, data):
365
+ if random.random() < self.prob:
366
+ data['img'] = functional.horizontal_flip(data['img'])
367
+ for key in data.get('gt_fields', []):
368
+ data[key] = functional.horizontal_flip(data[key])
369
+
370
+ return data
371
+
372
+
373
+ @manager.TRANSFORMS.add_component
374
+ class RandomBlur:
375
+ """
376
+ Blurring an image by a Gaussian function with a certain probability.
377
+
378
+ Args:
379
+ prob (float, optional): A probability of blurring an image. Default: 0.1.
380
+ """
381
+
382
+ def __init__(self, prob=0.1):
383
+ self.prob = prob
384
+
385
+ def __call__(self, data):
386
+ if self.prob <= 0:
387
+ n = 0
388
+ elif self.prob >= 1:
389
+ n = 1
390
+ else:
391
+ n = int(1.0 / self.prob)
392
+ if n > 0:
393
+ if np.random.randint(0, n) == 0:
394
+ radius = np.random.randint(3, 10)
395
+ if radius % 2 != 1:
396
+ radius = radius + 1
397
+ if radius > 9:
398
+ radius = 9
399
+ data['img'] = cv2.GaussianBlur(data['img'], (radius, radius), 0,
400
+ 0)
401
+ for key in data.get('gt_fields', []):
402
+ data[key] = cv2.GaussianBlur(data[key], (radius, radius), 0,
403
+ 0)
404
+ return data
405
+
406
+
407
+ @manager.TRANSFORMS.add_component
408
+ class RandomDistort:
409
+ """
410
+ Distort an image with random configurations.
411
+
412
+ Args:
413
+ brightness_range (float, optional): A range of brightness. Default: 0.5.
414
+ brightness_prob (float, optional): A probability of adjusting brightness. Default: 0.5.
415
+ contrast_range (float, optional): A range of contrast. Default: 0.5.
416
+ contrast_prob (float, optional): A probability of adjusting contrast. Default: 0.5.
417
+ saturation_range (float, optional): A range of saturation. Default: 0.5.
418
+ saturation_prob (float, optional): A probability of adjusting saturation. Default: 0.5.
419
+ hue_range (int, optional): A range of hue. Default: 18.
420
+ hue_prob (float, optional): A probability of adjusting hue. Default: 0.5.
421
+ """
422
+
423
+ def __init__(self,
424
+ brightness_range=0.5,
425
+ brightness_prob=0.5,
426
+ contrast_range=0.5,
427
+ contrast_prob=0.5,
428
+ saturation_range=0.5,
429
+ saturation_prob=0.5,
430
+ hue_range=18,
431
+ hue_prob=0.5):
432
+ self.brightness_range = brightness_range
433
+ self.brightness_prob = brightness_prob
434
+ self.contrast_range = contrast_range
435
+ self.contrast_prob = contrast_prob
436
+ self.saturation_range = saturation_range
437
+ self.saturation_prob = saturation_prob
438
+ self.hue_range = hue_range
439
+ self.hue_prob = hue_prob
440
+
441
+ def __call__(self, data):
442
+ brightness_lower = 1 - self.brightness_range
443
+ brightness_upper = 1 + self.brightness_range
444
+ contrast_lower = 1 - self.contrast_range
445
+ contrast_upper = 1 + self.contrast_range
446
+ saturation_lower = 1 - self.saturation_range
447
+ saturation_upper = 1 + self.saturation_range
448
+ hue_lower = -self.hue_range
449
+ hue_upper = self.hue_range
450
+ ops = [
451
+ functional.brightness, functional.contrast, functional.saturation,
452
+ functional.hue
453
+ ]
454
+ random.shuffle(ops)
455
+ params_dict = {
456
+ 'brightness': {
457
+ 'brightness_lower': brightness_lower,
458
+ 'brightness_upper': brightness_upper
459
+ },
460
+ 'contrast': {
461
+ 'contrast_lower': contrast_lower,
462
+ 'contrast_upper': contrast_upper
463
+ },
464
+ 'saturation': {
465
+ 'saturation_lower': saturation_lower,
466
+ 'saturation_upper': saturation_upper
467
+ },
468
+ 'hue': {
469
+ 'hue_lower': hue_lower,
470
+ 'hue_upper': hue_upper
471
+ }
472
+ }
473
+ prob_dict = {
474
+ 'brightness': self.brightness_prob,
475
+ 'contrast': self.contrast_prob,
476
+ 'saturation': self.saturation_prob,
477
+ 'hue': self.hue_prob
478
+ }
479
+
480
+ im = data['img'].astype('uint8')
481
+ im = Image.fromarray(im)
482
+ for id in range(len(ops)):
483
+ params = params_dict[ops[id].__name__]
484
+ params['im'] = im
485
+ prob = prob_dict[ops[id].__name__]
486
+ if np.random.uniform(0, 1) < prob:
487
+ im = ops[id](**params)
488
+ data['img'] = np.asarray(im)
489
+
490
+ for key in data.get('gt_fields', []):
491
+ if key in ['alpha', 'trimap']:
492
+ continue
493
+ else:
494
+ im = data[key].astype('uint8')
495
+ im = Image.fromarray(im)
496
+ for id in range(len(ops)):
497
+ params = params_dict[ops[id].__name__]
498
+ params['im'] = im
499
+ prob = prob_dict[ops[id].__name__]
500
+ if np.random.uniform(0, 1) < prob:
501
+ im = ops[id](**params)
502
+ data[key] = np.asarray(im)
503
+ return data
504
+
505
+
506
+ if __name__ == "__main__":
507
+ transforms = [RandomDistort()]
508
+ transforms = Compose(transforms)
509
+ fg_path = '/ssd1/home/chenguowei01/github/PaddleSeg/contrib/matting/data/matting/human_matting/Distinctions-646/train/fg/13(2).png'
510
+ alpha_path = fg_path.replace('fg', 'alpha')
511
+ bg_path = '/ssd1/home/chenguowei01/github/PaddleSeg/contrib/matting/data/matting/human_matting/bg/unsplash_bg/attic/photo-1443884590026-2e4d21aee71c?crop=entropy&cs=tinysrgb&fit=max&fm=jpg&ixid=MnwxMjA3fDB8MXxzZWFyY2h8Nzh8fGF0dGljfGVufDB8fHx8MTYyOTY4MDcxNQ&ixlib=rb-1.2.1&q=80&w=400.jpg'
512
+ data = {}
513
+ data['fg'] = cv2.imread(fg_path)
514
+ data['bg'] = cv2.imread(bg_path)
515
+ h, w, c = data['fg'].shape
516
+ data['bg'] = cv2.resize(data['bg'], (w, h))
517
+ alpha = cv2.imread(alpha_path)
518
+ data['alpha'] = alpha[:, :, 0]
519
+ alpha = alpha / 255.
520
+ data['img'] = alpha * data['fg'] + (1 - alpha) * data['bg']
521
+
522
+ data['gt_fields'] = ['fg', 'bg']
523
+ print(data['img'].shape)
524
+ for key in data['gt_fields']:
525
+ print(data[key].shape)
526
+ # import pdb
527
+ # pdb.set_trace()
528
+ data = transforms(data)
529
+ print(data['img'].dtype, data['img'].shape)
530
+ cv2.imwrite('distort_img.jpg', data['img'].transpose([1, 2, 0]))
matting/utils.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+
17
+
18
+ def get_files(root_path):
19
+ res = []
20
+ for root, dirs, files in os.walk(root_path, followlinks=True):
21
+ for f in files:
22
+ if f.endswith(('.jpg', '.png', '.jpeg', 'JPG')):
23
+ res.append(os.path.join(root, f))
24
+ return res
25
+
26
+
27
+ def get_image_list(image_path):
28
+ """Get image list"""
29
+ valid_suffix = [
30
+ '.JPEG', '.jpeg', '.JPG', '.jpg', '.BMP', '.bmp', '.PNG', '.png'
31
+ ]
32
+ image_list = []
33
+ image_dir = None
34
+ if os.path.isfile(image_path):
35
+ if os.path.splitext(image_path)[-1] in valid_suffix:
36
+ image_list.append(image_path)
37
+ else:
38
+ image_dir = os.path.dirname(image_path)
39
+ with open(image_path, 'r') as f:
40
+ for line in f:
41
+ line = line.strip()
42
+ if len(line.split()) > 1:
43
+ raise RuntimeError(
44
+ 'There should be only one image path per line in `image_path` file. Wrong line: {}'
45
+ .format(line))
46
+ image_list.append(os.path.join(image_dir, line))
47
+ elif os.path.isdir(image_path):
48
+ image_dir = image_path
49
+ for root, dirs, files in os.walk(image_path):
50
+ for f in files:
51
+ if '.ipynb_checkpoints' in root:
52
+ continue
53
+ if os.path.splitext(f)[-1] in valid_suffix:
54
+ image_list.append(os.path.join(root, f))
55
+ image_list.sort()
56
+ else:
57
+ raise FileNotFoundError(
58
+ '`image_path` is not found. it should be an image file or a directory including images'
59
+ )
60
+
61
+ if len(image_list) == 0:
62
+ raise RuntimeError('There are not image file in `image_path`')
63
+
64
+ return image_list, image_dir
65
+
66
+
67
+ def mkdir(path):
68
+ sub_dir = os.path.dirname(path)
69
+ if not os.path.exists(sub_dir):
70
+ os.makedirs(sub_dir)
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
1
+ paddlepaddle
2
+ paddleseg
train.txt ADDED
File without changes
val.txt ADDED
File without changes