schirrmacher commited on
Commit
2c218d6
1 Parent(s): 3b4fc48

Upload folder using huggingface_hub

Browse files
Files changed (5) hide show
  1. .gitattributes +3 -4
  2. app.py +31 -19
  3. example1.jpeg +3 -0
  4. example2.jpeg +3 -0
  5. example3.jpeg +3 -0
.gitattributes CHANGED
@@ -33,7 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
- example1.png filter=lfs diff=lfs merge=lfs -text
37
- example2.png filter=lfs diff=lfs merge=lfs -text
38
- example3.png filter=lfs diff=lfs merge=lfs -text
39
- examples.jpg filter=lfs diff=lfs merge=lfs -text
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ example1.jpeg filter=lfs diff=lfs merge=lfs -text
37
+ example2.jpeg filter=lfs diff=lfs merge=lfs -text
38
+ example3.jpeg filter=lfs diff=lfs merge=lfs -text
 
app.py CHANGED
@@ -1,4 +1,3 @@
1
- import spaces
2
  import numpy as np
3
  import torch
4
  import torch.nn.functional as F
@@ -6,27 +5,31 @@ import gradio as gr
6
  from ormbg import ORMBG
7
  from PIL import Image
8
 
 
9
  model_path = "ormbg.pth"
10
 
11
- # Load the model globally but don't send to device yet
12
  net = ORMBG()
13
- net.load_state_dict(torch.load(model_path, map_location="cpu"))
 
 
 
 
 
 
 
14
  net.eval()
15
 
 
16
  def resize_image(image):
17
  image = image.convert("RGB")
18
  model_input_size = (1024, 1024)
19
  image = image.resize(model_input_size, Image.BILINEAR)
20
  return image
21
 
22
- @spaces.GPU
23
- @torch.inference_mode()
24
  def inference(image):
25
- # Check for CUDA and set the device inside inference
26
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
27
- net.to(device)
28
 
29
- # Prepare input
30
  orig_image = Image.fromarray(image)
31
  w, h = orig_image.size
32
  image = resize_image(orig_image)
@@ -34,41 +37,50 @@ def inference(image):
34
  im_tensor = torch.tensor(im_np, dtype=torch.float32).permute(2, 0, 1)
35
  im_tensor = torch.unsqueeze(im_tensor, 0)
36
  im_tensor = torch.divide(im_tensor, 255.0)
37
-
38
  if torch.cuda.is_available():
39
- im_tensor = im_tensor.to(device)
40
 
41
- # Inference
42
  result = net(im_tensor)
43
- # Post process
44
  result = torch.squeeze(F.interpolate(result[0][0], size=(h, w), mode="bilinear"), 0)
45
  ma = torch.max(result)
46
  mi = torch.min(result)
47
  result = (result - mi) / (ma - mi)
48
- # Image to PIL
49
  im_array = (result * 255).cpu().data.numpy().astype(np.uint8)
50
  pil_im = Image.fromarray(np.squeeze(im_array))
51
- # Paste the mask on the original image
52
  new_im = Image.new("RGBA", pil_im.size, (0, 0, 0, 0))
53
  new_im.paste(orig_image, mask=pil_im)
54
 
55
  return new_im
56
 
57
- # Gradio interface setup
 
 
 
 
 
 
 
 
 
58
  title = "Open Remove Background Model (ormbg)"
59
  description = r"""
60
  This model is a <strong>fully open-source background remover</strong> optimized for images with humans.
 
61
  It is based on [Highly Accurate Dichotomous Image Segmentation research](https://github.com/xuebinqin/DIS).
62
  The model was trained with the synthetic [Human Segmentation Dataset](https://huggingface.co/datasets/schirrmacher/humans).
63
 
64
  This is the first iteration of the model, so there will be improvements!
65
- If you identify cases where the model fails, <a href='https://huggingface.co/schirrmacher/ormbg/discussions' target='_blank'>upload your examples</a>!
66
 
67
  - <a href='https://huggingface.co/schirrmacher/ormbg' target='_blank'>Model card</a>: find inference code, training information, tutorials
68
  - <a href='https://huggingface.co/schirrmacher/ormbg' target='_blank'>Dataset</a>: see training images, segmentation data, backgrounds
69
  - <a href='https://huggingface.co/schirrmacher/ormbg\#research' target='_blank'>Research</a>: see current approach for improvements
70
- """
71
 
 
72
  examples = ["./example1.png", "./example2.png", "./example3.png"]
73
 
