Spaces:
Running
on
T4
Running
on
T4
File size: 4,086 Bytes
186701e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 |
# Copyright (c) Tencent Inc. All rights reserved.
import copy
import json
import logging
from typing import Callable, List, Union
from mmengine.logging import print_log
from mmengine.dataset.base_dataset import (
BaseDataset, Compose, force_full_init)
from mmyolo.registry import DATASETS
@DATASETS.register_module()
class MultiModalDataset:
"""Multi-modal dataset."""
def __init__(self,
dataset: Union[BaseDataset, dict],
class_text_path: str = None,
test_mode: bool = True,
pipeline: List[Union[dict, Callable]] = [],
lazy_init: bool = False) -> None:
self.dataset: BaseDataset
if isinstance(dataset, dict):
self.dataset = DATASETS.build(dataset)
elif isinstance(dataset, BaseDataset):
self.dataset = dataset
else:
raise TypeError(
'dataset must be a dict or a BaseDataset, '
f'but got {dataset}')
if class_text_path is not None:
self.class_texts = json.load(open(class_text_path, 'r'))
# ori_classes = self.dataset.metainfo['classes']
# assert len(ori_classes) == len(self.class_texts), \
# ('The number of classes in the dataset and the class text'
# 'file must be the same.')
else:
self.class_texts = None
self.test_mode = test_mode
self._metainfo = self.dataset.metainfo
self.pipeline = Compose(pipeline)
self._fully_initialized = False
if not lazy_init:
self.full_init()
@property
def metainfo(self) -> dict:
return copy.deepcopy(self._metainfo)
def full_init(self) -> None:
"""``full_init`` dataset."""
if self._fully_initialized:
return
self.dataset.full_init()
self._ori_len = len(self.dataset)
self._fully_initialized = True
@force_full_init
def get_data_info(self, idx: int) -> dict:
"""Get annotation by index."""
data_info = self.dataset.get_data_info(idx)
if self.class_texts is not None:
data_info.update({'texts': self.class_texts})
return data_info
def __getitem__(self, idx):
if not self._fully_initialized:
print_log(
'Please call `full_init` method manually to '
'accelerate the speed.',
logger='current',
level=logging.WARNING)
self.full_init()
data_info = self.get_data_info(idx)
if hasattr(self.dataset, 'test_mode') and not self.dataset.test_mode:
data_info['dataset'] = self
elif not self.test_mode:
data_info['dataset'] = self
return self.pipeline(data_info)
@force_full_init
def __len__(self) -> int:
return self._ori_len
@DATASETS.register_module()
class MultiModalMixedDataset(MultiModalDataset):
"""Multi-modal Mixed dataset.
mix "detection dataset" and "caption dataset"
Args:
dataset_type (str): dataset type, 'detection' or 'caption'
"""
def __init__(self,
dataset: Union[BaseDataset, dict],
class_text_path: str = None,
dataset_type: str = 'detection',
test_mode: bool = True,
pipeline: List[Union[dict, Callable]] = [],
lazy_init: bool = False) -> None:
self.dataset_type = dataset_type
super().__init__(dataset,
class_text_path,
test_mode,
pipeline,
lazy_init)
@force_full_init
def get_data_info(self, idx: int) -> dict:
"""Get annotation by index."""
data_info = self.dataset.get_data_info(idx)
if self.class_texts is not None:
data_info.update({'texts': self.class_texts})
data_info['is_detection'] = 1 \
if self.dataset_type == 'detection' else 0
return data_info
|