anonymous
first commit without models
a2dba58
raw
history blame
7.3 kB
import argparse
from pathlib import Path
import os
import torch
from tqdm.auto import tqdm
from torch import autocast
from torch.utils.data import DataLoader
import sys
HEAD = Path(os.getcwd()).parent.parent
sys.path.append("/vol/biomedic3/mmr12/projects/TEDM/")
from models.diffusion_model import DiffusionModel
from models.unet_model import Unet
from models.datasetDM_model import DatasetDM
from trainers.datasetDM_per_step import ModDatasetDM
from trainers.train_baseline import dice, precision, recall
from dataloaders.JSRT import build_dataloaders
from dataloaders.NIH import NIHDataset
from dataloaders.Montgomery import MonDataset
NIHPATH = "/vol/biodata/data/chest_xray/NIH/"
NIHFILE = "correspondence_with_chestXray8.csv"
MONPATH = "/vol/biodata/data/chest_xray/NLM/MontgomerySet/"
MONFILE = "patient_data.csv"
if __name__ == "__main__":
# load config file and parse arguments
parser = argparse.ArgumentParser()
parser.add_argument('--experiment', "-e", type=str, help='Experiment path', default="logs/JSRT_conditional/20230213_171633")
parser.add_argument('--rerun', "-r", help='Run the test again', default=False, action="store_true")
args = parser.parse_args()
if os.path.isdir(args.experiment):
print("Experiment path identified as a directory")
else:
raise ValueError("Experiment path is not a directory")
files = os.listdir(args.experiment)
torch_file = None
if {'JSRT_val_predictions.pt', 'JSRT_test_predictions.pt', 'NIH_predictions.pt', 'Montgomery_predictions.pt'} <= set(files) and not args.rerun:
print("Experiment already tested")
for file in ['JSRT_val_predictions.pt', 'JSRT_test_predictions.pt', 'NIH_predictions.pt', 'Montgomery_predictions.pt']:
output = torch.load(Path(args.experiment) / file)
dataset_key = file.split("_")[0]
print(f"{dataset_key} metrics: \n\tdice: {output['dice'].mean():.3}+/-{output['dice'].std():.3}")
print(f"\tprecision: {output['precision'].mean():.3}+/-{output['precision'].std():.3}")
print(f"\trecall: {output['recall'].mean():.3}+/-{output['recall'].std():.3}")
#torch.save(output, Path(args.experiment) / f'{dataset_key}_predictions.pt')
exit(0)
for f in files:
if "model" in f:
torch_file = f
break
if torch_file is None:
raise ValueError("No checkpoint file found in experiment directory")
print(f"Loading experiment from {torch_file}")
data = torch.load(Path(args.experiment) / torch_file)
config = data["config"]
# pick model
if config.experiment in ["baseline", "global_finetune", "glob_loc_finetune"]:
model = Unet(**vars(config))
elif config.experiment == "datasetDM":
model = DatasetDM(config)
elif config.experiment == "simple_datasetDM":
model = ModDatasetDM(config)
else:
raise ValueError(f"Experiment {config.experiment} not recognized")
model.load_state_dict(data['model_state_dict'])
# Gather model output
model.eval().to(config.device)
# Load data
dataloaders = build_dataloaders(
config.data_dir,
config.img_size,
config.batch_size,
config.num_workers,
)
datasets_to_test = {
"JSRT_val": dataloaders["val"],
"JSRT_test": dataloaders["test"],
"NIH": DataLoader(NIHDataset(NIHPATH, NIHPATH, NIHFILE, config.img_size),
config.batch_size, num_workers=config.num_workers),
"Montgomery": DataLoader(MonDataset(MONPATH, MONPATH, MONFILE, config.img_size),
config.batch_size, num_workers=config.num_workers)
}
if config.experiment == "simple_datasetDM":
# re-calculate mean and var as they were not saved in the model dict
train_dl = dataloaders["train"]
for x, _ in tqdm(train_dl, desc="Calculating mean and variance"):
x = x.to(config.device)
features = model.extract_features(x)
model.mean += features.sum(dim=0)
model.mean_squared += (features ** 2).sum(dim=0)
model.mean = model.mean / len(train_dl.dataset)
model.std = (model.mean_squared / len(train_dl.dataset) - model.mean ** 2).sqrt() + 1e-6
model.mean = model.mean.to(config.device)
model.std = model.std.to(config.device)
for dataset_key in datasets_to_test:
if f"{dataset_key}_predictions.pt" in files and not args.rerun:
print(f"{dataset_key} already tested")
output = torch.load(Path(args.experiment) / f'{dataset_key}_predictions.pt')
print(f"{dataset_key} metrics: \n\tdice: {output['dice'].mean():.3}+/-{output['dice'].std():.3}")
print(f"\tprecision: {output['precision'].mean():.3}+/-{output['precision'].std():.3}")
print(f"\trecall: {output['recall'].mean():.3}+/-{output['recall'].std():.3}")
continue
print(f"Testing {dataset_key} set")
y_hat = []
y_star = []
for i, (x, y) in tqdm(enumerate(datasets_to_test[dataset_key]), desc='Validating'):
x = x.to(config.device)
if config.experiment == "conditional":
# sample n = 5 different segmetations
y_hats = []
for _ in range(5):
img = torch.randn(x.shape, device=config.device)
for t in tqdm(range(0, config.timesteps)[::-1]):
# sample next timestep image (x_{t-1})
with autocast(device_type=config.device, enabled=config.mixed_precision):
with torch.no_grad():
img = model.sample_timestep(img, t=t, cond=x)
y_hats.append(img.detach().cpu() / 2 + .5)
# take the average over the 5 samples
y_hats = torch.stack(y_hats, -1).mean(-1)
# record
y_hat.append(y_hats)
y_star.append(y)
elif config.experiment in ["baseline", "datasetDM", "simple_datasetDM", "global_finetune", "glob_loc_finetune"] :
with autocast(device_type=config.device, enabled=config.mixed_precision):
with torch.no_grad():
pred = torch.sigmoid(model(x))
y_hat.append(pred.detach().cpu())
y_star.append(y)
else:
raise ValueError(f"Experiment {config.experiment} not recognized")
# save predictions
y_hat = torch.cat(y_hat, 0)
y_star = torch.cat(y_star, 0)
output = {
'y_hat': y_hat,
'y_star': y_star,
'dice':dice(y_hat>.5, y_star),
'precision':precision(y_hat>.5, y_star),
'recall':recall(y_hat>.5, y_star),}
print(f"{dataset_key} metrics: \n\tdice: {output['dice'].mean():.3}+/-{output['dice'].std():.3}")
print(f"\tprecision: {output['precision'].mean():.3}+/-{output['precision'].std():.3}")
print(f"\trecall: {output['recall'].mean():.3}+/-{output['recall'].std():.3}")
torch.save(output, Path(args.experiment) / f'{dataset_key}_predictions.pt')