zhouyik's picture
Upload folder using huggingface_hub
032e687 verified
import os
import io
import json
import copy
import torch
import numpy as np
import torch.nn.functional as F
from torch.utils.data import Dataset
from PIL import Image
import random
try:
from petrel_client.client import Client
except:
Client = None
from xtuner.registry import BUILDER
from mmdet.datasets.api_wrappers.coco_api import COCOPanoptic
import mmcv
import io
from mmengine.fileio import get
from panopticapi import utils
from xtuner.utils.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
from mmengine.logging import print_log
from typing import Dict, Sequence
from torch.utils.data import ConcatDataset
def concat_datasets(datasets_list):
datasets_list = [BUILDER.build(dataset_) for dataset_ in datasets_list]
return ConcatDataset(datasets_list)
def custom_collate_fn(instances: Sequence[Dict]):
# return instances
# all list
return {'data': instances, 'data_samples': None}
# keys = instances[0].keys()
# return {k: [inst[k] for inst in instances] for k in keys}
class PNGDataset(Dataset):
def __init__(self,
json_file,
panoptic_json_file,
panoptic_png_path,
image_processor=None, tokenizer=None,
ceph_path=None, local_path=None, prompt_template=None,
prompt='<image>\nWhat is shown in this image?',
image2tensor=True,
add_image_token=False,
image_token=DEFAULT_IMAGE_TOKEN):
super().__init__()
with open(json_file, 'r') as f:
self.data = json.load(f)
self.coco = COCOPanoptic(panoptic_json_file)
self.panoptic_png_path = panoptic_png_path
self.ceph_path = ceph_path
self.local_path = local_path
self.FILE_CLIENT = None
self.use_ceph = (Client is not None) and (ceph_path is not None)
if isinstance(tokenizer, dict):
self.tokenizer = BUILDER.build(tokenizer)
else:
self.tokenizer = tokenizer
if isinstance(image_processor, dict):
self.image_processor = BUILDER.build(image_processor)
else:
self.image_processor = image_processor
self.image2tensor = image2tensor
self.image_token = image_token
self.add_image_token = add_image_token
if add_image_token:
print_log(f"Manually add image token: {self.image_token}")
special_tokens_dict = {'additional_special_tokens': [self.image_token,]}
num_added_toks = self.tokenizer.add_special_tokens(special_tokens_dict)
assert num_added_toks == 1
self.image_token_idx = self.tokenizer.encode(self.image_token, add_special_tokens=False)[-1]
print_log(f"Image token: {self.tokenizer.decode(self.image_token_idx)}")
self.prompt = self.tokenizer.encode(
prompt_template['INSTRUCTION'].format(input=prompt),
add_special_tokens=True)
self.prompt_template = prompt_template
@staticmethod
def _load_segm(segm_path):
img_bytes = get(segm_path)
pan_png = mmcv.imfrombytes(
img_bytes, flag='color', channel_order='rgb').squeeze()
segm_map = utils.rgb2id(pan_png)
return segm_map
def __len__(self):
return len(self.data)
def read_image(self, image_file):
if self.use_ceph:
image_path = os.path.join(self.ceph_path, image_file)
if self.FILE_CLIENT is None:
self.FILE_CLIENT = Client()
img_bytes = self.FILE_CLIENT.get(image_path)
image = Image.open(io.BytesIO(img_bytes))
else:
image_path = os.path.join(self.local_path, image_file)
image = Image.open(image_path)
return image
def __getitem__(self, index):
data_sample = self.data[index]
mask_cnt = 0
caption_input_ids = []
mask_ids = [-1]*len(self.prompt)
mask_segment_ids = []
mask_infos = [] # record isthing, plural
image_id = int(data_sample['image_id'])
annotations = {ann['id']: ann for ann in self.coco.imgToAnns[image_id]}
for segment in data_sample['segments']:
segment_input_ids = self.tokenizer.encode(segment['utterance'], add_special_tokens=False)
caption_input_ids += segment_input_ids
if len(segment['segment_ids']) == 0:
mask_ids += [-1] * len(segment_input_ids)
else:
mask_ids += [mask_cnt] * len(segment_input_ids)
mask_segment_ids.append(segment['segment_ids'])
if not segment['plural']:
assert len(segment['segment_ids']) == 1
segment_id = int(segment['segment_ids'][0])
isthing = self.coco.cats[annotations[segment_id]['category_id']]['isthing']
else:
isthing = 1
mask_infos.append(dict(plural=segment['plural'],
isthing=isthing > 0))
# todo: load masks
mask_cnt += 1
if mask_cnt == 0:
return self.__getitem__(random.choice(range(self.__len__())))
image_info = self.coco.imgs[image_id]
segm_file = image_info['segm_file']
segm_map = self._load_segm(os.path.join(self.panoptic_png_path, segm_file))
masks = []
for mask_segment_ids_ in mask_segment_ids:
mask = 0
for segment_id in mask_segment_ids_:
mask += (segm_map == int(segment_id)).astype(np.uint8)
masks.append(np.clip(mask, a_max=1, a_min=0))
assert len(masks) == mask_cnt
input_ids = self.prompt + caption_input_ids
input_ids = torch.tensor(input_ids, dtype=torch.long)
mask_ids = torch.tensor(mask_ids)
image = self.read_image(image_info['file_name'])
image_data = self.image_processor.preprocess(image)
pixel_values = image_data['pixel_values'][0]
if self.image2tensor:
pixel_values = torch.from_numpy(pixel_values)
meta_data = image_data['meta_datas'][0]
masks = torch.from_numpy(np.stack(masks))
h, w = meta_data['image_shape']['height'], meta_data['image_shape']['width']
gt_masks = masks.clone()
masks = F.interpolate(masks[None], size=(h, w))[0]
p_h, p_w = meta_data['padded_shape']['height'], meta_data['padded_shape']['width']
padded_masks = torch.zeros(mask_cnt, p_h, p_w, dtype=masks.dtype)
padding = meta_data['padding']
padded_masks[:, padding['before_height']:p_h-padding['after_height'],
padding['before_width']:p_w-padding['after_width']] = masks
# todo: add labels
prompt_len = len(self.prompt)
labels = torch.ones_like(input_ids) * IGNORE_INDEX
labels[prompt_len:] = input_ids[prompt_len:]
if self.add_image_token:
input_ids[input_ids == self.image_token_idx] = IMAGE_TOKEN_INDEX
return dict(input_ids=input_ids,
mask_ids=mask_ids,
pixel_values=pixel_values,
padded_masks=padded_masks,
masks=masks, # shape is kept
gt_masks=gt_masks,
image_sizes=torch.tensor(image_data['image_sizes'][0]),
mask_infos=mask_infos,
image=image,
file_name=image_info['file_name'],
meta_data=meta_data,
labels=labels)
if __name__ == '__main__':
from xtuner.utils.templates import PROMPT_TEMPLATE
# prompt_template = PROMPT_TEMPLATE.mistral
prompt_template = PROMPT_TEMPLATE.vicuna
from transformers import AutoTokenizer
from transformers import AutoTokenizer
# from flmm.datasets.llava_next_image_processor import CustomLlavaNextImageProcessor
from projects.f_llm.datasets.llava_processors import CustomLlavaImageProcessor
from tqdm import tqdm
dataset = PNGDataset(
json_file='data/coco/annotations/png_coco_val2017.json',
panoptic_json_file='data/coco/annotations/panoptic_val2017.json',
panoptic_png_path='data/coco/annotations/panoptic_val2017',
# tokenizer=dict(
# type=AutoTokenizer.from_pretrained,
# pretrained_model_name_or_path='llava-hf/llava-v1.6-mistral-7b-hf'),
tokenizer=dict(
type=AutoTokenizer.from_pretrained,
pretrained_model_name_or_path='llava-hf/llava-1.5-7b-hf'),
# image_processor=dict(
# type=CustomLlavaNextImageProcessor.from_pretrained,
# pretrained_model_name_or_path='llava-hf/llava-v1.6-mistral-7b-hf'),
image_processor=dict(
type=CustomLlavaImageProcessor.from_pretrained,
pretrained_model_name_or_path='openai/clip-vit-large-patch14-336'),
prompt_template=prompt_template,
local_path='data/coco/val2017'
)
for i in tqdm(range(len(dataset))):
data = dataset.__getitem__(i)