# Copyright (c) Tencent Inc. All rights reserved. import copy import json import logging from typing import Callable, List, Union from mmengine.logging import print_log from mmengine.dataset.base_dataset import ( BaseDataset, Compose, force_full_init) from mmyolo.registry import DATASETS @DATASETS.register_module() class MultiModalDataset: """Multi-modal dataset.""" def __init__(self, dataset: Union[BaseDataset, dict], class_text_path: str = None, test_mode: bool = True, pipeline: List[Union[dict, Callable]] = [], lazy_init: bool = False) -> None: self.dataset: BaseDataset if isinstance(dataset, dict): self.dataset = DATASETS.build(dataset) elif isinstance(dataset, BaseDataset): self.dataset = dataset else: raise TypeError( 'dataset must be a dict or a BaseDataset, ' f'but got {dataset}') if class_text_path is not None: self.class_texts = json.load(open(class_text_path, 'r')) # ori_classes = self.dataset.metainfo['classes'] # assert len(ori_classes) == len(self.class_texts), \ # ('The number of classes in the dataset and the class text' # 'file must be the same.') else: self.class_texts = None self.test_mode = test_mode self._metainfo = self.dataset.metainfo self.pipeline = Compose(pipeline) self._fully_initialized = False if not lazy_init: self.full_init() @property def metainfo(self) -> dict: return copy.deepcopy(self._metainfo) def full_init(self) -> None: """``full_init`` dataset.""" if self._fully_initialized: return self.dataset.full_init() self._ori_len = len(self.dataset) self._fully_initialized = True @force_full_init def get_data_info(self, idx: int) -> dict: """Get annotation by index.""" data_info = self.dataset.get_data_info(idx) if self.class_texts is not None: data_info.update({'texts': self.class_texts}) return data_info def __getitem__(self, idx): if not self._fully_initialized: print_log( 'Please call `full_init` method manually to ' 'accelerate the speed.', logger='current', level=logging.WARNING) self.full_init() data_info = self.get_data_info(idx) if hasattr(self.dataset, 'test_mode') and not self.dataset.test_mode: data_info['dataset'] = self elif not self.test_mode: data_info['dataset'] = self return self.pipeline(data_info) @force_full_init def __len__(self) -> int: return self._ori_len @DATASETS.register_module() class MultiModalMixedDataset(MultiModalDataset): """Multi-modal Mixed dataset. mix "detection dataset" and "caption dataset" Args: dataset_type (str): dataset type, 'detection' or 'caption' """ def __init__(self, dataset: Union[BaseDataset, dict], class_text_path: str = None, dataset_type: str = 'detection', test_mode: bool = True, pipeline: List[Union[dict, Callable]] = [], lazy_init: bool = False) -> None: self.dataset_type = dataset_type super().__init__(dataset, class_text_path, test_mode, pipeline, lazy_init) @force_full_init def get_data_info(self, idx: int) -> dict: """Get annotation by index.""" data_info = self.dataset.get_data_info(idx) if self.class_texts is not None: data_info.update({'texts': self.class_texts}) data_info['is_detection'] = 1 \ if self.dataset_type == 'detection' else 0 return data_info