import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from . import functional as F __all__ = ['FrustumPointNetLoss', 'get_box_corners_3d'] class FrustumPointNetLoss(nn.Module): def __init__(self, num_heading_angle_bins, num_size_templates, size_templates, box_loss_weight=1.0, corners_loss_weight=10.0, heading_residual_loss_weight=20.0, size_residual_loss_weight=20.0): super().__init__() self.box_loss_weight = box_loss_weight self.corners_loss_weight = corners_loss_weight self.heading_residual_loss_weight = heading_residual_loss_weight self.size_residual_loss_weight = size_residual_loss_weight self.num_heading_angle_bins = num_heading_angle_bins self.num_size_templates = num_size_templates self.register_buffer('size_templates', size_templates.view(self.num_size_templates, 3)) self.register_buffer( 'heading_angle_bin_centers', torch.arange(0, 2 * np.pi, 2 * np.pi / self.num_heading_angle_bins) ) def forward(self, inputs, targets): mask_logits = inputs['mask_logits'] # (B, 2, N) center_reg = inputs['center_reg'] # (B, 3) center = inputs['center'] # (B, 3) heading_scores = inputs['heading_scores'] # (B, NH) heading_residuals_normalized = inputs['heading_residuals_normalized'] # (B, NH) heading_residuals = inputs['heading_residuals'] # (B, NH) size_scores = inputs['size_scores'] # (B, NS) size_residuals_normalized = inputs['size_residuals_normalized'] # (B, NS, 3) size_residuals = inputs['size_residuals'] # (B, NS, 3) mask_logits_target = targets['mask_logits'] # (B, N) center_target = targets['center'] # (B, 3) heading_bin_id_target = targets['heading_bin_id'] # (B, ) heading_residual_target = targets['heading_residual'] # (B, ) size_template_id_target = targets['size_template_id'] # (B, ) size_residual_target = targets['size_residual'] # (B, 3) batch_size = center.size(0) batch_id = torch.arange(batch_size, device=center.device) # Basic Classification and Regression losses mask_loss = F.cross_entropy(mask_logits, mask_logits_target) heading_loss = F.cross_entropy(heading_scores, heading_bin_id_target) size_loss = F.cross_entropy(size_scores, size_template_id_target) center_loss = PF.huber_loss(torch.norm(center_target - center, dim=-1), delta=2.0) center_reg_loss = PF.huber_loss(torch.norm(center_target - center_reg, dim=-1), delta=1.0) # Refinement losses for size/heading heading_residuals_normalized = heading_residuals_normalized[batch_id, heading_bin_id_target] # (B, ) heading_residual_normalized_target = heading_residual_target / (np.pi / self.num_heading_angle_bins) heading_residual_normalized_loss = PF.huber_loss( heading_residuals_normalized - heading_residual_normalized_target, delta=1.0 ) size_residuals_normalized = size_residuals_normalized[batch_id, size_template_id_target] # (B, 3) size_residual_normalized_target = size_residual_target / self.size_templates[size_template_id_target] size_residual_normalized_loss = PF.huber_loss( torch.norm(size_residual_normalized_target - size_residuals_normalized, dim=-1), delta=1.0 ) # Bounding box losses heading = (heading_residuals[batch_id, heading_bin_id_target] + self.heading_angle_bin_centers[heading_bin_id_target]) # (B, ) # Warning: in origin code, size_residuals are added twice (issue #43 and #49 in charlesq34/frustum-pointnets) size = (size_residuals[batch_id, size_template_id_target] + self.size_templates[size_template_id_target]) # (B, 3) corners = get_box_corners_3d(centers=center, headings=heading, sizes=size, with_flip=False) # (B, 3, 8) heading_target = self.heading_angle_bin_centers[heading_bin_id_target] + heading_residual_target # (B, ) size_target = self.size_templates[size_template_id_target] + size_residual_target # (B, 3) corners_target, corners_target_flip = get_box_corners_3d(centers=center_target, headings=heading_target, sizes=size_target, with_flip=True) # (B, 3, 8) corners_loss = PF.huber_loss(torch.min( torch.norm(corners - corners_target, dim=1), torch.norm(corners - corners_target_flip, dim=1) ), delta=1.0) # Summing up loss = mask_loss + self.box_loss_weight * ( center_loss + center_reg_loss + heading_loss + size_loss + self.heading_residual_loss_weight * heading_residual_normalized_loss + self.size_residual_loss_weight * size_residual_normalized_loss + self.corners_loss_weight * corners_loss ) return loss def get_box_corners_3d(centers, headings, sizes, with_flip=False): """ :param centers: coords of box centers, FloatTensor[N, 3] :param headings: heading angles, FloatTensor[N, ] :param sizes: box sizes, FloatTensor[N, 3] :param with_flip: bool, whether to return flipped box (headings + np.pi) :return: coords of box corners, FloatTensor[N, 3, 8] NOTE: corner points are in counter clockwise order, e.g., 2--1 3--0 5 7--4 """ l = sizes[:, 0] # (N,) w = sizes[:, 1] # (N,) h = sizes[:, 2] # (N,) x_corners = torch.stack([l/2, l/2, -l/2, -l/2, l/2, l/2, -l/2, -l/2], dim=1) # (N, 8) y_corners = torch.stack([h/2, h/2, h/2, h/2, -h/2, -h/2, -h/2, -h/2], dim=1) # (N, 8) z_corners = torch.stack([w/2, -w/2, -w/2, w/2, w/2, -w/2, -w/2, w/2], dim=1) # (N, 8) c = torch.cos(headings) # (N,) s = torch.sin(headings) # (N,) o = torch.ones_like(headings) # (N,) z = torch.zeros_like(headings) # (N,) centers = centers.unsqueeze(-1) # (B, 3, 1) corners = torch.stack([x_corners, y_corners, z_corners], dim=1) # (N, 3, 8) R = torch.stack([c, z, s, z, o, z, -s, z, c], dim=1).view(-1, 3, 3) # roty matrix: (N, 3, 3) if with_flip: R_flip = torch.stack([-c, z, -s, z, o, z, s, z, -c], dim=1).view(-1, 3, 3) return torch.matmul(R, corners) + centers, torch.matmul(R_flip, corners) + centers else: return torch.matmul(R, corners) + centers # centers = centers.unsqueeze(1) # (B, 1, 3) # corners = torch.stack([x_corners, y_corners, z_corners], dim=-1) # (N, 8, 3) # RT = torch.stack([c, z, -s, z, o, z, s, z, c], dim=1).view(-1, 3, 3) # (N, 3, 3) # if with_flip: # RT_flip = torch.stack([-c, z, s, z, o, z, -s, z, -c], dim=1).view(-1, 3, 3) # (N, 3, 3) # return torch.matmul(corners, RT) + centers, torch.matmul(corners, RT_flip) + centers # (N, 8, 3) # else: # return torch.matmul(corners, RT) + centers # (N, 8, 3) # corners = torch.stack([x_corners, y_corners, z_corners], dim=1) # (N, 3, 8) # R = torch.stack([c, z, s, z, o, z, -s, z, c], dim=1).view(-1, 3, 3) # (N, 3, 3) # corners = torch.matmul(R, corners) + centers.unsqueeze(2) # (N, 3, 8) # corners = corners.transpose(1, 2) # (N, 8, 3)