import copy import torch.utils.data as data from utils.data_utils import check_img_loaded, check_numpy_loaded from data.test_data import add_test_data, apply_test_transforms from data.test_video_data import TestVideoData from data.static_data import StaticData from multiprocessing import Pool import sys class DataBin(object): def __init__(self, filegroups): self.filegroups = filegroups class SuperDataset(data.Dataset): def __init__(self, config, shuffle=False, check_all_data=False, DDP_device=None): self.config = config self.check_all_data = check_all_data self.DDP_device = DDP_device self.data = {} # Will be dictionary. Keys are data names, e.g. paired_A, patch_A. Values are lists containing associated data. self.transforms = {} if self.config['common']['phase'] == 'test': if not self.config['testing']['test_video'] is None: self.test_video_data = TestVideoData(self.config) else: add_test_data(self.data, self.transforms, self.config) return self.static_data = StaticData(self.config, shuffle) def convert_old_config_to_new(self): data_types = self.config['dataset']['data_type'] if len(data_types) == 1 and data_types[0] == 'custom': # convert custom data configuration to new data configuration old_dict = self.config['dataset']['custom_' + self.config['common']['phase'] + '_data'] preprocess_list = self.config['dataset']['preprocess'] new_datadict = self.config['dataset'][self.config['common']['phase'] + '_data'] = old_dict for i, group in enumerate(new_datadict.values()): # examples: (0, group_1), (1, group_2) group['paired'] = True group['preprocess'] = preprocess_list # custom data does not support patch so we skip patch logic. else: new_datadict = self.config['dataset'][self.config['common']['phase'] + '_data'] = {} preprocess_list = self.config['dataset']['preprocess'] new_datadict['paired_group'] = {} new_datadict['paired_group']['paired'] = True new_datadict['paired_group']['data_types'] = [] new_datadict['paired_group']['data_names'] = [] new_datadict['paired_group']['preprocess'] = preprocess_list new_datadict['unpaired_group'] = {} new_datadict['unpaired_group']['paired'] = False new_datadict['unpaired_group']['data_types'] = [] new_datadict['unpaired_group']['data_names'] = [] new_datadict['unpaired_group']['preprocess'] = preprocess_list for i in range(len(self.config['dataset']['data_type'])): data_type = self.config['dataset']['data_type'][i] if data_type == 'paired' or data_type == 'paired_numpy': if self.config['dataset']['paired_' + self.config['common']['phase'] + '_filelist'] != '': new_datadict['paired_group']['file_list'] = self.config['dataset'][ 'paired_' + self.config['common']['phase'] + '_filelist'] elif self.config['dataset']['paired_' + self.config['common']['phase'] + 'A_folder'] != '' and \ self.config['dataset']['paired_' + self.config['common']['phase'] + 'B_folder'] != '': new_datadict['paired_group']['paired_A_folder'] = self.config['dataset']['paired_' + self.config['common']['phase'] + 'A_folder'] new_datadict['paired_group']['paired_B_folder'] = self.config['dataset']['paired_' + self.config['common']['phase'] + 'B_folder'] else: new_datadict['paired_group']['dataroot'] = self.config['dataset']['dataroot'] new_datadict['paired_group']['data_names'].append('paired_A') new_datadict['paired_group']['data_names'].append('paired_B') if data_type == 'paired': new_datadict['paired_group']['data_types'].append('image') new_datadict['paired_group']['data_types'].append('image') else: new_datadict['paired_group']['data_types'].append('numpy') new_datadict['paired_group']['data_types'].append('numpy') elif data_type == 'unpaired' or data_type == 'unpaired_numpy': if self.config['dataset']['unpaired_' + self.config['common']['phase'] + 'A_filelist'] != ''\ and self.config['dataset']['unpaired_' + self.config['common']['phase'] + 'B_filelist'] != '': # combine those two filelists into one filelist self.combine_two_filelists_into_one( self.config['dataset']['unpaired_' + self.config['common']['phase'] + 'A_filelist'], self.config['dataset']['unpaired_' + self.config['common']['phase'] + 'B_filelist'] ) new_datadict['unpaired_group']['file_list'] = './tmp_filelist.txt' elif self.config['dataset']['unpaired_' + self.config['common']['phase'] + 'A_folder'] != '' and \ self.config['dataset']['unpaired_' + self.config['common']['phase'] + 'B_folder'] != '': new_datadict['unpaired_group']['unpaired_A_folder'] = self.config['dataset']['unpaired_' + self.config['common']['phase'] + 'A_folder'] new_datadict['unpaired_group']['unpaired_B_folder'] = self.config['dataset']['unpaired_' + self.config['common']['phase'] + 'B_folder'] else: new_datadict['unpaired_group']['dataroot'] = self.config['dataset']['dataroot'] new_datadict['unpaired_group']['data_names'].append('unpaired_A') new_datadict['unpaired_group']['data_names'].append('unpaired_B') if data_type == 'unpaired': new_datadict['unpaired_group']['data_types'].append('image') new_datadict['unpaired_group']['data_types'].append('image') else: new_datadict['unpaired_group']['data_types'].append('numpy') new_datadict['unpaired_group']['data_types'].append('numpy') elif data_type == 'landmark': if self.config['dataset']['paired_' + self.config['common']['phase'] + '_filelist'] != '': new_datadict['paired_group']['file_list'] = self.config['dataset'][ 'paired_' + self.config['common']['phase'] + '_filelist'] elif 'paired_' + self.config['common']['phase'] + 'A_lmk_folder' in self.config['dataset'] and \ 'paired_' + self.config['common']['phase'] + 'B_lmk_folder' in self.config['dataset'] and \ self.config['dataset']['paired_' + self.config['common']['phase'] + 'A_lmk_folder'] != '' and \ self.config['dataset']['paired_' + self.config['common']['phase'] + 'B_lmk_folder'] != '': new_datadict['paired_group']['lmk_A_folder'] = self.config['dataset']['paired_' + self.config['common']['phase'] + 'A_lmk_folder'] new_datadict['paired_group']['lmk_B_folder'] = self.config['dataset']['paired_' + self.config['common']['phase'] + 'B_lmk_folder'] else: new_datadict['paired_group']['dataroot'] = self.config['dataset']['dataroot'] new_datadict['paired_group']['data_names'].append('lmk_A') new_datadict['paired_group']['data_names'].append('lmk_B') new_datadict['paired_group']['data_types'].append('landmark') new_datadict['paired_group']['data_types'].append('landmark') # Handle patches. This needs to happen after all non-patch data are added first. if 'patch' in self.config['dataset']['data_type']: # determine if patch comes from paired or unpaired image if 'paired_A' in new_datadict['paired_group']['data_names']: new_datadict['paired_group']['data_types'].append('patch') new_datadict['paired_group']['data_names'].append('patch_A') new_datadict['paired_group']['data_types'].append('patch') new_datadict['paired_group']['data_names'].append('patch_B') if 'patch_sources' not in new_datadict['paired_group']: new_datadict['paired_group']['patch_sources'] = [] new_datadict['paired_group']['patch_sources'].append('paired_A') new_datadict['paired_group']['patch_sources'].append('paired_B') else: new_datadict['unpaired_group']['data_types'].append('patch') new_datadict['unpaired_group']['data_names'].append('patch_A') new_datadict['unpaired_group']['data_types'].append('patch') new_datadict['unpaired_group']['data_names'].append('patch_B') if 'patch_sources' not in new_datadict['unpaired_group']: new_datadict['unpaired_group']['patch_sources'] = [] new_datadict['unpaired_group']['patch_sources'].append('unpaired_A') new_datadict['unpaired_group']['patch_sources'].append('unpaired_B') if 'diff_patch' not in self.config['dataset']: self.config['dataset']['diff_patch'] = False new_datadict = {key: value for key, value in new_datadict.items() if len(value['data_names']) > 0} print('-----------------------------------------------------------------------') print("converted %s data configuration: " % self.config['common']['phase']) for key, value in new_datadict.items(): print(key + ': ', value) print('-----------------------------------------------------------------------') return self.config def combine_two_filelists_into_one(self, filelist1, filelist2): tmp_file = open('./tmp_filelist.txt', 'w+') f1 = open(filelist1, 'r') f2 = open(filelist2, 'r') f1_lines = f1.readlines() f2_lines = f2.readlines() min_index = min(len(f1_lines), len(f2_lines)) for i in range(min_index): tmp_file.write(f1_lines[i].strip() + ' ' + f2_lines[i].strip() + '\n') if min_index == len(f1_lines): for i in range(min_index, len(f2_lines)): tmp_file.write('None ' + f2_lines[i].strip() + '\n') else: for i in range(min_index, len(f1_lines)): tmp_file.write(f1_lines[i].strip() + ' None\n') tmp_file.close() f1.close() f2.close() def __len__(self): if self.config['common']['phase'] == 'test': if self.config['testing']['test_video'] is not None: return self.test_video_data.get_len() else: if len(self.data.keys()) == 0: return 0 else: min_len = 999999 for k, v in self.data.items(): length = len(v) if length < min_len: min_len = length return min_len else: return self.static_data.get_len() def get_item_logic(self, index): return_dict = {} if self.config['common']['phase'] == 'test': if not self.config['testing']['test_video'] is None: return self.test_video_data.get_item() else: apply_test_transforms(index, self.data, self.transforms, return_dict) return return_dict return_dict = self.static_data.get_item(index) return return_dict def __getitem__(self, index): if self.config['dataset']['accept_data_error']: while True: try: return self.get_item_logic(index) except Exception as e: print("Exception encountered in super_dataset's getitem function: ", e) index = (index + 1) % self.__len__() else: return self.get_item_logic(index) def split_data(self, value_mode, value, mode='split'): new_dataset = copy.deepcopy(self) ret1, new_dataset.static_data = self.split_data_helper(self.static_data, new_dataset.static_data, value_mode, value, mode=mode) if ret1 is not None: self.static_data = ret1 return self, new_dataset def split_data_helper(self, dataset, new_dataset, value_mode, value, mode='split'): for i in range(len(dataset.file_groups)): max_split_index = 0 for k in dataset.file_groups[i].keys(): length = len(dataset.file_groups[i][k]) if value_mode == 'count': split_index = min(length, value) else: split_index = int((1 - value) * length) max_split_index = max(max_split_index, split_index) new_dataset.file_groups[i][k] = new_dataset.file_groups[i][k][split_index:] if mode == 'split': dataset.file_groups[i][k] = dataset.file_groups[i][k][:split_index] new_dataset.len_of_groups[i] -= max_split_index if mode == 'split': dataset.len_of_groups[i] = max_split_index if mode == 'split': return dataset, new_dataset else: return None, new_dataset def check_data_helper(self, databin): all_pass = True for group in databin.filegroups: for data_name, data_list in group.items(): for data in data_list: if '.npy' in data: # case: numpy array or landmark all_pass = all_pass and check_numpy_loaded(data) else: # case: image all_pass = all_pass and check_img_loaded(data) return all_pass def check_data(self): if self.DDP_device is None or self.DDP_device == 0: print("-----------------------Checking all data-------------------------------") data_ok = True if self.config['dataset']['n_threads'] == 0: data_ok = data_ok and self.check_data_helper(self.static_data) else: # start n_threads number of workers to perform data checking with Pool(processes=self.config['dataset']['n_threads']) as pool: checks = pool.map(self.check_data_helper, self.split_data_into_bins(self.config['dataset']['n_threads'])) for check in checks: data_ok = data_ok and check if data_ok: print("---------------------all data passed check.-----------------------") else: print("---------------------The above data have failed in data checking. " "Please fix first.---------------------------") sys.exit() def split_data_into_bins(self, num_bins): bins = [] for i in range(num_bins): bins.append(DataBin(filegroups=[])) # handle static data bins = self.split_data_into_bins_helper(bins, self.static_data) return bins def split_data_into_bins_helper(self, bins, dataset): num_bins = len(bins) for bin in bins: for group_idx in range(len(dataset.file_groups)): bin.filegroups.append({}) for group_idx in range(len(dataset.file_groups)): file_group = dataset.file_groups[group_idx] for data_name, data_list in file_group.items(): num_items_in_bin = len(data_list) // num_bins for data_index in range(len(data_list)): which_bin = min(data_index // num_items_in_bin, num_bins - 1) if data_name not in bins[which_bin].filegroups[group_idx]: bins[which_bin].filegroups[group_idx][data_name] = [] bins[which_bin].filegroups[group_idx][data_name].append(data_list[data_index]) return bins