Spaces:
Running
Running
# flake8: noqa E501 | |
# pylint: disable=not-callable | |
# E501: line too long | |
from collections import defaultdict | |
from datetime import datetime | |
import glob | |
import os | |
import tempfile | |
from boltons.cacheutils import cached, LRU | |
from boltons.fileutils import atomic_save, mkdir_p | |
from boltons.iterutils import windowed | |
from IPython import get_ipython | |
from IPython.display import display | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import pandas as pd | |
import pysaliency | |
from pysaliency.filter_datasets import iterate_crossvalidation | |
from pysaliency.plotting import visualize_distribution | |
import torch | |
from torch.utils.tensorboard import SummaryWriter | |
from tqdm import tqdm | |
import yaml | |
from .data import ImageDataset, FixationDataset, ImageDatasetSampler, FixationMaskTransform | |
#from .loading import import_class, build_model, DeepGazeCheckpointModel, SharedPyTorchModel, _get_from_config | |
from .metrics import log_likelihood, nss, auc | |
from .modules import DeepGazeII | |
baseline_performance = cached(LRU(max_size=3))(lambda model, *args, **kwargs: model.information_gain(*args, **kwargs)) | |
def eval_epoch(model, dataset, baseline_information_gain, device, metrics=None): | |
model.eval() | |
if metrics is None: | |
metrics = ['LL', 'IG', 'NSS', 'AUC'] | |
metric_scores = {} | |
metric_functions = { | |
'LL': log_likelihood, | |
'NSS': nss, | |
'AUC': auc, | |
} | |
batch_weights = [] | |
with torch.no_grad(): | |
pbar = tqdm(dataset) | |
for batch in pbar: | |
image = batch.pop('image').to(device) | |
centerbias = batch.pop('centerbias').to(device) | |
fixation_mask = batch.pop('fixation_mask').to(device) | |
x_hist = batch.pop('x_hist', torch.tensor([])).to(device) | |
y_hist = batch.pop('y_hist', torch.tensor([])).to(device) | |
weights = batch.pop('weight').to(device) | |
durations = batch.pop('durations', torch.tensor([])).to(device) | |
kwargs = {} | |
for key, value in dict(batch).items(): | |
kwargs[key] = value.to(device) | |
if isinstance(model, DeepGazeII): | |
log_density = model(image, centerbias, **kwargs) | |
else: | |
log_density = model(image, centerbias, x_hist=x_hist, y_hist=y_hist, durations=durations, **kwargs) | |
for metric_name, metric_fn in metric_functions.items(): | |
if metric_name not in metrics: | |
continue | |
metric_scores.setdefault(metric_name, []).append(metric_fn(log_density, fixation_mask, weights=weights).detach().cpu().numpy()) | |
batch_weights.append(weights.detach().cpu().numpy().sum()) | |
for display_metric in ['LL', 'NSS', 'AUC']: | |
if display_metric in metrics: | |
pbar.set_description('{} {:.05f}'.format(display_metric, np.average(metric_scores[display_metric], weights=batch_weights))) | |
break | |
data = {metric_name: np.average(scores, weights=batch_weights) for metric_name, scores in metric_scores.items()} | |
if 'IG' in metrics: | |
data['IG'] = data['LL'] - baseline_information_gain | |
return data | |
def train_epoch(model, dataset, optimizer, device): | |
model.train() | |
losses = [] | |
batch_weights = [] | |
pbar = tqdm(dataset) | |
for batch in pbar: | |
optimizer.zero_grad() | |
image = batch.pop('image').to(device) | |
centerbias = batch.pop('centerbias').to(device) | |
fixation_mask = batch.pop('fixation_mask').to(device) | |
x_hist = batch.pop('x_hist', torch.tensor([])).to(device) | |
y_hist = batch.pop('y_hist', torch.tensor([])).to(device) | |
weights = batch.pop('weight').to(device) | |
durations = batch.pop('durations', torch.tensor([])).to(device) | |
kwargs = {} | |
for key, value in dict(batch).items(): | |
kwargs[key] = value.to(device) | |
if isinstance(model, DeepGazeII): | |
log_density = model(image, centerbias, **kwargs) | |
else: | |
log_density = model(image, centerbias, x_hist=x_hist, y_hist=y_hist, durations=durations, **kwargs) | |
loss = -log_likelihood(log_density, fixation_mask, weights=weights) | |
losses.append(loss.detach().cpu().numpy()) | |
batch_weights.append(weights.detach().cpu().numpy().sum()) | |
pbar.set_description('{:.05f}'.format(np.average(losses, weights=batch_weights))) | |
loss.backward() | |
optimizer.step() | |
return np.average(losses, weights=batch_weights) | |
def restore_from_checkpoint(model, optimizer, scheduler, path): | |
print("Restoring from", path) | |
data = torch.load(path) | |
if 'optimizer' in data: | |
# checkpoint contains training progress | |
model.load_state_dict(data['model']) | |
optimizer.load_state_dict(data['optimizer']) | |
scheduler.load_state_dict(data['scheduler']) | |
torch.set_rng_state(data['rng_state']) | |
return data['step'], data['loss'] | |
else: | |
# checkpoint contains just a model | |
missing_keys, unexpected_keys = model.load_state_dict(data, strict=False) | |
if missing_keys: | |
print("WARNING! missing keys", missing_keys) | |
if unexpected_keys: | |
print("WARNING! Unexpected keys", unexpected_keys) | |
def save_training_state(model, optimizer, scheduler, step, loss, path): | |
data = { | |
'model': model.state_dict(), | |
'optimizer': optimizer.state_dict(), | |
'scheduler': scheduler.state_dict(), | |
'rng_state': torch.get_rng_state(), | |
'step': step, | |
'loss': loss, | |
} | |
with atomic_save(path, text_mode=False, overwrite_part=True) as f: | |
torch.save(data, f) | |
def _train(this_directory, | |
model, | |
train_loader, train_baseline_log_likelihood, | |
val_loader, val_baseline_log_likelihood, | |
optimizer, lr_scheduler, | |
#optimizer_config, lr_scheduler_config, | |
minimum_learning_rate, | |
#initial_learning_rate, learning_rate_scheduler, learning_rate_decay, learning_rate_decay_epochs, learning_rate_backlook, learning_rate_reset_strategy, minimum_learning_rate, | |
validation_metric='IG', | |
validation_metrics=['IG', 'LL', 'AUC', 'NSS'], | |
validation_epochs=1, | |
startwith=None, | |
device=None): | |
mkdir_p(this_directory) | |
if os.path.isfile(os.path.join(this_directory, 'final.pth')): | |
print("Training Already finished") | |
return | |
if device is None: | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
print("Using device", device) | |
model.to(device) | |
val_metrics = defaultdict(lambda: []) | |
if startwith is not None: | |
restore_from_checkpoint(model, optimizer, lr_scheduler, startwith) | |
writer = SummaryWriter(os.path.join(this_directory, 'log'), flush_secs=30) | |
columns = ['epoch', 'timestamp', 'learning_rate', 'loss'] | |
print("validation metrics", validation_metrics) | |
for metric in validation_metrics: | |
columns.append(f'validation_{metric}') | |
progress = pd.DataFrame(columns=columns) | |
step = 0 | |
last_loss = np.nan | |
def save_step(): | |
save_training_state( | |
model, optimizer, lr_scheduler, step, last_loss, | |
'{}/step-{:04d}.pth'.format(this_directory, step), | |
) | |
#f = visualize(model, vis_data_loader) | |
#display_if_in_IPython(f) | |
#writer.add_figure('prediction', f, step) | |
writer.add_scalar('training/loss', last_loss, step) | |
writer.add_scalar('training/learning_rate', optimizer.state_dict()['param_groups'][0]['lr'], step) | |
writer.add_scalar('parameters/sigma', model.finalizer.gauss.sigma.detach().cpu().numpy(), step) | |
writer.add_scalar('parameters/center_bias_weight', model.finalizer.center_bias_weight.detach().cpu().numpy()[0], step) | |
if step % validation_epochs == 0: | |
_val_metrics = eval_epoch(model, val_loader, val_baseline_log_likelihood, device, metrics=validation_metrics) | |
else: | |
print("Skipping validation") | |
_val_metrics = {} | |
for key, value in _val_metrics.items(): | |
val_metrics[key].append(value) | |
for key, value in _val_metrics.items(): | |
writer.add_scalar(f'validation/{key}', value, step) | |
new_row = { | |
'epoch': step, | |
'timestamp': datetime.utcnow(), | |
'learning_rate': optimizer.state_dict()['param_groups'][0]['lr'], | |
'loss': last_loss, | |
#'validation_ig': val_igs[-1] | |
} | |
for key, value in _val_metrics.items(): | |
new_row['validation_{}'.format(key)] = value | |
progress.loc[step] = new_row | |
print(progress.tail(n=2)) | |
print(progress[['validation_{}'.format(key) for key in val_metrics]].idxmax(axis=0)) | |
with atomic_save('{}/log.csv'.format(this_directory), text_mode=True, overwrite_part=True) as f: | |
progress.to_csv(f) | |
for old_step in range(1, step): | |
# only check if we are computing validation metrics... | |
if validation_metric in val_metrics and val_metrics[validation_metric] and old_step == np.argmax(val_metrics[validation_metric]): | |
continue | |
for filename in glob.glob('{}/step-{:04d}.pth'.format(this_directory, old_step)): | |
print("removing", filename) | |
os.remove(filename) | |
old_checkpoints = sorted(glob.glob(os.path.join(this_directory, 'step-*.pth'))) | |
if old_checkpoints: | |
last_checkpoint = old_checkpoints[-1] | |
print("Found old checkpoint", last_checkpoint) | |
step, last_loss = restore_from_checkpoint(model, optimizer, lr_scheduler, last_checkpoint) | |
print("Setting step to", step) | |
if step == 0: | |
print("Beginning training") | |
save_step() | |
else: | |
print("Continuing from step", step) | |
progress = pd.read_csv(os.path.join(this_directory, 'log.csv'), index_col=0) | |
val_metrics = {} | |
for column_name in progress.columns: | |
if column_name.startswith('validation_'): | |
val_metrics[column_name.split('validation_', 1)[1]] = list(progress[column_name]) | |
if step not in progress.epoch.values: | |
print("Epoch not yet evaluated, evaluating...") | |
save_step() | |
# We have to make one scheduler step here, since we make the | |
# scheduler step _after_ saving the checkpoint | |
lr_scheduler.step() | |
print(progress) | |
while optimizer.state_dict()['param_groups'][0]['lr'] >= minimum_learning_rate: | |
step += 1 | |
last_loss = train_epoch(model, train_loader, optimizer, device) | |
save_step() | |
lr_scheduler.step() | |
#if learning_rate_reset_strategy == 'validation': | |
# best_step = np.argmax(val_metrics[validation_metric]) | |
# print("Best previous validation in step {}, saving as final result".format(best_step)) | |
# restore_from_checkpoint(model, optimizer, scheduler, os.path.join(this_directory, 'step-{:04d}.pth'.format(best_step))) | |
#else: | |
# print("Not resetting to best validation epoch") | |
torch.save(model.state_dict(), '{}/final.pth'.format(this_directory)) | |
for filename in glob.glob(os.path.join(this_directory, 'step-*')): | |
print("removing", filename) | |
os.remove(filename) |