ZhengPeng7 commited on
Commit
e2ce7e5
1 Parent(s): 327742a

Remove redundant part of our_ref in inference.

Browse files
Files changed (2) hide show
  1. app.py +1 -1
  2. models/baseline.py +24 -22
app.py CHANGED
@@ -35,7 +35,7 @@ class ImagePreprocessor():
35
  return image
36
 
37
 
38
- model = BiRefNet().to(device)
39
  state_dict = './BiRefNet_ep580.pth'
40
  if os.path.exists(state_dict):
41
  birefnet_dict = torch.load(state_dict, map_location=device)
 
35
  return image
36
 
37
 
38
+ model = BiRefNet(bb_pretrained=False).to(device)
39
  state_dict = './BiRefNet_ep580.pth'
40
  if os.path.exists(state_dict):
41
  birefnet_dict = torch.load(state_dict, map_location=device)
models/baseline.py CHANGED
@@ -20,11 +20,11 @@ from models.refinement.stem_layer import StemLayer
20
 
21
 
22
  class BiRefNet(nn.Module):
23
- def __init__(self):
24
  super(BiRefNet, self).__init__()
25
  self.config = Config()
26
  self.epoch = 1
27
- self.bb = build_backbone(self.config.bb, pretrained=False)
28
 
29
  channels = self.config.lateral_channels_in_collection
30
 
@@ -126,7 +126,7 @@ class BiRefNet(nn.Module):
126
  x4 = self.squeeze_module(x4)
127
  ########## Decoder ##########
128
  features = [x, x1, x2, x3, x4]
129
- if self.config.out_ref:
130
  features.append(laplacian(torch.mean(x, dim=1).unsqueeze(1), kernel_size=5))
131
  scaled_preds = self.decoder(features)
132
  return scaled_preds, class_preds
@@ -231,7 +231,7 @@ class Decoder(nn.Module):
231
  return torch.cat(patches_batch, dim=0)
232
 
233
  def forward(self, features):
234
- if self.config.out_ref:
235
  outs_gdt_pred = []
236
  outs_gdt_label = []
237
  x, x1, x2, x3, x4, gdt_gt = features
@@ -249,18 +249,19 @@ class Decoder(nn.Module):
249
  p3 = self.decoder_block3(_p3)
250
  m3 = self.conv_ms_spvn_3(p3) if self.config.ms_supervision else None
251
  if self.config.out_ref:
252
- # >> GT:
253
- # m3 --dilation--> m3_dia
254
- # G_3^gt * m3_dia --> G_3^m, which is the label of gradient
255
- m3_dia = m3
256
- gdt_label_main_3 = gdt_gt * F.interpolate(m3_dia, size=gdt_gt.shape[2:], mode='bilinear', align_corners=True)
257
- outs_gdt_label.append(gdt_label_main_3)
258
- # >> Pred:
259
- # p3 --conv--BN--> F_3^G, where F_3^G predicts the \hat{G_3} with xx
260
- # F_3^G --sigmoid--> A_3^G
261
  p3_gdt = self.gdt_convs_3(p3)
262
- gdt_pred_3 = self.gdt_convs_pred_3(p3_gdt)
263
- outs_gdt_pred.append(gdt_pred_3)
 
 
 
 
 
 
 
 
 
 
264
  gdt_attn_3 = self.gdt_convs_attn_3(p3_gdt).sigmoid()
265
  # >> Finally:
266
  # p3 = p3 * A_3^G
@@ -274,14 +275,15 @@ class Decoder(nn.Module):
274
  p2 = self.decoder_block2(_p2)
275
  m2 = self.conv_ms_spvn_2(p2) if self.config.ms_supervision else None
276
  if self.config.out_ref:
