Spaces:
Running
Running
File size: 6,338 Bytes
10dcc2e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 |
import argparse
import subprocess
import pprint
import numpy as np
import torch
# import torch.multiprocessing
# torch.multiprocessing.set_sharing_strategy('file_system')
from torch.utils.data import DataLoader
from COTR.models import build_model
from COTR.utils import debug_utils, utils
from COTR.datasets import cotr_dataset
from COTR.trainers.cotr_trainer import COTRTrainer
from COTR.global_configs import general_config
from COTR.options.options import *
from COTR.options.options_utils import *
utils.fix_randomness(0)
def train(opt):
pprint.pprint(dict(os.environ), width=1)
result = subprocess.Popen(["nvidia-smi"], stdout=subprocess.PIPE)
print(result.stdout.read().decode())
device = torch.cuda.current_device()
print(f'can see {torch.cuda.device_count()} gpus')
print(f'current using gpu at {device} -- {torch.cuda.get_device_name(device)}')
# dummy = torch.rand(3758725612).to(device)
# del dummy
torch.cuda.empty_cache()
model = build_model(opt)
model = model.to(device)
if opt.enable_zoom:
train_dset = cotr_dataset.COTRZoomDataset(opt, 'train')
val_dset = cotr_dataset.COTRZoomDataset(opt, 'val')
else:
train_dset = cotr_dataset.COTRDataset(opt, 'train')
val_dset = cotr_dataset.COTRDataset(opt, 'val')
train_loader = DataLoader(train_dset, batch_size=opt.batch_size,
shuffle=opt.shuffle_data, num_workers=opt.workers,
worker_init_fn=utils.worker_init_fn, pin_memory=True)
val_loader = DataLoader(val_dset, batch_size=opt.batch_size,
shuffle=opt.shuffle_data, num_workers=opt.workers,
drop_last=True, worker_init_fn=utils.worker_init_fn, pin_memory=True)
optim_list = [{"params": model.transformer.parameters(), "lr": opt.learning_rate},
{"params": model.corr_embed.parameters(), "lr": opt.learning_rate},
{"params": model.query_proj.parameters(), "lr": opt.learning_rate},
{"params": model.input_proj.parameters(), "lr": opt.learning_rate},
]
if opt.lr_backbone > 0:
optim_list.append({"params": model.backbone.parameters(), "lr": opt.lr_backbone})
optim = torch.optim.Adam(optim_list)
trainer = COTRTrainer(opt, model, optim, None, train_loader, val_loader)
trainer.train()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
set_general_arguments(parser)
set_dataset_arguments(parser)
set_nn_arguments(parser)
set_COTR_arguments(parser)
parser.add_argument('--num_kp', type=int,
default=100)
parser.add_argument('--kp_pool', type=int,
default=100)
parser.add_argument('--enable_zoom', type=str2bool,
default=False)
parser.add_argument('--zoom_start', type=float,
default=1.0)
parser.add_argument('--zoom_end', type=float,
default=0.1)
parser.add_argument('--zoom_levels', type=int,
default=10)
parser.add_argument('--zoom_jitter', type=float,
default=0.5)
parser.add_argument('--out_dir', type=str, default=general_config['out'], help='out directory')
parser.add_argument('--tb_dir', type=str, default=general_config['tb_out'], help='tensorboard runs directory')
parser.add_argument('--learning_rate', type=float,
default=1e-4, help='learning rate')
parser.add_argument('--lr_backbone', type=float,
default=1e-5, help='backbone learning rate')
parser.add_argument('--batch_size', type=int,
default=32, help='batch size for training')
parser.add_argument('--cycle_consis', type=str2bool, default=True,
help='cycle consistency')
parser.add_argument('--bidirectional', type=str2bool, default=True,
help='left2right and right2left')
parser.add_argument('--max_iter', type=int,
default=200000, help='total training iterations')
parser.add_argument('--valid_iter', type=int,
default=1000, help='iterval of validation')
parser.add_argument('--resume', type=str2bool, default=False,
help='resume training with same model name')
parser.add_argument('--cc_resume', type=str2bool, default=False,
help='resume from last run if possible')
parser.add_argument('--need_rotation', type=str2bool, default=False,
help='rotation augmentation')
parser.add_argument('--max_rotation', type=float, default=0,
help='max rotation for data augmentation')
parser.add_argument('--rotation_chance', type=float, default=0,
help='the probability of being rotated')
parser.add_argument('--load_weights', type=str, default=None, help='load a pretrained set of weights, you need to provide the model id')
parser.add_argument('--suffix', type=str, default='', help='model suffix')
opt = parser.parse_args()
opt.command = ' '.join(sys.argv)
layer_2_channels = {'layer1': 256,
'layer2': 512,
'layer3': 1024,
'layer4': 2048, }
opt.dim_feedforward = layer_2_channels[opt.layer]
opt.num_queries = opt.num_kp
opt.name = get_compact_naming_cotr(opt)
opt.out = os.path.join(opt.out_dir, opt.name)
opt.tb_out = os.path.join(opt.tb_dir, opt.name)
if opt.cc_resume:
if os.path.isfile(os.path.join(opt.out, 'checkpoint.pth.tar')):
print('resuming from last run')
opt.load_weights = None
opt.resume = True
else:
opt.resume = False
assert (bool(opt.load_weights) and opt.resume) == False
if opt.load_weights:
opt.load_weights_path = os.path.join(opt.out_dir, opt.load_weights, 'checkpoint.pth.tar')
if opt.resume:
opt.load_weights_path = os.path.join(opt.out, 'checkpoint.pth.tar')
opt.scenes_name_list = build_scenes_name_list_from_opt(opt)
if opt.confirm:
confirm_opt(opt)
else:
print_opt(opt)
save_opt(opt)
train(opt)
|