GraspAnything / models /grasp_mods.py
Plachta's picture
Update models/grasp_mods.py
76ce233 verified
"""
Add additional grasp decoder for Segment Anything model.
The structure should follow the grasp decoder structure in GraspDETR.
"""
import torch
import torch.nn as nn
from transformers.models.detr.configuration_detr import DetrConfig
from transformers.models.detr.modeling_detr import DetrHungarianMatcher, DetrLoss, DetrSegmentationOutput, DetrDecoder, sigmoid_focal_loss, dice_loss
from typing import Any, Dict, List, Tuple
from transformers.models.detr.modeling_detr import generalized_box_iou
from transformers.image_transforms import center_to_corners_format
from scipy.optimize import linear_sum_assignment
def modify_matcher_forward(self):
@torch.no_grad()
def matcher_forward(outputs, targets):
batch_size, num_queries = outputs["logits"].shape[:2]
# We flatten to compute the cost matrices in a batch
out_prob = outputs["logits"].flatten(0, 1).softmax(-1) # [batch_size * num_queries, num_classes]
out_bbox = outputs["pred_boxes"].flatten(0, 1) # [batch_size * num_queries, 4]
# Also concat the target labels and boxes
target_ids = torch.cat([v["class_labels"] for v in targets])
target_bbox = torch.cat([v["boxes"] for v in targets])
# Compute the classification cost. Contrary to the loss, we don't use the NLL,
# but approximate it in 1 - proba[target class].
# The 1 is a constant that doesn't change the matching, it can be ommitted.
class_cost = -out_prob[:, target_ids]
# Compute the L1 cost between boxes
bbox_cost = torch.cdist(out_bbox, target_bbox, p=1)
# Compute the giou cost between boxes
giou_cost = -generalized_box_iou(center_to_corners_format(out_bbox[:, :4]), center_to_corners_format(target_bbox[:, :4]))
# Final cost matrix
cost_matrix = self.bbox_cost * bbox_cost + self.class_cost * class_cost + self.giou_cost * giou_cost
cost_matrix = cost_matrix.view(batch_size, num_queries, -1).cpu()
sizes = [len(v["boxes"]) for v in targets]
indices = [linear_sum_assignment(c[i]) for i, c in enumerate(cost_matrix.split(sizes, -1))]
return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices]
return matcher_forward
def modify_grasp_loss_forward(self):
def modified_loss_labels(outputs, targets, indices, num_boxes):
"""
Classification loss (NLL) targets dicts must contain the key "class_labels" containing a tensor of dim
[nb_target_boxes]
"""
num_classes = 1 # model v9 always use class agnostic grasp
if "logits" not in outputs:
raise KeyError("No logits were found in the outputs")
source_logits = outputs["logits"]
idx = self._get_source_permutation_idx(indices)
target_classes_o = torch.cat([t["class_labels"][J] for t, (_, J) in zip(targets, indices)])
target_classes = torch.full(
source_logits.shape[:2], num_classes, dtype=torch.int64, device=source_logits.device
)
target_classes[idx] = target_classes_o
loss_ce = nn.functional.cross_entropy(source_logits.transpose(1, 2), target_classes)
losses = {"loss_ce": loss_ce}
return losses
def modified_loss_boxes(outputs, targets, indices, num_boxes, ignore_wh=False):
if "pred_boxes" not in outputs:
raise KeyError("No predicted boxes found in outputs")
idx = self._get_source_permutation_idx(indices)
source_boxes = outputs["pred_boxes"][idx]
target_boxes = torch.cat([t["boxes"][i] for t, (_, i) in zip(targets, indices)], dim=0)
if not ignore_wh:
loss_bbox = nn.functional.l1_loss(source_boxes, target_boxes, reduction="none")
else:
source_xytheta = source_boxes[:, [0, 1, 4]]
target_xytheta = target_boxes[:, [0, 1, 4]]
loss_bbox = nn.functional.l1_loss(source_xytheta, target_xytheta, reduction="none") * 5 / 3
losses = {}
losses["loss_bbox"] = loss_bbox.sum() / num_boxes
if not ignore_wh:
loss_giou = 1 - torch.diag(
generalized_box_iou(center_to_corners_format(source_boxes[:, :4]), center_to_corners_format(target_boxes[:, :4]))
)
else:
source_boxes[:, -2:] = target_boxes[:, -2:].clone()
source_corners = center_to_corners_format(source_boxes[:, :4])
target_corners = center_to_corners_format(target_boxes[:, :4])
loss_giou = 1 - torch.diag(generalized_box_iou(source_corners, target_corners))
losses["loss_giou"] = loss_giou.sum() / num_boxes
return losses
def modified_forward(outputs, targets, ignore_wh=False):
"""
This performs the loss computation.
Args:
outputs (`dict`, *optional*):
Dictionary of tensors, see the output specification of the model for the format.
targets (`List[dict]`, *optional*):
List of dicts, such that `len(targets) == batch_size`. The expected keys in each dict depends on the
losses applied, see each loss' doc.
"""
outputs_without_aux = {k: v for k, v in outputs.items() if k != "auxiliary_outputs"}
# Retrieve the matching between the outputs of the last layer and the targets
indices = self.matcher(outputs_without_aux, targets)
# Compute the average number of target boxes across all nodes, for normalization purposes
num_boxes = sum(len(t["class_labels"]) for t in targets)
num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device)
# (Niels): comment out function below, distributed training to be added
# if is_dist_avail_and_initialized():
# torch.distributed.all_reduce(num_boxes)
# (Niels) in original implementation, num_boxes is divided by get_world_size()
num_boxes = torch.clamp(num_boxes, min=1).item()
# Compute all the requested losses
losses = {}
losses.update(self.loss_labels(outputs, targets, indices, num_boxes))
losses.update(self.loss_boxes(outputs, targets, indices, num_boxes, ignore_wh))
return losses
return modified_loss_labels, modified_loss_boxes, modified_forward
def modify_forward(self):
"""
Modify the following methods to make SAM perform grasp detection after segmentation:
1. Add a parallel decoder for grasping detection: 1(+1) classes, 5 values to regress (bbox & rotation)
Returns:
Modified model
"""
# 1. We instantiate a new module in self.base_model, as another decoder
self.grasp_decoder_config = DetrConfig()
self.grasp_decoder = DetrDecoder(self.grasp_decoder_config).to(self.device)
self.grasp_query_position_embeddings = nn.Embedding(20, 256).to(self.device)
# 2. Base model forward method is not directly used, no modification needs to be done
# self.detr.model.forward = modify_base_model_forward(self.detr.model)
# 3. Add additional classification head & bbox regression head for grasp_decoder output
self.grasp_predictor = torch.nn.Sequential(
torch.nn.Linear(256, 256),
torch.nn.Linear(256, 256),
torch.nn.Linear(256, 5)
).to(self.device)
self.grasp_label_classifier = torch.nn.Linear(256, 2).to(self.device)
# 4. Add positional embedding
# name it as grasp_img_pos_embed to avoid name conflict
class ImagePosEmbed(nn.Module):
def __init__(self, img_size=64, hidden_dim=256):
super().__init__()
self.pos_embed = nn.Parameter(
torch.randn(1, img_size, img_size, hidden_dim)
)
def forward(self, x):
return x + self.pos_embed
self.grasp_img_pos_embed = ImagePosEmbed().to(self.device)
def modified_forward(
batched_input: List[Dict[str, Any]],
multimask_output: bool,
):
input_images = torch.stack([x["image"] for x in batched_input], dim=0)
image_embeddings = self.image_encoder(input_images)
batch_size = len(batched_input)
outputs = []
srcs = []
for image_record, curr_embedding in zip(batched_input, image_embeddings):
if "point_coords" in image_record:
points = (image_record["point_coords"], image_record["point_labels"])
else:
points = None
sparse_embeddings, dense_embeddings = self.prompt_encoder(
points=points,
boxes=image_record.get("boxes", None),
masks=image_record.get("mask_inputs", None),
)
low_res_masks, iou_predictions, src = self.mask_decoder(
image_embeddings=curr_embedding.unsqueeze(0),
image_pe=self.prompt_encoder.get_dense_pe(),
sparse_prompt_embeddings=sparse_embeddings,
dense_prompt_embeddings=dense_embeddings,
multimask_output=multimask_output,
)
outputs.append(
{
"iou_predictions": iou_predictions,
"low_res_logits": low_res_masks,
}
)
srcs.append(src[0])
srcs = torch.stack(srcs, dim=0)
# forward grasp decoder here
# 1. Get encoder hidden states
grasp_encoder_hidden_states = self.grasp_img_pos_embed(srcs.permute(0, 2, 3, 1))
# 2. Get query embeddings
grasp_query_pe = self.grasp_query_position_embeddings(torch.arange(20).to(self.device))
# repeat to batchsize
grasp_query_pe = grasp_query_pe.repeat(len(batched_input), 1, 1)
pixel_masks = torch.cat([batched_input[i]['pixel_mask'] for i in range(len(batched_input))], dim=0)
downsampled_pixel_masks = nn.functional.interpolate(pixel_masks.unsqueeze(1).float(), size=(64, 64), mode='nearest').squeeze(1).bool()
downsampled_pixel_masks = downsampled_pixel_masks.view(batch_size, 64*64).contiguous()
grasp_encoder_hidden_states = grasp_encoder_hidden_states.view(batch_size, 64*64, 256).contiguous()
grasp_decoder_outputs = self.grasp_decoder(
inputs_embeds=torch.zeros_like(grasp_query_pe),
attention_mask=None,
position_embeddings=torch.zeros_like(grasp_encoder_hidden_states),
query_position_embeddings=grasp_query_pe,
encoder_hidden_states=grasp_encoder_hidden_states,
encoder_attention_mask=downsampled_pixel_masks,
output_attentions=False,
output_hidden_states=False,
return_dict=True,
)
grasp_sequence_output = grasp_decoder_outputs[0]
grasp_logits = self.grasp_label_classifier(grasp_sequence_output)
pred_grasps = self.grasp_predictor(grasp_sequence_output).sigmoid()
# 3. Calculate loss
loss, loss_dict = 0, {}
if "grasp_labels" in batched_input[0]:
config = self.grasp_decoder_config
grasp_labels = [{
"class_labels": torch.zeros([len(x["grasp_labels"])], dtype=torch.long).to(self.device),
"boxes": x["grasp_labels"],
} for x in batched_input]
# First: create the matcher
matcher = DetrHungarianMatcher(
class_cost=config.class_cost, bbox_cost=config.bbox_cost, giou_cost=config.giou_cost
)
matcher.forward = modify_matcher_forward(matcher)
# Second: create the criterion
losses = ["labels", "boxes"]
criterion = DetrLoss(
matcher=matcher,
num_classes=config.num_labels,
eos_coef=config.eos_coefficient,
losses=losses,
)
criterion.loss_labels, criterion.loss_boxes, criterion.forward = modify_grasp_loss_forward(criterion)
criterion.to(self.device)
# Third: compute the losses, based on outputs and labels
outputs_loss = {}
outputs_loss["logits"] = grasp_logits
outputs_loss["pred_boxes"] = pred_grasps
grasp_loss_dict = criterion(outputs_loss, grasp_labels, ignore_wh=batched_input[0].get("ignore_wh", False))
# Fourth: compute total loss, as a weighted sum of the various losses
weight_dict = {"loss_ce": 1, "loss_bbox": config.bbox_loss_coefficient}
weight_dict["loss_giou"] = config.giou_loss_coefficient
if config.auxiliary_loss:
aux_weight_dict = {}
for i in range(config.decoder_layers - 1):
aux_weight_dict.update({k + f"_{i}": v for k, v in weight_dict.items()})
weight_dict.update(aux_weight_dict)
grasp_loss = sum(grasp_loss_dict[k] * weight_dict[k] for k in grasp_loss_dict.keys() if k in weight_dict)
# merge grasp branch loss into variable loss & loss_dict
loss += grasp_loss
loss_dict.update(grasp_loss_dict)
pred_masks = self.postprocess_masks(
torch.cat([x['low_res_logits'] for x in outputs], dim=0),
input_size=image_record["image"].shape[-2:],
original_size=(1024, 1024),
)
if 'masks' in batched_input[0]:
# 4. Calculate segmentation loss
sf_loss = sigmoid_focal_loss(pred_masks.flatten(1),
torch.stack([x['masks'] for x in batched_input], dim=0).unsqueeze(1).type(torch.float32).flatten(1), len(batched_input))
d_loss = dice_loss(pred_masks.flatten(1),
torch.stack([x['masks'] for x in batched_input], dim=0).unsqueeze(1).type(torch.float32).flatten(1), len(batched_input))
loss += sf_loss + d_loss
loss_dict["sf_loss"] = sf_loss
loss_dict["d_loss"] = d_loss
return DetrSegmentationOutput(
loss=loss,
loss_dict=loss_dict,
logits=grasp_logits,
pred_boxes=pred_grasps,
pred_masks=pred_masks,
)
return modified_forward
def add_inference_method(self):
def infer(
batched_input: List[Dict[str, Any]],
multimask_output: bool,
):
input_images = torch.stack([x["image"] for x in batched_input], dim=0)
image_embeddings = self.image_encoder(input_images)
outputs = []
srcs = []
curr_embedding = image_embeddings[0]
image_record = batched_input[0]
if "point_coords" in image_record:
points = (image_record["point_coords"], image_record["point_labels"])
else:
points = None
sparse_embeddings, dense_embeddings = self.prompt_encoder(
points=points,
boxes=image_record.get("boxes", None),
masks=image_record.get("mask_inputs", None),
)
low_res_masks, iou_predictions, src = self.mask_decoder(
image_embeddings=curr_embedding.unsqueeze(0),
image_pe=self.prompt_encoder.get_dense_pe(),
sparse_prompt_embeddings=sparse_embeddings,
dense_prompt_embeddings=dense_embeddings,
multimask_output=multimask_output,
)
outputs.append(
{
"iou_predictions": iou_predictions,
"low_res_logits": low_res_masks,
}
)
srcs.append(src[0])
n_queries = iou_predictions.size(0)
batch_size = n_queries
# forward grasp decoder here
# 1. Get encoder hidden states
grasp_encoder_hidden_states = self.grasp_img_pos_embed(src.permute(0, 2, 3, 1))
# 2. Get query embeddings
grasp_query_pe = self.grasp_query_position_embeddings(torch.arange(20).to(self.device))
# repeat to batchsize
grasp_query_pe = grasp_query_pe.repeat(n_queries, 1, 1)
pixel_masks = torch.cat([batched_input[i]['pixel_mask'] for i in range(len(batched_input))], dim=0)
pixel_masks = pixel_masks.repeat(n_queries, 1, 1)
downsampled_pixel_masks = nn.functional.interpolate(pixel_masks.unsqueeze(1).float(), size=(64, 64),
mode='nearest').squeeze(1).bool()
downsampled_pixel_masks = downsampled_pixel_masks.view(batch_size, 64 * 64).contiguous()
grasp_encoder_hidden_states = grasp_encoder_hidden_states.view(batch_size, 64 * 64, 256).contiguous()
grasp_decoder_outputs = self.grasp_decoder(
inputs_embeds=torch.zeros_like(grasp_query_pe),
attention_mask=None,
position_embeddings=torch.zeros_like(grasp_encoder_hidden_states),
query_position_embeddings=grasp_query_pe,
encoder_hidden_states=grasp_encoder_hidden_states,
encoder_attention_mask=downsampled_pixel_masks,
output_attentions=False,
output_hidden_states=False,
return_dict=True,
)
grasp_sequence_output = grasp_decoder_outputs[0]
grasp_logits = self.grasp_label_classifier(grasp_sequence_output)
pred_grasps = self.grasp_predictor(grasp_sequence_output).sigmoid()
pred_masks = self.postprocess_masks(
torch.cat([x['low_res_logits'] for x in outputs], dim=0),
input_size=image_record["image"].shape[-2:],
original_size=(1024, 1024),
)
return DetrSegmentationOutput(
loss=0,
loss_dict={},
logits=grasp_logits,
pred_boxes=pred_grasps,
pred_masks=pred_masks,
)
return infer