Files changed (1) hide show
  1. app.py +174 -46
app.py CHANGED
@@ -4,25 +4,35 @@ import torch.nn.functional as F
4
  from torchvision.transforms.functional import normalize
5
  from huggingface_hub import hf_hub_download
6
  import gradio as gr
7
- from gradio_imageslider import ImageSlider
 
8
  from briarmbg import BriaRMBG
9
  import PIL
10
  from PIL import Image
11
  from typing import Tuple
12
 
13
- net=BriaRMBG()
 
 
 
 
 
 
 
14
  # model_path = "./model1.pth"
15
- model_path = hf_hub_download("briaai/RMBG-1.4", 'model.pth')
16
  if torch.cuda.is_available():
17
  net.load_state_dict(torch.load(model_path))
18
- net=net.cuda()
 
19
  else:
20
- net.load_state_dict(torch.load(model_path,map_location="cpu"))
21
- net.eval()
 
 
22
 
23
-
24
  def resize_image(image):
25
- image = image.convert('RGB')
26
  model_input_size = (1024, 1024)
27
  image = image.resize(model_input_size, Image.BILINEAR)
28
  return image
@@ -32,28 +42,28 @@ def process(image):
32
 
33
  # prepare input
34
  orig_image = Image.fromarray(image)
35
- w,h = orig_im_size = orig_image.size
36
  image = resize_image(orig_image)
37
  im_np = np.array(image)
38
- im_tensor = torch.tensor(im_np, dtype=torch.float32).permute(2,0,1)
39
- im_tensor = torch.unsqueeze(im_tensor,0)
40
- im_tensor = torch.divide(im_tensor,255.0)
41
- im_tensor = normalize(im_tensor,[0.5,0.5,0.5],[1.0,1.0,1.0])
42
  if torch.cuda.is_available():
43
- im_tensor=im_tensor.cuda()
44
 
45
- #inference
46
- result=net(im_tensor)
47
  # post process
48
- result = torch.squeeze(F.interpolate(result[0][0], size=(h,w), mode='bilinear') ,0)
49
  ma = torch.max(result)
50
  mi = torch.min(result)
51
- result = (result-mi)/(ma-mi)
52
  # image to pil
53
- im_array = (result*255).cpu().data.numpy().astype(np.uint8)
54
  pil_im = Image.fromarray(np.squeeze(im_array))
55
  # paste the mask on the original image
56
- new_im = Image.new("RGBA", pil_im.size, (0,0,0,0))
57
  new_im.paste(orig_image, mask=pil_im)
58
  # new_orig_image = orig_image.convert('RGBA')
59
 
@@ -61,46 +71,164 @@ def process(image):
61
  # return [new_orig_image, new_im]
62
 
63
 
64
- # block = gr.Blocks().queue()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
- # with block:
67
- # gr.Markdown("## BRIA RMBG 1.4")
68
- # gr.HTML('''
69
- # <p style="margin-bottom: 10px; font-size: 94%">
70
- # This is a demo for BRIA RMBG 1.4 that using
71
- # <a href="https://huggingface.co/briaai/RMBG-1.4" target="_blank">BRIA RMBG-1.4 image matting model</a> as backbone.
72
- # </p>
73
- # ''')
74
- # with gr.Row():
75
- # with gr.Column():
76
- # input_image = gr.Image(sources=None, type="pil") # None for upload, ctrl+v and webcam
77
- # # input_image = gr.Image(sources=None, type="numpy") # None for upload, ctrl+v and webcam
78
- # run_button = gr.Button(value="Run")
79
-
80
- # with gr.Column():
81
- # result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery", columns=[1], height='auto')
82
- # ips = [input_image]
83
- # run_button.click(fn=process, inputs=ips, outputs=[result_gallery])
84
 
85
- # block.launch(debug = True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
 
87
- # block = gr.Blocks().queue()
88
 
89
  gr.Markdown("## BRIA RMBG 1.4")
