File size: 2,908 Bytes
80fd191
 
 
 
 
 
 
 
665e653
80fd191
665e653
 
 
80fd191
cc6c61e
 
 
 
 
 
 
80fd191
665e653
 
 
 
 
80fd191
 
665e653
80fd191
665e653
c57634b
665e653
 
 
 
 
 
80fd191
665e653
80fd191
665e653
 
80fd191
665e653
 
 
 
 
 
 
 
 
 
 
 
80fd191
 
 
 
 
 
 
 
 
 
 
cc6c61e
80fd191
e19fd5f
4d6ff3b
80fd191
51557c9
 
 
 
 
 
 
 
80fd191
4d6ff3b
80fd191
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import numpy as np
import torch
import torch.nn.functional as F
import gradio as gr
from ormbg import ORMBG
from PIL import Image


model_path = "ormbg.pth"

net = ORMBG()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
net.to(device)

if torch.cuda.is_available():
    net.load_state_dict(torch.load(model_path))
    net = net.cuda()
else:
    net.load_state_dict(torch.load(model_path, map_location="cpu"))
net.eval()


def resize_image(image):
    image = image.convert("RGB")
    model_input_size = (1024, 1024)
    image = image.resize(model_input_size, Image.BILINEAR)
    return image


def inference(image):

    # prepare input
    orig_image = Image.fromarray(image)
    w, h = orig_image.size
    image = resize_image(orig_image)
    im_np = np.array(image)
    im_tensor = torch.tensor(im_np, dtype=torch.float32).permute(2, 0, 1)
    im_tensor = torch.unsqueeze(im_tensor, 0)
    im_tensor = torch.divide(im_tensor, 255.0)
    if torch.cuda.is_available():
        im_tensor = im_tensor.cuda()

    # inference
    result = net(im_tensor)
    # post process
    result = torch.squeeze(F.interpolate(result[0][0], size=(h, w), mode="bilinear"), 0)
    ma = torch.max(result)
    mi = torch.min(result)
    result = (result - mi) / (ma - mi)
    # image to pil
    im_array = (result * 255).cpu().data.numpy().astype(np.uint8)
    pil_im = Image.fromarray(np.squeeze(im_array))
    # paste the mask on the original image
    new_im = Image.new("RGBA", pil_im.size, (0, 0, 0, 0))
    new_im.paste(orig_image, mask=pil_im)

    return new_im


gr.Markdown("## Open Remove Background Model (ormbg)")
gr.HTML(
    """
  <p style="margin-bottom: 10px; font-size: 94%">
    This is a demo for Open Remove Background Model (ormbg) that using
    <a href="https://huggingface.co/schirrmacher/ormbg" target="_blank">Open Remove Background Model (ormbg) model</a> as backbone.
  </p>
"""
)
title = "Open Remove Background Model (ormbg)"
description = r"""
This model is a <strong>fully open-source background remover</strong> optimized for images with humans.
If you identify cases were the model fails, <a href='https://huggingface.co/schirrmacher/ormbg/discussions' target='_blank'>please contact me</a>!

- <a href='https://huggingface.co/schirrmacher/ormbg' target='_blank'>Model card: inference code</a>
- <a href='https://huggingface.co/schirrmacher/ormbg' target='_blank'>Dataset: all images and backgrounds</a>

Known issues (work in progress):
- close-ups: from above, from below, profile, from side
- minor issues with hair segmentation when hair creates loops
- more various backgrounds needed

"""
examples = ["./example1.png", "./example2.png", "./example3.png"]

demo = gr.Interface(
    fn=inference,
    inputs="image",
    outputs="image",
    examples=examples,
    title=title,
    description=description,
)

if __name__ == "__main__":
    demo.launch(share=False)