aryanxxvii
commited on
Commit
•
9961846
1
Parent(s):
0008ffa
Add config.json with model_type
Browse files- config.json +1 -1
- u2net_pipeline.py +10 -4
config.json
CHANGED
@@ -4,7 +4,7 @@
|
|
4 |
"U2NET"
|
5 |
],
|
6 |
"task": "image-segmentation",
|
7 |
-
"pipeline_class": "U2NetPipeline",
|
8 |
"model_file": "u2net.py",
|
9 |
"weights_file": "u2net.pth"
|
10 |
}
|
|
|
4 |
"U2NET"
|
5 |
],
|
6 |
"task": "image-segmentation",
|
7 |
+
"pipeline_class": "u2net_pipeline.U2NetPipeline",
|
8 |
"model_file": "u2net.py",
|
9 |
"weights_file": "u2net.pth"
|
10 |
}
|
u2net_pipeline.py
CHANGED
@@ -10,10 +10,15 @@ from transformers import Pipeline
|
|
10 |
class U2NetPipeline(Pipeline):
|
11 |
def __init__(self, model, **kwargs):
|
12 |
super().__init__(model=model, **kwargs)
|
13 |
-
self.model =
|
14 |
-
self.model.load_state_dict(torch.load(model, map_location="cpu"))
|
15 |
self.model.eval()
|
16 |
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
def _sanitize_parameters(self, **kwargs):
|
18 |
return {}, {}, {}
|
19 |
|
@@ -50,5 +55,6 @@ class U2NetPipeline(Pipeline):
|
|
50 |
result = (result - mi) / (ma - mi)
|
51 |
return (result * 255).astype(np.uint8)
|
52 |
|
53 |
-
|
54 |
-
|
|
|
|
10 |
class U2NetPipeline(Pipeline):
|
11 |
def __init__(self, model, **kwargs):
|
12 |
super().__init__(model=model, **kwargs)
|
13 |
+
self.model = model
|
|
|
14 |
self.model.eval()
|
15 |
|
16 |
+
@classmethod
|
17 |
+
def from_pretrained(cls, model_path, **kwargs):
|
18 |
+
model = U2NET(3, 1)
|
19 |
+
model.load_state_dict(torch.load(f"{model_path}/u2net.pth", map_location="cpu"))
|
20 |
+
return cls(model, **kwargs)
|
21 |
+
|
22 |
def _sanitize_parameters(self, **kwargs):
|
23 |
return {}, {}, {}
|
24 |
|
|
|
55 |
result = (result - mi) / (ma - mi)
|
56 |
return (result * 255).astype(np.uint8)
|
57 |
|
58 |
+
# Remove or comment out this function as it's no longer needed
|
59 |
+
# def load_model():
|
60 |
+
# return U2NetPipeline("u2net.pth")
|