90
- gr.HTML('''
 
91
  <p style="margin-bottom: 10px; font-size: 94%">
92
  This is a demo for BRIA RMBG 1.4 that using
93
  <a href="https://huggingface.co/briaai/RMBG-1.4" target="_blank">BRIA RMBG-1.4 image matting model</a> as backbone.
94
  </p>
95
- ''')
 
96
  title = "Background Removal"
97
  description = r"""Background removal model developed by <a href='https://BRIA.AI' target='_blank'><b>BRIA.AI</b></a>, trained on a carefully selected dataset and is available as an open-source model for non-commercial use.<br>
98
  For test upload your image and wait. Read more at model card <a href='https://huggingface.co/briaai/RMBG-1.4' target='_blank'><b>briaai/RMBG-1.4</b></a>.<br>
99
  """
100
- examples = [['./input.jpg'],]
 
 
 
 
 
 
 
 
 
101
  # output = ImageSlider(position=0.5,label='Image without background', type="pil", show_download_button=True)
102
  # demo = gr.Interface(fn=process,inputs="image", outputs=output, examples=examples, title=title, description=description)
103
- demo = gr.Interface(fn=process,inputs="image", outputs="image", examples=examples, title=title, description=description)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
 
105
  if __name__ == "__main__":
106
- demo.launch(share=False)
 
4
  from torchvision.transforms.functional import normalize
5
  from huggingface_hub import hf_hub_download
6
  import gradio as gr
7
+
8
+ # from gradio_imageslider import ImageSlider
9
  from briarmbg import BriaRMBG
10
  import PIL
11
  from PIL import Image
12
  from typing import Tuple
13
 
14
+ import cv2
15
+ import os
16
+ import shutil
17
+ import glob
18
+ from tqdm import tqdm
19
+ from ffmpy import FFmpeg
20
+
21
+ net = BriaRMBG()
22
  # model_path = "./model1.pth"
23
+ model_path = hf_hub_download("briaai/RMBG-1.4", "model.pth")
24
  if torch.cuda.is_available():
25
  net.load_state_dict(torch.load(model_path))
26
+ net = net.cuda()
27
+ print("GPU is available")
28
  else:
29
+ net.load_state_dict(torch.load(model_path, map_location="cpu"))
30
+ print("GPU is NOT available")
31
+ net.eval()
32
+
33
 
 
34
  def resize_image(image):
35
+ image = image.convert("RGB")
36
  model_input_size = (1024, 1024)
37
  image = image.resize(model_input_size, Image.BILINEAR)
38
  return image
 
42
 
43
  # prepare input
44
  orig_image = Image.fromarray(image)
45
+ w, h = orig_im_size = orig_image.size
46
  image = resize_image(orig_image)
47
  im_np = np.array(image)
48
+ im_tensor = torch.tensor(im_np, dtype=torch.float32).permute(2, 0, 1)
49
+ im_tensor = torch.unsqueeze(im_tensor, 0)
50
+ im_tensor = torch.divide(im_tensor, 255.0)
51
+ im_tensor = normalize(im_tensor, [0.5, 0.5, 0.5], [1.0, 1.0, 1.0])
52
  if torch.cuda.is_available():
53
+ im_tensor = im_tensor.cuda()
54
 
55
+ # inference
56
+ result = net(im_tensor)
57
  # post process
58
+ result = torch.squeeze(F.interpolate(result[0][0], size=(h, w), mode="bilinear"), 0)
59
  ma = torch.max(result)
60
  mi = torch.min(result)
61
+ result = (result - mi) / (ma - mi)
62
  # image to pil
63
+ im_array = (result * 255).cpu().data.numpy().astype(np.uint8)
64
  pil_im = Image.fromarray(np.squeeze(im_array))
65
  # paste the mask on the original image
66
+ new_im = Image.new("RGBA", pil_im.size, (0, 0, 0, 0))
67
  new_im.paste(orig_image, mask=pil_im)
