arxivgpt kim commited on
Commit
dd67556
·
verified ·
1 Parent(s): 34bcb5d

Upload app (15).py

Browse files
Files changed (1) hide show
  1. app (15).py +106 -0
app (15).py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from torchvision.transforms.functional import normalize
5
+ from huggingface_hub import hf_hub_download
6
+ import gradio as gr
7
+ from gradio_imageslider import ImageSlider
8
+ from briarmbg import BriaRMBG
9
+ import PIL
10
+ from PIL import Image
11
+ from typing import Tuple
12
+
13
+ net=BriaRMBG()
14
+ # model_path = "./model1.pth"
15
+ model_path = hf_hub_download("briaai/RMBG-1.4", 'model.pth')
16
+ if torch.cuda.is_available():
17
+ net.load_state_dict(torch.load(model_path))
18
+ net=net.cuda()
19
+ else:
20
+ net.load_state_dict(torch.load(model_path,map_location="cpu"))
21
+ net.eval()
22
+
23
+
24
+ def resize_image(image):
25
+ image = image.convert('RGB')
26
+ model_input_size = (1024, 1024)
27
+ image = image.resize(model_input_size, Image.BILINEAR)
28
+ return image
29
+
30
+
31
+ def process(image):
32
+
33
+ # prepare input
34
+ orig_image = Image.fromarray(image)
35
+ w,h = orig_im_size = orig_image.size
36
+ image = resize_image(orig_image)
37
+ im_np = np.array(image)
38
+ im_tensor = torch.tensor(im_np, dtype=torch.float32).permute(2,0,1)
39
+ im_tensor = torch.unsqueeze(im_tensor,0)
40
+ im_tensor = torch.divide(im_tensor,255.0)
41
+ im_tensor = normalize(im_tensor,[0.5,0.5,0.5],[1.0,1.0,1.0])
42
+ if torch.cuda.is_available():
43
+ im_tensor=im_tensor.cuda()
44
+
45
+ #inference
46
+ result=net(im_tensor)
47
+ # post process
48
+ result = torch.squeeze(F.interpolate(result[0][0], size=(h,w), mode='bilinear') ,0)
49
+ ma = torch.max(result)
50
+ mi = torch.min(result)
51
+ result = (result-mi)/(ma-mi)
52
+ # image to pil
53
+ im_array = (result*255).cpu().data.numpy().astype(np.uint8)
54
+ pil_im = Image.fromarray(np.squeeze(im_array))
55
+ # paste the mask on the original image
56
+ new_im = Image.new("RGBA", pil_im.size, (0,0,0,0))
57
+ new_im.paste(orig_image, mask=pil_im)
58
+ # new_orig_image = orig_image.convert('RGBA')
59
+
60
+ return new_im
61
+ # return [new_orig_image, new_im]
62
+
63
+
64
+ # block = gr.Blocks().queue()
65
+
66
+ # with block:
67
+ # gr.Markdown("## BRIA RMBG 1.4")
68
+ # gr.HTML('''
69
+ # <p style="margin-bottom: 10px; font-size: 94%">
70
+ # This is a demo for BRIA RMBG 1.4 that using
71
+ # <a href="https://huggingface.co/briaai/RMBG-1.4" target="_blank">BRIA RMBG-1.4 image matting model</a> as backbone.
72
+ # </p>
73
+ # ''')
74
+ # with gr.Row():
75
+ # with gr.Column():
76
+ # input_image = gr.Image(sources=None, type="pil") # None for upload, ctrl+v and webcam
77
+ # # input_image = gr.Image(sources=None, type="numpy") # None for upload, ctrl+v and webcam
78
+ # run_button = gr.Button(value="Run")
79
+
80
+ # with gr.Column():
81
+ # result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery", columns=[1], height='auto')
82
+ # ips = [input_image]
83
+ # run_button.click(fn=process, inputs=ips, outputs=[result_gallery])
84
+
85
+ # block.launch(debug = True)
86
+
87
+ # block = gr.Blocks().queue()
88
+
89
+ gr.Markdown("## BRIA RMBG 1.4")
90
+ gr.HTML('''
91
+ <p style="margin-bottom: 10px; font-size: 94%">
92
+ This is a demo for BRIA RMBG 1.4 that using
93
+ <a href="https://huggingface.co/briaai/RMBG-1.4" target="_blank">BRIA RMBG-1.4 image matting model</a> as backbone.
94
+ </p>
95
+ ''')
96
+ title = "Background Removal"
97
+ description = r"""Background removal model developed by <a href='https://BRIA.AI' target='_blank'><b>BRIA.AI</b></a>, trained on a carefully selected dataset and is available as an open-source model for non-commercial use.<br>
98
+ For test upload your image and wait. Read more at model card <a href='https://huggingface.co/briaai/RMBG-1.4' target='_blank'><b>briaai/RMBG-1.4</b></a>.<br>
99
+ """
100
+ examples = [['./input.jpg'],]
101
+ # output = ImageSlider(position=0.5,label='Image without background', type="pil", show_download_button=True)
102
+ # demo = gr.Interface(fn=process,inputs="image", outputs=output, examples=examples, title=title, description=description)
103
+ demo = gr.Interface(fn=process,inputs="image", outputs="image", examples=examples, title=title, description=description)
104
+
105
+ if __name__ == "__main__":
106
+ demo.launch(share=False)