wyysf's picture
i
0f079b2
raw
history blame
10.4 kB
from dataclasses import dataclass
import torch
import torch.nn as nn
from torch import distributed as tdist
from torch.nn import functional as F
import math
import mcubes
import numpy as np
from einops import repeat, rearrange
from skimage import measure
from craftsman.utils.base import BaseModule
from craftsman.utils.typing import *
from craftsman.utils.misc import get_world_size
from craftsman.utils.ops import generate_dense_grid_points
VALID_EMBED_TYPES = ["identity", "fourier", "hashgrid", "sphere_harmonic", "triplane_fourier"]
class FourierEmbedder(nn.Module):
def __init__(self,
num_freqs: int = 6,
logspace: bool = True,
input_dim: int = 3,
include_input: bool = True,
include_pi: bool = True) -> None:
super().__init__()
if logspace:
frequencies = 2.0 ** torch.arange(
num_freqs,
dtype=torch.float32
)
else:
frequencies = torch.linspace(
1.0,
2.0 ** (num_freqs - 1),
num_freqs,
dtype=torch.float32
)
if include_pi:
frequencies *= torch.pi
self.register_buffer("frequencies", frequencies, persistent=False)
self.include_input = include_input
self.num_freqs = num_freqs
self.out_dim = self.get_dims(input_dim)
def get_dims(self, input_dim):
temp = 1 if self.include_input or self.num_freqs == 0 else 0
out_dim = input_dim * (self.num_freqs * 2 + temp)
return out_dim
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.num_freqs > 0:
embed = (x[..., None].contiguous() * self.frequencies).view(*x.shape[:-1], -1)
if self.include_input:
return torch.cat((x, embed.sin(), embed.cos()), dim=-1)
else:
return torch.cat((embed.sin(), embed.cos()), dim=-1)
else:
return x
class LearnedFourierEmbedder(nn.Module):
def __init__(self, input_dim, dim):
super().__init__()
assert (dim % 2) == 0
half_dim = dim // 2
per_channel_dim = half_dim // input_dim
self.weights = nn.Parameter(torch.randn(per_channel_dim))
self.out_dim = self.get_dims(input_dim)
def forward(self, x):
# [b, t, c, 1] * [1, d] = [b, t, c, d] -> [b, t, c * d]
freqs = (x[..., None] * self.weights[None] * 2 * np.pi).view(*x.shape[:-1], -1)
fouriered = torch.cat((x, freqs.sin(), freqs.cos()), dim=-1)
return fouriered
def get_dims(self, input_dim):
return input_dim * (self.weights.shape[0] * 2 + 1)
class Sine(nn.Module):
def __init__(self, w0 = 1.):
super().__init__()
self.w0 = w0
def forward(self, x):
return torch.sin(self.w0 * x)
class Siren(nn.Module):
def __init__(
self,
in_dim,
out_dim,
w0 = 1.,
c = 6.,
is_first = False,
use_bias = True,
activation = None,
dropout = 0.
):
super().__init__()
self.in_dim = in_dim
self.out_dim = out_dim
self.is_first = is_first
weight = torch.zeros(out_dim, in_dim)
bias = torch.zeros(out_dim) if use_bias else None
self.init_(weight, bias, c = c, w0 = w0)
self.weight = nn.Parameter(weight)
self.bias = nn.Parameter(bias) if use_bias else None
self.activation = Sine(w0) if activation is None else activation
self.dropout = nn.Dropout(dropout)
def init_(self, weight, bias, c, w0):
dim = self.in_dim
w_std = (1 / dim) if self.is_first else (math.sqrt(c / dim) / w0)
weight.uniform_(-w_std, w_std)
if bias is not None:
bias.uniform_(-w_std, w_std)
def forward(self, x):
out = F.linear(x, self.weight, self.bias)
out = self.activation(out)
out = self.dropout(out)
return out
def get_embedder(embed_type="fourier", num_freqs=-1, input_dim=3, include_pi=True):
if embed_type == "identity" or (embed_type == "fourier" and num_freqs == -1):
return nn.Identity(), input_dim
elif embed_type == "fourier":
embedder_obj = FourierEmbedder(num_freqs=num_freqs, include_pi=include_pi)
elif embed_type == "learned_fourier":
embedder_obj = LearnedFourierEmbedder(in_channels=input_dim, dim=num_freqs)
elif embed_type == "siren":
embedder_obj = Siren(in_dim=input_dim, out_dim=num_freqs * input_dim * 2 + input_dim)
elif embed_type == "hashgrid":
raise NotImplementedError
elif embed_type == "sphere_harmonic":
raise NotImplementedError
else:
raise ValueError(f"{embed_type} is not valid. Currently only supprts {VALID_EMBED_TYPES}")
return embedder_obj
###################### AutoEncoder
class AutoEncoder(BaseModule):
@dataclass
class Config(BaseModule.Config):
pretrained_model_name_or_path: str = ""
num_latents: int = 256
embed_dim: int = 64
width: int = 768
cfg: Config
def configure(self) -> None:
super().configure()
def encode(self, x: torch.FloatTensor) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
raise NotImplementedError
def decode(self, z: torch.FloatTensor) -> torch.FloatTensor:
raise NotImplementedError
def encode_kl_embed(self, latents: torch.FloatTensor, sample_posterior: bool = True):
posterior = None
if self.cfg.embed_dim > 0:
moments = self.pre_kl(latents)
posterior = DiagonalGaussianDistribution(moments, feat_dim=-1)
if sample_posterior:
kl_embed = posterior.sample()
else:
kl_embed = posterior.mode()
else:
kl_embed = latents
return kl_embed, posterior
def forward(self,
surface: torch.FloatTensor,
queries: torch.FloatTensor,
sample_posterior: bool = True):
shape_latents, kl_embed, posterior = self.encode(surface, sample_posterior=sample_posterior)
latents = self.decode(kl_embed) # [B, num_latents, width]
logits = self.query(queries, latents) # [B,]
return shape_latents, latents, posterior, logits
def query(self, queries: torch.FloatTensor, latents: torch.FloatTensor) -> torch.FloatTensor:
raise NotImplementedError
@torch.no_grad()
def extract_geometry(self,
latents: torch.FloatTensor,
bounds: Union[Tuple[float], List[float], float] = (-1.05, -1.05, -1.05, 1.05, 1.05, 1.05),
octree_depth: int = 8,
num_chunks: int = 10000,
):
if isinstance(bounds, float):
bounds = [-bounds, -bounds, -bounds, bounds, bounds, bounds]
bbox_min = np.array(bounds[0:3])
bbox_max = np.array(bounds[3:6])
bbox_size = bbox_max - bbox_min
xyz_samples, grid_size, length = generate_dense_grid_points(
bbox_min=bbox_min,
bbox_max=bbox_max,
octree_depth=octree_depth,
indexing="ij"
)
xyz_samples = torch.FloatTensor(xyz_samples)
batch_size = latents.shape[0]
batch_logits = []
for start in range(0, xyz_samples.shape[0], num_chunks):
queries = xyz_samples[start: start + num_chunks, :].to(latents)
batch_queries = repeat(queries, "p c -> b p c", b=batch_size)
logits = self.query(batch_queries, latents)
batch_logits.append(logits.cpu())
grid_logits = torch.cat(batch_logits, dim=1).view((batch_size, grid_size[0], grid_size[1], grid_size[2])).float().numpy()
mesh_v_f = []
has_surface = np.zeros((batch_size,), dtype=np.bool_)
for i in range(batch_size):
try:
vertices, faces, normals, _ = measure.marching_cubes(grid_logits[i], 0, method="lewiner")
# vertices, faces = mcubes.marching_cubes(grid_logits[i], 0)
vertices = vertices / grid_size * bbox_size + bbox_min
faces = faces[:, [2, 1, 0]]
mesh_v_f.append((vertices.astype(np.float32), np.ascontiguousarray(faces)))
has_surface[i] = True
except:
mesh_v_f.append((None, None))
has_surface[i] = False
return mesh_v_f, has_surface
class DiagonalGaussianDistribution(object):
def __init__(self, parameters: Union[torch.Tensor, List[torch.Tensor]], deterministic=False, feat_dim=1):
self.feat_dim = feat_dim
self.parameters = parameters
if isinstance(parameters, list):
self.mean = parameters[0]
self.logvar = parameters[1]
else:
self.mean, self.logvar = torch.chunk(parameters, 2, dim=feat_dim)
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
self.deterministic = deterministic
self.std = torch.exp(0.5 * self.logvar)
self.var = torch.exp(self.logvar)
if self.deterministic:
self.var = self.std = torch.zeros_like(self.mean)
def sample(self):
x = self.mean + self.std * torch.randn_like(self.mean)
return x
def kl(self, other=None, dims=(1, 2)):
if self.deterministic:
return torch.Tensor([0.])
else:
if other is None:
return 0.5 * torch.mean(torch.pow(self.mean, 2)
+ self.var - 1.0 - self.logvar,
dim=dims)
else:
return 0.5 * torch.mean(
torch.pow(self.mean - other.mean, 2) / other.var
+ self.var / other.var - 1.0 - self.logvar + other.logvar,
dim=dims)
def nll(self, sample, dims=(1, 2)):
if self.deterministic:
return torch.Tensor([0.])
logtwopi = np.log(2.0 * np.pi)
return 0.5 * torch.sum(
logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
dim=dims)
def mode(self):
return self.mean