File size: 4,040 Bytes
03561be
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import numpy as np
import torch

from .alpaca_gpt4_dataset import AlpacaGPT4Dataset  # noqa: F401
from .aokvqa_dataset import AOKVQADataset  # noqa: F401
from .cc_sbu_align_dataset import CcSbuAlignDataset  # noqa: F401
from .clevr_dataset import CLEVRDataset  # noqa: F401
from .coco_caption_dataset import COCOCaptionDataset  # noqa: F401
from .dial_dataset import DialDataset  # noqa: F401
from .dolly_dataset import DollyDataset  # noqa: F401
from .gqa_dataset import GQADataset  # noqa: F401
from .llava_dataset import LlavaDataset  # noqa: F401
from .nlvr_dataset import NLVRv1Dataset, NLVRv2Dataset  # noqa: F401
from .ocr_vqa_dataset import OCRVQADataset  # noqa: F401
from .snli_ve_datasets import SNLIVEDataset  # noqa: F401
from .text_ocr_dataset import TextOCRDataset  # noqa: F401
from .vqa_dataset import ConcatDataset, VQADataset  # noqa: F401
from .baize_dataset import BaiZeDataset  # noqa: F401


def build_dataset(dataset_config, **kwargs):
    if isinstance(dataset_config, list):
        datasets = [build_dataset(cfg, **kwargs) for cfg in dataset_config]
        return ConcatDataset(datasets)
    dataset_type = dataset_config.pop("type")
    sample = dataset_config.pop("sample", -1)
    if dataset_type == "llava":
        dataset = LlavaDataset(
            **dataset_config,
            **kwargs,
        )
    elif dataset_type == "vqa":
        dataset = VQADataset(
            **dataset_config,
            **kwargs,
        )
    elif dataset_type == "minigpt4":
        dataset = CcSbuAlignDataset(
            **dataset_config,
            **kwargs,
        )
    elif dataset_type == "llava_dial":
        dataset = DialDataset(
            **dataset_config,
            **kwargs,
        )
    elif dataset_type == "coco_dial":
        dataset = DialDataset(
            **dataset_config,
            **kwargs,
        )
    elif dataset_type == "aokvqa":
        dataset = AOKVQADataset(
            **dataset_config,
            **kwargs,
        )
    elif dataset_type == "okvqa":
        dataset = VQADataset(
            **dataset_config,
            **kwargs,
        )
    elif dataset_type == "text_ocr":
        dataset = TextOCRDataset(
            **dataset_config,
            **kwargs,
        )
    elif dataset_type == "ocr_vqa":
        dataset = OCRVQADataset(
            **dataset_config,
            **kwargs,
        )
    elif dataset_type == "coco_caption":
        dataset = COCOCaptionDataset(
            **dataset_config,
            **kwargs,
        )
    elif dataset_type == "gqa":
        dataset = GQADataset(
            **dataset_config,
            **kwargs,
        )
    elif dataset_type == "clevr":
        dataset = CLEVRDataset(
            **dataset_config,
            **kwargs,
        )
    elif dataset_type == "nlvrv1":
        dataset = NLVRv1Dataset(
            **dataset_config,
            **kwargs,
        )
    elif dataset_type == "nlvrv2":
        dataset = NLVRv2Dataset(
            **dataset_config,
            **kwargs,
        )
    elif dataset_type == "snlive":
        dataset = SNLIVEDataset(
            **dataset_config,
            **kwargs,
        )
    elif dataset_type == "dolly":
        dataset = DollyDataset(
            **dataset_config,
            **kwargs,
        )
    elif dataset_type == "alpaca_gpt4":
        dataset = AlpacaGPT4Dataset(
            **dataset_config,
            **kwargs,
        )
    elif dataset_type == "baize":
        dataset = BaiZeDataset(
            **dataset_config,
            **kwargs,
        )
    else:
        raise NotImplementedError

    if sample > 0:
        random_indices = np.random.choice(len(dataset), min(sample, len(dataset)), replace=False)
        subsample_dataset = torch.utils.data.Subset(dataset, random_indices)
        subsample_dataset.collater = dataset.collater
        return subsample_dataset
    else:
        return dataset