Spaces:
Running
on
Zero
Running
on
Zero
"""YOLOX model implementation and runtime.""" | |
from __future__ import annotations | |
import torch | |
from torch import nn | |
from vis4d.common.ckpt import load_model_checkpoint | |
from vis4d.op.base import BaseModel, CSPDarknet | |
from vis4d.op.box.box2d import scale_and_clip_boxes | |
from vis4d.op.detect.common import DetOut | |
from vis4d.op.detect.yolox import YOLOXHead, YOLOXOut, YOLOXPostprocess | |
from vis4d.op.fpp import YOLOXPAFPN, FeaturePyramidProcessing | |
REV_KEYS = [ | |
(r"^backbone\.", "basemodel."), | |
(r"^bbox_head\.", "yolox_head."), | |
(r"^neck\.", "fpn."), | |
(r"\.bn\.", ".norm."), | |
(r"\.conv.weight", ".weight"), | |
(r"\.conv.bias", ".bias"), | |
] | |
class YOLOX(nn.Module): | |
"""YOLOX detector.""" | |
def __init__( | |
self, | |
num_classes: int, | |
basemodel: BaseModel | None = None, | |
fpn: FeaturePyramidProcessing | None = None, | |
yolox_head: YOLOXHead | None = None, | |
postprocessor: YOLOXPostprocess | None = None, | |
weights: None | str = None, | |
) -> None: | |
"""Creates an instance of the class. | |
Args: | |
num_classes (int): Number of classes. | |
basemodel (BaseModel, optional): Base model. Defaults to None. If | |
None, will use CSPDarknet. | |
fpn (FeaturePyramidProcessing, optional): Feature Pyramid | |
Processing. Defaults to None. If None, will use YOLOXPAFPN. | |
yolox_head (YOLOXHead, optional): YOLOX head. Defaults to None. If | |
None, will use YOLOXHead. | |
postprocessor (YOLOXPostprocess, optional): Post processor. | |
Defaults to None. If None, will use YOLOXPostprocess. | |
weights (None | str, optional): Weights to load for model. If | |
set to "mmdet", will load MMDetection pre-trained weights. | |
Defaults to None. | |
""" | |
super().__init__() | |
self.basemodel = ( | |
CSPDarknet(deepen_factor=0.33, widen_factor=0.5) | |
if basemodel is None | |
else basemodel | |
) | |
self.fpn = ( | |
YOLOXPAFPN([128, 256, 512], 128, num_csp_blocks=1) | |
if fpn is None | |
else fpn | |
) | |
self.yolox_head = ( | |
YOLOXHead( | |
num_classes=num_classes, in_channels=128, feat_channels=128 | |
) | |
if yolox_head is None | |
else yolox_head | |
) | |
self.postprocessor = ( | |
YOLOXPostprocess( | |
self.yolox_head.point_generator, | |
self.yolox_head.box_decoder, | |
nms_threshold=0.65, | |
score_thr=0.01, | |
) | |
if postprocessor is None | |
else postprocessor | |
) | |
if weights is not None: | |
if weights.startswith("mmdet://") or weights.startswith( | |
"bdd100k://" | |
): | |
load_model_checkpoint(self, weights, rev_keys=REV_KEYS) | |
else: | |
load_model_checkpoint(self, weights) | |
def forward( | |
self, | |
images: torch.Tensor, | |
input_hw: None | list[tuple[int, int]] = None, | |
original_hw: None | list[tuple[int, int]] = None, | |
) -> YOLOXOut | DetOut: | |
"""Forward pass. | |
Args: | |
images (torch.Tensor): Input images. | |
input_hw (None | list[tuple[int, int]], optional): Input image | |
resolutions. Defaults to None. | |
original_hw (None | list[tuple[int, int]], optional): Original | |
image resolutions (before padding and resizing). Required for | |
testing. Defaults to None. | |
Returns: | |
YOLOXOut | DetOut: Either raw model outputs (for training) or | |
predicted outputs (for testing). | |
""" | |
if self.training: | |
return self.forward_train(images) | |
assert input_hw is not None and original_hw is not None | |
return self.forward_test(images, input_hw, original_hw) | |
def forward_train(self, images: torch.Tensor) -> YOLOXOut: | |
"""Forward training stage. | |
Args: | |
images (torch.Tensor): Input images. | |
Returns: | |
YOLOXOut: Raw model outputs. | |
""" | |
features = self.fpn(self.basemodel(images.contiguous())) | |
return self.yolox_head(features[-3:]) | |
def forward_test( | |
self, | |
images: torch.Tensor, | |
images_hw: list[tuple[int, int]], | |
original_hw: list[tuple[int, int]], | |
) -> DetOut: | |
"""Forward testing stage. | |
Args: | |
images (torch.Tensor): Input images. | |
images_hw (list[tuple[int, int]]): Input image resolutions. | |
original_hw (list[tuple[int, int]]): Original image resolutions | |
(before padding and resizing). | |
Returns: | |
DetOut: Predicted outputs. | |
""" | |
features = self.fpn(self.basemodel(images)) | |
outs = self.yolox_head(features[-3:]) | |
boxes, scores, class_ids = self.postprocessor( | |
cls_outs=outs.cls_score, | |
reg_outs=outs.bbox_pred, | |
obj_outs=outs.objectness, | |
images_hw=images_hw, | |
) | |
for i, boxs in enumerate(boxes): | |
boxes[i] = scale_and_clip_boxes(boxs, original_hw[i], images_hw[i]) | |
return DetOut(boxes, scores, class_ids) | |