Spaces:
Sleeping
Sleeping
""" | |
trainer.py - warpper and utility functions for network training | |
Compute loss, back-prop, update parameters, logging, etc. | |
""" | |
import datetime | |
import os | |
import time | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.optim as optim | |
from model.network import XMem | |
from model.losses import LossComputer | |
from util.log_integrator import Integrator | |
from util.image_saver import pool_pairs | |
class XMemTrainer: | |
def __init__(self, config, logger=None, save_path=None, local_rank=0, world_size=1): | |
self.config = config | |
self.num_frames = config["num_frames"] | |
self.num_ref_frames = config["num_ref_frames"] | |
self.deep_update_prob = config["deep_update_prob"] | |
self.local_rank = local_rank | |
self.XMem = nn.parallel.DistributedDataParallel( | |
XMem(config).cuda(), | |
device_ids=[local_rank], | |
output_device=local_rank, | |
broadcast_buffers=False, | |
) | |
# Set up logger when local_rank=0 | |
self.logger = logger | |
self.save_path = save_path | |
if logger is not None: | |
self.last_time = time.time() | |
self.logger.log_string( | |
"model_size", | |
str(sum([param.nelement() for param in self.XMem.parameters()])), | |
) | |
self.train_integrator = Integrator( | |
self.logger, distributed=True, local_rank=local_rank, world_size=world_size | |
) | |
self.loss_computer = LossComputer(config) | |
self.train() | |
self.optimizer = optim.AdamW( | |
filter(lambda p: p.requires_grad, self.XMem.parameters()), | |
lr=config["lr"], | |
weight_decay=config["weight_decay"], | |
) | |
self.scheduler = optim.lr_scheduler.MultiStepLR( | |
self.optimizer, config["steps"], config["gamma"] | |
) | |
if config["amp"]: | |
self.scaler = torch.cuda.amp.GradScaler() | |
# Logging info | |
self.log_text_interval = config["log_text_interval"] | |
self.log_image_interval = config["log_image_interval"] | |
self.save_network_interval = config["save_network_interval"] | |
self.save_checkpoint_interval = config["save_checkpoint_interval"] | |
if config["debug"]: | |
self.log_text_interval = self.log_image_interval = 1 | |
def do_pass(self, data, max_it, it=0): | |
# No need to store the gradient outside training | |
torch.set_grad_enabled(self._is_train) | |
for k, v in data.items(): | |
if type(v) != list and type(v) != dict and type(v) != int: | |
data[k] = v.cuda(non_blocking=True) | |
out = {} | |
frames = data["rgb"] | |
first_frame_gt = data["first_frame_gt"].float() | |
b = frames.shape[0] | |
num_filled_objects = [o.item() for o in data["info"]["num_objects"]] | |
num_objects = first_frame_gt.shape[2] | |
selector = data["selector"].unsqueeze(2).unsqueeze(2) | |
global_avg = 0 | |
with torch.cuda.amp.autocast(enabled=self.config["amp"]): | |
# image features never change, compute once | |
key, shrinkage, selection, f16, f8, f4 = self.XMem("encode_key", frames) | |
filler_one = torch.zeros(1, dtype=torch.int64) | |
hidden = torch.zeros( | |
(b, num_objects, self.config["hidden_dim"], *key.shape[-2:]) | |
) | |
v16, hidden = self.XMem( | |
"encode_value", frames[:, 0], f16[:, 0], hidden, first_frame_gt[:, 0] | |
) | |
values = v16.unsqueeze(3) # add the time dimension | |
for ti in range(1, self.num_frames): | |
if ti <= self.num_ref_frames: | |
ref_values = values | |
ref_keys = key[:, :, :ti] | |
ref_shrinkage = ( | |
shrinkage[:, :, :ti] if shrinkage is not None else None | |
) | |
else: | |
# pick num_ref_frames random frames | |
# this is not very efficient but I think we would | |
# need broadcasting in gather which we don't have | |
indices = [ | |
torch.cat( | |
[ | |
filler_one, | |
torch.randperm(ti - 1)[: self.num_ref_frames - 1] + 1, | |
] | |
) | |
for _ in range(b) | |
] | |
ref_values = torch.stack( | |
[values[bi, :, :, indices[bi]] for bi in range(b)], 0 | |
) | |
ref_keys = torch.stack( | |
[key[bi, :, indices[bi]] for bi in range(b)], 0 | |
) | |
ref_shrinkage = ( | |
torch.stack( | |
[shrinkage[bi, :, indices[bi]] for bi in range(b)], 0 | |
) | |
if shrinkage is not None | |
else None | |
) | |
# Segment frame ti | |
memory_readout = self.XMem( | |
"read_memory", | |
key[:, :, ti], | |
selection[:, :, ti] if selection is not None else None, | |
ref_keys, | |
ref_shrinkage, | |
ref_values, | |
) | |
hidden, logits, masks = self.XMem( | |
"segment", | |
(f16[:, ti], f8[:, ti], f4[:, ti]), | |
memory_readout, | |
hidden, | |
selector, | |
h_out=(ti < (self.num_frames - 1)), | |
) | |
# No need to encode the last frame | |
if ti < (self.num_frames - 1): | |
is_deep_update = np.random.rand() < self.deep_update_prob | |
v16, hidden = self.XMem( | |
"encode_value", | |
frames[:, ti], | |
f16[:, ti], | |
hidden, | |
masks, | |
is_deep_update=is_deep_update, | |
) | |
values = torch.cat([values, v16.unsqueeze(3)], 3) | |
out[f"masks_{ti}"] = masks | |
out[f"logits_{ti}"] = logits | |
if self._do_log or self._is_train: | |
losses = self.loss_computer.compute( | |
{**data, **out}, num_filled_objects, it | |
) | |
# Logging | |
if self._do_log: | |
self.integrator.add_dict(losses) | |
if self._is_train: | |
if it % self.log_image_interval == 0 and it != 0: | |
if self.logger is not None: | |
images = {**data, **out} | |
size = (384, 384) | |
self.logger.log_cv2( | |
"train/pairs", | |
pool_pairs(images, size, num_filled_objects), | |
it, | |
) | |
if self._is_train: | |
if (it) % self.log_text_interval == 0 and it != 0: | |
time_spent = time.time() - self.last_time | |
if self.logger is not None: | |
self.logger.log_scalar( | |
"train/lr", self.scheduler.get_last_lr()[0], it | |
) | |
self.logger.log_metrics( | |
"train", "time", (time_spent) / self.log_text_interval, it | |
) | |
global_avg = 0.5 * (global_avg) + 0.5 * (time_spent) | |
eta_seconds = global_avg * (max_it - it) / 100 | |
eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) | |
print(f"ETA: {eta_string}") | |
self.last_time = time.time() | |
self.train_integrator.finalize("train", it) | |
self.train_integrator.reset_except_hooks() | |
if it % self.save_network_interval == 0 and it != 0: | |
if self.logger is not None: | |
self.save_network(it) | |
if it % self.save_checkpoint_interval == 0 and it != 0: | |
if self.logger is not None: | |
self.save_checkpoint(it) | |
# Backward pass | |
self.optimizer.zero_grad(set_to_none=True) | |
if self.config["amp"]: | |
self.scaler.scale(losses["total_loss"]).backward() | |
self.scaler.step(self.optimizer) | |
self.scaler.update() | |
else: | |
losses["total_loss"].backward() | |
self.optimizer.step() | |
self.scheduler.step() | |
def save_network(self, it): | |
if self.save_path is None: | |
print("Saving has been disabled.") | |
return | |
os.makedirs(os.path.dirname(self.save_path), exist_ok=True) | |
model_path = f"{self.save_path}_{it}.pth" | |
torch.save(self.XMem.module.state_dict(), model_path) | |
print(f"Network saved to {model_path}.") | |
def save_checkpoint(self, it): | |
if self.save_path is None: | |
print("Saving has been disabled.") | |
return | |
os.makedirs(os.path.dirname(self.save_path), exist_ok=True) | |
checkpoint_path = f"{self.save_path}_checkpoint_{it}.pth" | |
checkpoint = { | |
"it": it, | |
"network": self.XMem.module.state_dict(), | |
"optimizer": self.optimizer.state_dict(), | |
"scheduler": self.scheduler.state_dict(), | |
} | |
torch.save(checkpoint, checkpoint_path) | |
print(f"Checkpoint saved to {checkpoint_path}.") | |
def load_checkpoint(self, path): | |
# This method loads everything and should be used to resume training | |
map_location = "cuda:%d" % self.local_rank | |
checkpoint = torch.load(path, map_location={"cuda:0": map_location}) | |
it = checkpoint["it"] | |
network = checkpoint["network"] | |
optimizer = checkpoint["optimizer"] | |
scheduler = checkpoint["scheduler"] | |
map_location = "cuda:%d" % self.local_rank | |
self.XMem.module.load_state_dict(network) | |
self.optimizer.load_state_dict(optimizer) | |
self.scheduler.load_state_dict(scheduler) | |
print("Network weights, optimizer states, and scheduler states loaded.") | |
return it | |
def load_network_in_memory(self, src_dict): | |
self.XMem.module.load_weights(src_dict) | |
print("Network weight loaded from memory.") | |
def load_network(self, path): | |
# This method loads only the network weight and should be used to load a pretrained model | |
map_location = "cuda:%d" % self.local_rank | |
src_dict = torch.load(path, map_location={"cuda:0": map_location}) | |
self.load_network_in_memory(src_dict) | |
print(f"Network weight loaded from {path}") | |
def train(self): | |
self._is_train = True | |
self._do_log = True | |
self.integrator = self.train_integrator | |
self.XMem.eval() | |
return self | |
def val(self): | |
self._is_train = False | |
self._do_log = True | |
self.XMem.eval() | |
return self | |
def test(self): | |
self._is_train = False | |
self._do_log = False | |
self.XMem.eval() | |
return self | |