Spaces:
Sleeping
Sleeping
import os, sys | |
import random | |
import numpy as np | |
from utils.augmentation import ImagePathToImage, NumpyToTensor | |
from utils.data_utils import Transforms | |
from utils.util import check_path_is_static_data | |
import torch | |
from PIL import Image | |
def check_dataname_folder_correspondence(data_names, group, group_name): | |
for data_name in data_names: | |
if data_name + '_folder' not in group: | |
print("%s not found in config file. Going to use dataroot mode to load group %s." % (data_name + '_folder', group_name)) | |
return False | |
return True | |
def custom_check_path_exists(str1): | |
return True if (str1 == "None" or os.path.exists(str1)) else False | |
def custom_getsize(str1): | |
return 1 if str1 == "None" else os.path.getsize(str1) | |
def check_different_extension_path_exists(str1): | |
acceptable_extensions = ['png', 'jpg', 'jpeg', 'npy', 'npz', 'PNG', 'JPG', 'JPEG'] | |
curr_extension = str1.split('.')[-1] | |
for extension in acceptable_extensions: | |
str2 = str1.replace(curr_extension, extension) | |
if os.path.exists(str2): | |
return str2 | |
return None | |
class StaticData(object): | |
def __init__(self, config, shuffle=False): | |
# private variables | |
self.file_groups = [] | |
self.type_groups = [] | |
self.group_names = [] | |
self.pair_type_groups = [] | |
self.len_of_groups = [] | |
self.transforms = {} | |
# parameters | |
self.shuffle = shuffle | |
self.config = config | |
def load_static_data(self): | |
data_dict = self.config['dataset'][self.config['common']['phase'] + '_data'] | |
print("----------------loading %s static data.---------------------" % self.config['common']['phase']) | |
if len(data_dict) == 0: | |
self.len_of_groups.append(0) | |
return | |
self.group_names = list(data_dict.keys()) | |
for i, group in enumerate(data_dict.values()): # examples: (0, group_1), (1, group_2) | |
data_types = group['data_types'] # examples: 'image', 'patch' | |
data_names = group['data_names'] # examples: 'real_A', 'patch_A' | |
self.file_groups.append({}) | |
self.type_groups.append({}) | |
self.len_of_groups.append(0) | |
self.pair_type_groups.append(group['paired']) | |
# exclude patch data since they are not stored on disk. They will be handled later. | |
data_types, data_names = self.exclude_patch_data(data_types, data_names) | |
assert(len(data_types) == len(data_names)) | |
if len(data_names) == 0: | |
continue | |
for data_name, data_type in zip(data_names, data_types): | |
self.file_groups[i][data_name] = [] | |
self.type_groups[i][data_name] = data_type | |
# paired data | |
if group['paired']: | |
# First way to load data: load a file list | |
if 'file_list' in group: | |
file_list = group['file_list'] | |
paired_file = open(file_list, 'rt') | |
lines = paired_file.readlines() | |
if self.shuffle: | |
random.shuffle(lines) | |
for line in lines: | |
items = line.strip().split(' ') | |
if len(items) == len(data_names): | |
ok = True | |
for item in items: | |
ok = ok and os.path.exists(item) and os.path.getsize(item) > 0 | |
if ok: | |
for data_name, item in zip(data_names, items): | |
self.file_groups[i][data_name].append(item) | |
paired_file.close() | |
# second and third way to load data: specify one folder for each dataname, or specify a dataroot folder | |
elif check_dataname_folder_correspondence(data_names, group, self.group_names[i]) or 'dataroot' in group: | |
dataname_to_dir_dict = {} | |
for data_name, data_type in zip(data_names, data_types): | |
if 'dataroot' in group: | |
# In new data config format, data is stored in dataroot_name/mode/dataname. e.g. FFHQ/train/pairedA | |
# In old format, data is stored in dataroot_name/mode_dataname. e.g. FFHQ/train_pairedA | |
# So we need to check both. | |
dir = os.path.join(group['dataroot'], self.config['common']['phase'], data_name) | |
if not os.path.exists(dir): | |
old_dir = os.path.join(group['dataroot'], self.config['common']['phase'] + data_name.replace('_', '')) | |
if 'numpy' in data_type: | |
old_dir += 'numpy' | |
if not os.path.exists(old_dir): | |
print("Both %s and %s does not exist. Please check." % (dir, old_dir)) | |
sys.exit() | |
else: | |
dir = old_dir | |
else: | |
dir = group[data_name + '_folder'] | |
if not os.path.exists(dir): | |
print("directory %s does not exist. Please check." % dir) | |
sys.exit() | |
dataname_to_dir_dict[data_name] = dir | |
filenames = os.listdir(dataname_to_dir_dict[data_names[0]]) | |
if self.shuffle: | |
random.shuffle(filenames) | |
for filename in filenames: | |
if not check_path_is_static_data(filename): | |
continue | |
file_paths = [] | |
for data_name in data_names: | |
file_path = os.path.join(dataname_to_dir_dict[data_name], filename) | |
checked_extension = check_different_extension_path_exists(file_path) | |
if checked_extension is not None: | |
file_paths.append(checked_extension) | |
if len(file_paths) != len(data_names): | |
print("for file %s , looks like some of the other pair data is missing. Ignoring and proceeding." % filename) | |
continue | |
else: | |
for j in range(len(data_names)): | |
data_name = data_names[j] | |
self.file_groups[i][data_name].append(file_paths[j]) | |
else: | |
print("method for loading data is incorrect/unspecified for data group %s." % self.group_names) | |
sys.exit() | |
self.len_of_groups[i] = len(self.file_groups[i][data_names[0]]) | |
# unpaired data | |
else: | |
# First way to load data: load a file list | |
if 'file_list' in group: | |
file_list = group['file_list'] | |
unpaired_file = open(file_list, 'rt') | |
lines = unpaired_file.readlines() | |
if self.shuffle: | |
random.shuffle(lines) | |
item_count = 0 | |
for line in lines: | |
items = line.strip().split(' ') | |
if len(items) == len(data_names): | |
ok = True | |
for item in items: | |
ok = ok and custom_check_path_exists(item) and custom_getsize(item) > 0 | |
if ok: | |
has_data = False | |
for data_name, item in zip(data_names, items): | |
if item != 'None': | |
self.file_groups[i][data_name].append(item) | |
has_data = True | |
if has_data: | |
item_count += 1 | |
unpaired_file.close() | |
self.len_of_groups[i] = item_count | |
# second and third way to load data: specify one folder for each dataname, or specify a dataroot folder | |
elif check_dataname_folder_correspondence(data_names, group, self.group_names[i]) or 'dataroot' in group: | |
max_length = 0 | |
for data_name, data_type in zip(data_names, data_types): | |
if 'dataroot' in group: | |
# In new data config format, data is stored in dataroot_name/mode/dataname. e.g. FFHQ/train/pairedA | |
# In old format, data is stored in dataroot_name/mode_dataname. e.g. FFHQ/train_pairedA | |
# So we need to check both. | |
dir = os.path.join(group['dataroot'], self.config['common']['phase'], data_name) | |
if not os.path.exists(dir): | |
old_dir = os.path.join(group['dataroot'], self.config['common']['phase'] + data_name.replace('_', '')) | |
if 'numpy' in data_type: | |
old_dir += 'numpy' | |
if not os.path.exists(old_dir): | |
print("Both %s and %s does not exist. Please check." % (dir, old_dir)) | |
sys.exit() | |
else: | |
dir = old_dir | |
else: | |
dir = group[data_name + '_folder'] | |
if not os.path.exists(dir): | |
print("directory %s does not exist. Please check." % dir) | |
sys.exit() | |
filenames = os.listdir(dir) | |
if self.shuffle: | |
random.shuffle(filenames) | |
item_count = 0 | |
for filename in filenames: | |
if not check_path_is_static_data(filename): | |
continue | |
fullpath = os.path.join(dir, filename) | |
if os.path.exists(fullpath): | |
self.file_groups[i][data_name].append(fullpath) | |
item_count += 1 | |
max_length = max(item_count, max_length) | |
self.len_of_groups[i] = max_length | |
else: | |
print("method for loading data is incorrect/unspecified for data group %s." % self.group_names) | |
sys.exit() | |
def create_transforms(self): | |
btoA = self.config['dataset']['direction'] == 'BtoA' | |
input_nc = self.config['model']['output_nc'] if btoA else self.config['model']['input_nc'] | |
output_nc = self.config['model']['input_nc'] if btoA else self.config['model']['output_nc'] | |
input_grayscale_flag = (input_nc == 1) | |
output_grayscale_flag = (output_nc == 1) | |
data_dict = self.config['dataset'][self.config['common']['phase'] + '_data'] | |
for i, group in enumerate(data_dict.values()): # examples: (0, group_1), (1, group_2) | |
if i not in self.transforms: | |
self.transforms[i] = {} | |
data_types = group['data_types'] # examples: 'image', 'patch' | |
data_names = group['data_names'] # examples: 'real_A', 'patch_A' | |
data_types, data_names = self.exclude_patch_data(data_types, data_names) | |
for data_name, data_type in zip(data_names, data_types): | |
if data_type in self.transforms[i]: | |
continue | |
self.transforms[i][data_type] = Transforms(self.config, input_grayscale_flag=input_grayscale_flag, | |
output_grayscale_flag=output_grayscale_flag) | |
self.transforms[i][data_type].create_transforms_from_list(group['preprocess']) | |
if '.png' in self.file_groups[i][data_name][0] or '.jpg' in self.file_groups[i][data_name][0] or \ | |
'.jpeg' in self.file_groups[i][data_name][0]: | |
self.transforms[i][data_type].get_transforms().insert(0, ImagePathToImage()) | |
elif '.npy' in self.file_groups[i][data_name][0] or '.npz' in self.file_groups[i][data_name][0]: | |
self.transforms[i][data_type].get_transforms().insert(0, NumpyToTensor()) | |
self.transforms[i][data_type] = self.transforms[i][data_type].compose_transforms() | |
def apply_transformations_to_images(self, img_list, img_dataname_list, transform, return_dict, | |
next_img_paths_bucket, next_img_dataname_list): | |
if len(img_list) == 1: | |
return_dict[img_dataname_list[0]], _ = transform(img_list[0], None) | |
elif len(img_list) > 1: | |
next_data_count = len(next_img_paths_bucket) | |
img_list += next_img_paths_bucket | |
img_dataname_list += next_img_dataname_list | |
input1, input2 = img_list[0], img_list[1:] | |
output1, output2 = transform(input1, input2) # output1 is one image. output2 is a list of images. | |
if next_data_count != 0: | |
output2, next_outputs = output2[:-next_data_count], output2[-next_data_count:] | |
for i in range(next_data_count): | |
return_dict[img_dataname_list[-next_data_count+i] + '_next'] = next_outputs[i] | |
return_dict[img_dataname_list[0]] = output1 | |
for j in range(0, len(output2)): | |
return_dict[img_dataname_list[j+1]] = output2[j] | |
return return_dict | |
def calculate_landmark_scale(self, data_path, data_type, i): | |
if data_type == 'image': | |
original_image = Image.open(data_path) | |
original_width, original_height = original_image.size | |
else: | |
original_image = np.load(data_path) | |
original_height, original_width = original_image.shape[0], original_image.shape[1] | |
transformed_image, _ = self.transforms[i][data_type](data_path, None) | |
transformed_height, transformed_width = transformed_image.size()[1:] | |
landmark_scale = (transformed_width / original_width, transformed_height / original_height) | |
return landmark_scale | |
def get_item(self, idx): | |
return_dict = {} | |
data_dict = self.config['dataset'][self.config['common']['phase'] + '_data'] | |
for i, group in enumerate(data_dict.values()): | |
if self.file_groups[i] == {}: | |
continue | |
paired_type = self.pair_type_groups[i] | |
inner_idx = idx if idx < self.len_of_groups[i] else random.randint(0, self.len_of_groups[i] - 1) | |
landmark_scale = None | |
# for patches since they might need to be loaded from different images. | |
next_img_paths_bucket = [] | |
next_img_dataname_list = [] | |
next_numpy_paths_bucket = [] | |
next_numpy_dataname_list = [] | |
# First, handle all non-patch data | |
if paired_type: | |
img_paths_bucket = [] | |
img_dataname_list = [] | |
numpy_paths_bucket = [] | |
numpy_dataname_list = [] | |
for data_name, data_list in self.file_groups[i].items(): | |
data_type = self.type_groups[i][data_name] | |
if data_type in ['image', 'numpy']: | |
if paired_type: | |
# augmentation will be applied to all images in paired group all at once so need to gather the images here. | |
if data_type == 'image': | |
img_paths_bucket.append(data_list[inner_idx]) | |
img_dataname_list.append(data_name) | |
else: | |
numpy_paths_bucket.append(data_list[inner_idx]) | |
numpy_dataname_list.append(data_name) | |
return_dict[data_name + '_path'] = data_list[inner_idx] | |
if landmark_scale is None: | |
landmark_scale = self.calculate_landmark_scale(data_list[inner_idx], data_type, i) | |
if 'diff_patch' in self.config['dataset'] and self.config['dataset']['diff_patch'] and \ | |
data_name in group['patch_sources']: | |
next_idx = (inner_idx + 1) % (len(data_list) - 1) | |
if data_type == 'image': | |
next_img_paths_bucket.append(data_list[next_idx]) | |
next_img_dataname_list.append(data_name) | |
else: | |
next_numpy_paths_bucket.append(data_list[next_idx]) | |
next_numpy_dataname_list.append(data_name) | |
else: | |
unpaired_inner_idx = random.randint(0, len(data_list) - 1) | |
return_dict[data_name], _ = self.transforms[i][data_type](data_list[unpaired_inner_idx], None) | |
if landmark_scale is None: | |
landmark_scale = self.calculate_landmark_scale(data_list[unpaired_inner_idx], data_type, i) | |
if 'diff_patch' in self.config['dataset'] and self.config['dataset']['diff_patch'] and \ | |
data_name in group['patch_sources']: | |
next_idx = (unpaired_inner_idx + 1) % (len(data_list) - 1) | |
return_dict[data_name + '_next'], _ = self.transforms[i][data_type](data_list[next_idx], None) | |
return_dict[data_name + '_path'] = data_list[unpaired_inner_idx] | |
elif self.type_groups[i][data_name] == 'landmark': | |
# We do not apply transformations on landmarks. Only scales landmarks to transformed image's size. | |
# Also numpy data is passed into network as numpy array and not tensor. | |
lmk = np.load(data_list[inner_idx]) | |
if self.config['dataset']['landmark_scale'] is not None: | |
lmk[:, 0] *= self.config['dataset']['landmark_scale'][0] | |
lmk[:, 1] *= self.config['dataset']['landmark_scale'][1] | |
else: | |
if landmark_scale is None: | |
print("landmark_scale is None. If you have not defined it in config file, please specify " | |
"image and numpy data before landmark data and the proper scale will be automatically calculated.") | |
else: | |
lmk[:, 0] *= landmark_scale[0] | |
lmk[:, 1] *= landmark_scale[1] | |
return_dict[data_name] = lmk | |
return_dict[data_name + '_path'] = data_list[inner_idx] | |
if paired_type: | |
# apply augmentations to all images and numpy arrays | |
if 'image' in self.transforms[i]: | |
return_dict = self.apply_transformations_to_images(img_paths_bucket, img_dataname_list, | |
self.transforms[i]['image'], return_dict, | |
next_img_paths_bucket, | |
next_img_dataname_list) | |
if 'numpy' in self.transforms[i]: | |
return_dict = self.apply_transformations_to_images(numpy_paths_bucket, numpy_dataname_list, | |
self.transforms[i]['numpy'], return_dict, | |
next_numpy_paths_bucket, | |
next_numpy_dataname_list) | |
# Handle patch data | |
data_types = group['data_types'] # examples: 'image', 'patch' | |
data_names = group['data_names'] # examples: 'real_A', 'patch_A' | |
data_types, data_names = self.filter_patch_data(data_types, data_names) | |
if 'patch_sources' in group: | |
patch_sources = group['patch_sources'] | |
return_dict = self.load_patches( | |
data_names, | |
self.config['dataset']['patch_batch_size'], | |
self.config['dataset']['batch_size'], | |
self.config['dataset']['patch_size'], | |
self.config['dataset']['patch_batch_size'] // self.config['dataset']['batch_size'], | |
self.config['dataset']['diff_patch'], | |
patch_sources, | |
return_dict, | |
) | |
return return_dict | |
def get_len(self): | |
if len(self.len_of_groups) == 0: | |
return 0 | |
else: | |
return max(self.len_of_groups) | |
def exclude_patch_data(self, data_types, data_names): | |
data_types_patch_excluded = [] | |
data_names_patch_excluded = [] | |
for data_name, data_type in zip(data_names, data_types): | |
if data_type != 'patch': | |
data_types_patch_excluded.append(data_type) | |
data_names_patch_excluded.append(data_name) | |
return data_types_patch_excluded, data_names_patch_excluded | |
def filter_patch_data(self, data_types, data_names): | |
data_types_patch = [] | |
data_names_patch = [] | |
for data_name, data_type in zip(data_names, data_types): | |
if data_type == 'patch': | |
data_types_patch.append(data_type) | |
data_names_patch.append(data_name) | |
return data_types_patch, data_names_patch | |
def load_patches(self, data_names, patch_batch_size, batch_size, patch_size, | |
num_patch, diff_patch, patch_sources, return_dict): | |
if patch_size > 0: | |
assert (patch_batch_size % batch_size == 0), \ | |
"patch_batch_size is not divisible by batch_size." | |
assert (len(patch_sources) == len(data_names)), \ | |
"length of patch_sources is not the same as number of patch data specified. Please check again in config file." | |
rlist = [] # used for cropping patches | |
clist = [] # used for cropping patches | |
for _ in range(num_patch): | |
r = random.randint(0, self.config['dataset']['crop_size'] - patch_size - 1) | |
c = random.randint(0, self.config['dataset']['crop_size'] - patch_size - 1) | |
rlist.append(r) | |
clist.append(c) | |
for i in range(len(data_names)): | |
# load transformed image | |
patch = return_dict[patch_sources[i]] if not diff_patch else return_dict[patch_sources[i] + '_next'] | |
# crop patch | |
patchs = [] | |
_, h, w = patch.size() | |
for j in range(num_patch): | |
patchs.append(patch[:, rlist[j]:rlist[j] + patch_size, clist[j]:clist[j] + patch_size]) | |
patchs = torch.cat(patchs, 0) | |
return_dict[data_names[i]] = patchs | |
return return_dict | |