akaaku commited on
Commit
c5325e3
·
verified ·
1 Parent(s): 4941fcb

Added Video feature

Browse files

But still have some issue
・key_color for background becomes dark

Files changed (1) hide show
  1. app.py +170 -45
app.py CHANGED
@@ -4,25 +4,35 @@ import torch.nn.functional as F
4
  from torchvision.transforms.functional import normalize
5
  from huggingface_hub import hf_hub_download
6
  import gradio as gr
7
- from gradio_imageslider import ImageSlider
 
8
  from briarmbg import BriaRMBG
9
  import PIL
10
  from PIL import Image
11
  from typing import Tuple
12
 
13
- net=BriaRMBG()
 
 
 
 
 
 
 
14
  # model_path = "./model1.pth"
15
- model_path = hf_hub_download("briaai/RMBG-1.4", 'model.pth')
16
  if torch.cuda.is_available():
17
  net.load_state_dict(torch.load(model_path))
18
- net=net.cuda()
 
19
  else:
20
- net.load_state_dict(torch.load(model_path,map_location="cpu"))
21
- net.eval()
 
 
22
 
23
-
24
  def resize_image(image):
25
- image = image.convert('RGB')
26
  model_input_size = (1024, 1024)
27
  image = image.resize(model_input_size, Image.BILINEAR)
28
  return image
@@ -32,28 +42,28 @@ def process(image):
32
 
33
  # prepare input
34
  orig_image = Image.fromarray(image)
35
- w,h = orig_im_size = orig_image.size
36
  image = resize_image(orig_image)
37
  im_np = np.array(image)
38
- im_tensor = torch.tensor(im_np, dtype=torch.float32).permute(2,0,1)
39
- im_tensor = torch.unsqueeze(im_tensor,0)
40
- im_tensor = torch.divide(im_tensor,255.0)
41
- im_tensor = normalize(im_tensor,[0.5,0.5,0.5],[1.0,1.0,1.0])
42
  if torch.cuda.is_available():
43
- im_tensor=im_tensor.cuda()
44
 
45
- #inference
46
- result=net(im_tensor)
47
  # post process
48
- result = torch.squeeze(F.interpolate(result[0][0], size=(h,w), mode='bilinear') ,0)
49
  ma = torch.max(result)
50
  mi = torch.min(result)
51
- result = (result-mi)/(ma-mi)
52
  # image to pil
53
- im_array = (result*255).cpu().data.numpy().astype(np.uint8)
54
  pil_im = Image.fromarray(np.squeeze(im_array))
55
  # paste the mask on the original image
56
- new_im = Image.new("RGBA", pil_im.size, (0,0,0,0))
57
  new_im.paste(orig_image, mask=pil_im)
58
  # new_orig_image = orig_image.convert('RGBA')
59
 
@@ -61,46 +71,161 @@ def process(image):
61
  # return [new_orig_image, new_im]
62
 
63
 
64
- # block = gr.Blocks().queue()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
- # with block:
67
- # gr.Markdown("## BRIA RMBG 1.4")
68
- # gr.HTML('''
69
- # <p style="margin-bottom: 10px; font-size: 94%">
70
- # This is a demo for BRIA RMBG 1.4 that using
71
- # <a href="https://huggingface.co/briaai/RMBG-1.4" target="_blank">BRIA RMBG-1.4 image matting model</a> as backbone.
72
- # </p>
73
- # ''')
74
- # with gr.Row():
75
- # with gr.Column():
76
- # input_image = gr.Image(sources=None, type="pil") # None for upload, ctrl+v and webcam
77
- # # input_image = gr.Image(sources=None, type="numpy") # None for upload, ctrl+v and webcam
78
- # run_button = gr.Button(value="Run")
79
-
80
- # with gr.Column():
81
- # result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery", columns=[1], height='auto')
82
- # ips = [input_image]
83
- # run_button.click(fn=process, inputs=ips, outputs=[result_gallery])
84
 
