Spaces:
Runtime error
Runtime error
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. | |
# | |
# This work is licensed under the Creative Commons Attribution-NonCommercial | |
# 4.0 International License. To view a copy of this license, visit | |
# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to | |
# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. | |
"""Common definitions for GAN metrics.""" | |
import os | |
import time | |
import hashlib | |
import numpy as np | |
import tensorflow as tf | |
import dnnlib | |
import dnnlib.tflib as tflib | |
import config | |
from training import misc | |
from training import dataset | |
#---------------------------------------------------------------------------- | |
# Standard metrics. | |
fid50k = dnnlib.EasyDict(func_name='metrics.frechet_inception_distance.FID', name='fid50k', num_images=50000, minibatch_per_gpu=8) | |
ppl_zfull = dnnlib.EasyDict(func_name='metrics.perceptual_path_length.PPL', name='ppl_zfull', num_samples=100000, epsilon=1e-4, space='z', sampling='full', minibatch_per_gpu=16) | |
ppl_wfull = dnnlib.EasyDict(func_name='metrics.perceptual_path_length.PPL', name='ppl_wfull', num_samples=100000, epsilon=1e-4, space='w', sampling='full', minibatch_per_gpu=16) | |
ppl_zend = dnnlib.EasyDict(func_name='metrics.perceptual_path_length.PPL', name='ppl_zend', num_samples=100000, epsilon=1e-4, space='z', sampling='end', minibatch_per_gpu=16) | |
ppl_wend = dnnlib.EasyDict(func_name='metrics.perceptual_path_length.PPL', name='ppl_wend', num_samples=100000, epsilon=1e-4, space='w', sampling='end', minibatch_per_gpu=16) | |
ls = dnnlib.EasyDict(func_name='metrics.linear_separability.LS', name='ls', num_samples=200000, num_keep=100000, attrib_indices=range(40), minibatch_per_gpu=4) | |
dummy = dnnlib.EasyDict(func_name='metrics.metric_base.DummyMetric', name='dummy') # for debugging | |
#---------------------------------------------------------------------------- | |
# Base class for metrics. | |
class MetricBase: | |
def __init__(self, name): | |
self.name = name | |
self._network_pkl = None | |
self._dataset_args = None | |
self._mirror_augment = None | |
self._results = [] | |
self._eval_time = None | |
def run(self, network_pkl, run_dir=None, dataset_args=None, mirror_augment=None, num_gpus=1, tf_config=None, log_results=True): | |
self._network_pkl = network_pkl | |
self._dataset_args = dataset_args | |
self._mirror_augment = mirror_augment | |
self._results = [] | |
if (dataset_args is None or mirror_augment is None) and run_dir is not None: | |
run_config = misc.parse_config_for_previous_run(run_dir) | |
self._dataset_args = dict(run_config['dataset']) | |
self._dataset_args['shuffle_mb'] = 0 | |
self._mirror_augment = run_config['train'].get('mirror_augment', False) | |
time_begin = time.time() | |
with tf.Graph().as_default(), tflib.create_session(tf_config).as_default(): # pylint: disable=not-context-manager | |
_G, _D, Gs = misc.load_pkl(self._network_pkl) | |
self._evaluate(Gs, num_gpus=num_gpus) | |
self._eval_time = time.time() - time_begin | |
if log_results: | |
result_str = self.get_result_str() | |
if run_dir is not None: | |
log = os.path.join(run_dir, 'metric-%s.txt' % self.name) | |
with dnnlib.util.Logger(log, 'a'): | |
print(result_str) | |
else: | |
print(result_str) | |
def get_result_str(self): | |
network_name = os.path.splitext(os.path.basename(self._network_pkl))[0] | |
if len(network_name) > 29: | |
network_name = '...' + network_name[-26:] | |
result_str = '%-30s' % network_name | |
result_str += ' time %-12s' % dnnlib.util.format_time(self._eval_time) | |
for res in self._results: | |
result_str += ' ' + self.name + res.suffix + ' ' | |
result_str += res.fmt % res.value | |
return result_str | |
def update_autosummaries(self): | |
for res in self._results: | |
tflib.autosummary.autosummary('Metrics/' + self.name + res.suffix, res.value) | |
def _evaluate(self, Gs, num_gpus): | |
raise NotImplementedError # to be overridden by subclasses | |
def _report_result(self, value, suffix='', fmt='%-10.4f'): | |
self._results += [dnnlib.EasyDict(value=value, suffix=suffix, fmt=fmt)] | |
def _get_cache_file_for_reals(self, extension='pkl', **kwargs): | |
all_args = dnnlib.EasyDict(metric_name=self.name, mirror_augment=self._mirror_augment) | |
all_args.update(self._dataset_args) | |
all_args.update(kwargs) | |
md5 = hashlib.md5(repr(sorted(all_args.items())).encode('utf-8')) | |
dataset_name = self._dataset_args['tfrecord_dir'].replace('\\', '/').split('/')[-1] | |
return os.path.join(config.cache_dir, '%s-%s-%s.%s' % (md5.hexdigest(), self.name, dataset_name, extension)) | |
def _iterate_reals(self, minibatch_size): | |
dataset_obj = dataset.load_dataset(data_dir=config.data_dir, **self._dataset_args) | |
while True: | |
images, _labels = dataset_obj.get_minibatch_np(minibatch_size) | |
if self._mirror_augment: | |
images = misc.apply_mirror_augment(images) | |
yield images | |
def _iterate_fakes(self, Gs, minibatch_size, num_gpus): | |
while True: | |
latents = np.random.randn(minibatch_size, *Gs.input_shape[1:]) | |
fmt = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True) | |
images = Gs.run(latents, None, output_transform=fmt, is_validation=True, num_gpus=num_gpus, assume_frozen=True) | |
yield images | |
#---------------------------------------------------------------------------- | |
# Group of multiple metrics. | |
class MetricGroup: | |
def __init__(self, metric_kwarg_list): | |
self.metrics = [dnnlib.util.call_func_by_name(**kwargs) for kwargs in metric_kwarg_list] | |
def run(self, *args, **kwargs): | |
for metric in self.metrics: | |
metric.run(*args, **kwargs) | |
def get_result_str(self): | |
return ' '.join(metric.get_result_str() for metric in self.metrics) | |
def update_autosummaries(self): | |
for metric in self.metrics: | |
metric.update_autosummaries() | |
#---------------------------------------------------------------------------- | |
# Dummy metric for debugging purposes. | |
class DummyMetric(MetricBase): | |
def _evaluate(self, Gs, num_gpus): | |
_ = Gs, num_gpus | |
self._report_result(0.0) | |
#---------------------------------------------------------------------------- | |