aryanxxvii commited on
Commit
9961846
1 Parent(s): 0008ffa

Add config.json with model_type

Browse files
Files changed (2) hide show
  1. config.json +1 -1
  2. u2net_pipeline.py +10 -4
config.json CHANGED
@@ -4,7 +4,7 @@
4
  "U2NET"
5
  ],
6
  "task": "image-segmentation",
7
- "pipeline_class": "U2NetPipeline",
8
  "model_file": "u2net.py",
9
  "weights_file": "u2net.pth"
10
  }
 
4
  "U2NET"
5
  ],
6
  "task": "image-segmentation",
7
+ "pipeline_class": "u2net_pipeline.U2NetPipeline",
8
  "model_file": "u2net.py",
9
  "weights_file": "u2net.pth"
10
  }
u2net_pipeline.py CHANGED
@@ -10,10 +10,15 @@ from transformers import Pipeline
10
  class U2NetPipeline(Pipeline):
11
  def __init__(self, model, **kwargs):
12
  super().__init__(model=model, **kwargs)
13
- self.model = U2NET(3, 1)
14
- self.model.load_state_dict(torch.load(model, map_location="cpu"))
15
  self.model.eval()
16
 
 
 
 
 
 
 
17
  def _sanitize_parameters(self, **kwargs):
18
  return {}, {}, {}
19
 
@@ -50,5 +55,6 @@ class U2NetPipeline(Pipeline):
50
  result = (result - mi) / (ma - mi)
51
  return (result * 255).astype(np.uint8)
52
 
53
- def load_model():
54
- return U2NetPipeline("u2net.pth")
 
 
10
  class U2NetPipeline(Pipeline):
11
  def __init__(self, model, **kwargs):
12
  super().__init__(model=model, **kwargs)
13
+ self.model = model
 
14
  self.model.eval()
15
 
16
+ @classmethod
17
+ def from_pretrained(cls, model_path, **kwargs):
18
+ model = U2NET(3, 1)
19
+ model.load_state_dict(torch.load(f"{model_path}/u2net.pth", map_location="cpu"))
20
+ return cls(model, **kwargs)
21
+
22
  def _sanitize_parameters(self, **kwargs):
23
  return {}, {}, {}
24
 
 
55
  result = (result - mi) / (ma - mi)
56
  return (result * 255).astype(np.uint8)
57
 
58
+ # Remove or comment out this function as it's no longer needed
59
+ # def load_model():
60
+ # return U2NetPipeline("u2net.pth")