Spaces:
Sleeping
Sleeping
| import os, sys | |
| import random | |
| import numpy as np | |
| from utils.augmentation import ImagePathToImage, NumpyToTensor | |
| from utils.data_utils import Transforms | |
| from utils.util import check_path_is_static_data | |
| import torch | |
| from PIL import Image | |
| def check_dataname_folder_correspondence(data_names, group, group_name): | |
| for data_name in data_names: | |
| if data_name + '_folder' not in group: | |
| print("%s not found in config file. Going to use dataroot mode to load group %s." % (data_name + '_folder', group_name)) | |
| return False | |
| return True | |
| def custom_check_path_exists(str1): | |
| return True if (str1 == "None" or os.path.exists(str1)) else False | |
| def custom_getsize(str1): | |
| return 1 if str1 == "None" else os.path.getsize(str1) | |
| def check_different_extension_path_exists(str1): | |
| acceptable_extensions = ['png', 'jpg', 'jpeg', 'npy', 'npz', 'PNG', 'JPG', 'JPEG'] | |
| curr_extension = str1.split('.')[-1] | |
| for extension in acceptable_extensions: | |
| str2 = str1.replace(curr_extension, extension) | |
| if os.path.exists(str2): | |
| return str2 | |
| return None | |
| class StaticData(object): | |
| def __init__(self, config, shuffle=False): | |
| # private variables | |
| self.file_groups = [] | |
| self.type_groups = [] | |
| self.group_names = [] | |
| self.pair_type_groups = [] | |
| self.len_of_groups = [] | |
| self.transforms = {} | |
| # parameters | |
| self.shuffle = shuffle | |
| self.config = config | |
| def load_static_data(self): | |
| data_dict = self.config['dataset'][self.config['common']['phase'] + '_data'] | |
| print("----------------loading %s static data.---------------------" % self.config['common']['phase']) | |
| if len(data_dict) == 0: | |
| self.len_of_groups.append(0) | |
| return | |
| self.group_names = list(data_dict.keys()) | |
| for i, group in enumerate(data_dict.values()): # examples: (0, group_1), (1, group_2) | |
| data_types = group['data_types'] # examples: 'image', 'patch' | |
| data_names = group['data_names'] # examples: 'real_A', 'patch_A' | |
| self.file_groups.append({}) | |
| self.type_groups.append({}) | |
| self.len_of_groups.append(0) | |
| self.pair_type_groups.append(group['paired']) | |
| # exclude patch data since they are not stored on disk. They will be handled later. | |
| data_types, data_names = self.exclude_patch_data(data_types, data_names) | |
| assert(len(data_types) == len(data_names)) | |
| if len(data_names) == 0: | |
| continue | |
| for data_name, data_type in zip(data_names, data_types): | |
| self.file_groups[i][data_name] = [] | |
| self.type_groups[i][data_name] = data_type | |
| # paired data | |
| if group['paired']: | |
| # First way to load data: load a file list | |
| if 'file_list' in group: | |
| file_list = group['file_list'] | |
| paired_file = open(file_list, 'rt') | |
| lines = paired_file.readlines() | |
| if self.shuffle: | |
| random.shuffle(lines) | |
| for line in lines: | |
| items = line.strip().split(' ') | |
| if len(items) == len(data_names): | |
| ok = True | |
| for item in items: | |
| ok = ok and os.path.exists(item) and os.path.getsize(item) > 0 | |
| if ok: | |
| for data_name, item in zip(data_names, items): | |
| self.file_groups[i][data_name].append(item) | |
| paired_file.close() | |
| # second and third way to load data: specify one folder for each dataname, or specify a dataroot folder | |
| elif check_dataname_folder_correspondence(data_names, group, self.group_names[i]) or 'dataroot' in group: | |
| dataname_to_dir_dict = {} | |
| for data_name, data_type in zip(data_names, data_types): | |
| if 'dataroot' in group: | |
| # In new data config format, data is stored in dataroot_name/mode/dataname. e.g. FFHQ/train/pairedA | |
| # In old format, data is stored in dataroot_name/mode_dataname. e.g. FFHQ/train_pairedA | |
| # So we need to check both. | |
| dir = os.path.join(group['dataroot'], self.config['common']['phase'], data_name) | |
| if not os.path.exists(dir): | |
| old_dir = os.path.join(group['dataroot'], self.config['common']['phase'] + data_name.replace('_', '')) | |
| if 'numpy' in data_type: | |
| old_dir += 'numpy' | |
| if not os.path.exists(old_dir): | |
| print("Both %s and %s does not exist. Please check." % (dir, old_dir)) | |
| sys.exit() | |
| else: | |
| dir = old_dir | |
| else: | |
| dir = group[data_name + '_folder'] | |
| if not os.path.exists(dir): | |
| print("directory %s does not exist. Please check." % dir) | |
| sys.exit() | |
| dataname_to_dir_dict[data_name] = dir | |
| filenames = os.listdir(dataname_to_dir_dict[data_names[0]]) | |
| if self.shuffle: | |
| random.shuffle(filenames) | |
| for filename in filenames: | |
| if not check_path_is_static_data(filename): | |
| continue | |
| file_paths = [] | |
| for data_name in data_names: | |
| file_path = os.path.join(dataname_to_dir_dict[data_name], filename) | |
| checked_extension = check_different_extension_path_exists(file_path) | |
| if checked_extension is not None: | |
| file_paths.append(checked_extension) | |
| if len(file_paths) != len(data_names): | |
| print("for file %s , looks like some of the other pair data is missing. Ignoring and proceeding." % filename) | |
| continue | |
| else: | |
| for j in range(len(data_names)): | |
| data_name = data_names[j] | |
| self.file_groups[i][data_name].append(file_paths[j]) | |
| else: | |
| print("method for loading data is incorrect/unspecified for data group %s." % self.group_names) | |
| sys.exit() | |
| self.len_of_groups[i] = len(self.file_groups[i][data_names[0]]) | |
| # unpaired data | |
| else: | |
| # First way to load data: load a file list | |
| if 'file_list' in group: | |
| file_list = group['file_list'] | |
| unpaired_file = open(file_list, 'rt') | |
| lines = unpaired_file.readlines() | |
| if self.shuffle: | |
| random.shuffle(lines) | |
| item_count = 0 | |
| for line in lines: | |
| items = line.strip().split(' ') | |
| if len(items) == len(data_names): | |
| ok = True | |
| for item in items: | |
| ok = ok and custom_check_path_exists(item) and custom_getsize(item) > 0 | |
| if ok: | |
| has_data = False | |
| for data_name, item in zip(data_names, items): | |
| if item != 'None': | |
| self.file_groups[i][data_name].append(item) | |
| has_data = True | |
| if has_data: | |
| item_count += 1 | |
| unpaired_file.close() | |
| self.len_of_groups[i] = item_count | |
| # second and third way to load data: specify one folder for each dataname, or specify a dataroot folder | |
| elif check_dataname_folder_correspondence(data_names, group, self.group_names[i]) or 'dataroot' in group: | |
| max_length = 0 | |
| for data_name, data_type in zip(data_names, data_types): | |
| if 'dataroot' in group: | |
| # In new data config format, data is stored in dataroot_name/mode/dataname. e.g. FFHQ/train/pairedA | |
| # In old format, data is stored in dataroot_name/mode_dataname. e.g. FFHQ/train_pairedA | |
| # So we need to check both. | |
| dir = os.path.join(group['dataroot'], self.config['common']['phase'], data_name) | |
| if not os.path.exists(dir): | |
| old_dir = os.path.join(group['dataroot'], self.config['common']['phase'] + data_name.replace('_', '')) | |
| if 'numpy' in data_type: | |
| old_dir += 'numpy' | |
| if not os.path.exists(old_dir): | |
| print("Both %s and %s does not exist. Please check." % (dir, old_dir)) | |
| sys.exit() | |
| else: | |
| dir = old_dir | |
| else: | |
| dir = group[data_name + '_folder'] | |
| if not os.path.exists(dir): | |
| print("directory %s does not exist. Please check." % dir) | |
| sys.exit() | |
| filenames = os.listdir(dir) | |
| if self.shuffle: | |
| random.shuffle(filenames) | |
| item_count = 0 | |
| for filename in filenames: | |
| if not check_path_is_static_data(filename): | |
| continue | |
| fullpath = os.path.join(dir, filename) | |
| if os.path.exists(fullpath): | |
| self.file_groups[i][data_name].append(fullpath) | |
| item_count += 1 | |
| max_length = max(item_count, max_length) | |
| self.len_of_groups[i] = max_length | |
| else: | |
| print("method for loading data is incorrect/unspecified for data group %s." % self.group_names) | |
| sys.exit() | |
| def create_transforms(self): | |
| btoA = self.config['dataset']['direction'] == 'BtoA' | |
| input_nc = self.config['model']['output_nc'] if btoA else self.config['model']['input_nc'] | |
| output_nc = self.config['model']['input_nc'] if btoA else self.config['model']['output_nc'] | |
| input_grayscale_flag = (input_nc == 1) | |
| output_grayscale_flag = (output_nc == 1) | |
| data_dict = self.config['dataset'][self.config['common']['phase'] + '_data'] | |
| for i, group in enumerate(data_dict.values()): # examples: (0, group_1), (1, group_2) | |
| if i not in self.transforms: | |
| self.transforms[i] = {} | |
| data_types = group['data_types'] # examples: 'image', 'patch' | |
| data_names = group['data_names'] # examples: 'real_A', 'patch_A' | |
| data_types, data_names = self.exclude_patch_data(data_types, data_names) | |
| for data_name, data_type in zip(data_names, data_types): | |
| if data_type in self.transforms[i]: | |
| continue | |
| self.transforms[i][data_type] = Transforms(self.config, input_grayscale_flag=input_grayscale_flag, | |
| output_grayscale_flag=output_grayscale_flag) | |
| self.transforms[i][data_type].create_transforms_from_list(group['preprocess']) | |
| if '.png' in self.file_groups[i][data_name][0] or '.jpg' in self.file_groups[i][data_name][0] or \ | |
| '.jpeg' in self.file_groups[i][data_name][0]: | |
| self.transforms[i][data_type].get_transforms().insert(0, ImagePathToImage()) | |
| elif '.npy' in self.file_groups[i][data_name][0] or '.npz' in self.file_groups[i][data_name][0]: | |
| self.transforms[i][data_type].get_transforms().insert(0, NumpyToTensor()) | |
| self.transforms[i][data_type] = self.transforms[i][data_type].compose_transforms() | |
| def apply_transformations_to_images(self, img_list, img_dataname_list, transform, return_dict, | |
| next_img_paths_bucket, next_img_dataname_list): | |
| if len(img_list) == 1: | |
| return_dict[img_dataname_list[0]], _ = transform(img_list[0], None) | |
| elif len(img_list) > 1: | |
| next_data_count = len(next_img_paths_bucket) | |
| img_list += next_img_paths_bucket | |
| img_dataname_list += next_img_dataname_list | |
| input1, input2 = img_list[0], img_list[1:] | |
| output1, output2 = transform(input1, input2) # output1 is one image. output2 is a list of images. | |
| if next_data_count != 0: | |
| output2, next_outputs = output2[:-next_data_count], output2[-next_data_count:] | |
| for i in range(next_data_count): | |
| return_dict[img_dataname_list[-next_data_count+i] + '_next'] = next_outputs[i] | |
| return_dict[img_dataname_list[0]] = output1 | |
| for j in range(0, len(output2)): | |
| return_dict[img_dataname_list[j+1]] = output2[j] | |
| return return_dict | |
| def calculate_landmark_scale(self, data_path, data_type, i): | |
| if data_type == 'image': | |
| original_image = Image.open(data_path) | |
| original_width, original_height = original_image.size | |
| else: | |
| original_image = np.load(data_path) | |
| original_height, original_width = original_image.shape[0], original_image.shape[1] | |
| transformed_image, _ = self.transforms[i][data_type](data_path, None) | |
| transformed_height, transformed_width = transformed_image.size()[1:] | |
| landmark_scale = (transformed_width / original_width, transformed_height / original_height) | |
| return landmark_scale | |
| def get_item(self, idx): | |
| return_dict = {} | |
| data_dict = self.config['dataset'][self.config['common']['phase'] + '_data'] | |
| for i, group in enumerate(data_dict.values()): | |
| if self.file_groups[i] == {}: | |
| continue | |
| paired_type = self.pair_type_groups[i] | |
| inner_idx = idx if idx < self.len_of_groups[i] else random.randint(0, self.len_of_groups[i] - 1) | |
| landmark_scale = None | |
| # for patches since they might need to be loaded from different images. | |
| next_img_paths_bucket = [] | |
| next_img_dataname_list = [] | |
| next_numpy_paths_bucket = [] | |
| next_numpy_dataname_list = [] | |
| # First, handle all non-patch data | |
| if paired_type: | |
| img_paths_bucket = [] | |
| img_dataname_list = [] | |
| numpy_paths_bucket = [] | |
| numpy_dataname_list = [] | |
| for data_name, data_list in self.file_groups[i].items(): | |
| data_type = self.type_groups[i][data_name] | |
| if data_type in ['image', 'numpy']: | |
| if paired_type: | |
| # augmentation will be applied to all images in paired group all at once so need to gather the images here. | |
| if data_type == 'image': | |
| img_paths_bucket.append(data_list[inner_idx]) | |
| img_dataname_list.append(data_name) | |
| else: | |
| numpy_paths_bucket.append(data_list[inner_idx]) | |
| numpy_dataname_list.append(data_name) | |
| return_dict[data_name + '_path'] = data_list[inner_idx] | |
| if landmark_scale is None: | |
| landmark_scale = self.calculate_landmark_scale(data_list[inner_idx], data_type, i) | |
| if 'diff_patch' in self.config['dataset'] and self.config['dataset']['diff_patch'] and \ | |
| data_name in group['patch_sources']: | |
| next_idx = (inner_idx + 1) % (len(data_list) - 1) | |
| if data_type == 'image': | |
| next_img_paths_bucket.append(data_list[next_idx]) | |
| next_img_dataname_list.append(data_name) | |
| else: | |
| next_numpy_paths_bucket.append(data_list[next_idx]) | |
| next_numpy_dataname_list.append(data_name) | |
| else: | |
| unpaired_inner_idx = random.randint(0, len(data_list) - 1) | |
| return_dict[data_name], _ = self.transforms[i][data_type](data_list[unpaired_inner_idx], None) | |
| if landmark_scale is None: | |
| landmark_scale = self.calculate_landmark_scale(data_list[unpaired_inner_idx], data_type, i) | |
| if 'diff_patch' in self.config['dataset'] and self.config['dataset']['diff_patch'] and \ | |
| data_name in group['patch_sources']: | |
| next_idx = (unpaired_inner_idx + 1) % (len(data_list) - 1) | |
| return_dict[data_name + '_next'], _ = self.transforms[i][data_type](data_list[next_idx], None) | |
| return_dict[data_name + '_path'] = data_list[unpaired_inner_idx] | |
| elif self.type_groups[i][data_name] == 'landmark': | |
| # We do not apply transformations on landmarks. Only scales landmarks to transformed image's size. | |
| # Also numpy data is passed into network as numpy array and not tensor. | |
| lmk = np.load(data_list[inner_idx]) | |
| if self.config['dataset']['landmark_scale'] is not None: | |
| lmk[:, 0] *= self.config['dataset']['landmark_scale'][0] | |
| lmk[:, 1] *= self.config['dataset']['landmark_scale'][1] | |
| else: | |
| if landmark_scale is None: | |
| print("landmark_scale is None. If you have not defined it in config file, please specify " | |
| "image and numpy data before landmark data and the proper scale will be automatically calculated.") | |
| else: | |
| lmk[:, 0] *= landmark_scale[0] | |
| lmk[:, 1] *= landmark_scale[1] | |
| return_dict[data_name] = lmk | |
| return_dict[data_name + '_path'] = data_list[inner_idx] | |
| if paired_type: | |
| # apply augmentations to all images and numpy arrays | |
| if 'image' in self.transforms[i]: | |
| return_dict = self.apply_transformations_to_images(img_paths_bucket, img_dataname_list, | |
| self.transforms[i]['image'], return_dict, | |
| next_img_paths_bucket, | |
| next_img_dataname_list) | |
| if 'numpy' in self.transforms[i]: | |
| return_dict = self.apply_transformations_to_images(numpy_paths_bucket, numpy_dataname_list, | |
| self.transforms[i]['numpy'], return_dict, | |
| next_numpy_paths_bucket, | |
| next_numpy_dataname_list) | |
| # Handle patch data | |
| data_types = group['data_types'] # examples: 'image', 'patch' | |
| data_names = group['data_names'] # examples: 'real_A', 'patch_A' | |
| data_types, data_names = self.filter_patch_data(data_types, data_names) | |
| if 'patch_sources' in group: | |
| patch_sources = group['patch_sources'] | |
| return_dict = self.load_patches( | |
| data_names, | |
| self.config['dataset']['patch_batch_size'], | |
| self.config['dataset']['batch_size'], | |
| self.config['dataset']['patch_size'], | |
| self.config['dataset']['patch_batch_size'] // self.config['dataset']['batch_size'], | |
| self.config['dataset']['diff_patch'], | |
| patch_sources, | |
| return_dict, | |
| ) | |
| return return_dict | |
| def get_len(self): | |
| if len(self.len_of_groups) == 0: | |
| return 0 | |
| else: | |
| return max(self.len_of_groups) | |
| def exclude_patch_data(self, data_types, data_names): | |
| data_types_patch_excluded = [] | |
| data_names_patch_excluded = [] | |
| for data_name, data_type in zip(data_names, data_types): | |
| if data_type != 'patch': | |
| data_types_patch_excluded.append(data_type) | |
| data_names_patch_excluded.append(data_name) | |
| return data_types_patch_excluded, data_names_patch_excluded | |
| def filter_patch_data(self, data_types, data_names): | |
| data_types_patch = [] | |
| data_names_patch = [] | |
| for data_name, data_type in zip(data_names, data_types): | |
| if data_type == 'patch': | |
| data_types_patch.append(data_type) | |
| data_names_patch.append(data_name) | |
| return data_types_patch, data_names_patch | |
| def load_patches(self, data_names, patch_batch_size, batch_size, patch_size, | |
| num_patch, diff_patch, patch_sources, return_dict): | |
| if patch_size > 0: | |
| assert (patch_batch_size % batch_size == 0), \ | |
| "patch_batch_size is not divisible by batch_size." | |
| assert (len(patch_sources) == len(data_names)), \ | |
| "length of patch_sources is not the same as number of patch data specified. Please check again in config file." | |
| rlist = [] # used for cropping patches | |
| clist = [] # used for cropping patches | |
| for _ in range(num_patch): | |
| r = random.randint(0, self.config['dataset']['crop_size'] - patch_size - 1) | |
| c = random.randint(0, self.config['dataset']['crop_size'] - patch_size - 1) | |
| rlist.append(r) | |
| clist.append(c) | |
| for i in range(len(data_names)): | |
| # load transformed image | |
| patch = return_dict[patch_sources[i]] if not diff_patch else return_dict[patch_sources[i] + '_next'] | |
| # crop patch | |
| patchs = [] | |
| _, h, w = patch.size() | |
| for j in range(num_patch): | |
| patchs.append(patch[:, rlist[j]:rlist[j] + patch_size, clist[j]:clist[j] + patch_size]) | |
| patchs = torch.cat(patchs, 0) | |
| return_dict[data_names[i]] = patchs | |
| return return_dict | |