Nunzio commited on
Commit
60fd570
·
1 Parent(s): ff83735

added BiSeNet V2

Browse files
app.py CHANGED
@@ -29,13 +29,14 @@ def run_prediction(image: gr.Image, selected_model: str)-> tuple[torch.Tensor]:
29
 
30
  if selected_model is None:
31
  return (gr.update(value=None, visible=False), gr.update(value=f"❌ No model selected for prediction.", visible=True))
32
- # try:
33
- model = loadModel(selected_model, device)
34
- image = hfImageToTensor(image, width=1024, height=512)
35
- prediction = predict(image, model)
36
- prediction = postprocessing(prediction)
37
- # except Exception as e:
38
- # return (gr.update(value=None, visible=False), gr.update(value=f"❌ {str(e)}.", visible=True))
 
39
  return (gr.update(value=prediction, visible=True), gr.update(value="", visible=False))
40
 
41
  # Gradio UI
 
29
 
30
  if selected_model is None:
31
  return (gr.update(value=None, visible=False), gr.update(value=f"❌ No model selected for prediction.", visible=True))
32
+
33
+ try:
34
+ model = loadModel(selected_model, device)
35
+ image = hfImageToTensor(image, width=1024, height=512)
36
+ prediction = predict(image, model)
37
+ prediction = postprocessing(prediction)
38
+ except Exception as e:
39
+ return (gr.update(value=None, visible=False), gr.update(value=f"❌ {str(e)}.", visible=True))
40
  return (gr.update(value=prediction, visible=True), gr.update(value="", visible=False))
41
 
42
  # Gradio UI
