Spaces:
Running
on
Zero
Running
on
Zero
""" | |
Default training/testing logic | |
modified from detectron2(https://github.com/facebookresearch/detectron2) | |
Author: Xiaoyang Wu (xiaoyang.wu.cs@gmail.com) | |
Please cite our work if the code is helpful to you. | |
""" | |
import os | |
import sys | |
import argparse | |
import multiprocessing as mp | |
from torch.nn.parallel import DistributedDataParallel | |
import pointcept.utils.comm as comm | |
from pointcept.utils.env import get_random_seed, set_seed | |
from pointcept.utils.config import Config, DictAction | |
def create_ddp_model(model, *, fp16_compression=False, **kwargs): | |
""" | |
Create a DistributedDataParallel model if there are >1 processes. | |
Args: | |
model: a torch.nn.Module | |
fp16_compression: add fp16 compression hooks to the ddp object. | |
See more at https://pytorch.org/docs/stable/ddp_comm_hooks.html#torch.distributed.algorithms.ddp_comm_hooks.default_hooks.fp16_compress_hook | |
kwargs: other arguments of :module:`torch.nn.parallel.DistributedDataParallel`. | |
""" | |
if comm.get_world_size() == 1: | |
return model | |
# kwargs['find_unused_parameters'] = True | |
if "device_ids" not in kwargs: | |
kwargs["device_ids"] = [comm.get_local_rank()] | |
if "output_device" not in kwargs: | |
kwargs["output_device"] = [comm.get_local_rank()] | |
ddp = DistributedDataParallel(model, **kwargs) | |
if fp16_compression: | |
from torch.distributed.algorithms.ddp_comm_hooks import default as comm_hooks | |
ddp.register_comm_hook(state=None, hook=comm_hooks.fp16_compress_hook) | |
return ddp | |
def worker_init_fn(worker_id, num_workers, rank, seed): | |
"""Worker init func for dataloader. | |
The seed of each worker equals to num_worker * rank + worker_id + user_seed | |
Args: | |
worker_id (int): Worker id. | |
num_workers (int): Number of workers. | |
rank (int): The rank of current process. | |
seed (int): The random seed to use. | |
""" | |
worker_seed = num_workers * rank + worker_id + seed | |
set_seed(worker_seed) | |
def default_argument_parser(epilog=None): | |
parser = argparse.ArgumentParser( | |
epilog=epilog | |
or f""" | |
Examples: | |
Run on single machine: | |
$ {sys.argv[0]} --num-gpus 8 --config-file cfg.yaml | |
Change some config options: | |
$ {sys.argv[0]} --config-file cfg.yaml MODEL.WEIGHTS /path/to/weight.pth SOLVER.BASE_LR 0.001 | |
Run on multiple machines: | |
(machine0)$ {sys.argv[0]} --machine-rank 0 --num-machines 2 --dist-url <URL> [--other-flags] | |
(machine1)$ {sys.argv[0]} --machine-rank 1 --num-machines 2 --dist-url <URL> [--other-flags] | |
""", | |
formatter_class=argparse.RawDescriptionHelpFormatter, | |
) | |
parser.add_argument( | |
"--config-file", default="", metavar="FILE", help="path to config file" | |
) | |
parser.add_argument( | |
"--num-gpus", type=int, default=1, help="number of gpus *per machine*" | |
) | |
parser.add_argument( | |
"--num-machines", type=int, default=1, help="total number of machines" | |
) | |
parser.add_argument( | |
"--machine-rank", | |
type=int, | |
default=0, | |
help="the rank of this machine (unique per machine)", | |
) | |
# PyTorch still may leave orphan processes in multi-gpu training. | |
# Therefore we use a deterministic way to obtain port, | |
# so that users are aware of orphan processes by seeing the port occupied. | |
# port = 2 ** 15 + 2 ** 14 + hash(os.getuid() if sys.platform != "win32" else 1) % 2 ** 14 | |
parser.add_argument( | |
"--dist-url", | |
# default="tcp://127.0.0.1:{}".format(port), | |
default="auto", | |
help="initialization URL for pytorch distributed backend. See " | |
"https://pytorch.org/docs/stable/distributed.html for details.", | |
) | |
parser.add_argument( | |
"--options", nargs="+", action=DictAction, help="custom options" | |
) | |
return parser | |
def default_config_parser(file_path, options): | |
# config name protocol: dataset_name/model_name-exp_name | |
if os.path.isfile(file_path): | |
cfg = Config.fromfile(file_path) | |
else: | |
sep = file_path.find("-") | |
cfg = Config.fromfile(os.path.join(file_path[:sep], file_path[sep + 1 :])) | |
if options is not None: | |
cfg.merge_from_dict(options) | |
if cfg.seed is None: | |
cfg.seed = get_random_seed() | |
cfg.data.train.loop = cfg.epoch // cfg.eval_epoch | |
os.makedirs(os.path.join(cfg.save_path, "model"), exist_ok=True) | |
if not cfg.resume: | |
cfg.dump(os.path.join(cfg.save_path, "config.py")) | |
return cfg | |
def default_setup(cfg): | |
# scalar by world size | |
world_size = comm.get_world_size() | |
cfg.num_worker = cfg.num_worker if cfg.num_worker is not None else mp.cpu_count() | |
cfg.num_worker_per_gpu = cfg.num_worker // world_size | |
assert cfg.batch_size % world_size == 0 | |
assert cfg.batch_size_val is None or cfg.batch_size_val % world_size == 0 | |
assert cfg.batch_size_test is None or cfg.batch_size_test % world_size == 0 | |
cfg.batch_size_per_gpu = cfg.batch_size // world_size | |
cfg.batch_size_val_per_gpu = ( | |
cfg.batch_size_val // world_size if cfg.batch_size_val is not None else 1 | |
) | |
cfg.batch_size_test_per_gpu = ( | |
cfg.batch_size_test // world_size if cfg.batch_size_test is not None else 1 | |
) | |
# update data loop | |
assert cfg.epoch % cfg.eval_epoch == 0 | |
# settle random seed | |
rank = comm.get_rank() | |
seed = None if cfg.seed is None else cfg.seed * cfg.num_worker_per_gpu + rank | |
set_seed(seed) | |
return cfg | |