MMFS / tools /ci_test.py
limoran
add basic files
7e2a2a5
raw
history blame
5.59 kB
import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
import sys
sys.path.append('./')
sys.path.append('../')
import skimage.io as skio
import skimage.transform as skt
import numpy as np
from data import CustomDataLoader
from data.super_dataset import SuperDataset
from models import create_model
from configs import parse_config
from utils.util import check_path
import random
import argparse
def make_toy_dataset():
check_path('./toy_dataset')
# paired
check_path('./toy_dataset/trainpairedA')
check_path('./toy_dataset/trainpairedB')
# paired numpy
check_path('./toy_dataset/trainnumpypairedA')
check_path('./toy_dataset/trainnumpypairedB')
# unpaired
check_path('./toy_dataset/trainunpairedA')
check_path('./toy_dataset/trainunpairedB')
# unpaired numpy
check_path('./toy_dataset/trainnumpyunpairedA')
check_path('./toy_dataset/trainnumpyunpairedB')
# landmark
check_path('./toy_dataset/trainlmkA')
check_path('./toy_dataset/trainlmkB')
for i in range(6):
A0 = np.random.randn(8, 8, 3) * 0.5 + 0.5
A0[:,:,0] = 0
A0 = np.clip(A0, 0, 1)
A1 = np.random.randn(8, 8, 3) * 0.5 + 0.5
A1[:,:,1] = 0
A1 = np.clip(A1, 0, 1)
A2 = np.random.randn(8, 8, 3) * 0.5 + 0.5
A2[:,:,2] = 0
A2 = np.clip(A2, 0, 1)
B = np.random.randn(8, 8, 3) * 0.5 + 0.5
B = np.clip(B, 0, 1)
A0 = skt.resize(A0, (128, 128))
A1 = skt.resize(A1, (128, 128))
A2 = skt.resize(A2, (128, 128))
B = skt.resize(B, (128, 128))
# paired numpy
np.save('./toy_dataset/trainnumpypairedA/%d.npy' % i, A0.astype(np.float32))
np.save('./toy_dataset/trainnumpypairedB/%d.npy' % i, B.astype(np.float32))
# unpaired numpy
np.save('./toy_dataset/trainnumpyunpairedA/%d.npy' % i, A0.astype(np.float32))
np.save('./toy_dataset/trainnumpyunpairedB/%d.npy' % i, B.astype(np.float32))
A0 = A0 * 255.0
A1 = A1 * 255.0
A2 = A2 * 255.0
B = B * 255.0
# paired
skio.imsave('./toy_dataset/trainpairedA/%d.png' % i, A0.astype(np.uint8))
skio.imsave('./toy_dataset/trainpairedB/%d.png' % i, B.astype(np.uint8))
# unpaired
skio.imsave('./toy_dataset/trainunpairedA/%d.png' % i, A0.astype(np.uint8))
skio.imsave('./toy_dataset/trainunpairedB/%d.png' % i, B.astype(np.uint8))
landmark = np.random.rand(101, 2) * 0.5 + 0.5
landmark = np.clip(landmark, 0, 1)
# landmark
np.save('./toy_dataset/trainlmkA/%d.npy' % i, landmark.astype(np.float32))
np.save('./toy_dataset/trainlmkB/%d.npy' % i, landmark.astype(np.float32))
def main(args):
make_toy_dataset()
config_dir = './exp'
if not os.path.exists(config_dir):
config_dir = './../exp'
config_files = os.listdir(config_dir)
if not args.all_tests:
random.shuffle(config_files)
config_files = config_files[:2]
for cfg in config_files:
if (not cfg.endswith('.yaml')) or "example" in cfg:
continue
print('Current:', cfg)
try:
# parse config
config = parse_config(os.path.join(config_dir, cfg))
config['common']['gpu_ids'] = None
config['training']['continue_train'] = False
config['dataset']['n_threads'] = 0
config['dataset']['batch_size'] = 2
if 'patch_size' in config['dataset']:
config['dataset']['patch_size'] = 64
if 'patch_batch_size' in config['dataset']:
config['dataset']['patch_batch_size'] = 2
config['dataset']['preprocess'] = ['scale_width']
config['dataset']['paired_trainA_folder'] = ''
config['dataset']['paired_trainB_folder'] = ''
config['dataset']['paired_train_filelist'] = ''
config['dataset']['paired_valA_folder'] = ''
config['dataset']['paired_valB_folder'] = ''
config['dataset']['paired_val_filelist'] = ''
config['dataset']['unpaired_trainA_folder'] = ''
config['dataset']['unpaired_trainB_folder'] = ''
config['dataset']['unpaired_trainA_filelist'] = ''
config['dataset']['unpaired_trainB_filelist'] = ''
config['dataset']['unpaired_valA_folder'] = ''
config['dataset']['unpaired_valB_folder'] = ''
config['dataset']['unpaired_valA_filelist'] = ''
config['dataset']['unpaired_valB_filelist'] = ''
config['dataset']['dataroot'] = "./toy_dataset"
# create dataset
dataset = SuperDataset(config)
dataset.config = dataset.convert_old_config_to_new()
dataset.static_data.load_static_data()
dataset.static_data.create_transforms()
print('The number of training images = %d' % len(dataset))
dataloader = CustomDataLoader(config, dataset)
# create model
model = create_model(config)
model.setup(config)
# train
for data in dataloader:
model.set_input(data)
model.optimize_parameters()
losses = model.get_current_losses()
print(losses)
except ImportError as error:
print(error)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='ci_test')
parser.add_argument('--all_tests', action='store_true')
args = parser.parse_args()
main(args)