Spaces:
Running
on
T4
Running
on
T4
# Copyright (c) Tencent Inc. All rights reserved. | |
import math | |
import copy | |
from typing import List, Optional, Tuple, Union, Sequence | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import numpy as np | |
from mmcv.cnn import ConvModule | |
from mmengine.config import ConfigDict | |
from mmengine.model import BaseModule | |
from torch import Tensor | |
from mmengine.dist import get_dist_info | |
from mmengine.structures import InstanceData | |
from mmdet.structures import SampleList | |
from mmdet.utils import OptConfigType, InstanceList, OptInstanceList | |
from mmdet.models.utils import ( | |
multi_apply, | |
unpack_gt_instances, | |
filter_scores_and_topk) | |
from mmyolo.registry import MODELS | |
from mmyolo.models.dense_heads import YOLOv8HeadModule, YOLOv8Head | |
from mmyolo.models.utils import gt_instances_preprocess | |
from mmcv.cnn.bricks import build_norm_layer | |
class ContrastiveHead(BaseModule): | |
"""Contrastive Head for YOLO-World | |
compute the region-text scores according to the | |
similarity between image and text features | |
Args: | |
embed_dims (int): embed dim of text and image features | |
""" | |
def __init__(self, | |
embed_dims: int, | |
init_cfg: OptConfigType = None) -> None: | |
super().__init__(init_cfg=init_cfg) | |
self.bias = nn.Parameter(torch.zeros([])) | |
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) | |
def forward(self, x: Tensor, w: Tensor) -> Tensor: | |
"""Forward function of contrastive learning.""" | |
x = F.normalize(x, dim=1, p=2) | |
w = F.normalize(w, dim=-1, p=2) | |
x = torch.einsum('bchw,bkc->bkhw', x, w) | |
x = x * self.logit_scale.exp() + self.bias | |
return x | |
class BNContrastiveHead(BaseModule): | |
""" Batch Norm Contrastive Head for YOLO-World | |
using batch norm instead of l2-normalization | |
Args: | |
embed_dims (int): embed dim of text and image features | |
norm_cfg (dict): normalization params | |
""" | |
def __init__(self, | |
embed_dims: int, | |
norm_cfg: ConfigDict, | |
init_cfg: OptConfigType = None) -> None: | |
super().__init__(init_cfg=init_cfg) | |
self.norm = build_norm_layer(norm_cfg, embed_dims)[1] | |
self.bias = nn.Parameter(torch.zeros([])) | |
# use -1.0 is more stable | |
self.logit_scale = nn.Parameter(-1.0 * torch.ones([])) | |
def forward(self, x: Tensor, w: Tensor) -> Tensor: | |
"""Forward function of contrastive learning.""" | |
x = self.norm(x) | |
w = F.normalize(w, dim=-1, p=2) | |
x = torch.einsum('bchw,bkc->bkhw', x, w) | |
x = x * self.logit_scale.exp() + self.bias | |
return x | |
class YOLOWorldHeadModule(YOLOv8HeadModule): | |
"""Head Module for YOLO-World | |
Args: | |
embed_dims (int): embed dim for text feautures and image features | |
use_bn_head (bool): use batch normalization head | |
""" | |
def __init__(self, | |
*args, | |
embed_dims: int, | |
use_bn_head: bool = False, | |
**kwargs) -> None: | |
self.embed_dims = embed_dims | |
self.use_bn_head = use_bn_head | |
super().__init__(*args, **kwargs) | |
def init_weights(self, prior_prob=0.01): | |
"""Initialize the weight and bias of PPYOLOE head.""" | |
super().init_weights() | |
for cls_pred, cls_contrast, stride in zip(self.cls_preds, | |
self.cls_contrasts, | |
self.featmap_strides): | |
cls_pred[-1].bias.data[:] = 0.0 # reset bias | |
if hasattr(cls_contrast, 'bias'): | |
nn.init.constant_( | |
cls_contrast.bias.data, | |
math.log(5 / self.num_classes / (640 / stride)**2)) | |
def _init_layers(self) -> None: | |
"""initialize conv layers in YOLOv8 head.""" | |
# Init decouple head | |
self.cls_preds = nn.ModuleList() | |
self.reg_preds = nn.ModuleList() | |
self.cls_contrasts = nn.ModuleList() | |
reg_out_channels = max( | |
(16, self.in_channels[0] // 4, self.reg_max * 4)) | |
cls_out_channels = max(self.in_channels[0], self.num_classes) | |
for i in range(self.num_levels): | |
self.reg_preds.append( | |
nn.Sequential( | |
ConvModule(in_channels=self.in_channels[i], | |
out_channels=reg_out_channels, | |
kernel_size=3, | |
stride=1, | |
padding=1, | |
norm_cfg=self.norm_cfg, | |
act_cfg=self.act_cfg), | |
ConvModule(in_channels=reg_out_channels, | |
out_channels=reg_out_channels, | |
kernel_size=3, | |
stride=1, | |
padding=1, | |
norm_cfg=self.norm_cfg, | |
act_cfg=self.act_cfg), | |
nn.Conv2d(in_channels=reg_out_channels, | |
out_channels=4 * self.reg_max, | |
kernel_size=1))) | |
self.cls_preds.append( | |
nn.Sequential( | |
ConvModule(in_channels=self.in_channels[i], | |
out_channels=cls_out_channels, | |
kernel_size=3, | |
stride=1, | |
padding=1, | |
norm_cfg=self.norm_cfg, | |
act_cfg=self.act_cfg), | |
ConvModule(in_channels=cls_out_channels, | |
out_channels=cls_out_channels, | |
kernel_size=3, | |
stride=1, | |
padding=1, | |
norm_cfg=self.norm_cfg, | |
act_cfg=self.act_cfg), | |
nn.Conv2d(in_channels=cls_out_channels, | |
out_channels=self.embed_dims, | |
kernel_size=1))) | |
if self.use_bn_head: | |
self.cls_contrasts.append( | |
BNContrastiveHead(self.embed_dims, self.norm_cfg)) | |
else: | |
self.cls_contrasts.append(ContrastiveHead(self.embed_dims)) | |
proj = torch.arange(self.reg_max, dtype=torch.float) | |
self.register_buffer('proj', proj, persistent=False) | |
def forward(self, img_feats: Tuple[Tensor], | |
txt_feats: Tensor) -> Tuple[List]: | |
"""Forward features from the upstream network.""" | |
assert len(img_feats) == self.num_levels | |
txt_feats = [txt_feats for _ in range(self.num_levels)] | |
return multi_apply(self.forward_single, img_feats, txt_feats, | |
self.cls_preds, self.reg_preds, self.cls_contrasts) | |
def forward_single(self, img_feat: Tensor, txt_feat: Tensor, | |
cls_pred: nn.ModuleList, reg_pred: nn.ModuleList, | |
cls_contrast: nn.ModuleList) -> Tuple: | |
"""Forward feature of a single scale level.""" | |
b, _, h, w = img_feat.shape | |
cls_embed = cls_pred(img_feat) | |
cls_logit = cls_contrast(cls_embed, txt_feat) | |
bbox_dist_preds = reg_pred(img_feat) | |
if self.reg_max > 1: | |
bbox_dist_preds = bbox_dist_preds.reshape( | |
[-1, 4, self.reg_max, h * w]).permute(0, 3, 1, 2) | |
# TODO: The get_flops script cannot handle the situation of | |
# matmul, and needs to be fixed later | |
# bbox_preds = bbox_dist_preds.softmax(3).matmul(self.proj) | |
bbox_preds = bbox_dist_preds.softmax(3).matmul( | |
self.proj.view([-1, 1])).squeeze(-1) | |
bbox_preds = bbox_preds.transpose(1, 2).reshape(b, -1, h, w) | |
else: | |
bbox_preds = bbox_dist_preds | |
if self.training: | |
return cls_logit, bbox_preds, bbox_dist_preds | |
else: | |
return cls_logit, bbox_preds | |
class YOLOWorldHead(YOLOv8Head): | |
"""YOLO-World Head | |
""" | |
def __init__(self, world_size=-1, *args, **kwargs) -> None: | |
super().__init__(*args, **kwargs) | |
self.world_size = world_size | |
"""YOLO World v8 head.""" | |
def loss(self, img_feats: Tuple[Tensor], txt_feats: Tensor, | |
batch_data_samples: Union[list, dict]) -> dict: | |
"""Perform forward propagation and loss calculation of the detection | |
head on the features of the upstream network.""" | |
outs = self(img_feats, txt_feats) | |
# Fast version | |
loss_inputs = outs + (batch_data_samples['bboxes_labels'], | |
batch_data_samples['img_metas']) | |
losses = self.loss_by_feat(*loss_inputs) | |
return losses | |
def loss_and_predict( | |
self, | |
img_feats: Tuple[Tensor], | |
txt_feats: Tensor, | |
batch_data_samples: SampleList, | |
proposal_cfg: Optional[ConfigDict] = None | |
) -> Tuple[dict, InstanceList]: | |
"""Perform forward propagation of the head, then calculate loss and | |
predictions from the features and data samples. | |
""" | |
outputs = unpack_gt_instances(batch_data_samples) | |
(batch_gt_instances, batch_gt_instances_ignore, | |
batch_img_metas) = outputs | |
outs = self(img_feats, txt_feats) | |
loss_inputs = outs + (batch_gt_instances, batch_img_metas, | |
batch_gt_instances_ignore) | |
losses = self.loss_by_feat(*loss_inputs) | |
predictions = self.predict_by_feat(*outs, | |
batch_img_metas=batch_img_metas, | |
cfg=proposal_cfg) | |
return losses, predictions | |
def forward(self, img_feats: Tuple[Tensor], | |
txt_feats: Tensor) -> Tuple[List]: | |
"""Forward features from the upstream network.""" | |
self.num_classes = txt_feats.shape[1] | |
return self.head_module(img_feats, txt_feats) | |
def predict(self, | |
img_feats: Tuple[Tensor], | |
txt_feats: Tensor, | |
batch_data_samples: SampleList, | |
rescale: bool = False) -> InstanceList: | |
"""Perform forward propagation of the detection head and predict | |
detection results on the features of the upstream network. | |
""" | |
batch_img_metas = [ | |
data_samples.metainfo for data_samples in batch_data_samples | |
] | |
outs = self(img_feats, txt_feats) | |
predictions = self.predict_by_feat(*outs, | |
batch_img_metas=batch_img_metas, | |
rescale=rescale) | |
return predictions | |
def aug_test(self, | |
aug_batch_feats, | |
aug_batch_img_metas, | |
rescale=False, | |
with_ori_nms=False, | |
**kwargs): | |
"""Test function with test time augmentation.""" | |
raise NotImplementedError('aug_test is not implemented yet.') | |
def loss_by_feat( | |
self, | |
cls_scores: Sequence[Tensor], | |
bbox_preds: Sequence[Tensor], | |
bbox_dist_preds: Sequence[Tensor], | |
batch_gt_instances: Sequence[InstanceData], | |
batch_img_metas: Sequence[dict], | |
batch_gt_instances_ignore: OptInstanceList = None) -> dict: | |
"""Calculate the loss based on the features extracted by the detection | |
head. | |
Args: | |
cls_scores (Sequence[Tensor]): Box scores for each scale level, | |
each is a 4D-tensor, the channel number is | |
num_priors * num_classes. | |
bbox_preds (Sequence[Tensor]): Box energies / deltas for each scale | |
level, each is a 4D-tensor, the channel number is | |
num_priors * 4. | |
bbox_dist_preds (Sequence[Tensor]): Box distribution logits for | |
each scale level with shape (bs, reg_max + 1, H*W, 4). | |
batch_gt_instances (list[:obj:`InstanceData`]): Batch of | |
gt_instance. It usually includes ``bboxes`` and ``labels`` | |
attributes. | |
batch_img_metas (list[dict]): Meta information of each image, e.g., | |
image size, scaling factor, etc. | |
batch_gt_instances_ignore (list[:obj:`InstanceData`], optional): | |
Batch of gt_instances_ignore. It includes ``bboxes`` attribute | |
data that is ignored during training and testing. | |
Defaults to None. | |
Returns: | |
dict[str, Tensor]: A dictionary of losses. | |
""" | |
num_imgs = len(batch_img_metas) | |
current_featmap_sizes = [ | |
cls_score.shape[2:] for cls_score in cls_scores | |
] | |
# If the shape does not equal, generate new one | |
if current_featmap_sizes != self.featmap_sizes_train: | |
self.featmap_sizes_train = current_featmap_sizes | |
mlvl_priors_with_stride = self.prior_generator.grid_priors( | |
self.featmap_sizes_train, | |
dtype=cls_scores[0].dtype, | |
device=cls_scores[0].device, | |
with_stride=True) | |
self.num_level_priors = [len(n) for n in mlvl_priors_with_stride] | |
self.flatten_priors_train = torch.cat( | |
mlvl_priors_with_stride, dim=0) | |
self.stride_tensor = self.flatten_priors_train[..., [2]] | |
# gt info | |
gt_info = gt_instances_preprocess(batch_gt_instances, num_imgs) | |
gt_labels = gt_info[:, :, :1] | |
gt_bboxes = gt_info[:, :, 1:] # xyxy | |
pad_bbox_flag = (gt_bboxes.sum(-1, keepdim=True) > 0).float() | |
# pred info | |
flatten_cls_preds = [ | |
cls_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1, | |
self.num_classes) | |
for cls_pred in cls_scores | |
] | |
flatten_pred_bboxes = [ | |
bbox_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1, 4) | |
for bbox_pred in bbox_preds | |
] | |
# (bs, n, 4 * reg_max) | |
flatten_pred_dists = [ | |
bbox_pred_org.reshape(num_imgs, -1, self.head_module.reg_max * 4) | |
for bbox_pred_org in bbox_dist_preds | |
] | |
flatten_dist_preds = torch.cat(flatten_pred_dists, dim=1) | |
flatten_cls_preds = torch.cat(flatten_cls_preds, dim=1) | |
flatten_pred_bboxes = torch.cat(flatten_pred_bboxes, dim=1) | |
flatten_pred_bboxes = self.bbox_coder.decode( | |
self.flatten_priors_train[..., :2], flatten_pred_bboxes, | |
self.stride_tensor[..., 0]) | |
assigned_result = self.assigner( | |
(flatten_pred_bboxes.detach()).type(gt_bboxes.dtype), | |
flatten_cls_preds.detach().sigmoid(), self.flatten_priors_train, | |
gt_labels, gt_bboxes, pad_bbox_flag) | |
assigned_bboxes = assigned_result['assigned_bboxes'] | |
assigned_scores = assigned_result['assigned_scores'] | |
fg_mask_pre_prior = assigned_result['fg_mask_pre_prior'] | |
assigned_scores_sum = assigned_scores.sum().clamp(min=1) | |
loss_cls = self.loss_cls(flatten_cls_preds, assigned_scores).sum() | |
loss_cls /= assigned_scores_sum | |
# rescale bbox | |
assigned_bboxes /= self.stride_tensor | |
flatten_pred_bboxes /= self.stride_tensor | |
# select positive samples mask | |
num_pos = fg_mask_pre_prior.sum() | |
if num_pos > 0: | |
# when num_pos > 0, assigned_scores_sum will >0, so the loss_bbox | |
# will not report an error | |
# iou loss | |
prior_bbox_mask = fg_mask_pre_prior.unsqueeze(-1).repeat([1, 1, 4]) | |
pred_bboxes_pos = torch.masked_select( | |
flatten_pred_bboxes, prior_bbox_mask).reshape([-1, 4]) | |
assigned_bboxes_pos = torch.masked_select( | |
assigned_bboxes, prior_bbox_mask).reshape([-1, 4]) | |
bbox_weight = torch.masked_select( | |
assigned_scores.sum(-1), fg_mask_pre_prior).unsqueeze(-1) | |
loss_bbox = self.loss_bbox( | |
pred_bboxes_pos, assigned_bboxes_pos, | |
weight=bbox_weight) / assigned_scores_sum | |
# dfl loss | |
pred_dist_pos = flatten_dist_preds[fg_mask_pre_prior] | |
assigned_ltrb = self.bbox_coder.encode( | |
self.flatten_priors_train[..., :2] / self.stride_tensor, | |
assigned_bboxes, | |
max_dis=self.head_module.reg_max - 1, | |
eps=0.01) | |
assigned_ltrb_pos = torch.masked_select( | |
assigned_ltrb, prior_bbox_mask).reshape([-1, 4]) | |
loss_dfl = self.loss_dfl( | |
pred_dist_pos.reshape(-1, self.head_module.reg_max), | |
assigned_ltrb_pos.reshape(-1), | |
weight=bbox_weight.expand(-1, 4).reshape(-1), | |
avg_factor=assigned_scores_sum) | |
else: | |
loss_bbox = flatten_pred_bboxes.sum() * 0 | |
loss_dfl = flatten_pred_bboxes.sum() * 0 | |
if self.world_size == -1: | |
_, world_size = get_dist_info() | |
else: | |
world_size = self.world_size | |
return dict( | |
loss_cls=loss_cls * num_imgs * world_size, | |
loss_bbox=loss_bbox * num_imgs * world_size, | |
loss_dfl=loss_dfl * num_imgs * world_size) | |
def predict_by_feat(self, | |
cls_scores: List[Tensor], | |
bbox_preds: List[Tensor], | |
objectnesses: Optional[List[Tensor]] = None, | |
batch_img_metas: Optional[List[dict]] = None, | |
cfg: Optional[ConfigDict] = None, | |
rescale: bool = True, | |
with_nms: bool = True) -> List[InstanceData]: | |
"""Transform a batch of output features extracted by the head into | |
bbox results. | |
Args: | |
cls_scores (list[Tensor]): Classification scores for all | |
scale levels, each is a 4D-tensor, has shape | |
(batch_size, num_priors * num_classes, H, W). | |
bbox_preds (list[Tensor]): Box energies / deltas for all | |
scale levels, each is a 4D-tensor, has shape | |
(batch_size, num_priors * 4, H, W). | |
objectnesses (list[Tensor], Optional): Score factor for | |
all scale level, each is a 4D-tensor, has shape | |
(batch_size, 1, H, W). | |
batch_img_metas (list[dict], Optional): Batch image meta info. | |
Defaults to None. | |
cfg (ConfigDict, optional): Test / postprocessing | |
configuration, if None, test_cfg would be used. | |
Defaults to None. | |
rescale (bool): If True, return boxes in original image space. | |
Defaults to False. | |
with_nms (bool): If True, do nms before return boxes. | |
Defaults to True. | |
Returns: | |
list[:obj:`InstanceData`]: Object detection results of each image | |
after the post process. Each item usually contains following keys. | |
- scores (Tensor): Classification scores, has a shape | |
(num_instance, ) | |
- labels (Tensor): Labels of bboxes, has a shape | |
(num_instances, ). | |
- bboxes (Tensor): Has a shape (num_instances, 4), | |
the last dimension 4 arrange as (x1, y1, x2, y2). | |
""" | |
assert len(cls_scores) == len(bbox_preds) | |
if objectnesses is None: | |
with_objectnesses = False | |
else: | |
with_objectnesses = True | |
assert len(cls_scores) == len(objectnesses) | |
cfg = self.test_cfg if cfg is None else cfg | |
cfg = copy.deepcopy(cfg) | |
multi_label = cfg.multi_label | |
multi_label &= self.num_classes > 1 | |
cfg.multi_label = multi_label | |
num_imgs = len(batch_img_metas) | |
featmap_sizes = [cls_score.shape[2:] for cls_score in cls_scores] | |
# If the shape does not change, use the previous mlvl_priors | |
if featmap_sizes != self.featmap_sizes: | |
self.mlvl_priors = self.prior_generator.grid_priors( | |
featmap_sizes, | |
dtype=cls_scores[0].dtype, | |
device=cls_scores[0].device) | |
self.featmap_sizes = featmap_sizes | |
flatten_priors = torch.cat(self.mlvl_priors) | |
mlvl_strides = [ | |
flatten_priors.new_full( | |
(featmap_size.numel() * self.num_base_priors, ), stride) for | |
featmap_size, stride in zip(featmap_sizes, self.featmap_strides) | |
] | |
flatten_stride = torch.cat(mlvl_strides) | |
# flatten cls_scores, bbox_preds and objectness | |
flatten_cls_scores = [ | |
cls_score.permute(0, 2, 3, 1).reshape(num_imgs, -1, | |
self.num_classes) | |
for cls_score in cls_scores | |
] | |
flatten_bbox_preds = [ | |
bbox_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1, 4) | |
for bbox_pred in bbox_preds | |
] | |
flatten_cls_scores = torch.cat(flatten_cls_scores, dim=1).sigmoid() | |
flatten_bbox_preds = torch.cat(flatten_bbox_preds, dim=1) | |
flatten_decoded_bboxes = self.bbox_coder.decode( | |
flatten_priors[None], flatten_bbox_preds, flatten_stride) | |
if with_objectnesses: | |
flatten_objectness = [ | |
objectness.permute(0, 2, 3, 1).reshape(num_imgs, -1) | |
for objectness in objectnesses | |
] | |
flatten_objectness = torch.cat(flatten_objectness, dim=1).sigmoid() | |
else: | |
flatten_objectness = [None for _ in range(num_imgs)] | |
# 8400 | |
# print(flatten_cls_scores.shape) | |
results_list = [] | |
for (bboxes, scores, objectness, | |
img_meta) in zip(flatten_decoded_bboxes, flatten_cls_scores, | |
flatten_objectness, batch_img_metas): | |
ori_shape = img_meta['ori_shape'] | |
scale_factor = img_meta['scale_factor'] | |
if 'pad_param' in img_meta: | |
pad_param = img_meta['pad_param'] | |
else: | |
pad_param = None | |
score_thr = cfg.get('score_thr', -1) | |
# yolox_style does not require the following operations | |
if objectness is not None and score_thr > 0 and not cfg.get( | |
'yolox_style', False): | |
conf_inds = objectness > score_thr | |
bboxes = bboxes[conf_inds, :] | |
scores = scores[conf_inds, :] | |
objectness = objectness[conf_inds] | |
if objectness is not None: | |
# conf = obj_conf * cls_conf | |
scores *= objectness[:, None] | |
if scores.shape[0] == 0: | |
empty_results = InstanceData() | |
empty_results.bboxes = bboxes | |
empty_results.scores = scores[:, 0] | |
empty_results.labels = scores[:, 0].int() | |
results_list.append(empty_results) | |
continue | |
nms_pre = cfg.get('nms_pre', 100000) | |
if cfg.multi_label is False: | |
scores, labels = scores.max(1, keepdim=True) | |
scores, _, keep_idxs, results = filter_scores_and_topk( | |
scores, | |
score_thr, | |
nms_pre, | |
results=dict(labels=labels[:, 0])) | |
labels = results['labels'] | |
else: | |
scores, labels, keep_idxs, _ = filter_scores_and_topk( | |
scores, score_thr, nms_pre) | |
results = InstanceData( | |
scores=scores, labels=labels, bboxes=bboxes[keep_idxs]) | |
if rescale: | |
if pad_param is not None: | |
results.bboxes -= results.bboxes.new_tensor([ | |
pad_param[2], pad_param[0], pad_param[2], pad_param[0] | |
]) | |
results.bboxes /= results.bboxes.new_tensor( | |
scale_factor).repeat((1, 2)) | |
if cfg.get('yolox_style', False): | |
# do not need max_per_img | |
cfg.max_per_img = len(results) | |
results = self._bbox_post_process( | |
results=results, | |
cfg=cfg, | |
rescale=False, | |
with_nms=with_nms, | |
img_meta=img_meta) | |
results.bboxes[:, 0::2].clamp_(0, ori_shape[1]) | |
results.bboxes[:, 1::2].clamp_(0, ori_shape[0]) | |
results_list.append(results) | |
return results_list | |