import numpy as np import torch from .alpaca_gpt4_dataset import AlpacaGPT4Dataset # noqa: F401 from .aokvqa_dataset import AOKVQADataset # noqa: F401 from .cc_sbu_align_dataset import CcSbuAlignDataset # noqa: F401 from .clevr_dataset import CLEVRDataset # noqa: F401 from .coco_caption_dataset import COCOCaptionDataset # noqa: F401 from .dial_dataset import DialDataset # noqa: F401 from .dolly_dataset import DollyDataset # noqa: F401 from .gqa_dataset import GQADataset # noqa: F401 from .llava_dataset import LlavaDataset # noqa: F401 from .nlvr_dataset import NLVRv1Dataset, NLVRv2Dataset # noqa: F401 from .ocr_vqa_dataset import OCRVQADataset # noqa: F401 from .snli_ve_datasets import SNLIVEDataset # noqa: F401 from .text_ocr_dataset import TextOCRDataset # noqa: F401 from .vqa_dataset import ConcatDataset, VQADataset # noqa: F401 from .baize_dataset import BaiZeDataset # noqa: F401 def build_dataset(dataset_config, **kwargs): if isinstance(dataset_config, list): datasets = [build_dataset(cfg, **kwargs) for cfg in dataset_config] return ConcatDataset(datasets) dataset_type = dataset_config.pop("type") sample = dataset_config.pop("sample", -1) if dataset_type == "llava": dataset = LlavaDataset( **dataset_config, **kwargs, ) elif dataset_type == "vqa": dataset = VQADataset( **dataset_config, **kwargs, ) elif dataset_type == "minigpt4": dataset = CcSbuAlignDataset( **dataset_config, **kwargs, ) elif dataset_type == "llava_dial": dataset = DialDataset( **dataset_config, **kwargs, ) elif dataset_type == "coco_dial": dataset = DialDataset( **dataset_config, **kwargs, ) elif dataset_type == "aokvqa": dataset = AOKVQADataset( **dataset_config, **kwargs, ) elif dataset_type == "okvqa": dataset = VQADataset( **dataset_config, **kwargs, ) elif dataset_type == "text_ocr": dataset = TextOCRDataset( **dataset_config, **kwargs, ) elif dataset_type == "ocr_vqa": dataset = OCRVQADataset( **dataset_config, **kwargs, ) elif dataset_type == "coco_caption": dataset = COCOCaptionDataset( **dataset_config, **kwargs, ) elif dataset_type == "gqa": dataset = GQADataset( **dataset_config, **kwargs, ) elif dataset_type == "clevr": dataset = CLEVRDataset( **dataset_config, **kwargs, ) elif dataset_type == "nlvrv1": dataset = NLVRv1Dataset( **dataset_config, **kwargs, ) elif dataset_type == "nlvrv2": dataset = NLVRv2Dataset( **dataset_config, **kwargs, ) elif dataset_type == "snlive": dataset = SNLIVEDataset( **dataset_config, **kwargs, ) elif dataset_type == "dolly": dataset = DollyDataset( **dataset_config, **kwargs, ) elif dataset_type == "alpaca_gpt4": dataset = AlpacaGPT4Dataset( **dataset_config, **kwargs, ) elif dataset_type == "baize": dataset = BaiZeDataset( **dataset_config, **kwargs, ) else: raise NotImplementedError if sample > 0: random_indices = np.random.choice(len(dataset), min(sample, len(dataset)), replace=False) subsample_dataset = torch.utils.data.Subset(dataset, random_indices) subsample_dataset.collater = dataset.collater return subsample_dataset else: return dataset