not-lain's picture
Duplicate from Maikou/Michelangelo
522e1ad verified
# -*- coding: utf-8 -*-
from typing import List, Tuple, Dict, Optional
from omegaconf import DictConfig
import torch
import torch.nn.functional as F
from torch.optim import lr_scheduler
import pytorch_lightning as pl
from typing import Union
from functools import partial
from michelangelo.utils import instantiate_from_config
from .inference_utils import extract_geometry
from .tsal_base import (
AlignedShapeAsLatentModule,
ShapeAsLatentModule,
Latent2MeshOutput,
AlignedMeshOutput
)
class AlignedShapeAsLatentPLModule(pl.LightningModule):
def __init__(self, *,
shape_module_cfg,
aligned_module_cfg,
loss_cfg,
optimizer_cfg: Optional[DictConfig] = None,
ckpt_path: Optional[str] = None,
ignore_keys: Union[Tuple[str], List[str]] = ()):
super().__init__()
shape_model: ShapeAsLatentModule = instantiate_from_config(
shape_module_cfg, device=None, dtype=None
)
self.model: AlignedShapeAsLatentModule = instantiate_from_config(
aligned_module_cfg, shape_model=shape_model
)
self.loss = instantiate_from_config(loss_cfg)
self.optimizer_cfg = optimizer_cfg
if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
self.save_hyperparameters()
def set_shape_model_only(self):
self.model.set_shape_model_only()
@property
def latent_shape(self):
return self.model.shape_model.latent_shape
@property
def zero_rank(self):
if self._trainer:
zero_rank = self.trainer.local_rank == 0
else:
zero_rank = True
return zero_rank
def init_from_ckpt(self, path, ignore_keys=()):
state_dict = torch.load(path, map_location="cpu")["state_dict"]
keys = list(state_dict.keys())
for k in keys:
for ik in ignore_keys:
if k.startswith(ik):
print("Deleting key {} from state_dict.".format(k))
del state_dict[k]
missing, unexpected = self.load_state_dict(state_dict, strict=False)
print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
if len(missing) > 0:
print(f"Missing Keys: {missing}")
print(f"Unexpected Keys: {unexpected}")
def configure_optimizers(self) -> Tuple[List, List]:
lr = self.learning_rate
trainable_parameters = list(self.model.parameters())
if self.optimizer_cfg is None:
optimizers = [torch.optim.AdamW(trainable_parameters, lr=lr, betas=(0.9, 0.99), weight_decay=1e-3)]
schedulers = []
else:
optimizer = instantiate_from_config(self.optimizer_cfg.optimizer, params=trainable_parameters)
scheduler_func = instantiate_from_config(
self.optimizer_cfg.scheduler,
max_decay_steps=self.trainer.max_steps,
lr_max=lr
)
scheduler = {
"scheduler": lr_scheduler.LambdaLR(optimizer, lr_lambda=scheduler_func.schedule),
"interval": "step",
"frequency": 1
}
optimizers = [optimizer]
schedulers = [scheduler]
return optimizers, schedulers
def forward(self,
surface: torch.FloatTensor,
image: torch.FloatTensor,
text: torch.FloatTensor,
volume_queries: torch.FloatTensor):
"""
Args:
surface (torch.FloatTensor):
image (torch.FloatTensor):
text (torch.FloatTensor):
volume_queries (torch.FloatTensor):
Returns:
"""
embed_outputs, shape_z = self.model(surface, image, text)
shape_zq, posterior = self.model.shape_model.encode_kl_embed(shape_z)
latents = self.model.shape_model.decode(shape_zq)
logits = self.model.shape_model.query_geometry(volume_queries, latents)
return embed_outputs, logits, posterior
def encode(self, surface: torch.FloatTensor, sample_posterior=True):
pc = surface[..., 0:3]
feats = surface[..., 3:6]
shape_embed, shape_zq, posterior = self.model.shape_model.encode(
pc=pc, feats=feats, sample_posterior=sample_posterior
)
return shape_zq
def decode(self,
z_q,
bounds: Union[Tuple[float], List[float], float] = 1.1,
octree_depth: int = 7,
num_chunks: int = 10000) -> List[Latent2MeshOutput]:
latents = self.model.shape_model.decode(z_q) # latents: [bs, num_latents, dim]
outputs = self.latent2mesh(latents, bounds=bounds, octree_depth=octree_depth, num_chunks=num_chunks)
return outputs
def training_step(self, batch: Dict[str, torch.FloatTensor],
batch_idx: int, optimizer_idx: int = 0) -> torch.FloatTensor:
"""
Args:
batch (dict): the batch sample, and it contains:
- surface (torch.FloatTensor): [bs, n_surface, (3 + input_dim)]
- image (torch.FloatTensor): [bs, 3, 224, 224]
- text (torch.FloatTensor): [bs, num_templates, 77]
- geo_points (torch.FloatTensor): [bs, n_pts, (3 + 1)]
batch_idx (int):
optimizer_idx (int):
Returns:
loss (torch.FloatTensor):
"""
surface = batch["surface"]
image = batch["image"]
text = batch["text"]
volume_queries = batch["geo_points"][..., 0:3]
shape_labels = batch["geo_points"][..., -1]
embed_outputs, shape_logits, posteriors = self(surface, image, text, volume_queries)
aeloss, log_dict_ae = self.loss(
**embed_outputs,
posteriors=posteriors,
shape_logits=shape_logits,
shape_labels=shape_labels,
split="train"
)
self.log_dict(log_dict_ae, prog_bar=True, logger=True, batch_size=shape_logits.shape[0],
sync_dist=False, rank_zero_only=True)
return aeloss
def validation_step(self, batch: Dict[str, torch.FloatTensor], batch_idx: int) -> torch.FloatTensor:
surface = batch["surface"]
image = batch["image"]
text = batch["text"]
volume_queries = batch["geo_points"][..., 0:3]
shape_labels = batch["geo_points"][..., -1]
embed_outputs, shape_logits, posteriors = self(surface, image, text, volume_queries)
aeloss, log_dict_ae = self.loss(
**embed_outputs,
posteriors=posteriors,
shape_logits=shape_logits,
shape_labels=shape_labels,
split="val"
)
self.log_dict(log_dict_ae, prog_bar=True, logger=True, batch_size=shape_logits.shape[0],
sync_dist=False, rank_zero_only=True)
return aeloss
def visual_alignment(self,
surface: torch.FloatTensor,
image: torch.FloatTensor,
text: torch.FloatTensor,
description: Optional[List[str]] = None,
bounds: Union[Tuple[float], List[float]] = (-1.25, -1.25, -1.25, 1.25, 1.25, 1.25),
octree_depth: int = 7,
num_chunks: int = 10000) -> List[AlignedMeshOutput]:
"""
Args:
surface:
image:
text:
description:
bounds:
octree_depth:
num_chunks:
Returns:
mesh_outputs (List[AlignedMeshOutput]): the mesh outputs list.
"""
outputs = []
device = surface.device
bs = surface.shape[0]
embed_outputs, shape_z = self.model(surface, image, text)
# calculate the similarity
image_embed = embed_outputs["image_embed"]
text_embed = embed_outputs["text_embed"]
shape_embed = embed_outputs["shape_embed"]
# normalized features
shape_embed = F.normalize(shape_embed, dim=-1, p=2)
text_embed = F.normalize(text_embed, dim=-1, p=2)
image_embed = F.normalize(image_embed, dim=-1, p=2)
# B x B
shape_text_similarity = (100.0 * shape_embed @ text_embed.T).softmax(dim=-1)
# B x B
shape_image_similarity = (100.0 * shape_embed @ image_embed.T).softmax(dim=-1)
# shape reconstruction
shape_zq, posterior = self.model.shape_model.encode_kl_embed(shape_z)
latents = self.model.shape_model.decode(shape_zq)
geometric_func = partial(self.model.shape_model.query_geometry, latents=latents)
# 2. decode geometry
mesh_v_f, has_surface = extract_geometry(
geometric_func=geometric_func,
device=device,
batch_size=bs,
bounds=bounds,
octree_depth=octree_depth,
num_chunks=num_chunks,
disable=not self.zero_rank
)
# 3. decode texture
for i, ((mesh_v, mesh_f), is_surface) in enumerate(zip(mesh_v_f, has_surface)):
if not is_surface:
outputs.append(None)
continue
out = AlignedMeshOutput()
out.mesh_v = mesh_v
out.mesh_f = mesh_f
out.surface = surface[i].cpu().numpy()
out.image = image[i].cpu().numpy()
if description is not None:
out.text = description[i]
out.shape_text_similarity = shape_text_similarity[i, i]
out.shape_image_similarity = shape_image_similarity[i, i]
outputs.append(out)
return outputs
def latent2mesh(self,
latents: torch.FloatTensor,
bounds: Union[Tuple[float], List[float], float] = 1.1,
octree_depth: int = 7,
num_chunks: int = 10000) -> List[Latent2MeshOutput]:
"""
Args:
latents: [bs, num_latents, dim]
bounds:
octree_depth:
num_chunks:
Returns:
mesh_outputs (List[MeshOutput]): the mesh outputs list.
"""
outputs = []
geometric_func = partial(self.model.shape_model.query_geometry, latents=latents)
# 2. decode geometry
device = latents.device
mesh_v_f, has_surface = extract_geometry(
geometric_func=geometric_func,
device=device,
batch_size=len(latents),
bounds=bounds,
octree_depth=octree_depth,
num_chunks=num_chunks,
disable=not self.zero_rank
)
# 3. decode texture
for i, ((mesh_v, mesh_f), is_surface) in enumerate(zip(mesh_v_f, has_surface)):
if not is_surface:
outputs.append(None)
continue
out = Latent2MeshOutput()
out.mesh_v = mesh_v
out.mesh_f = mesh_f
outputs.append(out)
return outputs