import random from typing import Iterable, List import numpy as np from mmdet.datasets.base_det_dataset import BaseDetDataset from mmdet.datasets.base_video_dataset import BaseVideoDataset from mmdet.registry import DATASETS from mmengine.dataset import BaseDataset from torch.utils.data import Dataset from torch.utils.data.dataset import ConcatDataset as _ConcatDataset @DATASETS.register_module() class RandomSampleConcatDataset(_ConcatDataset): def __init__( self, datasets: Iterable[Dataset], sampling_probs: List[float], fixed_length: int, lazy_init: bool = False, ): super(RandomSampleConcatDataset, self).__init__(datasets) assert len(sampling_probs) == len( datasets ), "Number of sampling probabilities must match the number of datasets" assert sum(sampling_probs) == 1.0, "Sum of sampling probabilities must be 1.0" self.datasets: List[BaseDataset] = [] for i, dataset in enumerate(datasets): if isinstance(dataset, dict): self.datasets.append(DATASETS.build(dataset)) elif isinstance(dataset, BaseDataset): self.datasets.append(dataset) else: raise TypeError( "elements in datasets sequence should be config or " f"`BaseDataset` instance, but got {type(dataset)}" ) self.sampling_probs = sampling_probs self.fixed_length = fixed_length self.metainfo = self.datasets[0].metainfo total_datasets_length = sum([len(dataset) for dataset in self.datasets]) assert ( self.fixed_length <= total_datasets_length ), "the length of the concatenated dataset must be less than the sum of the lengths of the individual datasets" self.flag = np.zeros(self.fixed_length, dtype=np.uint8) self._fully_initialized = False if not lazy_init: self.full_init() def full_init(self): """Loop to ``full_init`` each dataset.""" if self._fully_initialized: return for i, dataset in enumerate(self.datasets): dataset.full_init() self._ori_len = self.fixed_length self._fully_initialized = True def get_data_info(self, idx: int) -> dict: """Get annotation by index. Args: idx (int): Global index of ``ConcatDataset``. Returns: dict: The idx-th annotation of the datasets. """ return self.dataset.get_data_info(idx) def __len__(self): return self.fixed_length def __getitem__(self, idx): # Choose a dataset based on the sampling probabilities chosen_dataset_idx = random.choices( range(len(self.datasets)), weights=self.sampling_probs, k=1 )[0] chosen_dataset = self.datasets[chosen_dataset_idx] # Sample a random item from the chosen dataset sample_idx = random.randrange(0, len(chosen_dataset)) return chosen_dataset[sample_idx] @DATASETS.register_module() class RandomSampleJointVideoConcatDataset(_ConcatDataset): def __init__( self, datasets: Iterable[Dataset], fixed_length: int, lazy_init: bool = False, video_sampling_probs: List[float] = [], img_sampling_probs: List[float] = [], *args, **kwargs, ): super(RandomSampleJointVideoConcatDataset, self).__init__(datasets) self.datasets: List[BaseDataset] = [] for i, dataset in enumerate(datasets): if isinstance(dataset, dict): self.datasets.append(DATASETS.build(dataset)) elif isinstance(dataset, BaseDataset): self.datasets.append(dataset) else: raise TypeError( "elements in datasets sequence should be config or " f"`BaseDataset` instance, but got {type(dataset)}" ) self.video_dataset_idx = [] self.img_dataset_idx = [] self.datasets_indices_mapping = {} for i, dataset in enumerate(self.datasets): if isinstance(dataset, BaseVideoDataset): self.video_dataset_idx.append(i) num_videos = len(dataset) video_indices = [] for video_ind in range(num_videos): video_indices.extend( [ (video_ind, frame_ind) for frame_ind in range(dataset.get_len_per_video(video_ind)) ] ) self.datasets_indices_mapping[i] = video_indices elif isinstance(dataset, BaseDetDataset): self.img_dataset_idx.append(i) img_indices = [] num_imgs = len(dataset) for img_ind in range(num_imgs): img_indices.extend([img_ind]) self.datasets_indices_mapping[i] = img_indices else: raise TypeError( "elements in datasets sequence should be config or " f"`BaseDataset` instance, but got {type(dataset)}" ) self.fixed_length = fixed_length self.metainfo = self.datasets[0].metainfo total_datasets_length = sum( [len(indices) for key, indices in self.datasets_indices_mapping.items()] ) assert ( self.fixed_length <= total_datasets_length ), "the length of the concatenated dataset must be less than the sum of the lengths of the individual datasets" self.flag = np.zeros(self.fixed_length, dtype=np.uint8) self._fully_initialized = False if not lazy_init: self.full_init() self.video_sampling_probs = video_sampling_probs self.img_sampling_probs = img_sampling_probs if self.video_sampling_probs: assert ( sum(self.video_sampling_probs) == 1.0 ), "Sum of video sampling probabilities must be 1.0" if self.img_sampling_probs: assert ( sum(self.img_sampling_probs) == 1.0 ), "Sum of image sampling probabilities must be 1.0" def full_init(self): """Loop to ``full_init`` each dataset.""" if self._fully_initialized: return for i, dataset in enumerate(self.datasets): dataset.full_init() self._ori_len = self.fixed_length self._fully_initialized = True def get_data_info(self, idx: int) -> dict: """Get annotation by index. Args: idx (int): Global index of ``ConcatDataset``. Returns: dict: The idx-th annotation of the datasets. """ return self.dataset.get_data_info(idx) def __len__(self): return self.fixed_length def __getitem__(self, idx): # idx ==0 means samples from video dataset, idx == 1 means samples from image dataset # Choose a dataset based on the sampling probabilities if idx == 0: chosen_dataset_idx = random.choices( self.video_dataset_idx, weights=self.video_sampling_probs, k=1 )[0] elif idx == 1: chosen_dataset_idx = random.choices( self.img_dataset_idx, weights=self.img_sampling_probs, k=1 )[0] chosen_dataset = self.datasets[chosen_dataset_idx] # Sample a random item from the chosen dataset sample_idx = random.choice(self.datasets_indices_mapping[chosen_dataset_idx]) return chosen_dataset[sample_idx]