import torch import torch.nn.functional as F from torch import nn import numpy as np import math from typing import Any, Dict, Optional, Tuple, Union from transformers import OwlViTForObjectDetection, OwlViTConfig from .util import box_ops from .util.misc import (nested_tensor_from_tensor_list, accuracy, interpolate, inverse_sigmoid) from .matcher import HungarianMatcher from .segmentation import dice_loss, sigmoid_focal_loss from .matcher import HungarianMatcher import copy class OwlViT(nn.Module): def __init__(self, num_classes, is_eval=False): super().__init__() if is_eval: owlViT_config = OwlViTConfig.from_pretrained("google/owlvit-base-patch16") model_owlViT = OwlViTForObjectDetection(owlViT_config) else: model_owlViT = OwlViTForObjectDetection.from_pretrained("google/owlvit-base-patch16") self.vision_model = model_owlViT.owlvit.vision_model self.class_head = model_owlViT.class_head self.box_head = model_owlViT.box_head self.layer_norm = model_owlViT.layer_norm self.sigmoid = nn.Sigmoid() del model_owlViT self.matcher = HungarianMatcher(cost_class=2, cost_bbox=5, cost_giou=2) self.weight_dict = {'loss_ce': 2, 'loss_bbox': 5, 'loss_giou': 2} self.losses = ['labels', 'boxes'] # num_classes, matcher, weight_dict, losses, focal_alpha=0.25 self.criterion = SetCriterion(num_classes, self.matcher, self.weight_dict, self.losses, focal_alpha=0.25) def normalize_grid_corner_coordinates(self, feature_map: torch.FloatTensor): # Computes normalized xy corner coordinates from feature_map. if not feature_map.ndim == 4: raise ValueError("Expected input shape is [batch_size, num_patches, num_patches, hidden_dim]") device = feature_map.device num_patches = feature_map.shape[1] box_coordinates = np.stack( np.meshgrid(np.arange(1, num_patches + 1), np.arange(1, num_patches + 1)), axis=-1 ).astype(np.float32) box_coordinates /= np.array([num_patches, num_patches], np.float32) # Flatten (h, w, 2) -> (h*w, 2) box_coordinates = box_coordinates.reshape( box_coordinates.shape[0] * box_coordinates.shape[1], box_coordinates.shape[2] ) box_coordinates = torch.from_numpy(box_coordinates).to(device) return box_coordinates def compute_box_bias(self, feature_map: torch.FloatTensor) -> torch.FloatTensor: # The box center is biased to its position on the feature grid box_coordinates = self.normalize_grid_corner_coordinates(feature_map) box_coordinates = torch.clip(box_coordinates, 0.0, 1.0) # Unnormalize xy box_coord_bias = torch.log(box_coordinates + 1e-4) - torch.log1p(-box_coordinates + 1e-4) # The box size is biased to the patch size box_size = torch.full_like(box_coord_bias, 1.0 / feature_map.shape[-2]) box_size_bias = torch.log(box_size + 1e-4) - torch.log1p(-box_size + 1e-4) # Compute box bias box_bias = torch.cat([box_coord_bias, box_size_bias], dim=-1) return box_bias def box_predictor( self, image_feats: torch.FloatTensor, feature_map: torch.FloatTensor, ) -> torch.FloatTensor: """ Args: image_feats: Features extracted from the image, returned by the `image_text_embedder` method. feature_map: A spatial re-arrangement of image_features, also returned by the `image_text_embedder` method. Returns: pred_boxes: List of predicted boxes (cxcywh normalized to 0, 1) nested within a dictionary. """ # Bounding box detection head [batch_size, num_boxes, 4]. pred_boxes = self.box_head(image_feats) # Compute the location of each token on the grid and use it to compute a bias for the bbox prediction pred_boxes += self.compute_box_bias(feature_map) pred_boxes = self.sigmoid(pred_boxes) return pred_boxes def class_predictor( self, image_feats: torch.FloatTensor, query_embeds: Optional[torch.FloatTensor] = None, query_mask: Optional[torch.Tensor] = None, ) -> Tuple[torch.FloatTensor]: """ Args: image_feats: Features extracted from the `image_text_embedder`. query_embeds: Text query embeddings. query_mask: Must be provided with query_embeddings. A mask indicating which query embeddings are valid. """ (pred_logits, image_class_embeds) = self.class_head(image_feats, query_embeds, query_mask) return (pred_logits, image_class_embeds) def get_visual_embs(self, image): vision_outputs = self.vision_model( pixel_values=image, output_hidden_states=self.vision_model.config.output_hidden_states, return_dict=self.vision_model.config.use_return_dict, ) # Get image embeddings last_hidden_state = vision_outputs[0] image_embeds = self.vision_model.post_layernorm(last_hidden_state) # Resize class token new_size = tuple(np.array(image_embeds.shape) - np.array((0, 1, 0))) class_token_out = torch.broadcast_to(image_embeds[:, :1, :], new_size) # Merge image embedding with class tokens image_embeds = image_embeds[:, 1:, :] * class_token_out image_embeds = self.layer_norm(image_embeds) # Resize to [batch_size, num_patches, num_patches, hidden_size] new_size = ( image_embeds.shape[0], int(np.sqrt(image_embeds.shape[1])), int(np.sqrt(image_embeds.shape[1])), image_embeds.shape[-1], ) feature_map = image_embeds.reshape(new_size) return feature_map def forward( self, image_embeddings: torch.Tensor, prompt_embeddings: torch.Tensor, ): feature_map = image_embeddings batch_size, num_patches, num_patches, hidden_dim = feature_map.shape image_feats = torch.reshape(feature_map, (batch_size, num_patches * num_patches, hidden_dim)) query_embeds = prompt_embeddings.reshape(batch_size, 1, prompt_embeddings.shape[-1]) # Predict object classes [batch_size, num_patches, num_queries+1] (pred_logits, class_embeds) = self.class_predictor(image_feats, query_embeds) # Predict object boxes pred_boxes = self.box_predictor(image_feats, feature_map) out = {'pred_logits': pred_logits, 'pred_boxes': pred_boxes} return out class SetCriterion(nn.Module): """ This class computes the loss for DETR. The process happens in two steps: 1) we compute hungarian assignment between ground truth boxes and the outputs of the model 2) we supervise each pair of matched ground-truth / prediction (supervise class and box) """ def __init__(self, num_classes, matcher, weight_dict, losses, focal_alpha=0.25): """ Create the criterion. Parameters: num_classes: number of object categories, omitting the special no-object category matcher: module able to compute a matching between targets and proposals weight_dict: dict containing as key the names of the losses and as values their relative weight. losses: list of all the losses to be applied. See get_loss for list of available losses. focal_alpha: alpha in Focal Loss """ super().__init__() self.num_classes = num_classes self.matcher = matcher self.weight_dict = weight_dict self.losses = losses self.focal_alpha = focal_alpha def loss_labels(self, outputs, targets, indices, num_boxes, log=True): """Classification loss (NLL) targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes] """ assert 'pred_logits' in outputs src_logits = outputs['pred_logits'] idx = self._get_src_permutation_idx(indices) target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)]) target_classes = torch.full(src_logits.shape[:2], self.num_classes, dtype=torch.int64, device=src_logits.device) target_classes[idx] = target_classes_o target_classes_onehot = torch.zeros([src_logits.shape[0], src_logits.shape[1], src_logits.shape[2] + 1], dtype=src_logits.dtype, layout=src_logits.layout, device=src_logits.device) target_classes_onehot.scatter_(2, target_classes.unsqueeze(-1), 1) target_classes_onehot = target_classes_onehot[:,:,:-1] loss_ce = sigmoid_focal_loss(src_logits, target_classes_onehot, num_boxes, alpha=self.focal_alpha, gamma=2) * src_logits.shape[1] losses = {'loss_ce': loss_ce} if log: # TODO this should probably be a separate loss, not hacked in this one here losses['class_error'] = 100 - accuracy(src_logits[idx], target_classes_o)[0] return losses @torch.no_grad() def loss_cardinality(self, outputs, targets, indices, num_boxes): """ Compute the cardinality error, ie the absolute error in the number of predicted non-empty boxes This is not really a loss, it is intended for logging purposes only. It doesn't propagate gradients """ pred_logits = outputs['pred_logits'] device = pred_logits.device tgt_lengths = torch.as_tensor([len(v["labels"]) for v in targets], device=device) # Count the number of predictions that are NOT "no-object" (which is the last class) card_pred = (pred_logits.argmax(-1) != pred_logits.shape[-1] - 1).sum(1) card_err = F.l1_loss(card_pred.float(), tgt_lengths.float(), reduce=None) losses = {'cardinality_error': card_err} return losses def loss_boxes(self, outputs, targets, indices, num_boxes): """Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4] The target boxes are expected in format (center_x, center_y, h, w), normalized by the image size. """ assert 'pred_boxes' in outputs idx = self._get_src_permutation_idx(indices) src_boxes = outputs['pred_boxes'][idx] target_boxes = torch.cat([t['boxes'][i] for t, (_, i) in zip(targets, indices)], dim=0) loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction='none') losses = {} losses['loss_bbox'] = loss_bbox / num_boxes loss_giou = 1 - torch.diag(box_ops.generalized_box_iou( box_ops.box_cxcywh_to_xyxy(src_boxes), box_ops.box_cxcywh_to_xyxy(target_boxes))) losses['loss_giou'] = loss_giou / num_boxes return losses def loss_masks(self, outputs, targets, indices, num_boxes): """Compute the losses related to the masks: the focal loss and the dice loss. targets dicts must contain the key "masks" containing a tensor of dim [nb_target_boxes, h, w] """ assert "pred_masks" in outputs src_idx = self._get_src_permutation_idx(indices) tgt_idx = self._get_tgt_permutation_idx(indices) src_masks = outputs["pred_masks"] # TODO use valid to mask invalid areas due to padding in loss target_masks, valid = nested_tensor_from_tensor_list([t["masks"] for t in targets]).decompose() target_masks = target_masks.to(src_masks) src_masks = src_masks[src_idx] # upsample predictions to the target size src_masks = interpolate(src_masks[:, None], size=target_masks.shape[-2:], mode="bilinear", align_corners=False) src_masks = src_masks[:, 0].flatten(1) target_masks = target_masks[tgt_idx].flatten(1) losses = { "loss_mask": sigmoid_focal_loss(src_masks, target_masks, num_boxes), "loss_dice": dice_loss(src_masks, target_masks, num_boxes), } return losses def _get_src_permutation_idx(self, indices): # permute predictions following indices batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)]) src_idx = torch.cat([src for (src, _) in indices]) return batch_idx, src_idx def _get_tgt_permutation_idx(self, indices): # permute targets following indices batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)]) tgt_idx = torch.cat([tgt for (_, tgt) in indices]) return batch_idx, tgt_idx def get_loss(self, loss, outputs, targets, indices, num_boxes, **kwargs): loss_map = { 'labels': self.loss_labels, 'cardinality': self.loss_cardinality, 'boxes': self.loss_boxes, 'masks': self.loss_masks } assert loss in loss_map, f'do you really want to compute {loss} loss?' return loss_map[loss](outputs, targets, indices, num_boxes, **kwargs) def forward(self, outputs, targets, num_boxes): """ This performs the loss computation. Parameters: outputs: dict of tensors, see the output specification of the model for the format targets: list of dicts, such that len(targets) == batch_size. The expected keys in each dict depends on the losses applied, see each loss' doc """ outputs_without_aux = {k: v for k, v in outputs.items() if k != 'aux_outputs' and k != 'enc_outputs'} # Retrieve the matching between the outputs of the last layer and the targets indices = self.matcher(outputs_without_aux, targets) # Compute all the requested losses losses = {} for loss in self.losses: kwargs = {} losses.update(self.get_loss(loss, outputs, targets, indices, num_boxes, **kwargs)) # In case of auxiliary losses, we repeat this process with the output of each intermediate layer. if 'aux_outputs' in outputs: for i, aux_outputs in enumerate(outputs['aux_outputs']): indices = self.matcher(aux_outputs, targets) for loss in self.losses: if loss == 'masks': # Intermediate masks losses are too costly to compute, we ignore them. continue kwargs = {} if loss == 'labels': # Logging is enabled only for the last layer kwargs['log'] = False l_dict = self.get_loss(loss, aux_outputs, targets, indices, num_boxes, **kwargs) l_dict = {k + f'_{i}': v for k, v in l_dict.items()} losses.update(l_dict) if 'enc_outputs' in outputs: enc_outputs = outputs['enc_outputs'] bin_targets = copy.deepcopy(targets) for bt in bin_targets: bt['labels'] = torch.zeros_like(bt['labels']) indices = self.matcher(enc_outputs, bin_targets) for loss in self.losses: if loss == 'masks': # Intermediate masks losses are too costly to compute, we ignore them. continue kwargs = {} if loss == 'labels': # Logging is enabled only for the last layer kwargs['log'] = False l_dict = self.get_loss(loss, enc_outputs, bin_targets, indices, num_boxes, **kwargs) l_dict = {k + f'_enc': v for k, v in l_dict.items()} losses.update(l_dict) return losses