Spaces:
Running
on
Zero
Running
on
Zero
import spaces | |
import os | |
# this is a HF Spaces specific hack, as | |
# (i) building pytorch3d with GPU support is a bit tricky here | |
# (ii) installing the wheel via requirements.txt breaks ZeroGPU | |
os.system("pip install pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/py310_cu121_pyt221/download.html") | |
import torch | |
import torch.nn.functional as F | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import skimage | |
from PIL import Image | |
import gradio as gr | |
from utils.render import PointsRendererWithMasks, render | |
from utils.ops import snap_high_gradients_to_nn, project_points, get_pointcloud, merge_pointclouds, outpaint_with_depth_estimation | |
from utils.gs import gs_options, read_cameras_from_optimization_bundle, Scene, run_gaussian_splatting, get_blank_gs_bundle | |
from pytorch3d.utils import opencv_from_cameras_projection | |
from utils.ops import focal2fov, fov2focal | |
from utils.models import infer_with_zoe_dc | |
from utils.scene import GaussianModel | |
from utils.demo import downsample_point_cloud | |
from typing import Iterable, Tuple, Dict, Optional | |
import itertools | |
from pytorch3d.structures import Pointclouds | |
from pytorch3d.renderer import ( | |
look_at_view_transform, | |
PerspectiveCameras, | |
) | |
from pytorch3d.io import IO | |
def get_blank_gs_bundle(h, w): | |
return { | |
"camera_angle_x": focal2fov(torch.tensor([w], dtype=torch.float32), w), | |
"W": w, | |
"H": h, | |
"pcd_points": None, | |
"pcd_colors": None, | |
'frames': [], | |
} | |
def extrapolate_point_cloud(prompt: str, image_size: Tuple[int, int], look_at_params: Iterable[Tuple[float, float, float, Tuple[float, float, float]]], point_cloud: Pointclouds = None, dry_run: bool = False, discard_mask: bool = False, initial_image: Optional[Image.Image] = None, depth_scaling: float = 1, **render_kwargs): | |
w, h = image_size | |
optimization_bundle_frames = [] | |
for azim, elev, dist, at in look_at_params: | |
R, T = look_at_view_transform(device=device, azim=azim, elev=elev, dist=dist, at=at) | |
cameras = PerspectiveCameras(R=R, T=T, focal_length=torch.tensor([w], dtype=torch.float32), principal_point=(((h-1)/2, (w-1)/2),), image_size=(image_size,), device=device, in_ndc=False) | |
if point_cloud is not None: | |
images, masks, depths = render(cameras, point_cloud, **render_kwargs) | |
if not dry_run: | |
eroded_mask = skimage.morphology.binary_erosion((depths[0] > 0).cpu().numpy(), footprint=None)#skimage.morphology.disk(1)) | |
eroded_depth = depths[0].clone() | |
eroded_depth[torch.from_numpy(eroded_mask).to(depths.device) <= 0] = 0 | |
outpainted_img, aligned_depth = outpaint_with_depth_estimation(images[0], masks[0], eroded_depth, h, w, pipe, zoe_dc_model, prompt, cameras, dilation_size=2, depth_scaling=depth_scaling, generator=torch.Generator(device=pipe.device).manual_seed(0)) | |
aligned_depth = torch.from_numpy(aligned_depth).to(device) | |
else: | |
# in a dry run, we do not actually outpaint the image | |
outpainted_img = Image.fromarray((255*images[0].cpu().numpy()).astype(np.uint8)) | |
else: | |
assert initial_image is not None | |
assert not dry_run | |
# jumpstart the point cloud with a regular depth estimation | |
t_initial_image = torch.from_numpy(np.asarray(initial_image)/255.).permute(2,0,1).float() | |
depth = aligned_depth = infer_with_zoe_dc(zoe_dc_model, t_initial_image, torch.zeros(h, w)) | |
outpainted_img = initial_image | |
images = [t_initial_image.to(device)] | |
masks = [torch.ones(h, w, dtype=torch.bool).to(device)] | |
if not dry_run: | |
# snap high gradients to nearest neighbor, which eliminates noodle artifacts | |
aligned_depth = snap_high_gradients_to_nn(aligned_depth, threshold=12).cpu() | |
xy_depth_world = project_points(cameras, aligned_depth) | |
c2w = cameras.get_world_to_view_transform().get_matrix()[0] | |
optimization_bundle_frames.append({ | |
"image": outpainted_img, | |
"mask": masks[0].cpu().numpy(), | |
"transform_matrix": c2w.tolist(), | |
"azim": azim, | |
"elev": elev, | |
"dist": dist, | |
}) | |
if discard_mask: | |
optimization_bundle_frames[-1].pop("mask") | |
if not dry_run: | |
optimization_bundle_frames[-1]["center_point"] = xy_depth_world[0].mean(dim=0).tolist() | |
optimization_bundle_frames[-1]["depth"] = aligned_depth.cpu().numpy() | |
optimization_bundle_frames[-1]["mean_depth"] = aligned_depth.mean().item() | |
else: | |
# in a dry run, we do not modify the point cloud | |
continue | |
rgb = (torch.from_numpy(np.asarray(outpainted_img).copy()).reshape(-1, 3).float() / 255).to(device) | |
if point_cloud is None: | |
point_cloud = get_pointcloud(xy_depth_world[0], device=device, features=rgb) | |
else: | |
# pytorch 3d's mask might be slightly too big (subpixels), so we erode it a little to avoid seams | |
# in theory, 1 pixel is sufficient but we use 2 to be safe | |
masks[0] = torch.from_numpy(skimage.morphology.binary_erosion(masks[0].cpu().numpy(), footprint=skimage.morphology.disk(2))).to(device) | |
partial_outpainted_point_cloud = get_pointcloud(xy_depth_world[0][~masks[0].view(-1)], device=device, features=rgb[~masks[0].view(-1)]) | |
point_cloud = merge_pointclouds([point_cloud, partial_outpainted_point_cloud]) | |
return optimization_bundle_frames, point_cloud | |
def generate_point_cloud(initial_image: Image.Image, prompt: str): | |
image_size = initial_image.size | |
w, h = image_size | |
optimization_bundle = get_blank_gs_bundle(h, w) | |
step_size = 25 | |
azim_steps = [0, step_size, -step_size] | |
look_at_params = [(azim, 0, 0.01, torch.zeros((1, 3))) for azim in azim_steps] | |
optimization_bundle["frames"], point_cloud = extrapolate_point_cloud(prompt, image_size, look_at_params, discard_mask=True, initial_image=initial_image, depth_scaling=0.5, fill_point_cloud_holes=True) | |
optimization_bundle["pcd_points"] = point_cloud.points_padded()[0].cpu().numpy() | |
optimization_bundle["pcd_colors"] = point_cloud.features_padded()[0].cpu().numpy() | |
return optimization_bundle, point_cloud | |
def supplement_point_cloud(optimization_bundle: Dict, point_cloud: Pointclouds, prompt: str): | |
w, h = optimization_bundle["W"], optimization_bundle["H"] | |
supporting_frames = [] | |
for i, frame in enumerate(optimization_bundle["frames"]): | |
# skip supporting views | |
if frame.get("supporting", False): | |
continue | |
center_point = torch.tensor(frame["center_point"]).to(device) | |
mean_depth = frame["mean_depth"] | |
azim, elev = frame["azim"], frame["elev"] | |
azim_jitters = torch.linspace(-5, 5, 3).tolist() | |
elev_jitters = torch.linspace(-5, 5, 3).tolist() | |
# build the product of azim and elev jitters | |
camera_jitters = [{"azim": azim + azim_jitter, "elev": elev + elev_jitter} for azim_jitter, elev_jitter in itertools.product(azim_jitters, elev_jitters)] | |
look_at_params = [(camera_jitter["azim"], camera_jitter["elev"], mean_depth, center_point.unsqueeze(0)) for camera_jitter in camera_jitters] | |
local_supporting_frames, point_cloud = extrapolate_point_cloud(prompt, (w, h), look_at_params, point_cloud, dry_run=True, depth_scaling=0.5, antialiasing=3) | |
for local_supporting_frame in local_supporting_frames: | |
local_supporting_frame["supporting"] = True | |
supporting_frames.extend(local_supporting_frames) | |
optimization_bundle["pcd_points"] = point_cloud.points_padded()[0].cpu().numpy() | |
optimization_bundle["pcd_colors"] = point_cloud.features_padded()[0].cpu().numpy() | |
return optimization_bundle, point_cloud | |
def generate_scene(img: Image.Image, prompt: str): | |
assert isinstance(img, Image.Image) | |
# resize image maintaining the aspect ratio so the longest side is 720 pixels | |
max_size = 720 | |
img.thumbnail((max_size, max_size)) | |
# crop to ensure the image dimensions are divisible by 8 | |
img = img.crop((0, 0, img.width - img.width % 8, img.height - img.height % 8)) | |
gs_optimization_bundle, point_cloud = generate_point_cloud(img, prompt) | |
downsampled_point_cloud = downsample_point_cloud(gs_optimization_bundle, device=device) | |
gs_optimization_bundle["pcd_points"] = downsampled_point_cloud.points_padded()[0].cpu().numpy() | |
gs_optimization_bundle["pcd_colors"] = downsampled_point_cloud.features_padded()[0].cpu().numpy() | |
scene = Scene(gs_optimization_bundle, GaussianModel(gs_options.sh_degree), gs_options) | |
scene.gaussians._opacity = torch.ones_like(scene.gaussians._opacity) | |
#scene = run_gaussian_splatting(scene, gs_optimization_bundle) | |
# coordinate system transformation | |
scene.gaussians._xyz = scene.gaussians._xyz.detach() | |
scene.gaussians._xyz[:, 1] = -scene.gaussians._xyz[:, 1] | |
scene.gaussians._xyz[:, 2] = -scene.gaussians._xyz[:, 2] | |
save_path = "./output.ply" | |
scene.gaussians.save_ply(save_path) | |
return save_path | |
if __name__ == "__main__": | |
global device | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
from utils.models import get_zoe_dc_model, get_sd_pipeline | |
global zoe_dc_model | |
from huggingface_hub import hf_hub_download | |
zoe_dc_model = get_zoe_dc_model(ckpt_path=hf_hub_download(repo_id="paulengstler/invisible-stitch", filename="invisible-stitch.pt")).to(device) | |
global pipe | |
pipe = get_sd_pipeline().to(device) | |
demo = gr.Interface( | |
fn=generate_scene, | |
inputs=[ | |
gr.Image(label="Input Image", sources=["upload", "clipboard"], type="pil"), | |
gr.Textbox(label="Scene Hallucination Prompt") | |
], | |
outputs=gr.Model3D(label="Generated Scene"), | |
allow_flagging="never", | |
title="Invisible Stitch: Generating Smooth 3D Scenes with Depth Inpainting", | |
description="Hallucinate geometrically coherent 3D scenes from a single input image in less than 30 seconds.<br /> [Project Page](https://research.paulengstler.com/invisible-stitch) | [GitHub](https://github.com/paulengstler/invisible-stitch) | [Paper](https://arxiv.org/abs/2404.19758) <br /><br />To keep this demo snappy, we have limited its functionality. Scenes are generated at a low resolution without densification, supporting views are not inpainted, and we do not optimize the resulting point cloud. Imperfections are to be expected, in particular around object borders. Please allow a couple of seconds for the generated scene to be downloaded (about 40 megabytes).", | |
article="Please consider running this demo locally to obtain high-quality results (see the GitHub repository).<br /><br />Here are some observations we made that might help you to get better results:<ul><li>Use generic prompts that match the surroundings of your input image.</li><li>Ensure that the borders of your input image are free from partially visible objects.</li><li>Keep your prompts simple and avoid adding specific details.</li></ul>", | |
examples=[ | |
["examples/photo-1667788000333-4e36f948de9a.jpeg", "a street with traditional buildings in Kyoto, Japan"], | |
["examples/photo-1628624747186-a941c476b7ef.jpeg", "a suburban street in North Carolina on a bright, sunny day"], | |
["examples/photo-1469559845082-95b66baaf023.jpeg", "a view of Zion National Park"], | |
["examples/photo-1514984879728-be0aff75a6e8.jpeg", "a close-up view of a muddy path in a forest"], | |
["examples/photo-1618197345638-d2df92b39fe1.jpeg", "a close-up view of a white linen bed in a minimalistic room"], | |
["examples/photo-1546975490-e8b92a360b24.jpeg", "a warm living room with plants"], | |
["examples/photo-1499916078039-922301b0eb9b.jpeg", "a cozy bedroom on a bright day"], | |
]) | |
demo.queue().launch(share=True) | |