85
- # block.launch(debug = True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
 
87
- # block = gr.Blocks().queue()
88
 
89
  gr.Markdown("## BRIA RMBG 1.4")
90
- gr.HTML('''
 
91
  <p style="margin-bottom: 10px; font-size: 94%">
92
  This is a demo for BRIA RMBG 1.4 that using
93
  <a href="https://huggingface.co/briaai/RMBG-1.4" target="_blank">BRIA RMBG-1.4 image matting model</a> as backbone.
94
  </p>
95
- ''')
 
96
  title = "Background Removal"
97
  description = r"""Background removal model developed by <a href='https://BRIA.AI' target='_blank'><b>BRIA.AI</b></a>, trained on a carefully selected dataset and is available as an open-source model for non-commercial use.<br>
98
  For test upload your image and wait. Read more at model card <a href='https://huggingface.co/briaai/RMBG-1.4' target='_blank'><b>briaai/RMBG-1.4</b></a>.<br>
99
  """
100
- examples = [['./input.jpg'],]
 
 
 
 
 
 
 
 
 
 
 
101
  # output = ImageSlider(position=0.5,label='Image without background', type="pil", show_download_button=True)
102
  # demo = gr.Interface(fn=process,inputs="image", outputs=output, examples=examples, title=title, description=description)
103
- demo = gr.Interface(fn=process,inputs="image", outputs="image", examples=examples, title=title, description=description)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
 
105
  if __name__ == "__main__":
106
  demo.launch(share=False)
 
4
  from torchvision.transforms.functional import normalize
5
  from huggingface_hub import hf_hub_download
6
  import gradio as gr
7
+
8
+ # from gradio_imageslider import ImageSlider
9
  from briarmbg import BriaRMBG
10
  import PIL
11
  from PIL import Image
12
  from typing import Tuple
13
 
14
+ import cv2
15
+ import os
16
+ import shutil
17
+ import glob
18
+ from tqdm import tqdm
19
+ import subprocess
20
+
21
+ net = BriaRMBG()
22
  # model_path = "./model1.pth"
23
+ model_path = hf_hub_download("briaai/RMBG-1.4", "model.pth")
24
  if torch.cuda.is_available():
25
  net.load_state_dict(torch.load(model_path))
26
+ net = net.cuda()
27
+ print("GPU is available")
28
  else:
29
+ net.load_state_dict(torch.load(model_path, map_location="cpu"))
30
+ print("GPU is NOT available")
31
+ net.eval()
32
+
33
 
 
34
  def resize_image(image):
35
+ image = image.convert("RGB")
36
  model_input_size = (1024, 1024)
37
  image = image.resize(model_input_size, Image.BILINEAR)
38
  return image
 
42
 
43
  # prepare input
44
  orig_image = Image.fromarray(image)
45
+ w, h = orig_im_size = orig_image.size
46
  image = resize_image(orig_image)
47
  im_np = np.array(image)
48
+ im_tensor = torch.tensor(im_np, dtype=torch.float32).permute(2, 0, 1)
49
+ im_tensor = torch.unsqueeze(im_tensor, 0)
50
+ im_tensor = torch.divide(im_tensor, 255.0)
51
+ im_tensor = normalize(im_tensor, [0.5, 0.5, 0.5], [1.0, 1.0, 1.0])
52
  if torch.cuda.is_available():
53
+ im_tensor = im_tensor.cuda()
54
 
55
+ # inference
56
+ result = net(im_tensor)
57
  # post process
58
+ result = torch.squeeze(F.interpolate(result[0][0], size=(h, w), mode="bilinear"), 0)
59
  ma = torch.max(result)
60
  mi = torch.min(result)
61
+ result = (result - mi) / (ma - mi)
62
  # image to pil
63
+ im_array = (result * 255).cpu().data.numpy().astype(np.uint8)
64
  pil_im = Image.fromarray(np.squeeze(im_array))
