Spaces:
Sleeping
Sleeping
| import copy | |
| import torch.utils.data as data | |
| from utils.data_utils import check_img_loaded, check_numpy_loaded | |
| from data.test_data import add_test_data, apply_test_transforms | |
| from data.test_video_data import TestVideoData | |
| from data.static_data import StaticData | |
| from multiprocessing import Pool | |
| import sys | |
| class DataBin(object): | |
| def __init__(self, filegroups): | |
| self.filegroups = filegroups | |
| class SuperDataset(data.Dataset): | |
| def __init__(self, config, shuffle=False, check_all_data=False, DDP_device=None): | |
| self.config = config | |
| self.check_all_data = check_all_data | |
| self.DDP_device = DDP_device | |
| self.data = {} # Will be dictionary. Keys are data names, e.g. paired_A, patch_A. Values are lists containing associated data. | |
| self.transforms = {} | |
| if self.config['common']['phase'] == 'test': | |
| if not self.config['testing']['test_video'] is None: | |
| self.test_video_data = TestVideoData(self.config) | |
| else: | |
| add_test_data(self.data, self.transforms, self.config) | |
| return | |
| self.static_data = StaticData(self.config, shuffle) | |
| def convert_old_config_to_new(self): | |
| data_types = self.config['dataset']['data_type'] | |
| if len(data_types) == 1 and data_types[0] == 'custom': | |
| # convert custom data configuration to new data configuration | |
| old_dict = self.config['dataset']['custom_' + self.config['common']['phase'] + '_data'] | |
| preprocess_list = self.config['dataset']['preprocess'] | |
| new_datadict = self.config['dataset'][self.config['common']['phase'] + '_data'] = old_dict | |
| for i, group in enumerate(new_datadict.values()): # examples: (0, group_1), (1, group_2) | |
| group['paired'] = True | |
| group['preprocess'] = preprocess_list | |
| # custom data does not support patch so we skip patch logic. | |
| else: | |
| new_datadict = self.config['dataset'][self.config['common']['phase'] + '_data'] = {} | |
| preprocess_list = self.config['dataset']['preprocess'] | |
| new_datadict['paired_group'] = {} | |
| new_datadict['paired_group']['paired'] = True | |
| new_datadict['paired_group']['data_types'] = [] | |
| new_datadict['paired_group']['data_names'] = [] | |
| new_datadict['paired_group']['preprocess'] = preprocess_list | |
| new_datadict['unpaired_group'] = {} | |
| new_datadict['unpaired_group']['paired'] = False | |
| new_datadict['unpaired_group']['data_types'] = [] | |
| new_datadict['unpaired_group']['data_names'] = [] | |
| new_datadict['unpaired_group']['preprocess'] = preprocess_list | |
| for i in range(len(self.config['dataset']['data_type'])): | |
| data_type = self.config['dataset']['data_type'][i] | |
| if data_type == 'paired' or data_type == 'paired_numpy': | |
| if self.config['dataset']['paired_' + self.config['common']['phase'] + '_filelist'] != '': | |
| new_datadict['paired_group']['file_list'] = self.config['dataset'][ | |
| 'paired_' + self.config['common']['phase'] + '_filelist'] | |
| elif self.config['dataset']['paired_' + self.config['common']['phase'] + 'A_folder'] != '' and \ | |
| self.config['dataset']['paired_' + self.config['common']['phase'] + 'B_folder'] != '': | |
| new_datadict['paired_group']['paired_A_folder'] = self.config['dataset']['paired_' + self.config['common']['phase'] + 'A_folder'] | |
| new_datadict['paired_group']['paired_B_folder'] = self.config['dataset']['paired_' + self.config['common']['phase'] + 'B_folder'] | |
| else: | |
| new_datadict['paired_group']['dataroot'] = self.config['dataset']['dataroot'] | |
| new_datadict['paired_group']['data_names'].append('paired_A') | |
| new_datadict['paired_group']['data_names'].append('paired_B') | |
| if data_type == 'paired': | |
| new_datadict['paired_group']['data_types'].append('image') | |
| new_datadict['paired_group']['data_types'].append('image') | |
| else: | |
| new_datadict['paired_group']['data_types'].append('numpy') | |
| new_datadict['paired_group']['data_types'].append('numpy') | |
| elif data_type == 'unpaired' or data_type == 'unpaired_numpy': | |
| if self.config['dataset']['unpaired_' + self.config['common']['phase'] + 'A_filelist'] != ''\ | |
| and self.config['dataset']['unpaired_' + self.config['common']['phase'] + 'B_filelist'] != '': | |
| # combine those two filelists into one filelist | |
| self.combine_two_filelists_into_one( | |
| self.config['dataset']['unpaired_' + self.config['common']['phase'] + 'A_filelist'], | |
| self.config['dataset']['unpaired_' + self.config['common']['phase'] + 'B_filelist'] | |
| ) | |
| new_datadict['unpaired_group']['file_list'] = './tmp_filelist.txt' | |
| elif self.config['dataset']['unpaired_' + self.config['common']['phase'] + 'A_folder'] != '' and \ | |
| self.config['dataset']['unpaired_' + self.config['common']['phase'] + 'B_folder'] != '': | |
| new_datadict['unpaired_group']['unpaired_A_folder'] = self.config['dataset']['unpaired_' + self.config['common']['phase'] + 'A_folder'] | |
| new_datadict['unpaired_group']['unpaired_B_folder'] = self.config['dataset']['unpaired_' + self.config['common']['phase'] + 'B_folder'] | |
| else: | |
| new_datadict['unpaired_group']['dataroot'] = self.config['dataset']['dataroot'] | |
| new_datadict['unpaired_group']['data_names'].append('unpaired_A') | |
| new_datadict['unpaired_group']['data_names'].append('unpaired_B') | |
| if data_type == 'unpaired': | |
| new_datadict['unpaired_group']['data_types'].append('image') | |
| new_datadict['unpaired_group']['data_types'].append('image') | |
| else: | |
| new_datadict['unpaired_group']['data_types'].append('numpy') | |
| new_datadict['unpaired_group']['data_types'].append('numpy') | |
| elif data_type == 'landmark': | |
| if self.config['dataset']['paired_' + self.config['common']['phase'] + '_filelist'] != '': | |
| new_datadict['paired_group']['file_list'] = self.config['dataset'][ | |
| 'paired_' + self.config['common']['phase'] + '_filelist'] | |
| elif 'paired_' + self.config['common']['phase'] + 'A_lmk_folder' in self.config['dataset'] and \ | |
| 'paired_' + self.config['common']['phase'] + 'B_lmk_folder' in self.config['dataset'] and \ | |
| self.config['dataset']['paired_' + self.config['common']['phase'] + 'A_lmk_folder'] != '' and \ | |
| self.config['dataset']['paired_' + self.config['common']['phase'] + 'B_lmk_folder'] != '': | |
| new_datadict['paired_group']['lmk_A_folder'] = self.config['dataset']['paired_' + self.config['common']['phase'] + 'A_lmk_folder'] | |
| new_datadict['paired_group']['lmk_B_folder'] = self.config['dataset']['paired_' + self.config['common']['phase'] + 'B_lmk_folder'] | |
| else: | |
| new_datadict['paired_group']['dataroot'] = self.config['dataset']['dataroot'] | |
| new_datadict['paired_group']['data_names'].append('lmk_A') | |
| new_datadict['paired_group']['data_names'].append('lmk_B') | |
| new_datadict['paired_group']['data_types'].append('landmark') | |
| new_datadict['paired_group']['data_types'].append('landmark') | |
| # Handle patches. This needs to happen after all non-patch data are added first. | |
| if 'patch' in self.config['dataset']['data_type']: | |
| # determine if patch comes from paired or unpaired image | |
| if 'paired_A' in new_datadict['paired_group']['data_names']: | |
| new_datadict['paired_group']['data_types'].append('patch') | |
| new_datadict['paired_group']['data_names'].append('patch_A') | |
| new_datadict['paired_group']['data_types'].append('patch') | |
| new_datadict['paired_group']['data_names'].append('patch_B') | |
| if 'patch_sources' not in new_datadict['paired_group']: | |
| new_datadict['paired_group']['patch_sources'] = [] | |
| new_datadict['paired_group']['patch_sources'].append('paired_A') | |
| new_datadict['paired_group']['patch_sources'].append('paired_B') | |
| else: | |
| new_datadict['unpaired_group']['data_types'].append('patch') | |
| new_datadict['unpaired_group']['data_names'].append('patch_A') | |
| new_datadict['unpaired_group']['data_types'].append('patch') | |
| new_datadict['unpaired_group']['data_names'].append('patch_B') | |
| if 'patch_sources' not in new_datadict['unpaired_group']: | |
| new_datadict['unpaired_group']['patch_sources'] = [] | |
| new_datadict['unpaired_group']['patch_sources'].append('unpaired_A') | |
| new_datadict['unpaired_group']['patch_sources'].append('unpaired_B') | |
| if 'diff_patch' not in self.config['dataset']: | |
| self.config['dataset']['diff_patch'] = False | |
| new_datadict = {key: value for key, value in new_datadict.items() if len(value['data_names']) > 0} | |
| print('-----------------------------------------------------------------------') | |
| print("converted %s data configuration: " % self.config['common']['phase']) | |
| for key, value in new_datadict.items(): | |
| print(key + ': ', value) | |
| print('-----------------------------------------------------------------------') | |
| return self.config | |
| def combine_two_filelists_into_one(self, filelist1, filelist2): | |
| tmp_file = open('./tmp_filelist.txt', 'w+') | |
| f1 = open(filelist1, 'r') | |
| f2 = open(filelist2, 'r') | |
| f1_lines = f1.readlines() | |
| f2_lines = f2.readlines() | |
| min_index = min(len(f1_lines), len(f2_lines)) | |
| for i in range(min_index): | |
| tmp_file.write(f1_lines[i].strip() + ' ' + f2_lines[i].strip() + '\n') | |
| if min_index == len(f1_lines): | |
| for i in range(min_index, len(f2_lines)): | |
| tmp_file.write('None ' + f2_lines[i].strip() + '\n') | |
| else: | |
| for i in range(min_index, len(f1_lines)): | |
| tmp_file.write(f1_lines[i].strip() + ' None\n') | |
| tmp_file.close() | |
| f1.close() | |
| f2.close() | |
| def __len__(self): | |
| if self.config['common']['phase'] == 'test': | |
| if self.config['testing']['test_video'] is not None: | |
| return self.test_video_data.get_len() | |
| else: | |
| if len(self.data.keys()) == 0: | |
| return 0 | |
| else: | |
| min_len = 999999 | |
| for k, v in self.data.items(): | |
| length = len(v) | |
| if length < min_len: | |
| min_len = length | |
| return min_len | |
| else: | |
| return self.static_data.get_len() | |
| def get_item_logic(self, index): | |
| return_dict = {} | |
| if self.config['common']['phase'] == 'test': | |
| if not self.config['testing']['test_video'] is None: | |
| return self.test_video_data.get_item() | |
| else: | |
| apply_test_transforms(index, self.data, self.transforms, return_dict) | |
| return return_dict | |
| return_dict = self.static_data.get_item(index) | |
| return return_dict | |
| def __getitem__(self, index): | |
| if self.config['dataset']['accept_data_error']: | |
| while True: | |
| try: | |
| return self.get_item_logic(index) | |
| except Exception as e: | |
| print("Exception encountered in super_dataset's getitem function: ", e) | |
| index = (index + 1) % self.__len__() | |
| else: | |
| return self.get_item_logic(index) | |
| def split_data(self, value_mode, value, mode='split'): | |
| new_dataset = copy.deepcopy(self) | |
| ret1, new_dataset.static_data = self.split_data_helper(self.static_data, new_dataset.static_data, value_mode, value, mode=mode) | |
| if ret1 is not None: | |
| self.static_data = ret1 | |
| return self, new_dataset | |
| def split_data_helper(self, dataset, new_dataset, value_mode, value, mode='split'): | |
| for i in range(len(dataset.file_groups)): | |
| max_split_index = 0 | |
| for k in dataset.file_groups[i].keys(): | |
| length = len(dataset.file_groups[i][k]) | |
| if value_mode == 'count': | |
| split_index = min(length, value) | |
| else: | |
| split_index = int((1 - value) * length) | |
| max_split_index = max(max_split_index, split_index) | |
| new_dataset.file_groups[i][k] = new_dataset.file_groups[i][k][split_index:] | |
| if mode == 'split': | |
| dataset.file_groups[i][k] = dataset.file_groups[i][k][:split_index] | |
| new_dataset.len_of_groups[i] -= max_split_index | |
| if mode == 'split': | |
| dataset.len_of_groups[i] = max_split_index | |
| if mode == 'split': | |
| return dataset, new_dataset | |
| else: | |
| return None, new_dataset | |
| def check_data_helper(self, databin): | |
| all_pass = True | |
| for group in databin.filegroups: | |
| for data_name, data_list in group.items(): | |
| for data in data_list: | |
| if '.npy' in data: # case: numpy array or landmark | |
| all_pass = all_pass and check_numpy_loaded(data) | |
| else: # case: image | |
| all_pass = all_pass and check_img_loaded(data) | |
| return all_pass | |
| def check_data(self): | |
| if self.DDP_device is None or self.DDP_device == 0: | |
| print("-----------------------Checking all data-------------------------------") | |
| data_ok = True | |
| if self.config['dataset']['n_threads'] == 0: | |
| data_ok = data_ok and self.check_data_helper(self.static_data) | |
| else: | |
| # start n_threads number of workers to perform data checking | |
| with Pool(processes=self.config['dataset']['n_threads']) as pool: | |
| checks = pool.map(self.check_data_helper, | |
| self.split_data_into_bins(self.config['dataset']['n_threads'])) | |
| for check in checks: | |
| data_ok = data_ok and check | |
| if data_ok: | |
| print("---------------------all data passed check.-----------------------") | |
| else: | |
| print("---------------------The above data have failed in data checking. " | |
| "Please fix first.---------------------------") | |
| sys.exit() | |
| def split_data_into_bins(self, num_bins): | |
| bins = [] | |
| for i in range(num_bins): | |
| bins.append(DataBin(filegroups=[])) | |
| # handle static data | |
| bins = self.split_data_into_bins_helper(bins, self.static_data) | |
| return bins | |
| def split_data_into_bins_helper(self, bins, dataset): | |
| num_bins = len(bins) | |
| for bin in bins: | |
| for group_idx in range(len(dataset.file_groups)): | |
| bin.filegroups.append({}) | |
| for group_idx in range(len(dataset.file_groups)): | |
| file_group = dataset.file_groups[group_idx] | |
| for data_name, data_list in file_group.items(): | |
| num_items_in_bin = len(data_list) // num_bins | |
| for data_index in range(len(data_list)): | |
| which_bin = min(data_index // num_items_in_bin, num_bins - 1) | |
| if data_name not in bins[which_bin].filegroups[group_idx]: | |
| bins[which_bin].filegroups[group_idx][data_name] = [] | |
| bins[which_bin].filegroups[group_idx][data_name].append(data_list[data_index]) | |
| return bins | |