File size: 2,974 Bytes
6c0050a
80fd191
 
 
 
 
 
 
665e653
80fd191
d7921b8
665e653
d7921b8
cc6c61e
 
665e653
 
 
 
 
80fd191
675f1f9
 
665e653
d7921b8
 
 
80fd191
d7921b8
c57634b
665e653
 
 
 
 
 
d7921b8
80fd191
d7921b8
80fd191
d7921b8
665e653
d7921b8
665e653
 
 
 
d7921b8
665e653
 
d7921b8
665e653
 
 
 
80fd191
d7921b8
cc6c61e
80fd191
e19fd5f
6dc15f4
 
 
efe4474
d7921b8
80fd191
85f55d7
 
 
80fd191
d7921b8
4d6ff3b
80fd191
 
 
 
 
 
 
d7921b8
80fd191
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import spaces
import numpy as np
import torch
import torch.nn.functional as F
import gradio as gr
from ormbg import ORMBG
from PIL import Image

model_path = "ormbg.pth"

# Load the model globally but don't send to device yet
net = ORMBG()
net.load_state_dict(torch.load(model_path, map_location="cpu"))
net.eval()

def resize_image(image):
    image = image.convert("RGB")
    model_input_size = (1024, 1024)
    image = image.resize(model_input_size, Image.BILINEAR)
    return image

@spaces.GPU
@torch.inference_mode()
def inference(image):
    # Check for CUDA and set the device inside inference
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    net.to(device)

    # Prepare input
    orig_image = Image.fromarray(image)
    w, h = orig_image.size
    image = resize_image(orig_image)
    im_np = np.array(image)
    im_tensor = torch.tensor(im_np, dtype=torch.float32).permute(2, 0, 1)
    im_tensor = torch.unsqueeze(im_tensor, 0)
    im_tensor = torch.divide(im_tensor, 255.0)

    if torch.cuda.is_available():
        im_tensor = im_tensor.to(device)

    # Inference
    result = net(im_tensor)
    # Post process
    result = torch.squeeze(F.interpolate(result[0][0], size=(h, w), mode="bilinear"), 0)
    ma = torch.max(result)
    mi = torch.min(result)
    result = (result - mi) / (ma - mi)
    # Image to PIL
    im_array = (result * 255).cpu().data.numpy().astype(np.uint8)
    pil_im = Image.fromarray(np.squeeze(im_array))
    # Paste the mask on the original image
    new_im = Image.new("RGBA", pil_im.size, (0, 0, 0, 0))
    new_im.paste(orig_image, mask=pil_im)

    return new_im

# Gradio interface setup
title = "Open Remove Background Model (ormbg)"
description = r"""
This model is a <strong>fully open-source background remover</strong> optimized for images with humans.
It is based on [Highly Accurate Dichotomous Image Segmentation research](https://github.com/xuebinqin/DIS).
The model was trained with the synthetic [Human Segmentation Dataset](https://huggingface.co/datasets/schirrmacher/humans).

This is the first iteration of the model, so there will be improvements!
If you identify cases where the model fails, <a href='https://huggingface.co/schirrmacher/ormbg/discussions' target='_blank'>upload your examples</a>!

- <a href='https://huggingface.co/schirrmacher/ormbg' target='_blank'>Model card</a>: find inference code, training information, tutorials
- <a href='https://huggingface.co/schirrmacher/ormbg' target='_blank'>Dataset</a>: see training images, segmentation data, backgrounds
- <a href='https://huggingface.co/schirrmacher/ormbg\#research' target='_blank'>Research</a>: see current approach for improvements
"""

examples = ["./example1.png", "./example2.png", "./example3.png"]

demo = gr.Interface(
    fn=inference,
    inputs="image",
    outputs="image",
    examples=examples,
    title=title,
    description=description
)

if __name__ == "__main__":
    demo.launch(share=False)