Sharp-It / shap-e /shap_e /util /notebooks.py
YiftachEde's picture
Add shap-e without large binary files
efa71f7
raw
history blame
2.84 kB
import base64
import io
from typing import Union
import ipywidgets as widgets
import numpy as np
import torch
from PIL import Image
from shap_e.models.nn.camera import DifferentiableCameraBatch, DifferentiableProjectiveCamera
from shap_e.models.transmitter.base import Transmitter, VectorDecoder
from shap_e.rendering.torch_mesh import TorchMesh
from shap_e.util.collections import AttrDict
def create_pan_cameras(size: int, device: torch.device) -> DifferentiableCameraBatch:
origins = []
xs = []
ys = []
zs = []
for theta in np.linspace(0, 2 * np.pi, num=20):
z = np.array([np.sin(theta), np.cos(theta), -0.5])
z /= np.sqrt(np.sum(z**2))
origin = -z * 4
x = np.array([np.cos(theta), -np.sin(theta), 0.0])
y = np.cross(z, x)
origins.append(origin)
xs.append(x)
ys.append(y)
zs.append(z)
return DifferentiableCameraBatch(
shape=(1, len(xs)),
flat_camera=DifferentiableProjectiveCamera(
origin=torch.from_numpy(np.stack(origins, axis=0)).float().to(device),
x=torch.from_numpy(np.stack(xs, axis=0)).float().to(device),
y=torch.from_numpy(np.stack(ys, axis=0)).float().to(device),
z=torch.from_numpy(np.stack(zs, axis=0)).float().to(device),
width=size,
height=size,
x_fov=0.7,
y_fov=0.7,
),
)
@torch.no_grad()
def decode_latent_images(
xm: Union[Transmitter, VectorDecoder],
latent: torch.Tensor,
cameras: DifferentiableCameraBatch,
rendering_mode: str = "stf",
):
decoded = xm.renderer.render_views(
AttrDict(cameras=cameras),
params=(xm.encoder if isinstance(xm, Transmitter) else xm).bottleneck_to_params(
latent[None]
),
options=AttrDict(rendering_mode=rendering_mode, render_with_direction=False),
)
arr = decoded.channels.clamp(0, 255).to(torch.uint8)[0].cpu().numpy()
return [Image.fromarray(x) for x in arr]
@torch.no_grad()
def decode_latent_mesh(
xm: Union[Transmitter, VectorDecoder],
latent: torch.Tensor,
) -> TorchMesh:
decoded = xm.renderer.render_views(
AttrDict(cameras=create_pan_cameras(2, latent.device)), # lowest resolution possible
params=(xm.encoder if isinstance(xm, Transmitter) else xm).bottleneck_to_params(
latent[None]
),
options=AttrDict(rendering_mode="stf", render_with_direction=False),
)
return decoded.raw_meshes[0]
def gif_widget(images):
writer = io.BytesIO()
images[0].save(
writer, format="GIF", save_all=True, append_images=images[1:], duration=100, loop=0
)
writer.seek(0)
data = base64.b64encode(writer.read()).decode("ascii")
return widgets.HTML(f'<img src="data:image/gif;base64,{data}" />')