Spaces:
Running
Running
| import gradio as gr | |
| import torch | |
| import torch.nn as nn | |
| import numpy as np | |
| import cv2 | |
| from PIL import Image | |
| from torch.autograd import Function | |
| from transformers import AutoModelForDepthEstimation, AutoImageProcessor | |
| from huggingface_hub import hf_hub_download | |
| import os | |
| # === DEVICE === | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print(f"Running on device: {device}") | |
| # ============================================================================== | |
| # 1. FORWARD WARP IMPLEMENTATION (Native PyTorch) | |
| # ============================================================================== | |
| class ForwardWarpFunction(Function): | |
| def forward(ctx, im0, flow, interpolation_mode_int): | |
| # Input validation | |
| assert (len(im0.shape) == len(flow.shape) == 4) | |
| assert (interpolation_mode_int == 0 or interpolation_mode_int == 1) | |
| assert (im0.shape[0] == flow.shape[0]) | |
| assert (im0.shape[-2:] == flow.shape[1:3]) | |
| assert (flow.shape[3] == 2) | |
| B, C, H, W = im0.shape | |
| # Create a contiguous output tensor to prevent view/reshape errors | |
| im1 = torch.zeros(im0.shape, device=im0.device, dtype=im0.dtype).contiguous() | |
| # Grid creation | |
| grid_x, grid_y = torch.meshgrid( | |
| torch.arange(W, device=im0.device, dtype=im0.dtype), | |
| torch.arange(H, device=im0.device, dtype=im0.dtype), | |
| indexing='xy' | |
| ) | |
| grid_x = grid_x.unsqueeze(0).expand(B, -1, -1) | |
| grid_y = grid_y.unsqueeze(0).expand(B, -1, -1) | |
| # Destination coordinates | |
| x_dest = grid_x + flow[:, :, :, 0] | |
| y_dest = grid_y + flow[:, :, :, 1] | |
| if interpolation_mode_int == 0: # Bilinear Splatting | |
| x_f = torch.floor(x_dest).long() | |
| y_f = torch.floor(y_dest).long() | |
| x_c = x_f + 1 | |
| y_c = y_f + 1 | |
| # Weights | |
| nw_k = (x_c.float() - x_dest) * (y_c.float() - y_dest) | |
| ne_k = (x_dest - x_f.float()) * (y_c.float() - y_dest) | |
| sw_k = (x_c.float() - x_dest) * (y_dest - y_f.float()) | |
| se_k = (x_dest - x_f.float()) * (y_dest - y_f.float()) | |
| # Clamp coords | |
| x_f_clamped = torch.clamp(x_f, 0, W - 1) | |
| y_f_clamped = torch.clamp(y_f, 0, H - 1) | |
| x_c_clamped = torch.clamp(x_c, 0, W - 1) | |
| y_c_clamped = torch.clamp(y_c, 0, H - 1) | |
| # Per-corner validity masks | |
| mask_nw = (x_f >= 0) & (x_f < W) & (y_f >= 0) & (y_f < H) | |
| mask_ne = (x_c >= 0) & (x_c < W) & (y_f >= 0) & (y_f < H) | |
| mask_sw = (x_f >= 0) & (x_f < W) & (y_c >= 0) & (y_c < H) | |
| mask_se = (x_c >= 0) & (x_c < W) & (y_c >= 0) & (y_c < H) | |
| # Reshape for broadcasting | |
| nw_k = nw_k.unsqueeze(1) | |
| ne_k = ne_k.unsqueeze(1) | |
| sw_k = sw_k.unsqueeze(1) | |
| se_k = se_k.unsqueeze(1) | |
| mask_nw = mask_nw.unsqueeze(1) | |
| mask_ne = mask_ne.unsqueeze(1) | |
| mask_sw = mask_sw.unsqueeze(1) | |
| mask_se = mask_se.unsqueeze(1) | |
| # Flatten indices for scatter_add | |
| b_indices = torch.arange(B, device=im0.device).view(B, 1, 1, 1).expand(-1, C, H, W) | |
| c_indices = torch.arange(C, device=im0.device).view(1, C, 1, 1).expand(B, -1, H, W) | |
| base_idx = b_indices * (C * H * W) + c_indices * (H * W) | |
| # Scatter to 4 neighbors (Accumulate/Splat) | |
| def scatter_corner(y_idx, x_idx, weights, mask): | |
| flat_idx = base_idx + y_idx.unsqueeze(1) * W + x_idx.unsqueeze(1) | |
| values = (im0 * weights) * mask.float() | |
| # Since im1 is contiguous, we can safely use view() for in-place scatter | |
| im1_flat = im1.view(-1) | |
| idx_flat = flat_idx.contiguous().view(-1) | |
| val_flat = values.contiguous().view(-1) | |
| im1_flat.scatter_add_(0, idx_flat, val_flat) | |
| scatter_corner(y_f_clamped, x_f_clamped, nw_k, mask_nw) # NW | |
| scatter_corner(y_f_clamped, x_c_clamped, ne_k, mask_ne) # NE | |
| scatter_corner(y_c_clamped, x_f_clamped, sw_k, mask_sw) # SW | |
| scatter_corner(y_c_clamped, x_c_clamped, se_k, mask_se) # SE | |
| else: # Nearest Neighbor (Legacy fallback) | |
| x_nearest = torch.round(x_dest).long() | |
| y_nearest = torch.round(y_dest).long() | |
| valid_mask = (x_nearest >= 0) & (x_nearest < W) & (y_nearest >= 0) & (y_nearest < H) | |
| valid_mask = valid_mask.unsqueeze(1) | |
| x_clamped = torch.clamp(x_nearest, 0, W - 1) | |
| y_clamped = torch.clamp(y_nearest, 0, H - 1) | |
| b_indices = torch.arange(B, device=im0.device).view(B, 1, 1, 1).expand(-1, C, H, W) | |
| c_indices = torch.arange(C, device=im0.device).view(1, C, 1, 1).expand(B, -1, H, W) | |
| dest_idx = b_indices * (C * H * W) + c_indices * (H * W) + y_clamped.unsqueeze(1) * W + x_clamped.unsqueeze( | |
| 1) | |
| source_values = im0 * valid_mask.float() | |
| # Since im1 is contiguous, we can safely use view() | |
| im1.view(-1).scatter_(0, dest_idx.contiguous().view(-1), source_values.contiguous().view(-1)) | |
| return im1 | |
| def backward(ctx, grad_output): | |
| return None, None, None | |
| class forward_warp(nn.Module): | |
| def __init__(self, interpolation_mode="Bilinear"): | |
| super(forward_warp, self).__init__() | |
| self.interpolation_mode_int = 0 if interpolation_mode == "Bilinear" else 1 | |
| def forward(self, im0, flow): | |
| return ForwardWarpFunction.apply(im0, flow, self.interpolation_mode_int) | |
| # ============================================================================== | |
| # 2. STEREO WARPER WRAPPER | |
| # ============================================================================== | |
| class ForwardWarpStereo(nn.Module): | |
| """ | |
| Weighted Splatting wrapper. | |
| Handles Occlusions using exponential depth weights (Soft Z-Buffering). | |
| """ | |
| def __init__(self, eps=1e-6): | |
| super(ForwardWarpStereo, self).__init__() | |
| self.eps = eps | |
| self.fw = forward_warp(interpolation_mode="Bilinear") | |
| def forward(self, im, disp, convergence, divergence): | |
| # disp comes in as [B, 1, H, W] or [1, 1, H, W] | |
| # We need to squeeze the channel dim to do math with coordinates [B, H, W] | |
| disp_squeeze = disp.squeeze(1) # Shape [B, H, W] | |
| # Create Flow from Disparity | |
| # Shift = (Depth - Convergence) * Divergence | |
| # We negate it because standard flow is source->dest, but disparity logic varies. | |
| # For Right Eye view: Target = Source - Shift. So Flow = -Shift. | |
| shift = (disp_squeeze - convergence) * divergence | |
| flow_x = -shift | |
| # Stack flow (x, y=0) -> (B, H, W, 2) | |
| flow_y = torch.zeros_like(flow_x) | |
| # Stack along last dim: [B, H, W] + [B, H, W] -> [B, H, W, 2] | |
| flow = torch.stack((flow_x, flow_y), dim=-1) | |
| # 1. Calculate Weights (Soft Z-Buffer) | |
| # Closer objects (higher disparity) get exponentially higher weight. | |
| # This allows foreground to overwrite background during accumulation. | |
| # Using 1.5^disp is a tuned heuristic for separation. | |
| disp_norm = disp_squeeze / (disp_squeeze.max() + 1e-8) | |
| weights_map = disp_norm + 0.05 | |
| weights_map = weights_map.unsqueeze(1) | |
| # 2. Warp Image * Weights (Accumulate Weighted Color) | |
| # Input im is (B, C, H, W), weights is (B, 1, H, W) | |
| res_accum = self.fw(im * weights_map, flow) | |
| # 3. Warp Weights (Accumulate Weights) | |
| mask_accum = self.fw(weights_map, flow) | |
| # 4. Normalize (Color / TotalWeight) | |
| # Add epsilon to avoid divide-by-zero in empty regions | |
| mask_accum.clamp_(min=self.eps) | |
| res = res_accum / mask_accum | |
| # 5. Generate Binary Occlusion Mask (for Inpainting) | |
| # Splat a grid of ones. Where sum is 0, we have a hole. | |
| ones = torch.ones_like(disp) | |
| occupancy = self.fw(ones, flow) | |
| # Valid pixels have occupancy > 0. | |
| # We want holes = 1.0, filled = 0.0 | |
| occlusion_mask = (occupancy < self.eps).float() | |
| return res, occlusion_mask | |
| # ============================================================================== | |
| # 3. APP LOGIC & MODELS | |
| # ============================================================================== | |
| # === LOAD MODELS === | |
| def load_models(): | |
| print("Loading Depth Anything V2 Large...") | |
| depth_model = AutoModelForDepthEstimation.from_pretrained( | |
| "depth-anything/Depth-Anything-V2-Large-hf" | |
| ).to(device) | |
| depth_processor = AutoImageProcessor.from_pretrained( | |
| "depth-anything/Depth-Anything-V2-Large-hf" | |
| ) | |
| print("Loading LaMa Inpainting Model...") | |
| try: | |
| model_path = hf_hub_download(repo_id="fashn-ai/LaMa", filename="big-lama.pt") | |
| lama_model = torch.jit.load(model_path, map_location=device) | |
| lama_model.eval() | |
| except Exception as e: | |
| print(f"Error loading LaMa model: {e}") | |
| raise e | |
| # Initialize the new Stereo Warper | |
| stereo_warper = ForwardWarpStereo().to(device) | |
| return depth_model, depth_processor, lama_model, stereo_warper | |
| # Load models once at startup | |
| depth_model, depth_processor, lama_model, stereo_warper = load_models() | |
| # === DEPTH ESTIMATION === | |
| def estimate_depth(image_pil, model, processor): | |
| original_size = image_pil.size | |
| inputs = processor(images=image_pil, return_tensors="pt").to(device) | |
| depth = model(**inputs).predicted_depth | |
| depth = torch.nn.functional.interpolate( | |
| depth.unsqueeze(1), | |
| size=(original_size[1], original_size[0]), | |
| mode="bicubic", | |
| align_corners=False, | |
| ).squeeze() | |
| depth_min, depth_max = depth.min(), depth.max() | |
| if depth_max - depth_min > 0: | |
| depth = (depth - depth_min) / (depth_max - depth_min) | |
| else: | |
| depth = torch.zeros_like(depth) | |
| return depth | |
| # === DEPTH MANIPULATION === | |
| def erode_depth(depth_tensor, kernel_size): | |
| if kernel_size <= 0: return depth_tensor | |
| k = kernel_size if kernel_size % 2 == 1 else kernel_size + 1 | |
| x = depth_tensor.unsqueeze(0).unsqueeze(0) | |
| padding = k // 2 | |
| x_eroded = -torch.nn.functional.max_pool2d(-x, kernel_size=k, stride=1, padding=padding) | |
| return x_eroded.squeeze() | |
| # === LOCAL INPAINTING === | |
| def run_local_lama(image_bgr, mask_float): | |
| # 0. Dilate Mask slightly to catch edge artifacts from splatting | |
| kernel = np.ones((3, 3), np.uint8) | |
| mask_uint8 = (mask_float * 255).astype(np.uint8) | |
| mask_dilated = cv2.dilate(mask_uint8, kernel, iterations=3) | |
| # 1. Resize to be divisible by 8 | |
| h, w = image_bgr.shape[:2] | |
| new_h = (h // 8) * 8 | |
| new_w = (w // 8) * 8 | |
| img_resized = cv2.resize(image_bgr, (new_w, new_h)) | |
| mask_resized = cv2.resize(mask_dilated, (new_w, new_h), interpolation=cv2.INTER_NEAREST) | |
| # 2. Convert to Torch | |
| img_t = torch.from_numpy(img_resized).float().permute(2, 0, 1).unsqueeze(0) / 255.0 | |
| img_t = img_t[:, [2, 1, 0], :, :] # BGR to RGB | |
| mask_t = torch.from_numpy(mask_resized).float().unsqueeze(0).unsqueeze(0) / 255.0 | |
| mask_t = (mask_t > 0.5).float() | |
| img_t = img_t.to(device) | |
| mask_t = mask_t.to(device) | |
| # 3. Inference | |
| img_t = img_t * (1 - mask_t) | |
| inpainted_t = lama_model(img_t, mask_t) | |
| # 4. Post-process | |
| inpainted = inpainted_t[0].permute(1, 2, 0).cpu().numpy() | |
| inpainted = np.clip(inpainted * 255, 0, 255).astype(np.uint8) | |
| inpainted = cv2.cvtColor(inpainted, cv2.COLOR_RGB2BGR) | |
| if new_h != h or new_w != w: | |
| inpainted = cv2.resize(inpainted, (w, h)) | |
| return inpainted | |
| def make_anaglyph(left, right): | |
| l_arr = np.array(left) | |
| r_arr = np.array(right) | |
| anaglyph = np.zeros_like(l_arr) | |
| anaglyph[:, :, 0] = l_arr[:, :, 0] | |
| anaglyph[:, :, 1] = r_arr[:, :, 1] | |
| anaglyph[:, :, 2] = r_arr[:, :, 2] | |
| return Image.fromarray(anaglyph) | |
| # === PIPELINE === | |
| def stereo_pipeline(image_pil, divergence, convergence, edge_erosion): | |
| if image_pil is None: | |
| return None, None, None, None | |
| # Resize input if too large | |
| w, h = image_pil.size | |
| if w > 1920: | |
| ratio = 1920 / w | |
| new_h = int(h * ratio) | |
| image_pil = image_pil.resize((1920, new_h), Image.LANCZOS) | |
| # 1. Depth Estimation | |
| depth_tensor = estimate_depth(image_pil, depth_model, depth_processor) | |
| # 2. Depth Erosion (optional halo reduction) | |
| if edge_erosion > 0: | |
| depth_tensor = erode_depth(depth_tensor, int(edge_erosion)) | |
| # Visualize Depth | |
| depth_vis = (depth_tensor.cpu().numpy() * 255).astype(np.uint8) | |
| depth_image = Image.fromarray(depth_vis) | |
| # 3. Forward Warp (Weighted Bilinear Splatting) | |
| # Convert image to tensor (B, C, H, W) | |
| image_tensor = torch.from_numpy(np.array(image_pil)).float().to(device).permute(2, 0, 1).unsqueeze(0) / 255.0 | |
| # Prepare depth tensor (B, 1, H, W) | |
| depth_input = depth_tensor.unsqueeze(0).unsqueeze(0) | |
| # Run the new Stereo Warper | |
| with torch.no_grad(): | |
| right_img_tensor, mask_tensor = stereo_warper( | |
| image_tensor, | |
| depth_input, | |
| float(convergence), | |
| float(divergence) | |
| ) | |
| # Convert results back to CPU/Numpy | |
| right_img_rgb = (right_img_tensor.squeeze(0).permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8) | |
| mask_vis = (mask_tensor.squeeze(0).squeeze(0).cpu().numpy() * 255).astype(np.uint8) | |
| mask_image = Image.fromarray(mask_vis) | |
| # 4. Inpainting | |
| right_img_bgr = cv2.cvtColor(right_img_rgb, cv2.COLOR_RGB2BGR) | |
| mask_float = mask_tensor.squeeze().cpu().numpy() | |
| right_filled_bgr = run_local_lama(right_img_bgr, mask_float) | |
| # 5. Finalize | |
| left = image_pil | |
| right = Image.fromarray(cv2.cvtColor(right_filled_bgr, cv2.COLOR_BGR2RGB)) | |
| width, height = left.size | |
| combined_image = Image.new('RGB', (width * 2, height)) | |
| combined_image.paste(left, (0, 0)) | |
| combined_image.paste(right, (width, 0)) | |
| anaglyph_image = make_anaglyph(left, right) | |
| return combined_image, anaglyph_image, depth_image, mask_image | |
| # === GRADIO UI === | |
| with gr.Blocks(title="2D to 3D Stereo") as demo: | |
| # Inject CSS | |
| gr.Markdown("## 2D to 3D Stereo Generator (High-Quality Splatting)") | |
| gr.Markdown("Uses **Depth Anything V2**, **Bilinear Weighted Splatting** (Soft Z-Buffer), and **LaMa Inpainting**.") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| input_img = gr.Image(type="pil", label="Input Image", height=320) | |
| with gr.Group(): | |
| gr.Markdown("### 3D Controls") | |
| divergence_slider = gr.Slider( | |
| minimum=0, maximum=100, value=30, step=1, | |
| label="3D Strength (Divergence)", | |
| info="Max separation in pixels." | |
| ) | |
| convergence_slider = gr.Slider( | |
| minimum=0.0, maximum=1.0, value=0.5, step=0.05, | |
| label="Focus Plane (Convergence)", | |
| info="0.0 = Background at screen. 1.0 = Foreground at screen." | |
| ) | |
| erosion_slider = gr.Slider( | |
| minimum=0, maximum=20, value=2, step=1, | |
| label="Edge Masking (Erosion)", | |
| info="Cleanup edges. Set to 0 for raw splatting." | |
| ) | |
| btn = gr.Button("Generate 3D", variant="primary") | |
| with gr.Column(scale=1): | |
| out_anaglyph = gr.Image(label="Anaglyph (Red/Cyan)", height=320) | |
| out_stereo = gr.Image(label="Side-by-Side Stereo Pair", height=320) | |
| with gr.Row(): | |
| out_depth = gr.Image(label="Depth Map", height=200) | |
| out_mask = gr.Image(label="Inpainting Mask (Holes)", height=200) | |
| btn.click( | |
| fn=stereo_pipeline, | |
| inputs=[input_img, divergence_slider, convergence_slider, erosion_slider], | |
| outputs=[out_stereo, out_anaglyph, out_depth, out_mask] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |