HikariDawn777's picture
feat: initial push
59b2a81
raw
history blame
10.2 kB
"""
trainer.py - warpper and utility functions for network training
Compute loss, back-prop, update parameters, logging, etc.
"""
import datetime
import os
import time
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from network import XMem
from losses import LossComputer
from util.log_integrator import Integrator
from util.image_saver import pool_pairs
class XMemTrainer:
def __init__(self, config, logger=None, save_path=None, local_rank=0, world_size=1):
self.config = config
self.num_frames = config['num_frames']
self.num_ref_frames = config['num_ref_frames']
self.deep_update_prob = config['deep_update_prob']
self.local_rank = local_rank
self.XMem = nn.parallel.DistributedDataParallel(
XMem(config).cuda(),
device_ids=[local_rank], output_device=local_rank, broadcast_buffers=False)
# Set up logger when local_rank=0
self.logger = logger
self.save_path = save_path
if logger is not None:
self.last_time = time.time()
self.logger.log_string('model_size', str(sum([param.nelement() for param in self.XMem.parameters()])))
self.train_integrator = Integrator(self.logger, distributed=True, local_rank=local_rank, world_size=world_size)
self.loss_computer = LossComputer(config)
self.train()
self.optimizer = optim.AdamW(filter(
lambda p: p.requires_grad, self.XMem.parameters()), lr=config['lr'], weight_decay=config['weight_decay'])
self.scheduler = optim.lr_scheduler.MultiStepLR(self.optimizer, config['steps'], config['gamma'])
if config['amp']:
self.scaler = torch.cuda.amp.GradScaler()
# Logging info
self.log_text_interval = config['log_text_interval']
self.log_image_interval = config['log_image_interval']
self.save_network_interval = config['save_network_interval']
self.save_checkpoint_interval = config['save_checkpoint_interval']
if config['debug']:
self.log_text_interval = self.log_image_interval = 1
def do_pass(self, data, max_it, it=0):
# No need to store the gradient outside training
torch.set_grad_enabled(self._is_train)
for k, v in data.items():
if type(v) != list and type(v) != dict and type(v) != int:
data[k] = v.cuda(non_blocking=True)
out = {}
frames = data['rgb']
first_frame_gt = data['first_frame_gt'].float()
b = frames.shape[0]
num_filled_objects = [o.item() for o in data['info']['num_objects']]
num_objects = first_frame_gt.shape[2]
selector = data['selector'].unsqueeze(2).unsqueeze(2)
global_avg = 0
with torch.cuda.amp.autocast(enabled=self.config['amp']):
# image features never change, compute once
key, shrinkage, selection, f16, f8, f4 = self.XMem('encode_key', frames)
filler_one = torch.zeros(1, dtype=torch.int64)
hidden = torch.zeros((b, num_objects, self.config['hidden_dim'], *key.shape[-2:]))
v16, hidden = self.XMem('encode_value', frames[:,0], f16[:,0], hidden, first_frame_gt[:,0])
values = v16.unsqueeze(3) # add the time dimension
for ti in range(1, self.num_frames):
if ti <= self.num_ref_frames:
ref_values = values
ref_keys = key[:,:,:ti]
ref_shrinkage = shrinkage[:,:,:ti] if shrinkage is not None else None
else:
# pick num_ref_frames random frames
# this is not very efficient but I think we would
# need broadcasting in gather which we don't have
indices = [
torch.cat([filler_one, torch.randperm(ti-1)[:self.num_ref_frames-1]+1])
for _ in range(b)]
ref_values = torch.stack([
values[bi, :, :, indices[bi]] for bi in range(b)
], 0)
ref_keys = torch.stack([
key[bi, :, indices[bi]] for bi in range(b)
], 0)
ref_shrinkage = torch.stack([
shrinkage[bi, :, indices[bi]] for bi in range(b)
], 0) if shrinkage is not None else None
# Segment frame ti
memory_readout = self.XMem('read_memory', key[:,:,ti], selection[:,:,ti] if selection is not None else None,
ref_keys, ref_shrinkage, ref_values)
hidden, logits, masks = self.XMem('segment', (f16[:,ti], f8[:,ti], f4[:,ti]), memory_readout,
hidden, selector, h_out=(ti < (self.num_frames-1)))
# No need to encode the last frame
if ti < (self.num_frames-1):
is_deep_update = np.random.rand() < self.deep_update_prob
v16, hidden = self.XMem('encode_value', frames[:,ti], f16[:,ti], hidden, masks, is_deep_update=is_deep_update)
values = torch.cat([values, v16.unsqueeze(3)], 3)
out[f'masks_{ti}'] = masks
out[f'logits_{ti}'] = logits
if self._do_log or self._is_train:
losses = self.loss_computer.compute({**data, **out}, num_filled_objects, it)
# Logging
if self._do_log:
self.integrator.add_dict(losses)
if self._is_train:
if it % self.log_image_interval == 0 and it != 0:
if self.logger is not None:
images = {**data, **out}
size = (384, 384)
self.logger.log_cv2('train/pairs', pool_pairs(images, size, num_filled_objects), it)
if self._is_train:
if (it) % self.log_text_interval == 0 and it != 0:
time_spent = time.time()-self.last_time
if self.logger is not None:
self.logger.log_scalar('train/lr', self.scheduler.get_last_lr()[0], it)
self.logger.log_metrics('train', 'time', (time_spent)/self.log_text_interval, it)
global_avg = 0.5*(global_avg) + 0.5*(time_spent)
eta_seconds = global_avg * (max_it - it) / 100
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
print(f'ETA: {eta_string}')
self.last_time = time.time()
self.train_integrator.finalize('train', it)
self.train_integrator.reset_except_hooks()
if it % self.save_network_interval == 0 and it != 0:
if self.logger is not None:
self.save_network(it)
if it % self.save_checkpoint_interval == 0 and it != 0:
if self.logger is not None:
self.save_checkpoint(it)
# Backward pass
self.optimizer.zero_grad(set_to_none=True)
if self.config['amp']:
self.scaler.scale(losses['total_loss']).backward()
self.scaler.step(self.optimizer)
self.scaler.update()
else:
losses['total_loss'].backward()
self.optimizer.step()
self.scheduler.step()
def save_network(self, it):
if self.save_path is None:
print('Saving has been disabled.')
return
os.makedirs(os.path.dirname(self.save_path), exist_ok=True)
model_path = f'{self.save_path}_{it}.pth'
torch.save(self.XMem.module.state_dict(), model_path)
print(f'Network saved to {model_path}.')
def save_checkpoint(self, it):
if self.save_path is None:
print('Saving has been disabled.')
return
os.makedirs(os.path.dirname(self.save_path), exist_ok=True)
checkpoint_path = f'{self.save_path}_checkpoint_{it}.pth'
checkpoint = {
'it': it,
'network': self.XMem.module.state_dict(),
'optimizer': self.optimizer.state_dict(),
'scheduler': self.scheduler.state_dict()}
torch.save(checkpoint, checkpoint_path)
print(f'Checkpoint saved to {checkpoint_path}.')
def load_checkpoint(self, path):
# This method loads everything and should be used to resume training
map_location = 'cuda:%d' % self.local_rank
checkpoint = torch.load(path, map_location={'cpu': map_location})
it = checkpoint['it']
network = checkpoint['network']
optimizer = checkpoint['optimizer']
scheduler = checkpoint['scheduler']
map_location = 'cuda:%d' % self.local_rank
self.XMem.module.load_state_dict(network)
self.optimizer.load_state_dict(optimizer)
self.scheduler.load_state_dict(scheduler)
print('Network weights, optimizer states, and scheduler states loaded.')
return it
def load_network_in_memory(self, src_dict):
self.XMem.module.load_weights(src_dict)
print('Network weight loaded from memory.')
def load_network(self, path):
# This method loads only the network weight and should be used to load a pretrained model
map_location = 'cuda:%d' % self.local_rank
src_dict = torch.load(path, map_location={'cpu': map_location})
self.load_network_in_memory(src_dict)
print(f'Network weight loaded from {path}')
def train(self):
self._is_train = True
self._do_log = True
self.integrator = self.train_integrator
self.XMem.eval()
return self
def val(self):
self._is_train = False
self._do_log = True
self.XMem.eval()
return self
def test(self):
self._is_train = False
self._do_log = False
self.XMem.eval()
return self