schirrmacher commited on
Commit
665e653
1 Parent(s): 0479f79

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +32 -42
app.py CHANGED
@@ -6,58 +6,48 @@ from ormbg import ORMBG
6
  from PIL import Image
7
 
8
 
9
- def preprocess_image(im: np.ndarray, model_input_size: list) -> torch.Tensor:
10
- if len(im.shape) < 3:
11
- im = im[:, :, np.newaxis]
12
- im_tensor = torch.tensor(im, dtype=torch.float32).permute(2, 0, 1)
13
- im_tensor = F.interpolate(
14
- torch.unsqueeze(im_tensor, 0), size=model_input_size, mode="bilinear"
15
- ).type(torch.uint8)
16
- image = torch.divide(im_tensor, 255.0)
17
- return image
18
-
19
 
20
- def postprocess_image(result: torch.Tensor, im_size: list) -> np.ndarray:
21
- result = torch.squeeze(F.interpolate(result, size=im_size, mode="bilinear"), 0)
22
- ma = torch.max(result)
23
- mi = torch.min(result)
24
- result = (result - mi) / (ma - mi)
25
- im_array = (result * 255).permute(1, 2, 0).cpu().data.numpy().astype(np.uint8)
26
- im_array = np.squeeze(im_array)
27
- return im_array
28
 
29
 
30
- def inference(image):
 
 
 
 
31
 
32
- model_path = "ormbg.pth"
33
 
34
- net = ORMBG()
35
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
36
 
 
37
  orig_image = Image.fromarray(image)
38
-
 
 
 
 
 
39
  if torch.cuda.is_available():
40
- net.load_state_dict(torch.load(model_path))
41
- net = net.cuda()
42
- else:
43
- net.load_state_dict(torch.load(model_path, map_location="cpu"))
44
- net.eval()
45
-
46
- model_input_size = [1024, 1024]
47
- orig_im_size = orig_image.size
48
- processed_image = preprocess_image(orig_image, model_input_size).to(device)
49
-
50
- result = net(processed_image)
51
 
 
 
52
  # post process
53
- result_image = postprocess_image(result[0][0], orig_im_size)
54
-
55
- # save result
56
- pil_im = Image.fromarray(result_image)
57
- no_bg_image = Image.new("RGBA", pil_im.size, (0, 0, 0, 0))
58
- no_bg_image.paste(orig_image, mask=pil_im)
59
-
60
- return no_bg_image
 
 
 
 
61
 
62
 
63
  gr.Markdown("## Open Remove Background Model (ormbg)")
 
6
  from PIL import Image
7
 
8
 
9
+ model_path = "ormbg.pth"
 
 
 
 
 
 
 
 
 
10
 
11
+ net = ORMBG()
12
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
+ net.to(device)
 
 
 
 
 
14
 
15
 
16
+ def resize_image(image):
17
+ image = image.convert("RGB")
18
+ model_input_size = (1024, 1024)
19
+ image = image.resize(model_input_size, Image.BILINEAR)
20
+ return image
21
 
 
22
 
23
+ def inference(image):
 
24
 
25
+ # prepare input
26
  orig_image = Image.fromarray(image)
27
+ w, h = orig_image.size
28
+ image = resize_image(orig_image)
29
+ im_np = np.array(image)
30
+ im_tensor = torch.tensor(im_np, dtype=torch.float32).permute(2, 0, 1)
31
+ im_tensor = torch.unsqueeze(im_tensor, 0)
32
+ im_tensor = torch.divide(im_tensor, 255.0)
33
  if torch.cuda.is_available():
34
+ im_tensor = im_tensor.cuda()
 
 
 
 
 
 
 
 
 
 
35
 
36
+ # inference
37
+ result = net(im_tensor)
38
  # post process
39
+ result = torch.squeeze(F.interpolate(result[0][0], size=(h, w), mode="bilinear"), 0)
40
+ ma = torch.max(result)
41
+ mi = torch.min(result)
42
+ result = (result - mi) / (ma - mi)
43
+ # image to pil
44
+ im_array = (result * 255).cpu().data.numpy().astype(np.uint8)
45
+ pil_im = Image.fromarray(np.squeeze(im_array))
46
+ # paste the mask on the original image
47
+ new_im = Image.new("RGBA", pil_im.size, (0, 0, 0, 0))
48
+ new_im.paste(orig_image, mask=pil_im)
49
+
50
+ return new_im
51
 
52
 
53
  gr.Markdown("## Open Remove Background Model (ormbg)")