import torch
import torch.nn.functional as F
from torch import nn
import numpy as np
import math
from typing import Any, Dict, Optional, Tuple, Union
from transformers import OwlViTForObjectDetection, OwlViTConfig
from .util import box_ops
from .util.misc import (nested_tensor_from_tensor_list,
accuracy, interpolate, inverse_sigmoid)
from .matcher import HungarianMatcher
from .segmentation import dice_loss, sigmoid_focal_loss
from .matcher import HungarianMatcher
import copy
class OwlViT(nn.Module):
def __init__(self, num_classes, is_eval=False):
if is_eval:
owlViT_config = OwlViTConfig.from_pretrained("google/owlvit-base-patch16")
model_owlViT = OwlViTForObjectDetection(owlViT_config)
model_owlViT = OwlViTForObjectDetection.from_pretrained("google/owlvit-base-patch16")
self.vision_model = model_owlViT.owlvit.vision_model
self.class_head = model_owlViT.class_head
self.box_head = model_owlViT.box_head
self.layer_norm = model_owlViT.layer_norm
self.sigmoid = nn.Sigmoid()
del model_owlViT
self.matcher = HungarianMatcher(cost_class=2, cost_bbox=5, cost_giou=2)
self.weight_dict = {'loss_ce': 2, 'loss_bbox': 5, 'loss_giou': 2}
self.losses = ['labels', 'boxes']
# num_classes, matcher, weight_dict, losses, focal_alpha=0.25
self.criterion = SetCriterion(num_classes, self.matcher, self.weight_dict, self.losses, focal_alpha=0.25)
def normalize_grid_corner_coordinates(self, feature_map: torch.FloatTensor):
# Computes normalized xy corner coordinates from feature_map.
if not feature_map.ndim == 4:
raise ValueError("Expected input shape is [batch_size, num_patches, num_patches, hidden_dim]")
device = feature_map.device
num_patches = feature_map.shape[1]
box_coordinates = np.stack(
np.meshgrid(np.arange(1, num_patches + 1), np.arange(1, num_patches + 1)), axis=-1
box_coordinates /= np.array([num_patches, num_patches], np.float32)
# Flatten (h, w, 2) -> (h*w, 2)
box_coordinates = box_coordinates.reshape(
box_coordinates.shape[0] * box_coordinates.shape[1], box_coordinates.shape[2]
box_coordinates = torch.from_numpy(box_coordinates).to(device)
return box_coordinates
def compute_box_bias(self, feature_map: torch.FloatTensor) -> torch.FloatTensor:
# The box center is biased to its position on the feature grid
box_coordinates = self.normalize_grid_corner_coordinates(feature_map)
box_coordinates = torch.clip(box_coordinates, 0.0, 1.0)
# Unnormalize xy
box_coord_bias = torch.log(box_coordinates + 1e-4) - torch.log1p(-box_coordinates + 1e-4)
# The box size is biased to the patch size
box_size = torch.full_like(box_coord_bias, 1.0 / feature_map.shape[-2])
box_size_bias = torch.log(box_size + 1e-4) - torch.log1p(-box_size + 1e-4)
# Compute box bias
box_bias =[box_coord_bias, box_size_bias], dim=-1)
return box_bias
def box_predictor(
image_feats: torch.FloatTensor,
feature_map: torch.FloatTensor,
) -> torch.FloatTensor:
Features extracted from the image, returned by the `image_text_embedder` method.
A spatial re-arrangement of image_features, also returned by the `image_text_embedder` method.
List of predicted boxes (cxcywh normalized to 0, 1) nested within a dictionary.
# Bounding box detection head [batch_size, num_boxes, 4].
pred_boxes = self.box_head(image_feats)
# Compute the location of each token on the grid and use it to compute a bias for the bbox prediction
pred_boxes += self.compute_box_bias(feature_map)
pred_boxes = self.sigmoid(pred_boxes)
return pred_boxes
def class_predictor(
image_feats: torch.FloatTensor,
query_embeds: Optional[torch.FloatTensor] = None,
query_mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.FloatTensor]:
Features extracted from the `image_text_embedder`.
Text query embeddings.
Must be provided with query_embeddings. A mask indicating which query embeddings are valid.
(pred_logits, image_class_embeds) = self.class_head(image_feats, query_embeds, query_mask)
return (pred_logits, image_class_embeds)
def get_visual_embs(self, image):
vision_outputs = self.vision_model(
# Get image embeddings
last_hidden_state = vision_outputs[0]
image_embeds = self.vision_model.post_layernorm(last_hidden_state)
# Resize class token
new_size = tuple(np.array(image_embeds.shape) - np.array((0, 1, 0)))
class_token_out = torch.broadcast_to(image_embeds[:, :1, :], new_size)
# Merge image embedding with class tokens
image_embeds = image_embeds[:, 1:, :] * class_token_out
image_embeds = self.layer_norm(image_embeds)
# Resize to [batch_size, num_patches, num_patches, hidden_size]
new_size = (
feature_map = image_embeds.reshape(new_size)
return feature_map
def forward(
image_embeddings: torch.Tensor,
prompt_embeddings: torch.Tensor,
feature_map = image_embeddings
batch_size, num_patches, num_patches, hidden_dim = feature_map.shape
image_feats = torch.reshape(feature_map, (batch_size, num_patches * num_patches, hidden_dim))
query_embeds = prompt_embeddings.reshape(batch_size, 1, prompt_embeddings.shape[-1])
# Predict object classes [batch_size, num_patches, num_queries+1]
(pred_logits, class_embeds) = self.class_predictor(image_feats, query_embeds)
# Predict object boxes
pred_boxes = self.box_predictor(image_feats, feature_map)
out = {'pred_logits': pred_logits, 'pred_boxes': pred_boxes}
return out
class SetCriterion(nn.Module):
""" This class computes the loss for DETR.
The process happens in two steps:
1) we compute hungarian assignment between ground truth boxes and the outputs of the model
2) we supervise each pair of matched ground-truth / prediction (supervise class and box)
def __init__(self, num_classes, matcher, weight_dict, losses, focal_alpha=0.25):
""" Create the criterion.
num_classes: number of object categories, omitting the special no-object category
matcher: module able to compute a matching between targets and proposals
weight_dict: dict containing as key the names of the losses and as values their relative weight.
losses: list of all the losses to be applied. See get_loss for list of available losses.
focal_alpha: alpha in Focal Loss
self.num_classes = num_classes
self.matcher = matcher
self.weight_dict = weight_dict
self.losses = losses
self.focal_alpha = focal_alpha
def loss_labels(self, outputs, targets, indices, num_boxes, log=True):
"""Classification loss (NLL)
targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes]
assert 'pred_logits' in outputs
src_logits = outputs['pred_logits']
idx = self._get_src_permutation_idx(indices)
target_classes_o =[t["labels"][J] for t, (_, J) in zip(targets, indices)])
target_classes = torch.full(src_logits.shape[:2], self.num_classes,
dtype=torch.int64, device=src_logits.device)
target_classes[idx] = target_classes_o
target_classes_onehot = torch.zeros([src_logits.shape[0], src_logits.shape[1], src_logits.shape[2] + 1],
dtype=src_logits.dtype, layout=src_logits.layout, device=src_logits.device)
target_classes_onehot.scatter_(2, target_classes.unsqueeze(-1), 1)
target_classes_onehot = target_classes_onehot[:,:,:-1]
loss_ce = sigmoid_focal_loss(src_logits, target_classes_onehot, num_boxes, alpha=self.focal_alpha, gamma=2) * src_logits.shape[1]
losses = {'loss_ce': loss_ce}
if log:
# TODO this should probably be a separate loss, not hacked in this one here
losses['class_error'] = 100 - accuracy(src_logits[idx], target_classes_o)[0]
return losses
def loss_cardinality(self, outputs, targets, indices, num_boxes):
""" Compute the cardinality error, ie the absolute error in the number of predicted non-empty boxes
This is not really a loss, it is intended for logging purposes only. It doesn't propagate gradients
pred_logits = outputs['pred_logits']
device = pred_logits.device
tgt_lengths = torch.as_tensor([len(v["labels"]) for v in targets], device=device)
# Count the number of predictions that are NOT "no-object" (which is the last class)
card_pred = (pred_logits.argmax(-1) != pred_logits.shape[-1] - 1).sum(1)
card_err = F.l1_loss(card_pred.float(), tgt_lengths.float(), reduce=None)
losses = {'cardinality_error': card_err}
return losses
def loss_boxes(self, outputs, targets, indices, num_boxes):
"""Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss
targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4]
The target boxes are expected in format (center_x, center_y, h, w), normalized by the image size.
assert 'pred_boxes' in outputs
idx = self._get_src_permutation_idx(indices)
src_boxes = outputs['pred_boxes'][idx]
target_boxes =[t['boxes'][i] for t, (_, i) in zip(targets, indices)], dim=0)
loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction='none')
losses = {}
losses['loss_bbox'] = loss_bbox / num_boxes
loss_giou = 1 - torch.diag(box_ops.generalized_box_iou(
losses['loss_giou'] = loss_giou / num_boxes
return losses
def loss_masks(self, outputs, targets, indices, num_boxes):
"""Compute the losses related to the masks: the focal loss and the dice loss.
targets dicts must contain the key "masks" containing a tensor of dim [nb_target_boxes, h, w]
assert "pred_masks" in outputs
src_idx = self._get_src_permutation_idx(indices)
tgt_idx = self._get_tgt_permutation_idx(indices)
src_masks = outputs["pred_masks"]
# TODO use valid to mask invalid areas due to padding in loss
target_masks, valid = nested_tensor_from_tensor_list([t["masks"] for t in targets]).decompose()
target_masks =
src_masks = src_masks[src_idx]
# upsample predictions to the target size
src_masks = interpolate(src_masks[:, None], size=target_masks.shape[-2:],
mode="bilinear", align_corners=False)
src_masks = src_masks[:, 0].flatten(1)
target_masks = target_masks[tgt_idx].flatten(1)
losses = {
"loss_mask": sigmoid_focal_loss(src_masks, target_masks, num_boxes),
"loss_dice": dice_loss(src_masks, target_masks, num_boxes),
return losses
def _get_src_permutation_idx(self, indices):
# permute predictions following indices
batch_idx =[torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
src_idx =[src for (src, _) in indices])
return batch_idx, src_idx
def _get_tgt_permutation_idx(self, indices):
# permute targets following indices
batch_idx =[torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)])
tgt_idx =[tgt for (_, tgt) in indices])
return batch_idx, tgt_idx
def get_loss(self, loss, outputs, targets, indices, num_boxes, **kwargs):
loss_map = {
'labels': self.loss_labels,
'cardinality': self.loss_cardinality,
'boxes': self.loss_boxes,
'masks': self.loss_masks
assert loss in loss_map, f'do you really want to compute {loss} loss?'
return loss_map[loss](outputs, targets, indices, num_boxes, **kwargs)
def forward(self, outputs, targets, num_boxes):
""" This performs the loss computation.
outputs: dict of tensors, see the output specification of the model for the format
targets: list of dicts, such that len(targets) == batch_size.
The expected keys in each dict depends on the losses applied, see each loss' doc
outputs_without_aux = {k: v for k, v in outputs.items() if k != 'aux_outputs' and k != 'enc_outputs'}
# Retrieve the matching between the outputs of the last layer and the targets
indices = self.matcher(outputs_without_aux, targets)
# Compute all the requested losses
losses = {}
for loss in self.losses:
kwargs = {}
losses.update(self.get_loss(loss, outputs, targets, indices, num_boxes, **kwargs))
# In case of auxiliary losses, we repeat this process with the output of each intermediate layer.
if 'aux_outputs' in outputs:
for i, aux_outputs in enumerate(outputs['aux_outputs']):
indices = self.matcher(aux_outputs, targets)
for loss in self.losses:
if loss == 'masks':
# Intermediate masks losses are too costly to compute, we ignore them.
kwargs = {}
if loss == 'labels':
# Logging is enabled only for the last layer
kwargs['log'] = False
l_dict = self.get_loss(loss, aux_outputs, targets, indices, num_boxes, **kwargs)
l_dict = {k + f'_{i}': v for k, v in l_dict.items()}
if 'enc_outputs' in outputs:
enc_outputs = outputs['enc_outputs']
bin_targets = copy.deepcopy(targets)
for bt in bin_targets:
bt['labels'] = torch.zeros_like(bt['labels'])
indices = self.matcher(enc_outputs, bin_targets)
for loss in self.losses:
if loss == 'masks':
# Intermediate masks losses are too costly to compute, we ignore them.
kwargs = {}
if loss == 'labels':
# Logging is enabled only for the last layer
kwargs['log'] = False
l_dict = self.get_loss(loss, enc_outputs, bin_targets, indices, num_boxes, **kwargs)
l_dict = {k + f'_enc': v for k, v in l_dict.items()}
return losses