68
  # new_orig_image = orig_image.convert('RGBA')
69
 
 
71
  # return [new_orig_image, new_im]
72
 
73
 
74
+ def process_video(video, key_color):
75
+ workspace = "./temp"
76
+ original_video_name_without_ext = os.path.splitext(os.path.basename(video))[0]
77
+
78
+ os.makedirs(workspace, exist_ok=True)
79
+ os.makedirs(f"{workspace}/frames", exist_ok=True)
80
+ os.makedirs(f"{workspace}/result", exist_ok=True)
81
+ os.makedirs("./video_result", exist_ok=True)
82
+
83
+ video_file = cv2.VideoCapture(video)
84
+ fps = video_file.get(cv2.CAP_PROP_FPS)
85
+
86
+ # まず、videoを読み込んで、./frames/にフレームを保存する
87
+ # fase, load video and save frames to ./frames/
88
+ def extract_frames():
89
+ success, frame = video_file.read()
90
+ frame_num = 0
91
+ with tqdm(
92
+ total=None,
93
+ desc="Extracting frames",
94
+ ) as pbar:
95
+ while success:
96
+ file_name = f"{workspace}/frames/{frame_num:015d}.png"
97
+ cv2.imwrite(file_name, frame)
98
+ success, frame = video_file.read()
99
+ frame_num += 1
100
+ pbar.update(1)
101
+ video_file.release()
102
+ return
103
+
104
+ extract_frames()
105
+
106
+ # それぞれのフレームに対して処理を行う
107
+ # process each frame
108
+ def process_frame(frame_file):
109
+ image = cv2.imread(frame_file)
110
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
111
+ new_image = process(image)
112
+ # key_colorを背景にする
113
+ # set key_color as background
114
+ key_back_image = Image.new("RGBA", new_image.size, key_color)
115
+ new_image = Image.alpha_composite(key_back_image, new_image)
116
+ new_image.save(frame_file)
117
+
118
+ frame_files = sorted(glob.glob(f"{workspace}/frames/*.png"))
119
+ with tqdm(total=len(frame_files), desc="Processing frames") as pbar:
120
+ for file in frame_files:
121
+ process_frame(file)
122
+ pbar.update(1)
123
+
124
+ # frameからvideoを作成
125
+ # create video from frames
126
+ # first_frame = cv2.imread(frame_files[0])
127
+ # h, w, _ = first_frame.shape
128
+ # fourcc = cv2.VideoWriter_fourcc(*"avc1")
129
+ # new_video = cv2.VideoWriter(f"{workspace}/result/video.mp4", fourcc, fps, (w, h))
130
+
131
+ # for file in frame_files:
132
+ # image = cv2.imread(file)
133
+ # new_video.write(image)
134
+ # new_video.release()
135
+
136
+ # 上のコードをffmpyで書き直す
137
+ # rewrite the above code with ffmpy
138
+ ff = FFmpeg(
139
+ inputs={f"{workspace}/frames/%015d.png": f"-r {fps}"},
140
+ outputs={
141
+ f"{workspace}/result/video.mp4": f"-c:v libx264 -vf fps={fps},format=yuv420p -hide_banner -loglevel error -y"
142
+ },
143
+ )
144
+ ff.run()
145
+ # issue
146
+ # なぜかkey_colorの背景色が暗くなる
147
+ # idk why but key_color background color becomes dark
148
 
149
+ ff2 = FFmpeg(
150
+ inputs={f"{workspace}/result/video.mp4": None, f"{video}": None},
151
+ outputs={
152
+ f"./video_result/{original_video_name_without_ext}_BGremoved.mp4": "-c:v copy -c:a aac -strict experimental -map 0:v:0 -map 1:a:0 -shortest -hide_banner -loglevel error -y"
153
+ },
154
+ )
155
+ ff2.run()
 
 
 
 
 
 
 
 
 
 
 
156
 
