Spaces:
Build error
Build error
"""Implements training new models""" | |
import time | |
import copy | |
from collections import defaultdict | |
import numpy as np | |
import torch | |
import torchvision.transforms as transforms | |
from cirtorch.layers.loss import ContrastiveLoss | |
from cirtorch.datasets.datahelpers import collate_tuples | |
from cirtorch.datasets.traindataset import TuplesDataset | |
from cirtorch.datasets.genericdataset import ImagesFromList | |
from ..networks import how_net | |
from ..utils import data_helpers, io_helpers, logging, plots | |
from . import evaluate | |
def train(demo_train, training, validation, model, globals): | |
"""Demo training a network | |
:param dict demo_train: Demo-related options | |
:param dict training: Training options | |
:param dict validation: Validation options | |
:param dict model: Model options | |
:param dict globals: Global options | |
""" | |
logger = globals["logger"] | |
(globals["exp_path"] / "epochs").mkdir(exist_ok=True) | |
if (globals["exp_path"] / f"epochs/model_epoch{training['epochs']}.pth").exists(): | |
logger.info("Skipping network training, already trained") | |
return | |
# Global setup | |
set_seed(0) | |
globals["device"] = torch.device("cpu") | |
if demo_train['gpu_id'] is not None: | |
globals["device"] = torch.device(("cuda:%s" % demo_train['gpu_id'])) | |
# Initialize network | |
net = how_net.init_network(**model).to(globals["device"]) | |
globals["transform"] = transforms.Compose([transforms.ToTensor(), \ | |
transforms.Normalize(**dict(zip(["mean", "std"], net.runtime['mean_std'])))]) | |
with logging.LoggingStopwatch("initializing network whitening", logger.info, logger.debug): | |
initialize_dim_reduction(net, globals, **training['initialize_dim_reduction']) | |
# Initialize training | |
optimizer, scheduler, criterion, train_loader = \ | |
initialize_training(net.parameter_groups(training["optimizer"]), training, globals) | |
validation = Validation(validation, globals) | |
for epoch in range(training['epochs']): | |
epoch1 = epoch + 1 | |
set_seed(epoch1) | |
time0 = time.time() | |
train_loss = train_epoch(train_loader, net, globals, criterion, optimizer, epoch1) | |
validation.add_train_loss(train_loss, epoch1) | |
validation.validate(net, epoch1) | |
scheduler.step() | |
io_helpers.save_checkpoint({ | |
'epoch': epoch1, 'meta': net.meta, 'state_dict': net.state_dict(), | |
'optimizer' : optimizer.state_dict(), 'best_score': validation.best_score[1], | |
'scores': validation.scores, 'net_params': model, '_version': 'how/2020', | |
}, validation.best_score[0] == epoch1, epoch1 == training['epochs'], globals["exp_path"] / "epochs") | |
logger.info(f"Epoch {epoch1} finished in {time.time() - time0:.1f}s") | |
def train_epoch(train_loader, net, globals, criterion, optimizer, epoch1): | |
"""Train for one epoch""" | |
logger = globals['logger'] | |
batch_time = data_helpers.AverageMeter() | |
data_time = data_helpers.AverageMeter() | |
losses = data_helpers.AverageMeter() | |
# Prepare epoch | |
train_loader.dataset.create_epoch_tuples(net) | |
net.train() | |
end = time.time() | |
for i, (input, target) in enumerate(train_loader): | |
data_time.update(time.time() - end) | |
optimizer.zero_grad() | |
num_images = len(input[0]) # number of images per tuple | |
for inp, trg in zip(input, target): | |
output = torch.zeros(net.meta['outputdim'], num_images).to(globals["device"]) | |
for imi in range(num_images): | |
output[:, imi] = net(inp[imi].to(globals["device"])).squeeze() | |
loss = criterion(output, trg.to(globals["device"])) | |
loss.backward() | |
losses.update(loss.item()) | |
optimizer.step() | |
batch_time.update(time.time() - end) | |
end = time.time() | |
if (i+1) % 20 == 0 or i == 0 or (i+1) == len(train_loader): | |
logger.info(f'>> Train: [{epoch1}][{i+1}/{len(train_loader)}]\t' \ | |
f'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' \ | |
f'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' \ | |
f'Loss {losses.val:.4f} ({losses.avg:.4f})') | |
return losses.avg | |
def set_seed(seed): | |
"""Sets given seed globally in used libraries""" | |
torch.manual_seed(seed) | |
if torch.cuda.is_available(): | |
torch.cuda.manual_seed_all(seed) | |
np.random.seed(seed) | |
def initialize_training(net_parameters, training, globals): | |
"""Initialize classes necessary for training""" | |
# Need to check for keys because of defaults | |
assert training['optimizer'].keys() == {"lr", "weight_decay"} | |
assert training['lr_scheduler'].keys() == {"gamma"} | |
assert training['loss'].keys() == {"margin"} | |
assert training['dataset'].keys() == {"name", "mode", "imsize", "nnum", "qsize", "poolsize"} | |
assert training['loader'].keys() == {"batch_size"} | |
optimizer = torch.optim.Adam(net_parameters, **training["optimizer"]) | |
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, **training["lr_scheduler"]) | |
criterion = ContrastiveLoss(**training["loss"]).to(globals["device"]) | |
train_dataset = TuplesDataset(**training['dataset'], transform=globals["transform"]) | |
train_loader = torch.utils.data.DataLoader(train_dataset, **training['loader'], \ | |
pin_memory=True, drop_last=True, shuffle=True, collate_fn=collate_tuples, \ | |
num_workers=how_net.NUM_WORKERS) | |
return optimizer, scheduler, criterion, train_loader | |
def extract_train_descriptors(net, globals, *, images, features_num): | |
"""Extract descriptors for a given number of images from the train set""" | |
if features_num is None: | |
features_num = net.runtime['features_num'] | |
images = data_helpers.load_dataset('train', data_root=globals['root_path'])[0][:images] | |
dataset = ImagesFromList(root='', images=images, imsize=net.runtime['image_size'], bbxs=None, | |
transform=globals["transform"]) | |
des_train = how_net.extract_vectors_local(net, dataset, globals["device"], | |
scales=net.runtime['training_scales'], | |
features_num=features_num)[0] | |
return des_train | |
def initialize_dim_reduction(net, globals, **kwargs): | |
"""Initialize dimensionality reduction by PCA whitening from 'images' number of descriptors""" | |
if not net.dim_reduction: | |
return | |
print(">> Initializing dim reduction") | |
des_train = extract_train_descriptors(net.copy_excluding_dim_reduction(), globals, **kwargs) | |
net.dim_reduction.initialize_pca_whitening(des_train) | |
class Validation: | |
"""A convenient interface to validation, keeping historical values and plotting continuously | |
:param dict validations: Options for each validation type (e.g. local_descriptor) | |
:param dict globals: Global options | |
""" | |
methods = { | |
"global_descriptor": evaluate.eval_global, | |
"local_descriptor": evaluate.eval_asmk, | |
} | |
def __init__(self, validations, globals): | |
validations = copy.deepcopy(validations) | |
self.frequencies = {x: y.pop("frequency") for x, y in validations.items()} | |
self.validations = validations | |
self.globals = globals | |
self.scores = {x: defaultdict(list) for x in validations} | |
self.scores["train_loss"] = [] | |
def add_train_loss(self, loss, epoch): | |
"""Store training loss for given epoch""" | |
self.scores['train_loss'].append((epoch, loss)) | |
fig = plots.EpochFigure("train set", ylabel="loss") | |
fig.plot(*list(zip(*self.scores["train_loss"])), 'o-', label='train') | |
fig.save(self.globals['exp_path'] / "fig_train.jpg") | |
def validate(self, net, epoch): | |
"""Perform validation of the network and store the resulting score for given epoch""" | |
for name, frequency in self.frequencies.items(): | |
if frequency and epoch % frequency == 0: | |
scores = self.methods[name](net, net.runtime, self.globals, **self.validations[name]) | |
for dataset, values in scores.items(): | |
value = values['map_medium'] if "map_medium" in values else values['map'] | |
self.scores[name][dataset].append((epoch, value)) | |
if "val_eccv20" in scores: | |
fig = plots.EpochFigure(f"val set - {name}", ylabel="mAP") | |
fig.plot(*list(zip(*self.scores[name]['val_eccv20'])), 'o-', label='val') | |
fig.save(self.globals['exp_path'] / f"fig_val_{name}.jpg") | |
if scores.keys() - {"val_eccv20"}: | |
fig = plots.EpochFigure(f"test set - {name}", ylabel="mAP") | |
for dataset, value in self.scores[name].items(): | |
if dataset != "val_eccv20": | |
fig.plot(*list(zip(*value)), 'o-', label=dataset) | |
fig.save(self.globals['exp_path'] / f"fig_test_{name}.jpg") | |
def decisive_scores(self): | |
"""List of pairs (epoch, score) where score is decisive for comparing epochs""" | |
for name in ["local_descriptor", "global_descriptor"]: | |
if self.frequencies[name] and "val_eccv20" in self.scores[name]: | |
return self.scores[name]['val_eccv20'] | |
return self.scores["train_loss"] | |
def last_epoch(self): | |
"""Tuple (last epoch, last score) or (None, None) before decisive score is computed""" | |
decisive_scores = self.decisive_scores | |
if not decisive_scores: | |
return None, None | |
return decisive_scores[-1] | |
def best_score(self): | |
"""Tuple (best epoch, best score) or (None, None) before decisive score is computed""" | |
decisive_scores = self.decisive_scores | |
if not decisive_scores: | |
return None, None | |
aggr = min | |
for name in ["local_descriptor", "global_descriptor"]: | |
if self.frequencies[name] and "val_eccv20" in self.scores[name]: | |
aggr = max | |
return aggr(decisive_scores, key=lambda x: x[1]) | |