File size: 5,262 Bytes
032e687 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 |
import json
import logging
import os
import copy
from typing import Any
import torch
from torch.utils.data import Dataset
import torchvision.transforms as T
from datasets import Dataset as HFDataset
from datasets import DatasetDict, load_from_disk
from mmengine import print_log
from mmengine.config import Config, ConfigDict
from xtuner.registry import BUILDER
from xtuner.dataset.huggingface import process_hf_dataset, build_origin_dataset
from projects.lisa.multiprocess_eval_refcoco import template
class OspreyRegionCaptionDataset(Dataset):
def __init__(self,
image_folder,
image_processor,
data_path=None,
tokenizer=None,
offline_processed_text_folder=None,
max_dataset_length=None,
dataset_map_fn=None,
template_map_fn=None,
max_length=2048,
pad_image_to_square=False,
num_proc=32,
lazy=False,
repeats=1):
super().__init__()
assert offline_processed_text_folder or (data_path and tokenizer)
self.lazy = lazy
self.max_length = max_length
self.dataset_map_fn = dataset_map_fn
self.template_map_fn = template_map_fn
if isinstance(self.template_map_fn, dict) and self.lazy:
_type = self.template_map_fn.pop('type')
self.template_map_fn = _type(**self.template_map_fn)
if offline_processed_text_folder and data_path:
print_log(
'Both `offline_processed_text_folder` and '
'`data_path` are set, and we load dataset from'
'`offline_processed_text_folder` '
f'({offline_processed_text_folder})',
logger='current',
level=logging.WARNING)
if offline_processed_text_folder is not None:
raise NotImplementedError
else:
json_data = self.json_file_preprocess(data_path)
self.json_data = json_data
json_data = self.filter_hf_require_infos(json_data)
# hf_json_data = DatasetDict({"train": HFDataset.from_list(json_data)})
if self.lazy:
self.text_data = build_origin_dataset(json_data, 'train')
else:
raise NotImplementedError
self.image_folder = image_folder
size = image_processor.crop_size
if isinstance(size, int):
self.image_h, self.image_w = size, size
else:
self.image_w, self.image_h = size
if isinstance(image_processor, dict) or isinstance(
image_processor, Config) or isinstance(image_processor,
ConfigDict):
self.image_processor = BUILDER.build(image_processor)
else:
self.image_processor = image_processor
self.pad_image_to_square = pad_image_to_square
self.down_ratio = 1
self.repeats = repeats
self.tokenizer = tokenizer
self.transformer = T.Compose([
T.Lambda(lambda img: img.convert('RGB')
if img.mode != 'RGB' else img),
T.Resize((self.image_size, self.image_size))
])
def filter_hf_require_infos(self, dataset_infos):
ret = {}
for dataset_info in dataset_infos:
description = dataset_info["description"]
image = dataset_info["file_name"]
required_info = {"image": image, "description": description}
ret.append(required_info)
return ret
def json_file_preprocess(self, data_path):
with open(data_path, 'r') as f:
json_file = json.load(f)
ret = []
for item in json_file:
item.update({'image': item['file_name']})
if len(item["description"]) != len(item["annotation"]):
print("The number of description is not equal to seg !!!")
else:
ret.append(item)
return ret
@property
def modality_length(self):
length_list = []
for data_dict in self.text_data:
if self.lazy:
cur_len = 100
else:
cur_len = len(data_dict['input_ids'])
if data_dict.get('image', None) is None:
cur_len = -cur_len
length_list.append(cur_len)
return length_list
def __len__(self):
return len(self.text_data) * self.repeats
def real_len(self):
return len(self.text_data)
def multi_modal_get_item(self, data_item):
# Build transformtion function
return
def __getitem__(self, index) -> Any:
index = index % self.real_len()
data_dict = copy.deepcopy(self.json_data[index])
data_dict.update(self.text_data[index])
if self.lazy:
result = self.dataset_map_fn(data_dict)
data_dict.update(result)
assert 'image' in data_dict.keys()
if data_dict.get('image', None) is not None:
image_file = data_dict['image']
|