Paul Engstler
add arxiv link
1f0b43e
raw
history blame
12.2 kB
import spaces
import os
# this is a HF Spaces specific hack, as
# (i) building pytorch3d with GPU support is a bit tricky here
# (ii) installing the wheel via requirements.txt breaks ZeroGPU
os.system("pip install pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/py310_cu121_pyt221/download.html")
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
import skimage
from PIL import Image
import gradio as gr
from utils.render import PointsRendererWithMasks, render
from utils.ops import snap_high_gradients_to_nn, project_points, get_pointcloud, merge_pointclouds, outpaint_with_depth_estimation
from utils.gs import gs_options, read_cameras_from_optimization_bundle, Scene, run_gaussian_splatting, get_blank_gs_bundle
from pytorch3d.utils import opencv_from_cameras_projection
from utils.ops import focal2fov, fov2focal
from utils.models import infer_with_zoe_dc
from utils.scene import GaussianModel
from utils.demo import downsample_point_cloud
from typing import Iterable, Tuple, Dict, Optional
import itertools
from pytorch3d.structures import Pointclouds
from pytorch3d.renderer import (
look_at_view_transform,
PerspectiveCameras,
)
from pytorch3d.io import IO
def get_blank_gs_bundle(h, w):
return {
"camera_angle_x": focal2fov(torch.tensor([w], dtype=torch.float32), w),
"W": w,
"H": h,
"pcd_points": None,
"pcd_colors": None,
'frames': [],
}
@spaces.GPU(duration=30)
def extrapolate_point_cloud(prompt: str, image_size: Tuple[int, int], look_at_params: Iterable[Tuple[float, float, float, Tuple[float, float, float]]], point_cloud: Pointclouds = None, dry_run: bool = False, discard_mask: bool = False, initial_image: Optional[Image.Image] = None, depth_scaling: float = 1, **render_kwargs):
w, h = image_size
optimization_bundle_frames = []
for azim, elev, dist, at in look_at_params:
R, T = look_at_view_transform(device=device, azim=azim, elev=elev, dist=dist, at=at)
cameras = PerspectiveCameras(R=R, T=T, focal_length=torch.tensor([w], dtype=torch.float32), principal_point=(((h-1)/2, (w-1)/2),), image_size=(image_size,), device=device, in_ndc=False)
if point_cloud is not None:
images, masks, depths = render(cameras, point_cloud, **render_kwargs)
if not dry_run:
eroded_mask = skimage.morphology.binary_erosion((depths[0] > 0).cpu().numpy(), footprint=None)#skimage.morphology.disk(1))
eroded_depth = depths[0].clone()
eroded_depth[torch.from_numpy(eroded_mask).to(depths.device) <= 0] = 0
outpainted_img, aligned_depth = outpaint_with_depth_estimation(images[0], masks[0], eroded_depth, h, w, pipe, zoe_dc_model, prompt, cameras, dilation_size=2, depth_scaling=depth_scaling, generator=torch.Generator(device=pipe.device).manual_seed(0))
aligned_depth = torch.from_numpy(aligned_depth).to(device)
else:
# in a dry run, we do not actually outpaint the image
outpainted_img = Image.fromarray((255*images[0].cpu().numpy()).astype(np.uint8))
else:
assert initial_image is not None
assert not dry_run
# jumpstart the point cloud with a regular depth estimation
t_initial_image = torch.from_numpy(np.asarray(initial_image)/255.).permute(2,0,1).float()
depth = aligned_depth = infer_with_zoe_dc(zoe_dc_model, t_initial_image, torch.zeros(h, w))
outpainted_img = initial_image
images = [t_initial_image.to(device)]
masks = [torch.ones(h, w, dtype=torch.bool).to(device)]
if not dry_run:
# snap high gradients to nearest neighbor, which eliminates noodle artifacts
aligned_depth = snap_high_gradients_to_nn(aligned_depth, threshold=12).cpu()
xy_depth_world = project_points(cameras, aligned_depth)
c2w = cameras.get_world_to_view_transform().get_matrix()[0]
optimization_bundle_frames.append({
"image": outpainted_img,
"mask": masks[0].cpu().numpy(),
"transform_matrix": c2w.tolist(),
"azim": azim,
"elev": elev,
"dist": dist,
})
if discard_mask:
optimization_bundle_frames[-1].pop("mask")
if not dry_run:
optimization_bundle_frames[-1]["center_point"] = xy_depth_world[0].mean(dim=0).tolist()
optimization_bundle_frames[-1]["depth"] = aligned_depth.cpu().numpy()
optimization_bundle_frames[-1]["mean_depth"] = aligned_depth.mean().item()
else:
# in a dry run, we do not modify the point cloud
continue
rgb = (torch.from_numpy(np.asarray(outpainted_img).copy()).reshape(-1, 3).float() / 255).to(device)
if point_cloud is None:
point_cloud = get_pointcloud(xy_depth_world[0], device=device, features=rgb)
else:
# pytorch 3d's mask might be slightly too big (subpixels), so we erode it a little to avoid seams
# in theory, 1 pixel is sufficient but we use 2 to be safe
masks[0] = torch.from_numpy(skimage.morphology.binary_erosion(masks[0].cpu().numpy(), footprint=skimage.morphology.disk(2))).to(device)
partial_outpainted_point_cloud = get_pointcloud(xy_depth_world[0][~masks[0].view(-1)], device=device, features=rgb[~masks[0].view(-1)])
point_cloud = merge_pointclouds([point_cloud, partial_outpainted_point_cloud])
return optimization_bundle_frames, point_cloud
@spaces.GPU(duration=30)
def generate_point_cloud(initial_image: Image.Image, prompt: str):
image_size = initial_image.size
w, h = image_size
optimization_bundle = get_blank_gs_bundle(h, w)
step_size = 25
azim_steps = [0, step_size, -step_size]
look_at_params = [(azim, 0, 0.01, torch.zeros((1, 3))) for azim in azim_steps]
optimization_bundle["frames"], point_cloud = extrapolate_point_cloud(prompt, image_size, look_at_params, discard_mask=True, initial_image=initial_image, depth_scaling=0.5, fill_point_cloud_holes=True)
optimization_bundle["pcd_points"] = point_cloud.points_padded()[0].cpu().numpy()
optimization_bundle["pcd_colors"] = point_cloud.features_padded()[0].cpu().numpy()
return optimization_bundle, point_cloud
@spaces.GPU(duration=30)
def supplement_point_cloud(optimization_bundle: Dict, point_cloud: Pointclouds, prompt: str):
w, h = optimization_bundle["W"], optimization_bundle["H"]
supporting_frames = []
for i, frame in enumerate(optimization_bundle["frames"]):
# skip supporting views
if frame.get("supporting", False):
continue
center_point = torch.tensor(frame["center_point"]).to(device)
mean_depth = frame["mean_depth"]
azim, elev = frame["azim"], frame["elev"]
azim_jitters = torch.linspace(-5, 5, 3).tolist()
elev_jitters = torch.linspace(-5, 5, 3).tolist()
# build the product of azim and elev jitters
camera_jitters = [{"azim": azim + azim_jitter, "elev": elev + elev_jitter} for azim_jitter, elev_jitter in itertools.product(azim_jitters, elev_jitters)]
look_at_params = [(camera_jitter["azim"], camera_jitter["elev"], mean_depth, center_point.unsqueeze(0)) for camera_jitter in camera_jitters]
local_supporting_frames, point_cloud = extrapolate_point_cloud(prompt, (w, h), look_at_params, point_cloud, dry_run=True, depth_scaling=0.5, antialiasing=3)
for local_supporting_frame in local_supporting_frames:
local_supporting_frame["supporting"] = True
supporting_frames.extend(local_supporting_frames)
optimization_bundle["pcd_points"] = point_cloud.points_padded()[0].cpu().numpy()
optimization_bundle["pcd_colors"] = point_cloud.features_padded()[0].cpu().numpy()
return optimization_bundle, point_cloud
@spaces.GPU(duration=30)
def generate_scene(img: Image.Image, prompt: str):
assert isinstance(img, Image.Image)
# resize image maintaining the aspect ratio so the longest side is 720 pixels
max_size = 720
img.thumbnail((max_size, max_size))
# crop to ensure the image dimensions are divisible by 8
img = img.crop((0, 0, img.width - img.width % 8, img.height - img.height % 8))
gs_optimization_bundle, point_cloud = generate_point_cloud(img, prompt)
downsampled_point_cloud = downsample_point_cloud(gs_optimization_bundle, device=device)
gs_optimization_bundle["pcd_points"] = downsampled_point_cloud.points_padded()[0].cpu().numpy()
gs_optimization_bundle["pcd_colors"] = downsampled_point_cloud.features_padded()[0].cpu().numpy()
scene = Scene(gs_optimization_bundle, GaussianModel(gs_options.sh_degree), gs_options)
scene.gaussians._opacity = torch.ones_like(scene.gaussians._opacity)
#scene = run_gaussian_splatting(scene, gs_optimization_bundle)
# coordinate system transformation
scene.gaussians._xyz = scene.gaussians._xyz.detach()
scene.gaussians._xyz[:, 1] = -scene.gaussians._xyz[:, 1]
scene.gaussians._xyz[:, 2] = -scene.gaussians._xyz[:, 2]
save_path = "./output.ply"
scene.gaussians.save_ply(save_path)
return save_path
if __name__ == "__main__":
global device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
from utils.models import get_zoe_dc_model, get_sd_pipeline
global zoe_dc_model
from huggingface_hub import hf_hub_download
zoe_dc_model = get_zoe_dc_model(ckpt_path=hf_hub_download(repo_id="paulengstler/invisible-stitch", filename="invisible-stitch.pt")).to(device)
global pipe
pipe = get_sd_pipeline().to(device)
demo = gr.Interface(
fn=generate_scene,
inputs=[
gr.Image(label="Input Image", sources=["upload", "clipboard"], type="pil"),
gr.Textbox(label="Scene Hallucination Prompt")
],
outputs=gr.Model3D(label="Generated Scene"),
allow_flagging="never",
title="Invisible Stitch: Generating Smooth 3D Scenes with Depth Inpainting",
description="Hallucinate geometrically coherent 3D scenes from a single input image in less than 30 seconds.<br /> [Project Page](https://research.paulengstler.com/invisible-stitch) | [GitHub](https://github.com/paulengstler/invisible-stitch) | [Paper](https://arxiv.org/abs/2404.19758) <br /><br />To keep this demo snappy, we have limited its functionality. Scenes are generated at a low resolution without densification, supporting views are not inpainted, and we do not optimize the resulting point cloud. Imperfections are to be expected, in particular around object borders. Please allow a couple of seconds for the generated scene to be downloaded (about 40 megabytes).",
article="Please consider running this demo locally to obtain high-quality results (see the GitHub repository).<br /><br />Here are some observations we made that might help you to get better results:<ul><li>Use generic prompts that match the surroundings of your input image.</li><li>Ensure that the borders of your input image are free from partially visible objects.</li><li>Keep your prompts simple and avoid adding specific details.</li></ul>",
examples=[
["examples/photo-1667788000333-4e36f948de9a.jpeg", "a street with traditional buildings in Kyoto, Japan"],
["examples/photo-1628624747186-a941c476b7ef.jpeg", "a suburban street in North Carolina on a bright, sunny day"],
["examples/photo-1469559845082-95b66baaf023.jpeg", "a view of Zion National Park"],
["examples/photo-1514984879728-be0aff75a6e8.jpeg", "a close-up view of a muddy path in a forest"],
["examples/photo-1618197345638-d2df92b39fe1.jpeg", "a close-up view of a white linen bed in a minimalistic room"],
["examples/photo-1546975490-e8b92a360b24.jpeg", "a warm living room with plants"],
["examples/photo-1499916078039-922301b0eb9b.jpeg", "a cozy bedroom on a bright day"],
])
demo.queue().launch(share=True)