ZhengPeng7 commited on
Commit
d73d557
1 Parent(s): f10d16f

For users to load in one key.

Browse files
Files changed (3) hide show
  1. BiRefNet_pipe.py +10 -0
  2. MyPipe.py +0 -76
  3. config.json +1 -0
BiRefNet_pipe.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import Pipeline
3
+
4
+
5
+ class BiRefNetPipe(Pipeline):
6
+ def __init__(self, **kwargs):
7
+ Pipeline.__init__(self, **kwargs)
8
+ self.model.to(['cpu', 0][torch.cuda.is_available()])
9
+ self.model.eval()
10
+
MyPipe.py DELETED
@@ -1,76 +0,0 @@
1
- import torch, os
2
- import torch.nn.functional as F
3
- from torchvision.transforms.functional import normalize
4
- import numpy as np
5
- from transformers import Pipeline
6
- from transformers.image_utils import load_image
7
- from skimage import io
8
- from PIL import Image
9
-
10
- class RMBGPipe(Pipeline):
11
- def __init__(self,**kwargs):
12
- Pipeline.__init__(self,**kwargs)
13
- self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
14
- self.model.to(self.device)
15
- self.model.eval()
16
-
17
- def _sanitize_parameters(self, **kwargs):
18
- # parse parameters
19
- preprocess_kwargs = {}
20
- postprocess_kwargs = {}
21
- if "model_input_size" in kwargs :
22
- preprocess_kwargs["model_input_size"] = kwargs["model_input_size"]
23
- if "return_mask" in kwargs:
24
- postprocess_kwargs["return_mask"] = kwargs["return_mask"]
25
- return preprocess_kwargs, {}, postprocess_kwargs
26
-
27
- def preprocess(self,input_image,model_input_size: list=[1024,1024]):
28
- # preprocess the input
29
- orig_im = load_image(input_image)
30
- orig_im = np.array(orig_im)
31
- orig_im_size = orig_im.shape[0:2]
32
- preprocessed_image = self.preprocess_image(orig_im, model_input_size).to(self.device)
33
- inputs = {
34
- "preprocessed_image":preprocessed_image,
35
- "orig_im_size":orig_im_size,
36
- "input_image" : input_image
37
- }
38
- return inputs
39
-
40
- def _forward(self,inputs):
41
- result = self.model(inputs.pop("preprocessed_image"))
42
- inputs["result"] = result
43
- return inputs
44
-
45
- def postprocess(self,inputs,return_mask:bool=False ):
46
- result = inputs.pop("result")
47
- orig_im_size = inputs.pop("orig_im_size")
48
- input_image = inputs.pop("input_image")
49
- result_image = self.postprocess_image(result[0][0], orig_im_size)
50
- pil_im = Image.fromarray(result_image)
51
- if return_mask ==True :
52
- return pil_im
53
- no_bg_image = Image.new("RGBA", pil_im.size, (0,0,0,0))
54
- input_image = load_image(input_image)
55
- no_bg_image.paste(input_image, mask=pil_im)
56
- return no_bg_image
57
-
58
- # utilities functions
59
- def preprocess_image(self,im: np.ndarray, model_input_size: list=[1024,1024]) -> torch.Tensor:
60
- # same as utilities.py with minor modification
61
- if len(im.shape) < 3:
62
- im = im[:, :, np.newaxis]
63
- im_tensor = torch.tensor(im, dtype=torch.float32).permute(2,0,1)
64
- im_tensor = F.interpolate(torch.unsqueeze(im_tensor,0), size=model_input_size, mode='bilinear')
65
- image = torch.divide(im_tensor,255.0)
66
- image = normalize(image,[0.5,0.5,0.5],[1.0,1.0,1.0])
67
- return image
68
-
69
- def postprocess_image(self,result: torch.Tensor, im_size: list)-> np.ndarray:
70
- result = torch.squeeze(F.interpolate(result, size=im_size, mode='bilinear') ,0)
71
- ma = torch.max(result)
72
- mi = torch.min(result)
73
- result = (result-mi)/(ma-mi)
74
- im_array = (result*255).permute(1,2,0).cpu().data.numpy().astype(np.uint8)
75
- im_array = np.squeeze(im_array)
76
- return im_array
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
config.json CHANGED
@@ -9,6 +9,7 @@
9
  },
10
  "custom_pipelines": {
11
  "image-segmentation": {
 
12
  "pt": [
13
  "AutoModelForImageSegmentation"
14
  ],
 
9
  },
10
  "custom_pipelines": {
11
  "image-segmentation": {
12
+ "impl": "BiRefNet_pipe.BiRefNetPipe",
13
  "pt": [
14
  "AutoModelForImageSegmentation"
15
  ],