# -*- coding: utf-8 -*- import torch import torch.nn as nn import torch.nn.functional as F from typing import Optional, Tuple, Dict from michelangelo.models.modules.distributions import DiagonalGaussianDistribution from michelangelo.utils.eval import compute_psnr from michelangelo.utils import misc class KLNearFar(nn.Module): def __init__(self, near_weight: float = 0.1, kl_weight: float = 1.0, num_near_samples: Optional[int] = None): super().__init__() self.near_weight = near_weight self.kl_weight = kl_weight self.num_near_samples = num_near_samples self.geo_criterion = nn.BCEWithLogitsLoss() def forward(self, posteriors: Optional[DiagonalGaussianDistribution], logits: torch.FloatTensor, labels: torch.FloatTensor, split: Optional[str] = "train", **kwargs) -> Tuple[torch.FloatTensor, Dict[str, float]]: """ Args: posteriors (DiagonalGaussianDistribution or torch.distributions.Normal): logits (torch.FloatTensor): [B, 2*N], logits[:, 0:N] is the volume points; logits[:, N:2N] is the near points; labels (torch.FloatTensor): [B, 2*N], labels[:, 0:N] is the volume points; labels[:, N:2N] is the near points; split (str): **kwargs: Returns: loss (torch.Tensor): (,) log (dict): """ if self.num_near_samples is None: num_vol = logits.shape[1] // 2 else: num_vol = logits.shape[1] - self.num_near_samples vol_logits = logits[:, 0:num_vol] vol_labels = labels[:, 0:num_vol] near_logits = logits[:, num_vol:] near_labels = labels[:, num_vol:] # occupancy loss # vol_bce = self.geo_criterion(vol_logits, vol_labels) # near_bce = self.geo_criterion(near_logits, near_labels) vol_bce = self.geo_criterion(vol_logits.float(), vol_labels.float()) near_bce = self.geo_criterion(near_logits.float(), near_labels.float()) if posteriors is None: kl_loss = torch.tensor(0.0, dtype=vol_logits.dtype, device=vol_logits.device) else: kl_loss = posteriors.kl(dims=(1, 2)) kl_loss = torch.mean(kl_loss) loss = vol_bce + near_bce * self.near_weight + kl_loss * self.kl_weight with torch.no_grad(): preds = logits >= 0 accuracy = (preds == labels).float() accuracy = accuracy.mean() pos_ratio = torch.mean(labels) log = { "{}/total_loss".format(split): loss.clone().detach(), "{}/near".format(split): near_bce.detach(), "{}/far".format(split): vol_bce.detach(), "{}/kl".format(split): kl_loss.detach(), "{}/accuracy".format(split): accuracy, "{}/pos_ratio".format(split): pos_ratio } if posteriors is not None: log[f"{split}/mean"] = posteriors.mean.mean().detach() log[f"{split}/std_mean"] = posteriors.std.mean().detach() log[f"{split}/std_max"] = posteriors.std.max().detach() return loss, log class KLNearFarColor(nn.Module): def __init__(self, near_weight: float = 0.1, kl_weight: float = 1.0, color_weight: float = 1.0, color_criterion: str = "mse", num_near_samples: Optional[int] = None): super().__init__() self.color_weight = color_weight self.near_weight = near_weight self.kl_weight = kl_weight self.num_near_samples = num_near_samples if color_criterion == "mse": self.color_criterion = nn.MSELoss() elif color_criterion == "l1": self.color_criterion = nn.L1Loss() else: raise ValueError(f"{color_criterion} must be [`mse`, `l1`].") self.geo_criterion = nn.BCEWithLogitsLoss() def forward(self, posteriors: Optional[DiagonalGaussianDistribution], logits: torch.FloatTensor, labels: torch.FloatTensor, pred_colors: torch.FloatTensor, gt_colors: torch.FloatTensor, split: Optional[str] = "train", **kwargs) -> Tuple[torch.FloatTensor, Dict[str, float]]: """ Args: posteriors (DiagonalGaussianDistribution or torch.distributions.Normal): logits (torch.FloatTensor): [B, 2*N], logits[:, 0:N] is the volume points; logits[:, N:2N] is the near points; labels (torch.FloatTensor): [B, 2*N], labels[:, 0:N] is the volume points; labels[:, N:2N] is the near points; pred_colors (torch.FloatTensor): [B, M, 3] gt_colors (torch.FloatTensor): [B, M, 3] split (str): **kwargs: Returns: loss (torch.Tensor): (,) log (dict): """ if self.num_near_samples is None: num_vol = logits.shape[1] // 2 else: num_vol = logits.shape[1] - self.num_near_samples vol_logits = logits[:, 0:num_vol] vol_labels = labels[:, 0:num_vol] near_logits = logits[:, num_vol:] near_labels = labels[:, num_vol:] # occupancy loss # vol_bce = self.geo_criterion(vol_logits, vol_labels) # near_bce = self.geo_criterion(near_logits, near_labels) vol_bce = self.geo_criterion(vol_logits.float(), vol_labels.float()) near_bce = self.geo_criterion(near_logits.float(), near_labels.float()) # surface color loss color = self.color_criterion(pred_colors, gt_colors) if posteriors is None: kl_loss = torch.tensor(0.0, dtype=pred_colors.dtype, device=pred_colors.device) else: kl_loss = posteriors.kl(dims=(1, 2)) kl_loss = torch.mean(kl_loss) loss = vol_bce + near_bce * self.near_weight + color * self.color_weight + kl_loss * self.kl_weight with torch.no_grad(): preds = logits >= 0 accuracy = (preds == labels).float() accuracy = accuracy.mean() psnr = compute_psnr(pred_colors, gt_colors) log = { "{}/total_loss".format(split): loss.clone().detach(), "{}/near".format(split): near_bce.detach(), "{}/far".format(split): vol_bce.detach(), "{}/color".format(split): color.detach(), "{}/kl".format(split): kl_loss.detach(), "{}/psnr".format(split): psnr.detach(), "{}/accuracy".format(split): accuracy } return loss, log class ContrastKLNearFar(nn.Module): def __init__(self, contrast_weight: float = 1.0, near_weight: float = 0.1, kl_weight: float = 1.0, num_near_samples: Optional[int] = None): super().__init__() self.labels = None self.last_local_batch_size = None self.contrast_weight = contrast_weight self.near_weight = near_weight self.kl_weight = kl_weight self.num_near_samples = num_near_samples self.geo_criterion = nn.BCEWithLogitsLoss() def forward(self, shape_embed: torch.FloatTensor, text_embed: torch.FloatTensor, image_embed: torch.FloatTensor, logit_scale: torch.FloatTensor, posteriors: Optional[DiagonalGaussianDistribution], shape_logits: torch.FloatTensor, shape_labels: torch.FloatTensor, split: Optional[str] = "train", **kwargs): local_batch_size = shape_embed.size(0) if local_batch_size != self.last_local_batch_size: self.labels = local_batch_size * misc.get_rank() + torch.arange( local_batch_size, device=shape_embed.device ).long() self.last_local_batch_size = local_batch_size # normalized features shape_embed = F.normalize(shape_embed, dim=-1, p=2) text_embed = F.normalize(text_embed, dim=-1, p=2) image_embed = F.normalize(image_embed, dim=-1, p=2) # gather features from all GPUs shape_embed_all, text_embed_all, image_embed_all = misc.all_gather_batch( [shape_embed, text_embed, image_embed] ) # cosine similarity as logits logits_per_shape_text = logit_scale * shape_embed @ text_embed_all.t() logits_per_text_shape = logit_scale * text_embed @ shape_embed_all.t() logits_per_shape_image = logit_scale * shape_embed @ image_embed_all.t() logits_per_image_shape = logit_scale * image_embed @ shape_embed_all.t() contrast_loss = (F.cross_entropy(logits_per_shape_text, self.labels) + F.cross_entropy(logits_per_text_shape, self.labels)) / 2 + \ (F.cross_entropy(logits_per_shape_image, self.labels) + F.cross_entropy(logits_per_image_shape, self.labels)) / 2 # shape reconstruction if self.num_near_samples is None: num_vol = shape_logits.shape[1] // 2 else: num_vol = shape_logits.shape[1] - self.num_near_samples vol_logits = shape_logits[:, 0:num_vol] vol_labels = shape_labels[:, 0:num_vol] near_logits = shape_logits[:, num_vol:] near_labels = shape_labels[:, num_vol:] # occupancy loss vol_bce = self.geo_criterion(vol_logits.float(), vol_labels.float()) near_bce = self.geo_criterion(near_logits.float(), near_labels.float()) if posteriors is None: kl_loss = torch.tensor(0.0, dtype=vol_logits.dtype, device=vol_logits.device) else: kl_loss = posteriors.kl(dims=(1, 2)) kl_loss = torch.mean(kl_loss) loss = vol_bce + near_bce * self.near_weight + kl_loss * self.kl_weight + contrast_loss * self.contrast_weight # compute accuracy with torch.no_grad(): pred = torch.argmax(logits_per_shape_text, dim=-1) correct = pred.eq(self.labels).sum() shape_text_acc = 100 * correct / local_batch_size pred = torch.argmax(logits_per_shape_image, dim=-1) correct = pred.eq(self.labels).sum() shape_image_acc = 100 * correct / local_batch_size preds = shape_logits >= 0 accuracy = (preds == shape_labels).float() accuracy = accuracy.mean() log = { "{}/contrast".format(split): contrast_loss.clone().detach(), "{}/near".format(split): near_bce.detach(), "{}/far".format(split): vol_bce.detach(), "{}/kl".format(split): kl_loss.detach(), "{}/shape_text_acc".format(split): shape_text_acc, "{}/shape_image_acc".format(split): shape_image_acc, "{}/total_loss".format(split): loss.clone().detach(), "{}/accuracy".format(split): accuracy, } if posteriors is not None: log[f"{split}/mean"] = posteriors.mean.mean().detach() log[f"{split}/std_mean"] = posteriors.std.mean().detach() log[f"{split}/std_max"] = posteriors.std.max().detach() return loss, log