arxivgpt kim commited on
Commit
34bcb5d
·
verified ·
1 Parent(s): 4941fcb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +80 -67
app.py CHANGED
@@ -4,103 +4,116 @@ import torch.nn.functional as F
4
  from torchvision.transforms.functional import normalize
5
  from huggingface_hub import hf_hub_download
6
  import gradio as gr
7
- from gradio_imageslider import ImageSlider
8
  from briarmbg import BriaRMBG
9
  import PIL
10
  from PIL import Image
11
- from typing import Tuple
12
 
13
- net=BriaRMBG()
14
- # model_path = "./model1.pth"
15
  model_path = hf_hub_download("briaai/RMBG-1.4", 'model.pth')
16
  if torch.cuda.is_available():
17
  net.load_state_dict(torch.load(model_path))
18
- net=net.cuda()
19
  else:
20
- net.load_state_dict(torch.load(model_path,map_location="cpu"))
21
- net.eval()
22
 
23
-
24
  def resize_image(image):
25
  image = image.convert('RGB')
26
  model_input_size = (1024, 1024)
27
  image = image.resize(model_input_size, Image.BILINEAR)
28
  return image
29
 
30
-
31
  def process(image):
32
-
33
- # prepare input
34
  orig_image = Image.fromarray(image)
35
- w,h = orig_im_size = orig_image.size
36
  image = resize_image(orig_image)
37
  im_np = np.array(image)
38
- im_tensor = torch.tensor(im_np, dtype=torch.float32).permute(2,0,1)
39
- im_tensor = torch.unsqueeze(im_tensor,0)
40
- im_tensor = torch.divide(im_tensor,255.0)
41
- im_tensor = normalize(im_tensor,[0.5,0.5,0.5],[1.0,1.0,1.0])
42
  if torch.cuda.is_available():
43
- im_tensor=im_tensor.cuda()
44
 
45
- #inference
46
- result=net(im_tensor)
47
- # post process
48
- result = torch.squeeze(F.interpolate(result[0][0], size=(h,w), mode='bilinear') ,0)
49
  ma = torch.max(result)
50
  mi = torch.min(result)
51
- result = (result-mi)/(ma-mi)
52
- # image to pil
53
- im_array = (result*255).cpu().data.numpy().astype(np.uint8)
54
  pil_im = Image.fromarray(np.squeeze(im_array))
55
- # paste the mask on the original image
56
- new_im = Image.new("RGBA", pil_im.size, (0,0,0,0))
57
  new_im.paste(orig_image, mask=pil_im)
58
- # new_orig_image = orig_image.convert('RGBA')
59
 
60
  return new_im
61
- # return [new_orig_image, new_im]
62
-
63
-
64
- # block = gr.Blocks().queue()
65
 
66
- # with block:
67
- # gr.Markdown("## BRIA RMBG 1.4")
68
- # gr.HTML('''
69
- # <p style="margin-bottom: 10px; font-size: 94%">
70
- # This is a demo for BRIA RMBG 1.4 that using
71
- # <a href="https://huggingface.co/briaai/RMBG-1.4" target="_blank">BRIA RMBG-1.4 image matting model</a> as backbone.
72
- # </p>
73
- # ''')
74
- # with gr.Row():
75
- # with gr.Column():
76
- # input_image = gr.Image(sources=None, type="pil") # None for upload, ctrl+v and webcam
77
- # # input_image = gr.Image(sources=None, type="numpy") # None for upload, ctrl+v and webcam
78
- # run_button = gr.Button(value="Run")
79
-
80
- # with gr.Column():
81
- # result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery", columns=[1], height='auto')
82
- # ips = [input_image]
83
- # run_button.click(fn=process, inputs=ips, outputs=[result_gallery])
84
-
85
- # block.launch(debug = True)
86
-
87
- # block = gr.Blocks().queue()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
 
89
- gr.Markdown("## BRIA RMBG 1.4")
90
- gr.HTML('''
91
- <p style="margin-bottom: 10px; font-size: 94%">
92
- This is a demo for BRIA RMBG 1.4 that using
93
- <a href="https://huggingface.co/briaai/RMBG-1.4" target="_blank">BRIA RMBG-1.4 image matting model</a> as backbone.
94
- </p>
95
- ''')
96
  title = "Background Removal"
97
- description = r"""Background removal model developed by <a href='https://BRIA.AI' target='_blank'><b>BRIA.AI</b></a>, trained on a carefully selected dataset and is available as an open-source model for non-commercial use.<br>
98
- For test upload your image and wait. Read more at model card <a href='https://huggingface.co/briaai/RMBG-1.4' target='_blank'><b>briaai/RMBG-1.4</b></a>.<br>
 
 
99
  """
100
- examples = [['./input.jpg'],]
101
- # output = ImageSlider(position=0.5,label='Image without background', type="pil", show_download_button=True)
102
- # demo = gr.Interface(fn=process,inputs="image", outputs=output, examples=examples, title=title, description=description)
103
- demo = gr.Interface(fn=process,inputs="image", outputs="image", examples=examples, title=title, description=description)
 
 
 
 
 
