ReMoDiffuse / tools /test.py
mingyuan's picture
initial commit
a0d91d3
raw
history blame
3.79 kB
import argparse
import os
import os.path as osp
import mmcv
import torch
from mmcv import DictAction
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
from mmcv.runner import (
get_dist_info,
init_dist,
load_checkpoint,
wrap_fp16_model,
)
from mogen.apis import multi_gpu_test, single_gpu_test
from mogen.datasets import build_dataloader, build_dataset
from mogen.models import build_architecture
def parse_args():
parser = argparse.ArgumentParser(description='mogen evaluation')
parser.add_argument('config', help='test config file path')
parser.add_argument(
'--work-dir', help='the dir to save evaluation results')
parser.add_argument('checkpoint', help='checkpoint file')
parser.add_argument('--out', help='output result file')
parser.add_argument(
'--gpu_collect',
action='store_true',
help='whether to use gpu to collect results')
parser.add_argument('--tmpdir', help='tmp dir for writing some results')
parser.add_argument(
'--cfg-options',
nargs='+',
action=DictAction,
help='override some settings in the used config, the key-value pair '
'in xxx=yyy format will be merged into config file.')
parser.add_argument(
'--launcher',
choices=['none', 'pytorch', 'slurm', 'mpi'],
default='none',
help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)
parser.add_argument(
'--device',
choices=['cpu', 'cuda'],
default='cuda',
help='device used for testing')
args = parser.parse_args()
if 'LOCAL_RANK' not in os.environ:
os.environ['LOCAL_RANK'] = str(args.local_rank)
return args
def main():
args = parse_args()
cfg = mmcv.Config.fromfile(args.config)
if args.cfg_options is not None:
cfg.merge_from_dict(args.cfg_options)
# set cudnn_benchmark
if cfg.get('cudnn_benchmark', False):
torch.backends.cudnn.benchmark = True
cfg.data.test.test_mode = True
# init distributed env first, since logger depends on the dist info.
if args.launcher == 'none':
distributed = False
else:
distributed = True
init_dist(args.launcher, **cfg.dist_params)
# build the dataloader
dataset = build_dataset(cfg.data.test)
# the extra round_up data will be removed during gpu/cpu collect
data_loader = build_dataloader(
dataset,
samples_per_gpu=cfg.data.samples_per_gpu,
workers_per_gpu=cfg.data.workers_per_gpu,
dist=distributed,
shuffle=False,
round_up=False)
# build the model and load checkpoint
model = build_architecture(cfg.model)
fp16_cfg = cfg.get('fp16', None)
if fp16_cfg is not None:
wrap_fp16_model(model)
load_checkpoint(model, args.checkpoint, map_location='cpu')
if not distributed:
if args.device == 'cpu':
model = model.cpu()
else:
model = MMDataParallel(model, device_ids=[0])
outputs = single_gpu_test(model, data_loader)
else:
model = MMDistributedDataParallel(
model.cuda(),
device_ids=[torch.cuda.current_device()],
broadcast_buffers=False)
outputs = multi_gpu_test(model, data_loader, args.tmpdir,
args.gpu_collect)
rank, _ = get_dist_info()
if rank == 0:
mmcv.mkdir_or_exist(osp.abspath(args.work_dir))
results = dataset.evaluate(outputs, args.work_dir)
for k, v in results.items():
print(f'\n{k} : {v:.4f}')
if args.out and rank == 0:
print(f'\nwriting results to {args.out}')
mmcv.dump(results, args.out)
if __name__ == '__main__':
main()