Spaces:
Runtime error
Runtime error
import os | |
import json | |
import argparse | |
import numpy as np | |
import torch | |
import stylegan2 | |
from stylegan2 import utils | |
#---------------------------------------------------------------------------- | |
_description = """Metrics evaluation. | |
Run 'python %(prog)s <subcommand> --help' for subcommand help.""" | |
#---------------------------------------------------------------------------- | |
_examples = """examples: | |
# Train a network or convert a pretrained one. In this example we first convert a pretrained one. | |
python run_convert_from_tf --download ffhq-config-f --output G.pth D.pth Gs.pth | |
# Project generated images | |
python %(prog)s project_generated_images --network=Gs.pth --seeds=0,1,5 | |
# Project real images | |
python %(prog)s project_real_images --network=Gs.pth --data-dir=path/to/image_folder | |
""" | |
#---------------------------------------------------------------------------- | |
def _add_shared_arguments(parser): | |
parser.add_argument( | |
'--network', | |
help='Network file path', | |
required=True, | |
metavar='FILE' | |
) | |
parser.add_argument( | |
'--num_samples', | |
type=int, | |
help='Number of samples to gather for evaluating ' + \ | |
'this metric. Default: %(default)s', | |
default=50000, | |
metavar='VALUE' | |
) | |
parser.add_argument( | |
'--size', | |
type=int, | |
help='Rescale images so that this is the size of their ' + \ | |
'smallest side in pixels. Default: Unscaled', | |
default=None, | |
metavar='VALUE' | |
) | |
parser.add_argument( | |
'--batch_size', | |
help='Batch size for generator. Default: %(default)s', | |
type=int, | |
default=1, | |
metavar='VALUE' | |
) | |
parser.add_argument( | |
'--output', | |
help='Root directory for run results. Default: %(default)s', | |
type=str, | |
default='./results', | |
metavar='DIR' | |
) | |
parser.add_argument( | |
'--pixel_min', | |
help='Minumum of the value range of pixels in generated images. ' + \ | |
'Default: %(default)s', | |
default=-1, | |
type=float, | |
metavar='VALUE' | |
) | |
parser.add_argument( | |
'--pixel_max', | |
help='Maximum of the value range of pixels in generated images. ' + \ | |
'Default: %(default)s', | |
default=1, | |
type=float, | |
metavar='VALUE' | |
) | |
parser.add_argument( | |
'--gpu', | |
help='CUDA device indices (given as separate ' + \ | |
'values if multiple, i.e. "--gpu 0 1"). Default: Use CPU', | |
type=int, | |
default=[], | |
nargs='*', | |
metavar='INDEX' | |
) | |
def get_arg_parser(): | |
parser = argparse.ArgumentParser( | |
description=_description, | |
epilog=_examples, | |
formatter_class=argparse.RawDescriptionHelpFormatter | |
) | |
subparsers = parser.add_subparsers(help='Sub-commands', dest='command') | |
fid_parser = subparsers.add_parser('fid', help='Calculate FID') | |
fid_parser.add_argument( | |
'--data_dir', | |
help='Dataset root directory', | |
required=True, | |
metavar='DIR' | |
) | |
fid_parser.add_argument( | |
'--reals_batch_size', | |
help='Batch size for gathering statistics of reals. Default: %(default)s', | |
type=int, | |
default=1, | |
metavar='VALUE' | |
) | |
fid_parser.add_argument( | |
'--reals_data_workers', | |
help='Data workers for fetching real data samples. Default: %(default)s', | |
type=int, | |
default=4, | |
metavar='VALUE' | |
) | |
fid_parser.add_argument( | |
'--truncation_psi', | |
help='Truncation psi. Default: %(default)s', | |
type=float, | |
default=1.0, | |
metavar='VALUE' | |
) | |
_add_shared_arguments(fid_parser) | |
ppl_parser = subparsers.add_parser('ppl', help='Calculate PPL') | |
ppl_parser.add_argument( | |
'--epsilon', | |
type=float, | |
help='Perturbation value. Default: %(default)s', | |
default=1e-4, | |
metavar='VALUE' | |
) | |
ppl_parser.add_argument( | |
'--use_dlatent', | |
type=utils.bool_type, | |
help='Measure on perturbations of disentangled latents ' + \ | |
'instead of raw latents. Default: %(default)s', | |
default=True, | |
const=True, | |
nargs='?', | |
metavar='BOOL' | |
) | |
ppl_parser.add_argument( | |
'--full_sampling', | |
type=utils.bool_type, | |
help='Measure on random interpolation between two inputs ' + \ | |
'instead of directly on one input. Default: %(default)s', | |
default=False, | |
const=True, | |
nargs='?', | |
metavar='BOOL' | |
) | |
parser.add_argument( | |
'--ppl_ffhq_crop', | |
help='Crop images evaluated for PPL with crop values ' + \ | |
'for FFHQ. Default: False', | |
type=utils.bool_type, | |
const=True, | |
nargs='?', | |
default=False, | |
metavar='BOOL' | |
) | |
_add_shared_arguments(ppl_parser) | |
return parser | |
#---------------------------------------------------------------------------- | |
def _report_metric(value, name, args): | |
fpath = os.path.join(args.output, 'metrics.json') | |
metrics = {} | |
if os.path.exists(fpath): | |
with open(fpath, 'r') as fp: | |
try: | |
metrics = json.load(fp) | |
except Exception: | |
pass | |
metrics[name] = value | |
with open(fpath, 'w') as fp: | |
json.dump(metrics, fp) | |
print('\n\nMetric evaluated!:') | |
print('{}: {}'.format(name, value)) | |
#---------------------------------------------------------------------------- | |
def eval_fid(G, prior_generator, args): | |
assert args.data_dir, '--data_dir has to be specified.' | |
dataset = utils.ImageFolder( | |
args.data_dir, | |
pixel_min=args.pixel_min, | |
pixel_max=args.pixel_max | |
) | |
assert len(dataset), 'No images found at {}'.format(args.data_dir) | |
inception = stylegan2.external_models.inception.InceptionV3FeatureExtractor( | |
pixel_min=args.pixel_min, pixel_max=args.pixel_max) | |
if len(args.gpu) > 1: | |
inception = torch.nn.DataParallel(inception, device_ids=args.gpu) | |
args.reals_batch_size = max(args.reals_batch_size, len(args.gpu)) | |
fid = stylegan2.metrics.fid.FID( | |
G=G, | |
prior_generator=prior_generator, | |
dataset=dataset, | |
num_samples=args.num_samples, | |
fid_model=inception, | |
fid_size=args.size, | |
truncation_psi=args.truncation_psi, | |
reals_batch_size=args.reals_batch_size, | |
reals_data_workers=args.reals_data_workers | |
) | |
value = fid.evaluate() | |
name = 'FID' | |
if args.size: | |
name += '({})'.format(args.size) | |
if args.truncation_psi != 1: | |
name +='trunc{}'.format(args.truncation_psi) | |
name += ':{}k'.format(args.num_samples // 1000) | |
_report_metric(value, name, args) | |
#---------------------------------------------------------------------------- | |
def eval_ppl(G, prior_generator, args): | |
lpips = stylegan2.external_models.lpips.LPIPS_VGG16( | |
pixel_min=args.pixel_min, pixel_max=args.pixel_max) | |
if len(args.gpu) > 1: | |
lpips = torch.nn.DataParallel(lpips, device_ids=args.gpu) | |
crop = None | |
if args.ppl_ffhq_crop: | |
crop = stylegan2.metrics.ppl.PPL.FFHQ_CROP | |
ppl = stylegan2.metrics.ppl.PPL( | |
G=G, | |
prior_generator=prior_generator, | |
num_samples=args.num_samples, | |
epsilon=args.epsilon, | |
use_dlatent=args.use_dlatent, | |
full_sampling=args.full_sampling, | |
crop=crop, | |
lpips_model=lpips, | |
lpips_size=args.size, | |
) | |
value = ppl.evaluate() | |
name = 'PPL' | |
if args.size: | |
name += '({})'.format(args.size) | |
if args.use_dlatent: | |
name += 'W' | |
else: | |
name += 'Z' | |
if args.full_sampling: | |
name += '-full' | |
else: | |
name += '-end' | |
name += ':{}k'.format(args.num_samples // 1000) | |
_report_metric(value, name, args) | |
#---------------------------------------------------------------------------- | |
def main(): | |
args = get_arg_parser().parse_args() | |
assert args.command, 'Missing subcommand.' | |
assert os.path.isdir(args.output) or not os.path.splitext(args.output)[-1], \ | |
'--output argument should specify a directory, not a file.' | |
if not os.path.exists(args.output): | |
os.makedirs(args.output) | |
G = stylegan2.models.load(args.network) | |
assert isinstance(G, stylegan2.models.Generator), 'Model type has to be ' + \ | |
'stylegan2.models.Generator. Found {}.'.format(type(G)) | |
latent_size, label_size = G.latent_size, G.label_size | |
device = torch.device(args.gpu[0] if args.gpu else 'cpu') | |
if device.index is not None: | |
torch.cuda.set_device(device.index) | |
G.to(device).eval().requires_grad_(False) | |
if len(args.gpu) > 1: | |
G = torch.nn.DataParallel(G, device_ids=args.gpu) | |
args.batch_size = max(args.batch_size, len(args.gpu)) | |
prior_generator = utils.PriorGenerator( | |
latent_size=latent_size, | |
label_size=label_size, | |
batch_size=args.batch_size, | |
device=device | |
) | |
if args.command == 'fid': | |
eval_fid(G, prior_generator, args) | |
elif args.command == 'ppl': | |
eval_ppl(G, prior_generator, args) | |
else: | |
raise TypeError('Unkown command {}'.format(args.command)) | |
if __name__ == '__main__': | |
main() | |