104
 
105
  if __name__ == "__main__":
106
- demo.launch(share=False)
 
4
  from torchvision.transforms.functional import normalize
5
  from huggingface_hub import hf_hub_download
6
  import gradio as gr
 
7
  from briarmbg import BriaRMBG
8
  import PIL
9
  from PIL import Image
 
10
 
11
+ net = BriaRMBG()
 
12
  model_path = hf_hub_download("briaai/RMBG-1.4", 'model.pth')
13
  if torch.cuda.is_available():
14
  net.load_state_dict(torch.load(model_path))
15
+ net = net.cuda()
16
  else:
17
+ net.load_state_dict(torch.load(model_path, map_location="cpu"))
18
+ net.eval()
19
 
 
20
  def resize_image(image):
21
  image = image.convert('RGB')
22
  model_input_size = (1024, 1024)
23
  image = image.resize(model_input_size, Image.BILINEAR)
24
  return image
25
 
 
26
  def process(image):
 
 
27
  orig_image = Image.fromarray(image)
28
+ w, h = orig_image.size
29
  image = resize_image(orig_image)
30
  im_np = np.array(image)
31
+ im_tensor = torch.tensor(im_np, dtype=torch.float32).permute(2, 0, 1)
32
+ im_tensor = torch.unsqueeze(im_tensor, 0)
33
+ im_tensor = torch.divide(im_tensor, 255.0)
34
+ im_tensor = normalize(im_tensor, [0.5, 0.5, 0.5], [1.0, 1.0, 1.0])
35
  if torch.cuda.is_available():
36
+ im_tensor = im_tensor.cuda()
37
 
38
+ result = net(im_tensor)
39
+ result = torch.squeeze(F.interpolate(result[0][0], size=(h, w), mode='bilinear'), 0)
 
 
40
  ma = torch.max(result)
41
  mi = torch.min(result)
42
+ result = (result - mi) / (ma - mi)
43
+ im_array = (result * 255).cpu().data.numpy().astype(np.uint8)
 
44
  pil_im = Image.fromarray(np.squeeze(im_array))
45
+ new_im = Image.new("RGBA", pil_im.size, (0, 0, 0, 0))
 
46
  new_im.paste(orig_image, mask=pil_im)
 
47
 
48
  return new_im
 
 
 
 
49
 
50
+ css = """
51
+ body {
52
+ font-family: 'Arial', sans-serif;
53
+ margin: 0;
54
+ padding: 0;
55
+ background-color: #f0f2f5;
56
+ color: #333;
57
+ }
58
+ h1 {
59
+ color: #0000ff;
60
+ }
61
+ p {
62
+ color: #000000;
63
+ }
64
+ .gradio-app, .gradio-content {
65
+ background-color: #ffffff;
66
+ border-radius: 8px;
67
+ border: 1px solid #ccc;
68
+ box-shadow: 0 10px 25px 0 rgba(0,0,0,0.1);
69
+ padding: 20px;
70
+ }
71
+ button {
72
+ border: none;
73
+ color: white;
74
+ padding: 10px 20px;
75
+ margin: 10px 0;
76
+ cursor: pointer;
77
+ border-radius: 5px;
78
+ background-image: linear-gradient(to right, #6a11cb 0%, #2575fc 100%);
79
+ box-shadow: 0 4px 6px rgba(0, 0, 0, 0.2);
80
+ transition: all 0.2s ease-in-out;
81
+ }
82
+ button:hover {
83
+ box-shadow: 0 6px 8px rgba(0, 0, 0, 0.3);
84
+ }
85
+ input, textarea {
86
+ border: 2px solid #2575fc;
87
+ border-radius: 4px;
88
+ padding: 10px;
89
+ margin: 10px 0;
90
+ width: 100%;
91
+ box-sizing: border-box;
92
+ box-shadow: inset 0 2px 4px rgba(0, 0, 0, 0.1);
93
+ }
94
+ .gradio-toolbar {
95
+ background-color: #f0f2f5;
96
+ }
97
+ footer {
98
+ visibility: hidden;
99
+ }
100
+ """
101
 
 
 
 
 
 
 
 
102
  title = "Background Removal"
103
+ description = """
104
+ This is a demo for BRIA RMBG 1.4 using the BRIA RMBG-1.4 image matting model as a backbone.<br>
105
+ Background removal model developed by <a href='https://BRIA.AI' target='_blank'><b>BRIA.AI</b></a>, trained on a carefully selected dataset and is available as an open-source model for non-commercial use.<br>
106
+ For a test, upload your image and wait. Read more at the model card <a href='https://huggingface.co/briaai/RMBG-1.4' target='_blank'><b>briaai/RMBG-1.4</b></a>.<br>
107
  """
108
+
109
+ demo = gr.Interface(
110
+ fn=process,
111
+ inputs=gr.Image(type="pil"),
112
+ outputs=gr.Image(type="pil"),
113
+ title=title,
114
+ description=description,
115
+ css=css
116
+ )
117
 
118
  if __name__ == "__main__":
119
+ demo.launch(share=False)