MMFS / test.py
limoran
add basic files
7e2a2a5
import argparse
import os
import cv2
from tqdm import tqdm
from utils.util import *
from data import CustomDataLoader
from data.super_dataset import SuperDataset
from models import create_model
from configs import parse_config
if __name__ == '__main__':
# parse arguments
parser = argparse.ArgumentParser(description='Style Master')
parser.add_argument('--cfg_file', type=str, default='./exp/sp2pII-phase2.yaml')
parser.add_argument('--test_img', type=str, default='', help='path to your test img')
parser.add_argument('--test_video', type=str, default='')
parser.add_argument('--test_folder', type=str, default='./example/source')
parser.add_argument('--ckpt', type=str, default='./pretrained_models/phase2_pretrain_90000.pth')
parser.add_argument('--overwrite_output_dir', type=str, default='./example/outputs/multi-model')
parser.add_argument('--gpus', type=str, default='0')
args = parser.parse_args()
# parse config
config = parse_config(args.cfg_file)
# fix gpu ordering
gpu_string = ','.join(map(str, config['common']['gpu_ids']))
gpu_ids_fix = list(range(len(config['common']['gpu_ids']))) # wants GPU ids match nvidia-smi output order
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpus
#os.environ['CUDA_VISIBLE_DEVICES'] = gpu_string
config['common']['gpu_ids'] = gpu_ids_fix
# hard-code some parameters for test
config['common']['phase'] = 'test'
config['dataset']['n_threads'] = 0 # test code only supports num_threads = 0
config['dataset']['batch_size'] = 1 # test code only supports batch_size = 1
config['dataset']['serial_batches'] = True # disable data shuffling; comment this line if results on randomly chosen images are needed.
config['dataset']['no_flip'] = True # no flip; comment this line if results on flipped images are needed.
# override data augmentation
config['dataset']['load_size'] = config['testing']['load_size']
config['dataset']['crop_size'] = config['testing']['crop_size']
config['dataset']['preprocess'] = config['testing']['preprocess']
# add testing path
config['testing']['test_img'] = None if args.test_img == '' else args.test_img
config['testing']['test_video'] = None if args.test_video == '' else args.test_video
config['testing']['test_folder'] = args.test_folder
config['training']['pretrained_model'] = args.ckpt
dataset = SuperDataset(config)
dataloader = CustomDataLoader(config, dataset)
model = create_model(config) # create a model given opt.model and other options
model.load_networks(0, ckpt=args.ckpt)
model.eval()
if args.overwrite_output_dir != '':
save_path = args.overwrite_output_dir
else:
save_path = os.path.join(config['testing']['results_dir'], os.path.splitext(os.path.split(args.cfg_file)[1])[0],
config['common']['name'])
if not os.path.exists(save_path):
os.makedirs(save_path)
def reduce(x):
return reduce(x[0]) if not type(x) is str else x
ext_name = config['testing']['image_format']
use_input_format = (ext_name == 'input')
output_video = (not config['testing']['test_video'] is None)
vw_dict = {}
video_paths = []
for i, data in enumerate(tqdm(dataloader)):
if i >= config['testing']['num_test']: # only apply our model to opt.num_test images.
break
model.set_input(data) # unpack data from data loader
model.test() # run inference
visuals = model.get_current_visuals() # get image results
img_path = model.get_image_paths() # get image paths
# save result
items = os.path.splitext(os.path.split(reduce(img_path))[1])
img_fn = items[0]
if use_input_format:
ext_name = items[1][1:]
for k, v in visuals.items():
if not output_video:
tensor2file(v, os.path.join(save_path, img_fn + '_' + k), ext_name)
else:
img = tensor2im(v)
if not k in vw_dict:
h, w = img.shape[:2]
video_path = os.path.join(save_path, k + '_.mp4')
video_paths.append(video_path)
vw_dict[k] = cv2.VideoWriter(video_path, cv2.VideoWriter_fourcc(*'mp4v'), 30.0, (w, h))
vw_dict[k].write(img[:,:,::-1])
for _, v in vw_dict.items():
v.release()
# convert to libx264
for video_path in video_paths:
os.system('ffmpeg -i {} -c:v libx264 {}'.format(video_path, video_path[:-5] + '.mp4'))
os.system('rm {}'.format(video_path))