zhouyik's picture
Upload folder using huggingface_hub
032e687 verified
# Copyright (c) OpenMMLab. All rights reserved.
import json
import logging
import os
import torch
from datasets import Dataset as HFDataset
from datasets import DatasetDict, load_from_disk
from mmengine import print_log
from mmengine.config import Config, ConfigDict
from PIL import Image
from torch.utils.data import Dataset
from xtuner.registry import BUILDER
from xtuner.dataset.utils import expand2square
from .encode_fns import encode_fn
from xtuner.dataset.llava import load_jsonl
from xtuner.dataset.huggingface import build_origin_dataset
import copy
import numpy as np
from projects.omg_llava.dataset.utils import expand2square_mask
import torch.nn.functional as F
class RegionLLaVALazyDataset(Dataset):
def __init__(self,
image_folder,
image_processor,
data_path=None,
tokenizer=None,
offline_processed_text_folder=None,
max_dataset_length=None,
dataset_map_fn=None,
template_map_fn=None,
max_length=2048,
pad_image_to_square=False,
lazy=False,
feat_down_ratio=14,
repeats=1,
):
super().__init__()
assert offline_processed_text_folder or (data_path and tokenizer)
if offline_processed_text_folder and data_path:
print_log(
'Both `offline_processed_text_folder` and '
'`data_path` are set, and we load dataset from'
'`offline_processed_text_folder` '
f'({offline_processed_text_folder})',
logger='current',
level=logging.WARNING)
if offline_processed_text_folder is not None:
self.text_data = load_from_disk(offline_processed_text_folder)
else:
print("Loading {}!!!".format(data_path))
if data_path.endswith('.json'):
json_data = json.load(open(data_path))
elif data_path.endswith('.jsonl'):
json_data = load_jsonl(data_path)
else:
raise NotImplementedError
print("Loaded {}!!!".format(data_path))
for idx in range(len(json_data)):
if "id" in json_data[idx].keys() and isinstance(json_data[idx]['id'], int):
json_data[idx]['id'] = str(json_data[idx]['id'])
json_data = DatasetDict({'train': HFDataset.from_list(json_data)})
assert max_dataset_length is None, "max_dataset_length is not supported in Lazy mode"
self.text_data = build_origin_dataset(json_data, 'train')
self.image_folder = image_folder
if isinstance(image_processor, dict) or isinstance(
image_processor, Config) or isinstance(image_processor,
ConfigDict):
self.image_processor = BUILDER.build(image_processor)
else:
self.image_processor = image_processor
self.image_size = self.image_processor.crop_size
assert self.image_size['height'] == self.image_size['width']
self.image_size = self.image_size['height']
self.feat_size = self.image_size // feat_down_ratio
self.pad_image_to_square = pad_image_to_square
# is_lazy = True
self.lazy = lazy
if lazy:
self.tokenizer = tokenizer
if isinstance(self.tokenizer, dict) or isinstance(self.tokenizer, Config) or isinstance(self.tokenizer, ConfigDict):
self.tokenizer = BUILDER.build(self.tokenizer)
self.max_length = max_length
self.dataset_map_fn = dataset_map_fn
if isinstance(template_map_fn, dict) or isinstance(template_map_fn, Config) or isinstance(
template_map_fn, ConfigDict):
template_map_fn = BUILDER.build(template_map_fn)
self.template_map_fn = template_map_fn
self.repeats = repeats
@property
def modality_length(self):
if self.lazy:
length_list = [1000] * len(self.text_data) * self.repeats
return length_list
length_list = []
for data_dict in self.text_data:
if 'input_ids' in data_dict.keys():
cur_len = len(data_dict['input_ids'])
else:
cur_len = 1000
if data_dict.get('image', None) is None:
cur_len = -cur_len
length_list.append(cur_len)
return length_list
def __len__(self):
return len(self.text_data) * self.repeats
def __getitem__(self, index):
index = index % len(self.text_data)
data_dict = copy.deepcopy(self.text_data[index])
if 'image' not in data_dict.keys() and 'image_name' in data_dict.keys():
data_dict['image'] = data_dict['image_name']
if data_dict.get('image', None) is not None:
image_file = data_dict['image']
image = Image.open(os.path.join(self.image_folder,
image_file)).convert('RGB')
if self.pad_image_to_square:
image = expand2square(
image,
tuple(
int(x * 255) for x in self.image_processor.image_mean))
image = self.image_processor.preprocess(
image, return_tensors='pt')['pixel_values'][0]
data_dict['pixel_values'] = image
else:
if hasattr(self.image_processor, 'crop_size'):
crop_size = self.image_processor.crop_size
else:
crop_size = self.image_processor.size
data_dict['pixel_values'] = torch.zeros(3, crop_size['height'],
crop_size['width'])
result = self.dataset_map_fn(data_dict)
data_dict.update(result)
result = self.template_map_fn(data_dict)
data_dict.update(result)
result = encode_fn(data_dict, tokenizer=self.tokenizer, max_length=self.max_length, with_image_token=True)
data_dict.update(result)
if 'region_masks' not in data_dict.keys():
region_masks = np.ones((1, self.feat_size, self.feat_size), dtype=np.uint8)
region_masks = torch.from_numpy(region_masks)
else:
region_masks = data_dict['region_masks']
if self.pad_image_to_square:
region_masks = expand2square_mask(region_masks)
region_masks = torch.from_numpy(region_masks)
region_masks = F.interpolate(region_masks.unsqueeze(0),
size=(self.feat_size, self.feat_size), mode='nearest').squeeze(0)
data_dict['region_masks'] = region_masks
return data_dict
class CombineDataset(Dataset):
def __init__(self,
datasets_cfgs,
):
super().__init__()
self.datasets = []
self.datasets_length = []
self.tokenizer = datasets_cfgs[0].tokenizer
tokenizer_type = self.tokenizer['type']
del self.tokenizer['type']
self.tokenizer = tokenizer_type(**self.tokenizer)
for i in range(len(datasets_cfgs)):
datasets_cfgs[i].tokenizer = self.tokenizer
for dataset_cfg in datasets_cfgs:
dataset = dataset_cfg['type']
del dataset_cfg['type']
dataset = dataset(**dataset_cfg)
self.datasets.append(dataset)
self.datasets_length.append(len(dataset))
self.dataset_threthold = []
for i, length in enumerate(self.datasets_length):
if i == 0:
self.dataset_threthold.append(length)
else:
self.dataset_threthold.append(length + self.dataset_threthold[i - 1])
np.random.seed(42)
self.shuffled_index = np.arange(self.dataset_threthold[-1])
np.random.shuffle(self.shuffled_index)
@property
def modality_length(self):
length_list = []
for dataset in self.datasets:
length_list += dataset.modality_length
return length_list
def __len__(self):
return self.dataset_threthold[-1]
def __getitem__(self, index):
index = int(self.shuffled_index[index])
for i, thred in enumerate(self.dataset_threthold):
if index < thred:
break
if i == 0:
_index = index
else:
_index = index - self.dataset_threthold[i - 1]
return self.datasets[i][_index]