Spaces:
Sleeping
Sleeping
import os | |
import random | |
import numpy as np | |
from utils.augmentation import ImagePathToImage | |
from utils.data_utils import Transforms, check_img_loaded, check_numpy_loaded | |
class CustomData(object): | |
def __init__(self, config, shuffle=False): | |
self.paired_file_groups = [] | |
self.paired_type_groups = [] | |
self.len_of_groups = [] | |
self.landmark_scale = config['dataset']['landmark_scale'] | |
self.shuffle = shuffle | |
self.config = config | |
data_dict = config['dataset']['custom_' + config['common']['phase'] + '_data'] | |
if len(data_dict) == 0: | |
self.len_of_groups.append(0) | |
return | |
for i, group in enumerate(data_dict.values()): # one example: (0, group_1), (1, group_2) | |
data_types = group['data_types'] # one example: 'image', 'patch' | |
data_names = group['data_names'] # one example: 'real_A', 'patch_A' | |
file_list = group['file_list'] # one example: "lmt/data/trainA.txt" | |
assert(len(data_types) == len(data_names)) | |
self.paired_file_groups.append({}) | |
self.paired_type_groups.append({}) | |
for data_name, data_type in zip(data_names, data_types): | |
self.paired_file_groups[i][data_name] = [] | |
self.paired_type_groups[i][data_name] = data_type | |
paired_file = open(file_list, 'rt') | |
lines = paired_file.readlines() | |
if self.shuffle: | |
random.shuffle(lines) | |
for line in lines: | |
items = line.strip().split(' ') | |
if len(items) == len(data_names): | |
ok = True | |
for item in items: | |
ok = ok and os.path.exists(item) and os.path.getsize(item) > 0 | |
if ok: | |
for data_name, item in zip(data_names, items): | |
self.paired_file_groups[i][data_name].append(item) | |
paired_file.close() | |
self.len_of_groups.append(len(self.paired_file_groups[i][data_names[0]])) | |
self.transform = Transforms(config) | |
self.transform.get_transform_from_config() | |
self.transform.get_transforms().insert(0, ImagePathToImage()) | |
self.transform = self.transform.compose_transforms() | |
def get_len(self): | |
return max(self.len_of_groups) | |
def get_item(self, idx): | |
return_dict = {} | |
for i in range(len(self.paired_file_groups)): | |
inner_idx = idx if idx < self.len_of_groups[i] else random.randint(0, self.len_of_groups[i] - 1) | |
img_list = [] | |
img_k_list = [] | |
for k, v in self.paired_file_groups[i].items(): | |
if self.paired_type_groups[i][k] == 'image': | |
# gather images for processing later | |
img_k_list.append(k) | |
img_list.append(v[inner_idx]) | |
elif self.paired_type_groups[i][k] == 'landmark': | |
# different from images, landmark doesn't use data augmentation. So process them directly here. | |
lmk = np.load(v[inner_idx]) | |
lmk[:, 0] *= self.landmark_scale[0] | |
lmk[:, 1] *= self.landmark_scale[1] | |
return_dict[k] = lmk | |
return_dict[k + '_path'] = v[inner_idx] | |
# transform all images | |
if len(img_list) == 1: | |
return_dict[img_k_list[0]], _ = self.transform(img_list[0], None) | |
elif len(img_list) > 1: | |
input1, input2 = img_list[0], img_list[1:] | |
output1, output2 = self.transform(input1, input2) # output1 is one image. output2 is a list of images. | |
return_dict[img_k_list[0]] = output1 | |
for j in range(1, len(img_list)): | |
return_dict[img_k_list[j]] = output2[j-1] | |
return return_dict | |
def split_data_into_bins(self, num_bins): | |
bins = [] | |
for i in range(0, num_bins): | |
bins.append([]) | |
for i in range(0, len(self.paired_file_groups)): | |
for b in range(0, num_bins): | |
bins[b].append({}) | |
for dataname, item_list in self.paired_file_groups[i].items(): | |
if len(item_list) < self.config['dataset']['n_threads']: | |
bins[0][i][dataname] = item_list | |
else: | |
num_items_in_bin = len(item_list) // num_bins | |
for j in range(0, len(item_list)): | |
which_bin = min(j // num_items_in_bin, num_bins - 1) | |
if dataname not in bins[which_bin][i]: | |
bins[which_bin][i][dataname] = [] | |
else: | |
bins[which_bin][i][dataname].append(item_list[j]) | |
return bins | |
def check_data_helper(self, data): | |
all_pass = True | |
for paired_file_group in data: | |
for k, v in paired_file_group.items(): | |
if len(v) > 0: | |
for v1 in v: | |
if '.npy' in v1: # case: numpy array or landmark | |
all_pass = all_pass and check_numpy_loaded(v1) | |
else: # case: image | |
all_pass = all_pass and check_img_loaded(v1) | |
return all_pass | |