model/BiSeNetV2/model.py ADDED
@@ -0,0 +1,419 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import torch.utils.model_zoo as modelzoo
5
+
6
+ # URL for pretrained backbone weights
7
+ backbone_url = 'https://github.com/CoinCheung/BiSeNet/releases/download/0.0.0/backbone_v2.pth'
8
+
9
+ class ConvBNReLU(nn.Module):
10
+ """
11
+ Convolution + BatchNorm + ReLU block.
12
+ """
13
+ def __init__(self, in_chan, out_chan, ks=3, stride=1, padding=1,
14
+ dilation=1, groups=1, bias=False):
15
+ super(ConvBNReLU, self).__init__()
16
+ self.conv = nn.Conv2d(
17
+ in_chan, out_chan, kernel_size=ks, stride=stride,
18
+ padding=padding, dilation=dilation,
19
+ groups=groups, bias=bias)
20
+ self.bn = nn.BatchNorm2d(out_chan)
21
+ self.relu = nn.ReLU(inplace=True)
22
+
23
+ def forward(self, x):
24
+ feat = self.conv(x)
25
+ feat = self.bn(feat)
26
+ feat = self.relu(feat)
27
+ return feat
28
+
29
+ class UpSample(nn.Module):
30
+ """
31
+ Upsample block using PixelShuffle.
32
+ """
33
+ def __init__(self, n_chan, factor=2):
34
+ super(UpSample, self).__init__()
35
+ out_chan = n_chan * factor * factor
36
+ self.proj = nn.Conv2d(n_chan, out_chan, 1, 1, 0)
37
+ self.up = nn.PixelShuffle(factor)
38
+ self.init_weight()
39
+
40
+ def forward(self, x):
41
+ feat = self.proj(x)
42
+ feat = self.up(feat)
43
+ return feat
44
+
45
+ def init_weight(self):
46
+ nn.init.xavier_normal_(self.proj.weight, gain=1.)
47
+
48
+ class DetailBranch(nn.Module):
49
+ """
50
+ Detail branch for capturing spatial details.
51
+ """
52
+ def __init__(self):
53
+ super(DetailBranch, self).__init__()
54
+ self.S1 = nn.Sequential(
55
+ ConvBNReLU(3, 64, 3, stride=2),
56
+ ConvBNReLU(64, 64, 3, stride=1),
57
+ )
58
+ self.S2 = nn.Sequential(
59
+ ConvBNReLU(64, 64, 3, stride=2),
60
+ ConvBNReLU(64, 64, 3, stride=1),
61
+ ConvBNReLU(64, 64, 3, stride=1),
62
+ )
63
+ self.S3 = nn.Sequential(
64
+ ConvBNReLU(64, 128, 3, stride=2),
65
+ ConvBNReLU(128, 128, 3, stride=1),
66
+ ConvBNReLU(128, 128, 3, stride=1),
67
+ )
68
+
69
+ def forward(self, x):
70
+ feat = self.S1(x)
71
+ feat = self.S2(feat)
72
+ feat = self.S3(feat)
73
+ return feat
74
+
75
+ class StemBlock(nn.Module):
76
+ """
77
+ Stem block for the semantic branch.
78
+ """
79
+ def __init__(self):
80
+ super(StemBlock, self).__init__()
81
+ self.conv = ConvBNReLU(3, 16, 3, stride=2)
82
+ self.left = nn.Sequential(
83
+ ConvBNReLU(16, 8, 1, stride=1, padding=0),
84
+ ConvBNReLU(8, 16, 3, stride=2),
85
+ )
86
+ self.right = nn.MaxPool2d(
87
+ kernel_size=3, stride=2, padding=1, ceil_mode=False)
88
+ self.fuse = ConvBNReLU(32, 16, 3, stride=1)
89
+
90
+ def forward(self, x):
91
+ feat = self.conv(x)
92
+ feat_left = self.left(feat)
93
+ feat_right = self.right(feat)
94
+ feat = torch.cat([feat_left, feat_right], dim=1)
95
+ feat = self.fuse(feat)
96
+ return feat
97
+
98
+ class CEBlock(nn.Module):
99
+ """
100
+ Context Embedding Block.
101
+ """
102
+ def __init__(self):
103
+ super(CEBlock, self).__init__()
104
+ self.bn = nn.BatchNorm2d(128)
105
+ self.conv_gap = ConvBNReLU(128, 128, 1, stride=1, padding=0)
106
+ # In paper, this is a naive conv2d, no bn-relu
107
+ self.conv_last = ConvBNReLU(128, 128, 3, stride=1)
108
+
109
+ def forward(self, x):
110
+ feat = torch.mean(x, dim=(2, 3), keepdim=True)
111
+ feat = self.bn(feat)
112
+ feat = self.conv_gap(feat)
113
+ feat = feat + x
114
+ feat = self.conv_last(feat)
115
+ return feat
116
+
117
+ class GELayerS1(nn.Module):
118
+ """
119
+ Gather-and-Expansion Layer with stride 1.
120
+ """
121
+ def __init__(self, in_chan, out_chan, exp_ratio=6):
122
+ super(GELayerS1, self).__init__()
123
+ mid_chan = in_chan * exp_ratio
124
+ self.conv1 = ConvBNReLU(in_chan, in_chan, 3, stride=1)
125
+ self.dwconv = nn.Sequential(
126
+ nn.Conv2d(
127
+ in_chan, mid_chan, kernel_size=3, stride=1,
128
+ padding=1, groups=in_chan, bias=False),
129
+ nn.BatchNorm2d(mid_chan),
130
+ nn.ReLU(inplace=True), # not shown in paper
131
+ )
132
+ self.conv2 = nn.Sequential(
133
+ nn.Conv2d(
134
+ mid_chan, out_chan, kernel_size=1, stride=1,
135
+ padding=0, bias=False),
136
+ nn.BatchNorm2d(out_chan),
137
+ )
138
+ self.conv2[1].last_bn = True
139
+ self.relu = nn.ReLU(inplace=True)
140
+
141
+ def forward(self, x):
142
+ feat = self.conv1(x)
143
+ feat = self.dwconv(feat)
144
+ feat = self.conv2(feat)
145
+ feat = feat + x
146
+ feat = self.relu(feat)
147
+ return feat
148
+
149
+ class GELayerS2(nn.Module):
150
+ """
151
+ Gather-and-Expansion Layer with stride 2.
152
+ """
153
+ def __init__(self, in_chan, out_chan, exp_ratio=6):
154
+ super(GELayerS2, self).__init__()
155
+ mid_chan = in_chan * exp_ratio
156
+ self.conv1 = ConvBNReLU(in_chan, in_chan, 3, stride=1)
157
+ self.dwconv1 = nn.Sequential(
158
+ nn.Conv2d(
159
+ in_chan, mid_chan, kernel_size=3, stride=2,
160
+ padding=1, groups=in_chan, bias=False),
161
+ nn.BatchNorm2d(mid_chan),
162
+ )
163
+ self.dwconv2 = nn.Sequential(
164
+ nn.Conv2d(
165
+ mid_chan, mid_chan, kernel_size=3, stride=1,
166
+ padding=1, groups=mid_chan, bias=False),
167
+ nn.BatchNorm2d(mid_chan),
168
+ nn.ReLU(inplace=True), # not shown in paper
169
+ )
170
+ self.conv2 = nn.Sequential(
171
+ nn.Conv2d(
172
+ mid_chan, out_chan, kernel_size=1, stride=1,
173
+ padding=0, bias=False),
174
+ nn.BatchNorm2d(out_chan),
175
+ )
176
+ self.conv2[1].last_bn = True
177
+ self.shortcut = nn.Sequential(
178
+ nn.Conv2d(
179
+ in_chan, in_chan, kernel_size=3, stride=2,
180
+ padding=1, groups=in_chan, bias=False),
181
+ nn.BatchNorm2d(in_chan),
182
+ nn.Conv2d(
183
+ in_chan, out_chan, kernel_size=1, stride=1,
184
+ padding=0, bias=False),
185
+ nn.BatchNorm2d(out_chan),
186
+ )
187
+ self.relu = nn.ReLU(inplace=True)
188
+
189
+ def forward(self, x):
190
+ feat = self.conv1(x)
191
+ feat = self.dwconv1(feat)
192
+ feat = self.dwconv2(feat)
193
+ feat = self.conv2(feat)
194
+ shortcut = self.shortcut(x)
195
+ feat = feat + shortcut
196
+ feat = self.relu(feat)
197
+ return feat
198
+
199
+ class SegmentBranch(nn.Module):
200
+ """
201
+ Semantic branch for extracting semantic features.
202
+ """
203
+ def __init__(self):
204
+ super(SegmentBranch, self).__init__()
205
+ self.S1S2 = StemBlock()
206
+ self.S3 = nn.Sequential(
207
+ GELayerS2(16, 32),
208
+ GELayerS1(32, 32),
209
+ )
210
+ self.S4 = nn.Sequential(
211
+ GELayerS2(32, 64),
212
+ GELayerS1(64, 64),
213
+ )
214
+ self.S5_4 = nn.Sequential(
215
+ GELayerS2(64, 128),
216
+ GELayerS1(128, 128),
217
+ GELayerS1(128, 128),
218
+ GELayerS1(128, 128),
219
+ )
220
+ self.S5_5 = CEBlock()
221
+
222
+ def forward(self, x):
223
+ feat2 = self.S1S2(x)
224
+ feat3 = self.S3(feat2)
225
+ feat4 = self.S4(feat3)
226
+ feat5_4 = self.S5_4(feat4)
227
+ feat5_5 = self.S5_5(feat5_4)
228
+ return feat2, feat3, feat4, feat5_4, feat5_5
229
+
230
+ class BGALayer(nn.Module):
231
+ """
232
+ Bilateral Guided Aggregation Layer.
233
+ """
234
+ def __init__(self):
235
+ super(BGALayer, self).__init__()
236
+ self.left1 = nn.Sequential(
237
+ nn.Conv2d(
238
+ 128, 128, kernel_size=3, stride=1,
239
+ padding=1, groups=128, bias=False),
240
+ nn.BatchNorm2d(128),
241
+ nn.Conv2d(
242
+ 128, 128, kernel_size=1, stride=1,
243
+ padding=0, bias=False),
244
+ )
245
+ self.left2 = nn.Sequential(
246
+ nn.Conv2d(
247
+ 128, 128, kernel_size=3, stride=2,
248
+ padding=1, bias=False),
249
+ nn.BatchNorm2d(128),
250
+ nn.AvgPool2d(kernel_size=3, stride=2, padding=1, ceil_mode=False)
251
+ )
252
+ self.right1 = nn.Sequential(
253
+ nn.Conv2d(
254
+ 128, 128, kernel_size=3, stride=1,
255
+ padding=1, bias=False),
256
+ nn.BatchNorm2d(128),
257
+ )
258
+ self.right2 = nn.Sequential(
259
+ nn.Conv2d(
260
+ 128, 128, kernel_size=3, stride=1,
261
+ padding=1, groups=128, bias=False),
262
+ nn.BatchNorm2d(128),
263
+ nn.Conv2d(
264
+ 128, 128, kernel_size=1, stride=1,
265
+ padding=0, bias=False),
266
+ )
267
+ self.up1 = nn.Upsample(scale_factor=4)
268
+ self.up2 = nn.Upsample(scale_factor=4)
269
+ # In paper, this may have no relu
270
+ self.conv = nn.Sequential(
271
+ nn.Conv2d(
272
+ 128, 128, kernel_size=3, stride=1,
273
+ padding=1, bias=False),
274
+ nn.BatchNorm2d(128),
275
+ nn.ReLU(inplace=True), # not shown in paper
276
+ )
277
+
278
+ def forward(self, x_d, x_s):
279
+ dsize = x_d.size()[2:]
280
+ left1 = self.left1(x_d)
281
+ left2 = self.left2(x_d)
282
+ right1 = self.right1(x_s)
283
+ right2 = self.right2(x_s)
284
+ right1 = self.up1(right1)
285
+ left = left1 * torch.sigmoid(right1)
286
+ right = left2 * torch.sigmoid(right2)
287
+ right = self.up2(right)
288
+ out = self.conv(left + right)
289
+ return out
290
+
291
+ class SegmentHead(nn.Module):
292
+ """
293
+ Segmentation head for outputting logits.
294
+ """
295
+ def __init__(self, in_chan, mid_chan, n_classes, up_factor=8, aux=True):
296
+ super(SegmentHead, self).__init__()
297
+ self.conv = ConvBNReLU(in_chan, mid_chan, 3, stride=1)
298
+ self.drop = nn.Dropout(0.1)
299
+ self.up_factor = up_factor
300
+
301
+ out_chan = n_classes
302
+ mid_chan2 = up_factor * up_factor if aux else mid_chan
303
+ up_factor = up_factor // 2 if aux else up_factor
304
+ self.conv_out = nn.Sequential(
305
+ nn.Sequential(
306
+ nn.Upsample(scale_factor=2),
307
+ ConvBNReLU(mid_chan, mid_chan2, 3, stride=1)
308
+ ) if aux else nn.Identity(),
309
+ nn.Conv2d(mid_chan2, out_chan, 1, 1, 0, bias=True),
310
+ nn.Upsample(scale_factor=up_factor, mode='bilinear', align_corners=False)
311
+ )
312
+
313
+ def forward(self, x):
314
+ feat = self.conv(x)
315
+ feat = self.drop(feat)
316
+ feat = self.conv_out(feat)
317
+ return feat
318
+
319
+ class CustomArgMax(torch.autograd.Function):
320
+ """
321
+ Custom ArgMax function for ONNX export compatibility.
322
+ """
323
+ @staticmethod
324
+ def forward(ctx, feat_out, dim):
325
+ return feat_out.argmax(dim=dim).int()
326
+
327
+ @staticmethod
328
+ def symbolic(g, feat_out, dim: int):
329
+ return g.op('CustomArgMax', feat_out, dim_i=dim)
330
+
331
+ class BiSeNetV2(nn.Module):
332
+ """
333
+ BiSeNetV2 main model.
334
+ """
335
+ def __init__(self, n_classes, aux_mode='train'):
336
+ super(BiSeNetV2, self).__init__()
337
+ self.aux_mode = aux_mode
338
+ self.detail = DetailBranch()
339
+ self.segment = SegmentBranch()
340
+ self.bga = BGALayer()
341
+
342
+ # Main segmentation head
343
+ self.head = SegmentHead(128, 1024, n_classes, up_factor=8, aux=False)
344
+ if self.aux_mode == 'train':
345
+ # Auxiliary heads for deep supervision
346
+ self.aux2 = SegmentHead(16, 128, n_classes, up_factor=4)
347
+ self.aux3 = SegmentHead(32, 128, n_classes, up_factor=8)
348
+ self.aux4 = SegmentHead(64, 128, n_classes, up_factor=16)
349
+ self.aux5_4 = SegmentHead(128, 128, n_classes, up_factor=32)
350
+
351
+ self.init_weights()
352
+
353
+ def forward(self, x):
354
+ size = x.size()[2:]
355
+ feat_d = self.detail(x)
356
+ feat2, feat3, feat4, feat5_4, feat_s = self.segment(x)
357
+ feat_head = self.bga(feat_d, feat_s)
358
+
359
+ logits = self.head(feat_head)
360
+ if self.aux_mode == 'train':
361
+ logits_aux2 = self.aux2(feat2)
362
+ logits_aux3 = self.aux3(feat3)
363
+ logits_aux4 = self.aux4(feat4)
364
+ logits_aux5_4 = self.aux5_4(feat5_4)
365
+ return logits, logits_aux2, logits_aux3, logits_aux4, logits_aux5_4
366
+ elif self.aux_mode == 'eval':
367
+ return logits,
368
+ elif self.aux_mode == 'pred':
369
+ # Use custom argmax for ONNX compatibility
370
+ pred = CustomArgMax.apply(logits, 1)
371
+ return pred
372
+ else:
373
+ raise NotImplementedError
374
+
375
+ def init_weights(self):
376
+ """
377
+ Initialize model weights.
378
+ """
379
+ for name, module in self.named_modules():
380
+ if isinstance(module, (nn.Conv2d, nn.Linear)):
381
+ nn.init.kaiming_normal_(module.weight, mode='fan_out')
382
+ if not module.bias is None: nn.init.constant_(module.bias, 0)
383
+ elif isinstance(module, nn.modules.batchnorm._BatchNorm):
384
+ if hasattr(module, 'last_bn') and module.last_bn:
385
+ nn.init.zeros_(module.weight)
386
+ else:
387
+ nn.init.ones_(module.weight)
388
+ nn.init.zeros_(module.bias)
389
+ self.load_pretrain()
390
+
391
+ def load_pretrain(self):
392
+ """
393
+ Load pretrained backbone weights.
394
+ """
395
+ state = modelzoo.load_url(backbone_url)
396
+ for name, child in self.named_children():
397
+ if name in state.keys():
398
+ child.load_state_dict(state[name], strict=True)
399
+
400
+ def get_params(self):
401
+ """
402
+ Get model parameters for optimizer with/without weight decay.
403
+ """
404
+ def add_param_to_list(mod, wd_params, nowd_params):
405
+ for param in mod.parameters():
406
+ if param.dim() == 1:
407
+ nowd_params.append(param)
408
+ elif param.dim() == 4:
409
+ wd_params.append(param)
410
+ else:
411
+ print(name)
412
+
413
+ wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params = [], [], [], []
414
+ for name, child in self.named_children():
415
+ if 'head' in name or 'aux' in name:
416
+ add_param_to_list(child, lr_mul_wd_params, lr_mul_nowd_params)
417
+ else:
418
+ add_param_to_list(child, wd_params, nowd_params)
419
+ return wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params
model/modelLoading.py CHANGED
@@ -1,10 +1,9 @@
1
  import torch
