LucidDreamer / ZoeDepth /train_mix.py
ironjr's picture
untroubled files first
24f9881
raw history blame
No virus
5.69 kB
# MIT License
# Copyright (c) 2022 Intelligent Systems Lab Org
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
# File author: Shariq Farooq Bhat
from zoedepth.utils.misc import count_parameters, parallelize
from zoedepth.utils.config import get_config
from zoedepth.utils.arg_utils import parse_unknown
from zoedepth.trainers.builder import get_trainer
from zoedepth.models.builder import build_model
from zoedepth.data.data_mono import MixedNYUKITTI
import torch.utils.data.distributed
import torch.multiprocessing as mp
import torch
import numpy as np
from pprint import pprint
import argparse
import os
os.environ["PYOPENGL_PLATFORM"] = "egl"
os.environ["WANDB_START_METHOD"] = "thread"
def fix_random_seed(seed: int):
"""
Fix random seed for reproducibility
Args:
seed (int): random seed
"""
import random
import numpy
import torch
random.seed(seed)
numpy.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def load_ckpt(config, model, checkpoint_dir="./checkpoints", ckpt_type="best"):
import glob
import os
from zoedepth.models.model_io import load_wts
if hasattr(config, "checkpoint"):
checkpoint = config.checkpoint
elif hasattr(config, "ckpt_pattern"):
pattern = config.ckpt_pattern
matches = glob.glob(os.path.join(
checkpoint_dir, f"*{pattern}*{ckpt_type}*"))
if not (len(matches) > 0):
raise ValueError(f"No matches found for the pattern {pattern}")
checkpoint = matches[0]
else:
return model
model = load_wts(model, checkpoint)
print("Loaded weights from {0}".format(checkpoint))
return model
def main_worker(gpu, ngpus_per_node, config):
try:
fix_random_seed(43)
config.gpu = gpu
model = build_model(config)
model = load_ckpt(config, model)
model = parallelize(config, model)
total_params = f"{round(count_parameters(model)/1e6,2)}M"
config.total_params = total_params
print(f"Total parameters : {total_params}")
train_loader = MixedNYUKITTI(config, "train").data
test_loader = MixedNYUKITTI(config, "online_eval").data
trainer = get_trainer(config)(
config, model, train_loader, test_loader, device=config.gpu)
trainer.train()
finally:
import wandb
wandb.finish()
if __name__ == '__main__':
mp.set_start_method('forkserver')
parser = argparse.ArgumentParser()
parser.add_argument("-m", "--model", type=str, default="synunet")
parser.add_argument("-d", "--dataset", type=str, default='mix')
parser.add_argument("--trainer", type=str, default=None)
args, unknown_args = parser.parse_known_args()
overwrite_kwargs = parse_unknown(unknown_args)
overwrite_kwargs["model"] = args.model
if args.trainer is not None:
overwrite_kwargs["trainer"] = args.trainer
config = get_config(args.model, "train", args.dataset, **overwrite_kwargs)
# git_commit()
if config.use_shared_dict:
shared_dict = mp.Manager().dict()
else:
shared_dict = None
config.shared_dict = shared_dict
config.batch_size = config.bs
config.mode = 'train'
if config.root != "." and not os.path.isdir(config.root):
os.makedirs(config.root)
try:
node_str = os.environ['SLURM_JOB_NODELIST'].replace(
'[', '').replace(']', '')
nodes = node_str.split(',')
config.world_size = len(nodes)
config.rank = int(os.environ['SLURM_PROCID'])
# config.save_dir = "/ibex/scratch/bhatsf/videodepth/checkpoints"
except KeyError as e:
# We are NOT using SLURM
config.world_size = 1
config.rank = 0
nodes = ["127.0.0.1"]
if config.distributed:
print(config.rank)
port = np.random.randint(15000, 15025)
config.dist_url = 'tcp://{}:{}'.format(nodes[0], port)
print(config.dist_url)
config.dist_backend = 'nccl'
config.gpu = None
ngpus_per_node = torch.cuda.device_count()
config.num_workers = config.workers
config.ngpus_per_node = ngpus_per_node
print("Config:")
pprint(config)
if config.distributed:
config.world_size = ngpus_per_node * config.world_size
mp.spawn(main_worker, nprocs=ngpus_per_node,
args=(ngpus_per_node, config))
else:
if ngpus_per_node == 1:
config.gpu = 0
main_worker(config.gpu, ngpus_per_node, config)