import importlib import sys sys.path.append('.') sys.path.append('..') import torch import torch.multiprocessing as mp from networks.managers.evaluator import Evaluator def main_worker(gpu, cfg, seq_queue=None, info_queue=None, enable_amp=False): # Initiate a evaluating manager evaluator = Evaluator(rank=gpu, cfg=cfg, seq_queue=seq_queue, info_queue=info_queue) # Start evaluation if enable_amp: with torch.cuda.amp.autocast(enabled=True): evaluator.evaluating() else: evaluator.evaluating() def main(): import argparse parser = argparse.ArgumentParser(description="Eval VOS") parser.add_argument('--exp_name', type=str, default='default') parser.add_argument('--stage', type=str, default='pre') parser.add_argument('--model', type=str, default='aott') parser.add_argument('--lstt_num', type=int, default=-1) parser.add_argument('--lt_gap', type=int, default=-1) parser.add_argument('--st_skip', type=int, default=-1) parser.add_argument('--max_id_num', type=int, default='-1') parser.add_argument('--gpu_id', type=int, default=0) parser.add_argument('--gpu_num', type=int, default=1) parser.add_argument('--ckpt_path', type=str, default='') parser.add_argument('--ckpt_step', type=int, default=-1) parser.add_argument('--dataset', type=str, default='') parser.add_argument('--split', type=str, default='') parser.add_argument('--ema', action='store_true') parser.set_defaults(ema=False) parser.add_argument('--flip', action='store_true') parser.set_defaults(flip=False) parser.add_argument('--ms', nargs='+', type=float, default=[1.]) parser.add_argument('--max_resolution', type=float, default=480 * 1.3) parser.add_argument('--amp', action='store_true') parser.set_defaults(amp=False) args = parser.parse_args() engine_config = importlib.import_module('configs.' + args.stage) cfg = engine_config.EngineConfig(args.exp_name, args.model) cfg.TEST_EMA = args.ema cfg.TEST_GPU_ID = args.gpu_id cfg.TEST_GPU_NUM = args.gpu_num if args.lstt_num > 0: cfg.MODEL_LSTT_NUM = args.lstt_num if args.lt_gap > 0: cfg.TEST_LONG_TERM_MEM_GAP = args.lt_gap if args.st_skip > 0: cfg.TEST_SHORT_TERM_MEM_SKIP = args.st_skip if args.max_id_num > 0: cfg.MODEL_MAX_OBJ_NUM = args.max_id_num if args.ckpt_path != '': cfg.TEST_CKPT_PATH = args.ckpt_path if args.ckpt_step > 0: cfg.TEST_CKPT_STEP = args.ckpt_step if args.dataset != '': cfg.TEST_DATASET = args.dataset if args.split != '': cfg.TEST_DATASET_SPLIT = args.split cfg.TEST_FLIP = args.flip cfg.TEST_MULTISCALE = args.ms if cfg.TEST_MULTISCALE != [1.]: cfg.TEST_MAX_SHORT_EDGE = args.max_resolution # for preventing OOM else: cfg.TEST_MAX_SHORT_EDGE = None # the default resolution setting of CFBI and AOT cfg.TEST_MAX_LONG_EDGE = args.max_resolution * 800. / 480. if args.gpu_num > 1: mp.set_start_method('spawn') seq_queue = mp.Queue() info_queue = mp.Queue() mp.spawn(main_worker, nprocs=cfg.TEST_GPU_NUM, args=(cfg, seq_queue, info_queue, args.amp)) else: main_worker(0, cfg, enable_amp=args.amp) if __name__ == '__main__': main()