akaaku's picture
fix: description
fc836ce verified
import numpy as np
import torch
import torch.nn.functional as F
from torchvision.transforms.functional import normalize
from huggingface_hub import hf_hub_download
import gradio as gr
# from gradio_imageslider import ImageSlider
from briarmbg import BriaRMBG
import PIL
from PIL import Image
from typing import Tuple
import cv2
import os
import shutil
import glob
from tqdm import tqdm
from ffmpy import FFmpeg
net = BriaRMBG()
# model_path = "./model1.pth"
model_path = hf_hub_download("briaai/RMBG-1.4", "model.pth")
if torch.cuda.is_available():
net.load_state_dict(torch.load(model_path))
net = net.cuda()
print("GPU is available")
else:
net.load_state_dict(torch.load(model_path, map_location="cpu"))
print("GPU is NOT available")
net.eval()
def resize_image(image):
image = image.convert("RGB")
model_input_size = (1024, 1024)
image = image.resize(model_input_size, Image.BILINEAR)
return image
def process(image):
# prepare input
orig_image = Image.fromarray(image)
w, h = orig_im_size = orig_image.size
image = resize_image(orig_image)
im_np = np.array(image)
im_tensor = torch.tensor(im_np, dtype=torch.float32).permute(2, 0, 1)
im_tensor = torch.unsqueeze(im_tensor, 0)
im_tensor = torch.divide(im_tensor, 255.0)
im_tensor = normalize(im_tensor, [0.5, 0.5, 0.5], [1.0, 1.0, 1.0])
if torch.cuda.is_available():
im_tensor = im_tensor.cuda()
# inference
result = net(im_tensor)
# post process
result = torch.squeeze(F.interpolate(result[0][0], size=(h, w), mode="bilinear"), 0)
ma = torch.max(result)
mi = torch.min(result)
result = (result - mi) / (ma - mi)
# image to pil
im_array = (result * 255).cpu().data.numpy().astype(np.uint8)
pil_im = Image.fromarray(np.squeeze(im_array))
# paste the mask on the original image
new_im = Image.new("RGBA", pil_im.size, (0, 0, 0, 0))
new_im.paste(orig_image, mask=pil_im)
# new_orig_image = orig_image.convert('RGBA')
return new_im
# return [new_orig_image, new_im]
def process_video(video, key_color):
workspace = "./temp"
original_video_name_without_ext = os.path.splitext(os.path.basename(video))[0]
os.makedirs(workspace, exist_ok=True)
os.makedirs(f"{workspace}/frames", exist_ok=True)
os.makedirs(f"{workspace}/result", exist_ok=True)
os.makedirs("./video_result", exist_ok=True)
video_file = cv2.VideoCapture(video)
fps = video_file.get(cv2.CAP_PROP_FPS)
# まず、videoを読み込んで、./frames/にフレームを保存する
# fase, load video and save frames to ./frames/
def extract_frames():
success, frame = video_file.read()
frame_num = 0
with tqdm(
total=None,
desc="Extracting frames",
) as pbar:
while success:
file_name = f"{workspace}/frames/{frame_num:015d}.png"
cv2.imwrite(file_name, frame)
success, frame = video_file.read()
frame_num += 1
pbar.update(1)
video_file.release()
return
extract_frames()
# それぞれのフレームに対して処理を行う
# process each frame
def process_frame(frame_file):
image = cv2.imread(frame_file)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
new_image = process(image)
# key_colorを背景にする
# set key_color as background
key_back_image = Image.new("RGBA", new_image.size, key_color)
new_image = Image.alpha_composite(key_back_image, new_image)
new_image.save(frame_file)
frame_files = sorted(glob.glob(f"{workspace}/frames/*.png"))
with tqdm(total=len(frame_files), desc="Processing frames") as pbar:
for file in frame_files:
process_frame(file)
pbar.update(1)
# frameからvideoを作成
# create video from frames
# first_frame = cv2.imread(frame_files[0])
# h, w, _ = first_frame.shape
# fourcc = cv2.VideoWriter_fourcc(*"avc1")
# new_video = cv2.VideoWriter(f"{workspace}/result/video.mp4", fourcc, fps, (w, h))
# for file in frame_files:
# image = cv2.imread(file)
# new_video.write(image)
# new_video.release()
# 上のコードをffmpyで書き直す
# rewrite the above code with ffmpy
ff = FFmpeg(
inputs={f"{workspace}/frames/%015d.png": f"-r {fps}"},
outputs={
f"{workspace}/result/video.mp4": f"-c:v libx264 -vf fps={fps},format=yuv420p -hide_banner -loglevel error -y"
},
)
ff.run()
# issue
# なぜかkey_colorの背景色が暗くなる
# idk why but key_color background color becomes dark
ff2 = FFmpeg(
inputs={f"{workspace}/result/video.mp4": None, f"{video}": None},
outputs={
f"./video_result/{original_video_name_without_ext}_BGremoved.mp4": "-c:v copy -c:a aac -strict experimental -map 0:v:0 -map 1:a:0 -shortest -hide_banner -loglevel error -y"
},
)
ff2.run()
# 本当は透過の動画が良かったけど互換性がないのでボツ
# I wanted to make a transparent video, but it's not compatible, so I gave up
# subprocess.run(
# f'ffmpeg -framerate {fps} -i {workspace}/frames/%015d.png -auto-alt-ref 0 -c:v libvpx "./video_result/{original_video_name_without_ext}_BGremoved.webm" -hide_banner -loglevel error -y',
# shell=True,
# check=True,
# )
# クロマキー用なので音声いらないじゃん
# audio is not needed
# subprocess.run(
# f'ffmpeg -i "./video_result/{original_video_name_without_ext}_BGremoved.webm" -c:v libx264 -c:a aac -strict experimental -b:a 192k ./demo/demo.mp4 -hide_banner -loglevel error -y',
# shell=True,
# check=True,
# )
# ゴミ削除
# remove garbage
shutil.rmtree(workspace)
return f"./video_result/{original_video_name_without_ext}_BGremoved.mp4"
gr.Markdown("## BRIA RMBG 1.4")
gr.HTML(
"""
<p style="margin-bottom: 10px; font-size: 94%">
This is a demo for BRIA RMBG 1.4 that using
<a href="https://huggingface.co/briaai/RMBG-1.4" target="_blank">BRIA RMBG-1.4 image matting model</a> as backbone.
</p>
"""
)
title = "Background Removal"
description = r"""Background removal model developed by <a href='https://BRIA.AI' target='_blank'><b>BRIA.AI</b></a>, trained on a carefully selected dataset and is available as an open-source model for non-commercial use.<br>
For test upload your image and wait. Read more at model card <a href='https://huggingface.co/briaai/RMBG-1.4' target='_blank'><b>briaai/RMBG-1.4</b></a>.<br>
"""
examples = [
["./input.jpg"],
]
title2 = "Background Removal For Video"
description2 = r"""Background removal model developed by <a href='https://BRIA.AI' target='_blank'><b>BRIA.AI</b></a>, trained on a carefully selected dataset and is available as an open-source model for non-commercial use.<br>
For test upload your image and wait. Read more at model card <a href='https://huggingface.co/briaai/RMBG-1.4' target='_blank'><b>briaai/RMBG-1.4</b></a>.<br>
Also, you can remove the background from the video.<br>You may have to wait a little longer for the video to process as each frame in video will be processed, so using strong GPU locally is recommended.<br>
"""
# output = ImageSlider(position=0.5,label='Image without background', type="pil", show_download_button=True)
# demo = gr.Interface(fn=process,inputs="image", outputs=output, examples=examples, title=title, description=description)
demo1 = gr.Interface(
fn=process,
inputs="image",
outputs="image",
title=title,
description=description,
examples=examples,
api_name="demo1",
)
demo2 = gr.Interface(
fn=process_video,
inputs=[
gr.Video(label="Video"),
gr.ColorPicker(label="Key Color(Background color)"),
],
outputs="video",
title=title2,
description=description2,
api_name="demo2",
)
demo = gr.TabbedInterface(
interface_list=[demo1, demo2],
tab_names=["Image", "Video"],
)
if __name__ == "__main__":
demo.launch(share=False)