Spaces:
Sleeping
Sleeping
import argparse | |
import os | |
from utils.util import * | |
from models import create_model | |
from configs import parse_config | |
import torch | |
if __name__ == '__main__': | |
# parse arguments | |
parser = argparse.ArgumentParser(description='Style Master') | |
parser.add_argument('--cfg_file', type=str, default='./exp/cycle_gan_cfg.yaml') | |
parser.add_argument('--ckpt', type=str, default='') | |
parser.add_argument("--output_jit_file", type=str, default='./jit_models/cycle_gan.jit') | |
args = parser.parse_args() | |
# parse config | |
config = parse_config(args.cfg_file) | |
config['common']['phase'] = 'test' | |
model = create_model(config) # create a model given opt.model and other options | |
model.load_networks(0, ckpt=args.ckpt) | |
model.eval() | |
dummy_input = torch.rand(1, config['model']['input_nc'], config['testing']['load_size'], config['testing']['load_size']) | |
traced_script_module, dummy_output, dummy_output_traced = model.trace_jit(dummy_input) | |
if type(dummy_output) is list or type(dummy_output) is tuple: | |
diffs = [] | |
for i in range(len(dummy_output)): | |
diffs.append(np.abs(dummy_output[i].detach().numpy() - dummy_output_traced[i].detach().numpy())) | |
else: | |
diffs = np.abs(dummy_output.detach().numpy() - dummy_output_traced.detach().numpy()) | |
if not type(diffs) is list: | |
diffs = [ diffs ] | |
for i in range(len(diffs)): | |
avg_diff, max_diff, min_diff = np.mean(diffs[i]), np.max(diffs[i]), np.min(diffs[i]) | |
print('Network output ', i + 1) | |
print("average difference between original and traced model: ", avg_diff) | |
print("max difference between original and traced model: ", max_diff) | |
print("min difference between original and traced model: ", min_diff) | |
traced_script_module.save(args.output_jit_file) | |