Spaces:
Runtime error
Runtime error
File size: 3,632 Bytes
a72119e 496112d 8365126 a72119e 6d754a8 de54836 bcece85 496112d 6d754a8 bcece85 6d754a8 496112d 6d754a8 293e082 26a50b2 6d754a8 496112d 293e082 b1d6fce 6d754a8 293e082 6d754a8 4902bd9 6d754a8 8365126 b1d6fce 293e082 6bfcd1c 6d754a8 de54836 2189235 a72119e 6d754a8 a72119e 6d754a8 293e082 6d754a8 2189235 6d754a8 2189235 6d754a8 293e082 2189235 4902bd9 b1d6fce 492f003 6d754a8 4902bd9 bcece85 4902bd9 6d754a8 9c23c95 26a50b2 2189235 492f003 2189235 26a50b2 a72119e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 |
import gradio as gr
from loadimg import load_img
import spaces
from transformers import AutoModelForImageSegmentation
import torch
from torchvision import transforms
import moviepy.editor as mp
from pydub import AudioSegment
from PIL import Image
import numpy as np
import os
import tempfile
import uuid
torch.set_float32_matmul_precision(["high", "highest"][0])
birefnet = AutoModelForImageSegmentation.from_pretrained(
"ZhengPeng7/BiRefNet", trust_remote_code=True
)
birefnet.to("cuda")
transform_image = transforms.Compose(
[
transforms.Resize((1024, 1024)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
]
)
@spaces.GPU
def fn(vid, fps=0, color="#00FF00"):
# Load the video using moviepy
video = mp.VideoFileClip(vid)
# Load original fps if fps value is equal to 0
if fps == 0:
fps = video.fps
# Extract audio from the video
audio = video.audio
# Extract frames at the specified FPS
frames = video.iter_frames(fps=fps)
# Process each frame for background removal
processed_frames = []
yield gr.update(visible=True), gr.update(visible=False)
for frame in frames:
pil_image = Image.fromarray(frame)
processed_image = process(pil_image, color)
processed_frames.append(np.array(processed_image))
yield processed_image, None
# Create a new video from the processed frames
processed_video = mp.ImageSequenceClip(processed_frames, fps=fps)
# Add the original audio back to the processed video
processed_video = processed_video.set_audio(audio)
# Save the processed video to a temporary file
temp_dir = "temp"
os.makedirs(temp_dir, exist_ok=True)
unique_filename = str(uuid.uuid4()) + ".mp4"
temp_filepath = os.path.join(temp_dir, unique_filename)
processed_video.write_videofile(temp_filepath, codec="libx264")
yield gr.update(visible=False), gr.update(visible=True)
# Return the path to the temporary file
yield processed_image, temp_filepath
def process(image, color_hex):
image_size = image.size
input_images = transform_image(image).unsqueeze(0).to("cuda")
# Prediction
with torch.no_grad():
preds = birefnet(input_images)[-1].sigmoid().cpu()
pred = preds[0].squeeze()
pred_pil = transforms.ToPILImage()(pred)
mask = pred_pil.resize(image_size)
# Convert hex color to RGB tuple
color_rgb = tuple(int(color_hex[i : i + 2], 16) for i in (1, 3, 5))
# Create a background image with the chosen color
background = Image.new("RGBA", image_size, color_rgb + (255,))
# Composite the image onto the background using the mask
image = Image.composite(image, background, mask)
return image
with gr.Blocks() as demo:
with gr.Row():
in_video = gr.Video(label="Input Video")
stream_image = gr.Image(label="Streaming Output", visible=False)
out_video = gr.Video(label="Final Output Video")
submit_button = gr.Button("Change Background")
with gr.Row():
fps_slider = gr.Slider(minimum=0, maximum=60, step=1, value=0, label="Output FPS (0 will inherit the original fps value)")
color_picker = gr.ColorPicker(label="Background Color", value="#00FF00")
examples = gr.Examples(["rickroll-2sec.mp4"], inputs=in_video, outputs=[stream_image, out_video], fn=fn, cache_examples=True, cache_mode="eager")
submit_button.click(
fn, inputs=[in_video, fps_slider, color_picker], outputs=[stream_image, out_video]
)
if __name__ == "__main__":
demo.launch(show_error=True) |