2
 
3
  from model.BiSeNet.build_bisenet import BiSeNet
 
4
 
5
-
6
- # %% load model
7
-
8
  def loadModel(model:str = 'bisenet', device: str = 'cpu')->BiSeNet:
9
  """
10
  Load the specified model and move it to the given device.
@@ -18,6 +17,7 @@ def loadModel(model:str = 'bisenet', device: str = 'cpu')->BiSeNet:
18
  """
19
  match model.lower() if isinstance(model, str) else model:
20
  case 'bisenet': model = loadBiSeNet(device)
 
21
  case _: raise NotImplementedError(f"Model {model} is not implemented. Please choose 'bisenet' .")
22
 
23
  return model
@@ -38,4 +38,21 @@ def loadBiSeNet(device: str = 'cpu') -> BiSeNet:
38
  model.load_state_dict(torch.load('./weights/BiSeNet/weightADV.pth', map_location=device)['model_state_dict'])
39
  model.eval()
40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  return model
 
1
  import torch
2
 
3
  from model.BiSeNet.build_bisenet import BiSeNet
4
+ from model.BiSeNetV2.model import BiSeNetV2
5
 
6
+ # general loading function
 
 
7
  def loadModel(model:str = 'bisenet', device: str = 'cpu')->BiSeNet:
