xzl12306's picture
first commit
d6bc023
import copy
from dataclasses import dataclass
import json
from typing import Dict, Sequence
import transformers
import torch
from torch.utils.data import Dataset
from PIL import Image, ImageFile
from tinychart.arguments import *
from tinychart.utils import *
from tinychart.data.process import *
from tinychart.constants import *
ImageFile.LOAD_TRUNCATED_IMAGES = True
class LazySupervisedDataset(Dataset):
"""Dataset for supervised fine-tuning."""
def __init__(self, data_path: str,
tokenizer: transformers.PreTrainedTokenizer,
data_args: DataArguments):
super(LazySupervisedDataset, self).__init__()
list_data_dict = json.load(open(data_path, "r"))
rank0_print("Formatting inputs...Skip in lazy mode")
self.tokenizer = tokenizer
self.list_data_dict = list_data_dict
self.data_args = data_args
def __len__(self):
return len(self.list_data_dict)
@property
def lengths(self):
length_list = []
for sample in self.list_data_dict:
img_tokens = 128 if 'image' in sample else 0
length_list.append(sum(len(conv['value'].split()) for conv in sample['conversations']) + img_tokens)
return length_list
@property
def modality_lengths(self):
length_list = []
for sample in self.list_data_dict:
cur_len = sum(len(conv['value'].split()) for conv in sample['conversations'])
cur_len = cur_len if 'image' in sample else -cur_len
length_list.append(cur_len)
return length_list
def __getitem__(self, i) -> Dict[str, torch.Tensor]:
sources = self.list_data_dict[i]
if isinstance(i, int):
sources = [sources]
assert len(sources) == 1, "Don't know why it is wrapped to a list" # FIXME
if 'image' in sources[0]:
image_file = self.list_data_dict[i]['image']
image_folder = self.data_args.image_folder
processor = self.data_args.image_processor
image = Image.open(os.path.join(image_folder, image_file)).convert('RGB')
if self.data_args.image_aspect_ratio == 'pad':
def expand2square(pil_img, background_color):
width, height = pil_img.size
if width == height:
return pil_img
elif width > height:
result = Image.new(pil_img.mode, (width, width), background_color)
result.paste(pil_img, (0, (width - height) // 2))
return result
else:
result = Image.new(pil_img.mode, (height, height), background_color)
result.paste(pil_img, ((height - width) // 2, 0))
return result
image = expand2square(image, tuple(int(x * 255) for x in processor.image_mean))
image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
else:
image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
sources = preprocess_multimodal(
copy.deepcopy([e["conversations"] for e in sources]),
self.data_args)
else:
sources = copy.deepcopy([e["conversations"] for e in sources])
data_dict = preprocess(
sources,
self.tokenizer,
has_image=('image' in self.list_data_dict[i]))
if isinstance(i, int):
data_dict = dict(input_ids=data_dict["input_ids"][0],
labels=data_dict["labels"][0])
# image exist in the data
if 'image' in self.list_data_dict[i]:
data_dict['image'] = image
elif self.data_args.is_multimodal:
# image does not exist in the data, but the model is multimodal
crop_size = self.data_args.image_processor.crop_size
data_dict['image'] = torch.zeros(3, crop_size['height'], crop_size['width'])
return data_dict
@dataclass
class DataCollatorForSupervisedDataset(object):
"""Collate examples for supervised fine-tuning."""
tokenizer: transformers.PreTrainedTokenizer
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
input_ids, labels = tuple([instance[key] for instance in instances]
for key in ("input_ids", "labels"))
if self.tokenizer.pad_token_id == self.tokenizer.eos_token_id:
for input_id in input_ids:
input_id[input_id == self.tokenizer.eos_token_id] = -300
input_ids = torch.nn.utils.rnn.pad_sequence(
input_ids,
batch_first=True,
padding_value=self.tokenizer.pad_token_id)
labels = torch.nn.utils.rnn.pad_sequence(labels,
batch_first=True,
padding_value=IGNORE_INDEX)
input_ids = input_ids[:, :self.tokenizer.model_max_length]
attention_mask = input_ids.ne(self.tokenizer.pad_token_id)
labels = labels[:, :self.tokenizer.model_max_length]
# FIXME: This is a hack for handling phi and stablelm, as they have the same eos, pad and unk. We want the model
# FIXME: to predict the eos in the input ids, but we also use the id of eos to pad sequence, so we use a temp
# FIXME: eos id first, and convert them back.
if self.tokenizer.pad_token_id == self.tokenizer.eos_token_id:
for input_id in input_ids:
input_id[input_id == -300] = self.tokenizer.eos_token_id
batch = dict(
input_ids=input_ids,
labels=labels,
attention_mask=attention_mask,
)
if 'image' in instances[0]:
images = [instance['image'] for instance in instances]
if all(x is not None and x.shape == images[0].shape for x in images):
batch['images'] = torch.stack(images)
else:
batch['images'] = images
if 'question' in instances[0]:
questions = [instance['question'] for instance in instances]
batch['questions'] = questions
if 'answer' in instances[0]:
answers = [instance['answer'] for instance in instances]
batch['answers'] = answers
return batch
def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer,
data_args) -> Dict:
"""Make dataset and collator for supervised fine-tuning."""
train_dataset = LazySupervisedDataset(tokenizer=tokenizer,
data_path=data_args.data_path,
data_args=data_args)
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
return dict(train_dataset=train_dataset,
eval_dataset=None,
data_collator=data_collator)
def make_supervised_data_module_with_eval(tokenizer: transformers.PreTrainedTokenizer,
data_args) -> Dict:
"""Make dataset and collator for supervised fine-tuning."""
train_dataset = LazySupervisedDataset(tokenizer=tokenizer,
data_path=data_args.data_path,
data_args=data_args)
if data_args.eval_data_path is None or data_args.eval_data_path == "":
print('Evaluation dataset not specified, skipping...')
eval_dataset = None
else:
eval_dataset = LazySupervisedDataset(tokenizer=tokenizer,
data_path=data_args.eval_data_path,
data_args=data_args)
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
return dict(train_dataset=train_dataset,
eval_dataset=eval_dataset,
data_collator=data_collator)