Nik Ska commited on
Commit
779e178
1 Parent(s): ccce03b
Files changed (1) hide show
  1. app.py +1 -43
app.py CHANGED
@@ -14,49 +14,7 @@ import io
14
  import os
15
  # import streamlit as st
16
 
17
- print(os.environ.get(["ENDPOINT"]))
18
-
19
- providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
20
- model_path = huggingface_hub.hf_hub_download(
21
- "skytnt/anime-seg", "isnetis.onnx")
22
- rmbg_model = rt.InferenceSession(model_path, providers=providers)
23
-
24
-
25
- def custom_background(background, foreground):
26
- foreground = ImageOps.contain(foreground, background.size)
27
- x = (background.size[0] - foreground.size[0]) // 2
28
- y = (background.size[1] - foreground.size[1]) // 2
29
- background.paste(foreground, (x, y), foreground)
30
- return background
31
-
32
-
33
- def get_mask(img, s=1024):
34
- img = (img / 255).astype(np.float32)
35
- h, w = h0, w0 = img.shape[:-1]
36
- h, w = (s, int(s * w / h)) if h > w else (int(s * h / w), s)
37
- ph, pw = s - h, s - w
38
- img_input = np.zeros([s, s, 3], dtype=np.float32)
39
- img_input[ph // 2:ph // 2 + h, pw //
40
- 2:pw // 2 + w] = cv2.resize(img, (w, h))
41
- img_input = np.transpose(img_input, (2, 0, 1))
42
- img_input = img_input[np.newaxis, :]
43
- mask = rmbg_model.run(None, {'img': img_input})[0][0]
44
- mask = np.transpose(mask, (1, 2, 0))
45
- mask = mask[ph // 2:ph // 2 + h, pw // 2:pw // 2 + w]
46
- mask = cv2.resize(mask, (w0, h0))[:, :, np.newaxis]
47
- return mask
48
-
49
-
50
- def predict(image, new_background):
51
- mask = get_mask(image)
52
- image = (mask * image + 255 * (1 - mask)).astype(np.uint8)
53
- mask = (mask * 255).astype(np.uint8)
54
- image = np.concatenate([image, mask], axis=2, dtype=np.uint8)
55
- mask = mask.repeat(3, axis=2)
56
- if new_background is not None:
57
- foreground = PIL.Image.fromarray(image)
58
- return mask, custom_background(new_background, foreground)
59
- return mask, image
60
 
61
 
62
  def get_mask(img_in):
 
14
  import os
15
  # import streamlit as st
16
 
17
+ print(os.environ.get("ENDPOINT"))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
 
20
  def get_mask(img_in):