8
  """
9
  Load the specified model and move it to the given device.
 
17
  """
18
  match model.lower() if isinstance(model, str) else model:
19
  case 'bisenet': model = loadBiSeNet(device)
20
+ case 'bisenetv2': model = loadBiSeNetV2(device)
21
  case _: raise NotImplementedError(f"Model {model} is not implemented. Please choose 'bisenet' .")
22
 
23
  return model
 
38
  model.load_state_dict(torch.load('./weights/BiSeNet/weightADV.pth', map_location=device)['model_state_dict'])
39
  model.eval()
40
 
41
+ return model
42
+
43
+
44
+ def loadBiSeNetV2(device: str = 'cpu') -> BiSeNetV2:
45
+ """
46
+ Load the BiSeNetV2 model and move it to the specified device.
47
+
48
+ Args:
49
+ device (str): Device to load the model onto ('cpu' or 'cuda').
50
+
51
+ Returns:
52
+ model (BiSeNetV2): The loaded BiSeNetV2 model.
53
+ """
54
+ model = BiSeNetV2(n_classes=19).to(device)
55
+ model.load_state_dict(torch.load('./weights/BiSeNetV2/weightADV.pth', map_location=device)['model_state_dict'])
56
+ model.eval()
57
+
58
  return model
weights/BiSeNetV2/weightADV.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f4985e58c8879e096c82e0eb95b3dc29beec5ceb60518d490e27a346b8a4b8b7
3
+ size 64390390