""" Add additional grasp decoder for Segment Anything model. The structure should follow the grasp decoder structure in GraspDETR. """ import torch import torch.nn as nn from transformers.models.detr.configuration_detr import DetrConfig from transformers.models.detr.modeling_detr import DetrHungarianMatcher, DetrLoss, DetrSegmentationOutput, DetrDecoder, sigmoid_focal_loss, dice_loss from typing import Any, Dict, List, Tuple from transformers.models.detr.modeling_detr import generalized_box_iou from transformers.image_transforms import center_to_corners_format from scipy.optimize import linear_sum_assignment def modify_matcher_forward(self): @torch.no_grad() def matcher_forward(outputs, targets): batch_size, num_queries = outputs["logits"].shape[:2] # We flatten to compute the cost matrices in a batch out_prob = outputs["logits"].flatten(0, 1).softmax(-1) # [batch_size * num_queries, num_classes] out_bbox = outputs["pred_boxes"].flatten(0, 1) # [batch_size * num_queries, 4] # Also concat the target labels and boxes target_ids = torch.cat([v["class_labels"] for v in targets]) target_bbox = torch.cat([v["boxes"] for v in targets]) # Compute the classification cost. Contrary to the loss, we don't use the NLL, # but approximate it in 1 - proba[target class]. # The 1 is a constant that doesn't change the matching, it can be ommitted. class_cost = -out_prob[:, target_ids] # Compute the L1 cost between boxes bbox_cost = torch.cdist(out_bbox, target_bbox, p=1) # Compute the giou cost between boxes giou_cost = -generalized_box_iou(center_to_corners_format(out_bbox[:, :4]), center_to_corners_format(target_bbox[:, :4])) # Final cost matrix cost_matrix = self.bbox_cost * bbox_cost + self.class_cost * class_cost + self.giou_cost * giou_cost cost_matrix = cost_matrix.view(batch_size, num_queries, -1).cpu() sizes = [len(v["boxes"]) for v in targets] indices = [linear_sum_assignment(c[i]) for i, c in enumerate(cost_matrix.split(sizes, -1))] return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices] return matcher_forward def modify_grasp_loss_forward(self): def modified_loss_labels(outputs, targets, indices, num_boxes): """ Classification loss (NLL) targets dicts must contain the key "class_labels" containing a tensor of dim [nb_target_boxes] """ num_classes = 1 # model v9 always use class agnostic grasp if "logits" not in outputs: raise KeyError("No logits were found in the outputs") source_logits = outputs["logits"] idx = self._get_source_permutation_idx(indices) target_classes_o = torch.cat([t["class_labels"][J] for t, (_, J) in zip(targets, indices)]) target_classes = torch.full( source_logits.shape[:2], num_classes, dtype=torch.int64, device=source_logits.device ) target_classes[idx] = target_classes_o loss_ce = nn.functional.cross_entropy(source_logits.transpose(1, 2), target_classes) losses = {"loss_ce": loss_ce} return losses def modified_loss_boxes(outputs, targets, indices, num_boxes, ignore_wh=False): if "pred_boxes" not in outputs: raise KeyError("No predicted boxes found in outputs") idx = self._get_source_permutation_idx(indices) source_boxes = outputs["pred_boxes"][idx] target_boxes = torch.cat([t["boxes"][i] for t, (_, i) in zip(targets, indices)], dim=0) if not ignore_wh: loss_bbox = nn.functional.l1_loss(source_boxes, target_boxes, reduction="none") else: source_xytheta = source_boxes[:, [0, 1, 4]] target_xytheta = target_boxes[:, [0, 1, 4]] loss_bbox = nn.functional.l1_loss(source_xytheta, target_xytheta, reduction="none") * 5 / 3 losses = {} losses["loss_bbox"] = loss_bbox.sum() / num_boxes if not ignore_wh: loss_giou = 1 - torch.diag( generalized_box_iou(center_to_corners_format(source_boxes[:, :4]), center_to_corners_format(target_boxes[:, :4])) ) else: source_boxes[:, -2:] = target_boxes[:, -2:].clone() source_corners = center_to_corners_format(source_boxes[:, :4]) target_corners = center_to_corners_format(target_boxes[:, :4]) loss_giou = 1 - torch.diag(generalized_box_iou(source_corners, target_corners)) losses["loss_giou"] = loss_giou.sum() / num_boxes return losses def modified_forward(outputs, targets, ignore_wh=False): """ This performs the loss computation. Args: outputs (`dict`, *optional*): Dictionary of tensors, see the output specification of the model for the format. targets (`List[dict]`, *optional*): List of dicts, such that `len(targets) == batch_size`. The expected keys in each dict depends on the losses applied, see each loss' doc. """ outputs_without_aux = {k: v for k, v in outputs.items() if k != "auxiliary_outputs"} # Retrieve the matching between the outputs of the last layer and the targets indices = self.matcher(outputs_without_aux, targets) # Compute the average number of target boxes across all nodes, for normalization purposes num_boxes = sum(len(t["class_labels"]) for t in targets) num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device) # (Niels): comment out function below, distributed training to be added # if is_dist_avail_and_initialized(): # torch.distributed.all_reduce(num_boxes) # (Niels) in original implementation, num_boxes is divided by get_world_size() num_boxes = torch.clamp(num_boxes, min=1).item() # Compute all the requested losses losses = {} losses.update(self.loss_labels(outputs, targets, indices, num_boxes)) losses.update(self.loss_boxes(outputs, targets, indices, num_boxes, ignore_wh)) return losses return modified_loss_labels, modified_loss_boxes, modified_forward def modify_forward(self): """ Modify the following methods to make SAM perform grasp detection after segmentation: 1. Add a parallel decoder for grasping detection: 1(+1) classes, 5 values to regress (bbox & rotation) Returns: Modified model """ # 1. We instantiate a new module in self.base_model, as another decoder self.grasp_decoder_config = DetrConfig() self.grasp_decoder = DetrDecoder(self.grasp_decoder_config).to(self.device) self.grasp_query_position_embeddings = nn.Embedding(20, 256).to(self.device) # 2. Base model forward method is not directly used, no modification needs to be done # self.detr.model.forward = modify_base_model_forward(self.detr.model) # 3. Add additional classification head & bbox regression head for grasp_decoder output self.grasp_predictor = torch.nn.Sequential( torch.nn.Linear(256, 256), torch.nn.Linear(256, 256), torch.nn.Linear(256, 5) ).to(self.device) self.grasp_label_classifier = torch.nn.Linear(256, 2).to(self.device) # 4. Add positional embedding # name it as grasp_img_pos_embed to avoid name conflict class ImagePosEmbed(nn.Module): def __init__(self, img_size=64, hidden_dim=256): super().__init__() self.pos_embed = nn.Parameter( torch.randn(1, img_size, img_size, hidden_dim) ) def forward(self, x): return x + self.pos_embed self.grasp_img_pos_embed = ImagePosEmbed().to(self.device) def modified_forward( batched_input: List[Dict[str, Any]], multimask_output: bool, ): input_images = torch.stack([x["image"] for x in batched_input], dim=0) image_embeddings = self.image_encoder(input_images) batch_size = len(batched_input) outputs = [] srcs = [] for image_record, curr_embedding in zip(batched_input, image_embeddings): if "point_coords" in image_record: points = (image_record["point_coords"], image_record["point_labels"]) else: points = None sparse_embeddings, dense_embeddings = self.prompt_encoder( points=points, boxes=image_record.get("boxes", None), masks=image_record.get("mask_inputs", None), ) low_res_masks, iou_predictions, src = self.mask_decoder( image_embeddings=curr_embedding.unsqueeze(0), image_pe=self.prompt_encoder.get_dense_pe(), sparse_prompt_embeddings=sparse_embeddings, dense_prompt_embeddings=dense_embeddings, multimask_output=multimask_output, ) outputs.append( { "iou_predictions": iou_predictions, "low_res_logits": low_res_masks, } ) srcs.append(src[0]) srcs = torch.stack(srcs, dim=0) # forward grasp decoder here # 1. Get encoder hidden states grasp_encoder_hidden_states = self.grasp_img_pos_embed(srcs.permute(0, 2, 3, 1)) # 2. Get query embeddings grasp_query_pe = self.grasp_query_position_embeddings(torch.arange(20).to(self.device)) # repeat to batchsize grasp_query_pe = grasp_query_pe.repeat(len(batched_input), 1, 1) pixel_masks = torch.cat([batched_input[i]['pixel_mask'] for i in range(len(batched_input))], dim=0) downsampled_pixel_masks = nn.functional.interpolate(pixel_masks.unsqueeze(1).float(), size=(64, 64), mode='nearest').squeeze(1).bool() downsampled_pixel_masks = downsampled_pixel_masks.view(batch_size, 64*64).contiguous() grasp_encoder_hidden_states = grasp_encoder_hidden_states.view(batch_size, 64*64, 256).contiguous() grasp_decoder_outputs = self.grasp_decoder( inputs_embeds=torch.zeros_like(grasp_query_pe), attention_mask=None, position_embeddings=torch.zeros_like(grasp_encoder_hidden_states), query_position_embeddings=grasp_query_pe, encoder_hidden_states=grasp_encoder_hidden_states, encoder_attention_mask=downsampled_pixel_masks, output_attentions=False, output_hidden_states=False, return_dict=True, ) grasp_sequence_output = grasp_decoder_outputs[0] grasp_logits = self.grasp_label_classifier(grasp_sequence_output) pred_grasps = self.grasp_predictor(grasp_sequence_output).sigmoid() # 3. Calculate loss loss, loss_dict = 0, {} if "grasp_labels" in batched_input[0]: config = self.grasp_decoder_config grasp_labels = [{ "class_labels": torch.zeros([len(x["grasp_labels"])], dtype=torch.long).to(self.device), "boxes": x["grasp_labels"], } for x in batched_input] # First: create the matcher matcher = DetrHungarianMatcher( class_cost=config.class_cost, bbox_cost=config.bbox_cost, giou_cost=config.giou_cost ) matcher.forward = modify_matcher_forward(matcher) # Second: create the criterion losses = ["labels", "boxes"] criterion = DetrLoss( matcher=matcher, num_classes=config.num_labels, eos_coef=config.eos_coefficient, losses=losses, ) criterion.loss_labels, criterion.loss_boxes, criterion.forward = modify_grasp_loss_forward(criterion) criterion.to(self.device) # Third: compute the losses, based on outputs and labels outputs_loss = {} outputs_loss["logits"] = grasp_logits outputs_loss["pred_boxes"] = pred_grasps grasp_loss_dict = criterion(outputs_loss, grasp_labels, ignore_wh=batched_input[0].get("ignore_wh", False)) # Fourth: compute total loss, as a weighted sum of the various losses weight_dict = {"loss_ce": 1, "loss_bbox": config.bbox_loss_coefficient} weight_dict["loss_giou"] = config.giou_loss_coefficient if config.auxiliary_loss: aux_weight_dict = {} for i in range(config.decoder_layers - 1): aux_weight_dict.update({k + f"_{i}": v for k, v in weight_dict.items()}) weight_dict.update(aux_weight_dict) grasp_loss = sum(grasp_loss_dict[k] * weight_dict[k] for k in grasp_loss_dict.keys() if k in weight_dict) # merge grasp branch loss into variable loss & loss_dict loss += grasp_loss loss_dict.update(grasp_loss_dict) pred_masks = self.postprocess_masks( torch.cat([x['low_res_logits'] for x in outputs], dim=0), input_size=image_record["image"].shape[-2:], original_size=(1024, 1024), ) if 'masks' in batched_input[0]: # 4. Calculate segmentation loss sf_loss = sigmoid_focal_loss(pred_masks.flatten(1), torch.stack([x['masks'] for x in batched_input], dim=0).unsqueeze(1).type(torch.float32).flatten(1), len(batched_input)) d_loss = dice_loss(pred_masks.flatten(1), torch.stack([x['masks'] for x in batched_input], dim=0).unsqueeze(1).type(torch.float32).flatten(1), len(batched_input)) loss += sf_loss + d_loss loss_dict["sf_loss"] = sf_loss loss_dict["d_loss"] = d_loss return DetrSegmentationOutput( loss=loss, loss_dict=loss_dict, logits=grasp_logits, pred_boxes=pred_grasps, pred_masks=pred_masks, ) return modified_forward def add_inference_method(self): def infer( batched_input: List[Dict[str, Any]], multimask_output: bool, ): input_images = torch.stack([x["image"] for x in batched_input], dim=0) image_embeddings = self.image_encoder(input_images) outputs = [] srcs = [] curr_embedding = image_embeddings[0] image_record = batched_input[0] if "point_coords" in image_record: points = (image_record["point_coords"], image_record["point_labels"]) else: points = None sparse_embeddings, dense_embeddings = self.prompt_encoder( points=points, boxes=image_record.get("boxes", None), masks=image_record.get("mask_inputs", None), ) low_res_masks, iou_predictions, src = self.mask_decoder( image_embeddings=curr_embedding.unsqueeze(0), image_pe=self.prompt_encoder.get_dense_pe(), sparse_prompt_embeddings=sparse_embeddings, dense_prompt_embeddings=dense_embeddings, multimask_output=multimask_output, ) outputs.append( { "iou_predictions": iou_predictions, "low_res_logits": low_res_masks, } ) srcs.append(src[0]) n_queries = iou_predictions.size(0) batch_size = n_queries # forward grasp decoder here # 1. Get encoder hidden states grasp_encoder_hidden_states = self.grasp_img_pos_embed(src.permute(0, 2, 3, 1)) # 2. Get query embeddings grasp_query_pe = self.grasp_query_position_embeddings(torch.arange(20).to(self.device)) # repeat to batchsize grasp_query_pe = grasp_query_pe.repeat(n_queries, 1, 1) pixel_masks = torch.cat([batched_input[i]['pixel_mask'] for i in range(len(batched_input))], dim=0) pixel_masks = pixel_masks.repeat(n_queries, 1, 1) downsampled_pixel_masks = nn.functional.interpolate(pixel_masks.unsqueeze(1).float(), size=(64, 64), mode='nearest').squeeze(1).bool() downsampled_pixel_masks = downsampled_pixel_masks.view(batch_size, 64 * 64).contiguous() grasp_encoder_hidden_states = grasp_encoder_hidden_states.view(batch_size, 64 * 64, 256).contiguous() grasp_decoder_outputs = self.grasp_decoder( inputs_embeds=torch.zeros_like(grasp_query_pe), attention_mask=None, position_embeddings=torch.zeros_like(grasp_encoder_hidden_states), query_position_embeddings=grasp_query_pe, encoder_hidden_states=grasp_encoder_hidden_states, encoder_attention_mask=downsampled_pixel_masks, output_attentions=False, output_hidden_states=False, return_dict=True, ) grasp_sequence_output = grasp_decoder_outputs[0] grasp_logits = self.grasp_label_classifier(grasp_sequence_output) pred_grasps = self.grasp_predictor(grasp_sequence_output).sigmoid() pred_masks = self.postprocess_masks( torch.cat([x['low_res_logits'] for x in outputs], dim=0), input_size=image_record["image"].shape[-2:], original_size=(1024, 1024), ) return DetrSegmentationOutput( loss=0, loss_dict={}, logits=grasp_logits, pred_boxes=pred_grasps, pred_masks=pred_masks, ) return infer