Latte-1 / tools /calc_metrics_for_dataset.py
maxin-cn's picture
Upload folder using huggingface_hub
94bafa8 verified
raw
history blame
6.9 kB
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
"""Calculate quality metrics for previous training run or pretrained network pickle."""
import sys; sys.path.extend(['.', 'tools'])
import os
import click
import tempfile
import torch
from omegaconf import OmegaConf
from tools import dnnlib
from metrics import metric_main
from metrics import metric_utils
from tools.torch_utils import training_stats
from tools.torch_utils import custom_ops
#----------------------------------------------------------------------------
def subprocess_fn(rank, args, temp_dir):
dnnlib.util.Logger(should_flush=True)
# Init torch.distributed.
if args.num_gpus > 1:
init_file = os.path.abspath(os.path.join(temp_dir, '.torch_distributed_init'))
if os.name == 'nt':
init_method = 'file:///' + init_file.replace('\\', '/')
torch.distributed.init_process_group(backend='gloo', init_method=init_method, rank=rank, world_size=args.num_gpus)
else:
init_method = f'file://{init_file}'
torch.distributed.init_process_group(backend='nccl', init_method=init_method, rank=rank, world_size=args.num_gpus)
# Init torch_utils.
sync_device = torch.device('cuda', rank) if args.num_gpus > 1 else None
training_stats.init_multiprocessing(rank=rank, sync_device=sync_device)
if rank != 0 or not args.verbose:
custom_ops.verbosity = 'none'
# Print network summary.
device = torch.device('cuda', rank)
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False
# Calculate each metric.
for metric in args.metrics:
if rank == 0 and args.verbose:
print(f'Calculating {metric}...')
progress = metric_utils.ProgressMonitor(verbose=args.verbose)
result_dict = metric_main.calc_metric(
metric=metric,
dataset_kwargs=args.dataset_kwargs, # real
gen_dataset_kwargs=args.gen_dataset_kwargs, # fake
generator_as_dataset=args.generator_as_dataset,
num_gpus=args.num_gpus,
rank=rank,
device=device,
progress=progress,
cache=args.use_cache,
num_runs=args.num_runs,
)
if rank == 0:
metric_main.report_metric(result_dict, run_dir=args.run_dir)
if rank == 0 and args.verbose:
print()
# Done.
if rank == 0 and args.verbose:
print('Exiting...')
#----------------------------------------------------------------------------
class CommaSeparatedList(click.ParamType):
name = 'list'
def convert(self, value, param, ctx):
_ = param, ctx
if value is None or value.lower() == 'none' or value == '':
return []
return value.split(',')
#----------------------------------------------------------------------------
def calc_metrics_for_dataset(ctx, metrics, real_data_path, fake_data_path, mirror, resolution, gpus, verbose, use_cache: bool, num_runs: int):
dnnlib.util.Logger(should_flush=True)
# Validate arguments.
args = dnnlib.EasyDict(metrics=metrics, num_gpus=gpus, verbose=verbose)
if not all(metric_main.is_valid_metric(metric) for metric in args.metrics):
ctx.fail('\n'.join(['--metrics can only contain the following values:'] + metric_main.list_valid_metrics()))
if not args.num_gpus >= 1:
ctx.fail('--gpus must be at least 1')
dummy_dataset_cfg = OmegaConf.create({'max_num_frames': 10000})
# Initialize dataset options for real data.
args.dataset_kwargs = dnnlib.EasyDict(
class_name='utils.dataset.VideoFramesFolderDataset',
path=real_data_path,
cfg=dummy_dataset_cfg,
xflip=mirror,
resolution=resolution,
use_labels=False,
)
# Initialize dataset options for fake data.
args.gen_dataset_kwargs = dnnlib.EasyDict(
class_name='utils.dataset.VideoFramesFolderDataset',
path=fake_data_path,
cfg=dummy_dataset_cfg,
xflip=False,
resolution=resolution,
use_labels=False,
)
args.generator_as_dataset = True
# Print dataset options.
if args.verbose:
print('Real data options:')
print(args.dataset_kwargs)
print('Fake data options:')
print(args.gen_dataset_kwargs)
print('*' * 50 + 'parting line' + '*' * 50)
print('Fake data options:')
print(args.gen_dataset_kwargs)
# Locate run dir.
args.run_dir = None
args.use_cache = use_cache
args.num_runs = num_runs
# Launch processes.
if args.verbose:
print('Launching processes...')
torch.multiprocessing.set_start_method('spawn')
with tempfile.TemporaryDirectory() as temp_dir:
if args.num_gpus == 1:
subprocess_fn(rank=0, args=args, temp_dir=temp_dir)
else:
torch.multiprocessing.spawn(fn=subprocess_fn, args=(args, temp_dir), nprocs=args.num_gpus)
#----------------------------------------------------------------------------
@click.command()
@click.pass_context
@click.option('--metrics', help='Comma-separated list or "none"', type=CommaSeparatedList(), default='fvd2048_16f,fid50k_full', show_default=True)
@click.option('--real_data_path', help='Dataset to evaluate metrics against (directory or zip) [default: same as training data]', metavar='PATH')
@click.option('--fake_data_path', help='Generated images (directory or zip)', metavar='PATH')
@click.option('--mirror', help='Should we mirror the real data?', type=bool, metavar='BOOL')
@click.option('--resolution', help='Resolution for the source dataset', type=int, metavar='INT')
@click.option('--gpus', help='Number of GPUs to use', type=int, default=1, metavar='INT', show_default=True)
@click.option('--verbose', help='Print optional information', type=bool, default=False, metavar='BOOL', show_default=True)
@click.option('--use_cache', help='Use stats cache', type=bool, default=True, metavar='BOOL', show_default=True)
@click.option('--num_runs', help='Number of runs', type=int, default=1, metavar='INT', show_default=True)
def calc_metrics_cli_wrapper(ctx, *args, **kwargs):
calc_metrics_for_dataset(ctx, *args, **kwargs)
#----------------------------------------------------------------------------
if __name__ == "__main__":
calc_metrics_cli_wrapper() # pylint: disable=no-value-for-parameter
#----------------------------------------------------------------------------