65
  # paste the mask on the original image
66
+ new_im = Image.new("RGBA", pil_im.size, (0, 0, 0, 0))
67
  new_im.paste(orig_image, mask=pil_im)
68
  # new_orig_image = orig_image.convert('RGBA')
69
 
 
71
  # return [new_orig_image, new_im]
72
 
73
 
74
+ def process_video(video, key_color):
75
+ workspace = "./temp"
76
+ original_video_name_without_ext = os.path.splitext(os.path.basename(video))[0]
77
+
78
+ os.makedirs(workspace, exist_ok=True)
79
+ os.makedirs(f"{workspace}/frames", exist_ok=True)
80
+ os.makedirs(f"{workspace}/result", exist_ok=True)
81
+
82
+ video_file = cv2.VideoCapture(video)
83
+ fps = video_file.get(cv2.CAP_PROP_FPS)
84
+
85
+ # まず、videoを読み込んで、./frames/にフレームを保存する
86
+ # fase, load video and save frames to ./frames/
87
+ def extract_frames():
88
+ success, frame = video_file.read()
89
+ frame_num = 0
90
+ with tqdm(
91
+ total=None,
92
+ desc="Extracting frames",
93
+ ) as pbar:
94
+ while success:
95
+ file_name = f"{workspace}/frames/{frame_num:015d}.png"
96
+ cv2.imwrite(file_name, frame)
97
+ success, frame = video_file.read()
98
+ frame_num += 1
99
+ pbar.update(1)
100
+ video_file.release()
101
+ return
102
+
103
+ extract_frames()
104
+
105
+ # それぞれのフレームに対して処理を行う
106
+ # process each frame
107
+ def process_frame(frame_file):
108
+ image = cv2.imread(frame_file)
109
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
110
+ new_image = process(image)
111
+ # key_colorを背景にする
112
+ # set key_color as background
113
+ key_back_image = Image.new("RGBA", new_image.size, key_color)
114
+ new_image = Image.alpha_composite(key_back_image, new_image)
115
+ new_image.save(frame_file)
116
+
117
+ frame_files = sorted(glob.glob(f"{workspace}/frames/*.png"))
118
+ with tqdm(total=len(frame_files), desc="Processing frames") as pbar:
119
+ for file in frame_files:
120
+ process_frame(file)
121
+ pbar.update(1)
122
+
123
+ # frameからvideoを作成
124
+ # create video from frames
125
+ # first_frame = cv2.imread(frame_files[0])
126
+ # h, w, _ = first_frame.shape
127
+ # fourcc = cv2.VideoWriter_fourcc(*"avc1")
128
+ # new_video = cv2.VideoWriter(f"{workspace}/result/video.mp4", fourcc, fps, (w, h))
129
+
130
+ # for file in frame_files:
131
+ # image = cv2.imread(file)
132
+ # new_video.write(image)
133
+ # new_video.release()
134
+
135
+ # 上のコードをffmpegで置き換え
136
+ # replace the above code with ffmpeg
137
+ subprocess.run(
138
+ f'ffmpeg -r {fps} -i {workspace}/frames/%015d.png -c:v libx264 -vf "fps={fps},format=yuv420p" {workspace}/result/video.mp4 -hide_banner -loglevel error -y',
139
+ shell=True,
140
+ check=True,
141
+ )
142
+ # issue
143
+ # なぜかkey_colorの背景色が暗くなる
144
+ # idk why but key_color background color becomes dark
145
 
