ZhengPeng7 commited on
Commit
7febe9c
1 Parent(s): 1a1cf3c

Initialization.

Browse files
Files changed (5) hide show
  1. .gitignore +134 -0
  2. app.py +73 -0
  3. config.py +107 -0
  4. models/GCoNet.py +248 -0
  5. models/modules.py +516 -0
.gitignore ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Custom
2
+ .vscode
3
+ *.pth
4
+
5
+
6
+ # Byte-compiled / optimized / DLL files
7
+ __pycache__/
8
+ *.py[cod]
9
+ *$py.class
10
+
11
+ # C extensions
12
+ *.so
13
+
14
+ # Distribution / packaging
15
+ .Python
16
+ build/
17
+ develop-eggs/
18
+ dist/
19
+ downloads/
20
+ eggs/
21
+ .eggs/
22
+ lib/
23
+ lib64/
24
+ parts/
25
+ sdist/
26
+ var/
27
+ wheels/
28
+ pip-wheel-metadata/
29
+ share/python-wheels/
30
+ *.egg-info/
31
+ .installed.cfg
32
+ *.egg
33
+ MANIFEST
34
+
35
+ # PyInstaller
36
+ # Usually these files are written by a python script from a template
37
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
38
+ *.manifest
39
+ *.spec
40
+
41
+ # Installer logs
42
+ pip-log.txt
43
+ pip-delete-this-directory.txt
44
+
45
+ # Unit test / coverage reports
46
+ htmlcov/
47
+ .tox/
48
+ .nox/
49
+ .coverage
50
+ .coverage.*
51
+ .cache
52
+ nosetests.xml
53
+ coverage.xml
54
+ *.cover
55
+ *.py,cover
56
+ .hypothesis/
57
+ .pytest_cache/
58
+
59
+ # Translations
60
+ *.mo
61
+ *.pot
62
+
63
+ # Django stuff:
64
+ *.log
65
+ local_settings.py
66
+ db.sqlite3
67
+ db.sqlite3-journal
68
+
69
+ # Flask stuff:
70
+ instance/
71
+ .webassets-cache
72
+
73
+ # Scrapy stuff:
74
+ .scrapy
75
+
76
+ # Sphinx documentation
77
+ docs/_build/
78
+
79
+ # PyBuilder
80
+ target/
81
+
82
+ # Jupyter Notebook
83
+ .ipynb_checkpoints
84
+
85
+ # IPython
86
+ profile_default/
87
+ ipython_config.py
88
+
89
+ # pyenv
90
+ .python-version
91
+
92
+ # pipenv
93
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
94
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
95
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
96
+ # install all needed dependencies.
97
+ #Pipfile.lock
98
+
99
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
100
+ __pypackages__/
101
+
102
+ # Celery stuff
103
+ celerybeat-schedule
104
+ celerybeat.pid
105
+
106
+ # SageMath parsed files
107
+ *.sage.py
108
+
109
+ # Environments
110
+ .env
111
+ .venv
112
+ env/
113
+ venv/
114
+ ENV/
115
+ env.bak/
116
+ venv.bak/
117
+
118
+ # Spyder project settings
119
+ .spyderproject
120
+ .spyproject
121
+
122
+ # Rope project settings
123
+ .ropeproject
124
+
125
+ # mkdocs documentation
126
+ /site
127
+
128
+ # mypy
129
+ .mypy_cache/
130
+ .dmypy.json
131
+ dmypy.json
132
+
133
+ # Pyre type checker
134
+ .pyre/
app.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from glob import glob
3
+ import cv2
4
+ import numpy as np
5
+ from PIL import Image
6
+ import matplotlib.pyplot as plt
7
+ import torch
8
+ from torchvision import transforms
9
+ import gradio as gr
10
+
11
+ from models.GCoNet import GCoNet
12
+
13
+
14
+ device = ['cpu', 'cuda'][0]
15
+
16
+
17
+ class ImagePreprocessor():
18
+ def __init__(self) -> None:
19
+ self.transform_image = transforms.Compose([
20
+ transforms.Resize((256, 256)),
21
+ transforms.ToTensor(),
22
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
23
+ ])
24
+
25
+ def proc(self, image):
26
+ image = self.transform_image(image)
27
+ return image
28
+
29
+
30
+ model = GCoNet(bb_pretrained=False).to(device)
31
+ state_dict = './ultimate_duts_cocoseg (The best one).pth'
32
+ if os.path.exists(state_dict):
33
+ gconet_dict = torch.load(state_dict, map_location=device)
34
+ model.load_state_dict(gconet_dict)
35
+ model.eval()
36
+
37
+
38
+ def pred_maps(dr):
39
+ images = [cv2.imread(image_path) for image_path in glob(os.path.join(dr, '*'))]
40
+ image_shapes = [image.shape[:2] for image in images]
41
+ images = [Image.fromarray(image) for image in images]
42
+
43
+ images_proc = []
44
+ image_preprocessor = ImagePreprocessor()
45
+ for image in images:
46
+ images_proc.append(image_preprocessor.proc(image))
47
+ images_proc = torch.cat([image_proc.unsqueeze(0) for image_proc in images_proc])
48
+
49
+ with torch.no_grad():
50
+ scaled_preds_tensor = model(images_proc.to(device))[-1]
51
+ preds = []
52
+ for image_shape, pred_tensor in zip(image_shapes, scaled_preds_tensor):
53
+ if device == 'cuda':
54
+ pred_tensor = pred_tensor.cpu()
55
+ preds.append(torch.nn.functional.interpolate(pred_tensor.unsqueeze(0), size=image_shape, mode='bilinear', align_corners=True).squeeze().numpy())
56
+ image_preds = []
57
+ for image, pred in zip(images, preds):
58
+ image_preds.append(
59
+ cv2.cvtColor(
60
+ np.hstack([np.array(image.convert('RGB')), cv2.cvtColor((pred*255).astype(np.uint8), cv2.COLOR_GRAY2RGB)]),
61
+ cv2.COLOR_BGR2RGB
62
+ ))
63
+ # for image_pred in image_preds:
64
+ # cv2.imwrite('a.png', cv2.cvtColor(image_pred, cv2.COLOR_RGB2BGR))
65
+ return image_preds[:]
66
+
67
+ demo = gr.Interface(
68
+ fn=pred_maps,
69
+ inputs='text',
70
+ outputs=['image', 'image', 'image', 'image', 'image'],
71
+ css=".output_image, .input_image {height: 300px !important}",
72
+ )
73
+ demo.launch(debug=True)
config.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+
4
+ class Config():
5
+ def __init__(self) -> None:
6
+ # Backbone
7
+ self.bb = ['vgg16', 'vgg16bn', 'resnet50'][1]
8
+ # BN
9
+ self.use_bn = 'bn' in self.bb or 'resnet' in self.bb
10
+ # Augmentation
11
+ self.preproc_methods = ['flip', 'enhance', 'rotate', 'crop', 'pepper'][:3]
12
+
13
+ # Mask
14
+ losses = ['sal', 'cls', 'contrast', 'cls_mask']
15
+ self.loss = losses[:]
16
+ self.cls_mask_operation = ['x', '+', 'c'][0]
17
+ # Loss + Triplet Loss
18
+ self.lambdas_sal_last = {
19
+ # not 0 means opening this loss
20
+ # original rate -- 1 : 30 : 1.5 : 0.2, bce x 30
21
+ 'bce': 30 * 1, # high performance
22
+ 'iou': 0.5 * 1, # 0 / 255
23
+ 'ssim': 1 * 0, # help contours
24
+ 'mse': 150 * 0, # can smooth the saliency map
25
+ 'reg': 100 * 0,
26
+ 'triplet': 3 * 1 * ('cls' in self.loss),
27
+ }
28
+
29
+ # DB
30
+ self.db_output_decoder = True
31
+ self.db_k = 300
32
+ self.db_k_alpha = 1
33
+ self.split_mask = True and 'cls_mask' in self.loss
34
+ self.db_mask = False and self.split_mask
35
+
36
+ # Triplet Loss
37
+ self.triplet = ['_x5', 'mask'][:1]
38
+ self.triplet_loss_margin = 0.1
39
+ # Adv
40
+ self.lambda_adv = 0. # turn to 0 to avoid adv training
41
+
42
+ # Refiner
43
+ self.refine = [0, 1, 4][0] # 0 -- no refinement, 1 -- only output mask for refinement, 4 -- but also raw input.
44
+ if self.refine:
45
+ self.batch_size = 16
46
+ else:
47
+ if self.bb != 'vgg16':
48
+ self.batch_size = 26
49
+ else:
50
+ self.batch_size = 48
51
+ self.db_output_refiner = False and self.refine
52
+
53
+ # Intermediate Layers
54
+ self.lambdas_sal_others = {
55
+ 'bce': 0,
56
+ 'iou': 0.,
57
+ 'ssim': 0,
58
+ 'mse': 0,
59
+ 'reg': 0,
60
+ 'triplet': 0,
61
+ }
62
+ self.output_number = 1
63
+ self.loss_sal_layers = 4 # used to be last 4 layers
64
+ self.loss_cls_mask_last_layers = 1 # used to be last 4 layers
65
+ if 'keep in range':
66
+ self.loss_sal_layers = min(self.output_number, self.loss_sal_layers)
67
+ self.loss_cls_mask_last_layers = min(self.output_number, self.loss_cls_mask_last_layers)
68
+ self.output_number = min(self.output_number, max(self.loss_sal_layers, self.loss_cls_mask_last_layers))
69
+ if self.output_number == 1:
70
+ for cri in self.lambdas_sal_others:
71
+ self.lambdas_sal_others[cri] = 0
72
+ self.conv_after_itp = False
73
+ self.complex_lateral_connection = False
74
+
75
+ # to control the quantitive level of each single loss by number of output branches.
76
+ self.loss_cls_mask_ratio_by_last_layers = 4 / self.loss_cls_mask_last_layers
77
+ for loss_sal in self.lambdas_sal_last.keys():
78
+ loss_sal_ratio_by_last_layers = 4 / (int(bool(self.lambdas_sal_others[loss_sal])) * (self.loss_sal_layers - 1) + 1)
79
+ self.lambdas_sal_last[loss_sal] *= loss_sal_ratio_by_last_layers
80
+ self.lambdas_sal_others[loss_sal] *= loss_sal_ratio_by_last_layers
81
+ self.lambda_cls_mask = 2.5 * self.loss_cls_mask_ratio_by_last_layers
82
+ self.lambda_cls = 3.
83
+ self.lambda_contrast = 250.
84
+
85
+ # Performance of GCoNet
86
+ self.val_measures = {
87
+ 'Emax': {'CoCA': 0.760, 'CoSOD3k': 0.860, 'CoSal2015': 0.887},
88
+ 'Smeasure': {'CoCA': 0.673, 'CoSOD3k': 0.802, 'CoSal2015': 0.845},
89
+ 'Fmax': {'CoCA': 0.544, 'CoSOD3k': 0.777, 'CoSal2015': 0.847},
90
+ }
91
+
92
+ # others
93
+ self.GAM = True
94
+ if not self.GAM and 'contrast' in self.loss:
95
+ self.loss.remove('contrast')
96
+ self.lr = 1e-4 * (self.batch_size / 16)
97
+ self.relation_module = ['GAM', 'ICE', 'NonLocal', 'MHA'][0]
98
+ self.self_supervision = False
99
+ self.label_smoothing = False
100
+ self.freeze = True
101
+
102
+ self.validation = False
103
+ self.decay_step_size = 3000
104
+ self.rand_seed = 7
105
+ run_sh_file = [f for f in os.listdir('.') if 'gco' in f and '.sh' in f] + [os.path.join('..', f) for f in os.listdir('..') if 'gco' in f and '.sh' in f]
106
+ # with open(run_sh_file[0], 'r') as f:
107
+ # self.val_last = int([l.strip() for l in f.readlines() if 'val_last=' in l][0].split('=')[-1])
models/GCoNet.py ADDED
@@ -0,0 +1,248 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+ import torch
3
+ from torch.functional import norm
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from torchvision.models import vgg16, vgg16_bn
7
+ import fvcore.nn.weight_init as weight_init
8
+ from torchvision.models import resnet50
9
+
10
+ from models.modules import ResBlk, DSLayer, half_DSLayer, CoAttLayer, RefUnet, DBHead
11
+
12
+ from config import Config
13
+
14
+
15
+ class GCoNet(nn.Module):
16
+ def __init__(self, bb_pretrained=True):
17
+ super(GCoNet, self).__init__()
18
+ self.config = Config()
19
+ bb = self.config.bb
20
+ if bb == 'vgg16':
21
+ bb_net = list(vgg16(pretrained=bb_pretrained).children())[0]
22
+ bb_convs = OrderedDict({
23
+ 'conv1': bb_net[:4],
24
+ 'conv2': bb_net[4:9],
25
+ 'conv3': bb_net[9:16],
26
+ 'conv4': bb_net[16:23],
27
+ 'conv5': bb_net[23:30]
28
+ })
29
+ channel_scale = 1
30
+ elif bb == 'resnet50':
31
+ bb_net = list(resnet50(pretrained=bb_pretrained).children())
32
+ bb_convs = OrderedDict({
33
+ 'conv1': nn.Sequential(*bb_net[0:3]),
34
+ 'conv2': bb_net[4],
35
+ 'conv3': bb_net[5],
36
+ 'conv4': bb_net[6],
37
+ 'conv5': bb_net[7]
38
+ })
39
+ channel_scale = 4
40
+ elif bb == 'vgg16bn':
41
+ bb_net = list(vgg16_bn(pretrained=bb_pretrained).children())[0]
42
+ bb_convs = OrderedDict({
43
+ 'conv1': bb_net[:6],
44
+ 'conv2': bb_net[6:13],
45
+ 'conv3': bb_net[13:23],
46
+ 'conv4': bb_net[23:33],
47
+ 'conv5': bb_net[33:43]
48
+ })
49
+ channel_scale = 1
50
+ self.bb = nn.Sequential(bb_convs)
51
+ lateral_channels_in = [512, 512, 256, 128, 64] if 'vgg16' in bb else [2048, 1024, 512, 256, 64]
52
+
53
+ # channel_scale_latlayer = channel_scale // 2 if bb == 'resnet50' else 1
54
+ # channel_last = 32
55
+
56
+ ch_decoder = lateral_channels_in[0]//2//channel_scale
57
+ self.top_layer = ResBlk(lateral_channels_in[0], ch_decoder)
58
+ self.enlayer5 = ResBlk(ch_decoder, ch_decoder)
59
+ if self.config.conv_after_itp:
60
+ self.dslayer5 = DSLayer(ch_decoder, ch_decoder)
61
+ self.latlayer5 = ResBlk(lateral_channels_in[1], ch_decoder) if self.config.complex_lateral_connection else nn.Conv2d(lateral_channels_in[1], ch_decoder, 1, 1, 0)
62
+
63
+ ch_decoder //= 2
64
+ self.enlayer4 = ResBlk(ch_decoder*2, ch_decoder)
65
+ if self.config.conv_after_itp:
66
+ self.dslayer4 = DSLayer(ch_decoder, ch_decoder)
67
+ self.latlayer4 = ResBlk(lateral_channels_in[2], ch_decoder) if self.config.complex_lateral_connection else nn.Conv2d(lateral_channels_in[2], ch_decoder, 1, 1, 0)
68
+ if self.config.output_number >= 4:
69
+ self.conv_out4 = nn.Sequential(nn.Conv2d(ch_decoder, 32, 1, 1, 0), nn.ReLU(inplace=True), nn.Conv2d(32, 1, 1, 1, 0))
70
+
71
+ ch_decoder //= 2
72
+ self.enlayer3 = ResBlk(ch_decoder*2, ch_decoder)
73
+ if self.config.conv_after_itp:
74
+ self.dslayer3 = DSLayer(ch_decoder, ch_decoder)
75
+ self.latlayer3 = ResBlk(lateral_channels_in[3], ch_decoder) if self.config.complex_lateral_connection else nn.Conv2d(lateral_channels_in[3], ch_decoder, 1, 1, 0)
76
+ if self.config.output_number >= 3:
77
+ self.conv_out3 = nn.Sequential(nn.Conv2d(ch_decoder, 32, 1, 1, 0), nn.ReLU(inplace=True), nn.Conv2d(32, 1, 1, 1, 0))
78
+
79
+ ch_decoder //= 2
80
+ self.enlayer2 = ResBlk(ch_decoder*2, ch_decoder)
81
+ if self.config.conv_after_itp:
82
+ self.dslayer2 = DSLayer(ch_decoder, ch_decoder)
83
+ self.latlayer2 = ResBlk(lateral_channels_in[4], ch_decoder) if self.config.complex_lateral_connection else nn.Conv2d(lateral_channels_in[4], ch_decoder, 1, 1, 0)
84
+ if self.config.output_number >= 2:
85
+ self.conv_out2 = nn.Sequential(nn.Conv2d(ch_decoder, 32, 1, 1, 0), nn.ReLU(inplace=True), nn.Conv2d(32, 1, 1, 1, 0))
86
+
87
+ self.enlayer1 = ResBlk(ch_decoder, ch_decoder)
88
+ self.conv_out1 = nn.Sequential(nn.Conv2d(ch_decoder, 1, 1, 1, 0))
89
+
90
+ if self.config.GAM:
91
+ self.co_x5 = CoAttLayer(channel_in=lateral_channels_in[0])
92
+
93
+ if 'contrast' in self.config.loss:
94
+ self.pred_layer = half_DSLayer(lateral_channels_in[0])
95
+
96
+ if {'cls', 'cls_mask'} & set(self.config.loss):
97
+ self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
98
+ self.classifier = nn.Linear(lateral_channels_in[0], 291) # DUTS_class has 291 classes
99
+ for layer in [self.classifier]:
100
+ weight_init.c2_msra_fill(layer)
101
+ if self.config.split_mask:
102
+ self.sgm = nn.Sigmoid()
103
+ if self.config.refine:
104
+ self.refiner = nn.Sequential(RefUnet(self.config.refine, 64))
105
+ if self.config.split_mask:
106
+ self.conv_out_mask = nn.Sequential(nn.Conv2d(ch_decoder, 1, 1, 1, 0))
107
+ if self.config.db_mask:
108
+ self.db_mask = DBHead(32)
109
+ if self.config.db_output_decoder:
110
+ self.db_output_decoder = DBHead(32)
111
+ if self.config.cls_mask_operation == 'c':
112
+ self.conv_cat_mask = nn.Conv2d(4, 3, 1, 1, 0)
113
+
114
+ def forward(self, x):
115
+ ########## Encoder ##########
116
+
117
+ [N, _, H, W] = x.size()
118
+ x1 = self.bb.conv1(x)
119
+ x2 = self.bb.conv2(x1)
120
+ x3 = self.bb.conv3(x2)
121
+ x4 = self.bb.conv4(x3)
122
+ x5 = self.bb.conv5(x4)
123
+
124
+ if 'cls' in self.config.loss:
125
+ _x5 = self.avgpool(x5)
126
+ _x5 = _x5.view(_x5.size(0), -1)
127
+ pred_cls = self.classifier(_x5)
128
+
129
+ if self.config.GAM:
130
+ weighted_x5, neg_x5 = self.co_x5(x5)
131
+ if 'contrast' in self.config.loss:
132
+ if self.training:
133
+ ########## contrastive branch #########
134
+ cat_x5 = torch.cat([weighted_x5, neg_x5], dim=0)
135
+ pred_contrast = self.pred_layer(cat_x5)
136
+ pred_contrast = F.interpolate(pred_contrast, size=(H, W), mode='bilinear', align_corners=True)
137
+ p5 = self.top_layer(weighted_x5)
138
+ else:
139
+ p5 = self.top_layer(x5)
140
+
141
+ ########## Decoder ##########
142
+ scaled_preds = []
143
+ p5 = self.enlayer5(p5)
144
+ p5 = F.interpolate(p5, size=x4.shape[2:], mode='bilinear', align_corners=True)
145
+ if self.config.conv_after_itp:
146
+ p5 = self.dslayer5(p5)
147
+ p4 = p5 + self.latlayer5(x4)
148
+
149
+ p4 = self.enlayer4(p4)
150
+ p4 = F.interpolate(p4, size=x3.shape[2:], mode='bilinear', align_corners=True)
151
+ if self.config.conv_after_itp:
152
+ p4 = self.dslayer4(p4)
153
+ if self.config.output_number >= 4:
154
+ p4_out = self.conv_out4(p4)
155
+ scaled_preds.append(p4_out)
156
+ p3 = p4 + self.latlayer4(x3)
157
+
158
+ p3 = self.enlayer3(p3)
159
+ p3 = F.interpolate(p3, size=x2.shape[2:], mode='bilinear', align_corners=True)
160
+ if self.config.conv_after_itp:
161
+ p3 = self.dslayer3(p3)
162
+ if self.config.output_number >= 3:
163
+ p3_out = self.conv_out3(p3)
164
+ scaled_preds.append(p3_out)
165
+ p2 = p3 + self.latlayer3(x2)
166
+
167
+ p2 = self.enlayer2(p2)
168
+ p2 = F.interpolate(p2, size=x1.shape[2:], mode='bilinear', align_corners=True)
169
+ if self.config.conv_after_itp:
170
+ p2 = self.dslayer2(p2)
171
+ if self.config.output_number >= 2:
172
+ p2_out = self.conv_out2(p2)
173
+ scaled_preds.append(p2_out)
174
+ p1 = p2 + self.latlayer2(x1)
175
+
176
+ p1 = self.enlayer1(p1)
177
+ p1 = F.interpolate(p1, size=x.shape[2:], mode='bilinear', align_corners=True)
178
+ if self.config.db_output_decoder:
179
+ p1_out = self.db_output_decoder(p1)
180
+ else:
181
+ p1_out = self.conv_out1(p1)
182
+ scaled_preds.append(p1_out)
183
+
184
+ if self.config.refine == 1:
185
+ scaled_preds.append(self.refiner(p1_out))
186
+ elif self.config.refine == 4:
187
+ scaled_preds.append(self.refiner(torch.cat([x, p1_out], dim=1)))
188
+
189
+ if 'cls_mask' in self.config.loss:
190
+ pred_cls_masks = []
191
+ norm_features_mask = []
192
+ input_features = [x, x1, x2, x3][:self.config.loss_cls_mask_last_layers]
193
+ bb_lst = [self.bb.conv1, self.bb.conv2, self.bb.conv3, self.bb.conv4, self.bb.conv5]
194
+ for idx_out in range(self.config.loss_cls_mask_last_layers):
195
+ if idx_out:
196
+ mask_output = scaled_preds[-(idx_out+1+int(bool(self.config.refine)))]
197
+ else:
198
+ if self.config.split_mask:
199
+ if self.config.db_mask:
200
+ mask_output = self.db_mask(p1)
201
+ else:
202
+ mask_output = self.sgm(self.conv_out_mask(p1))
203
+
204
+ if self.config.cls_mask_operation == 'x':
205
+ masked_features = input_features[idx_out] * mask_output
206
+ elif self.config.cls_mask_operation == '+':
207
+ masked_features = input_features[idx_out] + mask_output
208
+ elif self.config.cls_mask_operation == 'c':
209
+ masked_features = self.conv_cat_mask(torch.cat((input_features[idx_out], mask_output), dim=1))
210
+ norm_feature_mask = self.avgpool(
211
+ nn.Sequential(*bb_lst[idx_out:])(
212
+ masked_features
213
+ )
214
+ ).view(N, -1)
215
+ norm_features_mask.append(norm_feature_mask)
216
+ pred_cls_masks.append(
217
+ self.classifier(
218
+ norm_feature_mask
219
+ )
220
+ )
221
+
222
+ if self.training:
223
+ return_values = []
224
+ if {'sal', 'cls', 'contrast', 'cls_mask'} == set(self.config.loss):
225
+ return_values = [scaled_preds, pred_cls, pred_contrast, pred_cls_masks]
226
+ elif {'sal', 'cls', 'contrast'} == set(self.config.loss):
227
+ return_values = [scaled_preds, pred_cls, pred_contrast]
228
+ elif {'sal', 'cls', 'cls_mask'} == set(self.config.loss):
229
+ return_values = [scaled_preds, pred_cls, pred_cls_masks]
230
+ elif {'sal', 'cls'} == set(self.config.loss):
231
+ return_values = [scaled_preds, pred_cls]
232
+ elif {'sal', 'contrast'} == set(self.config.loss):
233
+ return_values = [scaled_preds, pred_contrast]
234
+ elif {'sal', 'cls_mask'} == set(self.config.loss):
235
+ return_values = [scaled_preds, pred_cls_masks]
236
+ else:
237
+ return_values = [scaled_preds]
238
+
239
+ if self.config.lambdas_sal_last['triplet']:
240
+ norm_features = []
241
+ if '_x5' in self.config.triplet:
242
+ norm_features.append(_x5)
243
+ if 'mask' in self.config.triplet:
244
+ norm_features.append(norm_features_mask[0])
245
+ return_values.append(norm_features)
246
+ return return_values
247
+ else:
248
+ return scaled_preds
models/modules.py ADDED
@@ -0,0 +1,516 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import fvcore.nn.weight_init as weight_init
6
+
7
+ from config import Config
8
+
9
+
10
+ config = Config()
11
+
12
+
13
+ class ResBlk(nn.Module):
14
+ def __init__(self, channel_in=64, channel_out=64):
15
+ super(ResBlk, self).__init__()
16
+ self.conv_in = nn.Conv2d(channel_in, 64, 3, 1, 1)
17
+ self.relu_in = nn.ReLU(inplace=True)
18
+ self.conv_out = nn.Conv2d(64, channel_out, 3, 1, 1)
19
+ if config.use_bn:
20
+ self.bn_in = nn.BatchNorm2d(64)
21
+ self.bn_out = nn.BatchNorm2d(channel_out)
22
+
23
+ def forward(self, x):
24
+ x = self.conv_in(x)
25
+ if config.use_bn:
26
+ x = self.bn_in(x)
27
+ x = self.relu_in(x)
28
+ x = self.conv_out(x)
29
+ if config.use_bn:
30
+ x = self.bn_out(x)
31
+ return x
32
+
33
+
34
+ class DSLayer(nn.Module):
35
+ def __init__(self, channel_in=64, channel_out=1, activation_out='relu'):
36
+ super(DSLayer, self).__init__()
37
+ self.activation_out = activation_out
38
+ self.conv1 = nn.Conv2d(channel_in, 64, kernel_size=3, stride=1, padding=1)
39
+ self.relu1 = nn.ReLU(inplace=True)
40
+
41
+ self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
42
+ self.relu2 = nn.ReLU(inplace=True)
43
+ if activation_out:
44
+ self.pred_conv = nn.Conv2d(64, channel_out, kernel_size=1, stride=1, padding=0)
45
+ self.pred_relu = nn.ReLU(inplace=True)
46
+ else:
47
+ self.pred_conv = nn.Conv2d(64, channel_out, kernel_size=1, stride=1, padding=0)
48
+
49
+ if config.use_bn:
50
+ self.bn1 = nn.BatchNorm2d(64)
51
+ self.bn2 = nn.BatchNorm2d(64)
52
+ self.pred_bn = nn.BatchNorm2d(channel_out)
53
+
54
+ def forward(self, x):
55
+ x = self.conv1(x)
56
+ if config.use_bn:
57
+ x = self.bn1(x)
58
+ x = self.relu1(x)
59
+ x = self.conv2(x)
60
+ if config.use_bn:
61
+ x = self.bn2(x)
62
+ x = self.relu2(x)
63
+
64
+ x = self.pred_conv(x)
65
+ if config.use_bn:
66
+ x = self.pred_bn(x)
67
+ if self.activation_out:
68
+ x = self.pred_relu(x)
69
+ return x
70
+
71
+
72
+ class half_DSLayer(nn.Module):
73
+ def __init__(self, channel_in=512):
74
+ super(half_DSLayer, self).__init__()
75
+ self.enlayer = nn.Sequential(
76
+ nn.Conv2d(channel_in, int(channel_in//4), kernel_size=3, stride=1, padding=1),
77
+ nn.ReLU(inplace=True)
78
+ )
79
+ self.predlayer = nn.Sequential(
80
+ nn.Conv2d(int(channel_in//4), 1, kernel_size=1, stride=1, padding=0),
81
+ )
82
+
83
+ def forward(self, x):
84
+ x = self.enlayer(x)
85
+ x = self.predlayer(x)
86
+ return x
87
+
88
+
89
+ class CoAttLayer(nn.Module):
90
+ def __init__(self, channel_in=512):
91
+ super(CoAttLayer, self).__init__()
92
+
93
+ self.all_attention = eval(Config().relation_module + '(channel_in)')
94
+ self.conv_output = nn.Conv2d(channel_in, channel_in, kernel_size=1, stride=1, padding=0)
95
+ self.conv_transform = nn.Conv2d(channel_in, channel_in, kernel_size=1, stride=1, padding=0)
96
+ self.fc_transform = nn.Linear(channel_in, channel_in)
97
+
98
+ for layer in [self.conv_output, self.conv_transform, self.fc_transform]:
99
+ weight_init.c2_msra_fill(layer)
100
+
101
+ def forward(self, x5):
102
+ if self.training:
103
+ f_begin = 0
104
+ f_end = int(x5.shape[0] / 2)
105
+ s_begin = f_end
106
+ s_end = int(x5.shape[0])
107
+
108
+ x5_1 = x5[f_begin: f_end]
109
+ x5_2 = x5[s_begin: s_end]
110
+
111
+ x5_new_1 = self.all_attention(x5_1)
112
+ x5_new_2 = self.all_attention(x5_2)
113
+
114
+ x5_1_proto = torch.mean(x5_new_1, (0, 2, 3), True).view(1, -1)
115
+ x5_1_proto = x5_1_proto.unsqueeze(-1).unsqueeze(-1) # 1, C, 1, 1
116
+
117
+ x5_2_proto = torch.mean(x5_new_2, (0, 2, 3), True).view(1, -1)
118
+ x5_2_proto = x5_2_proto.unsqueeze(-1).unsqueeze(-1) # 1, C, 1, 1
119
+
120
+ x5_11 = x5_1 * x5_1_proto
121
+ x5_22 = x5_2 * x5_2_proto
122
+ weighted_x5 = torch.cat([x5_11, x5_22], dim=0)
123
+
124
+ x5_12 = x5_1 * x5_2_proto
125
+ x5_21 = x5_2 * x5_1_proto
126
+ neg_x5 = torch.cat([x5_12, x5_21], dim=0)
127
+ else:
128
+
129
+ x5_new = self.all_attention(x5)
130
+ x5_proto = torch.mean(x5_new, (0, 2, 3), True).view(1, -1)
131
+ x5_proto = x5_proto.unsqueeze(-1).unsqueeze(-1) # 1, C, 1, 1
132
+
133
+ weighted_x5 = x5 * x5_proto #* cweight
134
+ neg_x5 = None
135
+ return weighted_x5, neg_x5
136
+
137
+
138
+ class ICE(nn.Module):
139
+ # The Integrity Channel Enhancement (ICE) module
140
+ # _X means in X-th column
141
+ def __init__(self, channel_in=512):
142
+ super(ICE, self).__init__()
143
+ self.conv_1 = nn.Conv2d(channel_in, channel_in, 3, 1, 1)
144
+ self.conv_2 = nn.Conv1d(channel_in, channel_in, 3, 1, 1)
145
+ self.conv_3 = nn.Conv2d(channel_in*3, channel_in, 3, 1, 1)
146
+
147
+ self.fc_2 = nn.Linear(channel_in, channel_in)
148
+ self.fc_3 = nn.Linear(channel_in, channel_in)
149
+
150
+ def forward(self, x):
151
+ x_1, x_2, x_3 = x, x, x
152
+
153
+ x_1 = x_1 * x_2 * x_3
154
+ x_2 = x_1 + x_2 + x_3
155
+ x_3 = torch.cat((x_1, x_2, x_3), dim=1)
156
+
157
+ V = self.conv_1(x_1)
158
+
159
+ bs, c, h, w = x_2.shape
160
+ K = self.conv_2(x_2.view(bs, c, h*w))
161
+ Q_prime = self.conv_3(x_3)
162
+ Q_prime = torch.norm(Q_prime, dim=(-2, -1)).view(bs, c, 1, 1)
163
+ Q_prime = Q_prime.view(bs, -1)
164
+ Q_prime = self.fc_3(Q_prime)
165
+ Q_prime = torch.softmax(Q_prime, dim=-1)
166
+ Q_prime = Q_prime.unsqueeze(1)
167
+
168
+ Q = torch.matmul(Q_prime, K)
169
+
170
+ x_2 = torch.nn.functional.cosine_similarity(K, Q, dim=-1)
171
+ x_2 = torch.sigmoid(x_2)
172
+ x_2 = self.fc_2(x_2)
173
+ x_2 = x_2.unsqueeze(-1).unsqueeze(-1)
174
+ x_1 = V * x_2 + V
175
+
176
+ return x_1
177
+
178
+
179
+ class GAM(nn.Module):
180
+ def __init__(self, channel_in=512):
181
+
182
+ super(GAM, self).__init__()
183
+ self.query_transform = nn.Conv2d(channel_in, channel_in, kernel_size=1, stride=1, padding=0)
184
+ self.key_transform = nn.Conv2d(channel_in, channel_in, kernel_size=1, stride=1, padding=0)
185
+
186
+ self.scale = 1.0 / (channel_in ** 0.5)
187
+
188
+ self.conv6 = nn.Conv2d(channel_in, channel_in, kernel_size=1, stride=1, padding=0)
189
+
190
+ for layer in [self.query_transform, self.key_transform, self.conv6]:
191
+ weight_init.c2_msra_fill(layer)
192
+
193
+ def forward(self, x5):
194
+ # x: B,C,H,W
195
+ # x_query: B,C,HW
196
+ B, C, H5, W5 = x5.size()
197
+
198
+ x_query = self.query_transform(x5).view(B, C, -1)
199
+
200
+ # x_query: B,HW,C
201
+ x_query = torch.transpose(x_query, 1, 2).contiguous().view(-1, C) # BHW, C
202
+ # x_key: B,C,HW
203
+ x_key = self.key_transform(x5).view(B, C, -1)
204
+
205
+ x_key = torch.transpose(x_key, 0, 1).contiguous().view(C, -1) # C, BHW
206
+
207
+ # W = Q^T K: B,HW,HW
208
+ x_w = torch.matmul(x_query, x_key) #* self.scale # BHW, BHW
209
+ x_w = x_w.view(B*H5*W5, B, H5*W5)
210
+ x_w = torch.max(x_w, -1).values # BHW, B
211
+ x_w = x_w.mean(-1)
212
+ #x_w = torch.mean(x_w, -1).values # BHW
213
+ x_w = x_w.view(B, -1) * self.scale # B, HW
214
+ x_w = F.softmax(x_w, dim=-1) # B, HW
215
+ x_w = x_w.view(B, H5, W5).unsqueeze(1) # B, 1, H, W
216
+
217
+ x5 = x5 * x_w
218
+ x5 = self.conv6(x5)
219
+
220
+ return x5
221
+
222
+
223
+ class MHA(nn.Module):
224
+ '''
225
+ Scaled dot-product attention
226
+ '''
227
+
228
+ def __init__(self, d_model=512, d_k=512, d_v=512, h=8, dropout=.1, channel_in=512):
229
+ '''
230
+ :param d_model: Output dimensionality of the model
231
+ :param d_k: Dimensionality of queries and keys
232
+ :param d_v: Dimensionality of values
233
+ :param h: Number of heads
234
+ '''
235
+ super(MHA, self).__init__()
236
+ self.query_transform = nn.Conv2d(channel_in, channel_in, kernel_size=1, stride=1, padding=0)
237
+ self.key_transform = nn.Conv2d(channel_in, channel_in, kernel_size=1, stride=1, padding=0)
238
+ self.value_transform = nn.Conv2d(channel_in, channel_in, kernel_size=1, stride=1, padding=0)
239
+ self.fc_q = nn.Linear(d_model, h * d_k)
240
+ self.fc_k = nn.Linear(d_model, h * d_k)
241
+ self.fc_v = nn.Linear(d_model, h * d_v)
242
+ self.fc_o = nn.Linear(h * d_v, d_model)
243
+ self.dropout = nn.Dropout(dropout)
244
+
245
+ self.d_model = d_model
246
+ self.d_k = d_k
247
+ self.d_v = d_v
248
+ self.h = h
249
+
250
+ self.init_weights()
251
+
252
+ def init_weights(self):
253
+ for m in self.modules():
254
+ if isinstance(m, nn.Conv2d):
255
+ nn.init.kaiming_normal_(m.weight, mode='fan_out')
256
+ if m.bias is not None:
257
+ nn.init.constant_(m.bias, 0)
258
+ elif isinstance(m, nn.BatchNorm2d):
259
+ nn.init.constant_(m.weight, 1)
260
+ nn.init.constant_(m.bias, 0)
261
+ elif isinstance(m, nn.Linear):
262
+ nn.init.normal_(m.weight, std=0.001)
263
+ if m.bias is not None:
264
+ nn.init.constant_(m.bias, 0)
265
+
266
+ def forward(self, x, attention_mask=None, attention_weights=None):
267
+ '''
268
+ Computes
269
+ :param queries: Queries (b_s, nq, d_model)
270
+ :param keys: Keys (b_s, nk, d_model)
271
+ :param values: Values (b_s, nk, d_model)
272
+ :param attention_mask: Mask over attention values (b_s, h, nq, nk). True indicates masking.
273
+ :param attention_weights: Multiplicative weights for attention values (b_s, h, nq, nk).
274
+ :return:
275
+ '''
276
+ B, C, H, W = x.size()
277
+ queries = self.query_transform(x).view(B, -1, C)
278
+ keys = self.query_transform(x).view(B, -1, C)
279
+ values = self.query_transform(x).view(B, -1, C)
280
+
281
+ b_s, nq = queries.shape[:2]
282
+ nk = keys.shape[1]
283
+
284
+ q = self.fc_q(queries).view(b_s, nq, self.h, self.d_k).permute(0, 2, 1, 3) # (b_s, h, nq, d_k)
285
+ k = self.fc_k(keys).view(b_s, nk, self.h, self.d_k).permute(0, 2, 3, 1) # (b_s, h, d_k, nk)
286
+ v = self.fc_v(values).view(b_s, nk, self.h, self.d_v).permute(0, 2, 1, 3) # (b_s, h, nk, d_v)
287
+
288
+ att = torch.matmul(q, k) / np.sqrt(self.d_k) # (b_s, h, nq, nk)
289
+ if attention_weights is not None:
290
+ att = att * attention_weights
291
+ if attention_mask is not None:
292
+ att = att.masked_fill(attention_mask, -np.inf)
293
+ att = torch.softmax(att, -1)
294
+ att = self.dropout(att)
295
+
296
+ out = torch.matmul(att, v).permute(0, 2, 1, 3).contiguous().view(b_s, nq, self.h * self.d_v) # (b_s, nq, h*d_v)
297
+ out = self.fc_o(out).view(B, C, H, W) # (b_s, nq, d_model)
298
+ return out
299
+
300
+
301
+ class NonLocal(nn.Module):
302
+ def __init__(self, channel_in=512, inter_channels=None, dimension=2, sub_sample=True, bn_layer=True):
303
+ super(NonLocal, self).__init__()
304
+
305
+ assert dimension in [1, 2, 3]
306
+ self.dimension = dimension
307
+ self.sub_sample = sub_sample
308
+
309
+ self.channel_in = channel_in
310
+ self.inter_channels = inter_channels
311
+
312
+ if self.inter_channels is None:
313
+ self.inter_channels = channel_in // 2
314
+ if self.inter_channels == 0:
315
+ self.inter_channels = 1
316
+
317
+ self.g = nn.Conv2d(self.channel_in, self.inter_channels, 1, 1, 0)
318
+
319
+ if bn_layer:
320
+ self.W = nn.Sequential(
321
+ nn.Conv2d(self.inter_channels, self.channel_in, kernel_size=1, stride=1, padding=0),
322
+ nn.BatchNorm2d(self.channel_in)
323
+ )
324
+ nn.init.constant_(self.W[1].weight, 0)
325
+ nn.init.constant_(self.W[1].bias, 0)
326
+ else:
327
+ self.W = nn.Conv2d(self.inter_channels, self.channel_in, kernel_size=1, stride=1, padding=0)
328
+ nn.init.constant_(self.W.weight, 0)
329
+ nn.init.constant_(self.W.bias, 0)
330
+
331
+ self.theta = nn.Conv2d(self.channel_in, self.inter_channels, kernel_size=1, stride=1, padding=0)
332
+ self.phi = nn.Conv2d(self.channel_in, self.inter_channels, kernel_size=1, stride=1, padding=0)
333
+
334
+ if sub_sample:
335
+ self.g = nn.Sequential(self.g, nn.MaxPool2d(kernel_size=(2, 2)))
336
+ self.phi = nn.Sequential(self.phi, nn.MaxPool2d(kernel_size=(2, 2)))
337
+
338
+ def forward(self, x, return_nl_map=False):
339
+ """
340
+ :param x: (b, c, t, h, w)
341
+ :param return_nl_map: if True return z, nl_map, else only return z.
342
+ :return:
343
+ """
344
+
345
+ batch_size = x.size(0)
346
+
347
+ g_x = self.g(x).view(batch_size, self.inter_channels, -1)
348
+ g_x = g_x.permute(0, 2, 1)
349
+
350
+ theta_x = self.theta(x).view(batch_size, self.inter_channels, -1)
351
+ theta_x = theta_x.permute(0, 2, 1)
352
+ phi_x = self.phi(x).view(batch_size, self.inter_channels, -1)
353
+ f = torch.matmul(theta_x, phi_x)
354
+ f_div_C = F.softmax(f, dim=-1)
355
+
356
+ y = torch.matmul(f_div_C, g_x)
357
+ y = y.permute(0, 2, 1).contiguous()
358
+ y = y.view(batch_size, self.inter_channels, *x.size()[2:])
359
+ W_y = self.W(y)
360
+ z = W_y + x
361
+
362
+ if return_nl_map:
363
+ return z, f_div_C
364
+ return z
365
+
366
+
367
+ class DBHead(nn.Module):
368
+ def __init__(self, channel_in=32, channel_out=1, k=config.db_k):
369
+ super().__init__()
370
+ self.k = k
371
+ self.binarize = nn.Sequential(
372
+ nn.Conv2d(channel_in, channel_in, 3, 1, 1),
373
+ *[nn.BatchNorm2d(channel_in), nn.ReLU(inplace=True)] if config.use_bn else nn.ReLU(inplace=True),
374
+ nn.Conv2d(channel_in, channel_in, 3, 1, 1),
375
+ *[nn.BatchNorm2d(channel_in), nn.ReLU(inplace=True)] if config.use_bn else nn.ReLU(inplace=True),
376
+ nn.Conv2d(channel_in, channel_out, 1, 1, 0),
377
+ nn.Sigmoid()
378
+ )
379
+
380
+ self.thresh = nn.Sequential(
381
+ nn.Conv2d(channel_in, channel_in, 3, padding=1),
382
+ *[nn.BatchNorm2d(channel_in), nn.ReLU(inplace=True)] if config.use_bn else nn.ReLU(inplace=True),
383
+ nn.Conv2d(channel_in, channel_in, 3, 1, 1),
384
+ *[nn.BatchNorm2d(channel_in), nn.ReLU(inplace=True)] if config.use_bn else nn.ReLU(inplace=True),
385
+ nn.Conv2d(channel_in, channel_out, 1, 1, 0),
386
+ nn.Sigmoid()
387
+ )
388
+
389
+ def forward(self, x):
390
+ shrink_maps = self.binarize(x)
391
+ threshold_maps = self.thresh(x)
392
+ binary_maps = self.step_function(shrink_maps, threshold_maps)
393
+ return binary_maps
394
+
395
+ def step_function(self, x, y):
396
+ if config.db_k_alpha != 1:
397
+ z = x - y
398
+ mask_neg_inv = 1 - 2 * (z < 0)
399
+ a = torch.exp(-self.k * (torch.pow(z * mask_neg_inv + 1e-16, 1/config.k_alpha) * mask_neg_inv))
400
+ else:
401
+ a = torch.exp(-self.k * (x - y))
402
+ if torch.isinf(a).any():
403
+ a = torch.exp(-50 * (x - y))
404
+ return torch.reciprocal(1 + a)
405
+
406
+
407
+ class RefUnet(nn.Module):
408
+ # Refinement
409
+ def __init__(self, in_ch, inc_ch):
410
+ super(RefUnet, self).__init__()
411
+ self.conv0 = nn.Conv2d(in_ch, inc_ch, 3, padding=1)
412
+ self.conv1 = nn.Conv2d(inc_ch, 64, 3, padding=1)
413
+ if config.use_bn:
414
+ self.bn1 = nn.BatchNorm2d(64)
415
+ self.relu1 = nn.ReLU(inplace=True)
416
+
417
+ self.pool1 = nn.MaxPool2d(2, 2, ceil_mode=True)
418
+ self.conv2 = nn.Conv2d(64, 64, 3, padding=1)
419
+ if config.use_bn:
420
+ self.bn2 = nn.BatchNorm2d(64)
421
+ self.relu2 = nn.ReLU(inplace=True)
422
+
423
+ self.pool2 = nn.MaxPool2d(2, 2, ceil_mode=True)
424
+ self.conv3 = nn.Conv2d(64, 64, 3, padding=1)
425
+ if config.use_bn:
426
+ self.bn3 = nn.BatchNorm2d(64)
427
+ self.relu3 = nn.ReLU(inplace=True)
428
+
429
+ self.pool3 = nn.MaxPool2d(2, 2, ceil_mode=True)
430
+ self.conv4 = nn.Conv2d(64, 64, 3, padding=1)
431
+ if config.use_bn:
432
+ self.bn4 = nn.BatchNorm2d(64)
433
+ self.relu4 = nn.ReLU(inplace=True)
434
+
435
+ self.pool4 = nn.MaxPool2d(2, 2, ceil_mode=True)
436
+ #####
437
+ self.conv5 = nn.Conv2d(64, 64, 3, padding=1)
438
+ if config.use_bn:
439
+ self.bn5 = nn.BatchNorm2d(64)
440
+ self.relu5 = nn.ReLU(inplace=True)
441
+ #####
442
+ self.conv_d4 = nn.Conv2d(128, 64, 3, padding=1)
443
+ if config.use_bn:
444
+ self.bn_d4 = nn.BatchNorm2d(64)
445
+ self.relu_d4 = nn.ReLU(inplace=True)
446
+
447
+ self.conv_d3 = nn.Conv2d(128, 64, 3, padding=1)
448
+ if config.use_bn:
449
+ self.bn_d3 = nn.BatchNorm2d(64)
450
+ self.relu_d3 = nn.ReLU(inplace=True)
451
+
452
+ self.conv_d2 = nn.Conv2d(128, 64, 3, padding=1)
453
+ if config.use_bn:
454
+ self.bn_d2 = nn.BatchNorm2d(64)
455
+ self.relu_d2 = nn.ReLU(inplace=True)
456
+
457
+ self.conv_d1 = nn.Conv2d(128, 64, 3, padding=1)
458
+ if config.use_bn:
459
+ self.bn_d1 = nn.BatchNorm2d(64)
460
+ self.relu_d1 = nn.ReLU(inplace=True)
461
+
462
+ self.conv_d0 = nn.Conv2d(64, 1, 3, padding=1)
463
+
464
+ self.upscore2 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
465
+ if config.db_output_refiner:
466
+ self.db_output_refiner = DBHead(64)
467
+
468
+
469
+ def forward(self, x):
470
+ hx = x
471
+ hx = self.conv1(self.conv0(hx))
472
+ if config.use_bn:
473
+ hx = self.bn1(hx)
474
+ hx1 = self.relu1(hx)
475
+ hx = self.conv2(self.pool1(hx1))
476
+ if config.use_bn:
477
+ hx = self.bn2(hx)
478
+ hx2 = self.relu2(hx)
479
+ hx = self.conv3(self.pool2(hx2))
480
+ if config.use_bn:
481
+ hx = self.bn3(hx)
482
+ hx3 = self.relu3(hx)
483
+ hx = self.conv4(self.pool3(hx3))
484
+ if config.use_bn:
485
+ hx = self.bn4(hx)
486
+ hx4 = self.relu4(hx)
487
+ hx = self.conv5(self.pool4(hx4))
488
+ if config.use_bn:
489
+ hx = self.bn5(hx)
490
+ hx5 = self.relu5(hx)
491
+ hx = self.upscore2(hx5)
492
+ d4 = self.conv_d4(torch.cat((hx, hx4), 1))
493
+ if config.use_bn:
494
+ d4 = self.bn_d4(d4)
495
+ d4 = self.relu_d4(d4)
496
+ hx = self.upscore2(d4)
497
+ d3 = self.conv_d3(torch.cat((hx, hx3), 1))
498
+ if config.use_bn:
499
+ d3 = self.bn_d3(d3)
500
+ d3 = self.relu_d3(d3)
501
+ hx = self.upscore2(d3)
502
+ d2 = self.conv_d2(torch.cat((hx, hx2), 1))
503
+ if config.use_bn:
504
+ d2 = self.bn_d2(d2)
505
+ d2 = self.relu_d2(d2)
506
+ hx = self.upscore2(d2)
507
+ d1 = self.conv_d1(torch.cat((hx, hx1), 1))
508
+ if config.use_bn:
509
+ d1 = self.bn_d1(d1)
510
+ d1 = self.relu_d1(d1)
511
+ if config.db_output_refiner:
512
+ x = self.db_output_refiner(d1)
513
+ else:
514
+ residual = self.conv_d0(d1)
515
+ x = x + residual
516
+ return x