silentchen's picture
first commit
19c4ddf
raw
history blame contribute delete
No virus
2.72 kB
import base64
import io
from typing import Union, Optional
import numpy as np
import torch
from PIL import Image
from shap_e.models.nn.camera import DifferentiableCameraBatch, DifferentiableProjectiveCamera
from shap_e.models.transmitter.base import Transmitter, VectorDecoder
from shap_e.rendering.torch_mesh import TorchMesh
from shap_e.util.collections import AttrDict
def create_pan_cameras(size: int, device: torch.device, batch_size: Optional[int] = 1, dist: int = 4) -> DifferentiableCameraBatch:
origins = []
xs = []
ys = []
zs = []
for theta in np.linspace(0, 2 * np.pi, num=20):
z = np.array([np.sin(theta), np.cos(theta), -0.5])
z /= np.sqrt(np.sum(z**2))
origin = -z * dist
x = np.array([np.cos(theta), -np.sin(theta), 0.0])
y = np.cross(z, x)
origins.append(origin)
xs.append(x)
ys.append(y)
zs.append(z)
return DifferentiableCameraBatch(
shape=(batch_size, len(xs)),
flat_camera=DifferentiableProjectiveCamera(
origin=torch.from_numpy(np.stack(origins, axis=0)).float().to(device).repeat(batch_size, 1),
x=torch.from_numpy(np.stack(xs, axis=0)).float().to(device).repeat(batch_size, 1),
y=torch.from_numpy(np.stack(ys, axis=0)).float().to(device).repeat(batch_size, 1),
z=torch.from_numpy(np.stack(zs, axis=0)).float().to(device).repeat(batch_size, 1),
width=size,
height=size,
x_fov=0.7,
y_fov=0.7,
),
)
@torch.no_grad()
def decode_latent_images(
xm: Union[Transmitter, VectorDecoder],
latent: torch.Tensor,
cameras: DifferentiableCameraBatch,
rendering_mode: str = "stf",
):
# import pdb; pdb.set_trace()
decoded = xm.renderer.render_views(
AttrDict(cameras=cameras),
params=(xm.encoder if isinstance(xm, Transmitter) else xm).bottleneck_to_params(
latent[None]
),
options=AttrDict(rendering_mode=rendering_mode, render_with_direction=False),
)
import pdb; pdb.set_trace()
arr = decoded.channels.clamp(0, 255).to(torch.uint8)[0].cpu().numpy()
return [Image.fromarray(x) for x in arr]
@torch.no_grad()
def decode_latent_mesh(
xm: Union[Transmitter, VectorDecoder],
latent: torch.Tensor,
) -> TorchMesh:
decoded = xm.renderer.render_views(
AttrDict(cameras=create_pan_cameras(2, latent.device)), # lowest resolution possible
params=(xm.encoder if isinstance(xm, Transmitter) else xm).bottleneck_to_params(
latent[None]
),
options=AttrDict(rendering_mode="stf", render_with_direction=False),
)
return decoded.raw_meshes[0]