#!/usr/bin/env python # -*- encoding: utf-8 -*- # Copyright (c) 2014-2021 Megvii Inc. All rights reserved. import torch.nn as nn from .yolo_head import YOLOXHead from .yolo_pafpn import YOLOPAFPN class YOLOX(nn.Module): """ YOLOX model module. The module list is defined by create_yolov3_modules function. The network returns loss values from three YOLO layers during training and detection results during test. """ def __init__(self, backbone=None, head=None): super().__init__() if backbone is None: backbone = YOLOPAFPN() if head is None: head = YOLOXHead(80) self.backbone = backbone self.head = head def forward(self, x, targets=None): # fpn output content features of [dark3, dark4, dark5] fpn_outs = self.backbone(x) if self.training: assert targets is not None loss, iou_loss, conf_loss, cls_loss, l1_loss, num_fg = self.head( fpn_outs, targets, x ) outputs = { "total_loss": loss, "iou_loss": iou_loss, "l1_loss": l1_loss, "conf_loss": conf_loss, "cls_loss": cls_loss, "num_fg": num_fg, } else: outputs = self.head(fpn_outs) return outputs