ziqima's picture
initial commit
4893ce0
raw
history blame
5.5 kB
"""
Default training/testing logic
modified from detectron2(https://github.com/facebookresearch/detectron2)
Author: Xiaoyang Wu (xiaoyang.wu.cs@gmail.com)
Please cite our work if the code is helpful to you.
"""
import os
import sys
import argparse
import multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel
import pointcept.utils.comm as comm
from pointcept.utils.env import get_random_seed, set_seed
from pointcept.utils.config import Config, DictAction
def create_ddp_model(model, *, fp16_compression=False, **kwargs):
"""
Create a DistributedDataParallel model if there are >1 processes.
Args:
model: a torch.nn.Module
fp16_compression: add fp16 compression hooks to the ddp object.
See more at https://pytorch.org/docs/stable/ddp_comm_hooks.html#torch.distributed.algorithms.ddp_comm_hooks.default_hooks.fp16_compress_hook
kwargs: other arguments of :module:`torch.nn.parallel.DistributedDataParallel`.
"""
if comm.get_world_size() == 1:
return model
# kwargs['find_unused_parameters'] = True
if "device_ids" not in kwargs:
kwargs["device_ids"] = [comm.get_local_rank()]
if "output_device" not in kwargs:
kwargs["output_device"] = [comm.get_local_rank()]
ddp = DistributedDataParallel(model, **kwargs)
if fp16_compression:
from torch.distributed.algorithms.ddp_comm_hooks import default as comm_hooks
ddp.register_comm_hook(state=None, hook=comm_hooks.fp16_compress_hook)
return ddp
def worker_init_fn(worker_id, num_workers, rank, seed):
"""Worker init func for dataloader.
The seed of each worker equals to num_worker * rank + worker_id + user_seed
Args:
worker_id (int): Worker id.
num_workers (int): Number of workers.
rank (int): The rank of current process.
seed (int): The random seed to use.
"""
worker_seed = num_workers * rank + worker_id + seed
set_seed(worker_seed)
def default_argument_parser(epilog=None):
parser = argparse.ArgumentParser(
epilog=epilog
or f"""
Examples:
Run on single machine:
$ {sys.argv[0]} --num-gpus 8 --config-file cfg.yaml
Change some config options:
$ {sys.argv[0]} --config-file cfg.yaml MODEL.WEIGHTS /path/to/weight.pth SOLVER.BASE_LR 0.001
Run on multiple machines:
(machine0)$ {sys.argv[0]} --machine-rank 0 --num-machines 2 --dist-url <URL> [--other-flags]
(machine1)$ {sys.argv[0]} --machine-rank 1 --num-machines 2 --dist-url <URL> [--other-flags]
""",
formatter_class=argparse.RawDescriptionHelpFormatter,
)
parser.add_argument(
"--config-file", default="", metavar="FILE", help="path to config file"
)
parser.add_argument(
"--num-gpus", type=int, default=1, help="number of gpus *per machine*"
)
parser.add_argument(
"--num-machines", type=int, default=1, help="total number of machines"
)
parser.add_argument(
"--machine-rank",
type=int,
default=0,
help="the rank of this machine (unique per machine)",
)
# PyTorch still may leave orphan processes in multi-gpu training.
# Therefore we use a deterministic way to obtain port,
# so that users are aware of orphan processes by seeing the port occupied.
# port = 2 ** 15 + 2 ** 14 + hash(os.getuid() if sys.platform != "win32" else 1) % 2 ** 14
parser.add_argument(
"--dist-url",
# default="tcp://127.0.0.1:{}".format(port),
default="auto",
help="initialization URL for pytorch distributed backend. See "
"https://pytorch.org/docs/stable/distributed.html for details.",
)
parser.add_argument(
"--options", nargs="+", action=DictAction, help="custom options"
)
return parser
def default_config_parser(file_path, options):
# config name protocol: dataset_name/model_name-exp_name
if os.path.isfile(file_path):
cfg = Config.fromfile(file_path)
else:
sep = file_path.find("-")
cfg = Config.fromfile(os.path.join(file_path[:sep], file_path[sep + 1 :]))
if options is not None:
cfg.merge_from_dict(options)
if cfg.seed is None:
cfg.seed = get_random_seed()
cfg.data.train.loop = cfg.epoch // cfg.eval_epoch
os.makedirs(os.path.join(cfg.save_path, "model"), exist_ok=True)
if not cfg.resume:
cfg.dump(os.path.join(cfg.save_path, "config.py"))
return cfg
def default_setup(cfg):
# scalar by world size
world_size = comm.get_world_size()
cfg.num_worker = cfg.num_worker if cfg.num_worker is not None else mp.cpu_count()
cfg.num_worker_per_gpu = cfg.num_worker // world_size
assert cfg.batch_size % world_size == 0
assert cfg.batch_size_val is None or cfg.batch_size_val % world_size == 0
assert cfg.batch_size_test is None or cfg.batch_size_test % world_size == 0
cfg.batch_size_per_gpu = cfg.batch_size // world_size
cfg.batch_size_val_per_gpu = (
cfg.batch_size_val // world_size if cfg.batch_size_val is not None else 1
)
cfg.batch_size_test_per_gpu = (
cfg.batch_size_test // world_size if cfg.batch_size_test is not None else 1
)
# update data loop
assert cfg.epoch % cfg.eval_epoch == 0
# settle random seed
rank = comm.get_rank()
seed = None if cfg.seed is None else cfg.seed * cfg.num_worker_per_gpu + rank
set_seed(seed)
return cfg