157
+ # 本当は透過の動画が良かったけど互換性がないのでボツ
158
+ # I wanted to make a transparent video, but it's not compatible, so I gave up
159
+ # subprocess.run(
160
+ # f'ffmpeg -framerate {fps} -i {workspace}/frames/%015d.png -auto-alt-ref 0 -c:v libvpx "./video_result/{original_video_name_without_ext}_BGremoved.webm" -hide_banner -loglevel error -y',
161
+ # shell=True,
162
+ # check=True,
163
+ # )
164
+ # クロマキー用なので音声いらないじゃん
165
+ # audio is not needed
166
+
167
+ # subprocess.run(
168
+ # f'ffmpeg -i "./video_result/{original_video_name_without_ext}_BGremoved.webm" -c:v libx264 -c:a aac -strict experimental -b:a 192k ./demo/demo.mp4 -hide_banner -loglevel error -y',
169
+ # shell=True,
170
+ # check=True,
171
+ # )
172
+
173
+ # ゴミ削除
174
+ # remove garbage
175
+ shutil.rmtree(workspace)
176
+
177
+ return f"./video_result/{original_video_name_without_ext}_BGremoved.mp4"
178
 
 
179
 
180
  gr.Markdown("## BRIA RMBG 1.4")
181
+ gr.HTML(
182
+ """
183
  <p style="margin-bottom: 10px; font-size: 94%">
184
  This is a demo for BRIA RMBG 1.4 that using
185
  <a href="https://huggingface.co/briaai/RMBG-1.4" target="_blank">BRIA RMBG-1.4 image matting model</a> as backbone.
186
  </p>
187
+ """
188
+ )
189
  title = "Background Removal"
190
  description = r"""Background removal model developed by <a href='https://BRIA.AI' target='_blank'><b>BRIA.AI</b></a>, trained on a carefully selected dataset and is available as an open-source model for non-commercial use.<br>
191
  For test upload your image and wait. Read more at model card <a href='https://huggingface.co/briaai/RMBG-1.4' target='_blank'><b>briaai/RMBG-1.4</b></a>.<br>
192
  """
193
+ examples = [
194
+ ["./input.jpg"],
195
+ ]
196
+
197
+ title2 = "Background Removal For Video"
198
+ description2 = r"""Background removal model developed by <a href='https://BRIA.AI' target='_blank'><b>BRIA.AI</b></a>, trained on a carefully selected dataset and is available as an open-source model for non-commercial use.<br>
199
+ For test upload your image and wait. Read more at model card <a href='https://huggingface.co/briaai/RMBG-1.4' target='_blank'><b>briaai/RMBG-1.4</b></a>.<br>
200
+ Also, you can remove the background from the video.<br>You may have to wait a little longer for the video to process as each frame in video will be processed, so using strong GPU locally is recommended.<br>
201
+ """
202
+
203
  # output = ImageSlider(position=0.5,label='Image without background', type="pil", show_download_button=True)
204
  # demo = gr.Interface(fn=process,inputs="image", outputs=output, examples=examples, title=title, description=description)
205
+ demo1 = gr.Interface(
206
+ fn=process,
207
+ inputs="image",
208
+ outputs="image",
209
+ title=title,
210
+ description=description,
211
+ examples=examples,
212
+ api_name="demo1",
213
+ )
214
+
215
+
216
+ demo2 = gr.Interface(
217
+ fn=process_video,
218
+ inputs=[
219
+ gr.Video(label="Video"),
220
+ gr.ColorPicker(label="Key Color(Background color)"),
221
+ ],
222
+ outputs="video",
223
+ title=title2,
224
+ description=description2,
225
+ api_name="demo2",
226
+ )
227
+
228
+ demo = gr.TabbedInterface(
229
+ interface_list=[demo1, demo2],
230
+ tab_names=["Image", "Video"],
231
+ )
232
 
233
  if __name__ == "__main__":
234
+ demo.launch(share=False)