Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
from typing import Sequence | |
import torch | |
from mmengine.dataset import COLLATE_FUNCTIONS | |
def yolow_collate(data_batch: Sequence, | |
use_ms_training: bool = False) -> dict: | |
"""Rewrite collate_fn to get faster training speed. | |
Args: | |
data_batch (Sequence): Batch of data. | |
use_ms_training (bool): Whether to use multi-scale training. | |
""" | |
batch_imgs = [] | |
batch_bboxes_labels = [] | |
batch_masks = [] | |
for i in range(len(data_batch)): | |
datasamples = data_batch[i]['data_samples'] | |
inputs = data_batch[i]['inputs'] | |
batch_imgs.append(inputs) | |
gt_bboxes = datasamples.gt_instances.bboxes.tensor | |
gt_labels = datasamples.gt_instances.labels | |
if 'masks' in datasamples.gt_instances: | |
masks = datasamples.gt_instances.masks.to_tensor( | |
dtype=torch.bool, device=gt_bboxes.device) | |
batch_masks.append(masks) | |
batch_idx = gt_labels.new_full((len(gt_labels), 1), i) | |
bboxes_labels = torch.cat((batch_idx, gt_labels[:, None], gt_bboxes), | |
dim=1) | |
batch_bboxes_labels.append(bboxes_labels) | |
collated_results = { | |
'data_samples': { | |
'bboxes_labels': torch.cat(batch_bboxes_labels, 0) | |
} | |
} | |
if len(batch_masks) > 0: | |
collated_results['data_samples']['masks'] = torch.cat(batch_masks, 0) | |
if use_ms_training: | |
collated_results['inputs'] = batch_imgs | |
else: | |
collated_results['inputs'] = torch.stack(batch_imgs, 0) | |
if hasattr(data_batch[0]['data_samples'], 'texts'): | |
batch_texts = [meta['data_samples'].texts for meta in data_batch] | |
collated_results['data_samples']['texts'] = batch_texts | |
if hasattr(data_batch[0]['data_samples'], 'is_detection'): | |
# detection flag | |
batch_detection = [meta['data_samples'].is_detection | |
for meta in data_batch] | |
collated_results['data_samples']['is_detection'] = torch.tensor( | |
batch_detection) | |
return collated_results | |