zhouyik's picture
Upload folder using huggingface_hub
032e687 verified
from typing import Optional, Dict, Union, Tuple, List
from PIL import Image
import mmengine.fileio as fileio
from mmengine.logging import print_log
import io
from mmcv.transforms import LoadImageFromFile, BaseTransform
from xtuner.registry import BUILDER
from xtuner.utils.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
import torch
import torch.nn.functional as F
import copy
class PILLoadImageFromFile(LoadImageFromFile):
def __init__(self, **kwargs):
backend_args = kwargs.pop('backend_args', None)
super().__init__(backend_args=backend_args, **kwargs)
def transform(self, results: dict) -> Optional[dict]:
"""Functions to load image.
Args:
results (dict): Result dict from
:class:`mmengine.dataset.BaseDataset`.
Returns:
dict: The dict contains loaded image and meta information.
"""
filename = results['img_path']
try:
if self.file_client_args is not None:
file_client = fileio.FileClient.infer_client(
self.file_client_args, filename)
img_bytes = file_client.get(filename)
else:
img_bytes = fileio.get(
filename, backend_args=self.backend_args)
img = Image.open(io.BytesIO(img_bytes))
except Exception as e:
if self.ignore_empty:
return None
else:
raise e
# in some cases, images are not read successfully, the img would be
# `None`, refer to https://github.com/open-mmlab/mmpretrain/issues/1427
assert img is not None, f'failed to load image: {filename}'
results['img'] = img
results['img_shape'] = (img.height, img.width)
results['ori_shape'] = (img.height, img.width)
return results
class RefCOCO2PNG(BaseTransform):
def __init__(self,
image_processor=None,
tokenizer=None,
prompt_template=None,
prompt='<image>\nWhat is shown in this image?',
concat=True,
image2tensor=True,
add_image_token=False,
image_token=DEFAULT_IMAGE_TOKEN):
self.tokenizer = BUILDER.build(tokenizer)
self.image_processor = BUILDER.build(image_processor)
self.concat = concat
self.image2tensor = image2tensor
self.image_token = image_token
self.add_image_token = add_image_token
if add_image_token:
print_log(f"Manually add image token: {self.image_token}")
special_tokens_dict = {'additional_special_tokens': [self.image_token, ]}
num_added_toks = self.tokenizer.add_special_tokens(special_tokens_dict)
assert num_added_toks == 1
self.image_token_idx = self.tokenizer.encode(self.image_token, add_special_tokens=False)[-1]
print_log(f"Image token: {self.tokenizer.decode(self.image_token_idx)}")
self.prompt = self.tokenizer.encode(
prompt_template['INSTRUCTION'].format(input=prompt),
add_special_tokens=True)
self.prompt_template = prompt_template
def transform(self, results):
if self.concat:
return self.transform_concat(results)
else:
return self.transform_split(results)
def transform_split(self, results):
all_results = []
for inst_id, instant_text in enumerate(results['text']):
new_results = copy.deepcopy(results)
new_results['text'] = [instant_text]
new_results['gt_masks'] = results['gt_masks'][inst_id:inst_id+1]
all_results.append(self.transform_concat(new_results))
return all_results
def transform_concat(self, results: dict):
caption_input_ids = []
mask_ids = [-1] * len(self.prompt)
split_token_id = self.tokenizer.encode('.', add_special_tokens=False)[-1]
for inst_id, instant_text in enumerate(results['text']):
segment_input_ids = self.tokenizer.encode(instant_text, add_special_tokens=False)
caption_input_ids += segment_input_ids
mask_ids += [inst_id] * len(segment_input_ids)
caption_input_ids.append(split_token_id)
mask_ids.append(-1)
input_ids = self.prompt + caption_input_ids
input_ids = torch.tensor(input_ids, dtype=torch.long)
mask_ids = torch.tensor(mask_ids)
image = results['img']
image_data = self.image_processor.preprocess(image)
pixel_values = image_data['pixel_values'][0]
if self.image2tensor:
pixel_values = torch.from_numpy(pixel_values)
meta_data = image_data['meta_datas'][0]
assert len(results['gt_masks'].masks) == len(results['text'])
mask_cnt = len(results['text'])
masks = torch.from_numpy(results['gt_masks'].masks).float()
h, w = meta_data['image_shape']['height'], meta_data['image_shape']['width']
gt_masks = masks.clone()
masks = F.interpolate(masks[None], size=(h, w))[0]
p_h, p_w = meta_data['padded_shape']['height'], meta_data['padded_shape']['width']
padded_masks = torch.zeros(mask_cnt, p_h, p_w, dtype=masks.dtype)
padding = meta_data['padding']
padded_masks[:, padding['before_height']:p_h - padding['after_height'],
padding['before_width']:p_w - padding['after_width']] = masks
# todo: add labels
prompt_len = len(self.prompt)
labels = torch.ones_like(input_ids) * IGNORE_INDEX
labels[prompt_len:] = input_ids[prompt_len:]
if self.add_image_token:
input_ids[input_ids == self.image_token_idx] = IMAGE_TOKEN_INDEX
return dict(input_ids=input_ids,
mask_ids=mask_ids,
pixel_values=pixel_values,
padded_masks=padded_masks,
masks=masks, # shape is kept
gt_masks=gt_masks,
image_sizes=torch.tensor(image_data['image_sizes'][0]),
image=image,
meta_data=meta_data,
labels=labels)
if __name__ == '__main__':
from mmdet.datasets import RefCocoDataset
from mmengine.config import Config
from mmdet.datasets.transforms import LoadAnnotations
cfg = Config.fromfile('configs/fuyu/frozen_fuyu_8b_unet_sam_l_refcoco_png.py')
prompt_template = cfg.prompt_template
tokenizer = cfg.tokenizer
image_processor = cfg.image_processor
prompt = cfg.get('prompt', None)
refcoco2png_params = dict(
type=RefCOCO2PNG,
image_processor=image_processor,
tokenizer=tokenizer,
prompt_template=prompt_template,
)
if prompt is not None:
refcoco2png_params.update(prompt=prompt)
test_pipeline = [
dict(type=PILLoadImageFromFile, backend_args=None),
dict(
type=LoadAnnotations,
with_mask=True,
with_bbox=False,
with_seg=False,
with_label=False),
refcoco2png_params
]
dataset = RefCocoDataset(
data_root='data/coco/',
data_prefix=dict(img_path='train2014/'),
text_mode='select_first',
pipeline=test_pipeline,
ann_file='refcoco/instances.json',
split_file='refcoco/refs(unc).p',
split='val'
)
for data in dataset:
print(data.keys())