|
import gradio as gr |
|
import torch |
|
import spaces |
|
import moviepy.editor as mp |
|
from PIL import Image |
|
import numpy as np |
|
import tempfile |
|
import time |
|
import os |
|
import shutil |
|
import ffmpeg |
|
from concurrent.futures import ThreadPoolExecutor |
|
from gradio.themes.base import Base |
|
from gradio.themes.utils import colors, fonts |
|
from infer import lotus |
|
|
|
|
|
class WhiteTheme(Base): |
|
def __init__( |
|
self, |
|
*, |
|
primary_hue: colors.Color | str = colors.orange, |
|
font: fonts.Font | str | tuple[fonts.Font | str, ...] = ( |
|
fonts.GoogleFont("Inter"), |
|
"ui-sans-serif", |
|
"system-ui", |
|
"sans-serif", |
|
), |
|
font_mono: fonts.Font | str | tuple[fonts.Font | str, ...] = ( |
|
fonts.GoogleFont("Inter"), |
|
"ui-monospace", |
|
"system-ui", |
|
"monospace", |
|
) |
|
): |
|
super().__init__( |
|
primary_hue=primary_hue, |
|
font=font, |
|
font_mono=font_mono, |
|
) |
|
|
|
self.set( |
|
background_fill_primary="*primary_50", |
|
background_fill_secondary="white", |
|
border_color_primary="*primary_300", |
|
body_background_fill="white", |
|
body_background_fill_dark="white", |
|
block_background_fill="white", |
|
block_background_fill_dark="white", |
|
panel_background_fill="white", |
|
panel_background_fill_dark="white", |
|
body_text_color="black", |
|
body_text_color_dark="black", |
|
block_label_text_color="black", |
|
block_label_text_color_dark="black", |
|
block_border_color="white", |
|
panel_border_color="white", |
|
input_border_color="lightgray", |
|
input_background_fill="white", |
|
input_background_fill_dark="white", |
|
shadow_drop="none" |
|
) |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
def process_frame(frame, seed=0): |
|
""" |
|
Process a single frame through the depth model. |
|
Returns the discriminative depth map. |
|
""" |
|
try: |
|
|
|
image = Image.fromarray(frame) |
|
|
|
|
|
with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as tmp: |
|
image.save(tmp.name) |
|
|
|
|
|
_, output_d = lotus(tmp.name, 'depth', seed, device) |
|
|
|
|
|
os.unlink(tmp.name) |
|
|
|
|
|
depth_array = np.array(output_d) |
|
return depth_array |
|
|
|
except Exception as e: |
|
print(f"Error processing frame: {e}") |
|
return None |
|
|
|
@spaces.GPU |
|
def process_video(video_path, fps=0, seed=0, max_workers=6): |
|
""" |
|
Process video to create depth map sequence and video. |
|
Maintains original resolution and framerate if fps=0. |
|
""" |
|
temp_dir = None |
|
try: |
|
start_time = time.time() |
|
video = mp.VideoFileClip(video_path) |
|
|
|
|
|
if fps == 0: |
|
fps = video.fps |
|
|
|
frames = list(video.iter_frames(fps=fps)) |
|
total_frames = len(frames) |
|
|
|
print(f"Processing {total_frames} frames at {fps} FPS...") |
|
|
|
|
|
temp_dir = tempfile.mkdtemp() |
|
frames_dir = os.path.join(temp_dir, "frames") |
|
os.makedirs(frames_dir, exist_ok=True) |
|
|
|
|
|
processed_frames = [] |
|
with ThreadPoolExecutor(max_workers=max_workers) as executor: |
|
futures = [executor.submit(process_frame, frame, seed) for frame in frames] |
|
for i, future in enumerate(futures): |
|
try: |
|
result = future.result() |
|
if result is not None: |
|
|
|
frame_path = os.path.join(frames_dir, f"frame_{i:06d}.png") |
|
Image.fromarray(result).save(frame_path) |
|
|
|
|
|
processed_frames.append(result) |
|
|
|
|
|
elapsed_time = time.time() - start_time |
|
yield processed_frames[-1], None, None, f"Processing frame {i+1}/{total_frames}... Elapsed time: {elapsed_time:.2f} seconds" |
|
|
|
if (i + 1) % 10 == 0: |
|
print(f"Processed {i+1}/{total_frames} frames") |
|
except Exception as e: |
|
print(f"Error processing frame {i+1}: {e}") |
|
|
|
print("Creating output files...") |
|
|
|
output_dir = os.path.join(os.path.dirname(video_path), "output") |
|
os.makedirs(output_dir, exist_ok=True) |
|
|
|
|
|
zip_filename = f"depth_frames_{int(time.time())}.zip" |
|
zip_path = os.path.join(output_dir, zip_filename) |
|
shutil.make_archive(zip_path[:-4], 'zip', frames_dir) |
|
|
|
|
|
print("Creating MP4 video...") |
|
video_filename = f"depth_video_{int(time.time())}.mp4" |
|
video_path = os.path.join(output_dir, video_filename) |
|
|
|
try: |
|
|
|
stream = ffmpeg.input( |
|
os.path.join(frames_dir, 'frame_%06d.png'), |
|
pattern_type='sequence', |
|
framerate=fps |
|
) |
|
|
|
stream = ffmpeg.output( |
|
stream, |
|
video_path, |
|
vcodec='libx264', |
|
pix_fmt='yuv420p', |
|
crf=17, |
|
threads=max_workers |
|
) |
|
|
|
ffmpeg.run(stream, overwrite_output=True, capture_stdout=True, capture_stderr=True) |
|
print("MP4 video created successfully!") |
|
|
|
except ffmpeg.Error as e: |
|
print(f"Error creating video: {e.stderr.decode() if e.stderr else str(e)}") |
|
video_path = None |
|
|
|
print("Processing complete!") |
|
yield None, zip_path, video_path, f"Processing complete! Total time: {time.time() - start_time:.2f} seconds" |
|
|
|
except Exception as e: |
|
print(f"Error: {e}") |
|
yield None, None, None, f"Error processing video: {e}" |
|
finally: |
|
if temp_dir and os.path.exists(temp_dir): |
|
try: |
|
shutil.rmtree(temp_dir) |
|
except Exception as e: |
|
print(f"Error cleaning up temp directory: {e}") |
|
|
|
def process_wrapper(video, fps=0, seed=0, max_workers=6): |
|
if video is None: |
|
raise gr.Error("Please upload a video.") |
|
try: |
|
outputs = [] |
|
for output in process_video(video, fps, seed, max_workers): |
|
outputs.append(output) |
|
yield output |
|
return outputs[-1] |
|
except Exception as e: |
|
raise gr.Error(f"Error processing video: {str(e)}") |
|
|
|
|
|
custom_css = """ |
|
.title-container { |
|
text-align: center; |
|
padding: 10px 0; |
|
} |
|
|
|
#title { |
|
font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, Helvetica, Arial, sans-serif; |
|
font-size: 36px; |
|
font-weight: bold; |
|
color: #000000; |
|
padding: 10px; |
|
border-radius: 10px; |
|
display: inline-block; |
|
background: linear-gradient( |
|
135deg, |
|
#e0f7fa, #e8f5e9, #fff9c4, #ffebee, |
|
#f3e5f5, #e1f5fe, #fff3e0, #e8eaf6 |
|
); |
|
background-size: 400% 400%; |
|
animation: gradient-animation 15s ease infinite; |
|
} |
|
|
|
@keyframes gradient-animation { |
|
0% { background-position: 0% 50%; } |
|
50% { background-position: 100% 50%; } |
|
100% { background-position: 0% 50%; } |
|
} |
|
""" |
|
|
|
|
|
with gr.Blocks(css=custom_css, theme=WhiteTheme()) as demo: |
|
gr.HTML(''' |
|
<div class="title-container"> |
|
<div id="title">Video Depth Estimation</div> |
|
</div> |
|
''') |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
video_input = gr.Video( |
|
label="Upload Video", |
|
interactive=True, |
|
show_label=True, |
|
height=360, |
|
width=640 |
|
) |
|
with gr.Row(): |
|
fps_slider = gr.Slider( |
|
minimum=0, |
|
maximum=60, |
|
step=1, |
|
value=0, |
|
label="Output FPS (0 will inherit the original fps value)", |
|
) |
|
seed_slider = gr.Slider( |
|
minimum=0, |
|
maximum=999999999, |
|
step=1, |
|
value=0, |
|
label="Seed", |
|
) |
|
max_workers_slider = gr.Slider( |
|
minimum=1, |
|
maximum=32, |
|
step=1, |
|
value=6, |
|
label="Max Workers", |
|
info="Determines how many frames to process in parallel" |
|
) |
|
btn = gr.Button("Process Video", elem_id="submit-button") |
|
|
|
with gr.Column(): |
|
preview_image = gr.Image(label="Live Preview", show_label=True) |
|
output_frames_zip = gr.File(label="Download Frame Sequence (ZIP)") |
|
output_video = gr.File(label="Download Video (MP4)") |
|
time_textbox = gr.Textbox(label="Status", interactive=False) |
|
|
|
gr.Markdown(""" |
|
### Output Information |
|
- High-quality MP4 video output |
|
- Original resolution and framerate are maintained |
|
- Frame sequence provided for maximum compatibility |
|
""") |
|
|
|
btn.click( |
|
fn=process_wrapper, |
|
inputs=[video_input, fps_slider, seed_slider, max_workers_slider], |
|
outputs=[preview_image, output_frames_zip, output_video, time_textbox] |
|
) |
|
|
|
demo.queue() |
|
|
|
api = gr.Interface( |
|
fn=process_wrapper, |
|
inputs=[ |
|
gr.Video(label="Upload Video"), |
|
gr.Number(label="FPS", value=0), |
|
gr.Number(label="Seed", value=0), |
|
gr.Number(label="Max Workers", value=6) |
|
], |
|
outputs=[ |
|
gr.Image(label="Preview"), |
|
gr.File(label="Frame Sequence"), |
|
gr.File(label="Video"), |
|
gr.Textbox(label="Status") |
|
], |
|
title="Video Depth Estimation API", |
|
description="Generate depth maps from videos", |
|
api_name="/process_video" |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch(debug=True, show_error=True, share=False, server_name="0.0.0.0", server_port=7860) |