import argparse from pathlib import Path import os import torch from tqdm.auto import tqdm from torch import autocast from torch.utils.data import DataLoader import sys HEAD = Path(os.getcwd()).parent.parent sys.path.append("/vol/biomedic3/mmr12/projects/TEDM/") from models.diffusion_model import DiffusionModel from models.unet_model import Unet from models.datasetDM_model import DatasetDM from trainers.datasetDM_per_step import ModDatasetDM from trainers.train_baseline import dice, precision, recall from dataloaders.JSRT import build_dataloaders from dataloaders.NIH import NIHDataset from dataloaders.Montgomery import MonDataset NIHPATH = "/vol/biodata/data/chest_xray/NIH/" NIHFILE = "correspondence_with_chestXray8.csv" MONPATH = "/vol/biodata/data/chest_xray/NLM/MontgomerySet/" MONFILE = "patient_data.csv" if __name__ == "__main__": # load config file and parse arguments parser = argparse.ArgumentParser() parser.add_argument('--experiment', "-e", type=str, help='Experiment path', default="logs/JSRT_conditional/20230213_171633") parser.add_argument('--rerun', "-r", help='Run the test again', default=False, action="store_true") args = parser.parse_args() if os.path.isdir(args.experiment): print("Experiment path identified as a directory") else: raise ValueError("Experiment path is not a directory") files = os.listdir(args.experiment) torch_file = None if {'JSRT_val_predictions.pt', 'JSRT_test_predictions.pt', 'NIH_predictions.pt', 'Montgomery_predictions.pt'} <= set(files) and not args.rerun: print("Experiment already tested") for file in ['JSRT_val_predictions.pt', 'JSRT_test_predictions.pt', 'NIH_predictions.pt', 'Montgomery_predictions.pt']: output = torch.load(Path(args.experiment) / file) dataset_key = file.split("_")[0] print(f"{dataset_key} metrics: \n\tdice: {output['dice'].mean():.3}+/-{output['dice'].std():.3}") print(f"\tprecision: {output['precision'].mean():.3}+/-{output['precision'].std():.3}") print(f"\trecall: {output['recall'].mean():.3}+/-{output['recall'].std():.3}") #torch.save(output, Path(args.experiment) / f'{dataset_key}_predictions.pt') exit(0) for f in files: if "model" in f: torch_file = f break if torch_file is None: raise ValueError("No checkpoint file found in experiment directory") print(f"Loading experiment from {torch_file}") data = torch.load(Path(args.experiment) / torch_file) config = data["config"] # pick model if config.experiment in ["baseline", "global_finetune", "glob_loc_finetune"]: model = Unet(**vars(config)) elif config.experiment == "datasetDM": model = DatasetDM(config) elif config.experiment == "simple_datasetDM": model = ModDatasetDM(config) else: raise ValueError(f"Experiment {config.experiment} not recognized") model.load_state_dict(data['model_state_dict']) # Gather model output model.eval().to(config.device) # Load data dataloaders = build_dataloaders( config.data_dir, config.img_size, config.batch_size, config.num_workers, ) datasets_to_test = { "JSRT_val": dataloaders["val"], "JSRT_test": dataloaders["test"], "NIH": DataLoader(NIHDataset(NIHPATH, NIHPATH, NIHFILE, config.img_size), config.batch_size, num_workers=config.num_workers), "Montgomery": DataLoader(MonDataset(MONPATH, MONPATH, MONFILE, config.img_size), config.batch_size, num_workers=config.num_workers) } if config.experiment == "simple_datasetDM": # re-calculate mean and var as they were not saved in the model dict train_dl = dataloaders["train"] for x, _ in tqdm(train_dl, desc="Calculating mean and variance"): x = x.to(config.device) features = model.extract_features(x) model.mean += features.sum(dim=0) model.mean_squared += (features ** 2).sum(dim=0) model.mean = model.mean / len(train_dl.dataset) model.std = (model.mean_squared / len(train_dl.dataset) - model.mean ** 2).sqrt() + 1e-6 model.mean = model.mean.to(config.device) model.std = model.std.to(config.device) for dataset_key in datasets_to_test: if f"{dataset_key}_predictions.pt" in files and not args.rerun: print(f"{dataset_key} already tested") output = torch.load(Path(args.experiment) / f'{dataset_key}_predictions.pt') print(f"{dataset_key} metrics: \n\tdice: {output['dice'].mean():.3}+/-{output['dice'].std():.3}") print(f"\tprecision: {output['precision'].mean():.3}+/-{output['precision'].std():.3}") print(f"\trecall: {output['recall'].mean():.3}+/-{output['recall'].std():.3}") continue print(f"Testing {dataset_key} set") y_hat = [] y_star = [] for i, (x, y) in tqdm(enumerate(datasets_to_test[dataset_key]), desc='Validating'): x = x.to(config.device) if config.experiment == "conditional": # sample n = 5 different segmetations y_hats = [] for _ in range(5): img = torch.randn(x.shape, device=config.device) for t in tqdm(range(0, config.timesteps)[::-1]): # sample next timestep image (x_{t-1}) with autocast(device_type=config.device, enabled=config.mixed_precision): with torch.no_grad(): img = model.sample_timestep(img, t=t, cond=x) y_hats.append(img.detach().cpu() / 2 + .5) # take the average over the 5 samples y_hats = torch.stack(y_hats, -1).mean(-1) # record y_hat.append(y_hats) y_star.append(y) elif config.experiment in ["baseline", "datasetDM", "simple_datasetDM", "global_finetune", "glob_loc_finetune"] : with autocast(device_type=config.device, enabled=config.mixed_precision): with torch.no_grad(): pred = torch.sigmoid(model(x)) y_hat.append(pred.detach().cpu()) y_star.append(y) else: raise ValueError(f"Experiment {config.experiment} not recognized") # save predictions y_hat = torch.cat(y_hat, 0) y_star = torch.cat(y_star, 0) output = { 'y_hat': y_hat, 'y_star': y_star, 'dice':dice(y_hat>.5, y_star), 'precision':precision(y_hat>.5, y_star), 'recall':recall(y_hat>.5, y_star),} print(f"{dataset_key} metrics: \n\tdice: {output['dice'].mean():.3}+/-{output['dice'].std():.3}") print(f"\tprecision: {output['precision'].mean():.3}+/-{output['precision'].std():.3}") print(f"\trecall: {output['recall'].mean():.3}+/-{output['recall'].std():.3}") torch.save(output, Path(args.experiment) / f'{dataset_key}_predictions.pt')