277
- # >> GT:
278
- m2_dia = m2
279
- gdt_label_main_2 = gdt_gt * F.interpolate(m2_dia, size=gdt_gt.shape[2:], mode='bilinear', align_corners=True)
280
- outs_gdt_label.append(gdt_label_main_2)
281
- # >> Pred:
282
  p2_gdt = self.gdt_convs_2(p2)
283
- gdt_pred_2 = self.gdt_convs_pred_2(p2_gdt)
284
- outs_gdt_pred.append(gdt_pred_2)
 
 
 
 
 
 
285
  gdt_attn_2 = self.gdt_convs_attn_2(p2_gdt).sigmoid()
286
  # >> Finally:
287
  p2 = p2 * gdt_attn_2
 
20
 
21
 
22
  class BiRefNet(nn.Module):
23
+ def __init__(self, bb_pretrained=True):
24
  super(BiRefNet, self).__init__()
25
  self.config = Config()
26
  self.epoch = 1
27
+ self.bb = build_backbone(self.config.bb, pretrained=bb_pretrained)
28
 
29
  channels = self.config.lateral_channels_in_collection
30
 
 
126
  x4 = self.squeeze_module(x4)
127
  ########## Decoder ##########
128
  features = [x, x1, x2, x3, x4]
129
+ if self.training and self.config.out_ref:
130
  features.append(laplacian(torch.mean(x, dim=1).unsqueeze(1), kernel_size=5))
131
  scaled_preds = self.decoder(features)
132
  return scaled_preds, class_preds
 
231
  return torch.cat(patches_batch, dim=0)
232
 
233
  def forward(self, features):
234
+ if self.training and self.config.out_ref:
235
  outs_gdt_pred = []
236
  outs_gdt_label = []
237
  x, x1, x2, x3, x4, gdt_gt = features
 
249
  p3 = self.decoder_block3(_p3)
250
  m3 = self.conv_ms_spvn_3(p3) if self.config.ms_supervision else None
251
  if self.config.out_ref:
 
 
 
 
 
 
 
 
 
252
  p3_gdt = self.gdt_convs_3(p3)
253
+ if self.training:
254
+ # >> GT:
255
+ # m3 --dilation--> m3_dia
256
+ # G_3^gt * m3_dia --> G_3^m, which is the label of gradient
257
+ m3_dia = m3
258
+ gdt_label_main_3 = gdt_gt * F.interpolate(m3_dia, size=gdt_gt.shape[2:], mode='bilinear', align_corners=True)
259
+ outs_gdt_label.append(gdt_label_main_3)
260
+ # >> Pred:
261
+ # p3 --conv--BN--> F_3^G, where F_3^G predicts the \hat{G_3} with xx
262
+ # F_3^G --sigmoid--> A_3^G
263
+ gdt_pred_3 = self.gdt_convs_pred_3(p3_gdt)
264
+ outs_gdt_pred.append(gdt_pred_3)
265
  gdt_attn_3 = self.gdt_convs_attn_3(p3_gdt).sigmoid()
266
  # >> Finally:
267
  # p3 = p3 * A_3^G
 
275
  p2 = self.decoder_block2(_p2)
276
  m2 = self.conv_ms_spvn_2(p2) if self.config.ms_supervision else None
277
  if self.config.out_ref:
 
 
 
 
 
278
  p2_gdt = self.gdt_convs_2(p2)
279
+ if self.training:
280
+ # >> GT:
281
+ m2_dia = m2
282
+ gdt_label_main_2 = gdt_gt * F.interpolate(m2_dia, size=gdt_gt.shape[2:], mode='bilinear', align_corners=True)
283
+ outs_gdt_label.append(gdt_label_main_2)
284
+ # >> Pred:
285
+ gdt_pred_2 = self.gdt_convs_pred_2(p2_gdt)
286
+ outs_gdt_pred.append(gdt_pred_2)
287
  gdt_attn_2 = self.gdt_convs_attn_2(p2_gdt).sigmoid()
288
  # >> Finally:
289
  p2 = p2 * gdt_attn_2