|
"""Inference a pretrained model.""" |
|
|
|
import argparse |
|
import os |
|
|
|
import datasets |
|
import torch |
|
from mmcv import Config, DictAction |
|
from mmcv.cnn import fuse_conv_bn |
|
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel |
|
from mmcv.runner import ( |
|
get_dist_info, |
|
init_dist, |
|
load_checkpoint, |
|
wrap_fp16_model, |
|
) |
|
from mmdet.apis import multi_gpu_test, single_gpu_test |
|
from mmdet.datasets import ( |
|
build_dataloader, |
|
build_dataset, |
|
replace_ImageToTensor, |
|
) |
|
from mmdet.models import build_detector |
|
|
|
MODEL_SERVER = "https://dl.cv.ethz.ch/bdd100k/det/models/" |
|
|
|
|
|
def parse_args() -> argparse.Namespace: |
|
"""Arguements definitions.""" |
|
parser = argparse.ArgumentParser( |
|
description="MMDet test (and eval) a model" |
|
) |
|
parser.add_argument("config", help="test config file path") |
|
parser.add_argument( |
|
"--work-dir", |
|
help="the directory to save the file containing evaluation metrics", |
|
) |
|
parser.add_argument( |
|
"--fuse-conv-bn", |
|
action="store_true", |
|
help="Whether to fuse conv and bn, this will slightly increase" |
|
"the inference speed", |
|
) |
|
parser.add_argument( |
|
"--format-only", |
|
action="store_true", |
|
help="Format the output results without perform evaluation. It is" |
|
"useful when you want to format the result to a specific format and " |
|
"submit it to the test server", |
|
) |
|
parser.add_argument( |
|
"--format-dir", help="directory where the outputs are saved." |
|
) |
|
parser.add_argument("--show", action="store_true", help="show results") |
|
parser.add_argument( |
|
"--show-dir", help="directory where painted images will be saved" |
|
) |
|
parser.add_argument( |
|
"--show-score-thr", |
|
type=float, |
|
default=0.3, |
|
help="score threshold (default: 0.3)", |
|
) |
|
parser.add_argument( |
|
"--gpu-collect", |
|
action="store_true", |
|
help="whether to use gpu to collect results.", |
|
) |
|
parser.add_argument( |
|
"--tmpdir", |
|
help="tmp directory used for collecting results from multiple " |
|
"workers, available when gpu-collect is not specified", |
|
) |
|
parser.add_argument( |
|
"--cfg-options", |
|
nargs="+", |
|
action=DictAction, |
|
help="override some settings in the used config, the key-value pair " |
|
"in xxx=yyy format will be merged into config file. If the value to " |
|
'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' |
|
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' |
|
"Note that the quotation marks are necessary and that no white space " |
|
"is allowed.", |
|
) |
|
parser.add_argument( |
|
"--launcher", |
|
choices=["none", "pytorch", "slurm", "mpi"], |
|
default="none", |
|
help="job launcher", |
|
) |
|
parser.add_argument("--local_rank", type=int, default=0) |
|
args = parser.parse_args() |
|
if "LOCAL_RANK" not in os.environ: |
|
os.environ["LOCAL_RANK"] = str(args.local_rank) |
|
return args |
|
|
|
|
|
def main() -> None: |
|
"""Main function for model inference.""" |
|
args = parse_args() |
|
|
|
assert args.format_only or args.show or args.show_dir, ( |
|
"Please specify at least one operation (save/eval/format/show the " |
|
"results / save the results) with the argument '--format-only', " |
|
"'--show' or '--show-dir'" |
|
) |
|
|
|
cfg = Config.fromfile(args.config) |
|
if cfg.load_from is None: |
|
cfg_name = os.path.split(args.config)[-1].replace(".py", ".pth") |
|
cfg.load_from = MODEL_SERVER + cfg_name |
|
if args.cfg_options is not None: |
|
cfg.merge_from_dict(args.cfg_options) |
|
|
|
if cfg.get("cudnn_benchmark", False): |
|
torch.backends.cudnn.benchmark = True |
|
|
|
cfg.model.pretrained = None |
|
if cfg.model.get("neck"): |
|
if isinstance(cfg.model.neck, list): |
|
for neck_cfg in cfg.model.neck: |
|
if neck_cfg.get("rfp_backbone"): |
|
if neck_cfg.rfp_backbone.get("pretrained"): |
|
neck_cfg.rfp_backbone.pretrained = None |
|
elif cfg.model.neck.get("rfp_backbone"): |
|
if cfg.model.neck.rfp_backbone.get("pretrained"): |
|
cfg.model.neck.rfp_backbone.pretrained = None |
|
|
|
|
|
samples_per_gpu = 1 |
|
if isinstance(cfg.data.test, dict): |
|
cfg.data.test.test_mode = True |
|
samples_per_gpu = cfg.data.test.pop("samples_per_gpu", 1) |
|
if samples_per_gpu > 1: |
|
|
|
cfg.data.test.pipeline = replace_ImageToTensor( |
|
cfg.data.test.pipeline |
|
) |
|
elif isinstance(cfg.data.test, list): |
|
for ds_cfg in cfg.data.test: |
|
ds_cfg.test_mode = True |
|
samples_per_gpu = max( |
|
[ds_cfg.pop("samples_per_gpu", 1) for ds_cfg in cfg.data.test] |
|
) |
|
if samples_per_gpu > 1: |
|
for ds_cfg in cfg.data.test: |
|
ds_cfg.pipeline = replace_ImageToTensor(ds_cfg.pipeline) |
|
|
|
|
|
if args.launcher == "none": |
|
distributed = False |
|
else: |
|
distributed = True |
|
init_dist(args.launcher, **cfg.dist_params) |
|
|
|
rank, _ = get_dist_info() |
|
|
|
|
|
dataset = build_dataset(cfg.data.test) |
|
data_loader = build_dataloader( |
|
dataset, |
|
samples_per_gpu=samples_per_gpu, |
|
workers_per_gpu=cfg.data.workers_per_gpu, |
|
dist=distributed, |
|
shuffle=False, |
|
) |
|
|
|
|
|
cfg.model.train_cfg = None |
|
model = build_detector(cfg.model, test_cfg=cfg.get("test_cfg")) |
|
fp16_cfg = cfg.get("fp16", None) |
|
if fp16_cfg is not None: |
|
wrap_fp16_model(model) |
|
checkpoint = load_checkpoint(model, cfg.load_from, map_location="cpu") |
|
if args.fuse_conv_bn: |
|
model = fuse_conv_bn(model) |
|
|
|
|
|
if "CLASSES" in checkpoint.get("meta", {}): |
|
model.CLASSES = checkpoint["meta"]["CLASSES"] |
|
else: |
|
model.CLASSES = dataset.CLASSES |
|
|
|
if not distributed: |
|
model = MMDataParallel(model, device_ids=[0]) |
|
outputs = single_gpu_test( |
|
model, data_loader, args.show, args.show_dir, args.show_score_thr |
|
) |
|
else: |
|
model = MMDistributedDataParallel( |
|
model.cuda(), |
|
device_ids=[torch.cuda.current_device()], |
|
broadcast_buffers=False, |
|
) |
|
outputs = multi_gpu_test( |
|
model, data_loader, args.tmpdir, args.gpu_collect |
|
) |
|
|
|
rank, _ = get_dist_info() |
|
if rank == 0: |
|
if args.format_only: |
|
dataset.convert_format(outputs, args.format_dir) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|