MASA_GroundingDINO / masa /datasets /dataset_wrappers.py
JohanDL's picture
initial commit
f1dd031
raw
history blame
No virus
17.3 kB
import collections
import copy
import random
from typing import List, Sequence, Union
import numpy as np
from mmdet.datasets.base_det_dataset import BaseDetDataset
from mmdet.datasets.base_video_dataset import BaseVideoDataset
from mmdet.registry import DATASETS, TRANSFORMS
from mmengine.dataset import BaseDataset, force_full_init
from .rsconcat_dataset import RandomSampleJointVideoConcatDataset
@DATASETS.register_module(force=True)
class SeqMultiImageMixDataset:
"""A wrapper of multiple images mixed dataset.
Suitable for training on multiple images mixed data augmentation like
mosaic and mixup. For the augmentation pipeline of mixed image data,
the `get_indexes` method needs to be provided to obtain the image
indexes, and you can set `skip_flags` to change the pipeline running
process. At the same time, we provide the `dynamic_scale` parameter
to dynamically change the output image size.
Args:
dataset (:obj:`CustomDataset`): The dataset to be mixed.
pipeline (Sequence[dict]): Sequence of transform object or
config dict to be composed.
dynamic_scale (tuple[int], optional): The image scale can be changed
dynamically. Default to None. It is deprecated.
skip_type_keys (list[str], optional): Sequence of type string to
be skip pipeline. Default to None.
max_refetch (int): The maximum number of retry iterations for getting
valid results from the pipeline. If the number of iterations is
greater than `max_refetch`, but results is still None, then the
iteration is terminated and raise the error. Default: 15.
"""
def __init__(
self,
dataset: Union[BaseDataset, dict],
pipeline: Sequence[str],
skip_type_keys: Union[Sequence[str], None] = None,
max_refetch: int = 15,
lazy_init: bool = False,
) -> None:
assert isinstance(pipeline, collections.abc.Sequence)
if skip_type_keys is not None:
assert all(
[isinstance(skip_type_key, str) for skip_type_key in skip_type_keys]
)
self._skip_type_keys = skip_type_keys
self.pipeline = []
self.pipeline_types = []
for transform in pipeline:
if isinstance(transform, dict):
self.pipeline_types.append(transform["type"])
transform = TRANSFORMS.build(transform)
self.pipeline.append(transform)
else:
raise TypeError("pipeline must be a dict")
self.dataset: BaseDataset
if isinstance(dataset, dict):
self.dataset = DATASETS.build(dataset)
elif isinstance(dataset, BaseDataset):
self.dataset = dataset
else:
raise TypeError(
"elements in datasets sequence should be config or "
f"`BaseDataset` instance, but got {type(dataset)}"
)
self._metainfo = self.dataset.metainfo
if hasattr(self.dataset, "flag"):
self.flag = self.dataset.flag
self.num_samples = len(self.dataset)
self.max_refetch = max_refetch
self._fully_initialized = False
if not lazy_init:
self.full_init()
self.generate_indices()
def generate_indices(self):
cat_datasets = self.dataset.datasets
for dataset in cat_datasets:
self.test_mode = dataset.test_mode
assert not self.test_mode, "'ConcatDataset' should not exist in "
"test mode"
video_indices = []
img_indices = []
if isinstance(dataset, BaseVideoDataset):
num_videos = len(dataset)
for video_ind in range(num_videos):
video_indices.extend(
[
(video_ind, frame_ind)
for frame_ind in range(dataset.get_len_per_video(video_ind))
]
)
elif isinstance(dataset, BaseDetDataset):
num_imgs = len(dataset)
for img_ind in range(num_imgs):
img_indices.extend([img_ind])
###### special process to make debug task easier #####
def alternate_merge(list1, list2):
# Create a new list to hold the merged elements
merged_list = []
# Get the length of the shorter list
min_length = min(len(list1), len(list2))
# Append elements alternately from both lists
for i in range(min_length):
merged_list.append(list1[i])
merged_list.append(list2[i])
# Append the remaining elements from the longer list
if len(list1) > len(list2):
merged_list.extend(list1[min_length:])
else:
merged_list.extend(list2[min_length:])
return merged_list
self.indices = alternate_merge(img_indices, video_indices)
@property
def metainfo(self) -> dict:
"""Get the meta information of the multi-image-mixed dataset.
Returns:
dict: The meta information of multi-image-mixed dataset.
"""
return copy.deepcopy(self._metainfo)
def full_init(self):
"""Loop to ``full_init`` each dataset."""
if self._fully_initialized:
return
self.dataset.full_init()
self._ori_len = len(self.dataset)
self._fully_initialized = True
@force_full_init
def get_data_info(self, idx: int) -> dict:
"""Get annotation by index.
Args:
idx (int): Global index of ``ConcatDataset``.
Returns:
dict: The idx-th annotation of the datasets.
"""
return self.dataset.get_data_info(idx)
@force_full_init
def get_transform_indexes(self, transform, results, t_type="SeqMosaic"):
num_samples = len(results["img_id"])
for i in range(self.max_refetch):
# Make sure the results passed the loading pipeline
# of the original dataset is not None.
indexes = transform.get_indexes(self.dataset)
if not isinstance(indexes, collections.abc.Sequence):
indexes = [indexes]
mix_results = [copy.deepcopy(self.dataset[index]) for index in indexes]
if None not in mix_results:
if t_type == "SeqMosaic":
results["mosaic_mix_results"] = [mix_results] * num_samples
elif t_type == "SeqMixUp":
results["mixup_mix_results"] = [mix_results] * num_samples
elif t_type == "SeqCopyPaste":
results["copypaste_mix_results"] = [mix_results] * num_samples
return results
else:
raise RuntimeError(
"The loading pipeline of the original dataset"
" always return None. Please check the correctness "
"of the dataset and its pipeline."
)
@force_full_init
def __len__(self):
return self.num_samples
def __getitem__(self, idx):
while True:
results = copy.deepcopy(self.dataset[idx])
for (transform, transform_type) in zip(self.pipeline, self.pipeline_types):
if (
self._skip_type_keys is not None
and transform_type in self._skip_type_keys
):
continue
if transform_type == "MasaTransformBroadcaster":
for sub_transform in transform.transforms:
if hasattr(sub_transform, "get_indexes"):
sub_transform_type = type(sub_transform).__name__
results = self.get_transform_indexes(
sub_transform, results, sub_transform_type
)
elif hasattr(transform, "get_indexes"):
for i in range(self.max_refetch):
# Make sure the results passed the loading pipeline
# of the original dataset is not None.
indexes = transform.get_indexes(self.dataset)
if not isinstance(indexes, collections.abc.Sequence):
indexes = [indexes]
mix_results = [
copy.deepcopy(self.dataset[index]) for index in indexes
]
if None not in mix_results:
results["mix_results"] = mix_results
break
else:
raise RuntimeError(
"The loading pipeline of the original dataset"
" always return None. Please check the correctness "
"of the dataset and its pipeline."
)
for i in range(self.max_refetch):
# To confirm the results passed the training pipeline
# of the wrapper is not None.
try:
updated_results = transform(copy.deepcopy(results))
except Exception as e:
print(
"Error occurred while running pipeline",
f"{transform} with error: {e}",
)
# print('Empty instances due to augmentation, re-sampling...')
idx = self._rand_another(idx)
continue
if updated_results is not None:
results = updated_results
break
else:
raise RuntimeError(
"The training pipeline of the dataset wrapper"
" always return None.Please check the correctness "
"of the dataset and its pipeline."
)
if "mosaic_mix_results" in results:
results.pop("mosaic_mix_results")
if "mixup_mix_results" in results:
results.pop("mixup_mix_results")
if "copypaste_mix_results" in results:
results.pop("copypaste_mix_results")
return results
def update_skip_type_keys(self, skip_type_keys):
"""Update skip_type_keys. It is called by an external hook.
Args:
skip_type_keys (list[str], optional): Sequence of type
string to be skip pipeline.
"""
assert all([isinstance(skip_type_key, str) for skip_type_key in skip_type_keys])
self._skip_type_keys = skip_type_keys
def _rand_another(self, idx):
"""Get another random index from the same group as the given index."""
return np.random.choice(self.indices)
@DATASETS.register_module()
class SeqRandomMultiImageVideoMixDataset(SeqMultiImageMixDataset):
def __init__(
self, video_pipeline: Sequence[str], video_sample_ratio=0.5, *args, **kwargs
):
super().__init__(*args, **kwargs)
self.video_pipeline = []
self.video_pipeline_types = []
for transform in video_pipeline:
if isinstance(transform, dict):
self.video_pipeline_types.append(transform["type"])
transform = TRANSFORMS.build(transform)
self.video_pipeline.append(transform)
else:
raise TypeError("pipeline must be a dict")
self.video_sample_ratio = video_sample_ratio
assert isinstance(self.dataset, RandomSampleJointVideoConcatDataset)
@force_full_init
def get_transform_indexes(
self, transform, results, sample_video, t_type="SeqMosaic"
):
num_samples = len(results["img_id"])
for i in range(self.max_refetch):
# Make sure the results passed the loading pipeline
# of the original dataset is not None.
indexes = transform.get_indexes(self.dataset.datasets[0])
if not isinstance(indexes, collections.abc.Sequence):
indexes = [indexes]
if sample_video:
mix_results = [copy.deepcopy(self.dataset[0]) for index in indexes]
else:
mix_results = [copy.deepcopy(self.dataset[1]) for index in indexes]
if None not in mix_results:
if t_type == "SeqMosaic":
results["mosaic_mix_results"] = [mix_results] * num_samples
elif t_type == "SeqMixUp":
results["mixup_mix_results"] = [mix_results] * num_samples
elif t_type == "SeqCopyPaste":
results["copypaste_mix_results"] = [mix_results] * num_samples
return results
else:
raise RuntimeError(
"The loading pipeline of the original dataset"
" always return None. Please check the correctness "
"of the dataset and its pipeline."
)
def __getitem__(self, idx):
while True:
if random.random() < self.video_sample_ratio:
sample_video = True
else:
sample_video = False
if sample_video:
results = copy.deepcopy(self.dataset[0])
pipeline = self.video_pipeline
pipeline_type = self.video_pipeline_types
else:
results = copy.deepcopy(self.dataset[1])
pipeline = self.pipeline
pipeline_type = self.pipeline_types
# if results['img_id'][0] != results['img_id'][1]:
# self.update_skip_type_keys(['SeqMosaic', 'SeqMixUp'])
# else:
# self._skip_type_keys = None
for (transform, transform_type) in zip(pipeline, pipeline_type):
if (
self._skip_type_keys is not None
and transform_type in self._skip_type_keys
):
continue
if transform_type == "MasaTransformBroadcaster":
for sub_transform in transform.transforms:
if hasattr(sub_transform, "get_indexes"):
sub_transform_type = type(sub_transform).__name__
results = self.get_transform_indexes(
sub_transform, results, sample_video, sub_transform_type
)
elif hasattr(transform, "get_indexes"):
for i in range(self.max_refetch):
# Make sure the results passed the loading pipeline
# of the original dataset is not None.
indexes = transform.get_indexes(self.dataset)
if not isinstance(indexes, collections.abc.Sequence):
indexes = [indexes]
mix_results = [
copy.deepcopy(self.dataset[index]) for index in indexes
]
if None not in mix_results:
results["mix_results"] = mix_results
break
else:
raise RuntimeError(
"The loading pipeline of the original dataset"
" always return None. Please check the correctness "
"of the dataset and its pipeline."
)
for i in range(self.max_refetch):
# To confirm the results passed the training pipeline
# of the wrapper is not None.
try:
updated_results = transform(copy.deepcopy(results))
except Exception as e:
print(
"Error occurred while running pipeline",
f"{transform} with error: {e}",
)
# print('Empty instances due to augmentation, re-sampling...')
# idx = self._rand_another(idx)
continue
if updated_results is not None:
results = updated_results
break
else:
raise RuntimeError(
"The training pipeline of the dataset wrapper"
" always return None.Please check the correctness "
"of the dataset and its pipeline."
)
if "mosaic_mix_results" in results:
results.pop("mosaic_mix_results")
if "mixup_mix_results" in results:
results.pop("mixup_mix_results")
if "copypaste_mix_results" in results:
results.pop("copypaste_mix_results")
return results