# Copyright (c) Tencent Inc. All rights reserved. import math import copy from typing import List, Optional, Tuple, Union, Sequence import torch import torch.nn as nn import torch.nn.functional as F import numpy as np from mmcv.cnn import ConvModule from mmengine.config import ConfigDict from mmengine.model import BaseModule from torch import Tensor from mmengine.dist import get_dist_info from mmengine.structures import InstanceData from mmdet.structures import SampleList from mmdet.utils import OptConfigType, InstanceList, OptInstanceList from mmdet.models.utils import ( multi_apply, unpack_gt_instances, filter_scores_and_topk) from mmyolo.registry import MODELS from mmyolo.models.dense_heads import YOLOv8HeadModule, YOLOv8Head from mmyolo.models.utils import gt_instances_preprocess from mmcv.cnn.bricks import build_norm_layer @MODELS.register_module() class ContrastiveHead(BaseModule): """Contrastive Head for YOLO-World compute the region-text scores according to the similarity between image and text features Args: embed_dims (int): embed dim of text and image features """ def __init__(self, embed_dims: int, init_cfg: OptConfigType = None) -> None: super().__init__(init_cfg=init_cfg) self.bias = nn.Parameter(torch.zeros([])) self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) def forward(self, x: Tensor, w: Tensor) -> Tensor: """Forward function of contrastive learning.""" x = F.normalize(x, dim=1, p=2) w = F.normalize(w, dim=-1, p=2) x = torch.einsum('bchw,bkc->bkhw', x, w) x = x * self.logit_scale.exp() + self.bias return x @MODELS.register_module() class BNContrastiveHead(BaseModule): """ Batch Norm Contrastive Head for YOLO-World using batch norm instead of l2-normalization Args: embed_dims (int): embed dim of text and image features norm_cfg (dict): normalization params """ def __init__(self, embed_dims: int, norm_cfg: ConfigDict, init_cfg: OptConfigType = None) -> None: super().__init__(init_cfg=init_cfg) self.norm = build_norm_layer(norm_cfg, embed_dims)[1] self.bias = nn.Parameter(torch.zeros([])) # use -1.0 is more stable self.logit_scale = nn.Parameter(-1.0 * torch.ones([])) def forward(self, x: Tensor, w: Tensor) -> Tensor: """Forward function of contrastive learning.""" x = self.norm(x) w = F.normalize(w, dim=-1, p=2) x = torch.einsum('bchw,bkc->bkhw', x, w) x = x * self.logit_scale.exp() + self.bias return x @MODELS.register_module() class YOLOWorldHeadModule(YOLOv8HeadModule): """Head Module for YOLO-World Args: embed_dims (int): embed dim for text feautures and image features use_bn_head (bool): use batch normalization head """ def __init__(self, *args, embed_dims: int, use_bn_head: bool = False, **kwargs) -> None: self.embed_dims = embed_dims self.use_bn_head = use_bn_head super().__init__(*args, **kwargs) def init_weights(self, prior_prob=0.01): """Initialize the weight and bias of PPYOLOE head.""" super().init_weights() for cls_pred, cls_contrast, stride in zip(self.cls_preds, self.cls_contrasts, self.featmap_strides): cls_pred[-1].bias.data[:] = 0.0 # reset bias if hasattr(cls_contrast, 'bias'): nn.init.constant_( cls_contrast.bias.data, math.log(5 / self.num_classes / (640 / stride)**2)) def _init_layers(self) -> None: """initialize conv layers in YOLOv8 head.""" # Init decouple head self.cls_preds = nn.ModuleList() self.reg_preds = nn.ModuleList() self.cls_contrasts = nn.ModuleList() reg_out_channels = max( (16, self.in_channels[0] // 4, self.reg_max * 4)) cls_out_channels = max(self.in_channels[0], self.num_classes) for i in range(self.num_levels): self.reg_preds.append( nn.Sequential( ConvModule(in_channels=self.in_channels[i], out_channels=reg_out_channels, kernel_size=3, stride=1, padding=1, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg), ConvModule(in_channels=reg_out_channels, out_channels=reg_out_channels, kernel_size=3, stride=1, padding=1, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg), nn.Conv2d(in_channels=reg_out_channels, out_channels=4 * self.reg_max, kernel_size=1))) self.cls_preds.append( nn.Sequential( ConvModule(in_channels=self.in_channels[i], out_channels=cls_out_channels, kernel_size=3, stride=1, padding=1, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg), ConvModule(in_channels=cls_out_channels, out_channels=cls_out_channels, kernel_size=3, stride=1, padding=1, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg), nn.Conv2d(in_channels=cls_out_channels, out_channels=self.embed_dims, kernel_size=1))) if self.use_bn_head: self.cls_contrasts.append( BNContrastiveHead(self.embed_dims, self.norm_cfg)) else: self.cls_contrasts.append(ContrastiveHead(self.embed_dims)) proj = torch.arange(self.reg_max, dtype=torch.float) self.register_buffer('proj', proj, persistent=False) def forward(self, img_feats: Tuple[Tensor], txt_feats: Tensor) -> Tuple[List]: """Forward features from the upstream network.""" assert len(img_feats) == self.num_levels txt_feats = [txt_feats for _ in range(self.num_levels)] return multi_apply(self.forward_single, img_feats, txt_feats, self.cls_preds, self.reg_preds, self.cls_contrasts) def forward_single(self, img_feat: Tensor, txt_feat: Tensor, cls_pred: nn.ModuleList, reg_pred: nn.ModuleList, cls_contrast: nn.ModuleList) -> Tuple: """Forward feature of a single scale level.""" b, _, h, w = img_feat.shape cls_embed = cls_pred(img_feat) cls_logit = cls_contrast(cls_embed, txt_feat) bbox_dist_preds = reg_pred(img_feat) if self.reg_max > 1: bbox_dist_preds = bbox_dist_preds.reshape( [-1, 4, self.reg_max, h * w]).permute(0, 3, 1, 2) # TODO: The get_flops script cannot handle the situation of # matmul, and needs to be fixed later # bbox_preds = bbox_dist_preds.softmax(3).matmul(self.proj) bbox_preds = bbox_dist_preds.softmax(3).matmul( self.proj.view([-1, 1])).squeeze(-1) bbox_preds = bbox_preds.transpose(1, 2).reshape(b, -1, h, w) else: bbox_preds = bbox_dist_preds if self.training: return cls_logit, bbox_preds, bbox_dist_preds else: return cls_logit, bbox_preds @MODELS.register_module() class YOLOWorldHead(YOLOv8Head): """YOLO-World Head """ def __init__(self, world_size=-1, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self.world_size = world_size """YOLO World v8 head.""" def loss(self, img_feats: Tuple[Tensor], txt_feats: Tensor, batch_data_samples: Union[list, dict]) -> dict: """Perform forward propagation and loss calculation of the detection head on the features of the upstream network.""" outs = self(img_feats, txt_feats) # Fast version loss_inputs = outs + (batch_data_samples['bboxes_labels'], batch_data_samples['img_metas']) losses = self.loss_by_feat(*loss_inputs) return losses def loss_and_predict( self, img_feats: Tuple[Tensor], txt_feats: Tensor, batch_data_samples: SampleList, proposal_cfg: Optional[ConfigDict] = None ) -> Tuple[dict, InstanceList]: """Perform forward propagation of the head, then calculate loss and predictions from the features and data samples. """ outputs = unpack_gt_instances(batch_data_samples) (batch_gt_instances, batch_gt_instances_ignore, batch_img_metas) = outputs outs = self(img_feats, txt_feats) loss_inputs = outs + (batch_gt_instances, batch_img_metas, batch_gt_instances_ignore) losses = self.loss_by_feat(*loss_inputs) predictions = self.predict_by_feat(*outs, batch_img_metas=batch_img_metas, cfg=proposal_cfg) return losses, predictions def forward(self, img_feats: Tuple[Tensor], txt_feats: Tensor) -> Tuple[List]: """Forward features from the upstream network.""" self.num_classes = txt_feats.shape[1] return self.head_module(img_feats, txt_feats) def predict(self, img_feats: Tuple[Tensor], txt_feats: Tensor, batch_data_samples: SampleList, rescale: bool = False) -> InstanceList: """Perform forward propagation of the detection head and predict detection results on the features of the upstream network. """ batch_img_metas = [ data_samples.metainfo for data_samples in batch_data_samples ] outs = self(img_feats, txt_feats) predictions = self.predict_by_feat(*outs, batch_img_metas=batch_img_metas, rescale=rescale) return predictions def aug_test(self, aug_batch_feats, aug_batch_img_metas, rescale=False, with_ori_nms=False, **kwargs): """Test function with test time augmentation.""" raise NotImplementedError('aug_test is not implemented yet.') def loss_by_feat( self, cls_scores: Sequence[Tensor], bbox_preds: Sequence[Tensor], bbox_dist_preds: Sequence[Tensor], batch_gt_instances: Sequence[InstanceData], batch_img_metas: Sequence[dict], batch_gt_instances_ignore: OptInstanceList = None) -> dict: """Calculate the loss based on the features extracted by the detection head. Args: cls_scores (Sequence[Tensor]): Box scores for each scale level, each is a 4D-tensor, the channel number is num_priors * num_classes. bbox_preds (Sequence[Tensor]): Box energies / deltas for each scale level, each is a 4D-tensor, the channel number is num_priors * 4. bbox_dist_preds (Sequence[Tensor]): Box distribution logits for each scale level with shape (bs, reg_max + 1, H*W, 4). batch_gt_instances (list[:obj:`InstanceData`]): Batch of gt_instance. It usually includes ``bboxes`` and ``labels`` attributes. batch_img_metas (list[dict]): Meta information of each image, e.g., image size, scaling factor, etc. batch_gt_instances_ignore (list[:obj:`InstanceData`], optional): Batch of gt_instances_ignore. It includes ``bboxes`` attribute data that is ignored during training and testing. Defaults to None. Returns: dict[str, Tensor]: A dictionary of losses. """ num_imgs = len(batch_img_metas) current_featmap_sizes = [ cls_score.shape[2:] for cls_score in cls_scores ] # If the shape does not equal, generate new one if current_featmap_sizes != self.featmap_sizes_train: self.featmap_sizes_train = current_featmap_sizes mlvl_priors_with_stride = self.prior_generator.grid_priors( self.featmap_sizes_train, dtype=cls_scores[0].dtype, device=cls_scores[0].device, with_stride=True) self.num_level_priors = [len(n) for n in mlvl_priors_with_stride] self.flatten_priors_train = torch.cat( mlvl_priors_with_stride, dim=0) self.stride_tensor = self.flatten_priors_train[..., [2]] # gt info gt_info = gt_instances_preprocess(batch_gt_instances, num_imgs) gt_labels = gt_info[:, :, :1] gt_bboxes = gt_info[:, :, 1:] # xyxy pad_bbox_flag = (gt_bboxes.sum(-1, keepdim=True) > 0).float() # pred info flatten_cls_preds = [ cls_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1, self.num_classes) for cls_pred in cls_scores ] flatten_pred_bboxes = [ bbox_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1, 4) for bbox_pred in bbox_preds ] # (bs, n, 4 * reg_max) flatten_pred_dists = [ bbox_pred_org.reshape(num_imgs, -1, self.head_module.reg_max * 4) for bbox_pred_org in bbox_dist_preds ] flatten_dist_preds = torch.cat(flatten_pred_dists, dim=1) flatten_cls_preds = torch.cat(flatten_cls_preds, dim=1) flatten_pred_bboxes = torch.cat(flatten_pred_bboxes, dim=1) flatten_pred_bboxes = self.bbox_coder.decode( self.flatten_priors_train[..., :2], flatten_pred_bboxes, self.stride_tensor[..., 0]) assigned_result = self.assigner( (flatten_pred_bboxes.detach()).type(gt_bboxes.dtype), flatten_cls_preds.detach().sigmoid(), self.flatten_priors_train, gt_labels, gt_bboxes, pad_bbox_flag) assigned_bboxes = assigned_result['assigned_bboxes'] assigned_scores = assigned_result['assigned_scores'] fg_mask_pre_prior = assigned_result['fg_mask_pre_prior'] assigned_scores_sum = assigned_scores.sum().clamp(min=1) loss_cls = self.loss_cls(flatten_cls_preds, assigned_scores).sum() loss_cls /= assigned_scores_sum # rescale bbox assigned_bboxes /= self.stride_tensor flatten_pred_bboxes /= self.stride_tensor # select positive samples mask num_pos = fg_mask_pre_prior.sum() if num_pos > 0: # when num_pos > 0, assigned_scores_sum will >0, so the loss_bbox # will not report an error # iou loss prior_bbox_mask = fg_mask_pre_prior.unsqueeze(-1).repeat([1, 1, 4]) pred_bboxes_pos = torch.masked_select( flatten_pred_bboxes, prior_bbox_mask).reshape([-1, 4]) assigned_bboxes_pos = torch.masked_select( assigned_bboxes, prior_bbox_mask).reshape([-1, 4]) bbox_weight = torch.masked_select( assigned_scores.sum(-1), fg_mask_pre_prior).unsqueeze(-1) loss_bbox = self.loss_bbox( pred_bboxes_pos, assigned_bboxes_pos, weight=bbox_weight) / assigned_scores_sum # dfl loss pred_dist_pos = flatten_dist_preds[fg_mask_pre_prior] assigned_ltrb = self.bbox_coder.encode( self.flatten_priors_train[..., :2] / self.stride_tensor, assigned_bboxes, max_dis=self.head_module.reg_max - 1, eps=0.01) assigned_ltrb_pos = torch.masked_select( assigned_ltrb, prior_bbox_mask).reshape([-1, 4]) loss_dfl = self.loss_dfl( pred_dist_pos.reshape(-1, self.head_module.reg_max), assigned_ltrb_pos.reshape(-1), weight=bbox_weight.expand(-1, 4).reshape(-1), avg_factor=assigned_scores_sum) else: loss_bbox = flatten_pred_bboxes.sum() * 0 loss_dfl = flatten_pred_bboxes.sum() * 0 if self.world_size == -1: _, world_size = get_dist_info() else: world_size = self.world_size return dict( loss_cls=loss_cls * num_imgs * world_size, loss_bbox=loss_bbox * num_imgs * world_size, loss_dfl=loss_dfl * num_imgs * world_size) def predict_by_feat(self, cls_scores: List[Tensor], bbox_preds: List[Tensor], objectnesses: Optional[List[Tensor]] = None, batch_img_metas: Optional[List[dict]] = None, cfg: Optional[ConfigDict] = None, rescale: bool = True, with_nms: bool = True) -> List[InstanceData]: """Transform a batch of output features extracted by the head into bbox results. Args: cls_scores (list[Tensor]): Classification scores for all scale levels, each is a 4D-tensor, has shape (batch_size, num_priors * num_classes, H, W). bbox_preds (list[Tensor]): Box energies / deltas for all scale levels, each is a 4D-tensor, has shape (batch_size, num_priors * 4, H, W). objectnesses (list[Tensor], Optional): Score factor for all scale level, each is a 4D-tensor, has shape (batch_size, 1, H, W). batch_img_metas (list[dict], Optional): Batch image meta info. Defaults to None. cfg (ConfigDict, optional): Test / postprocessing configuration, if None, test_cfg would be used. Defaults to None. rescale (bool): If True, return boxes in original image space. Defaults to False. with_nms (bool): If True, do nms before return boxes. Defaults to True. Returns: list[:obj:`InstanceData`]: Object detection results of each image after the post process. Each item usually contains following keys. - scores (Tensor): Classification scores, has a shape (num_instance, ) - labels (Tensor): Labels of bboxes, has a shape (num_instances, ). - bboxes (Tensor): Has a shape (num_instances, 4), the last dimension 4 arrange as (x1, y1, x2, y2). """ assert len(cls_scores) == len(bbox_preds) if objectnesses is None: with_objectnesses = False else: with_objectnesses = True assert len(cls_scores) == len(objectnesses) cfg = self.test_cfg if cfg is None else cfg cfg = copy.deepcopy(cfg) multi_label = cfg.multi_label multi_label &= self.num_classes > 1 cfg.multi_label = multi_label num_imgs = len(batch_img_metas) featmap_sizes = [cls_score.shape[2:] for cls_score in cls_scores] # If the shape does not change, use the previous mlvl_priors if featmap_sizes != self.featmap_sizes: self.mlvl_priors = self.prior_generator.grid_priors( featmap_sizes, dtype=cls_scores[0].dtype, device=cls_scores[0].device) self.featmap_sizes = featmap_sizes flatten_priors = torch.cat(self.mlvl_priors) mlvl_strides = [ flatten_priors.new_full( (featmap_size.numel() * self.num_base_priors, ), stride) for featmap_size, stride in zip(featmap_sizes, self.featmap_strides) ] flatten_stride = torch.cat(mlvl_strides) # flatten cls_scores, bbox_preds and objectness flatten_cls_scores = [ cls_score.permute(0, 2, 3, 1).reshape(num_imgs, -1, self.num_classes) for cls_score in cls_scores ] flatten_bbox_preds = [ bbox_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1, 4) for bbox_pred in bbox_preds ] flatten_cls_scores = torch.cat(flatten_cls_scores, dim=1).sigmoid() flatten_bbox_preds = torch.cat(flatten_bbox_preds, dim=1) flatten_decoded_bboxes = self.bbox_coder.decode( flatten_priors[None], flatten_bbox_preds, flatten_stride) if with_objectnesses: flatten_objectness = [ objectness.permute(0, 2, 3, 1).reshape(num_imgs, -1) for objectness in objectnesses ] flatten_objectness = torch.cat(flatten_objectness, dim=1).sigmoid() else: flatten_objectness = [None for _ in range(num_imgs)] # 8400 # print(flatten_cls_scores.shape) results_list = [] for (bboxes, scores, objectness, img_meta) in zip(flatten_decoded_bboxes, flatten_cls_scores, flatten_objectness, batch_img_metas): ori_shape = img_meta['ori_shape'] scale_factor = img_meta['scale_factor'] if 'pad_param' in img_meta: pad_param = img_meta['pad_param'] else: pad_param = None score_thr = cfg.get('score_thr', -1) # yolox_style does not require the following operations if objectness is not None and score_thr > 0 and not cfg.get( 'yolox_style', False): conf_inds = objectness > score_thr bboxes = bboxes[conf_inds, :] scores = scores[conf_inds, :] objectness = objectness[conf_inds] if objectness is not None: # conf = obj_conf * cls_conf scores *= objectness[:, None] if scores.shape[0] == 0: empty_results = InstanceData() empty_results.bboxes = bboxes empty_results.scores = scores[:, 0] empty_results.labels = scores[:, 0].int() results_list.append(empty_results) continue nms_pre = cfg.get('nms_pre', 100000) if cfg.multi_label is False: scores, labels = scores.max(1, keepdim=True) scores, _, keep_idxs, results = filter_scores_and_topk( scores, score_thr, nms_pre, results=dict(labels=labels[:, 0])) labels = results['labels'] else: scores, labels, keep_idxs, _ = filter_scores_and_topk( scores, score_thr, nms_pre) results = InstanceData( scores=scores, labels=labels, bboxes=bboxes[keep_idxs]) if rescale: if pad_param is not None: results.bboxes -= results.bboxes.new_tensor([ pad_param[2], pad_param[0], pad_param[2], pad_param[0] ]) results.bboxes /= results.bboxes.new_tensor( scale_factor).repeat((1, 2)) if cfg.get('yolox_style', False): # do not need max_per_img cfg.max_per_img = len(results) results = self._bbox_post_process( results=results, cfg=cfg, rescale=False, with_nms=with_nms, img_meta=img_meta) results.bboxes[:, 0::2].clamp_(0, ori_shape[1]) results.bboxes[:, 1::2].clamp_(0, ori_shape[0]) results_list.append(results) return results_list