|
|
|
from copy import deepcopy |
|
from functools import partial |
|
from typing import List, Optional, Tuple |
|
|
|
import torch |
|
import torch.nn as nn |
|
from mmdet.models.backbones.csp_darknet import Focus |
|
from mmdet.models.layers import ChannelAttention |
|
from mmengine.config import ConfigDict |
|
from torch import Tensor |
|
|
|
from mmyolo.models import RepVGGBlock |
|
from mmyolo.models.dense_heads import (PPYOLOEHead, RTMDetHead, YOLOv5Head, |
|
YOLOv7Head, YOLOv8Head, YOLOXHead) |
|
from mmyolo.models.layers import ImplicitA, ImplicitM |
|
from ..backbone import DeployFocus, GConvFocus, NcnnFocus |
|
from ..bbox_code import (rtmdet_bbox_decoder, yolov5_bbox_decoder, |
|
yolox_bbox_decoder) |
|
from ..nms import batched_nms, efficient_nms, onnx_nms |
|
from .backend import MMYOLOBackend |
|
|
|
|
|
class DeployModel(nn.Module): |
|
transpose = False |
|
|
|
def __init__(self, |
|
baseModel: nn.Module, |
|
backend: MMYOLOBackend, |
|
postprocess_cfg: Optional[ConfigDict] = None): |
|
super().__init__() |
|
self.baseModel = baseModel |
|
self.baseHead = baseModel.bbox_head |
|
self.backend = backend |
|
if postprocess_cfg is None: |
|
self.with_postprocess = False |
|
else: |
|
self.with_postprocess = True |
|
self.__init_sub_attributes() |
|
self.detector_type = type(self.baseHead) |
|
self.pre_top_k = postprocess_cfg.get('pre_top_k', 1000) |
|
self.keep_top_k = postprocess_cfg.get('keep_top_k', 100) |
|
self.iou_threshold = postprocess_cfg.get('iou_threshold', 0.65) |
|
self.score_threshold = postprocess_cfg.get('score_threshold', 0.25) |
|
self.__switch_deploy() |
|
|
|
def __init_sub_attributes(self): |
|
self.bbox_decoder = self.baseHead.bbox_coder.decode |
|
self.prior_generate = self.baseHead.prior_generator.grid_priors |
|
self.num_base_priors = self.baseHead.num_base_priors |
|
self.featmap_strides = self.baseHead.featmap_strides |
|
self.num_classes = self.baseHead.num_classes |
|
|
|
def __switch_deploy(self): |
|
headType = type(self.baseHead) |
|
if not self.with_postprocess: |
|
if headType in (YOLOv5Head, YOLOv7Head): |
|
self.baseHead.head_module.forward_single = self.forward_single |
|
elif headType in (PPYOLOEHead, YOLOv8Head): |
|
self.baseHead.head_module.reg_max = 0 |
|
|
|
if self.backend in (MMYOLOBackend.HORIZONX3, MMYOLOBackend.NCNN, |
|
MMYOLOBackend.TORCHSCRIPT): |
|
self.transpose = True |
|
for layer in self.baseModel.modules(): |
|
if isinstance(layer, RepVGGBlock): |
|
layer.switch_to_deploy() |
|
elif isinstance(layer, ChannelAttention): |
|
layer.global_avgpool.forward = self.forward_gvp |
|
elif isinstance(layer, Focus): |
|
|
|
if self.backend in (MMYOLOBackend.ONNXRUNTIME, |
|
MMYOLOBackend.OPENVINO, |
|
MMYOLOBackend.TENSORRT8, |
|
MMYOLOBackend.TENSORRT7): |
|
self.baseModel.backbone.stem = DeployFocus(layer) |
|
|
|
elif self.backend == MMYOLOBackend.NCNN: |
|
self.baseModel.backbone.stem = NcnnFocus(layer) |
|
|
|
else: |
|
self.baseModel.backbone.stem = GConvFocus(layer) |
|
|
|
def pred_by_feat(self, |
|
cls_scores: List[Tensor], |
|
bbox_preds: List[Tensor], |
|
objectnesses: Optional[List[Tensor]] = None, |
|
**kwargs): |
|
assert len(cls_scores) == len(bbox_preds) |
|
dtype = cls_scores[0].dtype |
|
device = cls_scores[0].device |
|
|
|
nms_func = self.select_nms() |
|
if self.detector_type in (YOLOv5Head, YOLOv7Head): |
|
bbox_decoder = yolov5_bbox_decoder |
|
elif self.detector_type is RTMDetHead: |
|
bbox_decoder = rtmdet_bbox_decoder |
|
elif self.detector_type is YOLOXHead: |
|
bbox_decoder = yolox_bbox_decoder |
|
else: |
|
bbox_decoder = self.bbox_decoder |
|
|
|
num_imgs = cls_scores[0].shape[0] |
|
featmap_sizes = [cls_score.shape[2:] for cls_score in cls_scores] |
|
|
|
mlvl_priors = self.prior_generate( |
|
featmap_sizes, dtype=dtype, device=device) |
|
|
|
flatten_priors = torch.cat(mlvl_priors) |
|
|
|
mlvl_strides = [ |
|
flatten_priors.new_full( |
|
(featmap_size[0] * featmap_size[1] * self.num_base_priors, ), |
|
stride) for featmap_size, stride in zip( |
|
featmap_sizes, self.featmap_strides) |
|
] |
|
flatten_stride = torch.cat(mlvl_strides) |
|
|
|
|
|
flatten_cls_scores = [ |
|
cls_score.permute(0, 2, 3, 1).reshape(num_imgs, -1, |
|
self.num_classes) |
|
for cls_score in cls_scores |
|
] |
|
cls_scores = torch.cat(flatten_cls_scores, dim=1).sigmoid() |
|
|
|
flatten_bbox_preds = [ |
|
bbox_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1, 4) |
|
for bbox_pred in bbox_preds |
|
] |
|
flatten_bbox_preds = torch.cat(flatten_bbox_preds, dim=1) |
|
|
|
if objectnesses is not None: |
|
flatten_objectness = [ |
|
objectness.permute(0, 2, 3, 1).reshape(num_imgs, -1) |
|
for objectness in objectnesses |
|
] |
|
flatten_objectness = torch.cat(flatten_objectness, dim=1).sigmoid() |
|
cls_scores = cls_scores * (flatten_objectness.unsqueeze(-1)) |
|
|
|
scores = cls_scores |
|
|
|
bboxes = bbox_decoder(flatten_priors[None], flatten_bbox_preds, |
|
flatten_stride) |
|
|
|
return nms_func(bboxes, scores, self.keep_top_k, self.iou_threshold, |
|
self.score_threshold, self.pre_top_k, self.keep_top_k) |
|
|
|
def select_nms(self): |
|
if self.backend in (MMYOLOBackend.ONNXRUNTIME, MMYOLOBackend.OPENVINO): |
|
nms_func = onnx_nms |
|
elif self.backend == MMYOLOBackend.TENSORRT8: |
|
nms_func = efficient_nms |
|
elif self.backend == MMYOLOBackend.TENSORRT7: |
|
nms_func = batched_nms |
|
else: |
|
raise NotImplementedError |
|
if type(self.baseHead) in (YOLOv5Head, YOLOv7Head, YOLOXHead): |
|
nms_func = partial(nms_func, box_coding=1) |
|
|
|
return nms_func |
|
|
|
def forward(self, inputs: Tensor): |
|
neck_outputs = self.baseModel(inputs) |
|
if self.with_postprocess: |
|
return self.pred_by_feat(*neck_outputs) |
|
else: |
|
outputs = [] |
|
if self.transpose: |
|
for feats in zip(*neck_outputs): |
|
if self.backend in (MMYOLOBackend.NCNN, |
|
MMYOLOBackend.TORCHSCRIPT): |
|
outputs.append( |
|
torch.cat( |
|
[feat.permute(0, 2, 3, 1) for feat in feats], |
|
-1)) |
|
else: |
|
outputs.append(torch.cat(feats, 1).permute(0, 2, 3, 1)) |
|
else: |
|
for feats in zip(*neck_outputs): |
|
outputs.append(torch.cat(feats, 1)) |
|
return tuple(outputs) |
|
|
|
@staticmethod |
|
def forward_single(x: Tensor, convs: nn.Module) -> Tuple[Tensor]: |
|
if isinstance(convs, nn.Sequential) and any( |
|
type(m) in (ImplicitA, ImplicitM) for m in convs): |
|
a, c, m = convs |
|
aw = a.implicit.clone() |
|
mw = m.implicit.clone() |
|
c = deepcopy(c) |
|
nw, cw, _, _ = c.weight.shape |
|
na, ca, _, _ = aw.shape |
|
nm, cm, _, _ = mw.shape |
|
c.bias = nn.Parameter(c.bias + ( |
|
c.weight.reshape(nw, cw) @ aw.reshape(ca, na)).squeeze(1)) |
|
c.bias = nn.Parameter(c.bias * mw.reshape(cm)) |
|
c.weight = nn.Parameter(c.weight * mw.transpose(0, 1)) |
|
convs = c |
|
feat = convs(x) |
|
return (feat, ) |
|
|
|
@staticmethod |
|
def forward_gvp(x: Tensor) -> Tensor: |
|
return torch.mean(x, [2, 3], keepdim=True) |
|
|