from typing import List, Optional, Tuple import torch from torch import nn import torch.nn.functional as F from . import axis_ops, ilnr_loss from .vnl_loss import VNL_Loss from .midas_loss import MidasLoss from .detr.detr import MLP from .detr.transformer import Transformer from .detr.backbone import Backbone, Joiner from .detr.position_encoding import PositionEmbeddingSine from .detr.misc import nested_tensor_from_tensor_list, interpolate from .detr import box_ops from .detr.segmentation import ( MHAttentionMap, MaskHeadSmallConv, dice_loss, sigmoid_focal_loss ) class INTR(torch.nn.Module): """ Implement Interaction 3D Transformer. """ def __init__( self, backbone_name = 'resnet50', image_size = [192, 256], ignore_index = -100, num_classes = 1, num_queries = 15, freeze_backbone = False, transformer_hidden_dim = 256, transformer_dropout = 0.1, transformer_nhead = 8, transformer_dim_feedforward = 2048, transformer_num_encoder_layers = 6, transformer_num_decoder_layers = 6, transformer_normalize_before = False, transformer_return_intermediate_dec = True, layers_movable = 3, layers_rigid = 3, layers_kinematic = 3, layers_action = 3, layers_axis = 2, layers_affordance = 3, affordance_focal_alpha = 0.95, axis_bins = 30, depth_on = True, ): """ Initializes the model. Parameters: backbone: torch module of the backbone to be used. See backbone.py transformer: torch module of the transformer architecture. See transformer.py num_classes: number of object classes num_queries: number of object queries, ie detection slot. This is the maximal number of objects DETR can detect in a single image. For COCO, we recommend 100 queries. aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used. """ super().__init__() self._ignore_index = ignore_index self._image_size = image_size self._axis_bins = axis_bins self._affordance_focal_alpha = affordance_focal_alpha # backbone backbone_base = Backbone(backbone_name, not freeze_backbone, True, False) N_steps = transformer_hidden_dim // 2 position_embedding = PositionEmbeddingSine(N_steps, normalize=True) backbone = Joiner(backbone_base, position_embedding) backbone.num_channels = backbone_base.num_channels self.backbone = backbone self.transformer = Transformer( d_model=transformer_hidden_dim, dropout=transformer_dropout, nhead=transformer_nhead, dim_feedforward=transformer_dim_feedforward, num_encoder_layers=transformer_num_encoder_layers, num_decoder_layers=transformer_num_decoder_layers, normalize_before=transformer_normalize_before, return_intermediate_dec=transformer_return_intermediate_dec, ) hidden_dim = self.transformer.d_model self.hidden_dim = hidden_dim nheads = self.transformer.nhead self.num_queries = num_queries # before transformer, input_proj maps 2048 channel resnet50 output to 512-channel # transformer input self.input_proj = nn.Conv2d(self.backbone.num_channels, hidden_dim, kernel_size=1) # query mlp maps 2d keypoint coordinates to 256-dim positional encoding self.query_mlp = MLP(2, hidden_dim, hidden_dim, 2) # bbox MLP self.bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3) if layers_movable > 1: self.movable_embed = MLP(hidden_dim, hidden_dim, 3, layers_movable) elif layers_movable == 1: self.movable_embed = nn.Linear(hidden_dim, 3) else: raise ValueError("not supported") if layers_rigid > 1: self.rigid_embed = MLP(hidden_dim, hidden_dim, 2, layers_rigid) elif layers_rigid == 1: #self.rigid_embed = nn.Linear(hidden_dim, 2) self.rigid_embed = nn.Linear(hidden_dim, 3) else: raise ValueError("not supported") if layers_kinematic > 1: self.kinematic_embed = MLP(hidden_dim, hidden_dim, 3, layers_kinematic) elif layers_kinematic == 1: self.kinematic_embed = nn.Linear(hidden_dim, 3) else: raise ValueError("not supported") if layers_action > 1: self.action_embed = MLP(hidden_dim, hidden_dim, 3, layers_action) elif layers_action == 1: self.action_embed = nn.Linear(hidden_dim, 3) else: raise ValueError("not supported") if layers_axis > 1: #self.axis_embed = MLP(hidden_dim, hidden_dim, 4, layers_axis) self.axis_embed = MLP(hidden_dim, hidden_dim, 3, layers_axis) # classification # self.axis_embed = MLP(hidden_dim, hidden_dim, self._axis_bins * 2, layers_axis) elif layers_axis == 1: self.axis_embed = nn.Linear(hidden_dim, 3) else: raise ValueError("not supported") # affordance if layers_affordance > 1: self.aff_embed = MLP(hidden_dim, hidden_dim, 2, layers_affordance) elif layers_affordance == 1: self.aff_embed = nn.Linear(hidden_dim, 2) else: raise ValueError("not supported") # affordance head self.aff_attention = MHAttentionMap(hidden_dim, hidden_dim, nheads, dropout=0.0) self.aff_head = MaskHeadSmallConv(hidden_dim + nheads, [1024, 512, 256], hidden_dim, nheads) # mask head self.bbox_attention = MHAttentionMap(hidden_dim, hidden_dim, nheads, dropout=0.0) self.mask_head = MaskHeadSmallConv(hidden_dim + nheads, [1024, 512, 256], hidden_dim, nheads) # depth head self._depth_on = depth_on if self._depth_on: self.depth_query = nn.Embedding(1, hidden_dim) self.depth_attention = MHAttentionMap(hidden_dim, hidden_dim, nheads, dropout=0.0) self.depth_head = MaskHeadSmallConv(hidden_dim + nheads, [1024, 512, 256], hidden_dim, nheads) self.depth_loss = ilnr_loss.MEADSTD_TANH_NORM_Loss() fov = torch.tensor(1.0) focal_length = (image_size[1] / 2 / torch.tan(fov / 2)).item() self.vnl_loss = VNL_Loss(focal_length, focal_length, image_size) self.midas_loss = MidasLoss(alpha=0.1) def freeze_layers(self, names): """ Freeze layers in 'names'. """ for name, param in self.named_parameters(): for freeze_name in names: if freeze_name in name: #print(name + ' ' + freeze_name) param.requires_grad = False def forward( self, image: torch.Tensor, valid: torch.Tensor, keypoints: torch.Tensor, bbox: torch.Tensor, masks: torch.Tensor, movable: torch.Tensor, rigid: torch.Tensor, kinematic: torch.Tensor, action: torch.Tensor, affordance: torch.Tensor, affordance_map: torch.FloatTensor, depth: torch.Tensor, axis: torch.Tensor, fov: torch.Tensor, backward: bool = True, **kwargs, ): """ Model forward. Set backward = False if the model is inference only. """ device = image.device # number of queries can be different in runtime num_queries = keypoints.shape[1] # DETR forward samples = image if isinstance(samples, (list, torch.Tensor)): samples = nested_tensor_from_tensor_list(samples) features, pos = self.backbone(samples) bs = features[-1].tensors.shape[0] src, mask = features[-1].decompose() assert mask is not None # sample keypoint queries from the positional embedding use_sine = False if use_sine: anchors = keypoints.float() anchors_float = anchors.clone() anchors_float = anchors_float.reshape(-1, 2) anchors_float[:, 0] = ((anchors_float[:, 0] / self._image_size[1]) - 0.5) * 2 anchors_float[:, 1] = ((anchors_float[:, 1] / self._image_size[0]) - 0.5) * 2 anchors_float = anchors_float.unsqueeze(1).unsqueeze(1) # 4x256x1x1 keypoint_queries = F.grid_sample( #pos[0].repeat(self.num_queries, 1, 1, 1), pos[-1].repeat(self.num_queries, 1, 1, 1), anchors_float, mode='nearest', align_corners=True ) # 4 x 10 (number of object queires) x 256 keypoint_queries = keypoint_queries.squeeze().reshape(-1, self.num_queries, self.hidden_dim) else: # use learned MLP to map postional encoding anchors = keypoints.float() anchors_float = anchors.clone() anchors_float[:, :, 0] = ((anchors_float[:, :, 0] / self._image_size[1]) - 0.5) * 2 anchors_float[:, :, 1] = ((anchors_float[:, :, 1] / self._image_size[0]) - 0.5) * 2 keypoint_queries = self.query_mlp(anchors_float) # append depth_query if the model is learning depth. if self._depth_on: bs = keypoint_queries.shape[0] depth_query = self.depth_query.weight.unsqueeze(0).repeat(bs, 1, 1) keypoint_queries = torch.cat((keypoint_queries, depth_query), dim=1) # transformer forward src_proj = self.input_proj(src) hs, memory = self.transformer(src_proj, mask, keypoint_queries, pos[-1]) if self._depth_on: depth_hs = hs[-1][:, -1:] ord_hs = hs[-1][:, :-1] else: ord_hs = hs[-1] outputs_coord = self.bbox_embed(ord_hs).sigmoid() outputs_movable = self.movable_embed(ord_hs) outputs_rigid = self.rigid_embed(ord_hs) outputs_kinematic = self.kinematic_embed(ord_hs) outputs_action = self.action_embed(ord_hs) # axis forward outputs_axis = self.axis_embed(ord_hs).sigmoid() # sigmoid range is 0 to 1, we want it to be -1 to 1 outputs_axis = (outputs_axis - 0.5) * 2 # affordance forward bbox_aff = self.aff_attention(ord_hs, memory, mask=mask) aff_masks = self.aff_head(src_proj, bbox_aff, [features[2].tensors, features[1].tensors, features[0].tensors]) outputs_aff_masks = aff_masks.view(bs, num_queries, aff_masks.shape[-2], aff_masks.shape[-1]) # mask forward bbox_mask = self.bbox_attention(ord_hs, memory, mask=mask) seg_masks = self.mask_head(src_proj, bbox_mask, [features[2].tensors, features[1].tensors, features[0].tensors]) outputs_seg_masks = seg_masks.view(bs, num_queries, seg_masks.shape[-2], seg_masks.shape[-1]) # depth forward outputs_depth = None if self._depth_on: depth_att = self.depth_attention(depth_hs, memory, mask=mask) depth_masks = self.depth_head( src_proj, depth_att, [features[2].tensors, features[1].tensors, features[0].tensors] ) outputs_depth = depth_masks.view(bs, 1, depth_masks.shape[-2], depth_masks.shape[-1]) out = { 'pred_boxes': box_ops.box_cxcywh_to_xyxy(outputs_coord), 'pred_movable': outputs_movable, 'pred_rigid': outputs_rigid, 'pred_kinematic': outputs_kinematic, 'pred_action': outputs_action, 'pred_masks': outputs_seg_masks, 'pred_axis': outputs_axis, 'pred_depth': outputs_depth, 'pred_affordance': outputs_aff_masks, } if not backward: return out # backward src_boxes = outputs_coord target_boxes = bbox target_boxes = box_ops.box_xyxy_to_cxcywh(target_boxes) bbox_valid = bbox[:, :, 0] > -0.5 num_boxes = bbox_valid.sum() if num_boxes == 0: out['loss_bbox'] = torch.tensor(0.0, requires_grad=True).to(device) out['loss_giou'] = torch.tensor(0.0, requires_grad=True).to(device) else: loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction='none') loss_bbox = loss_bbox * bbox_valid.unsqueeze(2) # remove invalid out['loss_bbox'] = loss_bbox.sum() / num_boxes loss_giou = 1 - torch.diag(box_ops.generalized_box_iou( box_ops.box_cxcywh_to_xyxy(src_boxes).reshape(-1, 4), box_ops.box_cxcywh_to_xyxy(target_boxes).reshape(-1, 4), )).reshape(-1, self.num_queries) loss_giou = loss_giou * bbox_valid # remove invalid out['loss_giou'] = loss_giou.sum() / num_boxes # affordance affordance_valid = affordance[:, :, 0] > -0.5 if affordance_valid.sum() == 0: out['loss_affordance'] = torch.tensor(0.0, requires_grad=True).to(device) else: src_aff_masks = outputs_aff_masks[affordance_valid] tgt_aff_masks = affordance_map[affordance_valid] src_aff_masks = src_aff_masks.flatten(1) tgt_aff_masks = tgt_aff_masks.flatten(1) loss_aff = sigmoid_focal_loss( src_aff_masks, tgt_aff_masks, affordance_valid.sum(), alpha=self._affordance_focal_alpha, ) out['loss_affordance'] = loss_aff # axis axis_valid = axis[:, :, 0] > 0.0 num_axis = axis_valid.sum() if num_axis == 0: out['loss_axis_angle'] = torch.tensor(0.0, requires_grad=True).to(device) out['loss_axis_offset'] = torch.tensor(0.0, requires_grad=True).to(device) out['loss_eascore'] = torch.tensor(0.0, requires_grad=True).to(device) else: # regress angle src_axis_angle = outputs_axis[axis_valid] src_axis_angle_norm = F.normalize(src_axis_angle[:, :2]) src_axis_angle = torch.cat((src_axis_angle_norm, src_axis_angle[:, 2:]), dim=-1) target_axis_xyxy = axis[axis_valid] axis_center = target_boxes[axis_valid].clone() axis_center[:, 2:] = axis_center[:, :2] target_axis_angle = axis_ops.line_xyxy_to_angle(target_axis_xyxy, center=axis_center) loss_axis_angle = F.l1_loss(src_axis_angle[:, :2], target_axis_angle[:, :2], reduction='sum') / num_axis loss_axis_offset = F.l1_loss(src_axis_angle[:, 2:], target_axis_angle[:, 2:], reduction='sum') / num_axis out['loss_axis_angle'] = loss_axis_angle out['loss_axis_offset'] = loss_axis_offset src_axis_xyxy = axis_ops.line_angle_to_xyxy(src_axis_angle, center=axis_center) target_axis_xyxy = axis_ops.line_angle_to_xyxy(target_axis_angle, center=axis_center) axis_eascore, _, _ = axis_ops.ea_score(src_axis_xyxy, target_axis_xyxy) loss_eascore = 1 - axis_eascore out['loss_eascore'] = loss_eascore.mean() loss_movable = F.cross_entropy(outputs_movable.permute(0, 2, 1), movable, ignore_index=self._ignore_index) if torch.isnan(loss_movable): loss_movable = torch.tensor(0.0, requires_grad=True).to(device) out['loss_movable'] = loss_movable loss_rigid = F.cross_entropy(outputs_rigid.permute(0, 2, 1), rigid, ignore_index=self._ignore_index) if torch.isnan(loss_rigid): loss_rigid = torch.tensor(0.0, requires_grad=True).to(device) out['loss_rigid'] = loss_rigid loss_kinematic = F.cross_entropy(outputs_kinematic.permute(0, 2, 1), kinematic, ignore_index=self._ignore_index) if torch.isnan(loss_kinematic): loss_kinematic = torch.tensor(0.0, requires_grad=True).to(device) out['loss_kinematic'] = loss_kinematic loss_action = F.cross_entropy(outputs_action.permute(0, 2, 1), action, ignore_index=self._ignore_index) if torch.isnan(loss_action): loss_action = torch.tensor(0.0, requires_grad=True).to(device) out['loss_action'] = loss_action # depth backward if self._depth_on: # (bs, 1, H, W) src_depths = interpolate(outputs_depth, size=depth.shape[-2:], mode='bilinear', align_corners=False) src_depths = src_depths.clamp(min=0.0, max=1.0) tgt_depths = depth.unsqueeze(1) # (bs, H, W) valid_depth = depth[:, 0, 0] > 0 if valid_depth.any(): src_depths = src_depths[valid_depth] tgt_depths = tgt_depths[valid_depth] depth_mask = tgt_depths > 1e-8 midas_loss, ssi_loss, reg_loss = self.midas_loss(src_depths, tgt_depths, depth_mask) loss_vnl = self.vnl_loss(tgt_depths, src_depths) out['loss_depth'] = midas_loss out['loss_vnl'] = loss_vnl else: out['loss_depth'] = torch.tensor(0.0, requires_grad=True).to(device) out['loss_vnl'] = torch.tensor(0.0, requires_grad=True).to(device) else: out['loss_depth'] = torch.tensor(0.0, requires_grad=True).to(device) out['loss_vnl'] = torch.tensor(0.0, requires_grad=True).to(device) # mask backward tgt_masks = masks src_masks = interpolate(outputs_seg_masks, size=tgt_masks.shape[-2:], mode='bilinear', align_corners=False) valid_mask = tgt_masks.sum(dim=-1).sum(dim=-1) > 10 if valid_mask.sum() == 0: out['loss_mask'] = torch.tensor(0.0, requires_grad=True).to(device) out['loss_dice'] = torch.tensor(0.0, requires_grad=True).to(device) else: num_masks = valid_mask.sum() src_masks = src_masks[valid_mask] tgt_masks = tgt_masks[valid_mask] src_masks = src_masks.flatten(1) tgt_masks = tgt_masks.flatten(1) tgt_masks = tgt_masks.view(src_masks.shape) out['loss_mask'] = sigmoid_focal_loss(src_masks, tgt_masks.float(), num_masks) out['loss_dice'] = dice_loss(src_masks, tgt_masks, num_masks) return out