ZhengPeng7 commited on
Commit
59b8f8d
1 Parent(s): dc4c093

Update the model codes, including the previous inconsistencies.

Browse files
models/backbones/build_backbone.py CHANGED
@@ -2,7 +2,7 @@ import torch
2
  import torch.nn as nn
3
  from collections import OrderedDict
4
  from torchvision.models import vgg16, vgg16_bn, VGG16_Weights, VGG16_BN_Weights, resnet50, ResNet50_Weights
5
- from models.backbones.pvt_v2 import pvt_v2_b2, pvt_v2_b5
6
  from models.backbones.swin_v1 import swin_v1_t, swin_v1_s, swin_v1_b, swin_v1_l
7
  from config import Config
8
 
 
2
  import torch.nn as nn
3
  from collections import OrderedDict
4
  from torchvision.models import vgg16, vgg16_bn, VGG16_Weights, VGG16_BN_Weights, resnet50, ResNet50_Weights
5
+ from models.backbones.pvt_v2 import pvt_v2_b0, pvt_v2_b1, pvt_v2_b2, pvt_v2_b5
6
  from models.backbones.swin_v1 import swin_v1_t, swin_v1_s, swin_v1_b, swin_v1_l
7
  from config import Config
8
 
models/backbones/swin_v1.py CHANGED
@@ -578,31 +578,6 @@ class SwinTransformer(nn.Module):
578
  for param in m.parameters():
579
  param.requires_grad = False
580
 
581
- def init_weights(self, pretrained=None):
582
- """Initialize the weights in backbone.
583
-
584
- Args:
585
- pretrained (str, optional): Path to pre-trained weights.
586
- Defaults to None.
587
- """
588
-
589
- def _init_weights(m):
590
- if isinstance(m, nn.Linear):
591
- trunc_normal_(m.weight, std=.02)
592
- if isinstance(m, nn.Linear) and m.bias is not None:
593
- nn.init.constant_(m.bias, 0)
594
- elif isinstance(m, nn.LayerNorm):
595
- nn.init.constant_(m.bias, 0)
596
- nn.init.constant_(m.weight, 1.0)
597
-
598
- if isinstance(pretrained, str):
599
- self.apply(_init_weights)
600
- logger = get_root_logger()
601
- load_checkpoint(self, pretrained, strict=False, logger=logger)
602
- elif pretrained is None:
603
- self.apply(_init_weights)
604
- else:
605
- raise TypeError('pretrained must be a str or None')
606
 
607
  def forward(self, x):
608
  """Forward function."""
 
578
  for param in m.parameters():
579
  param.requires_grad = False
580
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
581
 
582
  def forward(self, x):
583
  """Forward function."""
models/{baseline.py → birefnet.py} RENAMED
@@ -41,14 +41,6 @@ class BiRefNet(nn.Module):
41
  ])
42
 
43
  self.decoder = Decoder(channels)
44
-
45
- if self.config.locate_head:
46
- self.locate_header = nn.ModuleList([
47
- BasicDecBlk(channels[0], channels[-1]),
48
- nn.Sequential(
49
- nn.Conv2d(channels[-1], 1, 1, 1, 0),
50
- )
51
- ])
52
 
53
  if self.config.ender:
54
  self.dec_end = nn.Sequential(
@@ -60,7 +52,7 @@ class BiRefNet(nn.Module):
60
  # refine patch-level segmentation
61
  if self.config.refine:
62
  if self.config.refine == 'itself':
63
- self.stem_layer = StemLayer(in_channels=3+1, inter_channels=48, out_channels=3)
64
  else:
65
  self.refiner = eval('{}({})'.format(self.config.refine, 'in_channels=3+1'))
66
 
@@ -105,20 +97,6 @@ class BiRefNet(nn.Module):
105
  )
106
  return (x1, x2, x3, x4), class_preds
107
 
