MMFS / tools /test_multi_epoches.py
limoran
add basic files
7e2a2a5
raw
history blame
4.18 kB
import sys
import cv2
import os
import yaml
import numpy as np
from multiprocessing import Process
from tqdm import tqdm
def run_test(param_list, wid, skip):
for cfg_file, test_folder, ckpt_file, out_dir in param_list:
if os.path.exists(out_dir):
if skip:
print(f'{out_dir} exists, skip it.')
continue
else:
print(f'{out_dir} exists, force to remove.')
os.system(f'rm -rf {out_dir}')
cmd = f'python3 test.py --cfg_file {cfg_file} --test_folder {test_folder} --ckpt {ckpt_file} --overwrite_output_dir {out_dir}'
print(f'Working process {wid} runs: {cmd}')
os.system(cmd)
def load_image(input_folder, output_folder, ckpt, fn, postfix, resize_to):
if ckpt == 'blank':
return np.zeros([resize_to[1], resize_to[0], 3], dtype=np.uint8)
if ckpt == 'input':
img = cv2.imread(os.path.join(input_folder, fn.replace(postfix, '')))
else:
img = cv2.imread(os.path.join(output_folder, ckpt, fn))
if not (img.shape[0] == resize_to[1] and img.shape[1] == resize_to[0]):
img = cv2.resize(img, resize_to)
return cv2.putText(img, ckpt, (10, 30), cv2.FONT_HERSHEY_COMPLEX, 1.0, (0, 0, 255), 2)
if __name__ == '__main__':
assert len(sys.argv) == 2, 'Usage: python3 test_multi_epochs.py [path to config].'
# read config
with open(sys.argv[1], 'rt') as f:
config = yaml.safe_load(f)
assert type(config) is dict, 'Loading config file failed.'
# check config
gpu_ids = config['gpu_ids']
skip_existed = config['skip_existed']
cfg_file = config['cfg_file']
ckpt_folder = config['ckpt_folder']
ckpt_layout = config['ckpt_layout']
test_folder = config['test_folder']
output_folder = config['output_folder']
resize_to = config['resize_to']
with open(cfg_file, 'rt') as f:
temp_config = yaml.safe_load(f)
assert type(temp_config) is dict, 'Loading config file failed.'
postfix = temp_config['testing']['visual_names']
assert len(postfix) == 1, 'Count of visual names should be 1.'
postfix = '_' + postfix[0]
temp_config['testing']['image_format'] = 'input'
# run test
if not os.path.exists(output_folder):
os.makedirs(output_folder)
gpu_count = len(gpu_ids)
param_lists = [ [] for _ in range(gpu_count) ]
cnt = 0
for line in ckpt_layout:
assert len(line) == len(ckpt_layout[0]), 'Each row should have the same count of items.'
for ckpt in line:
if not ckpt in ['input', 'blank']:
ckpt_path = os.path.join(ckpt_folder, ckpt + '.pth')
assert os.path.exists(ckpt_path), f'Checkpoint not found: {ckpt_path}'
new_cfg_file = os.path.join(output_folder, ckpt + '.yaml')
out_dir = os.path.join(output_folder, ckpt)
with open(new_cfg_file, 'wt') as f:
temp_config['common']['gpu_ids'] = [ gpu_ids[cnt % gpu_count] ]
yaml.safe_dump(temp_config, f)
param_lists[cnt % gpu_count].append((new_cfg_file, test_folder, ckpt_path, out_dir))
cnt += 1
pool = []
for i in range(gpu_count):
p = Process(target=run_test, args=(param_lists[i], i, skip_existed))
p.start()
pool.append(p)
for p in pool:
p.join()
# concat image
print('Begin to concat.')
out_dir = os.path.join(output_folder, 'concat')
if os.path.exists(out_dir):
print(f'{out_dir} exists, force to remove.')
os.system(f'rm -rf {out_dir}')
os.makedirs(out_dir)
files = os.listdir(param_lists[0][0][3])
for fn in tqdm(files):
if not os.path.splitext(fn)[1].lower() in ['.jpg', '.jpeg', '.png']:
continue
try:
result = np.vstack([
np.hstack([
load_image(test_folder, output_folder, ckpt, fn, postfix, resize_to) for ckpt in line
]) for line in ckpt_layout
])
cv2.imwrite(os.path.join(out_dir, fn.replace(postfix, '')), result)
except Exception as e:
print(fn, e)