# MIT License # Copyright (c) 2022 Intelligent Systems Lab Org # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: # The above copyright notice and this permission notice shall be included in all # copies or substantial portions of the Software. # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. # File author: Shariq Farooq Bhat from zoedepth.utils.misc import count_parameters, parallelize from zoedepth.utils.config import get_config from zoedepth.utils.arg_utils import parse_unknown from zoedepth.trainers.builder import get_trainer from zoedepth.models.builder import build_model from zoedepth.data.data_mono import MixedNYUKITTI import torch.utils.data.distributed import torch.multiprocessing as mp import torch import numpy as np from pprint import pprint import argparse import os os.environ["PYOPENGL_PLATFORM"] = "egl" os.environ["WANDB_START_METHOD"] = "thread" def fix_random_seed(seed: int): """ Fix random seed for reproducibility Args: seed (int): random seed """ import random import numpy import torch random.seed(seed) numpy.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False def load_ckpt(config, model, checkpoint_dir="./checkpoints", ckpt_type="best"): import glob import os from zoedepth.models.model_io import load_wts if hasattr(config, "checkpoint"): checkpoint = config.checkpoint elif hasattr(config, "ckpt_pattern"): pattern = config.ckpt_pattern matches = glob.glob(os.path.join( checkpoint_dir, f"*{pattern}*{ckpt_type}*")) if not (len(matches) > 0): raise ValueError(f"No matches found for the pattern {pattern}") checkpoint = matches[0] else: return model model = load_wts(model, checkpoint) print("Loaded weights from {0}".format(checkpoint)) return model def main_worker(gpu, ngpus_per_node, config): try: fix_random_seed(43) config.gpu = gpu model = build_model(config) model = load_ckpt(config, model) model = parallelize(config, model) total_params = f"{round(count_parameters(model)/1e6,2)}M" config.total_params = total_params print(f"Total parameters : {total_params}") train_loader = MixedNYUKITTI(config, "train").data test_loader = MixedNYUKITTI(config, "online_eval").data trainer = get_trainer(config)( config, model, train_loader, test_loader, device=config.gpu) trainer.train() finally: import wandb wandb.finish() if __name__ == '__main__': mp.set_start_method('forkserver') parser = argparse.ArgumentParser() parser.add_argument("-m", "--model", type=str, default="synunet") parser.add_argument("-d", "--dataset", type=str, default='mix') parser.add_argument("--trainer", type=str, default=None) args, unknown_args = parser.parse_known_args() overwrite_kwargs = parse_unknown(unknown_args) overwrite_kwargs["model"] = args.model if args.trainer is not None: overwrite_kwargs["trainer"] = args.trainer config = get_config(args.model, "train", args.dataset, **overwrite_kwargs) # git_commit() if config.use_shared_dict: shared_dict = mp.Manager().dict() else: shared_dict = None config.shared_dict = shared_dict config.batch_size = config.bs config.mode = 'train' if config.root != "." and not os.path.isdir(config.root): os.makedirs(config.root) try: node_str = os.environ['SLURM_JOB_NODELIST'].replace( '[', '').replace(']', '') nodes = node_str.split(',') config.world_size = len(nodes) config.rank = int(os.environ['SLURM_PROCID']) # config.save_dir = "/ibex/scratch/bhatsf/videodepth/checkpoints" except KeyError as e: # We are NOT using SLURM config.world_size = 1 config.rank = 0 nodes = ["127.0.0.1"] if config.distributed: print(config.rank) port = np.random.randint(15000, 15025) config.dist_url = 'tcp://{}:{}'.format(nodes[0], port) print(config.dist_url) config.dist_backend = 'nccl' config.gpu = None ngpus_per_node = torch.cuda.device_count() config.num_workers = config.workers config.ngpus_per_node = ngpus_per_node print("Config:") pprint(config) if config.distributed: config.world_size = ngpus_per_node * config.world_size mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, config)) else: if ngpus_per_node == 1: config.gpu = 0 main_worker(config.gpu, ngpus_per_node, config)