Spaces:
Runtime error
Runtime error
File size: 5,957 Bytes
a2dba58 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 |
import argparse
from pathlib import Path
import os
import numpy as np
import pandas as pd
import torch
import seaborn as sns
import matplotlib.pyplot as plt
from torch import nn
from tqdm.auto import tqdm
from torch import autocast
from torch.utils.data import DataLoader
from einops.layers.torch import Rearrange
from einops import rearrange
import sys
HEAD = Path(os.getcwd()).parent.parent
sys.path.append(HEAD)
from models.datasetDM_model import DatasetDM
from trainers.train_baseline import dice, precision, recall
from dataloaders.JSRT import build_dataloaders
from dataloaders.NIH import NIHDataset
from dataloaders.Montgomery import MonDataset
NIHPATH = "<PATH_TO_DATA>/NIH/"
NIHFILE = "correspondence_with_chestXray8.csv" # saved in data
MONPATH = "<PATH_TO_DATA>/NLM/MontgomerySet/"
MONFILE = "patient_data.csv"
if __name__ == "__main__":
# load config file and parse arguments
parser = argparse.ArgumentParser()
parser.add_argument('--experiment', "-e", type=str, help='Experiment path', default="logs/JSRT_conditional/20230213_171633")
parser.add_argument('--rerun', "-r", help='Run the test again', default=False, action="store_true")
args = parser.parse_args()
if os.path.isdir(args.experiment):
print("Experiment path identified as a directory")
else:
raise ValueError("Experiment path is not a directory")
files = os.listdir(args.experiment)
torch_file = None
if {'JSRT_val_predictions.pt', 'JSRT_test_predictions.pt', 'NIH_predictions.pt', 'Montgomery_predictions.pt'} <= set(files) and not args.rerun:
print("Experiment already tested")
sys.exit(0)
for f in files:
if "model" in f:
torch_file = f
break
if torch_file is None:
raise ValueError("No checkpoint file found in experiment directory")
print(f"Loading experiment from {torch_file}")
data = torch.load(Path(args.experiment) / torch_file)
config = data["config"]
# pick model
if config.experiment == "datasetDM":
model = DatasetDM(config)
model.classifier = nn.Sequential(
Rearrange('b (step act) h w -> (b step) act h w', step=len(model.steps)),
nn.Conv2d(960, 128, 1),
nn.ReLU(),
nn.BatchNorm2d(128),
nn.Conv2d(128, 32, 1),
nn.ReLU(),
nn.BatchNorm2d(32),
nn.Conv2d(32, 1, config.out_channels)
)
else:
raise ValueError(f"Experiment {config.experiment} not recognized")
model.load_state_dict(data['model_state_dict'])
# Gather model output
model.eval().to(config.device)
# Load data
dataloaders = build_dataloaders(
config.data_dir,
config.img_size,
config.batch_size,
config.num_workers,
)
datasets_to_test = {
"JSRT_val": dataloaders["val"],
"JSRT_test": dataloaders["test"],
"NIH": DataLoader(NIHDataset(NIHPATH, NIHPATH, NIHFILE, config.img_size),
config.batch_size, num_workers=config.num_workers),
"Montgomery": DataLoader(MonDataset(MONPATH, MONPATH, MONFILE, config.img_size),
config.batch_size, num_workers=config.num_workers)
}
for dataset_key in datasets_to_test:
if f"{dataset_key}_predictions.pt" in files and not args.rerun:
print(f"{dataset_key} already tested")
output = torch.load(Path(args.experiment) / f'{dataset_key}_predictions.pt')
print(f"{dataset_key} metrics: \n\tdice: {output['dice'].mean():.3}+/-{output['dice'].std():.3}")
print(f"\tprecision: {output['precision'].mean():.3}+/-{output['precision'].std():.3}")
print(f"\trecall: {output['recall'].mean():.3}+/-{output['recall'].std():.3}")
continue
print(f"Testing {dataset_key} set")
y_hats = []
y_star = []
for i, (x, y) in tqdm(enumerate(datasets_to_test[dataset_key]), desc='Validating'):
x = x.to(config.device)
with autocast(device_type=config.device, enabled=config.mixed_precision):
with torch.no_grad():
# all depths
pred = torch.sigmoid(model(x))
y_hats.append(pred.detach().cpu())
y_star.append(y)
# save predictions
y_star = torch.cat(y_star, 0)
y_hats = torch.cat(y_hats, 0)
y_hats = rearrange(y_hats, '(b step) 1 h w -> step b 1 h w', step=len(model.steps))
for i, y_hat in enumerate(y_hats):
output = {
'y_hat': y_hat,
'y_star': y_star,
'dice':dice(y_hat>.5, y_star),
'precision':precision(y_hat>.5, y_star),
'recall':recall(y_hat>.5, y_star),}
print(f"{dataset_key} {model.steps[i]} metrics: \n\tdice: {output['dice'].mean():.3}+/-{output['dice'].std():.3}")
print(f"\tprecision: {output['precision'].mean():.3}+/-{output['precision'].std():.3}")
print(f"\trecall: {output['recall'].mean():.3}+/-{output['recall'].std():.3}")
torch.save(output, Path(args.experiment) / f'{dataset_key}_timestep{model.steps[i]}_predictions.pt')
y_hat = y_hats.mean(0)
output = {
'y_hat': y_hat,
'y_star': y_star,
'dice':dice(y_hat>.5, y_star),
'precision':precision(y_hat>.5, y_star),
'recall':recall(y_hat>.5, y_star),}
print(f"{dataset_key} metrics: \n\tdice: {output['dice'].mean():.3}+/-{output['dice'].std():.3}")
print(f"\tprecision: {output['precision'].mean():.3}+/-{output['precision'].std():.3}")
print(f"\trecall: {output['recall'].mean():.3}+/-{output['recall'].std():.3}")
torch.save(output, Path(args.experiment) / f'{dataset_key}_predictions.pt')
|