Spaces:
Running
on
T4
Running
on
T4
# Copyright (c) OpenMMLab. All rights reserved. | |
from copy import deepcopy | |
from functools import partial | |
from typing import List, Optional, Tuple | |
import torch | |
import torch.nn as nn | |
from mmdet.models.backbones.csp_darknet import Focus | |
from mmdet.models.layers import ChannelAttention | |
from mmengine.config import ConfigDict | |
from torch import Tensor | |
from mmyolo.models import RepVGGBlock | |
from mmyolo.models.dense_heads import (PPYOLOEHead, RTMDetHead, YOLOv5Head, | |
YOLOv7Head, YOLOv8Head, YOLOXHead) | |
from mmyolo.models.layers import ImplicitA, ImplicitM | |
from ..backbone import DeployFocus, GConvFocus, NcnnFocus | |
from ..bbox_code import (rtmdet_bbox_decoder, yolov5_bbox_decoder, | |
yolox_bbox_decoder) | |
from ..nms import batched_nms, efficient_nms, onnx_nms | |
from .backend import MMYOLOBackend | |
class DeployModel(nn.Module): | |
transpose = False | |
def __init__(self, | |
baseModel: nn.Module, | |
backend: MMYOLOBackend, | |
postprocess_cfg: Optional[ConfigDict] = None): | |
super().__init__() | |
self.baseModel = baseModel | |
self.baseHead = baseModel.bbox_head | |
self.backend = backend | |
if postprocess_cfg is None: | |
self.with_postprocess = False | |
else: | |
self.with_postprocess = True | |
self.__init_sub_attributes() | |
self.detector_type = type(self.baseHead) | |
self.pre_top_k = postprocess_cfg.get('pre_top_k', 1000) | |
self.keep_top_k = postprocess_cfg.get('keep_top_k', 100) | |
self.iou_threshold = postprocess_cfg.get('iou_threshold', 0.65) | |
self.score_threshold = postprocess_cfg.get('score_threshold', 0.25) | |
self.__switch_deploy() | |
def __init_sub_attributes(self): | |
self.bbox_decoder = self.baseHead.bbox_coder.decode | |
self.prior_generate = self.baseHead.prior_generator.grid_priors | |
self.num_base_priors = self.baseHead.num_base_priors | |
self.featmap_strides = self.baseHead.featmap_strides | |
self.num_classes = self.baseHead.num_classes | |
def __switch_deploy(self): | |
headType = type(self.baseHead) | |
if not self.with_postprocess: | |
if headType in (YOLOv5Head, YOLOv7Head): | |
self.baseHead.head_module.forward_single = self.forward_single | |
elif headType in (PPYOLOEHead, YOLOv8Head): | |
self.baseHead.head_module.reg_max = 0 | |
if self.backend in (MMYOLOBackend.HORIZONX3, MMYOLOBackend.NCNN, | |
MMYOLOBackend.TORCHSCRIPT): | |
self.transpose = True | |
for layer in self.baseModel.modules(): | |
if isinstance(layer, RepVGGBlock): | |
layer.switch_to_deploy() | |
elif isinstance(layer, ChannelAttention): | |
layer.global_avgpool.forward = self.forward_gvp | |
elif isinstance(layer, Focus): | |
# onnxruntime openvino tensorrt8 tensorrt7 | |
if self.backend in (MMYOLOBackend.ONNXRUNTIME, | |
MMYOLOBackend.OPENVINO, | |
MMYOLOBackend.TENSORRT8, | |
MMYOLOBackend.TENSORRT7): | |
self.baseModel.backbone.stem = DeployFocus(layer) | |
# ncnn | |
elif self.backend == MMYOLOBackend.NCNN: | |
self.baseModel.backbone.stem = NcnnFocus(layer) | |
# switch focus to group conv | |
else: | |
self.baseModel.backbone.stem = GConvFocus(layer) | |
def pred_by_feat(self, | |
cls_scores: List[Tensor], | |
bbox_preds: List[Tensor], | |
objectnesses: Optional[List[Tensor]] = None, | |
**kwargs): | |
assert len(cls_scores) == len(bbox_preds) | |
dtype = cls_scores[0].dtype | |
device = cls_scores[0].device | |
nms_func = self.select_nms() | |
if self.detector_type in (YOLOv5Head, YOLOv7Head): | |
bbox_decoder = yolov5_bbox_decoder | |
elif self.detector_type is RTMDetHead: | |
bbox_decoder = rtmdet_bbox_decoder | |
elif self.detector_type is YOLOXHead: | |
bbox_decoder = yolox_bbox_decoder | |
else: | |
bbox_decoder = self.bbox_decoder | |
num_imgs = cls_scores[0].shape[0] | |
featmap_sizes = [cls_score.shape[2:] for cls_score in cls_scores] | |
mlvl_priors = self.prior_generate( | |
featmap_sizes, dtype=dtype, device=device) | |
flatten_priors = torch.cat(mlvl_priors) | |
mlvl_strides = [ | |
flatten_priors.new_full( | |
(featmap_size[0] * featmap_size[1] * self.num_base_priors, ), | |
stride) for featmap_size, stride in zip( | |
featmap_sizes, self.featmap_strides) | |
] | |
flatten_stride = torch.cat(mlvl_strides) | |
# flatten cls_scores, bbox_preds and objectness | |
flatten_cls_scores = [ | |
cls_score.permute(0, 2, 3, 1).reshape(num_imgs, -1, | |
self.num_classes) | |
for cls_score in cls_scores | |
] | |
cls_scores = torch.cat(flatten_cls_scores, dim=1).sigmoid() | |
flatten_bbox_preds = [ | |
bbox_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1, 4) | |
for bbox_pred in bbox_preds | |
] | |
flatten_bbox_preds = torch.cat(flatten_bbox_preds, dim=1) | |
if objectnesses is not None: | |
flatten_objectness = [ | |
objectness.permute(0, 2, 3, 1).reshape(num_imgs, -1) | |
for objectness in objectnesses | |
] | |
flatten_objectness = torch.cat(flatten_objectness, dim=1).sigmoid() | |
cls_scores = cls_scores * (flatten_objectness.unsqueeze(-1)) | |
scores = cls_scores | |
bboxes = bbox_decoder(flatten_priors[None], flatten_bbox_preds, | |
flatten_stride) | |
return nms_func(bboxes, scores, self.keep_top_k, self.iou_threshold, | |
self.score_threshold, self.pre_top_k, self.keep_top_k) | |
def select_nms(self): | |
if self.backend in (MMYOLOBackend.ONNXRUNTIME, MMYOLOBackend.OPENVINO): | |
nms_func = onnx_nms | |
elif self.backend == MMYOLOBackend.TENSORRT8: | |
nms_func = efficient_nms | |
elif self.backend == MMYOLOBackend.TENSORRT7: | |
nms_func = batched_nms | |
else: | |
raise NotImplementedError | |
if type(self.baseHead) in (YOLOv5Head, YOLOv7Head, YOLOXHead): | |
nms_func = partial(nms_func, box_coding=1) | |
return nms_func | |
def forward(self, inputs: Tensor): | |
neck_outputs = self.baseModel(inputs) | |
if self.with_postprocess: | |
return self.pred_by_feat(*neck_outputs) | |
else: | |
outputs = [] | |
if self.transpose: | |
for feats in zip(*neck_outputs): | |
if self.backend in (MMYOLOBackend.NCNN, | |
MMYOLOBackend.TORCHSCRIPT): | |
outputs.append( | |
torch.cat( | |
[feat.permute(0, 2, 3, 1) for feat in feats], | |
-1)) | |
else: | |
outputs.append(torch.cat(feats, 1).permute(0, 2, 3, 1)) | |
else: | |
for feats in zip(*neck_outputs): | |
outputs.append(torch.cat(feats, 1)) | |
return tuple(outputs) | |
def forward_single(x: Tensor, convs: nn.Module) -> Tuple[Tensor]: | |
if isinstance(convs, nn.Sequential) and any( | |
type(m) in (ImplicitA, ImplicitM) for m in convs): | |
a, c, m = convs | |
aw = a.implicit.clone() | |
mw = m.implicit.clone() | |
c = deepcopy(c) | |
nw, cw, _, _ = c.weight.shape | |
na, ca, _, _ = aw.shape | |
nm, cm, _, _ = mw.shape | |
c.bias = nn.Parameter(c.bias + ( | |
c.weight.reshape(nw, cw) @ aw.reshape(ca, na)).squeeze(1)) | |
c.bias = nn.Parameter(c.bias * mw.reshape(cm)) | |
c.weight = nn.Parameter(c.weight * mw.transpose(0, 1)) | |
convs = c | |
feat = convs(x) | |
return (feat, ) | |
def forward_gvp(x: Tensor) -> Tensor: | |
return torch.mean(x, [2, 3], keepdim=True) | |