Spaces:
Runtime error
Runtime error
File size: 4,040 Bytes
03561be |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 |
import numpy as np
import torch
from .alpaca_gpt4_dataset import AlpacaGPT4Dataset # noqa: F401
from .aokvqa_dataset import AOKVQADataset # noqa: F401
from .cc_sbu_align_dataset import CcSbuAlignDataset # noqa: F401
from .clevr_dataset import CLEVRDataset # noqa: F401
from .coco_caption_dataset import COCOCaptionDataset # noqa: F401
from .dial_dataset import DialDataset # noqa: F401
from .dolly_dataset import DollyDataset # noqa: F401
from .gqa_dataset import GQADataset # noqa: F401
from .llava_dataset import LlavaDataset # noqa: F401
from .nlvr_dataset import NLVRv1Dataset, NLVRv2Dataset # noqa: F401
from .ocr_vqa_dataset import OCRVQADataset # noqa: F401
from .snli_ve_datasets import SNLIVEDataset # noqa: F401
from .text_ocr_dataset import TextOCRDataset # noqa: F401
from .vqa_dataset import ConcatDataset, VQADataset # noqa: F401
from .baize_dataset import BaiZeDataset # noqa: F401
def build_dataset(dataset_config, **kwargs):
if isinstance(dataset_config, list):
datasets = [build_dataset(cfg, **kwargs) for cfg in dataset_config]
return ConcatDataset(datasets)
dataset_type = dataset_config.pop("type")
sample = dataset_config.pop("sample", -1)
if dataset_type == "llava":
dataset = LlavaDataset(
**dataset_config,
**kwargs,
)
elif dataset_type == "vqa":
dataset = VQADataset(
**dataset_config,
**kwargs,
)
elif dataset_type == "minigpt4":
dataset = CcSbuAlignDataset(
**dataset_config,
**kwargs,
)
elif dataset_type == "llava_dial":
dataset = DialDataset(
**dataset_config,
**kwargs,
)
elif dataset_type == "coco_dial":
dataset = DialDataset(
**dataset_config,
**kwargs,
)
elif dataset_type == "aokvqa":
dataset = AOKVQADataset(
**dataset_config,
**kwargs,
)
elif dataset_type == "okvqa":
dataset = VQADataset(
**dataset_config,
**kwargs,
)
elif dataset_type == "text_ocr":
dataset = TextOCRDataset(
**dataset_config,
**kwargs,
)
elif dataset_type == "ocr_vqa":
dataset = OCRVQADataset(
**dataset_config,
**kwargs,
)
elif dataset_type == "coco_caption":
dataset = COCOCaptionDataset(
**dataset_config,
**kwargs,
)
elif dataset_type == "gqa":
dataset = GQADataset(
**dataset_config,
**kwargs,
)
elif dataset_type == "clevr":
dataset = CLEVRDataset(
**dataset_config,
**kwargs,
)
elif dataset_type == "nlvrv1":
dataset = NLVRv1Dataset(
**dataset_config,
**kwargs,
)
elif dataset_type == "nlvrv2":
dataset = NLVRv2Dataset(
**dataset_config,
**kwargs,
)
elif dataset_type == "snlive":
dataset = SNLIVEDataset(
**dataset_config,
**kwargs,
)
elif dataset_type == "dolly":
dataset = DollyDataset(
**dataset_config,
**kwargs,
)
elif dataset_type == "alpaca_gpt4":
dataset = AlpacaGPT4Dataset(
**dataset_config,
**kwargs,
)
elif dataset_type == "baize":
dataset = BaiZeDataset(
**dataset_config,
**kwargs,
)
else:
raise NotImplementedError
if sample > 0:
random_indices = np.random.choice(len(dataset), min(sample, len(dataset)), replace=False)
subsample_dataset = torch.utils.data.Subset(dataset, random_indices)
subsample_dataset.collater = dataset.collater
return subsample_dataset
else:
return dataset
|