Med-Real2Sim / dynamic /scripts /plot_simulated_noise.py
Franny Dean
files
dde56f0
raw
history blame
6.14 kB
#!/usr/bin/env python3
"""Code to generate plots for Extended Data Fig. 6."""
import os
import pickle
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import PIL
import sklearn
import torch
import torchvision
import echonet
def main(fig_root=os.path.join("figure", "noise"),
video_output=os.path.join("output", "video", "r2plus1d_18_32_2_pretrained"),
seg_output=os.path.join("output", "segmentation", "deeplabv3_resnet50_random"),
NOISE=(0, 0.1, 0.2, 0.3, 0.4, 0.5)):
"""Generate plots for Extended Data Fig. 6."""
device = torch.device("cuda")
filename = os.path.join(fig_root, "data.pkl") # Cache of results
try:
# Attempt to load cache
with open(filename, "rb") as f:
Y, YHAT, INTER, UNION = pickle.load(f)
except FileNotFoundError:
# Generate results if no cache available
os.makedirs(fig_root, exist_ok=True)
# Load trained video model
model_v = torchvision.models.video.r2plus1d_18()
model_v.fc = torch.nn.Linear(model_v.fc.in_features, 1)
if device.type == "cuda":
model_v = torch.nn.DataParallel(model_v)
model_v.to(device)
checkpoint = torch.load(os.path.join(video_output, "checkpoint.pt"))
model_v.load_state_dict(checkpoint['state_dict'])
# Load trained segmentation model
model_s = torchvision.models.segmentation.deeplabv3_resnet50(aux_loss=False)
model_s.classifier[-1] = torch.nn.Conv2d(model_s.classifier[-1].in_channels, 1, kernel_size=model_s.classifier[-1].kernel_size)
if device.type == "cuda":
model_s = torch.nn.DataParallel(model_s)
model_s.to(device)
checkpoint = torch.load(os.path.join(seg_output, "checkpoint.pt"))
model_s.load_state_dict(checkpoint['state_dict'])
# Run simulation
dice = []
mse = []
r2 = []
Y = []
YHAT = []
INTER = []
UNION = []
for noise in NOISE:
Y.append([])
YHAT.append([])
INTER.append([])
UNION.append([])
dataset = echonet.datasets.Echo(split="test", noise=noise)
PIL.Image.fromarray(dataset[0][0][:, 0, :, :].astype(np.uint8).transpose(1, 2, 0)).save(os.path.join(fig_root, "noise_{}.tif".format(round(100 * noise))))
mean, std = echonet.utils.get_mean_and_std(echonet.datasets.Echo(split="train"))
tasks = ["LargeFrame", "SmallFrame", "LargeTrace", "SmallTrace"]
kwargs = {
"target_type": tasks,
"mean": mean,
"std": std,
"noise": noise
}
dataset = echonet.datasets.Echo(split="test", **kwargs)
dataloader = torch.utils.data.DataLoader(dataset,
batch_size=16, num_workers=5, shuffle=True, pin_memory=(device.type == "cuda"))
loss, large_inter, large_union, small_inter, small_union = echonet.utils.segmentation.run_epoch(model_s, dataloader, "test", None, device)
inter = np.concatenate((large_inter, small_inter)).sum()
union = np.concatenate((large_union, small_union)).sum()
dice.append(2 * inter / (union + inter))
INTER[-1].extend(large_inter.tolist() + small_inter.tolist())
UNION[-1].extend(large_union.tolist() + small_union.tolist())
kwargs = {"target_type": "EF",
"mean": mean,
"std": std,
"length": 32,
"period": 2,
"noise": noise
}
dataset = echonet.datasets.Echo(split="test", **kwargs)
dataloader = torch.utils.data.DataLoader(dataset,
batch_size=16, num_workers=5, shuffle=True, pin_memory=(device.type == "cuda"))
loss, yhat, y = echonet.utils.video.run_epoch(model_v, dataloader, "test", None, device)
mse.append(loss)
r2.append(sklearn.metrics.r2_score(y, yhat))
Y[-1].extend(y.tolist())
YHAT[-1].extend(yhat.tolist())
# Save results in cache
with open(filename, "wb") as f:
pickle.dump((Y, YHAT, INTER, UNION), f)
# Set up plot
echonet.utils.latexify()
NOISE = list(map(lambda x: round(100 * x), NOISE))
fig = plt.figure(figsize=(6.50, 4.75))
gs = matplotlib.gridspec.GridSpec(3, 1, height_ratios=[2.0, 2.0, 0.75])
ax = (plt.subplot(gs[0]), plt.subplot(gs[1]), plt.subplot(gs[2]))
# Plot EF prediction results (R^2)
r2 = [sklearn.metrics.r2_score(y, yhat) for (y, yhat) in zip(Y, YHAT)]
ax[0].plot(NOISE, r2, color="k", linewidth=1, marker=".")
ax[0].set_xticks([])
ax[0].set_ylabel("R$^2$")
l, h = min(r2), max(r2)
l, h = l - 0.1 * (h - l), h + 0.1 * (h - l)
ax[0].axis([min(NOISE) - 5, max(NOISE) + 5, 0, 1])
# Plot segmentation results (DSC)
dice = [echonet.utils.dice_similarity_coefficient(inter, union) for (inter, union) in zip(INTER, UNION)]
ax[1].plot(NOISE, dice, color="k", linewidth=1, marker=".")
ax[1].set_xlabel("Pixels Removed (%)")
ax[1].set_ylabel("DSC")
l, h = min(dice), max(dice)
l, h = l - 0.1 * (h - l), h + 0.1 * (h - l)
ax[1].axis([min(NOISE) - 5, max(NOISE) + 5, 0, 1])
# Add example images below
for noise in NOISE:
image = matplotlib.image.imread(os.path.join(fig_root, "noise_{}.tif".format(noise)))
imagebox = matplotlib.offsetbox.OffsetImage(image, zoom=0.4)
ab = matplotlib.offsetbox.AnnotationBbox(imagebox, (noise, 0.0), frameon=False)
ax[2].add_artist(ab)
ax[2].axis("off")
ax[2].axis([min(NOISE) - 5, max(NOISE) + 5, -1, 1])
fig.tight_layout()
plt.savefig(os.path.join(fig_root, "noise.pdf"), dpi=1200)
plt.savefig(os.path.join(fig_root, "noise.eps"), dpi=300)
plt.savefig(os.path.join(fig_root, "noise.png"), dpi=600)
plt.close(fig)
if __name__ == "__main__":
main()