|
|
|
import json |
|
import os |
|
|
|
import torch |
|
from datasets import Dataset as HFDataset |
|
from datasets import DatasetDict |
|
from mmengine.config import Config, ConfigDict |
|
from PIL import Image |
|
from torch.utils.data import Dataset |
|
|
|
from xtuner.registry import BUILDER |
|
from .huggingface import process_hf_dataset |
|
from .utils import expand2square |
|
|
|
|
|
class LLaVADataset(Dataset): |
|
|
|
def __init__(self, |
|
data_path, |
|
image_folder, |
|
tokenizer, |
|
image_processor, |
|
max_dataset_length=None, |
|
dataset_map_fn=None, |
|
template_map_fn=None, |
|
max_length=2048, |
|
pad_image_to_square=False): |
|
super().__init__() |
|
|
|
json_data = json.load(open(data_path)) |
|
for idx in range(len(json_data)): |
|
if isinstance(json_data[idx]['id'], int): |
|
json_data[idx]['id'] = str(json_data[idx]['id']) |
|
json_data = DatasetDict({'train': HFDataset.from_list(json_data)}) |
|
self.text_data = process_hf_dataset( |
|
dataset=json_data, |
|
tokenizer=tokenizer, |
|
max_length=max_length, |
|
dataset_map_fn=dataset_map_fn, |
|
template_map_fn=template_map_fn, |
|
split='train', |
|
max_dataset_length=max_dataset_length, |
|
remove_unused_columns=False, |
|
pack_to_max_length=False, |
|
with_image_token=True) |
|
|
|
self.image_folder = image_folder |
|
if isinstance(image_processor, dict) or isinstance( |
|
image_processor, Config) or isinstance(image_processor, |
|
ConfigDict): |
|
self.image_processor = BUILDER.build(image_processor) |
|
else: |
|
self.image_processor = image_processor |
|
self.pad_image_to_square = pad_image_to_square |
|
|
|
@property |
|
def modality_length(self): |
|
length_list = [] |
|
for data_dict in self.text_data: |
|
cur_len = len(data_dict['input_ids']) |
|
if data_dict.get('image', None) is None: |
|
cur_len = -cur_len |
|
length_list.append(cur_len) |
|
return length_list |
|
|
|
def __len__(self): |
|
return len(self.text_data) |
|
|
|
def __getitem__(self, index): |
|
data_dict = self.text_data[index] |
|
if data_dict.get('image', None) is not None: |
|
image_file = data_dict['image'] |
|
image = Image.open(os.path.join(self.image_folder, |
|
image_file)).convert('RGB') |
|
if self.pad_image_to_square: |
|
image = expand2square( |
|
image, |
|
tuple( |
|
int(x * 255) for x in self.image_processor.image_mean)) |
|
image = self.image_processor.preprocess( |
|
image, return_tensors='pt')['pixel_values'][0] |
|
data_dict['pixel_values'] = image |
|
else: |
|
size = self.image_processor.size |
|
data_dict['pixel_values'] = torch.zeros(3, size['height'], |
|
size['width']) |
|
return data_dict |
|
|