ZhengPeng7 commited on
Commit
108ae46
1 Parent(s): af31851

Move all BiRefNet github codes to the first level directory.

Browse files
__init__.py CHANGED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from os.path import dirname, basename, isfile, join
2
+ import glob
3
+
4
+
5
+ modules = glob.glob(join(dirname(__file__), "*.py"))
6
+ __all__ = [ basename(f)[:-3] for f in modules if isfile(f) and not f.endswith('__init__.py')]
models/backbones/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from os.path import dirname, basename, isfile, join
2
+ import glob
3
+
4
+
5
+ modules = glob.glob(join(dirname(__file__), "*.py"))
6
+ __all__ = [ basename(f)[:-3] for f in modules if isfile(f) and not f.endswith('__init__.py')]
models/modules/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from os.path import dirname, basename, isfile, join
2
+ import glob
3
+
4
+
5
+ modules = glob.glob(join(dirname(__file__), "*.py"))
6
+ __all__ = [ basename(f)[:-3] for f in modules if isfile(f) and not f.endswith('__init__.py')]
models/modules/refinement/refiner.py ADDED
@@ -0,0 +1,253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from collections import OrderedDict
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from torchvision.models import vgg16, vgg16_bn
8
+ from torchvision.models import resnet50
9
+
10
+ from config import Config
11
+ from dataset import class_labels_TR_sorted
12
+ from models.backbones.build_backbone import build_backbone
13
+ from models.modules.decoder_blocks import BasicDecBlk
14
+ from models.modules.lateral_blocks import BasicLatBlk
15
+ from models.modules.ing import *
16
+ from models.refinement.stem_layer import StemLayer
17
+
18
+
19
+ class RefinerPVTInChannels4(nn.Module):
20
+ def __init__(self, in_channels=3+1):
21
+ super(RefinerPVTInChannels4, self).__init__()
22
+ self.config = Config()
23
+ self.epoch = 1
24
+ self.bb = build_backbone(self.config.bb, params_settings='in_channels=4')
25
+
26
+ lateral_channels_in_collection = {
27
+ 'vgg16': [512, 256, 128, 64], 'vgg16bn': [512, 256, 128, 64], 'resnet50': [1024, 512, 256, 64],
28
+ 'pvt_v2_b2': [512, 320, 128, 64], 'pvt_v2_b5': [512, 320, 128, 64],
29
+ 'swin_v1_b': [1024, 512, 256, 128], 'swin_v1_l': [1536, 768, 384, 192],
30
+ }
31
+ channels = lateral_channels_in_collection[self.config.bb]
32
+ self.squeeze_module = BasicDecBlk(channels[0], channels[0])
33
+
34
+ self.decoder = Decoder(channels)
35
+
36
+ if 0:
37
+ for key, value in self.named_parameters():
38
+ if 'bb.' in key:
39
+ value.requires_grad = False
40
+
41
+ def forward(self, x):
42
+ if isinstance(x, list):
43
+ x = torch.cat(x, dim=1)
44
+ ########## Encoder ##########
45
+ if self.config.bb in ['vgg16', 'vgg16bn', 'resnet50']:
46
+ x1 = self.bb.conv1(x)
47
+ x2 = self.bb.conv2(x1)
48
+ x3 = self.bb.conv3(x2)
49
+ x4 = self.bb.conv4(x3)
50
+ else:
51
+ x1, x2, x3, x4 = self.bb(x)
52
+
53
+ x4 = self.squeeze_module(x4)
54
+
55
+ ########## Decoder ##########
56
+
57
+ features = [x, x1, x2, x3, x4]
58
+ scaled_preds = self.decoder(features)
59
+
60
+ return scaled_preds
61
+
62
+
63
+ class Refiner(nn.Module):
64
+ def __init__(self, in_channels=3+1):
65
+ super(Refiner, self).__init__()
66
+ self.config = Config()
67
+ self.epoch = 1
68
+ self.stem_layer = StemLayer(in_channels=in_channels, inter_channels=48, out_channels=3, norm_layer='BN' if self.config.batch_size > 1 else 'LN')
69
+ self.bb = build_backbone(self.config.bb)
70
+
71
+ lateral_channels_in_collection = {
72
+ 'vgg16': [512, 256, 128, 64], 'vgg16bn': [512, 256, 128, 64], 'resnet50': [1024, 512, 256, 64],
73
+ 'pvt_v2_b2': [512, 320, 128, 64], 'pvt_v2_b5': [512, 320, 128, 64],
74
+ 'swin_v1_b': [1024, 512, 256, 128], 'swin_v1_l': [1536, 768, 384, 192],
75
+ }
76
+ channels = lateral_channels_in_collection[self.config.bb]
77
+ self.squeeze_module = BasicDecBlk(channels[0], channels[0])
78
+
79
+ self.decoder = Decoder(channels)
80
+
81
+ if 0:
82
+ for key, value in self.named_parameters():
83
+ if 'bb.' in key:
84
+ value.requires_grad = False
85
+
86
+ def forward(self, x):
87
+ if isinstance(x, list):
88
+ x = torch.cat(x, dim=1)
89
+ x = self.stem_layer(x)
90
+ ########## Encoder ##########
91
+ if self.config.bb in ['vgg16', 'vgg16bn', 'resnet50']:
92
+ x1 = self.bb.conv1(x)
93
+ x2 = self.bb.conv2(x1)
94
+ x3 = self.bb.conv3(x2)
95
+ x4 = self.bb.conv4(x3)
96
+ else:
97
+ x1, x2, x3, x4 = self.bb(x)
98
+
99
+ x4 = self.squeeze_module(x4)
100
+
101
+ ########## Decoder ##########
102
+
103
+ features = [x, x1, x2, x3, x4]
104
+ scaled_preds = self.decoder(features)
105
+
106
+ return scaled_preds
107
+
108
+
109
+ class Decoder(nn.Module):
110
+ def __init__(self, channels):
111
+ super(Decoder, self).__init__()
112
+ self.config = Config()
113
+ DecoderBlock = eval('BasicDecBlk')
114
+ LateralBlock = eval('BasicLatBlk')
115
+
116
+ self.decoder_block4 = DecoderBlock(channels[0], channels[1])
117
+ self.decoder_block3 = DecoderBlock(channels[1], channels[2])
118
+ self.decoder_block2 = DecoderBlock(channels[2], channels[3])
119
+ self.decoder_block1 = DecoderBlock(channels[3], channels[3]//2)
120
+
121
+ self.lateral_block4 = LateralBlock(channels[1], channels[1])
122
+ self.lateral_block3 = LateralBlock(channels[2], channels[2])
123
+ self.lateral_block2 = LateralBlock(channels[3], channels[3])
124
+
125
+ if self.config.ms_supervision:
126
+ self.conv_ms_spvn_4 = nn.Conv2d(channels[1], 1, 1, 1, 0)
127
+ self.conv_ms_spvn_3 = nn.Conv2d(channels[2], 1, 1, 1, 0)
128
+ self.conv_ms_spvn_2 = nn.Conv2d(channels[3], 1, 1, 1, 0)
129
+ self.conv_out1 = nn.Sequential(nn.Conv2d(channels[3]//2, 1, 1, 1, 0))
130
+
131
+ def forward(self, features):
132
+ x, x1, x2, x3, x4 = features
133
+ outs = []
134
+ p4 = self.decoder_block4(x4)
135
+ _p4 = F.interpolate(p4, size=x3.shape[2:], mode='bilinear', align_corners=True)
136
+ _p3 = _p4 + self.lateral_block4(x3)
137
+
138
+ p3 = self.decoder_block3(_p3)
139
+ _p3 = F.interpolate(p3, size=x2.shape[2:], mode='bilinear', align_corners=True)
140
+ _p2 = _p3 + self.lateral_block3(x2)
141
+
142
+ p2 = self.decoder_block2(_p2)
143
+ _p2 = F.interpolate(p2, size=x1.shape[2:], mode='bilinear', align_corners=True)
144
+ _p1 = _p2 + self.lateral_block2(x1)
145
+
146
+ _p1 = self.decoder_block1(_p1)
147
+ _p1 = F.interpolate(_p1, size=x.shape[2:], mode='bilinear', align_corners=True)
148
+ p1_out = self.conv_out1(_p1)
149
+
150
+ if self.config.ms_supervision:
151
+ outs.append(self.conv_ms_spvn_4(p4))
152
+ outs.append(self.conv_ms_spvn_3(p3))
153
+ outs.append(self.conv_ms_spvn_2(p2))
154
+ outs.append(p1_out)
155
+ return outs
156
+
157
+
158
+ class RefUNet(nn.Module):
159
+ # Refinement
160
+ def __init__(self, in_channels=3+1):
161
+ super(RefUNet, self).__init__()
162
+ self.encoder_1 = nn.Sequential(
163
+ nn.Conv2d(in_channels, 64, 3, 1, 1),
164
+ nn.Conv2d(64, 64, 3, 1, 1),
165
+ nn.BatchNorm2d(64),
166
+ nn.ReLU(inplace=True)
167
+ )
168
+
169
+ self.encoder_2 = nn.Sequential(
170
+ nn.MaxPool2d(2, 2, ceil_mode=True),
171
+ nn.Conv2d(64, 64, 3, 1, 1),
172
+ nn.BatchNorm2d(64),
173
+ nn.ReLU(inplace=True)
174
+ )
175
+
176
+ self.encoder_3 = nn.Sequential(
177
+ nn.MaxPool2d(2, 2, ceil_mode=True),
178
+ nn.Conv2d(64, 64, 3, 1, 1),
179
+ nn.BatchNorm2d(64),
180
+ nn.ReLU(inplace=True)
181
+ )
182
+
183
+ self.encoder_4 = nn.Sequential(
184
+ nn.MaxPool2d(2, 2, ceil_mode=True),
185
+ nn.Conv2d(64, 64, 3, 1, 1),
186
+ nn.BatchNorm2d(64),
187
+ nn.ReLU(inplace=True)
188
+ )
189
+
190
+ self.pool4 = nn.MaxPool2d(2, 2, ceil_mode=True)
191
+ #####
192
+ self.decoder_5 = nn.Sequential(
193
+ nn.Conv2d(64, 64, 3, 1, 1),
194
+ nn.BatchNorm2d(64),
195
+ nn.ReLU(inplace=True)
196
+ )
197
+ #####
198
+ self.decoder_4 = nn.Sequential(
199
+ nn.Conv2d(128, 64, 3, 1, 1),
200
+ nn.BatchNorm2d(64),
201
+ nn.ReLU(inplace=True)
202
+ )
203
+
204
+ self.decoder_3 = nn.Sequential(
205
+ nn.Conv2d(128, 64, 3, 1, 1),
206
+ nn.BatchNorm2d(64),
207
+ nn.ReLU(inplace=True)
208
+ )
209
+
210
+ self.decoder_2 = nn.Sequential(
211
+ nn.Conv2d(128, 64, 3, 1, 1),
212
+ nn.BatchNorm2d(64),
213
+ nn.ReLU(inplace=True)
214
+ )
215
+
216
+ self.decoder_1 = nn.Sequential(
217
+ nn.Conv2d(128, 64, 3, 1, 1),
218
+ nn.BatchNorm2d(64),
219
+ nn.ReLU(inplace=True)
220
+ )
221
+
222
+ self.conv_d0 = nn.Conv2d(64, 1, 3, 1, 1)
223
+
224
+ self.upscore2 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
225
+
226
+ def forward(self, x):
227
+ outs = []
228
+ if isinstance(x, list):
229
+ x = torch.cat(x, dim=1)
230
+ hx = x
231
+
232
+ hx1 = self.encoder_1(hx)
233
+ hx2 = self.encoder_2(hx1)
234
+ hx3 = self.encoder_3(hx2)
235
+ hx4 = self.encoder_4(hx3)
236
+
237
+ hx = self.decoder_5(self.pool4(hx4))
238
+ hx = torch.cat((self.upscore2(hx), hx4), 1)
239
+
240
+ d4 = self.decoder_4(hx)
241
+ hx = torch.cat((self.upscore2(d4), hx3), 1)
242
+
243
+ d3 = self.decoder_3(hx)
244
+ hx = torch.cat((self.upscore2(d3), hx2), 1)
245
+
246
+ d2 = self.decoder_2(hx)
247
+ hx = torch.cat((self.upscore2(d2), hx1), 1)
248
+
249
+ d1 = self.decoder_1(hx)
250
+
251
+ x = self.conv_d0(d1)
252
+ outs.append(x)
253
+ return outs
models/modules/refinement/stem_layer.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ from models.modules.utils import build_act_layer, build_norm_layer
3
+
4
+
5
+ class StemLayer(nn.Module):
6
+ r""" Stem layer of InternImage
7
+ Args:
8
+ in_channels (int): number of input channels
9
+ out_channels (int): number of output channels
10
+ act_layer (str): activation layer
11
+ norm_layer (str): normalization layer
12
+ """
13
+
14
+ def __init__(self,
15
+ in_channels=3+1,
16
+ inter_channels=48,
17
+ out_channels=96,
18
+ act_layer='GELU',
19
+ norm_layer='BN'):
20
+ super().__init__()
21
+ self.conv1 = nn.Conv2d(in_channels,
22
+ inter_channels,
23
+ kernel_size=3,
24
+ stride=1,
25
+ padding=1)
26
+ self.norm1 = build_norm_layer(
27
+ inter_channels, norm_layer, 'channels_first', 'channels_first'
28
+ )
29
+ self.act = build_act_layer(act_layer)
30
+ self.conv2 = nn.Conv2d(inter_channels,
31
+ out_channels,
32
+ kernel_size=3,
33
+ stride=1,
34
+ padding=1)
35
+ self.norm2 = build_norm_layer(
36
+ out_channels, norm_layer, 'channels_first', 'channels_first'
37
+ )
38
+
39
+ def forward(self, x):
40
+ x = self.conv1(x)
41
+ x = self.norm1(x)
42
+ x = self.act(x)
43
+ x = self.conv2(x)
44
+ x = self.norm2(x)
45
+ return x
models/refinement/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from os.path import dirname, basename, isfile, join
2
+ import glob
3
+
4
+
5
+ modules = glob.glob(join(dirname(__file__), "*.py"))
6
+ __all__ = [ basename(f)[:-3] for f in modules if isfile(f) and not f.endswith('__init__.py')]