146
+ subprocess.run(
147
+ f'ffmpeg -i {workspace}/result/video.mp4 -i "{video}" -c:v copy -c:a aac -strict experimental -map 0:v:0 -map 1:a:0 -shortest "./video_result/{original_video_name_without_ext}_BGremoved.mp4" -hide_banner -loglevel error -y',
148
+ shell=True,
149
+ check=True,
150
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
151
 
152
+ # 本当は透過の動画が良かったけど互換性がないのでボツ
153
+ # I wanted to make a transparent video, but it's not compatible, so I gave up
154
+ # subprocess.run(
155
+ # f'ffmpeg -framerate {fps} -i {workspace}/frames/%015d.png -auto-alt-ref 0 -c:v libvpx "./video_result/{original_video_name_without_ext}_BGremoved.webm" -hide_banner -loglevel error -y',
156
+ # shell=True,
157
+ # check=True,
158
+ # )
159
+ # クロマキー用なので音声いらないじゃん
160
+ # audio is not needed
161
+
162
+ # subprocess.run(
163
+ # f'ffmpeg -i "./video_result/{original_video_name_without_ext}_BGremoved.webm" -c:v libx264 -c:a aac -strict experimental -b:a 192k ./demo/demo.mp4 -hide_banner -loglevel error -y',
164
+ # shell=True,
165
+ # check=True,
166
+ # )
167
+
168
+ # ゴミ削除
169
+ # remove garbage
170
+ shutil.rmtree(workspace)
171
+
172
+ return f"./video_result/{original_video_name_without_ext}_BGremoved.mp4"
173
 
 
174
 
175
  gr.Markdown("## BRIA RMBG 1.4")
176
+ gr.HTML(
177
+ """
178
  <p style="margin-bottom: 10px; font-size: 94%">
179
  This is a demo for BRIA RMBG 1.4 that using
180
  <a href="https://huggingface.co/briaai/RMBG-1.4" target="_blank">BRIA RMBG-1.4 image matting model</a> as backbone.
181
  </p>
182
+ """
183
+ )
184
  title = "Background Removal"
185
  description = r"""Background removal model developed by <a href='https://BRIA.AI' target='_blank'><b>BRIA.AI</b></a>, trained on a carefully selected dataset and is available as an open-source model for non-commercial use.<br>
186
  For test upload your image and wait. Read more at model card <a href='https://huggingface.co/briaai/RMBG-1.4' target='_blank'><b>briaai/RMBG-1.4</b></a>.<br>
187
  """
188
+ examples = [
189
+ ["./input.jpg"],
190
+ ]
191
+
192
+ title2 = "Background Removal For Video"
193
+ description2 = r"""Background removal model developed by <a href='https://BRIA.AI' target='_blank'><b>BRIA.AI</b></a>, trained on a carefully selected dataset and is available as an open-source model for non-commercial use.<br>
194
+ For test upload your image and wait. Read more at model card <a href='https://huggingface.co/briaai/RMBG-1.4' target='_blank'><b>briaai/RMBG-1.4</b></a>.<br>
195
+ Also, you can remove the background from the video.
196
+ You may have to wait a little longer for the video to process as each frame in video will be processed, so using strong GPU is recommended.
197
+ You need <b>ffmpeg</b> installed to use this feature.
198
+ """
199
+
200
  # output = ImageSlider(position=0.5,label='Image without background', type="pil", show_download_button=True)
201
  # demo = gr.Interface(fn=process,inputs="image", outputs=output, examples=examples, title=title, description=description)
202
+ demo1 = gr.Interface(
203
+ fn=process,
204
+ inputs="image",
205
+ outputs="image",
206
+ title=title,
207
+ description=description,
208
+ examples=examples,
209
+ api_name="demo1",
210
+ )
211
+
212
+
213
+ demo2 = gr.Interface(
214
+ fn=process_video,
215
+ inputs=[
216
+ gr.Video(label="Video"),
217
+ gr.ColorPicker(label="Key Color(Background color)"),
218
+ ],
219
+ outputs="video",
220
+ title=title2,
221
+ description=description2,
222
+ api_name="demo2",
223
+ )
224
+
225
+ demo = gr.TabbedInterface(
226
+ interface_list=[demo1, demo2],
227
+ tab_names=["Image", "Video"],
228
+ )
229
 
230
  if __name__ == "__main__":
231
  demo.launch(share=False)