# Copyright (c) OpenMMLab. All rights reserved. from typing import Optional, Union import torch from mmdet.models.data_preprocessors import DetDataPreprocessor from mmengine.structures import BaseDataElement from mmyolo.registry import MODELS CastData = Union[tuple, dict, BaseDataElement, torch.Tensor, list, bytes, str, None] @MODELS.register_module() class YOLOWDetDataPreprocessor(DetDataPreprocessor): """Rewrite collate_fn to get faster training speed. Note: It must be used together with `mmyolo.datasets.utils.yolow_collate` """ def __init__(self, *args, non_blocking: Optional[bool] = True, **kwargs): super().__init__(*args, non_blocking=non_blocking, **kwargs) def forward(self, data: dict, training: bool = False) -> dict: """Perform normalization, padding and bgr2rgb conversion based on ``DetDataPreprocessorr``. Args: data (dict): Data sampled from dataloader. training (bool): Whether to enable training time augmentation. Returns: dict: Data in the same format as the model input. """ if not training: return super().forward(data, training) data = self.cast_data(data) inputs, data_samples = data['inputs'], data['data_samples'] assert isinstance(data['data_samples'], dict) # TODO: Supports multi-scale training if self._channel_conversion and inputs.shape[1] == 3: inputs = inputs[:, [2, 1, 0], ...] if self._enable_normalize: inputs = (inputs - self.mean) / self.std if self.batch_augments is not None: for batch_aug in self.batch_augments: inputs, data_samples = batch_aug(inputs, data_samples) img_metas = [{'batch_input_shape': inputs.shape[2:]}] * len(inputs) data_samples_output = { 'bboxes_labels': data_samples['bboxes_labels'], 'texts': data_samples['texts'], 'img_metas': img_metas } if 'masks' in data_samples: data_samples_output['masks'] = data_samples['masks'] if 'is_detection' in data_samples: data_samples_output['is_detection'] = data_samples['is_detection'] return {'inputs': inputs, 'data_samples': data_samples_output}