Spaces:
Runtime error
Runtime error
import numpy as np | |
import torch | |
from .alpaca_gpt4_dataset import AlpacaGPT4Dataset # noqa: F401 | |
from .aokvqa_dataset import AOKVQADataset # noqa: F401 | |
from .cc_sbu_align_dataset import CcSbuAlignDataset # noqa: F401 | |
from .clevr_dataset import CLEVRDataset # noqa: F401 | |
from .coco_caption_dataset import COCOCaptionDataset # noqa: F401 | |
from .dial_dataset import DialDataset # noqa: F401 | |
from .dolly_dataset import DollyDataset # noqa: F401 | |
from .gqa_dataset import GQADataset # noqa: F401 | |
from .llava_dataset import LlavaDataset # noqa: F401 | |
from .nlvr_dataset import NLVRv1Dataset, NLVRv2Dataset # noqa: F401 | |
from .ocr_vqa_dataset import OCRVQADataset # noqa: F401 | |
from .snli_ve_datasets import SNLIVEDataset # noqa: F401 | |
from .text_ocr_dataset import TextOCRDataset # noqa: F401 | |
from .vqa_dataset import ConcatDataset, VQADataset # noqa: F401 | |
from .baize_dataset import BaiZeDataset # noqa: F401 | |
def build_dataset(dataset_config, **kwargs): | |
if isinstance(dataset_config, list): | |
datasets = [build_dataset(cfg, **kwargs) for cfg in dataset_config] | |
return ConcatDataset(datasets) | |
dataset_type = dataset_config.pop("type") | |
sample = dataset_config.pop("sample", -1) | |
if dataset_type == "llava": | |
dataset = LlavaDataset( | |
**dataset_config, | |
**kwargs, | |
) | |
elif dataset_type == "vqa": | |
dataset = VQADataset( | |
**dataset_config, | |
**kwargs, | |
) | |
elif dataset_type == "minigpt4": | |
dataset = CcSbuAlignDataset( | |
**dataset_config, | |
**kwargs, | |
) | |
elif dataset_type == "llava_dial": | |
dataset = DialDataset( | |
**dataset_config, | |
**kwargs, | |
) | |
elif dataset_type == "coco_dial": | |
dataset = DialDataset( | |
**dataset_config, | |
**kwargs, | |
) | |
elif dataset_type == "aokvqa": | |
dataset = AOKVQADataset( | |
**dataset_config, | |
**kwargs, | |
) | |
elif dataset_type == "okvqa": | |
dataset = VQADataset( | |
**dataset_config, | |
**kwargs, | |
) | |
elif dataset_type == "text_ocr": | |
dataset = TextOCRDataset( | |
**dataset_config, | |
**kwargs, | |
) | |
elif dataset_type == "ocr_vqa": | |
dataset = OCRVQADataset( | |
**dataset_config, | |
**kwargs, | |
) | |
elif dataset_type == "coco_caption": | |
dataset = COCOCaptionDataset( | |
**dataset_config, | |
**kwargs, | |
) | |
elif dataset_type == "gqa": | |
dataset = GQADataset( | |
**dataset_config, | |
**kwargs, | |
) | |
elif dataset_type == "clevr": | |
dataset = CLEVRDataset( | |
**dataset_config, | |
**kwargs, | |
) | |
elif dataset_type == "nlvrv1": | |
dataset = NLVRv1Dataset( | |
**dataset_config, | |
**kwargs, | |
) | |
elif dataset_type == "nlvrv2": | |
dataset = NLVRv2Dataset( | |
**dataset_config, | |
**kwargs, | |
) | |
elif dataset_type == "snlive": | |
dataset = SNLIVEDataset( | |
**dataset_config, | |
**kwargs, | |
) | |
elif dataset_type == "dolly": | |
dataset = DollyDataset( | |
**dataset_config, | |
**kwargs, | |
) | |
elif dataset_type == "alpaca_gpt4": | |
dataset = AlpacaGPT4Dataset( | |
**dataset_config, | |
**kwargs, | |
) | |
elif dataset_type == "baize": | |
dataset = BaiZeDataset( | |
**dataset_config, | |
**kwargs, | |
) | |
else: | |
raise NotImplementedError | |
if sample > 0: | |
random_indices = np.random.choice(len(dataset), min(sample, len(dataset)), replace=False) | |
subsample_dataset = torch.utils.data.Subset(dataset, random_indices) | |
subsample_dataset.collater = dataset.collater | |
return subsample_dataset | |
else: | |
return dataset | |