File size: 5,163 Bytes
032e687 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
# Copyright (c) OpenMMLab. All rights reserved.
import json
import logging
import os
import torch
from datasets import Dataset as HFDataset
from datasets import DatasetDict, load_from_disk
from mmengine import print_log
from mmengine.config import Config, ConfigDict
from PIL import Image
from torch.utils.data import Dataset
from xtuner.registry import BUILDER
from xtuner.dataset.utils import expand2square, encode_fn
from xtuner.dataset.llava import load_jsonl
from xtuner.dataset.huggingface import build_origin_dataset
class LLaVALazyDataset(Dataset):
def __init__(self,
image_folder,
image_processor,
data_path=None,
tokenizer=None,
offline_processed_text_folder=None,
max_dataset_length=None,
dataset_map_fn=None,
template_map_fn=None,
max_length=2048,
pad_image_to_square=False,
lazy=False,
):
super().__init__()
assert offline_processed_text_folder or (data_path and tokenizer)
if offline_processed_text_folder and data_path:
print_log(
'Both `offline_processed_text_folder` and '
'`data_path` are set, and we load dataset from'
'`offline_processed_text_folder` '
f'({offline_processed_text_folder})',
logger='current',
level=logging.WARNING)
if offline_processed_text_folder is not None:
self.text_data = load_from_disk(offline_processed_text_folder)
else:
if data_path.endswith('.json'):
json_data = json.load(open(data_path))
elif data_path.endswith('.jsonl'):
json_data = load_jsonl(data_path)
else:
raise NotImplementedError
for idx in range(len(json_data)):
if "id" in json_data[idx].keys() and isinstance(json_data[idx]['id'], int):
json_data[idx]['id'] = str(json_data[idx]['id'])
json_data = DatasetDict({'train': HFDataset.from_list(json_data)})
assert max_dataset_length is None, "max_dataset_length is not supported in Lazy mode"
self.text_data = build_origin_dataset(json_data, 'train')
self.image_folder = image_folder
if isinstance(image_processor, dict) or isinstance(
image_processor, Config) or isinstance(image_processor,
ConfigDict):
self.image_processor = BUILDER.build(image_processor)
else:
self.image_processor = image_processor
self.pad_image_to_square = pad_image_to_square
# is_lazy = True
if lazy:
self.tokenizer = tokenizer
if isinstance(self.tokenizer, dict) or isinstance(self.tokenizer, Config) or isinstance(self.tokenizer, ConfigDict):
self.tokenizer = BUILDER.build(self.tokenizer)
self.max_length = max_length
self.dataset_map_fn = dataset_map_fn
if isinstance(template_map_fn, dict) or isinstance(template_map_fn, Config) or isinstance(
template_map_fn, ConfigDict):
template_map_fn = BUILDER.build(template_map_fn)
self.template_map_fn = template_map_fn
@property
def modality_length(self):
length_list = []
for data_dict in self.text_data:
cur_len = len(data_dict['input_ids'])
if data_dict.get('image', None) is None:
cur_len = -cur_len
length_list.append(cur_len)
return length_list
def __len__(self):
return len(self.text_data)
def __getitem__(self, index):
data_dict = self.text_data[index]
if data_dict.get('image', None) is not None:
image_file = data_dict['image']
image = Image.open(os.path.join(self.image_folder,
image_file)).convert('RGB')
if self.pad_image_to_square:
image = expand2square(
image,
tuple(
int(x * 255) for x in self.image_processor.image_mean))
image = self.image_processor.preprocess(
image, return_tensors='pt')['pixel_values'][0]
data_dict['pixel_values'] = image
else:
if hasattr(self.image_processor, 'crop_size'):
crop_size = self.image_processor.crop_size
else:
crop_size = self.image_processor.size
data_dict['pixel_values'] = torch.zeros(3, crop_size['height'],
crop_size['width'])
result = self.dataset_map_fn(data_dict)
data_dict.update(result)
result = self.template_map_fn(data_dict)
data_dict.update(result)
result = encode_fn(data_dict, tokenizer=self.tokenizer, max_length=self.max_length, with_image_token=True)
data_dict.update(result)
return data_dict
|