Spaces:
Running
on
Zero
Running
on
Zero
File size: 10,155 Bytes
59b2a81 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 |
"""
trainer.py - warpper and utility functions for network training
Compute loss, back-prop, update parameters, logging, etc.
"""
import datetime
import os
import time
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from network import XMem
from losses import LossComputer
from util.log_integrator import Integrator
from util.image_saver import pool_pairs
class XMemTrainer:
def __init__(self, config, logger=None, save_path=None, local_rank=0, world_size=1):
self.config = config
self.num_frames = config['num_frames']
self.num_ref_frames = config['num_ref_frames']
self.deep_update_prob = config['deep_update_prob']
self.local_rank = local_rank
self.XMem = nn.parallel.DistributedDataParallel(
XMem(config).cuda(),
device_ids=[local_rank], output_device=local_rank, broadcast_buffers=False)
# Set up logger when local_rank=0
self.logger = logger
self.save_path = save_path
if logger is not None:
self.last_time = time.time()
self.logger.log_string('model_size', str(sum([param.nelement() for param in self.XMem.parameters()])))
self.train_integrator = Integrator(self.logger, distributed=True, local_rank=local_rank, world_size=world_size)
self.loss_computer = LossComputer(config)
self.train()
self.optimizer = optim.AdamW(filter(
lambda p: p.requires_grad, self.XMem.parameters()), lr=config['lr'], weight_decay=config['weight_decay'])
self.scheduler = optim.lr_scheduler.MultiStepLR(self.optimizer, config['steps'], config['gamma'])
if config['amp']:
self.scaler = torch.cuda.amp.GradScaler()
# Logging info
self.log_text_interval = config['log_text_interval']
self.log_image_interval = config['log_image_interval']
self.save_network_interval = config['save_network_interval']
self.save_checkpoint_interval = config['save_checkpoint_interval']
if config['debug']:
self.log_text_interval = self.log_image_interval = 1
def do_pass(self, data, max_it, it=0):
# No need to store the gradient outside training
torch.set_grad_enabled(self._is_train)
for k, v in data.items():
if type(v) != list and type(v) != dict and type(v) != int:
data[k] = v.cuda(non_blocking=True)
out = {}
frames = data['rgb']
first_frame_gt = data['first_frame_gt'].float()
b = frames.shape[0]
num_filled_objects = [o.item() for o in data['info']['num_objects']]
num_objects = first_frame_gt.shape[2]
selector = data['selector'].unsqueeze(2).unsqueeze(2)
global_avg = 0
with torch.cuda.amp.autocast(enabled=self.config['amp']):
# image features never change, compute once
key, shrinkage, selection, f16, f8, f4 = self.XMem('encode_key', frames)
filler_one = torch.zeros(1, dtype=torch.int64)
hidden = torch.zeros((b, num_objects, self.config['hidden_dim'], *key.shape[-2:]))
v16, hidden = self.XMem('encode_value', frames[:,0], f16[:,0], hidden, first_frame_gt[:,0])
values = v16.unsqueeze(3) # add the time dimension
for ti in range(1, self.num_frames):
if ti <= self.num_ref_frames:
ref_values = values
ref_keys = key[:,:,:ti]
ref_shrinkage = shrinkage[:,:,:ti] if shrinkage is not None else None
else:
# pick num_ref_frames random frames
# this is not very efficient but I think we would
# need broadcasting in gather which we don't have
indices = [
torch.cat([filler_one, torch.randperm(ti-1)[:self.num_ref_frames-1]+1])
for _ in range(b)]
ref_values = torch.stack([
values[bi, :, :, indices[bi]] for bi in range(b)
], 0)
ref_keys = torch.stack([
key[bi, :, indices[bi]] for bi in range(b)
], 0)
ref_shrinkage = torch.stack([
shrinkage[bi, :, indices[bi]] for bi in range(b)
], 0) if shrinkage is not None else None
# Segment frame ti
memory_readout = self.XMem('read_memory', key[:,:,ti], selection[:,:,ti] if selection is not None else None,
ref_keys, ref_shrinkage, ref_values)
hidden, logits, masks = self.XMem('segment', (f16[:,ti], f8[:,ti], f4[:,ti]), memory_readout,
hidden, selector, h_out=(ti < (self.num_frames-1)))
# No need to encode the last frame
if ti < (self.num_frames-1):
is_deep_update = np.random.rand() < self.deep_update_prob
v16, hidden = self.XMem('encode_value', frames[:,ti], f16[:,ti], hidden, masks, is_deep_update=is_deep_update)
values = torch.cat([values, v16.unsqueeze(3)], 3)
out[f'masks_{ti}'] = masks
out[f'logits_{ti}'] = logits
if self._do_log or self._is_train:
losses = self.loss_computer.compute({**data, **out}, num_filled_objects, it)
# Logging
if self._do_log:
self.integrator.add_dict(losses)
if self._is_train:
if it % self.log_image_interval == 0 and it != 0:
if self.logger is not None:
images = {**data, **out}
size = (384, 384)
self.logger.log_cv2('train/pairs', pool_pairs(images, size, num_filled_objects), it)
if self._is_train:
if (it) % self.log_text_interval == 0 and it != 0:
time_spent = time.time()-self.last_time
if self.logger is not None:
self.logger.log_scalar('train/lr', self.scheduler.get_last_lr()[0], it)
self.logger.log_metrics('train', 'time', (time_spent)/self.log_text_interval, it)
global_avg = 0.5*(global_avg) + 0.5*(time_spent)
eta_seconds = global_avg * (max_it - it) / 100
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
print(f'ETA: {eta_string}')
self.last_time = time.time()
self.train_integrator.finalize('train', it)
self.train_integrator.reset_except_hooks()
if it % self.save_network_interval == 0 and it != 0:
if self.logger is not None:
self.save_network(it)
if it % self.save_checkpoint_interval == 0 and it != 0:
if self.logger is not None:
self.save_checkpoint(it)
# Backward pass
self.optimizer.zero_grad(set_to_none=True)
if self.config['amp']:
self.scaler.scale(losses['total_loss']).backward()
self.scaler.step(self.optimizer)
self.scaler.update()
else:
losses['total_loss'].backward()
self.optimizer.step()
self.scheduler.step()
def save_network(self, it):
if self.save_path is None:
print('Saving has been disabled.')
return
os.makedirs(os.path.dirname(self.save_path), exist_ok=True)
model_path = f'{self.save_path}_{it}.pth'
torch.save(self.XMem.module.state_dict(), model_path)
print(f'Network saved to {model_path}.')
def save_checkpoint(self, it):
if self.save_path is None:
print('Saving has been disabled.')
return
os.makedirs(os.path.dirname(self.save_path), exist_ok=True)
checkpoint_path = f'{self.save_path}_checkpoint_{it}.pth'
checkpoint = {
'it': it,
'network': self.XMem.module.state_dict(),
'optimizer': self.optimizer.state_dict(),
'scheduler': self.scheduler.state_dict()}
torch.save(checkpoint, checkpoint_path)
print(f'Checkpoint saved to {checkpoint_path}.')
def load_checkpoint(self, path):
# This method loads everything and should be used to resume training
map_location = 'cuda:%d' % self.local_rank
checkpoint = torch.load(path, map_location={'cpu': map_location})
it = checkpoint['it']
network = checkpoint['network']
optimizer = checkpoint['optimizer']
scheduler = checkpoint['scheduler']
map_location = 'cuda:%d' % self.local_rank
self.XMem.module.load_state_dict(network)
self.optimizer.load_state_dict(optimizer)
self.scheduler.load_state_dict(scheduler)
print('Network weights, optimizer states, and scheduler states loaded.')
return it
def load_network_in_memory(self, src_dict):
self.XMem.module.load_weights(src_dict)
print('Network weight loaded from memory.')
def load_network(self, path):
# This method loads only the network weight and should be used to load a pretrained model
map_location = 'cuda:%d' % self.local_rank
src_dict = torch.load(path, map_location={'cpu': map_location})
self.load_network_in_memory(src_dict)
print(f'Network weight loaded from {path}')
def train(self):
self._is_train = True
self._do_log = True
self.integrator = self.train_integrator
self.XMem.eval()
return self
def val(self):
self._is_train = False
self._do_log = True
self.XMem.eval()
return self
def test(self):
self._is_train = False
self._do_log = False
self.XMem.eval()
return self
|