Spaces:
Running
Running
"""Functions for training and running EF prediction.""" | |
import math | |
import os | |
import time | |
import click | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import sklearn.metrics | |
import torch | |
import torchvision | |
import tqdm | |
import echonet | |
def run( | |
data_dir=None, | |
output=None, | |
task="EF", | |
model_name="r2plus1d_18", | |
pretrained=True, | |
weights=None, | |
run_test=False, | |
num_epochs=45, | |
lr=1e-4, | |
weight_decay=1e-4, | |
lr_step_period=15, | |
frames=32, | |
period=2, | |
num_train_patients=None, | |
num_workers=4, | |
batch_size=20, | |
device=None, | |
seed=0, | |
): | |
"""Trains/tests EF prediction model. | |
\b | |
Args: | |
data_dir (str, optional): Directory containing dataset. Defaults to | |
`echonet.config.DATA_DIR`. | |
output (str, optional): Directory to place outputs. Defaults to | |
output/video/<model_name>_<pretrained/random>/. | |
task (str, optional): Name of task to predict. Options are the headers | |
of FileList.csv. Defaults to ``EF''. | |
model_name (str, optional): Name of model. One of ``mc3_18'', | |
``r2plus1d_18'', or ``r3d_18'' | |
(options are torchvision.models.video.<model_name>) | |
Defaults to ``r2plus1d_18''. | |
pretrained (bool, optional): Whether to use pretrained weights for model | |
Defaults to True. | |
weights (str, optional): Path to checkpoint containing weights to | |
initialize model. Defaults to None. | |
run_test (bool, optional): Whether or not to run on test. | |
Defaults to False. | |
num_epochs (int, optional): Number of epochs during training. | |
Defaults to 45. | |
lr (float, optional): Learning rate for SGD | |
Defaults to 1e-4. | |
weight_decay (float, optional): Weight decay for SGD | |
Defaults to 1e-4. | |
lr_step_period (int or None, optional): Period of learning rate decay | |
(learning rate is decayed by a multiplicative factor of 0.1) | |
Defaults to 15. | |
frames (int, optional): Number of frames to use in clip | |
Defaults to 32. | |
period (int, optional): Sampling period for frames | |
Defaults to 2. | |
n_train_patients (int or None, optional): Number of training patients | |
for ablations. Defaults to all patients. | |
num_workers (int, optional): Number of subprocesses to use for data | |
loading. If 0, the data will be loaded in the main process. | |
Defaults to 4. | |
device (str or None, optional): Name of device to run on. Options from | |
https://pytorch.org/docs/stable/tensor_attributes.html#torch.torch.device | |
Defaults to ``cuda'' if available, and ``cpu'' otherwise. | |
batch_size (int, optional): Number of samples to load per batch | |
Defaults to 20. | |
seed (int, optional): Seed for random number generator. Defaults to 0. | |
""" | |
# Seed RNGs | |
np.random.seed(seed) | |
torch.manual_seed(seed) | |
# Set default output directory | |
if output is None: | |
output = os.path.join("output", "video", "{}_{}_{}_{}".format(model_name, frames, period, "pretrained" if pretrained else "random")) | |
os.makedirs(output, exist_ok=True) | |
# Set device for computations | |
if device is None: | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
# Set up model | |
model = torchvision.models.video.__dict__[model_name](pretrained=pretrained) | |
model.fc = torch.nn.Linear(model.fc.in_features, 1) | |
model.fc.bias.data[0] = 55.6 | |
if device.type == "cuda": | |
model = torch.nn.DataParallel(model) | |
model.to(device) | |
if weights is not None: | |
checkpoint = torch.load(weights) | |
model.load_state_dict(checkpoint['state_dict']) | |
# Set up optimizer | |
optim = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=weight_decay) | |
if lr_step_period is None: | |
lr_step_period = math.inf | |
scheduler = torch.optim.lr_scheduler.StepLR(optim, lr_step_period) | |
# Compute mean and std | |
mean, std = echonet.utils.get_mean_and_std(echonet.datasets.Echo(root=data_dir, split="train")) | |
kwargs = {"target_type": task, | |
"mean": mean, | |
"std": std, | |
"length": frames, | |
"period": period, | |
} | |
# Set up datasets and dataloaders | |
dataset = {} | |
dataset["train"] = echonet.datasets.Echo(root=data_dir, split="train", **kwargs, pad=12) | |
if num_train_patients is not None and len(dataset["train"]) > num_train_patients: | |
# Subsample patients (used for ablation experiment) | |
indices = np.random.choice(len(dataset["train"]), num_train_patients, replace=False) | |
dataset["train"] = torch.utils.data.Subset(dataset["train"], indices) | |
dataset["val"] = echonet.datasets.Echo(root=data_dir, split="val", **kwargs) | |
# Run training and testing loops | |
with open(os.path.join(output, "log.csv"), "a") as f: | |
epoch_resume = 0 | |
bestLoss = float("inf") | |
try: | |
# Attempt to load checkpoint | |
checkpoint = torch.load(os.path.join(output, "checkpoint.pt")) | |
model.load_state_dict(checkpoint['state_dict']) | |
optim.load_state_dict(checkpoint['opt_dict']) | |
scheduler.load_state_dict(checkpoint['scheduler_dict']) | |
epoch_resume = checkpoint["epoch"] + 1 | |
bestLoss = checkpoint["best_loss"] | |
f.write("Resuming from epoch {}\n".format(epoch_resume)) | |
except FileNotFoundError: | |
f.write("Starting run from scratch\n") | |
for epoch in range(epoch_resume, num_epochs): | |
print("Epoch #{}".format(epoch), flush=True) | |
for phase in ['train', 'val']: | |
start_time = time.time() | |
for i in range(torch.cuda.device_count()): | |
torch.cuda.reset_peak_memory_stats(i) | |
ds = dataset[phase] | |
dataloader = torch.utils.data.DataLoader( | |
ds, batch_size=batch_size, num_workers=num_workers, shuffle=True, pin_memory=(device.type == "cuda"), drop_last=(phase == "train")) | |
loss, yhat, y = echonet.utils.video.run_epoch(model, dataloader, phase == "train", optim, device) | |
f.write("{},{},{},{},{},{},{},{},{}\n".format(epoch, | |
phase, | |
loss, | |
sklearn.metrics.r2_score(y, yhat), | |
time.time() - start_time, | |
y.size, | |
sum(torch.cuda.max_memory_allocated() for i in range(torch.cuda.device_count())), | |
sum(torch.cuda.max_memory_reserved() for i in range(torch.cuda.device_count())), | |
batch_size)) | |
f.flush() | |
scheduler.step() | |
# Save checkpoint | |
save = { | |
'epoch': epoch, | |
'state_dict': model.state_dict(), | |
'period': period, | |
'frames': frames, | |
'best_loss': bestLoss, | |
'loss': loss, | |
'r2': sklearn.metrics.r2_score(y, yhat), | |
'opt_dict': optim.state_dict(), | |
'scheduler_dict': scheduler.state_dict(), | |
} | |
torch.save(save, os.path.join(output, "checkpoint.pt")) | |
if loss < bestLoss: | |
torch.save(save, os.path.join(output, "best.pt")) | |
bestLoss = loss | |
# Load best weights | |
if num_epochs != 0: | |
checkpoint = torch.load(os.path.join(output, "best.pt")) | |
model.load_state_dict(checkpoint['state_dict']) | |
f.write("Best validation loss {} from epoch {}\n".format(checkpoint["loss"], checkpoint["epoch"])) | |
f.flush() | |
if run_test: | |
for split in ["val", "test"]: | |
# Performance without test-time augmentation | |
dataloader = torch.utils.data.DataLoader( | |
echonet.datasets.Echo(root=data_dir, split=split, **kwargs), | |
batch_size=batch_size, num_workers=num_workers, shuffle=True, pin_memory=(device.type == "cuda")) | |
loss, yhat, y = echonet.utils.video.run_epoch(model, dataloader, False, None, device) | |
f.write("{} (one clip) R2: {:.3f} ({:.3f} - {:.3f})\n".format(split, *echonet.utils.bootstrap(y, yhat, sklearn.metrics.r2_score))) | |
f.write("{} (one clip) MAE: {:.2f} ({:.2f} - {:.2f})\n".format(split, *echonet.utils.bootstrap(y, yhat, sklearn.metrics.mean_absolute_error))) | |
f.write("{} (one clip) RMSE: {:.2f} ({:.2f} - {:.2f})\n".format(split, *tuple(map(math.sqrt, echonet.utils.bootstrap(y, yhat, sklearn.metrics.mean_squared_error))))) | |
f.flush() | |
# Performance with test-time augmentation | |
ds = echonet.datasets.Echo(root=data_dir, split=split, **kwargs, clips="all") | |
dataloader = torch.utils.data.DataLoader( | |
ds, batch_size=1, num_workers=num_workers, shuffle=False, pin_memory=(device.type == "cuda")) | |
loss, yhat, y = echonet.utils.video.run_epoch(model, dataloader, False, None, device, save_all=True, block_size=batch_size) | |
f.write("{} (all clips) R2: {:.3f} ({:.3f} - {:.3f})\n".format(split, *echonet.utils.bootstrap(y, np.array(list(map(lambda x: x.mean(), yhat))), sklearn.metrics.r2_score))) | |
f.write("{} (all clips) MAE: {:.2f} ({:.2f} - {:.2f})\n".format(split, *echonet.utils.bootstrap(y, np.array(list(map(lambda x: x.mean(), yhat))), sklearn.metrics.mean_absolute_error))) | |
f.write("{} (all clips) RMSE: {:.2f} ({:.2f} - {:.2f})\n".format(split, *tuple(map(math.sqrt, echonet.utils.bootstrap(y, np.array(list(map(lambda x: x.mean(), yhat))), sklearn.metrics.mean_squared_error))))) | |
f.flush() | |
# Write full performance to file | |
with open(os.path.join(output, "{}_predictions.csv".format(split)), "w") as g: | |
for (filename, pred) in zip(ds.fnames, yhat): | |
for (i, p) in enumerate(pred): | |
g.write("{},{},{:.4f}\n".format(filename, i, p)) | |
echonet.utils.latexify() | |
yhat = np.array(list(map(lambda x: x.mean(), yhat))) | |
# Plot actual and predicted EF | |
fig = plt.figure(figsize=(3, 3)) | |
lower = min(y.min(), yhat.min()) | |
upper = max(y.max(), yhat.max()) | |
plt.scatter(y, yhat, color="k", s=1, edgecolor=None, zorder=2) | |
plt.plot([0, 100], [0, 100], linewidth=1, zorder=3) | |
plt.axis([lower - 3, upper + 3, lower - 3, upper + 3]) | |
plt.gca().set_aspect("equal", "box") | |
plt.xlabel("Actual EF (%)") | |
plt.ylabel("Predicted EF (%)") | |
plt.xticks([10, 20, 30, 40, 50, 60, 70, 80]) | |
plt.yticks([10, 20, 30, 40, 50, 60, 70, 80]) | |
plt.grid(color="gainsboro", linestyle="--", linewidth=1, zorder=1) | |
plt.tight_layout() | |
plt.savefig(os.path.join(output, "{}_scatter.pdf".format(split))) | |
plt.close(fig) | |
# Plot AUROC | |
fig = plt.figure(figsize=(3, 3)) | |
plt.plot([0, 1], [0, 1], linewidth=1, color="k", linestyle="--") | |
for thresh in [35, 40, 45, 50]: | |
fpr, tpr, _ = sklearn.metrics.roc_curve(y > thresh, yhat) | |
print(thresh, sklearn.metrics.roc_auc_score(y > thresh, yhat)) | |
plt.plot(fpr, tpr) | |
plt.axis([-0.01, 1.01, -0.01, 1.01]) | |
plt.xlabel("False Positive Rate") | |
plt.ylabel("True Positive Rate") | |
plt.tight_layout() | |
plt.savefig(os.path.join(output, "{}_roc.pdf".format(split))) | |
plt.close(fig) | |
def run_epoch(model, dataloader, train, optim, device, save_all=False, block_size=None): | |
"""Run one epoch of training/evaluation for segmentation. | |
Args: | |
model (torch.nn.Module): Model to train/evaulate. | |
dataloder (torch.utils.data.DataLoader): Dataloader for dataset. | |
train (bool): Whether or not to train model. | |
optim (torch.optim.Optimizer): Optimizer | |
device (torch.device): Device to run on | |
save_all (bool, optional): If True, return predictions for all | |
test-time augmentations separately. If False, return only | |
the mean prediction. | |
Defaults to False. | |
block_size (int or None, optional): Maximum number of augmentations | |
to run on at the same time. Use to limit the amount of memory | |
used. If None, always run on all augmentations simultaneously. | |
Default is None. | |
""" | |
model.train(train) | |
total = 0 # total training loss | |
n = 0 # number of videos processed | |
s1 = 0 # sum of ground truth EF | |
s2 = 0 # Sum of ground truth EF squared | |
yhat = [] | |
y = [] | |
with torch.set_grad_enabled(train): | |
with tqdm.tqdm(total=len(dataloader)) as pbar: | |
for (X, outcome) in dataloader: | |
y.append(outcome.numpy()) | |
X = X.to(device) | |
outcome = outcome.to(device) | |
average = (len(X.shape) == 6) | |
if average: | |
batch, n_clips, c, f, h, w = X.shape | |
X = X.view(-1, c, f, h, w) | |
s1 += outcome.sum() | |
s2 += (outcome ** 2).sum() | |
if block_size is None: | |
outputs = model(X) | |
else: | |
outputs = torch.cat([model(X[j:(j + block_size), ...]) for j in range(0, X.shape[0], block_size)]) | |
if save_all: | |
yhat.append(outputs.view(-1).to("cpu").detach().numpy()) | |
if average: | |
outputs = outputs.view(batch, n_clips, -1).mean(1) | |
if not save_all: | |
yhat.append(outputs.view(-1).to("cpu").detach().numpy()) | |
loss = torch.nn.functional.mse_loss(outputs.view(-1), outcome) | |
if train: | |
optim.zero_grad() | |
loss.backward() | |
optim.step() | |
total += loss.item() * X.size(0) | |
n += X.size(0) | |
pbar.set_postfix_str("{:.2f} ({:.2f}) / {:.2f}".format(total / n, loss.item(), s2 / n - (s1 / n) ** 2)) | |
pbar.update() | |
if not save_all: | |
yhat = np.concatenate(yhat) | |
y = np.concatenate(y) | |
return total / n, yhat, y | |