"""Implementation of Pointnet.""" from __future__ import annotations import torch from torch import nn from vis4d.common.ckpt import load_model_checkpoint from vis4d.common.typing import LossesType, ModelOutput from vis4d.data.const import CommonKeys from vis4d.op.base.pointnet import PointNetSegmentation, PointNetSemanticsOut from vis4d.op.loss.orthogonal_transform_loss import ( OrthogonalTransformRegularizationLoss, ) class PointnetSegmentationModel(nn.Module): """Simple Segmentation Model using Pointnet.""" def __init__( self, num_classes: int = 11, in_dimensions: int = 3, weights: str | None = None, ) -> None: """Simple Segmentation Model using Pointnet. Args: num_classes: Number of semantic classes in_dimensions: Input dimension weights: Path to weight file """ super().__init__() self.model = PointNetSegmentation( n_classes=num_classes, in_dimensions=in_dimensions ) if weights is not None: load_model_checkpoint(self, weights) def __call__( self, data: torch.Tensor, target: torch.Tensor | None = None ) -> PointNetSemanticsOut | ModelOutput: """Runs the semantic model. Args: data: Input Tensor Shape [N, C, n_pts] target: Target Classes shape [N, n_pts] """ return self._call_impl(data, target) def forward( self, data: torch.Tensor, target: torch.Tensor | None = None ) -> PointNetSemanticsOut | ModelOutput: """Runs the semantic model. Args: data: Input Tensor Shape [N, C, n_pts] target: Target Classes shape [N, n_pts] """ if target is not None: return self.forward_train(data, target) return self.forward_test(data) def forward_train( self, points: torch.Tensor, target: torch.Tensor, ) -> PointNetSemanticsOut: """Forward training stage. Args: points: Input Tensor Shape [N, C, n_pts] target: Target Classes shape [N, n_pts] """ out = self.model(points) return out def forward_test( self, points: torch.Tensor, ) -> ModelOutput: """Forward test stage. Args: points: Input Tensor Shape [N, C, n_pts] """ return { CommonKeys.semantics3d: torch.argmax( self.model(points).class_logits, dim=1 ) } class PointnetSegmentationLoss(nn.Module): """PointnetSegmentationLoss Loss.""" def __init__( self, regularize_transform: bool = True, ignore_index: int = 255, transform_weight: float = 1e-3, semantic_weights: torch.Tensor | None = None, ) -> None: """Creates an instance of the class. Args: regularize_transform: If true add transforms to loss ignore_index: Semantic class that should be ignored transform_weight: Loss weight factor for transform regularization loss semantic_weights: Classwise weights for semantic loss """ super().__init__() self.segmentation_loss = nn.CrossEntropyLoss( weight=semantic_weights, ignore_index=ignore_index ) self.transformation_loss = OrthogonalTransformRegularizationLoss() self.regularize_transform = regularize_transform self.transform_weight = transform_weight def forward( self, outputs: PointNetSemanticsOut, target: torch.Tensor ) -> LossesType: """Calculates the losss. Args: outputs: Pointnet output target: Target Labels """ if not self.regularize_transform: dict( segmentation_loss=self.segmentation_loss( outputs.class_logits, target ) ) return dict( segmentation_loss=self.segmentation_loss( outputs.class_logits, target ), transform_loss=self.transform_weight * self.transformation_loss(outputs.transformations), )