File size: 2,899 Bytes
032e687 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 |
from torch.utils.data import Dataset
import numpy as np
from transformers import AutoConfig, AutoTokenizer, AutoImageProcessor
from .utils import DEFAULT_VISION_PROMPT_TOKEN, VPT_CONTEXT_TOKEN, VPT_START_TOKEN, VPT_END_TOKEN
class LLaVACombineDataset(Dataset):
def __init__(self,
datasets_cfgs,
exhibit_special_tokens=False,
llava_processor=None,
ot_image_processor=None,
repeat_time=1,
):
super().__init__()
self.datasets = []
self.datasets_length = []
if ot_image_processor:
process_clazz = ot_image_processor.pop('type')
ot_image_processor = process_clazz(**ot_image_processor)
else:
ot_image_processor = None
if llava_processor:
llava_processor_clazz = llava_processor.pop('type')
self.llava_processor = llava_processor_clazz(**llava_processor)
else:
self.llava_processor = None
if not exhibit_special_tokens:
self._add_special_tokens()
for dataset_cfg in datasets_cfgs:
dataset = dataset_cfg['type']
ori_repeat_time = dataset_cfg['repeat_time']
del dataset_cfg['type']
dataset_cfg.update(dict(ot_image_processor=ot_image_processor,
llava_processor=self.llava_processor,
repeat_time=ori_repeat_time*repeat_time))
dataset = dataset(**dataset_cfg)
self.datasets.append(dataset)
self.datasets_length.append(len(dataset))
self.dataset_threshold = []
for i, length in enumerate(self.datasets_length):
if i == 0:
self.dataset_threshold.append(length)
else:
self.dataset_threshold.append(length + self.dataset_threshold[i - 1])
np.random.seed(42)
self.shuffled_index = np.arange(self.dataset_threshold[-1])
np.random.shuffle(self.shuffled_index)
def _add_special_tokens(self):
special_tokens = [VPT_CONTEXT_TOKEN,]
num_new_tokens = self.llava_processor.tokenizer.add_tokens(special_tokens, special_tokens=True)
@property
def modality_length(self):
length_list = []
for dataset in self.datasets:
length_list += dataset.modality_length
return length_list
def __len__(self):
return self.dataset_threshold[-1]
def __getitem__(self, index):
index = int(self.shuffled_index[index])
for i, thred in enumerate(self.dataset_threshold):
if index < thred:
break
if i == 0:
_index = index
else:
_index = index - self.dataset_threshold[i - 1]
return self.datasets[i][_index] |