Spaces:
Sleeping
Sleeping
import os | |
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE' | |
import sys | |
sys.path.append('./') | |
sys.path.append('../') | |
import skimage.io as skio | |
import skimage.transform as skt | |
import numpy as np | |
from data import CustomDataLoader | |
from data.super_dataset import SuperDataset | |
from models import create_model | |
from configs import parse_config | |
from utils.util import check_path | |
import random | |
import argparse | |
def make_toy_dataset(): | |
check_path('./toy_dataset') | |
# paired | |
check_path('./toy_dataset/trainpairedA') | |
check_path('./toy_dataset/trainpairedB') | |
# paired numpy | |
check_path('./toy_dataset/trainnumpypairedA') | |
check_path('./toy_dataset/trainnumpypairedB') | |
# unpaired | |
check_path('./toy_dataset/trainunpairedA') | |
check_path('./toy_dataset/trainunpairedB') | |
# unpaired numpy | |
check_path('./toy_dataset/trainnumpyunpairedA') | |
check_path('./toy_dataset/trainnumpyunpairedB') | |
# landmark | |
check_path('./toy_dataset/trainlmkA') | |
check_path('./toy_dataset/trainlmkB') | |
for i in range(6): | |
A0 = np.random.randn(8, 8, 3) * 0.5 + 0.5 | |
A0[:,:,0] = 0 | |
A0 = np.clip(A0, 0, 1) | |
A1 = np.random.randn(8, 8, 3) * 0.5 + 0.5 | |
A1[:,:,1] = 0 | |
A1 = np.clip(A1, 0, 1) | |
A2 = np.random.randn(8, 8, 3) * 0.5 + 0.5 | |
A2[:,:,2] = 0 | |
A2 = np.clip(A2, 0, 1) | |
B = np.random.randn(8, 8, 3) * 0.5 + 0.5 | |
B = np.clip(B, 0, 1) | |
A0 = skt.resize(A0, (128, 128)) | |
A1 = skt.resize(A1, (128, 128)) | |
A2 = skt.resize(A2, (128, 128)) | |
B = skt.resize(B, (128, 128)) | |
# paired numpy | |
np.save('./toy_dataset/trainnumpypairedA/%d.npy' % i, A0.astype(np.float32)) | |
np.save('./toy_dataset/trainnumpypairedB/%d.npy' % i, B.astype(np.float32)) | |
# unpaired numpy | |
np.save('./toy_dataset/trainnumpyunpairedA/%d.npy' % i, A0.astype(np.float32)) | |
np.save('./toy_dataset/trainnumpyunpairedB/%d.npy' % i, B.astype(np.float32)) | |
A0 = A0 * 255.0 | |
A1 = A1 * 255.0 | |
A2 = A2 * 255.0 | |
B = B * 255.0 | |
# paired | |
skio.imsave('./toy_dataset/trainpairedA/%d.png' % i, A0.astype(np.uint8)) | |
skio.imsave('./toy_dataset/trainpairedB/%d.png' % i, B.astype(np.uint8)) | |
# unpaired | |
skio.imsave('./toy_dataset/trainunpairedA/%d.png' % i, A0.astype(np.uint8)) | |
skio.imsave('./toy_dataset/trainunpairedB/%d.png' % i, B.astype(np.uint8)) | |
landmark = np.random.rand(101, 2) * 0.5 + 0.5 | |
landmark = np.clip(landmark, 0, 1) | |
# landmark | |
np.save('./toy_dataset/trainlmkA/%d.npy' % i, landmark.astype(np.float32)) | |
np.save('./toy_dataset/trainlmkB/%d.npy' % i, landmark.astype(np.float32)) | |
def main(args): | |
make_toy_dataset() | |
config_dir = './exp' | |
if not os.path.exists(config_dir): | |
config_dir = './../exp' | |
config_files = os.listdir(config_dir) | |
if not args.all_tests: | |
random.shuffle(config_files) | |
config_files = config_files[:2] | |
for cfg in config_files: | |
if (not cfg.endswith('.yaml')) or "example" in cfg: | |
continue | |
print('Current:', cfg) | |
try: | |
# parse config | |
config = parse_config(os.path.join(config_dir, cfg)) | |
config['common']['gpu_ids'] = None | |
config['training']['continue_train'] = False | |
config['dataset']['n_threads'] = 0 | |
config['dataset']['batch_size'] = 2 | |
if 'patch_size' in config['dataset']: | |
config['dataset']['patch_size'] = 64 | |
if 'patch_batch_size' in config['dataset']: | |
config['dataset']['patch_batch_size'] = 2 | |
config['dataset']['preprocess'] = ['scale_width'] | |
config['dataset']['paired_trainA_folder'] = '' | |
config['dataset']['paired_trainB_folder'] = '' | |
config['dataset']['paired_train_filelist'] = '' | |
config['dataset']['paired_valA_folder'] = '' | |
config['dataset']['paired_valB_folder'] = '' | |
config['dataset']['paired_val_filelist'] = '' | |
config['dataset']['unpaired_trainA_folder'] = '' | |
config['dataset']['unpaired_trainB_folder'] = '' | |
config['dataset']['unpaired_trainA_filelist'] = '' | |
config['dataset']['unpaired_trainB_filelist'] = '' | |
config['dataset']['unpaired_valA_folder'] = '' | |
config['dataset']['unpaired_valB_folder'] = '' | |
config['dataset']['unpaired_valA_filelist'] = '' | |
config['dataset']['unpaired_valB_filelist'] = '' | |
config['dataset']['dataroot'] = "./toy_dataset" | |
# create dataset | |
dataset = SuperDataset(config) | |
dataset.config = dataset.convert_old_config_to_new() | |
dataset.static_data.load_static_data() | |
dataset.static_data.create_transforms() | |
print('The number of training images = %d' % len(dataset)) | |
dataloader = CustomDataLoader(config, dataset) | |
# create model | |
model = create_model(config) | |
model.setup(config) | |
# train | |
for data in dataloader: | |
model.set_input(data) | |
model.optimize_parameters() | |
losses = model.get_current_losses() | |
print(losses) | |
except ImportError as error: | |
print(error) | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser(description='ci_test') | |
parser.add_argument('--all_tests', action='store_true') | |
args = parser.parse_args() | |
main(args) | |