MMFS / data /static_data.py
limoran
add basic files
7e2a2a5
raw
history blame
23.6 kB
import os, sys
import random
import numpy as np
from utils.augmentation import ImagePathToImage, NumpyToTensor
from utils.data_utils import Transforms
from utils.util import check_path_is_static_data
import torch
from PIL import Image
def check_dataname_folder_correspondence(data_names, group, group_name):
for data_name in data_names:
if data_name + '_folder' not in group:
print("%s not found in config file. Going to use dataroot mode to load group %s." % (data_name + '_folder', group_name))
return False
return True
def custom_check_path_exists(str1):
return True if (str1 == "None" or os.path.exists(str1)) else False
def custom_getsize(str1):
return 1 if str1 == "None" else os.path.getsize(str1)
def check_different_extension_path_exists(str1):
acceptable_extensions = ['png', 'jpg', 'jpeg', 'npy', 'npz', 'PNG', 'JPG', 'JPEG']
curr_extension = str1.split('.')[-1]
for extension in acceptable_extensions:
str2 = str1.replace(curr_extension, extension)
if os.path.exists(str2):
return str2
return None
class StaticData(object):
def __init__(self, config, shuffle=False):
# private variables
self.file_groups = []
self.type_groups = []
self.group_names = []
self.pair_type_groups = []
self.len_of_groups = []
self.transforms = {}
# parameters
self.shuffle = shuffle
self.config = config
def load_static_data(self):
data_dict = self.config['dataset'][self.config['common']['phase'] + '_data']
print("----------------loading %s static data.---------------------" % self.config['common']['phase'])
if len(data_dict) == 0:
self.len_of_groups.append(0)
return
self.group_names = list(data_dict.keys())
for i, group in enumerate(data_dict.values()): # examples: (0, group_1), (1, group_2)
data_types = group['data_types'] # examples: 'image', 'patch'
data_names = group['data_names'] # examples: 'real_A', 'patch_A'
self.file_groups.append({})
self.type_groups.append({})
self.len_of_groups.append(0)
self.pair_type_groups.append(group['paired'])
# exclude patch data since they are not stored on disk. They will be handled later.
data_types, data_names = self.exclude_patch_data(data_types, data_names)
assert(len(data_types) == len(data_names))
if len(data_names) == 0:
continue
for data_name, data_type in zip(data_names, data_types):
self.file_groups[i][data_name] = []
self.type_groups[i][data_name] = data_type
# paired data
if group['paired']:
# First way to load data: load a file list
if 'file_list' in group:
file_list = group['file_list']
paired_file = open(file_list, 'rt')
lines = paired_file.readlines()
if self.shuffle:
random.shuffle(lines)
for line in lines:
items = line.strip().split(' ')
if len(items) == len(data_names):
ok = True
for item in items:
ok = ok and os.path.exists(item) and os.path.getsize(item) > 0
if ok:
for data_name, item in zip(data_names, items):
self.file_groups[i][data_name].append(item)
paired_file.close()
# second and third way to load data: specify one folder for each dataname, or specify a dataroot folder
elif check_dataname_folder_correspondence(data_names, group, self.group_names[i]) or 'dataroot' in group:
dataname_to_dir_dict = {}
for data_name, data_type in zip(data_names, data_types):
if 'dataroot' in group:
# In new data config format, data is stored in dataroot_name/mode/dataname. e.g. FFHQ/train/pairedA
# In old format, data is stored in dataroot_name/mode_dataname. e.g. FFHQ/train_pairedA
# So we need to check both.
dir = os.path.join(group['dataroot'], self.config['common']['phase'], data_name)
if not os.path.exists(dir):
old_dir = os.path.join(group['dataroot'], self.config['common']['phase'] + data_name.replace('_', ''))
if 'numpy' in data_type:
old_dir += 'numpy'
if not os.path.exists(old_dir):
print("Both %s and %s does not exist. Please check." % (dir, old_dir))
sys.exit()
else:
dir = old_dir
else:
dir = group[data_name + '_folder']
if not os.path.exists(dir):
print("directory %s does not exist. Please check." % dir)
sys.exit()
dataname_to_dir_dict[data_name] = dir
filenames = os.listdir(dataname_to_dir_dict[data_names[0]])
if self.shuffle:
random.shuffle(filenames)
for filename in filenames:
if not check_path_is_static_data(filename):
continue
file_paths = []
for data_name in data_names:
file_path = os.path.join(dataname_to_dir_dict[data_name], filename)
checked_extension = check_different_extension_path_exists(file_path)
if checked_extension is not None:
file_paths.append(checked_extension)
if len(file_paths) != len(data_names):
print("for file %s , looks like some of the other pair data is missing. Ignoring and proceeding." % filename)
continue
else:
for j in range(len(data_names)):
data_name = data_names[j]
self.file_groups[i][data_name].append(file_paths[j])
else:
print("method for loading data is incorrect/unspecified for data group %s." % self.group_names)
sys.exit()
self.len_of_groups[i] = len(self.file_groups[i][data_names[0]])
# unpaired data
else:
# First way to load data: load a file list
if 'file_list' in group:
file_list = group['file_list']
unpaired_file = open(file_list, 'rt')
lines = unpaired_file.readlines()
if self.shuffle:
random.shuffle(lines)
item_count = 0
for line in lines:
items = line.strip().split(' ')
if len(items) == len(data_names):
ok = True
for item in items:
ok = ok and custom_check_path_exists(item) and custom_getsize(item) > 0
if ok:
has_data = False
for data_name, item in zip(data_names, items):
if item != 'None':
self.file_groups[i][data_name].append(item)
has_data = True
if has_data:
item_count += 1
unpaired_file.close()
self.len_of_groups[i] = item_count
# second and third way to load data: specify one folder for each dataname, or specify a dataroot folder
elif check_dataname_folder_correspondence(data_names, group, self.group_names[i]) or 'dataroot' in group:
max_length = 0
for data_name, data_type in zip(data_names, data_types):
if 'dataroot' in group:
# In new data config format, data is stored in dataroot_name/mode/dataname. e.g. FFHQ/train/pairedA
# In old format, data is stored in dataroot_name/mode_dataname. e.g. FFHQ/train_pairedA
# So we need to check both.
dir = os.path.join(group['dataroot'], self.config['common']['phase'], data_name)
if not os.path.exists(dir):
old_dir = os.path.join(group['dataroot'], self.config['common']['phase'] + data_name.replace('_', ''))
if 'numpy' in data_type:
old_dir += 'numpy'
if not os.path.exists(old_dir):
print("Both %s and %s does not exist. Please check." % (dir, old_dir))
sys.exit()
else:
dir = old_dir
else:
dir = group[data_name + '_folder']
if not os.path.exists(dir):
print("directory %s does not exist. Please check." % dir)
sys.exit()
filenames = os.listdir(dir)
if self.shuffle:
random.shuffle(filenames)
item_count = 0
for filename in filenames:
if not check_path_is_static_data(filename):
continue
fullpath = os.path.join(dir, filename)
if os.path.exists(fullpath):
self.file_groups[i][data_name].append(fullpath)
item_count += 1
max_length = max(item_count, max_length)
self.len_of_groups[i] = max_length
else:
print("method for loading data is incorrect/unspecified for data group %s." % self.group_names)
sys.exit()
def create_transforms(self):
btoA = self.config['dataset']['direction'] == 'BtoA'
input_nc = self.config['model']['output_nc'] if btoA else self.config['model']['input_nc']
output_nc = self.config['model']['input_nc'] if btoA else self.config['model']['output_nc']
input_grayscale_flag = (input_nc == 1)
output_grayscale_flag = (output_nc == 1)
data_dict = self.config['dataset'][self.config['common']['phase'] + '_data']
for i, group in enumerate(data_dict.values()): # examples: (0, group_1), (1, group_2)
if i not in self.transforms:
self.transforms[i] = {}
data_types = group['data_types'] # examples: 'image', 'patch'
data_names = group['data_names'] # examples: 'real_A', 'patch_A'
data_types, data_names = self.exclude_patch_data(data_types, data_names)
for data_name, data_type in zip(data_names, data_types):
if data_type in self.transforms[i]:
continue
self.transforms[i][data_type] = Transforms(self.config, input_grayscale_flag=input_grayscale_flag,
output_grayscale_flag=output_grayscale_flag)
self.transforms[i][data_type].create_transforms_from_list(group['preprocess'])
if '.png' in self.file_groups[i][data_name][0] or '.jpg' in self.file_groups[i][data_name][0] or \
'.jpeg' in self.file_groups[i][data_name][0]:
self.transforms[i][data_type].get_transforms().insert(0, ImagePathToImage())
elif '.npy' in self.file_groups[i][data_name][0] or '.npz' in self.file_groups[i][data_name][0]:
self.transforms[i][data_type].get_transforms().insert(0, NumpyToTensor())
self.transforms[i][data_type] = self.transforms[i][data_type].compose_transforms()
def apply_transformations_to_images(self, img_list, img_dataname_list, transform, return_dict,
next_img_paths_bucket, next_img_dataname_list):
if len(img_list) == 1:
return_dict[img_dataname_list[0]], _ = transform(img_list[0], None)
elif len(img_list) > 1:
next_data_count = len(next_img_paths_bucket)
img_list += next_img_paths_bucket
img_dataname_list += next_img_dataname_list
input1, input2 = img_list[0], img_list[1:]
output1, output2 = transform(input1, input2) # output1 is one image. output2 is a list of images.
if next_data_count != 0:
output2, next_outputs = output2[:-next_data_count], output2[-next_data_count:]
for i in range(next_data_count):
return_dict[img_dataname_list[-next_data_count+i] + '_next'] = next_outputs[i]
return_dict[img_dataname_list[0]] = output1
for j in range(0, len(output2)):
return_dict[img_dataname_list[j+1]] = output2[j]
return return_dict
def calculate_landmark_scale(self, data_path, data_type, i):
if data_type == 'image':
original_image = Image.open(data_path)
original_width, original_height = original_image.size
else:
original_image = np.load(data_path)
original_height, original_width = original_image.shape[0], original_image.shape[1]
transformed_image, _ = self.transforms[i][data_type](data_path, None)
transformed_height, transformed_width = transformed_image.size()[1:]
landmark_scale = (transformed_width / original_width, transformed_height / original_height)
return landmark_scale
def get_item(self, idx):
return_dict = {}
data_dict = self.config['dataset'][self.config['common']['phase'] + '_data']
for i, group in enumerate(data_dict.values()):
if self.file_groups[i] == {}:
continue
paired_type = self.pair_type_groups[i]
inner_idx = idx if idx < self.len_of_groups[i] else random.randint(0, self.len_of_groups[i] - 1)
landmark_scale = None
# for patches since they might need to be loaded from different images.
next_img_paths_bucket = []
next_img_dataname_list = []
next_numpy_paths_bucket = []
next_numpy_dataname_list = []
# First, handle all non-patch data
if paired_type:
img_paths_bucket = []
img_dataname_list = []
numpy_paths_bucket = []
numpy_dataname_list = []
for data_name, data_list in self.file_groups[i].items():
data_type = self.type_groups[i][data_name]
if data_type in ['image', 'numpy']:
if paired_type:
# augmentation will be applied to all images in paired group all at once so need to gather the images here.
if data_type == 'image':
img_paths_bucket.append(data_list[inner_idx])
img_dataname_list.append(data_name)
else:
numpy_paths_bucket.append(data_list[inner_idx])
numpy_dataname_list.append(data_name)
return_dict[data_name + '_path'] = data_list[inner_idx]
if landmark_scale is None:
landmark_scale = self.calculate_landmark_scale(data_list[inner_idx], data_type, i)
if 'diff_patch' in self.config['dataset'] and self.config['dataset']['diff_patch'] and \
data_name in group['patch_sources']:
next_idx = (inner_idx + 1) % (len(data_list) - 1)
if data_type == 'image':
next_img_paths_bucket.append(data_list[next_idx])
next_img_dataname_list.append(data_name)
else:
next_numpy_paths_bucket.append(data_list[next_idx])
next_numpy_dataname_list.append(data_name)
else:
unpaired_inner_idx = random.randint(0, len(data_list) - 1)
return_dict[data_name], _ = self.transforms[i][data_type](data_list[unpaired_inner_idx], None)
if landmark_scale is None:
landmark_scale = self.calculate_landmark_scale(data_list[unpaired_inner_idx], data_type, i)
if 'diff_patch' in self.config['dataset'] and self.config['dataset']['diff_patch'] and \
data_name in group['patch_sources']:
next_idx = (unpaired_inner_idx + 1) % (len(data_list) - 1)
return_dict[data_name + '_next'], _ = self.transforms[i][data_type](data_list[next_idx], None)
return_dict[data_name + '_path'] = data_list[unpaired_inner_idx]
elif self.type_groups[i][data_name] == 'landmark':
# We do not apply transformations on landmarks. Only scales landmarks to transformed image's size.
# Also numpy data is passed into network as numpy array and not tensor.
lmk = np.load(data_list[inner_idx])
if self.config['dataset']['landmark_scale'] is not None:
lmk[:, 0] *= self.config['dataset']['landmark_scale'][0]
lmk[:, 1] *= self.config['dataset']['landmark_scale'][1]
else:
if landmark_scale is None:
print("landmark_scale is None. If you have not defined it in config file, please specify "
"image and numpy data before landmark data and the proper scale will be automatically calculated.")
else:
lmk[:, 0] *= landmark_scale[0]
lmk[:, 1] *= landmark_scale[1]
return_dict[data_name] = lmk
return_dict[data_name + '_path'] = data_list[inner_idx]
if paired_type:
# apply augmentations to all images and numpy arrays
if 'image' in self.transforms[i]:
return_dict = self.apply_transformations_to_images(img_paths_bucket, img_dataname_list,
self.transforms[i]['image'], return_dict,
next_img_paths_bucket,
next_img_dataname_list)
if 'numpy' in self.transforms[i]:
return_dict = self.apply_transformations_to_images(numpy_paths_bucket, numpy_dataname_list,
self.transforms[i]['numpy'], return_dict,
next_numpy_paths_bucket,
next_numpy_dataname_list)
# Handle patch data
data_types = group['data_types'] # examples: 'image', 'patch'
data_names = group['data_names'] # examples: 'real_A', 'patch_A'
data_types, data_names = self.filter_patch_data(data_types, data_names)
if 'patch_sources' in group:
patch_sources = group['patch_sources']
return_dict = self.load_patches(
data_names,
self.config['dataset']['patch_batch_size'],
self.config['dataset']['batch_size'],
self.config['dataset']['patch_size'],
self.config['dataset']['patch_batch_size'] // self.config['dataset']['batch_size'],
self.config['dataset']['diff_patch'],
patch_sources,
return_dict,
)
return return_dict
def get_len(self):
if len(self.len_of_groups) == 0:
return 0
else:
return max(self.len_of_groups)
def exclude_patch_data(self, data_types, data_names):
data_types_patch_excluded = []
data_names_patch_excluded = []
for data_name, data_type in zip(data_names, data_types):
if data_type != 'patch':
data_types_patch_excluded.append(data_type)
data_names_patch_excluded.append(data_name)
return data_types_patch_excluded, data_names_patch_excluded
def filter_patch_data(self, data_types, data_names):
data_types_patch = []
data_names_patch = []
for data_name, data_type in zip(data_names, data_types):
if data_type == 'patch':
data_types_patch.append(data_type)
data_names_patch.append(data_name)
return data_types_patch, data_names_patch
def load_patches(self, data_names, patch_batch_size, batch_size, patch_size,
num_patch, diff_patch, patch_sources, return_dict):
if patch_size > 0:
assert (patch_batch_size % batch_size == 0), \
"patch_batch_size is not divisible by batch_size."
assert (len(patch_sources) == len(data_names)), \
"length of patch_sources is not the same as number of patch data specified. Please check again in config file."
rlist = [] # used for cropping patches
clist = [] # used for cropping patches
for _ in range(num_patch):
r = random.randint(0, self.config['dataset']['crop_size'] - patch_size - 1)
c = random.randint(0, self.config['dataset']['crop_size'] - patch_size - 1)
rlist.append(r)
clist.append(c)
for i in range(len(data_names)):
# load transformed image
patch = return_dict[patch_sources[i]] if not diff_patch else return_dict[patch_sources[i] + '_next']
# crop patch
patchs = []
_, h, w = patch.size()
for j in range(num_patch):
patchs.append(patch[:, rlist[j]:rlist[j] + patch_size, clist[j]:clist[j] + patch_size])
patchs = torch.cat(patchs, 0)
return_dict[data_names[i]] = patchs
return return_dict