Spaces:
Runtime error
Runtime error
import random | |
from typing import Iterable, List | |
import numpy as np | |
from mmdet.datasets.base_det_dataset import BaseDetDataset | |
from mmdet.datasets.base_video_dataset import BaseVideoDataset | |
from mmdet.registry import DATASETS | |
from mmengine.dataset import BaseDataset | |
from torch.utils.data import Dataset | |
from torch.utils.data.dataset import ConcatDataset as _ConcatDataset | |
class RandomSampleConcatDataset(_ConcatDataset): | |
def __init__( | |
self, | |
datasets: Iterable[Dataset], | |
sampling_probs: List[float], | |
fixed_length: int, | |
lazy_init: bool = False, | |
): | |
super(RandomSampleConcatDataset, self).__init__(datasets) | |
assert len(sampling_probs) == len( | |
datasets | |
), "Number of sampling probabilities must match the number of datasets" | |
assert sum(sampling_probs) == 1.0, "Sum of sampling probabilities must be 1.0" | |
self.datasets: List[BaseDataset] = [] | |
for i, dataset in enumerate(datasets): | |
if isinstance(dataset, dict): | |
self.datasets.append(DATASETS.build(dataset)) | |
elif isinstance(dataset, BaseDataset): | |
self.datasets.append(dataset) | |
else: | |
raise TypeError( | |
"elements in datasets sequence should be config or " | |
f"`BaseDataset` instance, but got {type(dataset)}" | |
) | |
self.sampling_probs = sampling_probs | |
self.fixed_length = fixed_length | |
self.metainfo = self.datasets[0].metainfo | |
total_datasets_length = sum([len(dataset) for dataset in self.datasets]) | |
assert ( | |
self.fixed_length <= total_datasets_length | |
), "the length of the concatenated dataset must be less than the sum of the lengths of the individual datasets" | |
self.flag = np.zeros(self.fixed_length, dtype=np.uint8) | |
self._fully_initialized = False | |
if not lazy_init: | |
self.full_init() | |
def full_init(self): | |
"""Loop to ``full_init`` each dataset.""" | |
if self._fully_initialized: | |
return | |
for i, dataset in enumerate(self.datasets): | |
dataset.full_init() | |
self._ori_len = self.fixed_length | |
self._fully_initialized = True | |
def get_data_info(self, idx: int) -> dict: | |
"""Get annotation by index. | |
Args: | |
idx (int): Global index of ``ConcatDataset``. | |
Returns: | |
dict: The idx-th annotation of the datasets. | |
""" | |
return self.dataset.get_data_info(idx) | |
def __len__(self): | |
return self.fixed_length | |
def __getitem__(self, idx): | |
# Choose a dataset based on the sampling probabilities | |
chosen_dataset_idx = random.choices( | |
range(len(self.datasets)), weights=self.sampling_probs, k=1 | |
)[0] | |
chosen_dataset = self.datasets[chosen_dataset_idx] | |
# Sample a random item from the chosen dataset | |
sample_idx = random.randrange(0, len(chosen_dataset)) | |
return chosen_dataset[sample_idx] | |
class RandomSampleJointVideoConcatDataset(_ConcatDataset): | |
def __init__( | |
self, | |
datasets: Iterable[Dataset], | |
fixed_length: int, | |
lazy_init: bool = False, | |
video_sampling_probs: List[float] = [], | |
img_sampling_probs: List[float] = [], | |
*args, | |
**kwargs, | |
): | |
super(RandomSampleJointVideoConcatDataset, self).__init__(datasets) | |
self.datasets: List[BaseDataset] = [] | |
for i, dataset in enumerate(datasets): | |
if isinstance(dataset, dict): | |
self.datasets.append(DATASETS.build(dataset)) | |
elif isinstance(dataset, BaseDataset): | |
self.datasets.append(dataset) | |
else: | |
raise TypeError( | |
"elements in datasets sequence should be config or " | |
f"`BaseDataset` instance, but got {type(dataset)}" | |
) | |
self.video_dataset_idx = [] | |
self.img_dataset_idx = [] | |
self.datasets_indices_mapping = {} | |
for i, dataset in enumerate(self.datasets): | |
if isinstance(dataset, BaseVideoDataset): | |
self.video_dataset_idx.append(i) | |
num_videos = len(dataset) | |
video_indices = [] | |
for video_ind in range(num_videos): | |
video_indices.extend( | |
[ | |
(video_ind, frame_ind) | |
for frame_ind in range(dataset.get_len_per_video(video_ind)) | |
] | |
) | |
self.datasets_indices_mapping[i] = video_indices | |
elif isinstance(dataset, BaseDetDataset): | |
self.img_dataset_idx.append(i) | |
img_indices = [] | |
num_imgs = len(dataset) | |
for img_ind in range(num_imgs): | |
img_indices.extend([img_ind]) | |
self.datasets_indices_mapping[i] = img_indices | |
else: | |
raise TypeError( | |
"elements in datasets sequence should be config or " | |
f"`BaseDataset` instance, but got {type(dataset)}" | |
) | |
self.fixed_length = fixed_length | |
self.metainfo = self.datasets[0].metainfo | |
total_datasets_length = sum( | |
[len(indices) for key, indices in self.datasets_indices_mapping.items()] | |
) | |
assert ( | |
self.fixed_length <= total_datasets_length | |
), "the length of the concatenated dataset must be less than the sum of the lengths of the individual datasets" | |
self.flag = np.zeros(self.fixed_length, dtype=np.uint8) | |
self._fully_initialized = False | |
if not lazy_init: | |
self.full_init() | |
self.video_sampling_probs = video_sampling_probs | |
self.img_sampling_probs = img_sampling_probs | |
if self.video_sampling_probs: | |
assert ( | |
sum(self.video_sampling_probs) == 1.0 | |
), "Sum of video sampling probabilities must be 1.0" | |
if self.img_sampling_probs: | |
assert ( | |
sum(self.img_sampling_probs) == 1.0 | |
), "Sum of image sampling probabilities must be 1.0" | |
def full_init(self): | |
"""Loop to ``full_init`` each dataset.""" | |
if self._fully_initialized: | |
return | |
for i, dataset in enumerate(self.datasets): | |
dataset.full_init() | |
self._ori_len = self.fixed_length | |
self._fully_initialized = True | |
def get_data_info(self, idx: int) -> dict: | |
"""Get annotation by index. | |
Args: | |
idx (int): Global index of ``ConcatDataset``. | |
Returns: | |
dict: The idx-th annotation of the datasets. | |
""" | |
return self.dataset.get_data_info(idx) | |
def __len__(self): | |
return self.fixed_length | |
def __getitem__(self, idx): | |
# idx ==0 means samples from video dataset, idx == 1 means samples from image dataset | |
# Choose a dataset based on the sampling probabilities | |
if idx == 0: | |
chosen_dataset_idx = random.choices( | |
self.video_dataset_idx, weights=self.video_sampling_probs, k=1 | |
)[0] | |
elif idx == 1: | |
chosen_dataset_idx = random.choices( | |
self.img_dataset_idx, weights=self.img_sampling_probs, k=1 | |
)[0] | |
chosen_dataset = self.datasets[chosen_dataset_idx] | |
# Sample a random item from the chosen dataset | |
sample_idx = random.choice(self.datasets_indices_mapping[chosen_dataset_idx]) | |
return chosen_dataset[sample_idx] | |