| | |
| | |
| | |
| | |
| |
|
| | import spaces |
| | import gradio as gr |
| | import cv2 |
| | import numpy as np |
| | import time |
| | import random |
| | from PIL import Image |
| | import torch |
| | import re |
| | import os |
| | import shutil |
| | import subprocess |
| | import tempfile |
| |
|
| | torch.jit.script = lambda f: f |
| |
|
| | from transparent_background import Remover |
| |
|
| | @spaces.GPU(duration=90) |
| | def doo(video, color, mode, out_format, progress=gr.Progress()): |
| | print(str(color)) |
| | if str(color).startswith('#'): |
| | color = color.lstrip('#') |
| | rgb = tuple(int(color[i:i+2], 16) for i in (0, 2, 4)) |
| | color = str(list(rgb)) |
| | elif str(color).startswith('rgba'): |
| | rgba_match = re.match(r'rgba\(([\d.]+), ([\d.]+), ([\d.]+), [\d.]+\)', color) |
| | if rgba_match: |
| | r, g, b = rgba_match.groups() |
| | color = str([int(float(r)), int(float(g)), int(float(b))]) |
| | print("Parsed color:", color) |
| | if mode == 'Fast': |
| | remover = Remover(mode='fast') |
| | else: |
| | remover = Remover() |
| |
|
| | cap = cv2.VideoCapture(video) |
| | total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) |
| | fps = cap.get(cv2.CAP_PROP_FPS) or 25.0 |
| | writer = None |
| | tmpname = random.randint(111111111, 999999999) |
| | processed_frames = 0 |
| | start_time = time.time() |
| |
|
| | mp4_path = str(tmpname) + '.mp4' |
| | webm_path = str(tmpname) + '.webm' |
| |
|
| | if out_format == 'mp4': |
| | while cap.isOpened(): |
| | ret, frame = cap.read() |
| |
|
| | if ret is False: |
| | break |
| |
|
| | if time.time() - start_time >= 20 * 60 - 5: |
| | print("GPU Timeout is coming") |
| | cap.release() |
| | if writer is not None: |
| | writer.release() |
| | return mp4_path |
| |
|
| | frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
| | img = Image.fromarray(frame).convert('RGB') |
| |
|
| | if writer is None: |
| | writer = cv2.VideoWriter(mp4_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, img.size) |
| |
|
| | processed_frames += 1 |
| | print(f"Processing frame {processed_frames}") |
| | progress(processed_frames / total_frames, desc=f"Processing frame {processed_frames}/{total_frames}") |
| |
|
| | out = remover.process(img, type=color) |
| |
|
| | frame_bgr = cv2.cvtColor(np.array(out), cv2.COLOR_RGB2BGR) |
| | writer.write(frame_bgr) |
| |
|
| | cap.release() |
| | if writer is not None: |
| | writer.release() |
| | return mp4_path |
| |
|
| | else: |
| | temp_dir = tempfile.mkdtemp(prefix=f"tb_{tmpname}_") |
| | try: |
| | frame_idx = 0 |
| | while cap.isOpened(): |
| | ret, frame = cap.read() |
| |
|
| | if ret is False: |
| | break |
| |
|
| | if time.time() - start_time >= 20 * 60 - 5: |
| | print("GPU Timeout is coming") |
| | cap.release() |
| | |
| | shutil.rmtree(temp_dir, ignore_errors=True) |
| | return webm_path |
| |
|
| | frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
| | img = Image.fromarray(frame).convert('RGB') |
| |
|
| | processed_frames += 1 |
| | frame_idx += 1 |
| | print(f"Processing frame {processed_frames}") |
| | progress(processed_frames / total_frames, desc=f"Processing frame {processed_frames}/{total_frames}") |
| |
|
| | out = remover.process(img, type='rgba') |
| | out = out.convert('RGBA') |
| |
|
| | frame_name = os.path.join(temp_dir, f"frame_{frame_idx:06d}.png") |
| | out.save(frame_name, 'PNG') |
| |
|
| | cap.release() |
| |
|
| | fr_str = str(int(round(fps))) if fps > 0 else "25" |
| | pattern = os.path.join(temp_dir, "frame_%06d.png") |
| | ffmpeg_cmd = [ |
| | "ffmpeg", "-y", |
| | "-framerate", fr_str, |
| | "-i", pattern, |
| | "-i", str(video), |
| | "-map", "0:v", |
| | "-map", "1:a?", |
| | "-c:v", "libvpx-vp9", |
| | "-pix_fmt", "yuva420p", |
| | "-auto-alt-ref", "0", |
| | "-metadata:s:v:0", "alpha_mode=1", |
| | "-c:a", "libopus", |
| | "-shortest", |
| | webm_path |
| | ] |
| | print("Running ffmpeg:", " ".join(ffmpeg_cmd)) |
| | subprocess.run(ffmpeg_cmd, check=True) |
| |
|
| | shutil.rmtree(temp_dir, ignore_errors=True) |
| | return webm_path |
| |
|
| | except subprocess.CalledProcessError as e: |
| | print("ffmpeg failed:", e) |
| | shutil.rmtree(temp_dir, ignore_errors=True) |
| | return webm_path |
| | except Exception as e: |
| | print("Error during processing:", e) |
| | shutil.rmtree(temp_dir, ignore_errors=True) |
| | raise |
| |
|
| | title = "🎞️ Video Background Removal Tool 🎥" |
| | description = """*Please note that if your video file is long (has a high number of frames), there is a chance that processing break due to GPU timeout. In this case, consider trying Fast mode.*""" |
| |
|
| | examples = [ |
| | ['./input.mp4', '#00FF00', 'Normal', 'mp4'], |
| | ] |
| |
|
| | iface = gr.Interface( |
| | fn=doo, |
| | inputs=[ |
| | "video", |
| | gr.ColorPicker(label="Background color", value="#00FF00"), |
| | gr.components.Radio(['Normal', 'Fast'], label='Select mode', value='Normal', info='Normal is more accurate, but takes longer. | Fast has lower accuracy so the process will be faster.'), |
| | gr.components.Radio(['mp4', 'webm'], label='Output format', value='mp4') |
| | ], |
| | outputs="video", |
| | examples=examples, |
| | title=title, |
| | description=description |
| | ) |
| | iface.launch() |
| |
|