# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. # # This work is licensed under the Creative Commons Attribution-NonCommercial # 4.0 International License. To view a copy of this license, visit # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. """Common definitions for GAN metrics.""" import os import time import hashlib import numpy as np import tensorflow as tf import dnnlib import dnnlib.tflib as tflib import config from training import misc from training import dataset #---------------------------------------------------------------------------- # Standard metrics. fid50k = dnnlib.EasyDict(func_name='metrics.frechet_inception_distance.FID', name='fid50k', num_images=50000, minibatch_per_gpu=8) ppl_zfull = dnnlib.EasyDict(func_name='metrics.perceptual_path_length.PPL', name='ppl_zfull', num_samples=100000, epsilon=1e-4, space='z', sampling='full', minibatch_per_gpu=16) ppl_wfull = dnnlib.EasyDict(func_name='metrics.perceptual_path_length.PPL', name='ppl_wfull', num_samples=100000, epsilon=1e-4, space='w', sampling='full', minibatch_per_gpu=16) ppl_zend = dnnlib.EasyDict(func_name='metrics.perceptual_path_length.PPL', name='ppl_zend', num_samples=100000, epsilon=1e-4, space='z', sampling='end', minibatch_per_gpu=16) ppl_wend = dnnlib.EasyDict(func_name='metrics.perceptual_path_length.PPL', name='ppl_wend', num_samples=100000, epsilon=1e-4, space='w', sampling='end', minibatch_per_gpu=16) ls = dnnlib.EasyDict(func_name='metrics.linear_separability.LS', name='ls', num_samples=200000, num_keep=100000, attrib_indices=range(40), minibatch_per_gpu=4) dummy = dnnlib.EasyDict(func_name='metrics.metric_base.DummyMetric', name='dummy') # for debugging #---------------------------------------------------------------------------- # Base class for metrics. class MetricBase: def __init__(self, name): self.name = name self._network_pkl = None self._dataset_args = None self._mirror_augment = None self._results = [] self._eval_time = None def run(self, network_pkl, run_dir=None, dataset_args=None, mirror_augment=None, num_gpus=1, tf_config=None, log_results=True): self._network_pkl = network_pkl self._dataset_args = dataset_args self._mirror_augment = mirror_augment self._results = [] if (dataset_args is None or mirror_augment is None) and run_dir is not None: run_config = misc.parse_config_for_previous_run(run_dir) self._dataset_args = dict(run_config['dataset']) self._dataset_args['shuffle_mb'] = 0 self._mirror_augment = run_config['train'].get('mirror_augment', False) time_begin = time.time() with tf.Graph().as_default(), tflib.create_session(tf_config).as_default(): # pylint: disable=not-context-manager _G, _D, Gs = misc.load_pkl(self._network_pkl) self._evaluate(Gs, num_gpus=num_gpus) self._eval_time = time.time() - time_begin if log_results: result_str = self.get_result_str() if run_dir is not None: log = os.path.join(run_dir, 'metric-%s.txt' % self.name) with dnnlib.util.Logger(log, 'a'): print(result_str) else: print(result_str) def get_result_str(self): network_name = os.path.splitext(os.path.basename(self._network_pkl))[0] if len(network_name) > 29: network_name = '...' + network_name[-26:] result_str = '%-30s' % network_name result_str += ' time %-12s' % dnnlib.util.format_time(self._eval_time) for res in self._results: result_str += ' ' + self.name + res.suffix + ' ' result_str += res.fmt % res.value return result_str def update_autosummaries(self): for res in self._results: tflib.autosummary.autosummary('Metrics/' + self.name + res.suffix, res.value) def _evaluate(self, Gs, num_gpus): raise NotImplementedError # to be overridden by subclasses def _report_result(self, value, suffix='', fmt='%-10.4f'): self._results += [dnnlib.EasyDict(value=value, suffix=suffix, fmt=fmt)] def _get_cache_file_for_reals(self, extension='pkl', **kwargs): all_args = dnnlib.EasyDict(metric_name=self.name, mirror_augment=self._mirror_augment) all_args.update(self._dataset_args) all_args.update(kwargs) md5 = hashlib.md5(repr(sorted(all_args.items())).encode('utf-8')) dataset_name = self._dataset_args['tfrecord_dir'].replace('\\', '/').split('/')[-1] return os.path.join(config.cache_dir, '%s-%s-%s.%s' % (md5.hexdigest(), self.name, dataset_name, extension)) def _iterate_reals(self, minibatch_size): dataset_obj = dataset.load_dataset(data_dir=config.data_dir, **self._dataset_args) while True: images, _labels = dataset_obj.get_minibatch_np(minibatch_size) if self._mirror_augment: images = misc.apply_mirror_augment(images) yield images def _iterate_fakes(self, Gs, minibatch_size, num_gpus): while True: latents = np.random.randn(minibatch_size, *Gs.input_shape[1:]) fmt = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True) images = Gs.run(latents, None, output_transform=fmt, is_validation=True, num_gpus=num_gpus, assume_frozen=True) yield images #---------------------------------------------------------------------------- # Group of multiple metrics. class MetricGroup: def __init__(self, metric_kwarg_list): self.metrics = [dnnlib.util.call_func_by_name(**kwargs) for kwargs in metric_kwarg_list] def run(self, *args, **kwargs): for metric in self.metrics: metric.run(*args, **kwargs) def get_result_str(self): return ' '.join(metric.get_result_str() for metric in self.metrics) def update_autosummaries(self): for metric in self.metrics: metric.update_autosummaries() #---------------------------------------------------------------------------- # Dummy metric for debugging purposes. class DummyMetric(MetricBase): def _evaluate(self, Gs, num_gpus): _ = Gs, num_gpus self._report_result(0.0) #----------------------------------------------------------------------------