|
import json |
|
import os |
|
|
|
import torch |
|
from datasets import Dataset as HFDataset |
|
from datasets import DatasetDict, load_from_disk |
|
from PIL import Image |
|
from torch.utils.data import Dataset |
|
from pycocotools import mask as maskUtils |
|
import numpy as np |
|
import copy |
|
|
|
from xtuner.registry import BUILDER |
|
from xtuner.dataset.huggingface import process_hf_dataset, build_origin_dataset |
|
import torchvision.transforms as T |
|
from xtuner.utils import DEFAULT_IMAGE_TOKEN |
|
from torchvision.transforms.functional import InterpolationMode |
|
from .encode_fn import video_lisa_encode_fn |
|
from .utils import dynamic_preprocess |
|
|
|
import random |
|
|
|
import torch.nn.functional as F |
|
|
|
class OspreyDataset(Dataset): |
|
os.environ['TOKENIZERS_PARALLELISM'] = 'true' |
|
IMG_CONTEXT_TOKEN = '<IMG_CONTEXT>' |
|
IMG_START_TOKEN = '<img>' |
|
IMG_END_TOKEN = '</img>' |
|
|
|
LIMIT = '' |
|
|
|
VP_START_TOKEN = '<vp>' |
|
VP_END_TOKEN = '</vp>' |
|
|
|
IMAGENET_MEAN = (0.485, 0.456, 0.406) |
|
IMAGENET_STD = (0.229, 0.224, 0.225) |
|
def __init__(self, |
|
image_folder, |
|
data_path=None, |
|
tokenizer=None, |
|
max_length=8196, |
|
special_tokens=None, |
|
template_map_fn=None, |
|
extra_image_processor=None, |
|
lazy=True, |
|
repeats=1, |
|
single_image_mode=False, |
|
): |
|
super().__init__() |
|
assert lazy |
|
self.lazy = lazy |
|
self.max_length = max_length |
|
|
|
json_data = self.json_file_preprocess(data_path) |
|
self.text_data = json_data |
|
|
|
self.image_folder = image_folder |
|
|
|
self.tokenizer = BUILDER.build(tokenizer) |
|
if special_tokens is not None: |
|
self.tokenizer.add_tokens(special_tokens, special_tokens=True) |
|
|
|
self.template_map_fn = template_map_fn |
|
if isinstance(self.template_map_fn, dict) and self.lazy: |
|
_type = self.template_map_fn['type'] |
|
del self.template_map_fn['type'] |
|
self.template_map_fn = _type(**self.template_map_fn) |
|
|
|
if extra_image_processor is not None: |
|
self.extra_image_processor = BUILDER.build(extra_image_processor) |
|
|
|
self.repeats = repeats |
|
|
|
self._system = '' |
|
|
|
self.min_dynamic_patch = 1 |
|
self.max_dynamic_patch = 12 |
|
self.downsample_ratio = 0.5 |
|
self.image_size = 448 |
|
self.use_thumbnail = True |
|
patch_size = 14 |
|
self.patch_size = patch_size |
|
self.patch_token = int((self.image_size // patch_size) ** 2 * (self.downsample_ratio ** 2)) |
|
|
|
self.transformer = T.Compose([ |
|
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img), |
|
T.Resize((self.image_size, self.image_size), interpolation=InterpolationMode.BICUBIC), |
|
T.ToTensor(), |
|
T.Normalize(mean=self.IMAGENET_MEAN, std=self.IMAGENET_STD) |
|
]) |
|
|
|
if special_tokens is not None: |
|
self.tokenizer.add_tokens(special_tokens, special_tokens=True) |
|
|
|
self.single_image_mode = single_image_mode |
|
|
|
def json_file_preprocess(self, data_path): |
|
with open(data_path, 'r') as f: |
|
json_data = json.load(f) |
|
return json_data |
|
|
|
@property |
|
def modality_length(self): |
|
length_list = [] |
|
for data_dict in self.text_data: |
|
if self.lazy: |
|
cur_len = 100 |
|
else: |
|
cur_len = len(data_dict['input_ids']) |
|
if data_dict.get('image', None) is None: |
|
cur_len = -cur_len |
|
length_list.append(cur_len) |
|
return length_list * self.repeats |
|
|
|
def __len__(self): |
|
return len(self.text_data) * self.repeats |
|
|
|
def real_len(self): |
|
return len(self.text_data) |
|
|
|
def annToMask(self, mask_ann, h, w): |
|
if isinstance(mask_ann, list): |
|
rles = maskUtils.frPyObjects(mask_ann, h, w) |
|
rle = maskUtils.merge(rles) |
|
elif isinstance(mask_ann['counts'], list): |
|
|
|
rle = maskUtils.frPyObjects(mask_ann, h, w) |
|
else: |
|
|
|
rle = mask_ann |
|
mask = maskUtils.decode(rle) |
|
return mask |
|
|
|
def decode_mask(self, object_masks, ori_height, ori_width): |
|
binary_masks = [] |
|
for object_mask in object_masks: |
|
binary_mask = self.annToMask(object_mask, ori_height, ori_width) |
|
binary_masks.append(binary_mask) |
|
if len(binary_masks) == 0: |
|
return None |
|
masks = np.stack(binary_masks, axis=0) |
|
masks = torch.from_numpy(masks) |
|
return masks |
|
|
|
def _process_conversation(self, converations, n_regions, region_pixels): |
|
start_region_str = '<image> There are {} part regions in the picture: '.format(n_regions) |
|
for i in range(n_regions): |
|
start_region_str = start_region_str + \ |
|
f"region{i+1}" + self.VP_START_TOKEN + self.IMG_CONTEXT_TOKEN * region_pixels[i] + self.VP_END_TOKEN |
|
if i == n_regions - 1: |
|
start_region_str = start_region_str + '.\n' |
|
else: |
|
start_region_str = start_region_str + ', ' |
|
|
|
for i, item in enumerate(converations): |
|
item['value'] = item['value'].replace('<', '').replace('>', '') |
|
if item['from'] == 'human': |
|
item['value'] = item['value'] + self.LIMIT |
|
|
|
if i == 0: |
|
assert item['from'] == "human" |
|
item['value'] = start_region_str + item['value'] |
|
|
|
messages = converations |
|
input = '' |
|
|
|
conversation = [] |
|
while messages and messages[0]['from'] == 'gpt': |
|
|
|
messages = messages[1:] |
|
for msg in messages: |
|
if msg['from'] == 'human': |
|
if DEFAULT_IMAGE_TOKEN in msg['value']: |
|
msg['value'] = msg['value'].replace(DEFAULT_IMAGE_TOKEN, |
|
'').strip() |
|
msg['value'] = DEFAULT_IMAGE_TOKEN + '\n' + msg['value'] |
|
msg['value'] = msg['value'].strip() |
|
input += msg['value'] |
|
|
|
elif msg['from'] == 'gpt': |
|
conversation.append({'input': input, 'output': msg['value']}) |
|
input = '' |
|
else: |
|
raise NotImplementedError |
|
|
|
return conversation |
|
|
|
def _get_region_infos(self, masks): |
|
|
|
masks = F.interpolate( |
|
masks.unsqueeze(0), |
|
size=(int(self.image_size // self.patch_size * self.downsample_ratio), |
|
int(self.image_size // self.patch_size * self.downsample_ratio)), |
|
mode='nearest').squeeze(0) |
|
region_pixels = [] |
|
for mask in masks: |
|
region_pixels.append(mask.bool().to(torch.int64).sum()) |
|
return masks, region_pixels |
|
|
|
def dataset_map_fn(self, data_dict): |
|
file_name = data_dict['file_name'] |
|
conversations = data_dict['conversations'] |
|
masks = [anno["segmentation"] for anno in data_dict["annotation"]] |
|
height = data_dict['height'] |
|
width = data_dict['width'] |
|
_ret = {} |
|
|
|
_ret['image'] = file_name |
|
_ret['height'] = height |
|
_ret['width'] = width |
|
|
|
masks = self.decode_mask(masks, height, width) |
|
masks, region_pixels = self._get_region_infos(masks) |
|
|
|
if masks is None: |
|
return None |
|
|
|
conversations = self._process_conversation(conversations, len(masks), region_pixels) |
|
_ret['conversation'] = conversations |
|
_ret['prompt_masks'] = masks |
|
return _ret |
|
|
|
def replace_image_str(self, data_dict, image_str): |
|
data_dict['conversation'][0]['input'] = \ |
|
data_dict['conversation'][0]['input'].replace(DEFAULT_IMAGE_TOKEN, image_str) |
|
return data_dict |
|
|
|
def __getitem__(self, index): |
|
|
|
index = index % self.real_len() |
|
data_dict = copy.deepcopy(self.text_data[index]) |
|
|
|
|
|
result = self.dataset_map_fn(data_dict) |
|
if result is None or result['prompt_masks'] is None: |
|
return self.__getitem__(0) |
|
|
|
data_dict = result |
|
|
|
|
|
image_file = data_dict['image'] |
|
if isinstance(self.image_folder, list): |
|
for image_folder in self.image_folder: |
|
image_path = os.path.join(image_folder, image_file) |
|
if os.path.exists(image_path): |
|
image = Image.open(image_path).convert('RGB') |
|
break |
|
else: |
|
image = Image.open(os.path.join(self.image_folder, |
|
image_file)).convert('RGB') |
|
ori_width, ori_height = image.size |
|
|
|
if self.single_image_mode: |
|
images = [image] |
|
else: |
|
images = dynamic_preprocess(image, self.min_dynamic_patch, |
|
self.max_dynamic_patch, |
|
self.image_size, self.use_thumbnail) |
|
vp_overall_mask = torch.Tensor([False] * (len(images) - 1) + [True]) |
|
data_dict['vp_overall_mask'] = vp_overall_mask |
|
|
|
pixel_values = [self.transformer(image) for image in images] |
|
pixel_values = torch.stack(pixel_values) |
|
data_dict['pixel_values'] = pixel_values |
|
|
|
num_image_tokens = pixel_values.shape[0] * self.patch_token |
|
image_token_str = f'{self.IMG_START_TOKEN}' \ |
|
f'{self.IMG_CONTEXT_TOKEN * num_image_tokens}' \ |
|
f'{self.IMG_END_TOKEN}' |
|
|
|
data_dict = self.replace_image_str(data_dict, image_token_str) |
|
|
|
result = self.template_map_fn(data_dict) |
|
data_dict.update(result) |
|
result = video_lisa_encode_fn(data_dict, tokenizer=self.tokenizer, max_length=self.max_length, |
|
with_image_token=True) |
|
data_dict.update(result) |
|
|
|
|
|
|
|
if data_dict['prompt_masks'] is None: |
|
return self.__getitem__(0) |
|
|
|
return data_dict |
|
|
|
|
|
DETAILED_QUESTIONS = [ |
|
'Can you provide me with a detailed description of the region in the picture marked by <region>?', |
|
"I'm curious about the region represented by <region> in the picture. Could you describe it in detail?", |
|
'What can you tell me about the region indicated by <region> in the image?', |
|
"I'd like to know more about the area in the photo labeled <region>. Can you give me a detailed description?", |
|
'Could you describe the region shown as <region> in the picture in great detail?', |
|
'What details can you give me about the region outlined by <region> in the photo?', |
|
'Please provide me with a comprehensive description of the region marked with <region> in the image.', |
|
'Can you give me a detailed account of the region labeled as <region> in the picture?', |
|
"I'm interested in learning more about the region represented by <region> in the photo. Can you describe it in detail?", |
|
'What is the region outlined by <region> in the picture like? Could you give me a detailed description?', |
|
'Can you provide me with a detailed description of the region in the picture marked by <region>, please?', |
|
"I'm curious about the region represented by <region> in the picture. Could you describe it in detail, please?", |
|
'What can you tell me about the region indicated by <region> in the image, exactly?', |
|
"I'd like to know more about the area in the photo labeled <region>, please. Can you give me a detailed description?", |
|
'Could you describe the region shown as <region> in the picture in great detail, please?', |
|
'What details can you give me about the region outlined by <region> in the photo, please?', |
|
'Please provide me with a comprehensive description of the region marked with <region> in the image, please.', |
|
'Can you give me a detailed account of the region labeled as <region> in the picture, please?', |
|
"I'm interested in learning more about the region represented by <region> in the photo. Can you describe it in detail, please?", |
|
'What is the region outlined by <region> in the picture like, please? Could you give me a detailed description?', |
|
'Please describe the region <region> in the image in detail.', |
|
'Can you offer a thorough analysis of the region <region> in the image?', |
|
'Could you elaborate on the region highlighted by <region> in the picture provided?', |
|
'Please share more information about the zone emphasized with <region> in the photo.', |
|
'What insights can you give about the area denoted by <region> in the image presented?', |
|
'Can you share a comprehensive rundown of the region denoted by <region> in the presented image?', |
|
"I'd like to know more about the region highlighted by <region> in the picture provided.", |
|
'Work through the important details of the area <region> in the image.', |
|
'Illustrate the area represented by <region> through a descriptive explanation.', |
|
'Examine the region <region> closely and share its details.' |
|
] |
|
|
|
class OspreyDescriptionDataset(OspreyDataset): |
|
os.environ['TOKENIZERS_PARALLELISM'] = 'true' |
|
IMG_CONTEXT_TOKEN = '<IMG_CONTEXT>' |
|
IMG_START_TOKEN = '<img>' |
|
IMG_END_TOKEN = '</img>' |
|
|
|
VP_START_TOKEN = '<vp>' |
|
VP_END_TOKEN = '</vp>' |
|
|
|
LIMIT='' |
|
|
|
IMAGENET_MEAN = (0.485, 0.456, 0.406) |
|
IMAGENET_STD = (0.229, 0.224, 0.225) |
|
def __init__(self, |
|
image_folder, |
|
data_path=None, |
|
tokenizer=None, |
|
max_length=8196, |
|
special_tokens=None, |
|
template_map_fn=None, |
|
extra_image_processor=None, |
|
lazy=True, |
|
repeats=1, |
|
single_image_mode=False, |
|
): |
|
super(OspreyDescriptionDataset, self).__init__( |
|
image_folder=image_folder, |
|
data_path=data_path, |
|
tokenizer=tokenizer, |
|
max_length=max_length, |
|
special_tokens=special_tokens, |
|
template_map_fn=template_map_fn, |
|
extra_image_processor=extra_image_processor, |
|
lazy=lazy, |
|
repeats=repeats, |
|
single_image_mode=single_image_mode, |
|
) |
|
|
|
def dataset_map_fn(self, data_dict): |
|
file_name = data_dict['file_name'] |
|
descriptions = data_dict['description'] |
|
masks = [anno["segmentation"] for anno in data_dict["annotation"]] |
|
height = data_dict['height'] |
|
width = data_dict['width'] |
|
_ret = {} |
|
|
|
_ret['image'] = file_name |
|
_ret['height'] = height |
|
_ret['width'] = width |
|
|
|
masks = self.decode_mask(masks, height, width) |
|
masks, region_pixels = self._get_region_infos(masks) |
|
|
|
if masks is None: |
|
return None |
|
|
|
conversations = self._process_conversation(descriptions, len(masks), region_pixels) |
|
_ret['conversation'] = conversations |
|
_ret['prompt_masks'] = masks |
|
return _ret |
|
|
|
def _process_conversation(self, descriptions, n_regions, region_pixels): |
|
start_region_str = '<image> There are {} part regions in the picture: '.format(n_regions) |
|
for i in range(n_regions): |
|
start_region_str = start_region_str + \ |
|
f"region{i+1}" + self.VP_START_TOKEN + self.IMG_CONTEXT_TOKEN * region_pixels[i] + self.VP_END_TOKEN |
|
if i == n_regions - 1: |
|
start_region_str = start_region_str + '.\n' |
|
else: |
|
start_region_str = start_region_str + ', ' |
|
|
|
converations = [] |
|
for i, item in enumerate(descriptions): |
|
question = random.choice(DETAILED_QUESTIONS).strip().replace('<region>', f"region{i+1}") + self.LIMIT |
|
answer = item.replace('<', '').replace('>', '') |
|
|
|
if i == 0: |
|
question = start_region_str + question |
|
converations.append({'from': 'human', 'value': question}) |
|
converations.append({'from': 'gpt', 'value': answer}) |
|
|
|
messages = converations |
|
input = '' |
|
|
|
conversation = [] |
|
while messages and messages[0]['from'] == 'gpt': |
|
|
|
messages = messages[1:] |
|
for msg in messages: |
|
if msg['from'] == 'human': |
|
if DEFAULT_IMAGE_TOKEN in msg['value']: |
|
msg['value'] = msg['value'].replace(DEFAULT_IMAGE_TOKEN, |
|
'').strip() |
|
msg['value'] = DEFAULT_IMAGE_TOKEN + '\n' + msg['value'] |
|
msg['value'] = msg['value'].strip() |
|
input += msg['value'] |
|
|
|
elif msg['from'] == 'gpt': |
|
conversation.append({'input': input, 'output': msg['value']}) |
|
input = '' |
|
else: |
|
raise NotImplementedError |
|
return conversation |
|
|
|
|
|
class OspreyShortDescriptionDataset(OspreyDataset): |
|
os.environ['TOKENIZERS_PARALLELISM'] = 'true' |
|
IMG_CONTEXT_TOKEN = '<IMG_CONTEXT>' |
|
IMG_START_TOKEN = '<img>' |
|
IMG_END_TOKEN = '</img>' |
|
|
|
VP_START_TOKEN = '<vp>' |
|
VP_END_TOKEN = '</vp>' |
|
|
|
LIMIT = ' Answer the question using a single word or phrase.' |
|
|
|
IMAGENET_MEAN = (0.485, 0.456, 0.406) |
|
IMAGENET_STD = (0.229, 0.224, 0.225) |
|
|
|
def __init__(self, |
|
image_folder, |
|
data_path=None, |
|
tokenizer=None, |
|
max_length=8196, |
|
special_tokens=None, |
|
template_map_fn=None, |
|
extra_image_processor=None, |
|
lazy=True, |
|
repeats=1, |
|
single_image_mode=False, |
|
): |
|
super(OspreyShortDescriptionDataset, self).__init__( |
|
image_folder=image_folder, |
|
data_path=data_path, |
|
tokenizer=tokenizer, |
|
max_length=max_length, |
|
special_tokens=special_tokens, |
|
template_map_fn=template_map_fn, |
|
extra_image_processor=extra_image_processor, |
|
lazy=lazy, |
|
repeats=repeats, |
|
single_image_mode=single_image_mode, |
|
) |