Spaces:
Running
on
Zero
Running
on
Zero
File size: 3,088 Bytes
0383b74 80fd191 913742e 80fd191 5932f9f 80fd191 0383b74 665e653 0383b74 cc6c61e 2c218d6 665e653 80fd191 2c218d6 0383b74 665e653 0383b74 80fd191 0383b74 c57634b 665e653 0383b74 80fd191 0383b74 80fd191 0383b74 665e653 0383b74 665e653 0383b74 665e653 0383b74 665e653 80fd191 2c218d6 0383b74 cc6c61e 80fd191 79d0705 5d8c246 0383b74 80fd191 85f55d7 79d0705 2c218d6 0383b74 92c1934 eacf438 92c1934 80fd191 2c218d6 80fd191 5932f9f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 |
import spaces
import numpy as np
import torch
import torch.nn.functional as F
import gradio as gr
from ormbg import ORMBG
from PIL import Image
model_path = "ormbg.pth"
# Load the model globally but don't send to device yet
net = ORMBG()
net.load_state_dict(torch.load(model_path, map_location="cpu"))
net.eval()
def resize_image(image):
image = image.convert("RGB")
model_input_size = (1024, 1024)
image = image.resize(model_input_size, Image.BILINEAR)
return image
@spaces.GPU
@torch.inference_mode()
def inference(image):
# Check for CUDA and set the device inside inference
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
net.to(device)
# Prepare input
orig_image = Image.fromarray(image)
w, h = orig_image.size
image = resize_image(orig_image)
im_np = np.array(image)
im_tensor = torch.tensor(im_np, dtype=torch.float32).permute(2, 0, 1)
im_tensor = torch.unsqueeze(im_tensor, 0)
im_tensor = torch.divide(im_tensor, 255.0)
if torch.cuda.is_available():
im_tensor = im_tensor.to(device)
# Inference
result = net(im_tensor)
# Post process
result = torch.squeeze(F.interpolate(result[0][0], size=(h, w), mode="bilinear"), 0)
ma = torch.max(result)
mi = torch.min(result)
result = (result - mi) / (ma - mi)
# Image to PIL
im_array = (result * 255).cpu().data.numpy().astype(np.uint8)
pil_im = Image.fromarray(np.squeeze(im_array))
# Paste the mask on the original image
new_im = Image.new("RGBA", pil_im.size, (0, 0, 0, 0))
new_im.paste(orig_image, mask=pil_im)
return new_im
# Gradio interface setup
title = "Open Remove Background Model (ormbg)"
description = r"""
This model is a <strong>fully open-source background remover</strong> optimized for images with humans. It is based on [Highly Accurate Dichotomous Image Segmentation research](https://github.com/xuebinqin/DIS). The model was trained with the synthetic <a href="https://huggingface.co/datasets/schirrmacher/humans">Human Segmentation Dataset</a>, <a href="https://paperswithcode.com/dataset/p3m-10k">P3M-10k</a> and <a href="https://paperswithcode.com/dataset/aim-500">AIM-500</a>.
If you identify cases where the model fails, <a href='https://huggingface.co/schirrmacher/ormbg/discussions' target='_blank'>upload your examples</a>!
- <a href='https://huggingface.co/schirrmacher/ormbg' target='_blank'>Model card</a>: find inference code, training information, tutorials
- <a href='https://huggingface.co/schirrmacher/ormbg' target='_blank'>Dataset</a>: see training images, segmentation data, backgrounds
- <a href='https://huggingface.co/schirrmacher/ormbg\#research' target='_blank'>Research</a>: see current approach for improvements
"""
examples = [
"example01.jpeg",
"example02.jpeg",
"example03.jpeg",
]
demo = gr.Interface(
fn=inference,
inputs="image",
outputs="image",
examples=examples,
title=title,
description=description,
)
if __name__ == "__main__":
demo.launch(share=False, allowed_paths=["./"])
|