Spaces:
Sleeping
Sleeping
import argparse | |
import os | |
import cv2 | |
from tqdm import tqdm | |
from utils.util import * | |
from data import CustomDataLoader | |
from data.super_dataset import SuperDataset | |
from models import create_model | |
from configs import parse_config | |
if __name__ == '__main__': | |
# parse arguments | |
parser = argparse.ArgumentParser(description='Style Master') | |
parser.add_argument('--cfg_file', type=str, default='./exp/sp2pII-phase2.yaml') | |
parser.add_argument('--test_img', type=str, default='', help='path to your test img') | |
parser.add_argument('--test_video', type=str, default='') | |
parser.add_argument('--test_folder', type=str, default='./example/source') | |
parser.add_argument('--ckpt', type=str, default='./pretrained_models/phase2_pretrain_90000.pth') | |
parser.add_argument('--overwrite_output_dir', type=str, default='./example/outputs/multi-model') | |
parser.add_argument('--gpus', type=str, default='0') | |
args = parser.parse_args() | |
# parse config | |
config = parse_config(args.cfg_file) | |
# fix gpu ordering | |
gpu_string = ','.join(map(str, config['common']['gpu_ids'])) | |
gpu_ids_fix = list(range(len(config['common']['gpu_ids']))) # wants GPU ids match nvidia-smi output order | |
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpus | |
#os.environ['CUDA_VISIBLE_DEVICES'] = gpu_string | |
config['common']['gpu_ids'] = gpu_ids_fix | |
# hard-code some parameters for test | |
config['common']['phase'] = 'test' | |
config['dataset']['n_threads'] = 0 # test code only supports num_threads = 0 | |
config['dataset']['batch_size'] = 1 # test code only supports batch_size = 1 | |
config['dataset']['serial_batches'] = True # disable data shuffling; comment this line if results on randomly chosen images are needed. | |
config['dataset']['no_flip'] = True # no flip; comment this line if results on flipped images are needed. | |
# override data augmentation | |
config['dataset']['load_size'] = config['testing']['load_size'] | |
config['dataset']['crop_size'] = config['testing']['crop_size'] | |
config['dataset']['preprocess'] = config['testing']['preprocess'] | |
# add testing path | |
config['testing']['test_img'] = None if args.test_img == '' else args.test_img | |
config['testing']['test_video'] = None if args.test_video == '' else args.test_video | |
config['testing']['test_folder'] = args.test_folder | |
config['training']['pretrained_model'] = args.ckpt | |
dataset = SuperDataset(config) | |
dataloader = CustomDataLoader(config, dataset) | |
model = create_model(config) # create a model given opt.model and other options | |
model.load_networks(0, ckpt=args.ckpt) | |
model.eval() | |
if args.overwrite_output_dir != '': | |
save_path = args.overwrite_output_dir | |
else: | |
save_path = os.path.join(config['testing']['results_dir'], os.path.splitext(os.path.split(args.cfg_file)[1])[0], | |
config['common']['name']) | |
if not os.path.exists(save_path): | |
os.makedirs(save_path) | |
def reduce(x): | |
return reduce(x[0]) if not type(x) is str else x | |
ext_name = config['testing']['image_format'] | |
use_input_format = (ext_name == 'input') | |
output_video = (not config['testing']['test_video'] is None) | |
vw_dict = {} | |
video_paths = [] | |
for i, data in enumerate(tqdm(dataloader)): | |
if i >= config['testing']['num_test']: # only apply our model to opt.num_test images. | |
break | |
model.set_input(data) # unpack data from data loader | |
model.test() # run inference | |
visuals = model.get_current_visuals() # get image results | |
img_path = model.get_image_paths() # get image paths | |
# save result | |
items = os.path.splitext(os.path.split(reduce(img_path))[1]) | |
img_fn = items[0] | |
if use_input_format: | |
ext_name = items[1][1:] | |
for k, v in visuals.items(): | |
if not output_video: | |
tensor2file(v, os.path.join(save_path, img_fn + '_' + k), ext_name) | |
else: | |
img = tensor2im(v) | |
if not k in vw_dict: | |
h, w = img.shape[:2] | |
video_path = os.path.join(save_path, k + '_.mp4') | |
video_paths.append(video_path) | |
vw_dict[k] = cv2.VideoWriter(video_path, cv2.VideoWriter_fourcc(*'mp4v'), 30.0, (w, h)) | |
vw_dict[k].write(img[:,:,::-1]) | |
for _, v in vw_dict.items(): | |
v.release() | |
# convert to libx264 | |
for video_path in video_paths: | |
os.system('ffmpeg -i {} -c:v libx264 {}'.format(video_path, video_path[:-5] + '.mp4')) | |
os.system('rm {}'.format(video_path)) | |