# flake8: noqa E501 # pylint: disable=not-callable # E501: line too long from collections import defaultdict from datetime import datetime import glob import os import tempfile from boltons.cacheutils import cached, LRU from boltons.fileutils import atomic_save, mkdir_p from boltons.iterutils import windowed from IPython import get_ipython from IPython.display import display import matplotlib.pyplot as plt import numpy as np import pandas as pd import pysaliency from pysaliency.filter_datasets import iterate_crossvalidation from pysaliency.plotting import visualize_distribution import torch from torch.utils.tensorboard import SummaryWriter from tqdm import tqdm import yaml from .data import ImageDataset, FixationDataset, ImageDatasetSampler, FixationMaskTransform #from .loading import import_class, build_model, DeepGazeCheckpointModel, SharedPyTorchModel, _get_from_config from .metrics import log_likelihood, nss, auc from .modules import DeepGazeII baseline_performance = cached(LRU(max_size=3))(lambda model, *args, **kwargs: model.information_gain(*args, **kwargs)) def eval_epoch(model, dataset, baseline_information_gain, device, metrics=None): model.eval() if metrics is None: metrics = ['LL', 'IG', 'NSS', 'AUC'] metric_scores = {} metric_functions = { 'LL': log_likelihood, 'NSS': nss, 'AUC': auc, } batch_weights = [] with torch.no_grad(): pbar = tqdm(dataset) for batch in pbar: image = batch.pop('image').to(device) centerbias = batch.pop('centerbias').to(device) fixation_mask = batch.pop('fixation_mask').to(device) x_hist = batch.pop('x_hist', torch.tensor([])).to(device) y_hist = batch.pop('y_hist', torch.tensor([])).to(device) weights = batch.pop('weight').to(device) durations = batch.pop('durations', torch.tensor([])).to(device) kwargs = {} for key, value in dict(batch).items(): kwargs[key] = value.to(device) if isinstance(model, DeepGazeII): log_density = model(image, centerbias, **kwargs) else: log_density = model(image, centerbias, x_hist=x_hist, y_hist=y_hist, durations=durations, **kwargs) for metric_name, metric_fn in metric_functions.items(): if metric_name not in metrics: continue metric_scores.setdefault(metric_name, []).append(metric_fn(log_density, fixation_mask, weights=weights).detach().cpu().numpy()) batch_weights.append(weights.detach().cpu().numpy().sum()) for display_metric in ['LL', 'NSS', 'AUC']: if display_metric in metrics: pbar.set_description('{} {:.05f}'.format(display_metric, np.average(metric_scores[display_metric], weights=batch_weights))) break data = {metric_name: np.average(scores, weights=batch_weights) for metric_name, scores in metric_scores.items()} if 'IG' in metrics: data['IG'] = data['LL'] - baseline_information_gain return data def train_epoch(model, dataset, optimizer, device): model.train() losses = [] batch_weights = [] pbar = tqdm(dataset) for batch in pbar: optimizer.zero_grad() image = batch.pop('image').to(device) centerbias = batch.pop('centerbias').to(device) fixation_mask = batch.pop('fixation_mask').to(device) x_hist = batch.pop('x_hist', torch.tensor([])).to(device) y_hist = batch.pop('y_hist', torch.tensor([])).to(device) weights = batch.pop('weight').to(device) durations = batch.pop('durations', torch.tensor([])).to(device) kwargs = {} for key, value in dict(batch).items(): kwargs[key] = value.to(device) if isinstance(model, DeepGazeII): log_density = model(image, centerbias, **kwargs) else: log_density = model(image, centerbias, x_hist=x_hist, y_hist=y_hist, durations=durations, **kwargs) loss = -log_likelihood(log_density, fixation_mask, weights=weights) losses.append(loss.detach().cpu().numpy()) batch_weights.append(weights.detach().cpu().numpy().sum()) pbar.set_description('{:.05f}'.format(np.average(losses, weights=batch_weights))) loss.backward() optimizer.step() return np.average(losses, weights=batch_weights) def restore_from_checkpoint(model, optimizer, scheduler, path): print("Restoring from", path) data = torch.load(path) if 'optimizer' in data: # checkpoint contains training progress model.load_state_dict(data['model']) optimizer.load_state_dict(data['optimizer']) scheduler.load_state_dict(data['scheduler']) torch.set_rng_state(data['rng_state']) return data['step'], data['loss'] else: # checkpoint contains just a model missing_keys, unexpected_keys = model.load_state_dict(data, strict=False) if missing_keys: print("WARNING! missing keys", missing_keys) if unexpected_keys: print("WARNING! Unexpected keys", unexpected_keys) def save_training_state(model, optimizer, scheduler, step, loss, path): data = { 'model': model.state_dict(), 'optimizer': optimizer.state_dict(), 'scheduler': scheduler.state_dict(), 'rng_state': torch.get_rng_state(), 'step': step, 'loss': loss, } with atomic_save(path, text_mode=False, overwrite_part=True) as f: torch.save(data, f) def _train(this_directory, model, train_loader, train_baseline_log_likelihood, val_loader, val_baseline_log_likelihood, optimizer, lr_scheduler, #optimizer_config, lr_scheduler_config, minimum_learning_rate, #initial_learning_rate, learning_rate_scheduler, learning_rate_decay, learning_rate_decay_epochs, learning_rate_backlook, learning_rate_reset_strategy, minimum_learning_rate, validation_metric='IG', validation_metrics=['IG', 'LL', 'AUC', 'NSS'], validation_epochs=1, startwith=None, device=None): mkdir_p(this_directory) if os.path.isfile(os.path.join(this_directory, 'final.pth')): print("Training Already finished") return if device is None: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print("Using device", device) model.to(device) val_metrics = defaultdict(lambda: []) if startwith is not None: restore_from_checkpoint(model, optimizer, lr_scheduler, startwith) writer = SummaryWriter(os.path.join(this_directory, 'log'), flush_secs=30) columns = ['epoch', 'timestamp', 'learning_rate', 'loss'] print("validation metrics", validation_metrics) for metric in validation_metrics: columns.append(f'validation_{metric}') progress = pd.DataFrame(columns=columns) step = 0 last_loss = np.nan def save_step(): save_training_state( model, optimizer, lr_scheduler, step, last_loss, '{}/step-{:04d}.pth'.format(this_directory, step), ) #f = visualize(model, vis_data_loader) #display_if_in_IPython(f) #writer.add_figure('prediction', f, step) writer.add_scalar('training/loss', last_loss, step) writer.add_scalar('training/learning_rate', optimizer.state_dict()['param_groups'][0]['lr'], step) writer.add_scalar('parameters/sigma', model.finalizer.gauss.sigma.detach().cpu().numpy(), step) writer.add_scalar('parameters/center_bias_weight', model.finalizer.center_bias_weight.detach().cpu().numpy()[0], step) if step % validation_epochs == 0: _val_metrics = eval_epoch(model, val_loader, val_baseline_log_likelihood, device, metrics=validation_metrics) else: print("Skipping validation") _val_metrics = {} for key, value in _val_metrics.items(): val_metrics[key].append(value) for key, value in _val_metrics.items(): writer.add_scalar(f'validation/{key}', value, step) new_row = { 'epoch': step, 'timestamp': datetime.utcnow(), 'learning_rate': optimizer.state_dict()['param_groups'][0]['lr'], 'loss': last_loss, #'validation_ig': val_igs[-1] } for key, value in _val_metrics.items(): new_row['validation_{}'.format(key)] = value progress.loc[step] = new_row print(progress.tail(n=2)) print(progress[['validation_{}'.format(key) for key in val_metrics]].idxmax(axis=0)) with atomic_save('{}/log.csv'.format(this_directory), text_mode=True, overwrite_part=True) as f: progress.to_csv(f) for old_step in range(1, step): # only check if we are computing validation metrics... if validation_metric in val_metrics and val_metrics[validation_metric] and old_step == np.argmax(val_metrics[validation_metric]): continue for filename in glob.glob('{}/step-{:04d}.pth'.format(this_directory, old_step)): print("removing", filename) os.remove(filename) old_checkpoints = sorted(glob.glob(os.path.join(this_directory, 'step-*.pth'))) if old_checkpoints: last_checkpoint = old_checkpoints[-1] print("Found old checkpoint", last_checkpoint) step, last_loss = restore_from_checkpoint(model, optimizer, lr_scheduler, last_checkpoint) print("Setting step to", step) if step == 0: print("Beginning training") save_step() else: print("Continuing from step", step) progress = pd.read_csv(os.path.join(this_directory, 'log.csv'), index_col=0) val_metrics = {} for column_name in progress.columns: if column_name.startswith('validation_'): val_metrics[column_name.split('validation_', 1)[1]] = list(progress[column_name]) if step not in progress.epoch.values: print("Epoch not yet evaluated, evaluating...") save_step() # We have to make one scheduler step here, since we make the # scheduler step _after_ saving the checkpoint lr_scheduler.step() print(progress) while optimizer.state_dict()['param_groups'][0]['lr'] >= minimum_learning_rate: step += 1 last_loss = train_epoch(model, train_loader, optimizer, device) save_step() lr_scheduler.step() #if learning_rate_reset_strategy == 'validation': # best_step = np.argmax(val_metrics[validation_metric]) # print("Best previous validation in step {}, saving as final result".format(best_step)) # restore_from_checkpoint(model, optimizer, scheduler, os.path.join(this_directory, 'step-{:04d}.pth'.format(best_step))) #else: # print("Not resetting to best validation epoch") torch.save(model.state_dict(), '{}/final.pth'.format(this_directory)) for filename in glob.glob(os.path.join(this_directory, 'step-*')): print("removing", filename) os.remove(filename)