# Copyright (c) OpenMMLab. All rights reserved. import torch import torch.nn as nn from mmcv.runner import load_checkpoint from ..builder import DETECTORS, build_backbone, build_head, build_neck from .kd_one_stage import KnowledgeDistillationSingleStageDetector @DETECTORS.register_module() class LAD(KnowledgeDistillationSingleStageDetector): """Implementation of `LAD `_.""" def __init__(self, backbone, neck, bbox_head, teacher_backbone, teacher_neck, teacher_bbox_head, teacher_ckpt, eval_teacher=True, train_cfg=None, test_cfg=None, pretrained=None): super(KnowledgeDistillationSingleStageDetector, self).__init__(backbone, neck, bbox_head, train_cfg, test_cfg, pretrained) self.eval_teacher = eval_teacher self.teacher_model = nn.Module() self.teacher_model.backbone = build_backbone(teacher_backbone) if teacher_neck is not None: self.teacher_model.neck = build_neck(teacher_neck) teacher_bbox_head.update(train_cfg=train_cfg) teacher_bbox_head.update(test_cfg=test_cfg) self.teacher_model.bbox_head = build_head(teacher_bbox_head) if teacher_ckpt is not None: load_checkpoint( self.teacher_model, teacher_ckpt, map_location='cpu') @property def with_teacher_neck(self): """bool: whether the detector has a teacher_neck""" return hasattr(self.teacher_model, 'neck') and \ self.teacher_model.neck is not None def extract_teacher_feat(self, img): """Directly extract teacher features from the backbone+neck.""" x = self.teacher_model.backbone(img) if self.with_teacher_neck: x = self.teacher_model.neck(x) return x def forward_train(self, img, img_metas, gt_bboxes, gt_labels, gt_bboxes_ignore=None): """ Args: img (Tensor): Input images of shape (N, C, H, W). Typically these should be mean centered and std scaled. img_metas (list[dict]): A List of image info dict where each dict has: 'img_shape', 'scale_factor', 'flip', and may also contain 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. For details on the values of these keys see :class:`mmdet.datasets.pipelines.Collect`. gt_bboxes (list[Tensor]): Each item are the truth boxes for each image in [tl_x, tl_y, br_x, br_y] format. gt_labels (list[Tensor]): Class indices corresponding to each box gt_bboxes_ignore (None | list[Tensor]): Specify which bounding boxes can be ignored when computing the loss. Returns: dict[str, Tensor]: A dictionary of loss components. """ # get label assignment from the teacher with torch.no_grad(): x_teacher = self.extract_teacher_feat(img) outs_teacher = self.teacher_model.bbox_head(x_teacher) label_assignment_results = \ self.teacher_model.bbox_head.get_label_assignment( *outs_teacher, gt_bboxes, gt_labels, img_metas, gt_bboxes_ignore) # the student use the label assignment from the teacher to learn x = self.extract_feat(img) losses = self.bbox_head.forward_train(x, label_assignment_results, img_metas, gt_bboxes, gt_labels, gt_bboxes_ignore) return losses