Spaces:
Build error
Build error
# python3.7 | |
"""Main function for model inference.""" | |
import os.path | |
import shutil | |
import argparse | |
import torch | |
import torch.distributed as dist | |
import runners | |
from utils.logger import build_logger | |
from utils.misc import init_dist | |
from utils.misc import DictAction, parse_config, update_config | |
def parse_args(): | |
"""Parses arguments.""" | |
parser = argparse.ArgumentParser(description='Run model inference.') | |
parser.add_argument('config', type=str, | |
help='Path to the inference configuration.') | |
parser.add_argument('--work_dir', type=str, required=True, | |
help='The work directory to save logs and checkpoints.') | |
parser.add_argument('--checkpoint', type=str, required=True, | |
help='Path to the checkpoint to load. (default: ' | |
'%(default)s)') | |
parser.add_argument('--synthesis_num', type=int, default=1000, | |
help='Number of samples to synthesize. Set as 0 to ' | |
'disable synthesis. (default: %(default)s)') | |
parser.add_argument('--fid_num', type=int, default=50000, | |
help='Number of samples to compute FID. Set as 0 to ' | |
'disable FID test. (default: %(default)s)') | |
parser.add_argument('--use_torchvision', action='store_true', | |
help='Wether to use the Inception model from ' | |
'`torchvision` to compute FID. (default: False)') | |
parser.add_argument('--launcher', type=str, default='pytorch', | |
choices=['pytorch', 'slurm'], | |
help='Launcher type. (default: %(default)s)') | |
parser.add_argument('--backend', type=str, default='nccl', | |
help='Backend for distributed launcher. (default: ' | |
'%(default)s)') | |
parser.add_argument('--rank', type=int, default=-1, | |
help='Node rank for distributed running. (default: ' | |
'%(default)s)') | |
parser.add_argument('--local_rank', type=int, default=0, | |
help='Rank of the current node. (default: %(default)s)') | |
parser.add_argument('--options', nargs='+', action=DictAction, | |
help='arguments in dict') | |
return parser.parse_args() | |
def main(): | |
"""Main function.""" | |
# Parse arguments. | |
args = parse_args() | |
# Parse configurations. | |
config = parse_config(args.config) | |
config = update_config(config, args.options) | |
config.work_dir = args.work_dir | |
config.checkpoint = args.checkpoint | |
config.launcher = args.launcher | |
config.backend = args.backend | |
if not os.path.isfile(config.checkpoint): | |
raise FileNotFoundError(f'Checkpoint file `{config.checkpoint}` is ' | |
f'missing!') | |
# Set CUDNN. | |
config.cudnn_benchmark = config.get('cudnn_benchmark', True) | |
config.cudnn_deterministic = config.get('cudnn_deterministic', False) | |
torch.backends.cudnn.benchmark = config.cudnn_benchmark | |
torch.backends.cudnn.deterministic = config.cudnn_deterministic | |
# Setting for launcher. | |
config.is_distributed = True | |
init_dist(config.launcher, backend=config.backend) | |
config.num_gpus = dist.get_world_size() | |
# Setup logger. | |
if dist.get_rank() == 0: | |
logger_type = config.get('logger_type', 'normal') | |
logger = build_logger(logger_type, work_dir=config.work_dir) | |
shutil.copy(args.config, os.path.join(config.work_dir, 'config.py')) | |
commit_id = os.popen('git rev-parse HEAD').readline() | |
logger.info(f'Commit ID: {commit_id}') | |
else: | |
logger = build_logger('dumb', work_dir=config.work_dir) | |
# Start inference. | |
runner = getattr(runners, config.runner_type)(config, logger) | |
runner.load(filepath=config.checkpoint, | |
running_metadata=False, | |
learning_rate=False, | |
optimizer=False, | |
running_stats=False) | |
if args.synthesis_num > 0: | |
num = args.synthesis_num | |
logger.print() | |
logger.info(f'Synthesizing images ...') | |
runner.synthesize(num, html_name=f'synthesis_{num}.html') | |
logger.info(f'Finish synthesizing {num} images.') | |
if args.fid_num > 0: | |
num = args.fid_num | |
logger.print() | |
logger.info(f'Testing FID ...') | |
fid_value = runner.fid(num, align_tf=not args.use_torchvision) | |
logger.info(f'Finish testing FID on {num} samples. ' | |
f'The result is {fid_value:.6f}.') | |
if __name__ == '__main__': | |
main() | |