"""Functions for training and running EF prediction.""" import math import os import time import click import matplotlib.pyplot as plt import numpy as np import sklearn.metrics import torch import torchvision import tqdm import echonet @click.command("video") @click.option("--data_dir", type=click.Path(exists=True, file_okay=False), default=None) @click.option("--output", type=click.Path(file_okay=False), default=None) @click.option("--task", type=str, default="EF") @click.option("--model_name", type=click.Choice( sorted(name for name in torchvision.models.video.__dict__ if name.islower() and not name.startswith("__") and callable(torchvision.models.video.__dict__[name]))), default="r2plus1d_18") @click.option("--pretrained/--random", default=True) @click.option("--weights", type=click.Path(exists=True, dir_okay=False), default=None) @click.option("--run_test/--skip_test", default=False) @click.option("--num_epochs", type=int, default=45) @click.option("--lr", type=float, default=1e-4) @click.option("--weight_decay", type=float, default=1e-4) @click.option("--lr_step_period", type=int, default=15) @click.option("--frames", type=int, default=32) @click.option("--period", type=int, default=2) @click.option("--num_train_patients", type=int, default=None) @click.option("--num_workers", type=int, default=4) @click.option("--batch_size", type=int, default=20) @click.option("--device", type=str, default=None) @click.option("--seed", type=int, default=0) def run( data_dir=None, output=None, task="EF", model_name="r2plus1d_18", pretrained=True, weights=None, run_test=False, num_epochs=45, lr=1e-4, weight_decay=1e-4, lr_step_period=15, frames=32, period=2, num_train_patients=None, num_workers=4, batch_size=20, device=None, seed=0, ): """Trains/tests EF prediction model. \b Args: data_dir (str, optional): Directory containing dataset. Defaults to `echonet.config.DATA_DIR`. output (str, optional): Directory to place outputs. Defaults to output/video/_/. task (str, optional): Name of task to predict. Options are the headers of FileList.csv. Defaults to ``EF''. model_name (str, optional): Name of model. One of ``mc3_18'', ``r2plus1d_18'', or ``r3d_18'' (options are torchvision.models.video.) Defaults to ``r2plus1d_18''. pretrained (bool, optional): Whether to use pretrained weights for model Defaults to True. weights (str, optional): Path to checkpoint containing weights to initialize model. Defaults to None. run_test (bool, optional): Whether or not to run on test. Defaults to False. num_epochs (int, optional): Number of epochs during training. Defaults to 45. lr (float, optional): Learning rate for SGD Defaults to 1e-4. weight_decay (float, optional): Weight decay for SGD Defaults to 1e-4. lr_step_period (int or None, optional): Period of learning rate decay (learning rate is decayed by a multiplicative factor of 0.1) Defaults to 15. frames (int, optional): Number of frames to use in clip Defaults to 32. period (int, optional): Sampling period for frames Defaults to 2. n_train_patients (int or None, optional): Number of training patients for ablations. Defaults to all patients. num_workers (int, optional): Number of subprocesses to use for data loading. If 0, the data will be loaded in the main process. Defaults to 4. device (str or None, optional): Name of device to run on. Options from https://pytorch.org/docs/stable/tensor_attributes.html#torch.torch.device Defaults to ``cuda'' if available, and ``cpu'' otherwise. batch_size (int, optional): Number of samples to load per batch Defaults to 20. seed (int, optional): Seed for random number generator. Defaults to 0. """ # Seed RNGs np.random.seed(seed) torch.manual_seed(seed) # Set default output directory if output is None: output = os.path.join("output", "video", "{}_{}_{}_{}".format(model_name, frames, period, "pretrained" if pretrained else "random")) os.makedirs(output, exist_ok=True) # Set device for computations if device is None: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Set up model model = torchvision.models.video.__dict__[model_name](pretrained=pretrained) model.fc = torch.nn.Linear(model.fc.in_features, 1) model.fc.bias.data[0] = 55.6 if device.type == "cuda": model = torch.nn.DataParallel(model) model.to(device) if weights is not None: checkpoint = torch.load(weights) model.load_state_dict(checkpoint['state_dict']) # Set up optimizer optim = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=weight_decay) if lr_step_period is None: lr_step_period = math.inf scheduler = torch.optim.lr_scheduler.StepLR(optim, lr_step_period) # Compute mean and std mean, std = echonet.utils.get_mean_and_std(echonet.datasets.Echo(root=data_dir, split="train")) kwargs = {"target_type": task, "mean": mean, "std": std, "length": frames, "period": period, } # Set up datasets and dataloaders dataset = {} dataset["train"] = echonet.datasets.Echo(root=data_dir, split="train", **kwargs, pad=12) if num_train_patients is not None and len(dataset["train"]) > num_train_patients: # Subsample patients (used for ablation experiment) indices = np.random.choice(len(dataset["train"]), num_train_patients, replace=False) dataset["train"] = torch.utils.data.Subset(dataset["train"], indices) dataset["val"] = echonet.datasets.Echo(root=data_dir, split="val", **kwargs) # Run training and testing loops with open(os.path.join(output, "log.csv"), "a") as f: epoch_resume = 0 bestLoss = float("inf") try: # Attempt to load checkpoint checkpoint = torch.load(os.path.join(output, "checkpoint.pt")) model.load_state_dict(checkpoint['state_dict']) optim.load_state_dict(checkpoint['opt_dict']) scheduler.load_state_dict(checkpoint['scheduler_dict']) epoch_resume = checkpoint["epoch"] + 1 bestLoss = checkpoint["best_loss"] f.write("Resuming from epoch {}\n".format(epoch_resume)) except FileNotFoundError: f.write("Starting run from scratch\n") for epoch in range(epoch_resume, num_epochs): print("Epoch #{}".format(epoch), flush=True) for phase in ['train', 'val']: start_time = time.time() for i in range(torch.cuda.device_count()): torch.cuda.reset_peak_memory_stats(i) ds = dataset[phase] dataloader = torch.utils.data.DataLoader( ds, batch_size=batch_size, num_workers=num_workers, shuffle=True, pin_memory=(device.type == "cuda"), drop_last=(phase == "train")) loss, yhat, y = echonet.utils.video.run_epoch(model, dataloader, phase == "train", optim, device) f.write("{},{},{},{},{},{},{},{},{}\n".format(epoch, phase, loss, sklearn.metrics.r2_score(y, yhat), time.time() - start_time, y.size, sum(torch.cuda.max_memory_allocated() for i in range(torch.cuda.device_count())), sum(torch.cuda.max_memory_reserved() for i in range(torch.cuda.device_count())), batch_size)) f.flush() scheduler.step() # Save checkpoint save = { 'epoch': epoch, 'state_dict': model.state_dict(), 'period': period, 'frames': frames, 'best_loss': bestLoss, 'loss': loss, 'r2': sklearn.metrics.r2_score(y, yhat), 'opt_dict': optim.state_dict(), 'scheduler_dict': scheduler.state_dict(), } torch.save(save, os.path.join(output, "checkpoint.pt")) if loss < bestLoss: torch.save(save, os.path.join(output, "best.pt")) bestLoss = loss # Load best weights if num_epochs != 0: checkpoint = torch.load(os.path.join(output, "best.pt")) model.load_state_dict(checkpoint['state_dict']) f.write("Best validation loss {} from epoch {}\n".format(checkpoint["loss"], checkpoint["epoch"])) f.flush() if run_test: for split in ["val", "test"]: # Performance without test-time augmentation dataloader = torch.utils.data.DataLoader( echonet.datasets.Echo(root=data_dir, split=split, **kwargs), batch_size=batch_size, num_workers=num_workers, shuffle=True, pin_memory=(device.type == "cuda")) loss, yhat, y = echonet.utils.video.run_epoch(model, dataloader, False, None, device) f.write("{} (one clip) R2: {:.3f} ({:.3f} - {:.3f})\n".format(split, *echonet.utils.bootstrap(y, yhat, sklearn.metrics.r2_score))) f.write("{} (one clip) MAE: {:.2f} ({:.2f} - {:.2f})\n".format(split, *echonet.utils.bootstrap(y, yhat, sklearn.metrics.mean_absolute_error))) f.write("{} (one clip) RMSE: {:.2f} ({:.2f} - {:.2f})\n".format(split, *tuple(map(math.sqrt, echonet.utils.bootstrap(y, yhat, sklearn.metrics.mean_squared_error))))) f.flush() # Performance with test-time augmentation ds = echonet.datasets.Echo(root=data_dir, split=split, **kwargs, clips="all") dataloader = torch.utils.data.DataLoader( ds, batch_size=1, num_workers=num_workers, shuffle=False, pin_memory=(device.type == "cuda")) loss, yhat, y = echonet.utils.video.run_epoch(model, dataloader, False, None, device, save_all=True, block_size=batch_size) f.write("{} (all clips) R2: {:.3f} ({:.3f} - {:.3f})\n".format(split, *echonet.utils.bootstrap(y, np.array(list(map(lambda x: x.mean(), yhat))), sklearn.metrics.r2_score))) f.write("{} (all clips) MAE: {:.2f} ({:.2f} - {:.2f})\n".format(split, *echonet.utils.bootstrap(y, np.array(list(map(lambda x: x.mean(), yhat))), sklearn.metrics.mean_absolute_error))) f.write("{} (all clips) RMSE: {:.2f} ({:.2f} - {:.2f})\n".format(split, *tuple(map(math.sqrt, echonet.utils.bootstrap(y, np.array(list(map(lambda x: x.mean(), yhat))), sklearn.metrics.mean_squared_error))))) f.flush() # Write full performance to file with open(os.path.join(output, "{}_predictions.csv".format(split)), "w") as g: for (filename, pred) in zip(ds.fnames, yhat): for (i, p) in enumerate(pred): g.write("{},{},{:.4f}\n".format(filename, i, p)) echonet.utils.latexify() yhat = np.array(list(map(lambda x: x.mean(), yhat))) # Plot actual and predicted EF fig = plt.figure(figsize=(3, 3)) lower = min(y.min(), yhat.min()) upper = max(y.max(), yhat.max()) plt.scatter(y, yhat, color="k", s=1, edgecolor=None, zorder=2) plt.plot([0, 100], [0, 100], linewidth=1, zorder=3) plt.axis([lower - 3, upper + 3, lower - 3, upper + 3]) plt.gca().set_aspect("equal", "box") plt.xlabel("Actual EF (%)") plt.ylabel("Predicted EF (%)") plt.xticks([10, 20, 30, 40, 50, 60, 70, 80]) plt.yticks([10, 20, 30, 40, 50, 60, 70, 80]) plt.grid(color="gainsboro", linestyle="--", linewidth=1, zorder=1) plt.tight_layout() plt.savefig(os.path.join(output, "{}_scatter.pdf".format(split))) plt.close(fig) # Plot AUROC fig = plt.figure(figsize=(3, 3)) plt.plot([0, 1], [0, 1], linewidth=1, color="k", linestyle="--") for thresh in [35, 40, 45, 50]: fpr, tpr, _ = sklearn.metrics.roc_curve(y > thresh, yhat) print(thresh, sklearn.metrics.roc_auc_score(y > thresh, yhat)) plt.plot(fpr, tpr) plt.axis([-0.01, 1.01, -0.01, 1.01]) plt.xlabel("False Positive Rate") plt.ylabel("True Positive Rate") plt.tight_layout() plt.savefig(os.path.join(output, "{}_roc.pdf".format(split))) plt.close(fig) def run_epoch(model, dataloader, train, optim, device, save_all=False, block_size=None): """Run one epoch of training/evaluation for segmentation. Args: model (torch.nn.Module): Model to train/evaulate. dataloder (torch.utils.data.DataLoader): Dataloader for dataset. train (bool): Whether or not to train model. optim (torch.optim.Optimizer): Optimizer device (torch.device): Device to run on save_all (bool, optional): If True, return predictions for all test-time augmentations separately. If False, return only the mean prediction. Defaults to False. block_size (int or None, optional): Maximum number of augmentations to run on at the same time. Use to limit the amount of memory used. If None, always run on all augmentations simultaneously. Default is None. """ model.train(train) total = 0 # total training loss n = 0 # number of videos processed s1 = 0 # sum of ground truth EF s2 = 0 # Sum of ground truth EF squared yhat = [] y = [] with torch.set_grad_enabled(train): with tqdm.tqdm(total=len(dataloader)) as pbar: for (X, outcome) in dataloader: y.append(outcome.numpy()) X = X.to(device) outcome = outcome.to(device) average = (len(X.shape) == 6) if average: batch, n_clips, c, f, h, w = X.shape X = X.view(-1, c, f, h, w) s1 += outcome.sum() s2 += (outcome ** 2).sum() if block_size is None: outputs = model(X) else: outputs = torch.cat([model(X[j:(j + block_size), ...]) for j in range(0, X.shape[0], block_size)]) if save_all: yhat.append(outputs.view(-1).to("cpu").detach().numpy()) if average: outputs = outputs.view(batch, n_clips, -1).mean(1) if not save_all: yhat.append(outputs.view(-1).to("cpu").detach().numpy()) loss = torch.nn.functional.mse_loss(outputs.view(-1), outcome) if train: optim.zero_grad() loss.backward() optim.step() total += loss.item() * X.size(0) n += X.size(0) pbar.set_postfix_str("{:.2f} ({:.2f}) / {:.2f}".format(total / n, loss.item(), s2 / n - (s1 / n) ** 2)) pbar.update() if not save_all: yhat = np.concatenate(yhat) y = np.concatenate(y) return total / n, yhat, y