nielsr HF staff commited on
Commit
f7e05b7
1 Parent(s): 36193ca
Files changed (2) hide show
  1. briarmbg.py +3 -1
  2. mixin.py +15 -0
briarmbg.py CHANGED
@@ -2,6 +2,8 @@ import torch
2
  import torch.nn as nn
3
  import torch.nn.functional as F
4
 
 
 
5
  class REBNCONV(nn.Module):
6
  def __init__(self,in_ch=3,out_ch=3,dirate=1,stride=1):
7
  super(REBNCONV,self).__init__()
@@ -344,7 +346,7 @@ class myrebnconv(nn.Module):
344
  return self.rl(self.bn(self.conv(x)))
345
 
346
 
347
- class BriaRMBG(nn.Module):
348
 
349
  def __init__(self,in_ch=3,out_ch=1):
350
  super(BriaRMBG,self).__init__()
 
2
  import torch.nn as nn
3
  import torch.nn.functional as F
4
 
5
+ from huggingface_hub import PyTorchModelHubMixin
6
+
7
  class REBNCONV(nn.Module):
8
  def __init__(self,in_ch=3,out_ch=3,dirate=1,stride=1):
9
  super(REBNCONV,self).__init__()
 
346
  return self.rl(self.bn(self.conv(x)))
347
 
348
 
349
+ class BriaRMBG(nn.Module, PyTorchModelHubMixin):
350
 
351
  def __init__(self,in_ch=3,out_ch=1):
352
  super(BriaRMBG,self).__init__()
mixin.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from briarmbg import BriaRMBG
2
+ import torch
3
+ from huggingface_hub import hf_hub_download
4
+
5
+ model_path = hf_hub_download("briaai/RMBG-1.4", 'model.pth')
6
+
7
+ net = BriaRMBG()
8
+ net.load_state_dict(torch.load(model_path, map_location="cpu"))
9
+ net.eval()
10
+
11
+ # push to hub
12
+ net.push_to_hub("nielsr/RMBG-1.4")
13
+
14
+ # reload
15
+ net = BriaRMBG.from_pretrained("nielsr/RMBG-1.4")