File size: 2,899 Bytes
032e687
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
from torch.utils.data import Dataset
import numpy as np
from transformers import AutoConfig, AutoTokenizer, AutoImageProcessor
from .utils import DEFAULT_VISION_PROMPT_TOKEN, VPT_CONTEXT_TOKEN, VPT_START_TOKEN, VPT_END_TOKEN

class LLaVACombineDataset(Dataset):
    def __init__(self,
                 datasets_cfgs,
                 exhibit_special_tokens=False,
                 llava_processor=None,
                 ot_image_processor=None,
                 repeat_time=1,
                 ):
        super().__init__()

        self.datasets = []
        self.datasets_length = []
        
        if ot_image_processor:
            process_clazz = ot_image_processor.pop('type')
            ot_image_processor = process_clazz(**ot_image_processor)
        else:
            ot_image_processor = None
        if llava_processor:
            llava_processor_clazz = llava_processor.pop('type')
            self.llava_processor = llava_processor_clazz(**llava_processor)
        else:
            self.llava_processor = None
        if not exhibit_special_tokens:
            self._add_special_tokens()

        for dataset_cfg in datasets_cfgs:
            dataset = dataset_cfg['type']
            ori_repeat_time = dataset_cfg['repeat_time']
            del dataset_cfg['type']
            dataset_cfg.update(dict(ot_image_processor=ot_image_processor,
                                    llava_processor=self.llava_processor,
                                    repeat_time=ori_repeat_time*repeat_time))
            dataset = dataset(**dataset_cfg)
            self.datasets.append(dataset)
            self.datasets_length.append(len(dataset))
        
        self.dataset_threshold = []
        for i, length in enumerate(self.datasets_length):
            if i == 0:
                self.dataset_threshold.append(length)
            else:
                self.dataset_threshold.append(length + self.dataset_threshold[i - 1])
        
        np.random.seed(42)
        self.shuffled_index = np.arange(self.dataset_threshold[-1])
        np.random.shuffle(self.shuffled_index)

    def _add_special_tokens(self):
        special_tokens = [VPT_CONTEXT_TOKEN,]
        num_new_tokens = self.llava_processor.tokenizer.add_tokens(special_tokens, special_tokens=True)
    
    @property
    def modality_length(self):
        length_list = []
        for dataset in self.datasets:
            length_list += dataset.modality_length
        return length_list
    
    def __len__(self):
        return self.dataset_threshold[-1]
    
    def __getitem__(self, index):
        index = int(self.shuffled_index[index])
        for i, thred in enumerate(self.dataset_threshold):
            if index < thred:
                break
        if i == 0:
            _index = index
        else:
            _index = index - self.dataset_threshold[i - 1]
        
        return self.datasets[i][_index]