csmithxc's picture
Upload 146 files
1530901 verified
raw
history blame
No virus
26.9 kB
# Copyright (c) Tencent Inc. All rights reserved.
import math
import copy
from typing import List, Optional, Tuple, Union, Sequence
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from mmcv.cnn import ConvModule
from mmengine.config import ConfigDict
from mmengine.model import BaseModule
from torch import Tensor
from torch.nn.modules.batchnorm import _BatchNorm
from mmengine.dist import get_dist_info
from mmengine.structures import InstanceData
from mmdet.structures import SampleList
from mmdet.utils import OptConfigType, InstanceList, OptInstanceList
from mmdet.models.utils import (multi_apply, unpack_gt_instances,
filter_scores_and_topk)
from mmyolo.registry import MODELS
from mmyolo.models.dense_heads import YOLOv8HeadModule, YOLOv8Head
from mmyolo.models.utils import gt_instances_preprocess
from mmcv.cnn.bricks import build_norm_layer
@MODELS.register_module()
class ContrastiveHead(BaseModule):
"""Contrastive Head for YOLO-World
compute the region-text scores according to the
similarity between image and text features
Args:
embed_dims (int): embed dim of text and image features
"""
def __init__(self,
embed_dims: int,
init_cfg: OptConfigType = None,
use_einsum: bool = True) -> None:
super().__init__(init_cfg=init_cfg)
self.bias = nn.Parameter(torch.zeros([]))
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
self.use_einsum = use_einsum
def forward(self, x: Tensor, w: Tensor) -> Tensor:
"""Forward function of contrastive learning."""
x = F.normalize(x, dim=1, p=2)
w = F.normalize(w, dim=-1, p=2)
if self.use_einsum:
x = torch.einsum('bchw,bkc->bkhw', x, w)
else:
batch, channel, height, width = x.shape
_, k, _ = w.shape
x = x.permute(0, 2, 3, 1) # bchw->bhwc
x = x.reshape(batch, -1, channel) # bhwc->b(hw)c
w = w.permute(0, 2, 1) # bkc->bck
x = torch.matmul(x, w)
x = x.reshape(batch, height, width, k)
x = x.permute(0, 3, 1, 2)
x = x * self.logit_scale.exp() + self.bias
return x
@MODELS.register_module()
class BNContrastiveHead(BaseModule):
""" Batch Norm Contrastive Head for YOLO-World
using batch norm instead of l2-normalization
Args:
embed_dims (int): embed dim of text and image features
norm_cfg (dict): normalization params
"""
def __init__(self,
embed_dims: int,
norm_cfg: ConfigDict,
init_cfg: OptConfigType = None,
use_einsum: bool = True) -> None:
super().__init__(init_cfg=init_cfg)
self.norm = build_norm_layer(norm_cfg, embed_dims)[1]
self.bias = nn.Parameter(torch.zeros([]))
# use -1.0 is more stable
self.logit_scale = nn.Parameter(-1.0 * torch.ones([]))
self.use_einsum = use_einsum
def forward(self, x: Tensor, w: Tensor) -> Tensor:
"""Forward function of contrastive learning."""
x = self.norm(x)
w = F.normalize(w, dim=-1, p=2)
if self.use_einsum:
x = torch.einsum('bchw,bkc->bkhw', x, w)
else:
batch, channel, height, width = x.shape
_, k, _ = w.shape
x = x.permute(0, 2, 3, 1) # bchw->bhwc
x = x.reshape(batch, -1, channel) # bhwc->b(hw)c
w = w.permute(0, 2, 1) # bkc->bck
x = torch.matmul(x, w)
x = x.reshape(batch, height, width, k)
x = x.permute(0, 3, 1, 2)
x = x * self.logit_scale.exp() + self.bias
return x
@MODELS.register_module()
class YOLOWorldHeadModule(YOLOv8HeadModule):
"""Head Module for YOLO-World
Args:
embed_dims (int): embed dim for text feautures and image features
use_bn_head (bool): use batch normalization head
"""
def __init__(self,
*args,
embed_dims: int,
use_bn_head: bool = False,
use_einsum: bool = True,
freeze_all: bool = False,
**kwargs) -> None:
self.embed_dims = embed_dims
self.use_bn_head = use_bn_head
self.use_einsum = use_einsum
self.freeze_all = freeze_all
super().__init__(*args, **kwargs)
def init_weights(self, prior_prob=0.01):
"""Initialize the weight and bias of PPYOLOE head."""
super().init_weights()
for cls_pred, cls_contrast, stride in zip(self.cls_preds,
self.cls_contrasts,
self.featmap_strides):
cls_pred[-1].bias.data[:] = 0.0 # reset bias
if hasattr(cls_contrast, 'bias'):
nn.init.constant_(
cls_contrast.bias.data,
math.log(5 / self.num_classes / (640 / stride)**2))
def _init_layers(self) -> None:
"""initialize conv layers in YOLOv8 head."""
# Init decouple head
self.cls_preds = nn.ModuleList()
self.reg_preds = nn.ModuleList()
self.cls_contrasts = nn.ModuleList()
reg_out_channels = max(
(16, self.in_channels[0] // 4, self.reg_max * 4))
cls_out_channels = max(self.in_channels[0], self.num_classes)
for i in range(self.num_levels):
self.reg_preds.append(
nn.Sequential(
ConvModule(in_channels=self.in_channels[i],
out_channels=reg_out_channels,
kernel_size=3,
stride=1,
padding=1,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg),
ConvModule(in_channels=reg_out_channels,
out_channels=reg_out_channels,
kernel_size=3,
stride=1,
padding=1,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg),
nn.Conv2d(in_channels=reg_out_channels,
out_channels=4 * self.reg_max,
kernel_size=1)))
self.cls_preds.append(
nn.Sequential(
ConvModule(in_channels=self.in_channels[i],
out_channels=cls_out_channels,
kernel_size=3,
stride=1,
padding=1,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg),
ConvModule(in_channels=cls_out_channels,
out_channels=cls_out_channels,
kernel_size=3,
stride=1,
padding=1,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg),
nn.Conv2d(in_channels=cls_out_channels,
out_channels=self.embed_dims,
kernel_size=1)))
if self.use_bn_head:
self.cls_contrasts.append(
BNContrastiveHead(self.embed_dims,
self.norm_cfg,
use_einsum=self.use_einsum))
else:
self.cls_contrasts.append(
ContrastiveHead(self.embed_dims,
use_einsum=self.use_einsum))
proj = torch.arange(self.reg_max, dtype=torch.float)
self.register_buffer('proj', proj, persistent=False)
if self.freeze_all:
self._freeze_all()
def _freeze_all(self):
"""Freeze the model."""
for m in self.modules():
if isinstance(m, _BatchNorm):
m.eval()
for param in m.parameters():
param.requires_grad = False
def train(self, mode=True):
super().train(mode)
if self.freeze_all:
self._freeze_all()
def forward(self, img_feats: Tuple[Tensor],
txt_feats: Tensor) -> Tuple[List]:
"""Forward features from the upstream network."""
assert len(img_feats) == self.num_levels
txt_feats = [txt_feats for _ in range(self.num_levels)]
return multi_apply(self.forward_single, img_feats, txt_feats,
self.cls_preds, self.reg_preds, self.cls_contrasts)
def forward_single(self, img_feat: Tensor, txt_feat: Tensor,
cls_pred: nn.ModuleList, reg_pred: nn.ModuleList,
cls_contrast: nn.ModuleList) -> Tuple:
"""Forward feature of a single scale level."""
b, _, h, w = img_feat.shape
cls_embed = cls_pred(img_feat)
cls_logit = cls_contrast(cls_embed, txt_feat)
bbox_dist_preds = reg_pred(img_feat)
if self.reg_max > 1:
bbox_dist_preds = bbox_dist_preds.reshape(
[-1, 4, self.reg_max, h * w]).permute(0, 3, 1, 2)
# TODO: The get_flops script cannot handle the situation of
# matmul, and needs to be fixed later
# bbox_preds = bbox_dist_preds.softmax(3).matmul(self.proj)
bbox_preds = bbox_dist_preds.softmax(3).matmul(
self.proj.view([-1, 1])).squeeze(-1)
bbox_preds = bbox_preds.transpose(1, 2).reshape(b, -1, h, w)
else:
bbox_preds = bbox_dist_preds
if self.training:
return cls_logit, bbox_preds, bbox_dist_preds
else:
return cls_logit, bbox_preds
@MODELS.register_module()
class YOLOWorldHead(YOLOv8Head):
"""YOLO-World Head
"""
def __init__(self, world_size=-1, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.world_size = world_size
"""YOLO World v8 head."""
def loss(self, img_feats: Tuple[Tensor], txt_feats: Tensor,
batch_data_samples: Union[list, dict]) -> dict:
"""Perform forward propagation and loss calculation of the detection
head on the features of the upstream network."""
outs = self(img_feats, txt_feats)
# Fast version
loss_inputs = outs + (batch_data_samples['bboxes_labels'],
batch_data_samples['img_metas'])
losses = self.loss_by_feat(*loss_inputs)
return losses
def loss_and_predict(
self,
img_feats: Tuple[Tensor],
txt_feats: Tensor,
batch_data_samples: SampleList,
proposal_cfg: Optional[ConfigDict] = None
) -> Tuple[dict, InstanceList]:
"""Perform forward propagation of the head, then calculate loss and
predictions from the features and data samples.
"""
outputs = unpack_gt_instances(batch_data_samples)
(batch_gt_instances, batch_gt_instances_ignore,
batch_img_metas) = outputs
outs = self(img_feats, txt_feats)
loss_inputs = outs + (batch_gt_instances, batch_img_metas,
batch_gt_instances_ignore)
losses = self.loss_by_feat(*loss_inputs)
predictions = self.predict_by_feat(*outs,
batch_img_metas=batch_img_metas,
cfg=proposal_cfg)
return losses, predictions
def forward(self, img_feats: Tuple[Tensor],
txt_feats: Tensor) -> Tuple[List]:
"""Forward features from the upstream network."""
return self.head_module(img_feats, txt_feats)
def predict(self,
img_feats: Tuple[Tensor],
txt_feats: Tensor,
batch_data_samples: SampleList,
rescale: bool = False) -> InstanceList:
"""Perform forward propagation of the detection head and predict
detection results on the features of the upstream network.
"""
batch_img_metas = [
data_samples.metainfo for data_samples in batch_data_samples
]
outs = self(img_feats, txt_feats)
predictions = self.predict_by_feat(*outs,
batch_img_metas=batch_img_metas,
rescale=rescale)
return predictions
def aug_test(self,
aug_batch_feats,
aug_batch_img_metas,
rescale=False,
with_ori_nms=False,
**kwargs):
"""Test function with test time augmentation."""
raise NotImplementedError('aug_test is not implemented yet.')
def loss_by_feat(
self,
cls_scores: Sequence[Tensor],
bbox_preds: Sequence[Tensor],
bbox_dist_preds: Sequence[Tensor],
batch_gt_instances: Sequence[InstanceData],
batch_img_metas: Sequence[dict],
batch_gt_instances_ignore: OptInstanceList = None) -> dict:
"""Calculate the loss based on the features extracted by the detection
head.
Args:
cls_scores (Sequence[Tensor]): Box scores for each scale level,
each is a 4D-tensor, the channel number is
num_priors * num_classes.
bbox_preds (Sequence[Tensor]): Box energies / deltas for each scale
level, each is a 4D-tensor, the channel number is
num_priors * 4.
bbox_dist_preds (Sequence[Tensor]): Box distribution logits for
each scale level with shape (bs, reg_max + 1, H*W, 4).
batch_gt_instances (list[:obj:`InstanceData`]): Batch of
gt_instance. It usually includes ``bboxes`` and ``labels``
attributes.
batch_img_metas (list[dict]): Meta information of each image, e.g.,
image size, scaling factor, etc.
batch_gt_instances_ignore (list[:obj:`InstanceData`], optional):
Batch of gt_instances_ignore. It includes ``bboxes`` attribute
data that is ignored during training and testing.
Defaults to None.
Returns:
dict[str, Tensor]: A dictionary of losses.
"""
num_imgs = len(batch_img_metas)
current_featmap_sizes = [
cls_score.shape[2:] for cls_score in cls_scores
]
# If the shape does not equal, generate new one
if current_featmap_sizes != self.featmap_sizes_train:
self.featmap_sizes_train = current_featmap_sizes
mlvl_priors_with_stride = self.prior_generator.grid_priors(
self.featmap_sizes_train,
dtype=cls_scores[0].dtype,
device=cls_scores[0].device,
with_stride=True)
self.num_level_priors = [len(n) for n in mlvl_priors_with_stride]
self.flatten_priors_train = torch.cat(mlvl_priors_with_stride,
dim=0)
self.stride_tensor = self.flatten_priors_train[..., [2]]
# gt info
gt_info = gt_instances_preprocess(batch_gt_instances, num_imgs)
gt_labels = gt_info[:, :, :1]
gt_bboxes = gt_info[:, :, 1:] # xyxy
pad_bbox_flag = (gt_bboxes.sum(-1, keepdim=True) > 0).float()
# pred info
flatten_cls_preds = [
cls_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1,
self.num_classes)
for cls_pred in cls_scores
]
flatten_pred_bboxes = [
bbox_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1, 4)
for bbox_pred in bbox_preds
]
# (bs, n, 4 * reg_max)
flatten_pred_dists = [
bbox_pred_org.reshape(num_imgs, -1, self.head_module.reg_max * 4)
for bbox_pred_org in bbox_dist_preds
]
flatten_dist_preds = torch.cat(flatten_pred_dists, dim=1)
flatten_cls_preds = torch.cat(flatten_cls_preds, dim=1)
flatten_pred_bboxes = torch.cat(flatten_pred_bboxes, dim=1)
flatten_pred_bboxes = self.bbox_coder.decode(
self.flatten_priors_train[..., :2], flatten_pred_bboxes,
self.stride_tensor[..., 0])
assigned_result = self.assigner(
(flatten_pred_bboxes.detach()).type(gt_bboxes.dtype),
flatten_cls_preds.detach().sigmoid(), self.flatten_priors_train,
gt_labels, gt_bboxes, pad_bbox_flag)
assigned_bboxes = assigned_result['assigned_bboxes']
assigned_scores = assigned_result['assigned_scores']
fg_mask_pre_prior = assigned_result['fg_mask_pre_prior']
assigned_scores_sum = assigned_scores.sum().clamp(min=1)
loss_cls = self.loss_cls(flatten_cls_preds, assigned_scores).sum()
loss_cls /= assigned_scores_sum
# rescale bbox
assigned_bboxes /= self.stride_tensor
flatten_pred_bboxes /= self.stride_tensor
# select positive samples mask
num_pos = fg_mask_pre_prior.sum()
if num_pos > 0:
# when num_pos > 0, assigned_scores_sum will >0, so the loss_bbox
# will not report an error
# iou loss
prior_bbox_mask = fg_mask_pre_prior.unsqueeze(-1).repeat([1, 1, 4])
pred_bboxes_pos = torch.masked_select(
flatten_pred_bboxes, prior_bbox_mask).reshape([-1, 4])
assigned_bboxes_pos = torch.masked_select(
assigned_bboxes, prior_bbox_mask).reshape([-1, 4])
bbox_weight = torch.masked_select(assigned_scores.sum(-1),
fg_mask_pre_prior).unsqueeze(-1)
loss_bbox = self.loss_bbox(
pred_bboxes_pos, assigned_bboxes_pos,
weight=bbox_weight) / assigned_scores_sum
# dfl loss
pred_dist_pos = flatten_dist_preds[fg_mask_pre_prior]
assigned_ltrb = self.bbox_coder.encode(
self.flatten_priors_train[..., :2] / self.stride_tensor,
assigned_bboxes,
max_dis=self.head_module.reg_max - 1,
eps=0.01)
assigned_ltrb_pos = torch.masked_select(
assigned_ltrb, prior_bbox_mask).reshape([-1, 4])
loss_dfl = self.loss_dfl(pred_dist_pos.reshape(
-1, self.head_module.reg_max),
assigned_ltrb_pos.reshape(-1),
weight=bbox_weight.expand(-1,
4).reshape(-1),
avg_factor=assigned_scores_sum)
else:
loss_bbox = flatten_pred_bboxes.sum() * 0
loss_dfl = flatten_pred_bboxes.sum() * 0
if self.world_size == -1:
_, world_size = get_dist_info()
else:
world_size = self.world_size
return dict(loss_cls=loss_cls * num_imgs * world_size,
loss_bbox=loss_bbox * num_imgs * world_size,
loss_dfl=loss_dfl * num_imgs * world_size)
def predict_by_feat(self,
cls_scores: List[Tensor],
bbox_preds: List[Tensor],
objectnesses: Optional[List[Tensor]] = None,
batch_img_metas: Optional[List[dict]] = None,
cfg: Optional[ConfigDict] = None,
rescale: bool = True,
with_nms: bool = True) -> List[InstanceData]:
"""Transform a batch of output features extracted by the head into
bbox results.
Args:
cls_scores (list[Tensor]): Classification scores for all
scale levels, each is a 4D-tensor, has shape
(batch_size, num_priors * num_classes, H, W).
bbox_preds (list[Tensor]): Box energies / deltas for all
scale levels, each is a 4D-tensor, has shape
(batch_size, num_priors * 4, H, W).
objectnesses (list[Tensor], Optional): Score factor for
all scale level, each is a 4D-tensor, has shape
(batch_size, 1, H, W).
batch_img_metas (list[dict], Optional): Batch image meta info.
Defaults to None.
cfg (ConfigDict, optional): Test / postprocessing
configuration, if None, test_cfg would be used.
Defaults to None.
rescale (bool): If True, return boxes in original image space.
Defaults to False.
with_nms (bool): If True, do nms before return boxes.
Defaults to True.
Returns:
list[:obj:`InstanceData`]: Object detection results of each image
after the post process. Each item usually contains following keys.
- scores (Tensor): Classification scores, has a shape
(num_instance, )
- labels (Tensor): Labels of bboxes, has a shape
(num_instances, ).
- bboxes (Tensor): Has a shape (num_instances, 4),
the last dimension 4 arrange as (x1, y1, x2, y2).
"""
assert len(cls_scores) == len(bbox_preds)
if objectnesses is None:
with_objectnesses = False
else:
with_objectnesses = True
assert len(cls_scores) == len(objectnesses)
cfg = self.test_cfg if cfg is None else cfg
cfg = copy.deepcopy(cfg)
multi_label = cfg.multi_label
multi_label &= self.num_classes > 1
cfg.multi_label = multi_label
num_imgs = len(batch_img_metas)
featmap_sizes = [cls_score.shape[2:] for cls_score in cls_scores]
# If the shape does not change, use the previous mlvl_priors
if featmap_sizes != self.featmap_sizes:
self.mlvl_priors = self.prior_generator.grid_priors(
featmap_sizes,
dtype=cls_scores[0].dtype,
device=cls_scores[0].device)
self.featmap_sizes = featmap_sizes
flatten_priors = torch.cat(self.mlvl_priors)
mlvl_strides = [
flatten_priors.new_full(
(featmap_size.numel() * self.num_base_priors, ), stride) for
featmap_size, stride in zip(featmap_sizes, self.featmap_strides)
]
flatten_stride = torch.cat(mlvl_strides)
# flatten cls_scores, bbox_preds and objectness
flatten_cls_scores = [
cls_score.permute(0, 2, 3, 1).reshape(num_imgs, -1,
self.num_classes)
for cls_score in cls_scores
]
flatten_bbox_preds = [
bbox_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1, 4)
for bbox_pred in bbox_preds
]
flatten_cls_scores = torch.cat(flatten_cls_scores, dim=1).sigmoid()
flatten_bbox_preds = torch.cat(flatten_bbox_preds, dim=1)
flatten_decoded_bboxes = self.bbox_coder.decode(
flatten_priors[None], flatten_bbox_preds, flatten_stride)
if with_objectnesses:
flatten_objectness = [
objectness.permute(0, 2, 3, 1).reshape(num_imgs, -1)
for objectness in objectnesses
]
flatten_objectness = torch.cat(flatten_objectness, dim=1).sigmoid()
else:
flatten_objectness = [None for _ in range(num_imgs)]
# 8400
# print(flatten_cls_scores.shape)
results_list = []
for (bboxes, scores, objectness,
img_meta) in zip(flatten_decoded_bboxes, flatten_cls_scores,
flatten_objectness, batch_img_metas):
ori_shape = img_meta['ori_shape']
scale_factor = img_meta['scale_factor']
if 'pad_param' in img_meta:
pad_param = img_meta['pad_param']
else:
pad_param = None
score_thr = cfg.get('score_thr', -1)
# yolox_style does not require the following operations
if objectness is not None and score_thr > 0 and not cfg.get(
'yolox_style', False):
conf_inds = objectness > score_thr
bboxes = bboxes[conf_inds, :]
scores = scores[conf_inds, :]
objectness = objectness[conf_inds]
if objectness is not None:
# conf = obj_conf * cls_conf
scores *= objectness[:, None]
if scores.shape[0] == 0:
empty_results = InstanceData()
empty_results.bboxes = bboxes
empty_results.scores = scores[:, 0]
empty_results.labels = scores[:, 0].int()
results_list.append(empty_results)
continue
nms_pre = cfg.get('nms_pre', 100000)
if cfg.multi_label is False:
scores, labels = scores.max(1, keepdim=True)
scores, _, keep_idxs, results = filter_scores_and_topk(
scores,
score_thr,
nms_pre,
results=dict(labels=labels[:, 0]))
labels = results['labels']
else:
scores, labels, keep_idxs, _ = filter_scores_and_topk(
scores, score_thr, nms_pre)
results = InstanceData(scores=scores,
labels=labels,
bboxes=bboxes[keep_idxs])
if rescale:
if pad_param is not None:
results.bboxes -= results.bboxes.new_tensor([
pad_param[2], pad_param[0], pad_param[2], pad_param[0]
])
results.bboxes /= results.bboxes.new_tensor(
scale_factor).repeat((1, 2))
if cfg.get('yolox_style', False):
# do not need max_per_img
cfg.max_per_img = len(results)
results = self._bbox_post_process(results=results,
cfg=cfg,
rescale=False,
with_nms=with_nms,
img_meta=img_meta)
results.bboxes[:, 0::2].clamp_(0, ori_shape[1])
results.bboxes[:, 1::2].clamp_(0, ori_shape[0])
results_list.append(results)
return results_list