Spaces:
Running
on
Zero
Running
on
Zero
import tempfile | |
import os | |
import cv2 | |
import numpy as np | |
import imageio | |
import torch | |
import torchvision.io as io | |
from torchvision.transforms import functional as F | |
from PIL import Image, ImageDraw, ImageFont | |
import torch.nn.functional as nnf | |
def convert_to_rgb(frame): | |
"""Convert frame to RGB format.""" | |
if frame.shape[2] == 4: # RGBA | |
# Convert RGBA to RGB using alpha compositing with white background | |
alpha = frame[:, :, 3:4] / 255.0 | |
rgb = frame[:, :, :3] | |
return (rgb * alpha + (1 - alpha) * 255).astype(np.uint8) | |
return frame | |
def process_frames_batch(frames, target_size, device): | |
"""Process a batch of frames efficiently.""" | |
# Stack frames and move to GPU | |
frames = torch.stack(frames).to(device) | |
# Batch resize | |
frames = nnf.interpolate(frames, size=target_size, | |
mode='bilinear', align_corners=False) | |
return frames | |
def combine_video(obj_dir, output_path, input_frames=None, displayed_preds=3): | |
"""Combine multiple GIFs into a grid layout using torchvision.""" | |
print("Starting video combination process...") | |
# Set device for GPU acceleration | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
print(f"Using device: {device}") | |
# Get all GIF files from shadow_gif directory | |
shadow_gif_dir = os.path.join(obj_dir, 'shadow_gif') | |
gif_files = [f for f in os.listdir(shadow_gif_dir) if f.endswith('_tranp.gif') and not f.startswith('obs')] | |
gif_files = sorted(gif_files) | |
# Limit number of GIFs based on displayed_preds | |
gif_files = gif_files[:displayed_preds] | |
print(f"Using {len(gif_files)} GIFs for {displayed_preds} predictions") | |
# Calculate grid dimensions | |
grid_cols = min(displayed_preds, 3) # Maximum 3 columns | |
grid_rows = (displayed_preds + grid_cols - 1) // grid_cols | |
print(f"Grid layout: {grid_rows}x{grid_cols}") | |
# Load and process all GIFs | |
gif_frames = [] | |
durations = [] | |
for gif_file in gif_files: | |
gif_path = os.path.join(shadow_gif_dir, gif_file) | |
print(f"Loading {gif_file}...") | |
# Read GIF frames efficiently | |
with imageio.get_reader(gif_path) as reader: | |
frames = [] | |
for frame in reader: | |
# Convert to RGB if needed | |
frame = convert_to_rgb(frame) | |
frame = cv2.resize(frame, (frame.shape[1] // 4, frame.shape[0] // 4), interpolation=cv2.INTER_AREA) | |
# Convert to tensor and normalize | |
frame = torch.from_numpy(frame).permute(2, 0, 1).float().to(device) / 255.0 | |
frames.append(frame) | |
# Get duration from the first frame | |
with Image.open(gif_path) as img: | |
duration = img.info.get('duration', 100) / 1000.0 # Convert to seconds | |
gif_frames.append(frames) | |
durations.append(duration) | |
if not gif_frames: | |
raise ValueError("No GIF files found!") | |
# Get common duration | |
common_duration = min(durations) | |
print(f"Common duration: {common_duration}") | |
# Process input frames if provided | |
if input_frames is not None: | |
# Convert BGR to RGB and resize input frames | |
input_frames = [cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) for frame in input_frames] | |
input_frames = [cv2.resize(frame, (frame.shape[1]//8, frame.shape[0]//8), | |
interpolation=cv2.INTER_NEAREST) for frame in input_frames] | |
input_frames = [torch.from_numpy(frame).permute(2, 0, 1).float().to(device) / 255.0 | |
for frame in input_frames] | |
# Calculate target size for each GIF in the grid | |
first_frame = gif_frames[0][0] | |
target_height = first_frame.shape[1] | |
target_width = first_frame.shape[2] | |
# Create grid frames | |
num_frames = max(len(frames) for frames in gif_frames) | |
grid_frames = [] | |
# Process frames in batches | |
batch_size = 4 # Adjust based on GPU memory | |
for frame_idx in range(0, num_frames, batch_size): | |
batch_end = min(frame_idx + batch_size, num_frames) | |
# Create empty grid for the batch | |
grid = torch.ones((batch_end - frame_idx, 3, target_height * grid_rows, target_width * grid_cols), | |
device=device) | |
# Process each GIF in the batch | |
for i, frames in enumerate(gif_frames): | |
row = i // grid_cols | |
col = i % grid_cols | |
# Get frames for this batch | |
batch_frames = frames[frame_idx:batch_end] | |
if batch_frames: | |
# Process frames in batch | |
resized_frames = process_frames_batch(batch_frames, (target_height, target_width), device) | |
# Add to grid | |
for j, frame in enumerate(resized_frames): | |
grid[j, :, row*target_height:(row+1)*target_height, | |
col*target_width:(col+1)*target_width] = frame | |
# Add input frames if provided | |
if input_frames is not None: | |
for i in range(len(gif_frames)): | |
row = i // grid_cols | |
col = i % grid_cols | |
# Get input frames for this batch | |
batch_input_frames = input_frames[frame_idx:batch_end] | |
if batch_input_frames: | |
orig_h, orig_w = batch_input_frames[0].shape[1:3] # (C, H, W) | |
pip_max_width = target_width // 2 | |
pip_max_height = target_height // 2 | |
aspect = orig_w / orig_h | |
if pip_max_width / aspect <= pip_max_height: | |
pip_w = pip_max_width | |
pip_h = int(pip_max_width / aspect) | |
else: | |
pip_h = pip_max_height | |
pip_w = int(pip_max_height * aspect) | |
# resize | |
resized_input_frames = process_frames_batch(batch_input_frames, (pip_h, pip_w), device) | |
# Add to grid | |
for j, frame in enumerate(resized_input_frames): | |
x_pos = col * target_width + target_width - frame.shape[2] - 10 | |
y_pos = row * target_height + 10 | |
grid[j, :, y_pos:y_pos+frame.shape[1], x_pos:x_pos+frame.shape[2]] = frame | |
# Add batch to grid_frames | |
grid_frames.extend([frame for frame in grid]) | |
# Convert frames to numpy and save as GIF | |
print(f"Saving to {output_path}") | |
frames_np = [(frame.cpu().permute(1, 2, 0).numpy() * 255).astype(np.uint8) | |
for frame in grid_frames] | |
# Save as GIF with optimization | |
imageio.mimsave(output_path, frames_np, fps=30, optimize=True, quantizer=0, loop=0) | |
print("Video combination completed!") | |
return output_path | |
if __name__ == "__main__": | |
combine_video("./9622_GRAB/", tempfile.NamedTemporaryFile(suffix=".gif", delete=False).name) | |