nielsr HF staff commited on
Commit
e9d77ab
1 Parent(s): 018621a

Update briarmbg.py

Browse files
Files changed (1) hide show
  1. briarmbg.py +6 -5
briarmbg.py CHANGED
@@ -1,6 +1,7 @@
1
  import torch
2
  import torch.nn as nn
3
  import torch.nn.functional as F
 
4
 
5
  class REBNCONV(nn.Module):
6
  def __init__(self,in_ch=3,out_ch=3,dirate=1,stride=1):
@@ -344,11 +345,12 @@ class myrebnconv(nn.Module):
344
  return self.rl(self.bn(self.conv(x)))
345
 
346
 
347
- class BriaRMBG(nn.Module):
348
 
349
- def __init__(self,in_ch=3,out_ch=1):
350
  super(BriaRMBG,self).__init__()
351
-
 
352
  self.conv_in = nn.Conv2d(in_ch,64,3,stride=2,padding=1)
353
  self.pool_in = nn.MaxPool2d(2,stride=2,ceil_mode=True)
354
 
@@ -451,5 +453,4 @@ class BriaRMBG(nn.Module):
451
  d6 = self.side6(hx6)
452
  d6 = _upsample_like(d6,x)
453
 
454
- return [F.sigmoid(d1), F.sigmoid(d2), F.sigmoid(d3), F.sigmoid(d4), F.sigmoid(d5), F.sigmoid(d6)],[hx1d,hx2d,hx3d,hx4d,hx5d,hx6]
455
-
 
1
  import torch
2
  import torch.nn as nn
3
  import torch.nn.functional as F
4
+ from huggingface_hub import PyTorchModelHubMixin
5
 
6
  class REBNCONV(nn.Module):
7
  def __init__(self,in_ch=3,out_ch=3,dirate=1,stride=1):
 
345
  return self.rl(self.bn(self.conv(x)))
346
 
347
 
348
+ class BriaRMBG(nn.Module, PyTorchModelHubMixin):
349
 
350
+ def __init__(self,config:dict={"in_ch":3,"out_ch":1}):
351
  super(BriaRMBG,self).__init__()
352
+ in_ch=config["in_ch"]
353
+ out_ch=config["out_ch"]
354
  self.conv_in = nn.Conv2d(in_ch,64,3,stride=2,padding=1)
355
  self.pool_in = nn.MaxPool2d(2,stride=2,ceil_mode=True)
356
 
 
453
  d6 = self.side6(hx6)
454
  d6 = _upsample_like(d6,x)
455
 
456
+ return [F.sigmoid(d1), F.sigmoid(d2), F.sigmoid(d3), F.sigmoid(d4), F.sigmoid(d5), F.sigmoid(d6)],[hx1d,hx2d,hx3d,hx4d,hx5d,hx6]