2D-to-Stereo-3D / app.py
enoky's picture
Update app.py
b4c58d3 verified
import gradio as gr
import torch
import torch.nn as nn
import numpy as np
import cv2
from PIL import Image
from torch.autograd import Function
from transformers import AutoModelForDepthEstimation, AutoImageProcessor
from huggingface_hub import hf_hub_download
import os
# === DEVICE ===
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Running on device: {device}")
# ==============================================================================
# 1. FORWARD WARP IMPLEMENTATION (Native PyTorch)
# ==============================================================================
class ForwardWarpFunction(Function):
@staticmethod
def forward(ctx, im0, flow, interpolation_mode_int):
# Input validation
assert (len(im0.shape) == len(flow.shape) == 4)
assert (interpolation_mode_int == 0 or interpolation_mode_int == 1)
assert (im0.shape[0] == flow.shape[0])
assert (im0.shape[-2:] == flow.shape[1:3])
assert (flow.shape[3] == 2)
B, C, H, W = im0.shape
# Create a contiguous output tensor to prevent view/reshape errors
im1 = torch.zeros(im0.shape, device=im0.device, dtype=im0.dtype).contiguous()
# Grid creation
grid_x, grid_y = torch.meshgrid(
torch.arange(W, device=im0.device, dtype=im0.dtype),
torch.arange(H, device=im0.device, dtype=im0.dtype),
indexing='xy'
)
grid_x = grid_x.unsqueeze(0).expand(B, -1, -1)
grid_y = grid_y.unsqueeze(0).expand(B, -1, -1)
# Destination coordinates
x_dest = grid_x + flow[:, :, :, 0]
y_dest = grid_y + flow[:, :, :, 1]
if interpolation_mode_int == 0: # Bilinear Splatting
x_f = torch.floor(x_dest).long()
y_f = torch.floor(y_dest).long()
x_c = x_f + 1
y_c = y_f + 1
# Weights
nw_k = (x_c.float() - x_dest) * (y_c.float() - y_dest)
ne_k = (x_dest - x_f.float()) * (y_c.float() - y_dest)
sw_k = (x_c.float() - x_dest) * (y_dest - y_f.float())
se_k = (x_dest - x_f.float()) * (y_dest - y_f.float())
# Clamp coords
x_f_clamped = torch.clamp(x_f, 0, W - 1)
y_f_clamped = torch.clamp(y_f, 0, H - 1)
x_c_clamped = torch.clamp(x_c, 0, W - 1)
y_c_clamped = torch.clamp(y_c, 0, H - 1)
# Per-corner validity masks
mask_nw = (x_f >= 0) & (x_f < W) & (y_f >= 0) & (y_f < H)
mask_ne = (x_c >= 0) & (x_c < W) & (y_f >= 0) & (y_f < H)
mask_sw = (x_f >= 0) & (x_f < W) & (y_c >= 0) & (y_c < H)
mask_se = (x_c >= 0) & (x_c < W) & (y_c >= 0) & (y_c < H)
# Reshape for broadcasting
nw_k = nw_k.unsqueeze(1)
ne_k = ne_k.unsqueeze(1)
sw_k = sw_k.unsqueeze(1)
se_k = se_k.unsqueeze(1)
mask_nw = mask_nw.unsqueeze(1)
mask_ne = mask_ne.unsqueeze(1)
mask_sw = mask_sw.unsqueeze(1)
mask_se = mask_se.unsqueeze(1)
# Flatten indices for scatter_add
b_indices = torch.arange(B, device=im0.device).view(B, 1, 1, 1).expand(-1, C, H, W)
c_indices = torch.arange(C, device=im0.device).view(1, C, 1, 1).expand(B, -1, H, W)
base_idx = b_indices * (C * H * W) + c_indices * (H * W)
# Scatter to 4 neighbors (Accumulate/Splat)
def scatter_corner(y_idx, x_idx, weights, mask):
flat_idx = base_idx + y_idx.unsqueeze(1) * W + x_idx.unsqueeze(1)
values = (im0 * weights) * mask.float()
# Since im1 is contiguous, we can safely use view() for in-place scatter
im1_flat = im1.view(-1)
idx_flat = flat_idx.contiguous().view(-1)
val_flat = values.contiguous().view(-1)
im1_flat.scatter_add_(0, idx_flat, val_flat)
scatter_corner(y_f_clamped, x_f_clamped, nw_k, mask_nw) # NW
scatter_corner(y_f_clamped, x_c_clamped, ne_k, mask_ne) # NE
scatter_corner(y_c_clamped, x_f_clamped, sw_k, mask_sw) # SW
scatter_corner(y_c_clamped, x_c_clamped, se_k, mask_se) # SE
else: # Nearest Neighbor (Legacy fallback)
x_nearest = torch.round(x_dest).long()
y_nearest = torch.round(y_dest).long()
valid_mask = (x_nearest >= 0) & (x_nearest < W) & (y_nearest >= 0) & (y_nearest < H)
valid_mask = valid_mask.unsqueeze(1)
x_clamped = torch.clamp(x_nearest, 0, W - 1)
y_clamped = torch.clamp(y_nearest, 0, H - 1)
b_indices = torch.arange(B, device=im0.device).view(B, 1, 1, 1).expand(-1, C, H, W)
c_indices = torch.arange(C, device=im0.device).view(1, C, 1, 1).expand(B, -1, H, W)
dest_idx = b_indices * (C * H * W) + c_indices * (H * W) + y_clamped.unsqueeze(1) * W + x_clamped.unsqueeze(
1)
source_values = im0 * valid_mask.float()
# Since im1 is contiguous, we can safely use view()
im1.view(-1).scatter_(0, dest_idx.contiguous().view(-1), source_values.contiguous().view(-1))
return im1
@staticmethod
def backward(ctx, grad_output):
return None, None, None
class forward_warp(nn.Module):
def __init__(self, interpolation_mode="Bilinear"):
super(forward_warp, self).__init__()
self.interpolation_mode_int = 0 if interpolation_mode == "Bilinear" else 1
def forward(self, im0, flow):
return ForwardWarpFunction.apply(im0, flow, self.interpolation_mode_int)
# ==============================================================================
# 2. STEREO WARPER WRAPPER
# ==============================================================================
class ForwardWarpStereo(nn.Module):
"""
Weighted Splatting wrapper.
Handles Occlusions using exponential depth weights (Soft Z-Buffering).
"""
def __init__(self, eps=1e-6):
super(ForwardWarpStereo, self).__init__()
self.eps = eps
self.fw = forward_warp(interpolation_mode="Bilinear")
def forward(self, im, disp, convergence, divergence):
# disp comes in as [B, 1, H, W] or [1, 1, H, W]
# We need to squeeze the channel dim to do math with coordinates [B, H, W]
disp_squeeze = disp.squeeze(1) # Shape [B, H, W]
# Create Flow from Disparity
# Shift = (Depth - Convergence) * Divergence
# We negate it because standard flow is source->dest, but disparity logic varies.
# For Right Eye view: Target = Source - Shift. So Flow = -Shift.
shift = (disp_squeeze - convergence) * divergence
flow_x = -shift
# Stack flow (x, y=0) -> (B, H, W, 2)
flow_y = torch.zeros_like(flow_x)
# Stack along last dim: [B, H, W] + [B, H, W] -> [B, H, W, 2]
flow = torch.stack((flow_x, flow_y), dim=-1)
# 1. Calculate Weights (Soft Z-Buffer)
# Closer objects (higher disparity) get exponentially higher weight.
# This allows foreground to overwrite background during accumulation.
# Using 1.5^disp is a tuned heuristic for separation.
disp_norm = disp_squeeze / (disp_squeeze.max() + 1e-8)
weights_map = disp_norm + 0.05
weights_map = weights_map.unsqueeze(1)
# 2. Warp Image * Weights (Accumulate Weighted Color)
# Input im is (B, C, H, W), weights is (B, 1, H, W)
res_accum = self.fw(im * weights_map, flow)
# 3. Warp Weights (Accumulate Weights)
mask_accum = self.fw(weights_map, flow)
# 4. Normalize (Color / TotalWeight)
# Add epsilon to avoid divide-by-zero in empty regions
mask_accum.clamp_(min=self.eps)
res = res_accum / mask_accum
# 5. Generate Binary Occlusion Mask (for Inpainting)
# Splat a grid of ones. Where sum is 0, we have a hole.
ones = torch.ones_like(disp)
occupancy = self.fw(ones, flow)
# Valid pixels have occupancy > 0.
# We want holes = 1.0, filled = 0.0
occlusion_mask = (occupancy < self.eps).float()
return res, occlusion_mask
# ==============================================================================
# 3. APP LOGIC & MODELS
# ==============================================================================
# === LOAD MODELS ===
def load_models():
print("Loading Depth Anything V2 Large...")
depth_model = AutoModelForDepthEstimation.from_pretrained(
"depth-anything/Depth-Anything-V2-Large-hf"
).to(device)
depth_processor = AutoImageProcessor.from_pretrained(
"depth-anything/Depth-Anything-V2-Large-hf"
)
print("Loading LaMa Inpainting Model...")
try:
model_path = hf_hub_download(repo_id="fashn-ai/LaMa", filename="big-lama.pt")
lama_model = torch.jit.load(model_path, map_location=device)
lama_model.eval()
except Exception as e:
print(f"Error loading LaMa model: {e}")
raise e
# Initialize the new Stereo Warper
stereo_warper = ForwardWarpStereo().to(device)
return depth_model, depth_processor, lama_model, stereo_warper
# Load models once at startup
depth_model, depth_processor, lama_model, stereo_warper = load_models()
# === DEPTH ESTIMATION ===
@torch.no_grad()
def estimate_depth(image_pil, model, processor):
original_size = image_pil.size
inputs = processor(images=image_pil, return_tensors="pt").to(device)
depth = model(**inputs).predicted_depth
depth = torch.nn.functional.interpolate(
depth.unsqueeze(1),
size=(original_size[1], original_size[0]),
mode="bicubic",
align_corners=False,
).squeeze()
depth_min, depth_max = depth.min(), depth.max()
if depth_max - depth_min > 0:
depth = (depth - depth_min) / (depth_max - depth_min)
else:
depth = torch.zeros_like(depth)
return depth
# === DEPTH MANIPULATION ===
def erode_depth(depth_tensor, kernel_size):
if kernel_size <= 0: return depth_tensor
k = kernel_size if kernel_size % 2 == 1 else kernel_size + 1
x = depth_tensor.unsqueeze(0).unsqueeze(0)
padding = k // 2
x_eroded = -torch.nn.functional.max_pool2d(-x, kernel_size=k, stride=1, padding=padding)
return x_eroded.squeeze()
# === LOCAL INPAINTING ===
@torch.no_grad()
def run_local_lama(image_bgr, mask_float):
# 0. Dilate Mask slightly to catch edge artifacts from splatting
kernel = np.ones((3, 3), np.uint8)
mask_uint8 = (mask_float * 255).astype(np.uint8)
mask_dilated = cv2.dilate(mask_uint8, kernel, iterations=3)
# 1. Resize to be divisible by 8
h, w = image_bgr.shape[:2]
new_h = (h // 8) * 8
new_w = (w // 8) * 8
img_resized = cv2.resize(image_bgr, (new_w, new_h))
mask_resized = cv2.resize(mask_dilated, (new_w, new_h), interpolation=cv2.INTER_NEAREST)
# 2. Convert to Torch
img_t = torch.from_numpy(img_resized).float().permute(2, 0, 1).unsqueeze(0) / 255.0
img_t = img_t[:, [2, 1, 0], :, :] # BGR to RGB
mask_t = torch.from_numpy(mask_resized).float().unsqueeze(0).unsqueeze(0) / 255.0
mask_t = (mask_t > 0.5).float()
img_t = img_t.to(device)
mask_t = mask_t.to(device)
# 3. Inference
img_t = img_t * (1 - mask_t)
inpainted_t = lama_model(img_t, mask_t)
# 4. Post-process
inpainted = inpainted_t[0].permute(1, 2, 0).cpu().numpy()
inpainted = np.clip(inpainted * 255, 0, 255).astype(np.uint8)
inpainted = cv2.cvtColor(inpainted, cv2.COLOR_RGB2BGR)
if new_h != h or new_w != w:
inpainted = cv2.resize(inpainted, (w, h))
return inpainted
def make_anaglyph(left, right):
l_arr = np.array(left)
r_arr = np.array(right)
anaglyph = np.zeros_like(l_arr)
anaglyph[:, :, 0] = l_arr[:, :, 0]
anaglyph[:, :, 1] = r_arr[:, :, 1]
anaglyph[:, :, 2] = r_arr[:, :, 2]
return Image.fromarray(anaglyph)
# === PIPELINE ===
def stereo_pipeline(image_pil, divergence, convergence, edge_erosion):
if image_pil is None:
return None, None, None, None
# Resize input if too large
w, h = image_pil.size
if w > 1920:
ratio = 1920 / w
new_h = int(h * ratio)
image_pil = image_pil.resize((1920, new_h), Image.LANCZOS)
# 1. Depth Estimation
depth_tensor = estimate_depth(image_pil, depth_model, depth_processor)
# 2. Depth Erosion (optional halo reduction)
if edge_erosion > 0:
depth_tensor = erode_depth(depth_tensor, int(edge_erosion))
# Visualize Depth
depth_vis = (depth_tensor.cpu().numpy() * 255).astype(np.uint8)
depth_image = Image.fromarray(depth_vis)
# 3. Forward Warp (Weighted Bilinear Splatting)
# Convert image to tensor (B, C, H, W)
image_tensor = torch.from_numpy(np.array(image_pil)).float().to(device).permute(2, 0, 1).unsqueeze(0) / 255.0
# Prepare depth tensor (B, 1, H, W)
depth_input = depth_tensor.unsqueeze(0).unsqueeze(0)
# Run the new Stereo Warper
with torch.no_grad():
right_img_tensor, mask_tensor = stereo_warper(
image_tensor,
depth_input,
float(convergence),
float(divergence)
)
# Convert results back to CPU/Numpy
right_img_rgb = (right_img_tensor.squeeze(0).permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
mask_vis = (mask_tensor.squeeze(0).squeeze(0).cpu().numpy() * 255).astype(np.uint8)
mask_image = Image.fromarray(mask_vis)
# 4. Inpainting
right_img_bgr = cv2.cvtColor(right_img_rgb, cv2.COLOR_RGB2BGR)
mask_float = mask_tensor.squeeze().cpu().numpy()
right_filled_bgr = run_local_lama(right_img_bgr, mask_float)
# 5. Finalize
left = image_pil
right = Image.fromarray(cv2.cvtColor(right_filled_bgr, cv2.COLOR_BGR2RGB))
width, height = left.size
combined_image = Image.new('RGB', (width * 2, height))
combined_image.paste(left, (0, 0))
combined_image.paste(right, (width, 0))
anaglyph_image = make_anaglyph(left, right)
return combined_image, anaglyph_image, depth_image, mask_image
# === GRADIO UI ===
with gr.Blocks(title="2D to 3D Stereo") as demo:
# Inject CSS
gr.Markdown("## 2D to 3D Stereo Generator (High-Quality Splatting)")
gr.Markdown("Uses **Depth Anything V2**, **Bilinear Weighted Splatting** (Soft Z-Buffer), and **LaMa Inpainting**.")
with gr.Row():
with gr.Column(scale=1):
input_img = gr.Image(type="pil", label="Input Image", height=320)
with gr.Group():
gr.Markdown("### 3D Controls")
divergence_slider = gr.Slider(
minimum=0, maximum=100, value=30, step=1,
label="3D Strength (Divergence)",
info="Max separation in pixels."
)
convergence_slider = gr.Slider(
minimum=0.0, maximum=1.0, value=0.5, step=0.05,
label="Focus Plane (Convergence)",
info="0.0 = Background at screen. 1.0 = Foreground at screen."
)
erosion_slider = gr.Slider(
minimum=0, maximum=20, value=2, step=1,
label="Edge Masking (Erosion)",
info="Cleanup edges. Set to 0 for raw splatting."
)
btn = gr.Button("Generate 3D", variant="primary")
with gr.Column(scale=1):
out_anaglyph = gr.Image(label="Anaglyph (Red/Cyan)", height=320)
out_stereo = gr.Image(label="Side-by-Side Stereo Pair", height=320)
with gr.Row():
out_depth = gr.Image(label="Depth Map", height=200)
out_mask = gr.Image(label="Inpainting Mask (Holes)", height=200)
btn.click(
fn=stereo_pipeline,
inputs=[input_img, divergence_slider, convergence_slider, erosion_slider],
outputs=[out_stereo, out_anaglyph, out_depth, out_mask]
)
if __name__ == "__main__":
demo.launch()