74
  demo = gr.Interface(
@@ -77,7 +89,7 @@ demo = gr.Interface(
77
  outputs="image",
78
  examples=examples,
79
  title=title,
80
- description=description
81
  )
82
 
83
  if __name__ == "__main__":
 
 
1
  import numpy as np
2
  import torch
3
  import torch.nn.functional as F
 
5
  from ormbg import ORMBG
6
  from PIL import Image
7
 
8
+
9
  model_path = "ormbg.pth"
10
 
 
11
  net = ORMBG()
12
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
+ net.to(device)
14
+
15
+ if torch.cuda.is_available():
16
+ net.load_state_dict(torch.load(model_path))
17
+ net = net.cuda()
18
+ else:
19
+ net.load_state_dict(torch.load(model_path, map_location="cpu"))
20
  net.eval()
21
 
22
+
23
  def resize_image(image):
24
  image = image.convert("RGB")
25
  model_input_size = (1024, 1024)
26
  image = image.resize(model_input_size, Image.BILINEAR)
27
  return image
28
 
29
+
 
30
  def inference(image):
 
 
 
31
 
32
+ # prepare input
33
  orig_image = Image.fromarray(image)
34
  w, h = orig_image.size
35
  image = resize_image(orig_image)
 
37
  im_tensor = torch.tensor(im_np, dtype=torch.float32).permute(2, 0, 1)
38
  im_tensor = torch.unsqueeze(im_tensor, 0)
39
  im_tensor = torch.divide(im_tensor, 255.0)
 
40
  if torch.cuda.is_available():
41
+ im_tensor = im_tensor.cuda()
42
 
43
+ # inference
44
  result = net(im_tensor)
45
+ # post process
46
  result = torch.squeeze(F.interpolate(result[0][0], size=(h, w), mode="bilinear"), 0)
47
  ma = torch.max(result)
48
  mi = torch.min(result)
49
  result = (result - mi) / (ma - mi)
50
+ # image to pil
51
  im_array = (result * 255).cpu().data.numpy().astype(np.uint8)
52
  pil_im = Image.fromarray(np.squeeze(im_array))
53
+ # paste the mask on the original image
54
  new_im = Image.new("RGBA", pil_im.size, (0, 0, 0, 0))
55
  new_im.paste(orig_image, mask=pil_im)
56
 
57
  return new_im
58
 
59
+
60
+ gr.Markdown("## Open Remove Background Model (ormbg)")
61
+ gr.HTML(
62
+ """
63
+ <p style="margin-bottom: 10px; font-size: 94%">
64
+ This is a demo for Open Remove Background Model (ormbg) that using
65
+ <a href="https://huggingface.co/schirrmacher/ormbg" target="_blank">Open Remove Background Model (ormbg) model</a> as backbone.
66
+ </p>
67
+ """
68
+ )
69
  title = "Open Remove Background Model (ormbg)"
70
  description = r"""
71
  This model is a <strong>fully open-source background remover</strong> optimized for images with humans.
72
+
73
  It is based on [Highly Accurate Dichotomous Image Segmentation research](https://github.com/xuebinqin/DIS).
74
  The model was trained with the synthetic [Human Segmentation Dataset](https://huggingface.co/datasets/schirrmacher/humans).
75
 
76
  This is the first iteration of the model, so there will be improvements!
77
+ If you identify cases were the model fails, <a href='https://huggingface.co/schirrmacher/ormbg/discussions' target='_blank'>upload your examples</a>!
78
 
79
  - <a href='https://huggingface.co/schirrmacher/ormbg' target='_blank'>Model card</a>: find inference code, training information, tutorials
80
  - <a href='https://huggingface.co/schirrmacher/ormbg' target='_blank'>Dataset</a>: see training images, segmentation data, backgrounds
81
  - <a href='https://huggingface.co/schirrmacher/ormbg\#research' target='_blank'>Research</a>: see current approach for improvements
 
82
 
83
+ """
84
  examples = ["./example1.png", "./example2.png", "./example3.png"]
85
 
86
  demo = gr.Interface(
 
89
  outputs="image",
90
  examples=examples,
91
  title=title,
92
+ description=description,
93
  )
94
 
95
  if __name__ == "__main__":
example1.jpeg ADDED

Git LFS Details

  • SHA256: 2a48f83d810c2ebb1c4d0e51bdc5c9b290abaea7e25fe438001f6773ff9f0939
  • Pointer size: 132 Bytes
  • Size of remote file: 3.15 MB
example2.jpeg ADDED

Git LFS Details

  • SHA256: a867f03f26b1d1b68c03f7c217ba00d52a9fbce274211492582aea385829c657
  • Pointer size: 132 Bytes
  • Size of remote file: 3.28 MB
example3.jpeg ADDED

Git LFS Details

  • SHA256: 080ad9be300673eb598d9a21aab2fbc8849f68a0b9842029dc043a2e8cf5a614
  • Pointer size: 132 Bytes
  • Size of remote file: 5.96 MB