onuralpszr's picture
feat: ✨ YOLO-World-Seg files uploaded
b291f6a verified
raw
history blame contribute delete
No virus
5.83 kB
# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase
import torch
from mmdet.structures import DetDataSample
from mmengine import MessageHub
from mmyolo.models import PPYOLOEBatchRandomResize, PPYOLOEDetDataPreprocessor
from mmyolo.models.data_preprocessors import (YOLOv5DetDataPreprocessor,
YOLOXBatchSyncRandomResize)
from mmyolo.utils import register_all_modules
register_all_modules()
class TestYOLOv5DetDataPreprocessor(TestCase):
def test_forward(self):
processor = YOLOv5DetDataPreprocessor(mean=[0, 0, 0], std=[1, 1, 1])
data = {
'inputs': [torch.randint(0, 256, (3, 11, 10))],
'data_samples': [DetDataSample()]
}
out_data = processor(data, training=False)
batch_inputs, batch_data_samples = out_data['inputs'], out_data[
'data_samples']
self.assertEqual(batch_inputs.shape, (1, 3, 11, 10))
self.assertEqual(len(batch_data_samples), 1)
# test channel_conversion
processor = YOLOv5DetDataPreprocessor(
mean=[0., 0., 0.], std=[1., 1., 1.], bgr_to_rgb=True)
out_data = processor(data, training=False)
batch_inputs, batch_data_samples = out_data['inputs'], out_data[
'data_samples']
self.assertEqual(batch_inputs.shape, (1, 3, 11, 10))
self.assertEqual(len(batch_data_samples), 1)
# test padding, training=False
data = {
'inputs': [
torch.randint(0, 256, (3, 10, 11)),
torch.randint(0, 256, (3, 9, 14))
]
}
processor = YOLOv5DetDataPreprocessor(
mean=[0., 0., 0.], std=[1., 1., 1.], bgr_to_rgb=True)
out_data = processor(data, training=False)
batch_inputs, batch_data_samples = out_data['inputs'], out_data[
'data_samples']
self.assertEqual(batch_inputs.shape, (2, 3, 10, 14))
self.assertIsNone(batch_data_samples)
# test training
data = {
'inputs': torch.randint(0, 256, (2, 3, 10, 11)),
'data_samples': {
'bboxes_labels': torch.randint(0, 11, (18, 6))
},
}
out_data = processor(data, training=True)
batch_inputs, batch_data_samples = out_data['inputs'], out_data[
'data_samples']
self.assertIn('img_metas', batch_data_samples)
self.assertIn('bboxes_labels', batch_data_samples)
self.assertEqual(batch_inputs.shape, (2, 3, 10, 11))
self.assertIsInstance(batch_data_samples['bboxes_labels'],
torch.Tensor)
self.assertIsInstance(batch_data_samples['img_metas'], list)
data = {
'inputs': [torch.randint(0, 256, (3, 11, 10))],
'data_samples': [DetDataSample()]
}
# data_samples must be dict
with self.assertRaises(AssertionError):
processor(data, training=True)
class TestPPYOLOEDetDataPreprocessor(TestCase):
def test_batch_random_resize(self):
processor = PPYOLOEDetDataPreprocessor(
pad_size_divisor=32,
batch_augments=[
dict(
type='PPYOLOEBatchRandomResize',
random_size_range=(320, 480),
interval=1,
size_divisor=32,
random_interp=True,
keep_ratio=False)
],
mean=[0., 0., 0.],
std=[255., 255., 255.],
bgr_to_rgb=True)
self.assertTrue(
isinstance(processor.batch_augments[0], PPYOLOEBatchRandomResize))
message_hub = MessageHub.get_instance('test_batch_random_resize')
message_hub.update_info('iter', 0)
# test training
data = {
'inputs': [
torch.randint(0, 256, (3, 10, 11)),
torch.randint(0, 256, (3, 10, 11))
],
'data_samples': {
'bboxes_labels': torch.randint(0, 11, (18, 6)).float()
},
}
out_data = processor(data, training=True)
batch_data_samples = out_data['data_samples']
self.assertIn('img_metas', batch_data_samples)
self.assertIn('bboxes_labels', batch_data_samples)
self.assertIsInstance(batch_data_samples['bboxes_labels'],
torch.Tensor)
self.assertIsInstance(batch_data_samples['img_metas'], list)
data = {
'inputs': [torch.randint(0, 256, (3, 11, 10))],
'data_samples': DetDataSample()
}
# data_samples must be list
with self.assertRaises(AssertionError):
processor(data, training=True)
class TestYOLOXDetDataPreprocessor(TestCase):
def test_batch_sync_random_size(self):
processor = YOLOXBatchSyncRandomResize(
random_size_range=(480, 800), size_divisor=32, interval=1)
self.assertTrue(isinstance(processor, YOLOXBatchSyncRandomResize))
message_hub = MessageHub.get_instance(
'test_yolox_batch_sync_random_resize')
message_hub.update_info('iter', 0)
# test training
inputs = torch.randint(0, 256, (4, 3, 10, 11))
data_samples = {'bboxes_labels': torch.randint(0, 11, (18, 6)).float()}
inputs, data_samples = processor(inputs, data_samples)
self.assertIn('bboxes_labels', data_samples)
self.assertIsInstance(data_samples['bboxes_labels'], torch.Tensor)
self.assertIsInstance(inputs, torch.Tensor)
inputs = torch.randint(0, 256, (4, 3, 10, 11))
data_samples = DetDataSample()
# data_samples must be dict
with self.assertRaises(AssertionError):
processor(inputs, data_samples)