108
- # def forward_loc(self, x):
109
- # ########## Encoder ##########
110
- # (x1, x2, x3, x4), class_preds = self.forward_enc(x)
111
- # if self.config.squeeze_block:
112
- # x4 = self.squeeze_module(x4)
113
- # if self.config.locate_head:
114
- # locate_preds = self.locate_header[1](
115
- # F.interpolate(
116
- # self.locate_header[0](
117
- # F.interpolate(x4, size=x2.shape[2:], mode='bilinear', align_corners=True)
118
- # ), size=x.shape[2:], mode='bilinear', align_corners=True
119
- # )
120
- # )
121
-
122
  def forward_ori(self, x):
123
  ########## Encoder ##########
124
  (x1, x2, x3, x4), class_preds = self.forward_enc(x)
@@ -131,22 +109,22 @@ class BiRefNet(nn.Module):
131
  scaled_preds = self.decoder(features)
132
  return scaled_preds, class_preds
133
 
134
- def forward_ref(self, x, pred):
135
- # refine patch-level segmentation
136
- if pred.shape[2:] != x.shape[2:]:
137
- pred = F.interpolate(pred, size=x.shape[2:], mode='bilinear', align_corners=True)
138
- # pred = pred.sigmoid()
139
- if self.config.refine == 'itself':
140
- x = self.stem_layer(torch.cat([x, pred], dim=1))
141
- scaled_preds, class_preds = self.forward_ori(x)
142
- else:
143
- scaled_preds = self.refiner([x, pred])
144
- class_preds = None
145
- return scaled_preds, class_preds
146
 
147
- def forward_ref_end(self, x):
148
- # remove the grids of concatenated preds
149
- return self.dec_end(x) if self.config.ender else x
150
 
151
 
152
  # def forward(self, x):
@@ -181,6 +159,7 @@ class Decoder(nn.Module):
181
  DBlock = SimpleConvs
182
  ic = 64
183
  ipt_cha_opt = 1
 
