Spaces:
Paused
Paused
| from dataclasses import dataclass, field | |
| import numpy as np | |
| import torch | |
| from skimage import measure | |
| from einops import repeat, rearrange | |
| import step1x3d_geometry | |
| from step1x3d_geometry.systems.base import BaseSystem | |
| from step1x3d_geometry.utils.ops import generate_dense_grid_points | |
| from step1x3d_geometry.utils.typing import * | |
| from step1x3d_geometry.utils.misc import get_rank | |
| class ShapeAutoEncoderSystem(BaseSystem): | |
| class Config(BaseSystem.Config): | |
| shape_model_type: str = None | |
| shape_model: dict = field(default_factory=dict) | |
| sample_posterior: bool = True | |
| # for mesh extraction | |
| bounds: float = 1.05 | |
| mc_level: float = 0.0 | |
| octree_resolution: int = 256 | |
| cfg: Config | |
| def configure(self): | |
| super().configure() | |
| self.shape_model = step1x3d_geometry.find(self.cfg.shape_model_type)( | |
| self.cfg.shape_model | |
| ) | |
| def forward(self, batch: Dict[str, Any]) -> Dict[str, Any]: | |
| rand_points = batch["rand_points"] | |
| if "sdf" in batch: | |
| target = batch["sdf"] | |
| criteria = torch.nn.MSELoss() | |
| elif "occupancies" in batch: | |
| target = batch["occupancies"] | |
| criteria = torch.nn.BCEWithLogitsLoss() | |
| else: | |
| raise NotImplementedError | |
| # forward pass | |
| num_point_feats = 3 + self.cfg.shape_model.point_feats | |
| shape_latents, kl_embed, posterior = self.shape_model.encode( | |
| batch["surface"][..., :num_point_feats], | |
| sharp_surface=( | |
| batch["sharp_surface"][..., :num_point_feats] | |
| if "sharp_surface" in batch | |
| else None | |
| ), | |
| sample_posterior=self.cfg.sample_posterior, | |
| ) | |
| latents = self.shape_model.decode(kl_embed) # [B, num_latents, width] | |
| logits = self.shape_model.query(rand_points, latents).squeeze( | |
| -1 | |
| ) # [B, num_rand_points] | |
| if self.cfg.sample_posterior: | |
| loss_kl = posterior.kl() | |
| loss_kl = torch.sum(loss_kl) / loss_kl.shape[0] | |
| return { | |
| "loss_logits": criteria(logits, target).mean(), | |
| "loss_kl": loss_kl, | |
| "logits": logits, | |
| "target": target, | |
| "latents": latents, | |
| } | |
| else: | |
| return { | |
| "loss_logits": criteria(logits, target).mean(), | |
| "latents": latents, | |
| "logits": logits, | |
| } | |
| def training_step(self, batch, batch_idx): | |
| """ | |
| Description: | |
| Args: | |
| batch: | |
| batch_idx: | |
| Returns: | |
| loss: | |
| """ | |
| out = self(batch) | |
| loss = 0.0 | |
| for name, value in out.items(): | |
| if name.startswith("loss_"): | |
| self.log(f"train/{name}", value) | |
| loss += value * self.C(self.cfg.loss[name.replace("loss_", "lambda_")]) | |
| for name, value in self.cfg.loss.items(): | |
| self.log(f"train_params/{name}", self.C(value)) | |
| return {"loss": loss} | |
| def validation_step(self, batch, batch_idx): | |
| self.eval() | |
| out = self(batch) | |
| meshes = self.shape_model.extract_geometry( | |
| out["latents"], | |
| bounds=self.cfg.bounds, | |
| mc_level=self.cfg.mc_level, | |
| octree_resolution=self.cfg.octree_resolution, | |
| enable_pbar=False, | |
| ) | |
| for idx, name in enumerate(batch["uid"]): | |
| self.save_mesh( | |
| f"it{self.true_global_step}/{name}.obj", | |
| meshes[idx].verts, | |
| meshes[idx].faces, | |
| ) | |
| threshold = 0 | |
| outputs = out["logits"] | |
| labels = out["target"] | |
| pred = torch.zeros_like(outputs) | |
| pred[outputs >= threshold] = 1 | |
| accuracy = (pred == labels).float().sum(dim=1) / labels.shape[1] | |
| accuracy = accuracy.mean() | |
| intersection = (pred * labels).sum(dim=1) | |
| union = (pred + labels).gt(0).sum(dim=1) | |
| iou = intersection * 1.0 / union + 1e-5 | |
| iou = iou.mean() | |
| self.log("val/accuracy", accuracy) | |
| self.log("val/iou", iou) | |
| torch.cuda.empty_cache() | |
| return { | |
| "val/loss": out["loss_logits"], | |
| "val/accuracy": accuracy, | |
| "val/iou": iou, | |
| } | |
| def on_validation_epoch_end(self): | |
| pass | |
| def test_step(self, batch, batch_idx): | |
| return | |