File size: 4,673 Bytes
7e2a2a5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import argparse
import os
import cv2
from tqdm import tqdm

from utils.util import *
from data import CustomDataLoader
from data.super_dataset import SuperDataset
from models import create_model
from configs import parse_config

if __name__ == '__main__':
    # parse arguments
    parser = argparse.ArgumentParser(description='Style Master')
    parser.add_argument('--cfg_file', type=str, default='./exp/sp2pII-phase2.yaml')
    parser.add_argument('--test_img', type=str, default='', help='path to your test img')
    parser.add_argument('--test_video', type=str, default='')
    parser.add_argument('--test_folder', type=str, default='./example/source')
    parser.add_argument('--ckpt', type=str, default='./pretrained_models/phase2_pretrain_90000.pth')
    parser.add_argument('--overwrite_output_dir', type=str, default='./example/outputs/multi-model')
    parser.add_argument('--gpus', type=str, default='0')
    args = parser.parse_args()

    # parse config
    config = parse_config(args.cfg_file)

    # fix gpu ordering
    gpu_string = ','.join(map(str, config['common']['gpu_ids']))
    gpu_ids_fix = list(range(len(config['common']['gpu_ids'])))  # wants GPU ids match nvidia-smi output order

    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpus
    #os.environ['CUDA_VISIBLE_DEVICES'] = gpu_string
    config['common']['gpu_ids'] = gpu_ids_fix

    # hard-code some parameters for test
    config['common']['phase'] = 'test'
    config['dataset']['n_threads'] = 0   # test code only supports num_threads = 0
    config['dataset']['batch_size'] = 1    # test code only supports batch_size = 1
    config['dataset']['serial_batches'] = True  # disable data shuffling; comment this line if results on randomly chosen images are needed.
    config['dataset']['no_flip'] = True    # no flip; comment this line if results on flipped images are needed.

    # override data augmentation
    config['dataset']['load_size'] = config['testing']['load_size']
    config['dataset']['crop_size'] = config['testing']['crop_size']
    config['dataset']['preprocess'] = config['testing']['preprocess']

    # add testing path
    config['testing']['test_img'] = None if args.test_img == '' else args.test_img
    config['testing']['test_video'] = None if args.test_video == '' else args.test_video
    config['testing']['test_folder'] = args.test_folder

    config['training']['pretrained_model'] = args.ckpt
    dataset = SuperDataset(config)
    dataloader = CustomDataLoader(config, dataset)

    model = create_model(config)      # create a model given opt.model and other options
    model.load_networks(0, ckpt=args.ckpt)

    model.eval()

    if args.overwrite_output_dir != '':
        save_path = args.overwrite_output_dir
    else:
        save_path = os.path.join(config['testing']['results_dir'], os.path.splitext(os.path.split(args.cfg_file)[1])[0],
                                config['common']['name'])

    if not os.path.exists(save_path):
        os.makedirs(save_path)

    def reduce(x):
        return reduce(x[0]) if not type(x) is str else x

    ext_name = config['testing']['image_format']
    use_input_format = (ext_name == 'input')
    output_video = (not config['testing']['test_video'] is None)
    vw_dict = {}
    video_paths = []

    for i, data in enumerate(tqdm(dataloader)):
        if i >= config['testing']['num_test']:  # only apply our model to opt.num_test images.
            break
        model.set_input(data)  # unpack data from data loader
        model.test()           # run inference
        visuals = model.get_current_visuals()  # get image results
        img_path = model.get_image_paths()     # get image paths

        # save result
        items = os.path.splitext(os.path.split(reduce(img_path))[1])
        img_fn = items[0]
        if use_input_format:
            ext_name = items[1][1:]

        for k, v in visuals.items():
            if not output_video:
                tensor2file(v, os.path.join(save_path, img_fn + '_' + k), ext_name)
            else:
                img = tensor2im(v)
                if not k in vw_dict:
                    h, w = img.shape[:2]
                    video_path = os.path.join(save_path, k + '_.mp4')
                    video_paths.append(video_path)
                    vw_dict[k] = cv2.VideoWriter(video_path, cv2.VideoWriter_fourcc(*'mp4v'), 30.0, (w, h))
                vw_dict[k].write(img[:,:,::-1])

    for _, v in vw_dict.items():
        v.release()

    # convert to libx264
    for video_path in video_paths:
        os.system('ffmpeg -i {} -c:v libx264 {}'.format(video_path, video_path[:-5] + '.mp4'))
        os.system('rm {}'.format(video_path))