|
|
|
import copy |
|
from typing import List, Optional, Tuple |
|
|
|
import torch |
|
from mmengine.structures import InstanceData |
|
from torch import Tensor |
|
|
|
from mmdet.models.utils import (filter_gt_instances, rename_loss_dict, |
|
reweight_loss_dict) |
|
from mmdet.registry import MODELS |
|
from mmdet.structures import SampleList |
|
from mmdet.structures.bbox import bbox2roi, bbox_project |
|
from mmdet.utils import ConfigType, InstanceList, OptConfigType, OptMultiConfig |
|
from ..utils.misc import unpack_gt_instances |
|
from .semi_base import SemiBaseDetector |
|
|
|
|
|
@MODELS.register_module() |
|
class SoftTeacher(SemiBaseDetector): |
|
r"""Implementation of `End-to-End Semi-Supervised Object Detection |
|
with Soft Teacher <https://arxiv.org/abs/2106.09018>`_ |
|
|
|
Args: |
|
detector (:obj:`ConfigDict` or dict): The detector config. |
|
semi_train_cfg (:obj:`ConfigDict` or dict, optional): |
|
The semi-supervised training config. |
|
semi_test_cfg (:obj:`ConfigDict` or dict, optional): |
|
The semi-supervised testing config. |
|
data_preprocessor (:obj:`ConfigDict` or dict, optional): Config of |
|
:class:`DetDataPreprocessor` to process the input data. |
|
Defaults to None. |
|
init_cfg (:obj:`ConfigDict` or list[:obj:`ConfigDict`] or dict or |
|
list[dict], optional): Initialization config dict. |
|
Defaults to None. |
|
""" |
|
|
|
def __init__(self, |
|
detector: ConfigType, |
|
semi_train_cfg: OptConfigType = None, |
|
semi_test_cfg: OptConfigType = None, |
|
data_preprocessor: OptConfigType = None, |
|
init_cfg: OptMultiConfig = None) -> None: |
|
super().__init__( |
|
detector=detector, |
|
semi_train_cfg=semi_train_cfg, |
|
semi_test_cfg=semi_test_cfg, |
|
data_preprocessor=data_preprocessor, |
|
init_cfg=init_cfg) |
|
|
|
def loss_by_pseudo_instances(self, |
|
batch_inputs: Tensor, |
|
batch_data_samples: SampleList, |
|
batch_info: Optional[dict] = None) -> dict: |
|
"""Calculate losses from a batch of inputs and pseudo data samples. |
|
|
|
Args: |
|
batch_inputs (Tensor): Input images of shape (N, C, H, W). |
|
These should usually be mean centered and std scaled. |
|
batch_data_samples (List[:obj:`DetDataSample`]): The batch |
|
data samples. It usually includes information such |
|
as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`, |
|
which are `pseudo_instance` or `pseudo_panoptic_seg` |
|
or `pseudo_sem_seg` in fact. |
|
batch_info (dict): Batch information of teacher model |
|
forward propagation process. Defaults to None. |
|
|
|
Returns: |
|
dict: A dictionary of loss components |
|
""" |
|
|
|
x = self.student.extract_feat(batch_inputs) |
|
|
|
losses = {} |
|
rpn_losses, rpn_results_list = self.rpn_loss_by_pseudo_instances( |
|
x, batch_data_samples) |
|
losses.update(**rpn_losses) |
|
losses.update(**self.rcnn_cls_loss_by_pseudo_instances( |
|
x, rpn_results_list, batch_data_samples, batch_info)) |
|
losses.update(**self.rcnn_reg_loss_by_pseudo_instances( |
|
x, rpn_results_list, batch_data_samples)) |
|
unsup_weight = self.semi_train_cfg.get('unsup_weight', 1.) |
|
return rename_loss_dict('unsup_', |
|
reweight_loss_dict(losses, unsup_weight)) |
|
|
|
@torch.no_grad() |
|
def get_pseudo_instances( |
|
self, batch_inputs: Tensor, batch_data_samples: SampleList |
|
) -> Tuple[SampleList, Optional[dict]]: |
|
"""Get pseudo instances from teacher model.""" |
|
assert self.teacher.with_bbox, 'Bbox head must be implemented.' |
|
x = self.teacher.extract_feat(batch_inputs) |
|
|
|
|
|
if batch_data_samples[0].get('proposals', None) is None: |
|
rpn_results_list = self.teacher.rpn_head.predict( |
|
x, batch_data_samples, rescale=False) |
|
else: |
|
rpn_results_list = [ |
|
data_sample.proposals for data_sample in batch_data_samples |
|
] |
|
|
|
results_list = self.teacher.roi_head.predict( |
|
x, rpn_results_list, batch_data_samples, rescale=False) |
|
|
|
for data_samples, results in zip(batch_data_samples, results_list): |
|
data_samples.gt_instances = results |
|
|
|
batch_data_samples = filter_gt_instances( |
|
batch_data_samples, |
|
score_thr=self.semi_train_cfg.pseudo_label_initial_score_thr) |
|
|
|
reg_uncs_list = self.compute_uncertainty_with_aug( |
|
x, batch_data_samples) |
|
|
|
for data_samples, reg_uncs in zip(batch_data_samples, reg_uncs_list): |
|
data_samples.gt_instances['reg_uncs'] = reg_uncs |
|
data_samples.gt_instances.bboxes = bbox_project( |
|
data_samples.gt_instances.bboxes, |
|
torch.from_numpy(data_samples.homography_matrix).inverse().to( |
|
self.data_preprocessor.device), data_samples.ori_shape) |
|
|
|
batch_info = { |
|
'feat': x, |
|
'img_shape': [], |
|
'homography_matrix': [], |
|
'metainfo': [] |
|
} |
|
for data_samples in batch_data_samples: |
|
batch_info['img_shape'].append(data_samples.img_shape) |
|
batch_info['homography_matrix'].append( |
|
torch.from_numpy(data_samples.homography_matrix).to( |
|
self.data_preprocessor.device)) |
|
batch_info['metainfo'].append(data_samples.metainfo) |
|
return batch_data_samples, batch_info |
|
|
|
def rpn_loss_by_pseudo_instances(self, x: Tuple[Tensor], |
|
batch_data_samples: SampleList) -> dict: |
|
"""Calculate rpn loss from a batch of inputs and pseudo data samples. |
|
|
|
Args: |
|
x (tuple[Tensor]): Features from FPN. |
|
batch_data_samples (List[:obj:`DetDataSample`]): The batch |
|
data samples. It usually includes information such |
|
as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`, |
|
which are `pseudo_instance` or `pseudo_panoptic_seg` |
|
or `pseudo_sem_seg` in fact. |
|
Returns: |
|
dict: A dictionary of rpn loss components |
|
""" |
|
|
|
rpn_data_samples = copy.deepcopy(batch_data_samples) |
|
rpn_data_samples = filter_gt_instances( |
|
rpn_data_samples, score_thr=self.semi_train_cfg.rpn_pseudo_thr) |
|
proposal_cfg = self.student.train_cfg.get('rpn_proposal', |
|
self.student.test_cfg.rpn) |
|
|
|
for data_sample in rpn_data_samples: |
|
data_sample.gt_instances.labels = \ |
|
torch.zeros_like(data_sample.gt_instances.labels) |
|
|
|
rpn_losses, rpn_results_list = self.student.rpn_head.loss_and_predict( |
|
x, rpn_data_samples, proposal_cfg=proposal_cfg) |
|
for key in rpn_losses.keys(): |
|
if 'loss' in key and 'rpn' not in key: |
|
rpn_losses[f'rpn_{key}'] = rpn_losses.pop(key) |
|
return rpn_losses, rpn_results_list |
|
|
|
def rcnn_cls_loss_by_pseudo_instances(self, x: Tuple[Tensor], |
|
unsup_rpn_results_list: InstanceList, |
|
batch_data_samples: SampleList, |
|
batch_info: dict) -> dict: |
|
"""Calculate classification loss from a batch of inputs and pseudo data |
|
samples. |
|
|
|
Args: |
|
x (tuple[Tensor]): List of multi-level img features. |
|
unsup_rpn_results_list (list[:obj:`InstanceData`]): |
|
List of region proposals. |
|
batch_data_samples (List[:obj:`DetDataSample`]): The batch |
|
data samples. It usually includes information such |
|
as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`, |
|
which are `pseudo_instance` or `pseudo_panoptic_seg` |
|
or `pseudo_sem_seg` in fact. |
|
batch_info (dict): Batch information of teacher model |
|
forward propagation process. |
|
|
|
Returns: |
|
dict[str, Tensor]: A dictionary of rcnn |
|
classification loss components |
|
""" |
|
rpn_results_list = copy.deepcopy(unsup_rpn_results_list) |
|
cls_data_samples = copy.deepcopy(batch_data_samples) |
|
cls_data_samples = filter_gt_instances( |
|
cls_data_samples, score_thr=self.semi_train_cfg.cls_pseudo_thr) |
|
|
|
outputs = unpack_gt_instances(cls_data_samples) |
|
batch_gt_instances, batch_gt_instances_ignore, _ = outputs |
|
|
|
|
|
num_imgs = len(cls_data_samples) |
|
sampling_results = [] |
|
for i in range(num_imgs): |
|
|
|
rpn_results = rpn_results_list[i] |
|
rpn_results.priors = rpn_results.pop('bboxes') |
|
assign_result = self.student.roi_head.bbox_assigner.assign( |
|
rpn_results, batch_gt_instances[i], |
|
batch_gt_instances_ignore[i]) |
|
sampling_result = self.student.roi_head.bbox_sampler.sample( |
|
assign_result, |
|
rpn_results, |
|
batch_gt_instances[i], |
|
feats=[lvl_feat[i][None] for lvl_feat in x]) |
|
sampling_results.append(sampling_result) |
|
|
|
selected_bboxes = [res.priors for res in sampling_results] |
|
rois = bbox2roi(selected_bboxes) |
|
bbox_results = self.student.roi_head._bbox_forward(x, rois) |
|
|
|
|
|
cls_reg_targets = self.student.roi_head.bbox_head.get_targets( |
|
sampling_results, self.student.train_cfg.rcnn) |
|
|
|
selected_results_list = [] |
|
for bboxes, data_samples, teacher_matrix, teacher_img_shape in zip( |
|
selected_bboxes, batch_data_samples, |
|
batch_info['homography_matrix'], batch_info['img_shape']): |
|
student_matrix = torch.tensor( |
|
data_samples.homography_matrix, device=teacher_matrix.device) |
|
homography_matrix = teacher_matrix @ student_matrix.inverse() |
|
projected_bboxes = bbox_project(bboxes, homography_matrix, |
|
teacher_img_shape) |
|
selected_results_list.append(InstanceData(bboxes=projected_bboxes)) |
|
|
|
with torch.no_grad(): |
|
results_list = self.teacher.roi_head.predict_bbox( |
|
batch_info['feat'], |
|
batch_info['metainfo'], |
|
selected_results_list, |
|
rcnn_test_cfg=None, |
|
rescale=False) |
|
bg_score = torch.cat( |
|
[results.scores[:, -1] for results in results_list]) |
|
|
|
neg_inds = cls_reg_targets[ |
|
0] == self.student.roi_head.bbox_head.num_classes |
|
|
|
cls_reg_targets[1][neg_inds] = bg_score[neg_inds].detach() |
|
|
|
losses = self.student.roi_head.bbox_head.loss( |
|
bbox_results['cls_score'], bbox_results['bbox_pred'], rois, |
|
*cls_reg_targets) |
|
|
|
losses['loss_cls'] = losses['loss_cls'] * len( |
|
cls_reg_targets[1]) / max(sum(cls_reg_targets[1]), 1.0) |
|
return losses |
|
|
|
def rcnn_reg_loss_by_pseudo_instances( |
|
self, x: Tuple[Tensor], unsup_rpn_results_list: InstanceList, |
|
batch_data_samples: SampleList) -> dict: |
|
"""Calculate rcnn regression loss from a batch of inputs and pseudo |
|
data samples. |
|
|
|
Args: |
|
x (tuple[Tensor]): List of multi-level img features. |
|
unsup_rpn_results_list (list[:obj:`InstanceData`]): |
|
List of region proposals. |
|
batch_data_samples (List[:obj:`DetDataSample`]): The batch |
|
data samples. It usually includes information such |
|
as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`, |
|
which are `pseudo_instance` or `pseudo_panoptic_seg` |
|
or `pseudo_sem_seg` in fact. |
|
|
|
Returns: |
|
dict[str, Tensor]: A dictionary of rcnn |
|
regression loss components |
|
""" |
|
rpn_results_list = copy.deepcopy(unsup_rpn_results_list) |
|
reg_data_samples = copy.deepcopy(batch_data_samples) |
|
for data_samples in reg_data_samples: |
|
if data_samples.gt_instances.bboxes.shape[0] > 0: |
|
data_samples.gt_instances = data_samples.gt_instances[ |
|
data_samples.gt_instances.reg_uncs < |
|
self.semi_train_cfg.reg_pseudo_thr] |
|
roi_losses = self.student.roi_head.loss(x, rpn_results_list, |
|
reg_data_samples) |
|
return {'loss_bbox': roi_losses['loss_bbox']} |
|
|
|
def compute_uncertainty_with_aug( |
|
self, x: Tuple[Tensor], |
|
batch_data_samples: SampleList) -> List[Tensor]: |
|
"""Compute uncertainty with augmented bboxes. |
|
|
|
Args: |
|
x (tuple[Tensor]): List of multi-level img features. |
|
batch_data_samples (List[:obj:`DetDataSample`]): The batch |
|
data samples. It usually includes information such |
|
as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`, |
|
which are `pseudo_instance` or `pseudo_panoptic_seg` |
|
or `pseudo_sem_seg` in fact. |
|
|
|
Returns: |
|
list[Tensor]: A list of uncertainty for pseudo bboxes. |
|
""" |
|
auged_results_list = self.aug_box(batch_data_samples, |
|
self.semi_train_cfg.jitter_times, |
|
self.semi_train_cfg.jitter_scale) |
|
|
|
auged_results_list = [ |
|
InstanceData(bboxes=auged.reshape(-1, auged.shape[-1])) |
|
for auged in auged_results_list |
|
] |
|
|
|
self.teacher.roi_head.test_cfg = None |
|
results_list = self.teacher.roi_head.predict( |
|
x, auged_results_list, batch_data_samples, rescale=False) |
|
self.teacher.roi_head.test_cfg = self.teacher.test_cfg.rcnn |
|
|
|
reg_channel = max( |
|
[results.bboxes.shape[-1] for results in results_list]) // 4 |
|
bboxes = [ |
|
results.bboxes.reshape(self.semi_train_cfg.jitter_times, -1, |
|
results.bboxes.shape[-1]) |
|
if results.bboxes.numel() > 0 else results.bboxes.new_zeros( |
|
self.semi_train_cfg.jitter_times, 0, 4 * reg_channel).float() |
|
for results in results_list |
|
] |
|
|
|
box_unc = [bbox.std(dim=0) for bbox in bboxes] |
|
bboxes = [bbox.mean(dim=0) for bbox in bboxes] |
|
labels = [ |
|
data_samples.gt_instances.labels |
|
for data_samples in batch_data_samples |
|
] |
|
if reg_channel != 1: |
|
bboxes = [ |
|
bbox.reshape(bbox.shape[0], reg_channel, |
|
4)[torch.arange(bbox.shape[0]), label] |
|
for bbox, label in zip(bboxes, labels) |
|
] |
|
box_unc = [ |
|
unc.reshape(unc.shape[0], reg_channel, |
|
4)[torch.arange(unc.shape[0]), label] |
|
for unc, label in zip(box_unc, labels) |
|
] |
|
|
|
box_shape = [(bbox[:, 2:4] - bbox[:, :2]).clamp(min=1.0) |
|
for bbox in bboxes] |
|
box_unc = [ |
|
torch.mean( |
|
unc / wh[:, None, :].expand(-1, 2, 2).reshape(-1, 4), dim=-1) |
|
if wh.numel() > 0 else unc for unc, wh in zip(box_unc, box_shape) |
|
] |
|
return box_unc |
|
|
|
@staticmethod |
|
def aug_box(batch_data_samples, times, frac): |
|
"""Augment bboxes with jitter.""" |
|
|
|
def _aug_single(box): |
|
box_scale = box[:, 2:4] - box[:, :2] |
|
box_scale = ( |
|
box_scale.clamp(min=1)[:, None, :].expand(-1, 2, |
|
2).reshape(-1, 4)) |
|
aug_scale = box_scale * frac |
|
|
|
offset = ( |
|
torch.randn(times, box.shape[0], 4, device=box.device) * |
|
aug_scale[None, ...]) |
|
new_box = box.clone()[None, ...].expand(times, box.shape[0], |
|
-1) + offset |
|
return new_box |
|
|
|
return [ |
|
_aug_single(data_samples.gt_instances.bboxes) |
|
for data_samples in batch_data_samples |
|
] |
|
|