Pixart-Sigma / diffusion /data /datasets /InternalData_ms.py
artificialguybr's picture
Hi
eadd7b4
raw
history blame
No virus
15.3 kB
import os
import numpy as np
import torch
import random
from torchvision.datasets.folder import default_loader
from diffusion.data.datasets.InternalData import InternalData, InternalDataSigma
from diffusion.data.builder import get_data_path, DATASETS
from diffusion.utils.logger import get_root_logger
import torchvision.transforms as T
from torchvision.transforms.functional import InterpolationMode
from diffusion.data.datasets.utils import *
def get_closest_ratio(height: float, width: float, ratios: dict):
aspect_ratio = height / width
closest_ratio = min(ratios.keys(), key=lambda ratio: abs(float(ratio) - aspect_ratio))
return ratios[closest_ratio], float(closest_ratio)
@DATASETS.register_module()
class InternalDataMS(InternalData):
def __init__(self,
root,
image_list_json='data_info.json',
transform=None,
resolution=256,
sample_subset=None,
load_vae_feat=False,
input_size=32,
patch_size=2,
mask_ratio=0.0,
mask_type='null',
load_mask_index=False,
real_prompt_ratio=1.0,
max_length=120,
config=None,
**kwargs):
self.root = get_data_path(root)
self.transform = transform
self.load_vae_feat = load_vae_feat
self.ori_imgs_nums = 0
self.resolution = resolution
self.N = int(resolution // (input_size // patch_size))
self.mask_ratio = mask_ratio
self.load_mask_index = load_mask_index
self.mask_type = mask_type
self.real_prompt_ratio = real_prompt_ratio
self.max_lenth = max_length
self.base_size = int(kwargs['aspect_ratio_type'].split('_')[-1])
self.aspect_ratio = eval(kwargs.pop('aspect_ratio_type')) # base aspect ratio
self.meta_data_clean = []
self.img_samples = []
self.txt_feat_samples = []
self.vae_feat_samples = []
self.mask_index_samples = []
self.ratio_index = {}
self.ratio_nums = {}
# self.weight_dtype = torch.float16 if self.real_prompt_ratio > 0 else torch.float32
for k, v in self.aspect_ratio.items():
self.ratio_index[float(k)] = [] # used for self.getitem
self.ratio_nums[float(k)] = 0 # used for batch-sampler
image_list_json = image_list_json if isinstance(image_list_json, list) else [image_list_json]
for json_file in image_list_json:
meta_data = self.load_json(os.path.join(self.root, json_file))
self.ori_imgs_nums += len(meta_data)
meta_data_clean = [item for item in meta_data if item['ratio'] <= 4]
self.meta_data_clean.extend(meta_data_clean)
self.img_samples.extend([os.path.join(self.root.replace('InternData', "InternImgs"), item['path']) for item in meta_data_clean])
self.txt_feat_samples.extend([os.path.join(self.root, 'caption_features', '_'.join(item['path'].rsplit('/', 1)).replace('.png', '.npz')) for item in meta_data_clean])
self.vae_feat_samples.extend([os.path.join(self.root, f'img_vae_fatures_{resolution}_multiscale/ms', '_'.join(item['path'].rsplit('/', 1)).replace('.png', '.npy')) for item in meta_data_clean])
# Set loader and extensions
if load_vae_feat:
self.transform = None
self.loader = self.vae_feat_loader
else:
self.loader = default_loader
if sample_subset is not None:
self.sample_subset(sample_subset) # sample dataset for local debug
# scan the dataset for ratio static
for i, info in enumerate(self.meta_data_clean[:len(self.meta_data_clean)//3]):
ori_h, ori_w = info['height'], info['width']
closest_size, closest_ratio = get_closest_ratio(ori_h, ori_w, self.aspect_ratio)
self.ratio_nums[closest_ratio] += 1
if len(self.ratio_index[closest_ratio]) == 0:
self.ratio_index[closest_ratio].append(i)
# print(self.ratio_nums)
logger = get_root_logger() if config is None else get_root_logger(os.path.join(config.work_dir, 'train_log.log'))
logger.info(f"T5 max token length: {self.max_lenth}")
def getdata(self, index):
img_path = self.img_samples[index]
npz_path = self.txt_feat_samples[index]
npy_path = self.vae_feat_samples[index]
ori_h, ori_w = self.meta_data_clean[index]['height'], self.meta_data_clean[index]['width']
# Calculate the closest aspect ratio and resize & crop image[w, h]
closest_size, closest_ratio = get_closest_ratio(ori_h, ori_w, self.aspect_ratio)
closest_size = list(map(lambda x: int(x), closest_size))
self.closest_ratio = closest_ratio
if self.load_vae_feat:
try:
img = self.loader(npy_path)
if index not in self.ratio_index[closest_ratio]:
self.ratio_index[closest_ratio].append(index)
except Exception:
index = random.choice(self.ratio_index[closest_ratio])
return self.getdata(index)
h, w = (img.shape[1], img.shape[2])
assert h, w == (ori_h//8, ori_w//8)
else:
img = self.loader(img_path)
h, w = (img.size[1], img.size[0])
assert h, w == (ori_h, ori_w)
data_info = {'img_hw': torch.tensor([ori_h, ori_w], dtype=torch.float32)}
data_info['aspect_ratio'] = closest_ratio
data_info["mask_type"] = self.mask_type
txt_info = np.load(npz_path)
txt_fea = torch.from_numpy(txt_info['caption_feature'])
attention_mask = torch.ones(1, 1, txt_fea.shape[1])
if 'attention_mask' in txt_info.keys():
attention_mask = torch.from_numpy(txt_info['attention_mask'])[None]
if not self.load_vae_feat:
if closest_size[0] / ori_h > closest_size[1] / ori_w:
resize_size = closest_size[0], int(ori_w * closest_size[0] / ori_h)
else:
resize_size = int(ori_h * closest_size[1] / ori_w), closest_size[1]
self.transform = T.Compose([
T.Lambda(lambda img: img.convert('RGB')),
T.Resize(resize_size, interpolation=InterpolationMode.BICUBIC), # Image.BICUBIC
T.CenterCrop(closest_size),
T.ToTensor(),
T.Normalize([.5], [.5]),
])
if self.transform:
img = self.transform(img)
return img, txt_fea, attention_mask, data_info
def __getitem__(self, idx):
for _ in range(20):
try:
return self.getdata(idx)
except Exception as e:
print(f"Error details: {str(e)}")
idx = random.choice(self.ratio_index[self.closest_ratio])
raise RuntimeError('Too many bad data.')
@DATASETS.register_module()
class InternalDataMSSigma(InternalDataSigma):
def __init__(self,
root,
image_list_json='data_info.json',
transform=None,
resolution=256,
sample_subset=None,
load_vae_feat=False,
load_t5_feat=False,
input_size=32,
patch_size=2,
mask_ratio=0.0,
mask_type='null',
load_mask_index=False,
real_prompt_ratio=1.0,
max_length=300,
config=None,
**kwargs):
self.root = get_data_path(root)
self.transform = transform
self.load_vae_feat = load_vae_feat
self.load_t5_feat = load_t5_feat
self.ori_imgs_nums = 0
self.resolution = resolution
self.N = int(resolution // (input_size // patch_size))
self.mask_ratio = mask_ratio
self.load_mask_index = load_mask_index
self.mask_type = mask_type
self.real_prompt_ratio = real_prompt_ratio
self.max_lenth = max_length
self.base_size = int(kwargs['aspect_ratio_type'].split('_')[-1])
self.aspect_ratio = eval(kwargs.pop('aspect_ratio_type')) # base aspect ratio
self.meta_data_clean = []
self.img_samples = []
self.txt_samples = []
self.sharegpt4v_txt_samples = []
self.txt_feat_samples = []
self.vae_feat_samples = []
self.mask_index_samples = []
self.ratio_index = {}
self.ratio_nums = {}
self.gpt4v_txt_feat_samples = []
self.weight_dtype = torch.float16 if self.real_prompt_ratio > 0 else torch.float32
self.interpolate_model = InterpolationMode.BICUBIC
if self.aspect_ratio in [ASPECT_RATIO_2048, ASPECT_RATIO_2880]:
self.interpolate_model = InterpolationMode.LANCZOS
suffix = ''
for k, v in self.aspect_ratio.items():
self.ratio_index[float(k)] = [] # used for self.getitem
self.ratio_nums[float(k)] = 0 # used for batch-sampler
logger = get_root_logger() if config is None else get_root_logger(os.path.join(config.work_dir, 'train_log.log'))
logger.info(f"T5 max token length: {self.max_lenth}")
logger.info(f"ratio of real user prompt: {self.real_prompt_ratio}")
image_list_json = image_list_json if isinstance(image_list_json, list) else [image_list_json]
for json_file in image_list_json:
meta_data = self.load_json(os.path.join(self.root, json_file))
logger.info(f"{json_file} data volume: {len(meta_data)}")
self.ori_imgs_nums += len(meta_data)
meta_data_clean = [item for item in meta_data if item['ratio'] <= 4.5]
self.meta_data_clean.extend(meta_data_clean)
self.img_samples.extend([
os.path.join(self.root.replace('InternData'+suffix, 'InternImgs'), item['path']) for item in meta_data_clean
])
self.txt_samples.extend([item['prompt'] for item in meta_data_clean])
self.sharegpt4v_txt_samples.extend([item['sharegpt4v'] if 'sharegpt4v' in item else '' for item in meta_data_clean])
self.txt_feat_samples.extend([
os.path.join(
self.root,
'caption_features_new',
'_'.join(item['path'].rsplit('/', 1)).replace('.png', '.npz')
) for item in meta_data_clean
])
self.gpt4v_txt_feat_samples.extend([
os.path.join(
self.root,
'sharegpt4v_caption_features_new',
'_'.join(item['path'].rsplit('/', 1)).replace('.png', '.npz')
) for item in meta_data_clean
])
self.vae_feat_samples.extend(
[
os.path.join(
self.root + suffix,
f'img_sdxl_vae_features_{resolution}resolution_ms_new',
'_'.join(item['path'].rsplit('/', 1)).replace('.png', '.npy')
) for item in meta_data_clean
])
if self.real_prompt_ratio < 1:
assert len(self.sharegpt4v_txt_samples[0]) != 0
# Set loader and extensions
if load_vae_feat:
self.transform = None
self.loader = self.vae_feat_loader
else:
self.loader = default_loader
if sample_subset is not None:
self.sample_subset(sample_subset) # sample dataset for local debug
# scan the dataset for ratio static
for i, info in enumerate(self.meta_data_clean[:len(self.meta_data_clean)//3]):
ori_h, ori_w = info['height'], info['width']
closest_size, closest_ratio = get_closest_ratio(ori_h, ori_w, self.aspect_ratio)
self.ratio_nums[closest_ratio] += 1
if len(self.ratio_index[closest_ratio]) == 0:
self.ratio_index[closest_ratio].append(i)
def getdata(self, index):
img_path = self.img_samples[index]
real_prompt = random.random() < self.real_prompt_ratio
npz_path = self.txt_feat_samples[index] if real_prompt else self.gpt4v_txt_feat_samples[index]
txt = self.txt_samples[index] if real_prompt else self.sharegpt4v_txt_samples[index]
npy_path = self.vae_feat_samples[index]
data_info = {}
ori_h, ori_w = self.meta_data_clean[index]['height'], self.meta_data_clean[index]['width']
# Calculate the closest aspect ratio and resize & crop image[w, h]
closest_size, closest_ratio = get_closest_ratio(ori_h, ori_w, self.aspect_ratio)
closest_size = list(map(lambda x: int(x), closest_size))
self.closest_ratio = closest_ratio
if self.load_vae_feat:
img = self.loader(npy_path)
if index not in self.ratio_index[closest_ratio]:
self.ratio_index[closest_ratio].append(index)
h, w = (img.shape[1], img.shape[2])
assert h, w == (ori_h//8, ori_w//8)
else:
img = self.loader(img_path)
h, w = (img.size[1], img.size[0])
assert h, w == (ori_h, ori_w)
data_info['img_hw'] = torch.tensor([ori_h, ori_w], dtype=torch.float32)
data_info['aspect_ratio'] = closest_ratio
data_info["mask_type"] = self.mask_type
attention_mask = torch.ones(1, 1, self.max_lenth)
if self.load_t5_feat:
txt_info = np.load(npz_path)
txt_fea = torch.from_numpy(txt_info['caption_feature'])
if 'attention_mask' in txt_info.keys():
attention_mask = torch.from_numpy(txt_info['attention_mask'])[None]
if txt_fea.shape[1] != self.max_lenth:
txt_fea = torch.cat([txt_fea, txt_fea[:, -1:].repeat(1, self.max_lenth-txt_fea.shape[1], 1)], dim=1).to(self.weight_dtype)
attention_mask = torch.cat([attention_mask, torch.zeros(1, 1, self.max_lenth-attention_mask.shape[-1])], dim=-1)
else:
txt_fea = txt
if not self.load_vae_feat:
if closest_size[0] / ori_h > closest_size[1] / ori_w:
resize_size = closest_size[0], int(ori_w * closest_size[0] / ori_h)
else:
resize_size = int(ori_h * closest_size[1] / ori_w), closest_size[1]
self.transform = T.Compose([
T.Lambda(lambda img: img.convert('RGB')),
T.Resize(resize_size, interpolation=self.interpolate_model), # Image.BICUBIC
T.CenterCrop(closest_size),
T.ToTensor(),
T.Normalize([.5], [.5]),
])
if self.transform:
img = self.transform(img)
return img, txt_fea, attention_mask.to(torch.int16), data_info
def __getitem__(self, idx):
for _ in range(20):
try:
data = self.getdata(idx)
return data
except Exception as e:
print(f"Error details: {str(e)}")
idx = random.choice(self.ratio_index[self.closest_ratio])
raise RuntimeError('Too many bad data.')