mshukor
init
3eb682b
raw
history blame
3 kB
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
"""Argument parser functions."""
import argparse
import sys
import timesformer.utils.checkpoint as cu
from timesformer.config.defaults import get_cfg
def parse_args():
"""
Parse the following arguments for a default parser for PySlowFast users.
Args:
shard_id (int): shard id for the current machine. Starts from 0 to
num_shards - 1. If single machine is used, then set shard id to 0.
num_shards (int): number of shards using by the job.
init_method (str): initialization method to launch the job with multiple
devices. Options includes TCP or shared file-system for
initialization. details can be find in
https://pytorch.org/docs/stable/distributed.html#tcp-initialization
cfg (str): path to the config file.
opts (argument): provide addtional options from the command line, it
overwrites the config loaded from file.
"""
parser = argparse.ArgumentParser(
description="Provide SlowFast video training and testing pipeline."
)
parser.add_argument(
"--shard_id",
help="The shard id of current node, Starts from 0 to num_shards - 1",
default=0,
type=int,
)
parser.add_argument(
"--num_shards",
help="Number of shards using by the job",
default=1,
type=int,
)
parser.add_argument(
"--init_method",
help="Initialization method, includes TCP or shared file-system",
default="tcp://localhost:9999",
type=str,
)
parser.add_argument(
"--cfg",
dest="cfg_file",
help="Path to the config file",
default="configs/Kinetics/SLOWFAST_4x16_R50.yaml",
type=str,
)
parser.add_argument(
"opts",
help="See slowfast/config/defaults.py for all options",
default=None,
nargs=argparse.REMAINDER,
)
if len(sys.argv) == 1:
parser.print_help()
return parser.parse_args()
def load_config(args):
"""
Given the arguemnts, load and initialize the configs.
Args:
args (argument): arguments includes `shard_id`, `num_shards`,
`init_method`, `cfg_file`, and `opts`.
"""
# Setup cfg.
cfg = get_cfg()
# Load config from cfg.
if args.cfg_file is not None:
cfg.merge_from_file(args.cfg_file)
# Load config from command line, overwrite config from opts.
if args.opts is not None:
cfg.merge_from_list(args.opts)
# Inherit parameters from args.
if hasattr(args, "num_shards") and hasattr(args, "shard_id"):
cfg.NUM_SHARDS = args.num_shards
cfg.SHARD_ID = args.shard_id
if hasattr(args, "rng_seed"):
cfg.RNG_SEED = args.rng_seed
if hasattr(args, "output_dir"):
cfg.OUTPUT_DIR = args.output_dir
# Create the checkpoint dir.
cu.make_checkpoint_dir(cfg.OUTPUT_DIR)
return cfg