# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. """Argument parser functions.""" import argparse import sys import timesformer.utils.checkpoint as cu from timesformer.config.defaults import get_cfg def parse_args(): """ Parse the following arguments for a default parser for PySlowFast users. Args: shard_id (int): shard id for the current machine. Starts from 0 to num_shards - 1. If single machine is used, then set shard id to 0. num_shards (int): number of shards using by the job. init_method (str): initialization method to launch the job with multiple devices. Options includes TCP or shared file-system for initialization. details can be find in https://pytorch.org/docs/stable/distributed.html#tcp-initialization cfg (str): path to the config file. opts (argument): provide addtional options from the command line, it overwrites the config loaded from file. """ parser = argparse.ArgumentParser( description="Provide SlowFast video training and testing pipeline." ) parser.add_argument( "--shard_id", help="The shard id of current node, Starts from 0 to num_shards - 1", default=0, type=int, ) parser.add_argument( "--num_shards", help="Number of shards using by the job", default=1, type=int, ) parser.add_argument( "--init_method", help="Initialization method, includes TCP or shared file-system", default="tcp://localhost:9999", type=str, ) parser.add_argument( "--cfg", dest="cfg_file", help="Path to the config file", default="configs/Kinetics/SLOWFAST_4x16_R50.yaml", type=str, ) parser.add_argument( "opts", help="See slowfast/config/defaults.py for all options", default=None, nargs=argparse.REMAINDER, ) if len(sys.argv) == 1: parser.print_help() return parser.parse_args() def load_config(args): """ Given the arguemnts, load and initialize the configs. Args: args (argument): arguments includes `shard_id`, `num_shards`, `init_method`, `cfg_file`, and `opts`. """ # Setup cfg. cfg = get_cfg() # Load config from cfg. if args.cfg_file is not None: cfg.merge_from_file(args.cfg_file) # Load config from command line, overwrite config from opts. if args.opts is not None: cfg.merge_from_list(args.opts) # Inherit parameters from args. if hasattr(args, "num_shards") and hasattr(args, "shard_id"): cfg.NUM_SHARDS = args.num_shards cfg.SHARD_ID = args.shard_id if hasattr(args, "rng_seed"): cfg.RNG_SEED = args.rng_seed if hasattr(args, "output_dir"): cfg.OUTPUT_DIR = args.output_dir # Create the checkpoint dir. cu.make_checkpoint_dir(cfg.OUTPUT_DIR) return cfg