csmithxc's picture
Upload 146 files
1530901 verified
raw
history blame
2.15 kB
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Sequence
import torch
from mmengine.dataset import COLLATE_FUNCTIONS
@COLLATE_FUNCTIONS.register_module()
def yolow_collate(data_batch: Sequence,
use_ms_training: bool = False) -> dict:
"""Rewrite collate_fn to get faster training speed.
Args:
data_batch (Sequence): Batch of data.
use_ms_training (bool): Whether to use multi-scale training.
"""
batch_imgs = []
batch_bboxes_labels = []
batch_masks = []
for i in range(len(data_batch)):
datasamples = data_batch[i]['data_samples']
inputs = data_batch[i]['inputs']
batch_imgs.append(inputs)
gt_bboxes = datasamples.gt_instances.bboxes.tensor
gt_labels = datasamples.gt_instances.labels
if 'masks' in datasamples.gt_instances:
masks = datasamples.gt_instances.masks.to(
dtype=torch.bool, device=gt_bboxes.device)
batch_masks.append(masks)
batch_idx = gt_labels.new_full((len(gt_labels), 1), i)
bboxes_labels = torch.cat((batch_idx, gt_labels[:, None], gt_bboxes),
dim=1)
batch_bboxes_labels.append(bboxes_labels)
collated_results = {
'data_samples': {
'bboxes_labels': torch.cat(batch_bboxes_labels, 0)
}
}
if len(batch_masks) > 0:
collated_results['data_samples']['masks'] = torch.cat(batch_masks, 0)
if use_ms_training:
collated_results['inputs'] = batch_imgs
else:
collated_results['inputs'] = torch.stack(batch_imgs, 0)
if hasattr(data_batch[0]['data_samples'], 'texts'):
batch_texts = [meta['data_samples'].texts for meta in data_batch]
collated_results['data_samples']['texts'] = batch_texts
if hasattr(data_batch[0]['data_samples'], 'is_detection'):
# detection flag
batch_detection = [meta['data_samples'].is_detection
for meta in data_batch]
collated_results['data_samples']['is_detection'] = torch.tensor(
batch_detection)
return collated_results