184
  self.ipt_blk4 = DBlock(2**8*3 if self.split else 3, [N_dec_ipt, channels[0]//8][ipt_cha_opt], inter_channels=ic)
185
  self.ipt_blk3 = DBlock(2**6*3 if self.split else 3, [N_dec_ipt, channels[1]//8][ipt_cha_opt], inter_channels=ic)
186
  self.ipt_blk2 = DBlock(2**4*3 if self.split else 3, [N_dec_ipt, channels[2]//8][ipt_cha_opt], inter_channels=ic)
@@ -188,7 +167,7 @@ class Decoder(nn.Module):
188
  else:
189
  self.split = None
190
 
191
- self.decoder_block4 = DecoderBlock(channels[0], channels[1])
192
  self.decoder_block3 = DecoderBlock(channels[1]+([N_dec_ipt, channels[0]//8][ipt_cha_opt] if self.config.dec_ipt else 0), channels[2])
193
  self.decoder_block2 = DecoderBlock(channels[2]+([N_dec_ipt, channels[1]//8][ipt_cha_opt] if self.config.dec_ipt else 0), channels[3])
194
  self.decoder_block1 = DecoderBlock(channels[3]+([N_dec_ipt, channels[2]//8][ipt_cha_opt] if self.config.dec_ipt else 0), channels[3]//2)
@@ -205,15 +184,15 @@ class Decoder(nn.Module):
205
 
206
  if self.config.out_ref:
207
  _N = 16
208
- # self.gdt_convs_4 = nn.Sequential(nn.Conv2d(channels[1], _N, 3, 1, 1), nn.BatchNorm2d(_N), nn.ReLU(inplace=True))
209
- self.gdt_convs_3 = nn.Sequential(nn.Conv2d(channels[2], _N, 3, 1, 1), nn.BatchNorm2d(_N), nn.ReLU(inplace=True))
210
- self.gdt_convs_2 = nn.Sequential(nn.Conv2d(channels[3], _N, 3, 1, 1), nn.BatchNorm2d(_N), nn.ReLU(inplace=True))
211
 
212
- # self.gdt_convs_pred_4 = nn.Sequential(nn.Conv2d(_N, 1, 1, 1, 0))
213
  self.gdt_convs_pred_3 = nn.Sequential(nn.Conv2d(_N, 1, 1, 1, 0))
214
  self.gdt_convs_pred_2 = nn.Sequential(nn.Conv2d(_N, 1, 1, 1, 0))
215
 
216
- # self.gdt_convs_attn_4 = nn.Sequential(nn.Conv2d(_N, 1, 1, 1, 0))
217
  self.gdt_convs_attn_3 = nn.Sequential(nn.Conv2d(_N, 1, 1, 1, 0))
218
  self.gdt_convs_attn_2 = nn.Sequential(nn.Conv2d(_N, 1, 1, 1, 0))
219
 
@@ -238,14 +217,31 @@ class Decoder(nn.Module):
238
  else:
239
  x, x1, x2, x3, x4 = features
240
  outs = []
 
 
 
 
241
  p4 = self.decoder_block4(x4)
242
  m4 = self.conv_ms_spvn_4(p4) if self.config.ms_supervision else None
 
 
 
 
 
 
 
 
 
 
 
 
 
243
  _p4 = F.interpolate(p4, size=x3.shape[2:], mode='bilinear', align_corners=True)
244
  _p3 = _p4 + self.lateral_block4(x3)
 
245
  if self.config.dec_ipt:
246
  patches_batch = self.get_patches_batch(x, _p3) if self.split else x
247
  _p3 = torch.cat((_p3, self.ipt_blk4(F.interpolate(patches_batch, size=x3.shape[2:], mode='bilinear', align_corners=True))), 1)
248
-
249
  p3 = self.decoder_block3(_p3)
250
  m3 = self.conv_ms_spvn_3(p3) if self.config.ms_supervision else None
251
  if self.config.out_ref:
@@ -268,10 +264,10 @@ class Decoder(nn.Module):
268
  p3 = p3 * gdt_attn_3
269
  _p3 = F.interpolate(p3, size=x2.shape[2:], mode='bilinear', align_corners=True)
270
  _p2 = _p3 + self.lateral_block3(x2)
 
271
  if self.config.dec_ipt:
272
  patches_batch = self.get_patches_batch(x, _p2) if self.split else x
273
  _p2 = torch.cat((_p2, self.ipt_blk3(F.interpolate(patches_batch, size=x2.shape[2:], mode='bilinear', align_corners=True))), 1)
274
-
275
  p2 = self.decoder_block2(_p2)
276
  m2 = self.conv_ms_spvn_2(p2) if self.config.ms_supervision else None
277
  if self.config.out_ref:
@@ -289,12 +285,13 @@ class Decoder(nn.Module):
289
  p2 = p2 * gdt_attn_2
290
  _p2 = F.interpolate(p2, size=x1.shape[2:], mode='bilinear', align_corners=True)
291
  _p1 = _p2 + self.lateral_block2(x1)
 
292
  if self.config.dec_ipt:
293
  patches_batch = self.get_patches_batch(x, _p1) if self.split else x
294
  _p1 = torch.cat((_p1, self.ipt_blk2(F.interpolate(patches_batch, size=x1.shape[2:], mode='bilinear', align_corners=True))), 1)
295
-
296
  _p1 = self.decoder_block1(_p1)
297
  _p1 = F.interpolate(_p1, size=x.shape[2:], mode='bilinear', align_corners=True)
 
298
  if self.config.dec_ipt:
299
  patches_batch = self.get_patches_batch(x, _p1) if self.split else x
300
  _p1 = torch.cat((_p1, self.ipt_blk1(F.interpolate(patches_batch, size=x.shape[2:], mode='bilinear', align_corners=True))), 1)
 
41
  ])
42
 
43
  self.decoder = Decoder(channels)
 
 
 
 
 
 
 
 
44
 
45
  if self.config.ender:
46
  self.dec_end = nn.Sequential(
 
52
  # refine patch-level segmentation
53
  if self.config.refine:
54
  if self.config.refine == 'itself':
55
+ self.stem_layer = StemLayer(in_channels=3+1, inter_channels=48, out_channels=3, norm_layer='BN' if self.config.batch_size > 1 else 'LN')
56
  else:
57
  self.refiner = eval('{}({})'.format(self.config.refine, 'in_channels=3+1'))
58
 
 
97
  )
98
  return (x1, x2, x3, x4), class_preds
99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
  def forward_ori(self, x):
101
  ########## Encoder ##########
102
  (x1, x2, x3, x4), class_preds = self.forward_enc(x)
 
109
  scaled_preds = self.decoder(features)
110
  return scaled_preds, class_preds
111
 
112
+ # def forward_ref(self, x, pred):
113
+ # # refine patch-level segmentation
114
+ # if pred.shape[2:] != x.shape[2:]:
115
+ # pred = F.interpolate(pred, size=x.shape[2:], mode='bilinear', align_corners=True)
116
+ # # pred = pred.sigmoid()
117
+ # if self.config.refine == 'itself':
118
+ # x = self.stem_layer(torch.cat([x, pred], dim=1))
119
+ # scaled_preds, class_preds = self.forward_ori(x)
120
+ # else:
121
+ # scaled_preds = self.refiner([x, pred])
122
+ # class_preds = None
123
+ # return scaled_preds, class_preds
124
 
125
+ # def forward_ref_end(self, x):
126
+ # # remove the grids of concatenated preds
127
+ # return self.dec_end(x) if self.config.ender else x
128
 
129
 
130
  # def forward(self, x):
 
159
  DBlock = SimpleConvs
160
  ic = 64
161
  ipt_cha_opt = 1
162
+ self.ipt_blk5 = DBlock(2**10*3 if self.split else 3, [N_dec_ipt, channels[0]//8][ipt_cha_opt], inter_channels=ic)
163
  self.ipt_blk4 = DBlock(2**8*3 if self.split else 3, [N_dec_ipt, channels[0]//8][ipt_cha_opt], inter_channels=ic)
164
  self.ipt_blk3 = DBlock(2**6*3 if self.split else 3, [N_dec_ipt, channels[1]//8][ipt_cha_opt], inter_channels=ic)
165
  self.ipt_blk2 = DBlock(2**4*3 if self.split else 3, [N_dec_ipt, channels[2]//8][ipt_cha_opt], inter_channels=ic)
 
167
  else:
168
  self.split = None
169
 
170
+ self.decoder_block4 = DecoderBlock(channels[0]+([N_dec_ipt, channels[0]//8][ipt_cha_opt] if self.config.dec_ipt else 0), channels[1])
171
  self.decoder_block3 = DecoderBlock(channels[1]+([N_dec_ipt, channels[0]//8][ipt_cha_opt] if self.config.dec_ipt else 0), channels[2])
172
  self.decoder_block2 = DecoderBlock(channels[2]+([N_dec_ipt, channels[1]//8][ipt_cha_opt] if self.config.dec_ipt else 0), channels[3])
173
  self.decoder_block1 = DecoderBlock(channels[3]+([N_dec_ipt, channels[2]//8][ipt_cha_opt] if self.config.dec_ipt else 0), channels[3]//2)
 
184
 
185
  if self.config.out_ref:
186
  _N = 16
187
+ self.gdt_convs_4 = nn.Sequential(nn.Conv2d(channels[1], _N, 3, 1, 1), nn.BatchNorm2d(_N) if self.config.batch_size > 1 else nn.Identity(), nn.ReLU(inplace=True))
188
+ self.gdt_convs_3 = nn.Sequential(nn.Conv2d(channels[2], _N, 3, 1, 1), nn.BatchNorm2d(_N) if self.config.batch_size > 1 else nn.Identity(), nn.ReLU(inplace=True))
189
+ self.gdt_convs_2 = nn.Sequential(nn.Conv2d(channels[3], _N, 3, 1, 1), nn.BatchNorm2d(_N) if self.config.batch_size > 1 else nn.Identity(), nn.ReLU(inplace=True))
190
 
191
+ self.gdt_convs_pred_4 = nn.Sequential(nn.Conv2d(_N, 1, 1, 1, 0))
192
  self.gdt_convs_pred_3 = nn.Sequential(nn.Conv2d(_N, 1, 1, 1, 0))
193
  self.gdt_convs_pred_2 = nn.Sequential(nn.Conv2d(_N, 1, 1, 1, 0))
194
 
195
+ self.gdt_convs_attn_4 = nn.Sequential(nn.Conv2d(_N, 1, 1, 1, 0))
196
  self.gdt_convs_attn_3 = nn.Sequential(nn.Conv2d(_N, 1, 1, 1, 0))
197
  self.gdt_convs_attn_2 = nn.Sequential(nn.Conv2d(_N, 1, 1, 1, 0))
198
 
 
217
  else:
218
  x, x1, x2, x3, x4 = features
219
  outs = []
220
+
221
+ if self.config.dec_ipt:
222
+ patches_batch = self.get_patches_batch(x, x4) if self.split else x
223
+ x4 = torch.cat((x4, self.ipt_blk5(F.interpolate(patches_batch, size=x4.shape[2:], mode='bilinear', align_corners=True))), 1)
224
  p4 = self.decoder_block4(x4)
225
  m4 = self.conv_ms_spvn_4(p4) if self.config.ms_supervision else None
226
+ if self.config.out_ref:
227
+ p4_gdt = self.gdt_convs_4(p4)
228
+ if self.training:
229
+ # >> GT:
230
+ m4_dia = m4
231
+ gdt_label_main_4 = gdt_gt * F.interpolate(m4_dia, size=gdt_gt.shape[2:], mode='bilinear', align_corners=True)
232
+ outs_gdt_label.append(gdt_label_main_4)
233
+ # >> Pred:
234
+ gdt_pred_4 = self.gdt_convs_pred_4(p4_gdt)
235
+ outs_gdt_pred.append(gdt_pred_4)
236
+ gdt_attn_4 = self.gdt_convs_attn_4(p4_gdt).sigmoid()
237
+ # >> Finally:
238
+ p4 = p4 * gdt_attn_4
239
  _p4 = F.interpolate(p4, size=x3.shape[2:], mode='bilinear', align_corners=True)
240
  _p3 = _p4 + self.lateral_block4(x3)
241
+
242
  if self.config.dec_ipt:
243
  patches_batch = self.get_patches_batch(x, _p3) if self.split else x
244
  _p3 = torch.cat((_p3, self.ipt_blk4(F.interpolate(patches_batch, size=x3.shape[2:], mode='bilinear', align_corners=True))), 1)
 
245
  p3 = self.decoder_block3(_p3)
246
  m3 = self.conv_ms_spvn_3(p3) if self.config.ms_supervision else None
247
  if self.config.out_ref:
 
264
  p3 = p3 * gdt_attn_3
265
  _p3 = F.interpolate(p3, size=x2.shape[2:], mode='bilinear', align_corners=True)
266
  _p2 = _p3 + self.lateral_block3(x2)
267
+
268
  if self.config.dec_ipt:
269
  patches_batch = self.get_patches_batch(x, _p2) if self.split else x
270
  _p2 = torch.cat((_p2, self.ipt_blk3(F.interpolate(patches_batch, size=x2.shape[2:], mode='bilinear', align_corners=True))), 1)
 
271
  p2 = self.decoder_block2(_p2)
272
  m2 = self.conv_ms_spvn_2(p2) if self.config.ms_supervision else None
273
  if self.config.out_ref:
 
285
  p2 = p2 * gdt_attn_2
286
  _p2 = F.interpolate(p2, size=x1.shape[2:], mode='bilinear', align_corners=True)
287
  _p1 = _p2 + self.lateral_block2(x1)
288
+
289
  if self.config.dec_ipt:
290
  patches_batch = self.get_patches_batch(x, _p1) if self.split else x
291
  _p1 = torch.cat((_p1, self.ipt_blk2(F.interpolate(patches_batch, size=x1.shape[2:], mode='bilinear', align_corners=True))), 1)
 
292
  _p1 = self.decoder_block1(_p1)
293
  _p1 = F.interpolate(_p1, size=x.shape[2:], mode='bilinear', align_corners=True)
294
+
295
  if self.config.dec_ipt:
296
  patches_batch = self.get_patches_batch(x, _p1) if self.split else x
297
  _p1 = torch.cat((_p1, self.ipt_blk1(F.interpolate(patches_batch, size=x.shape[2:], mode='bilinear', align_corners=True))), 1)
models/modules/aspp.py CHANGED
@@ -8,56 +8,12 @@ from config import Config
8
  config = Config()
9
 
10
 
11
- class ASPPComplex(nn.Module):
12
- def __init__(self, in_channels=64, out_channels=None, output_stride=16):
13
- super(ASPPComplex, self).__init__()
14
- self.down_scale = 1
15
- if out_channels is None:
16
- out_channels = in_channels
17
- self.in_channelster = 256 // self.down_scale
18
- if output_stride == 16:
19
- dilations = [1, 6, 12, 18]
20
- elif output_stride == 8:
21
- dilations = [1, 12, 24, 36]
22
- else:
23
- raise NotImplementedError
24
-
25
- self.aspp1 = _ASPPModule(in_channels, self.in_channelster, 1, padding=0, dilation=dilations[0])
26
- self.aspp2 = _ASPPModule(in_channels, self.in_channelster, 3, padding=dilations[1], dilation=dilations[1])
27
- self.aspp3 = _ASPPModule(in_channels, self.in_channelster, 3, padding=dilations[2], dilation=dilations[2])
28
- self.aspp4 = _ASPPModule(in_channels, self.in_channelster, 3, padding=dilations[3], dilation=dilations[3])
29
-
30
- self.global_avg_pool = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)),
31
- nn.Conv2d(in_channels, self.in_channelster, 1, stride=1, bias=False),
32
- nn.BatchNorm2d(self.in_channelster) if config.batch_size > 1 else nn.Identity(),
33
- nn.ReLU(inplace=True))
34
- self.conv1 = nn.Conv2d(self.in_channelster * 5, out_channels, 1, bias=False)
35
- self.bn1 = nn.BatchNorm2d(out_channels)
36
- self.relu = nn.ReLU(inplace=True)
37
- self.dropout = nn.Dropout(0.5)
38
-
39
- def forward(self, x):
40
- x1 = self.aspp1(x)
41
- x2 = self.aspp2(x)
42
- x3 = self.aspp3(x)
43
- x4 = self.aspp4(x)
44
- x5 = self.global_avg_pool(x)
45
- x5 = F.interpolate(x5, size=x1.size()[2:], mode='bilinear', align_corners=True)
46
- x = torch.cat((x1, x2, x3, x4, x5), dim=1)
47
-
48
- x = self.conv1(x)
49
- x = self.bn1(x)
50
- x = self.relu(x)
51
-
52
- return self.dropout(x)
53
-
54
-
55
  class _ASPPModule(nn.Module):
56
  def __init__(self, in_channels, planes, kernel_size, padding, dilation):
57
  super(_ASPPModule, self).__init__()
58
  self.atrous_conv = nn.Conv2d(in_channels, planes, kernel_size=kernel_size,
59
  stride=1, padding=padding, dilation=dilation, bias=False)
60
- self.bn = nn.BatchNorm2d(planes)
61
  self.relu = nn.ReLU(inplace=True)
62
 
63
  def forward(self, x):
@@ -66,6 +22,7 @@ class _ASPPModule(nn.Module):
66
 
67
  return self.relu(x)
68
 
 
69
  class ASPP(nn.Module):
70
  def __init__(self, in_channels=64, out_channels=None, output_stride=16):
71
  super(ASPP, self).__init__()
@@ -90,7 +47,7 @@ class ASPP(nn.Module):
90
  nn.BatchNorm2d(self.in_channelster) if config.batch_size > 1 else nn.Identity(),
91
  nn.ReLU(inplace=True))
92
  self.conv1 = nn.Conv2d(self.in_channelster * 5, out_channels, 1, bias=False)
93
- self.bn1 = nn.BatchNorm2d(out_channels)
94
  self.relu = nn.ReLU(inplace=True)
95
  self.dropout = nn.Dropout(0.5)
96
 
@@ -116,7 +73,7 @@ class _ASPPModuleDeformable(nn.Module):
116
  super(_ASPPModuleDeformable, self).__init__()
117
  self.atrous_conv = DeformableConv2d(in_channels, planes, kernel_size=kernel_size,
118
  stride=1, padding=padding, bias=False)
119
- self.bn = nn.BatchNorm2d(planes)
120
  self.relu = nn.ReLU(inplace=True)
121
 
122
  def forward(self, x):
@@ -127,7 +84,7 @@ class _ASPPModuleDeformable(nn.Module):
127
 
128
 
129
  class ASPPDeformable(nn.Module):
130
- def __init__(self, in_channels, out_channels=None, num_parallel_block=1):
131
  super(ASPPDeformable, self).__init__()
132
  self.down_scale = 1
133
  if out_channels is None:
@@ -136,7 +93,7 @@ class ASPPDeformable(nn.Module):
136
 
137
  self.aspp1 = _ASPPModuleDeformable(in_channels, self.in_channelster, 1, padding=0)
138
  self.aspp_deforms = nn.ModuleList([
139
- _ASPPModuleDeformable(in_channels, self.in_channelster, 3, padding=1) for _ in range(num_parallel_block)
140
  ])
141
 
142
  self.global_avg_pool = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)),
@@ -144,7 +101,7 @@ class ASPPDeformable(nn.Module):
144
  nn.BatchNorm2d(self.in_channelster) if config.batch_size > 1 else nn.Identity(),
145
  nn.ReLU(inplace=True))
146
  self.conv1 = nn.Conv2d(self.in_channelster * (2 + len(self.aspp_deforms)), out_channels, 1, bias=False)
147
- self.bn1 = nn.BatchNorm2d(out_channels)
148
  self.relu = nn.ReLU(inplace=True)
149
  self.dropout = nn.Dropout(0.5)
150
 
 
8
  config = Config()
9
 
10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  class _ASPPModule(nn.Module):
12
  def __init__(self, in_channels, planes, kernel_size, padding, dilation):
13
  super(_ASPPModule, self).__init__()
14
  self.atrous_conv = nn.Conv2d(in_channels, planes, kernel_size=kernel_size,
15
  stride=1, padding=padding, dilation=dilation, bias=False)
16
+ self.bn = nn.BatchNorm2d(planes) if config.batch_size > 1 else nn.Identity()
17
  self.relu = nn.ReLU(inplace=True)
18
 
19
  def forward(self, x):
 
22
 
23
  return self.relu(x)
24
 
25
+
26
  class ASPP(nn.Module):
27
  def __init__(self, in_channels=64, out_channels=None, output_stride=16):
28
  super(ASPP, self).__init__()
 
47
  nn.BatchNorm2d(self.in_channelster) if config.batch_size > 1 else nn.Identity(),
48
  nn.ReLU(inplace=True))
49
  self.conv1 = nn.Conv2d(self.in_channelster * 5, out_channels, 1, bias=False)
50
+ self.bn1 = nn.BatchNorm2d(out_channels) if config.batch_size > 1 else nn.Identity()
51
  self.relu = nn.ReLU(inplace=True)
52
  self.dropout = nn.Dropout(0.5)
53
 
 
73
  super(_ASPPModuleDeformable, self).__init__()
74
  self.atrous_conv = DeformableConv2d(in_channels, planes, kernel_size=kernel_size,
75
  stride=1, padding=padding, bias=False)
76
+ self.bn = nn.BatchNorm2d(planes) if config.batch_size > 1 else nn.Identity()
77
  self.relu = nn.ReLU(inplace=True)
78
 
79
  def forward(self, x):
 
84
 
85
 
86
  class ASPPDeformable(nn.Module):
87
+ def __init__(self, in_channels, out_channels=None, parallel_block_sizes=[1, 3, 7]):
88
  super(ASPPDeformable, self).__init__()
89
  self.down_scale = 1
90
  if out_channels is None:
 
93
 
94
  self.aspp1 = _ASPPModuleDeformable(in_channels, self.in_channelster, 1, padding=0)
95
  self.aspp_deforms = nn.ModuleList([
96
+ _ASPPModuleDeformable(in_channels, self.in_channelster, conv_size, padding=int(conv_size//2)) for conv_size in parallel_block_sizes
97
  ])
98
 
99
  self.global_avg_pool = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)),
 
101
  nn.BatchNorm2d(self.in_channelster) if config.batch_size > 1 else nn.Identity(),
102
  nn.ReLU(inplace=True))
103
  self.conv1 = nn.Conv2d(self.in_channelster * (2 + len(self.aspp_deforms)), out_channels, 1, bias=False)
104
+ self.bn1 = nn.BatchNorm2d(out_channels) if config.batch_size > 1 else nn.Identity()
105
  self.relu = nn.ReLU(inplace=True)
106
  self.dropout = nn.Dropout(0.5)
107
 
models/modules/decoder_blocks.py CHANGED
@@ -19,8 +19,8 @@ class BasicDecBlk(nn.Module):
19
  elif config.dec_att == 'ASPPDeformable':
20
  self.dec_att = ASPPDeformable(in_channels=inter_channels)
21
  self.conv_out = nn.Conv2d(inter_channels, out_channels, 3, 1, padding=1)
22
- self.bn_in = nn.BatchNorm2d(inter_channels)
23
- self.bn_out = nn.BatchNorm2d(out_channels)
24
 
25
  def forward(self, x):
26
  x = self.conv_in(x)
@@ -41,7 +41,7 @@ class ResBlk(nn.Module):
41
  inter_channels = in_channels // 4 if config.dec_channels_inter == 'adap' else 64
42
 
43
  self.conv_in = nn.Conv2d(in_channels, inter_channels, 3, 1, padding=1)
44
- self.bn_in = nn.BatchNorm2d(inter_channels)
45
  self.relu_in = nn.ReLU(inplace=True)
46
 
47
  if config.dec_att == 'ASPP':
@@ -50,7 +50,7 @@ class ResBlk(nn.Module):
50
  self.dec_att = ASPPDeformable(in_channels=inter_channels)
51
 
52
  self.conv_out = nn.Conv2d(inter_channels, out_channels, 3, 1, padding=1)
53
- self.bn_out = nn.BatchNorm2d(out_channels)
54
 
55
  self.conv_resi = nn.Conv2d(in_channels, out_channels, 1, 1, 0)
56
 
 
19
  elif config.dec_att == 'ASPPDeformable':
20
  self.dec_att = ASPPDeformable(in_channels=inter_channels)
21
  self.conv_out = nn.Conv2d(inter_channels, out_channels, 3, 1, padding=1)
22
+ self.bn_in = nn.BatchNorm2d(inter_channels) if config.batch_size > 1 else nn.Identity()
23
+ self.bn_out = nn.BatchNorm2d(out_channels) if config.batch_size > 1 else nn.Identity()
24
 
25
  def forward(self, x):
26
  x = self.conv_in(x)
 
41
  inter_channels = in_channels // 4 if config.dec_channels_inter == 'adap' else 64
42
 
43
  self.conv_in = nn.Conv2d(in_channels, inter_channels, 3, 1, padding=1)
44
+ self.bn_in = nn.BatchNorm2d(inter_channels) if config.batch_size > 1 else nn.Identity()
45
  self.relu_in = nn.ReLU(inplace=True)
46
 
47
  if config.dec_att == 'ASPP':
 
50
  self.dec_att = ASPPDeformable(in_channels=inter_channels)
51
 
52
  self.conv_out = nn.Conv2d(inter_channels, out_channels, 3, 1, padding=1)
53
+ self.bn_out = nn.BatchNorm2d(out_channels) if config.batch_size > 1 else nn.Identity()
54
 
55
  self.conv_resi = nn.Conv2d(in_channels, out_channels, 1, 1, 0)
56
 
models/refinement/refiner.py CHANGED
@@ -65,7 +65,7 @@ class Refiner(nn.Module):
65
  super(Refiner, self).__init__()
66
  self.config = Config()
67
  self.epoch = 1
68
- self.stem_layer = StemLayer(in_channels=in_channels, inter_channels=48, out_channels=3)
69
  self.bb = build_backbone(self.config.bb)
70
 
71
  lateral_channels_in_collection = {
 
65
  super(Refiner, self).__init__()
66
  self.config = Config()
67
  self.epoch = 1
68
+ self.stem_layer = StemLayer(in_channels=in_channels, inter_channels=48, out_channels=3, norm_layer='BN' if self.config.batch_size > 1 else 'LN')
69
  self.bb = build_backbone(self.config.bb)
70
 
71
  lateral_channels_in_collection = {