File size: 3,833 Bytes
542c815
3f8e328
542c815
 
37ded0e
542c815
a888400
d6e753e
 
 
8a357d1
37ded0e
542c815
37ded0e
fbe23c5
37ded0e
c0a3a3c
542c815
4f91b95
 
542c815
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0c9e50b
542c815
 
0c9e50b
1605763
 
70974c3
542c815
 
70974c3
542c815
 
 
 
 
 
 
70974c3
 
542c815
 
70974c3
542c815
 
 
 
 
70974c3
 
542c815
 
 
 
 
 
70974c3
542c815
 
d909bca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
542c815
d909bca
 
 
 
542c815
d909bca
 
 
 
 
 
8cb0f2e
 
d909bca
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import numpy as np
import torch
import torch.nn.functional as F
from torchvision.transforms.functional import normalize
# from foo import hello
import gradio as gr
from gradio_imageslider import ImageSlider
from briarmbg import BriaRMBG
import PIL
from PIL import Image
from typing import Tuple
# import git  # pip install gitpython

# hello()

# git.Git(".").clone("https://huggingface.co/briaai/RMBG-1.4")
# git.Git(".").clone("git@hf.co:briaai/RMBG-1.4")

net=BriaRMBG()
model_path = "./model.pth"
if torch.cuda.is_available():
    net.load_state_dict(torch.load(model_path))
    net=net.cuda()
else:
    net.load_state_dict(torch.load(model_path,map_location="cpu"))
net.eval() 

def image_size_by_min_resolution(
    image: Image.Image,
    resolution: Tuple,
    resample=None,
):
    w, h = image.size  

    image_min = min(w, h)
    resolution_min = min(resolution)

    scale_factor = image_min / resolution_min

    resize_to: Tuple[int, int] = (
        int(w // scale_factor),
        int(h // scale_factor),
    )
    return resize_to
    

def resize_image(image):
    image = image.convert('RGB')
    new_image_size = image_size_by_min_resolution(image=image,resolution=(1024, 1024))
    image = image.resize(new_image_size, Image.BILINEAR)
    return image


def process(image):

    # prepare input
    print(type(image))
    print(image.shape)
    orig_image = Image.fromarray(image)
    # return [orig_image,orig_image]
    w,h = orig_im_size = orig_image.size
    image = resize_image(orig_image)
    print("process debug1")    
    im_np = np.array(image)
    im_tensor = torch.tensor(im_np, dtype=torch.float32).permute(2,0,1)
    im_tensor = torch.unsqueeze(im_tensor,0)
    im_tensor = torch.divide(im_tensor,255.0)
    im_tensor = normalize(im_tensor,[0.5,0.5,0.5],[1.0,1.0,1.0])
    if torch.cuda.is_available():
        im_tensor=im_tensor.cuda()

    print("process debug2")
    #inference
    result=net(im_tensor)
    print("process debug3")
    # post process
    result = torch.squeeze(F.interpolate(result[0][0], size=(h,w), mode='bilinear') ,0)
    ma = torch.max(result)
    mi = torch.min(result)
    result = (result-mi)/(ma-mi)    
    print("process debug4")
    # image to pil
    im_array = (result*255).cpu().data.numpy().astype(np.uint8)
    pil_im = Image.fromarray(np.squeeze(im_array))
    # paste the mask on the original image
    new_im = Image.new("RGBA", pil_im.size, (0,0,0))
    new_im.paste(orig_image, mask=pil_im)

    return [orig_image, new_im]


# block = gr.Blocks().queue()

# with block:
#     gr.Markdown("## BRIA RMBG 1.4")
#     gr.HTML('''
#       <p style="margin-bottom: 10px; font-size: 94%">
#         This is a demo for BRIA RMBG 1.4 that using
#         <a href="https://huggingface.co/briaai/RMBG-1.4" target="_blank">BRIA RMBG-1.4 image matting model</a> as backbone. 
#       </p>
#     ''')
#     with gr.Row():
#         with gr.Column():
#             input_image = gr.Image(sources=None, type="pil") # None for upload, ctrl+v and webcam
#             # input_image = gr.Image(sources=None, type="numpy") # None for upload, ctrl+v and webcam
#             run_button = gr.Button(value="Run")
            
#         with gr.Column():
#             result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery", columns=[1], height='auto')
#     ips = [input_image]
#     run_button.click(fn=process, inputs=ips, outputs=[result_gallery])

# block.launch(debug = True)


title = "background_removal"
description = "remove image background"
examples = [['./input.jpg'],]
output = ImageSlider(position=0.5,label='Image without background slider-view', type="pil")
demo = gr.Interface(fn=process,inputs="image", outputs=output, examples=examples, title=title, description=description)

if __name__ == "__main__":
    demo.launch(share=False)