import argparse from pathlib import Path import os import numpy as np import pandas as pd import torch import seaborn as sns import matplotlib.pyplot as plt from torch import nn from tqdm.auto import tqdm from torch import autocast from torch.utils.data import DataLoader from einops.layers.torch import Rearrange from einops import rearrange import sys HEAD = Path(os.getcwd()).parent.parent sys.path.append(HEAD) from models.datasetDM_model import DatasetDM from trainers.train_baseline import dice, precision, recall from dataloaders.JSRT import build_dataloaders from dataloaders.NIH import NIHDataset from dataloaders.Montgomery import MonDataset NIHPATH = "/NIH/" NIHFILE = "correspondence_with_chestXray8.csv" # saved in data MONPATH = "/NLM/MontgomerySet/" MONFILE = "patient_data.csv" if __name__ == "__main__": # load config file and parse arguments parser = argparse.ArgumentParser() parser.add_argument('--experiment', "-e", type=str, help='Experiment path', default="logs/JSRT_conditional/20230213_171633") parser.add_argument('--rerun', "-r", help='Run the test again', default=False, action="store_true") args = parser.parse_args() if os.path.isdir(args.experiment): print("Experiment path identified as a directory") else: raise ValueError("Experiment path is not a directory") files = os.listdir(args.experiment) torch_file = None if {'JSRT_val_predictions.pt', 'JSRT_test_predictions.pt', 'NIH_predictions.pt', 'Montgomery_predictions.pt'} <= set(files) and not args.rerun: print("Experiment already tested") sys.exit(0) for f in files: if "model" in f: torch_file = f break if torch_file is None: raise ValueError("No checkpoint file found in experiment directory") print(f"Loading experiment from {torch_file}") data = torch.load(Path(args.experiment) / torch_file) config = data["config"] # pick model if config.experiment == "datasetDM": model = DatasetDM(config) model.classifier = nn.Sequential( Rearrange('b (step act) h w -> (b step) act h w', step=len(model.steps)), nn.Conv2d(960, 128, 1), nn.ReLU(), nn.BatchNorm2d(128), nn.Conv2d(128, 32, 1), nn.ReLU(), nn.BatchNorm2d(32), nn.Conv2d(32, 1, config.out_channels) ) else: raise ValueError(f"Experiment {config.experiment} not recognized") model.load_state_dict(data['model_state_dict']) # Gather model output model.eval().to(config.device) # Load data dataloaders = build_dataloaders( config.data_dir, config.img_size, config.batch_size, config.num_workers, ) datasets_to_test = { "JSRT_val": dataloaders["val"], "JSRT_test": dataloaders["test"], "NIH": DataLoader(NIHDataset(NIHPATH, NIHPATH, NIHFILE, config.img_size), config.batch_size, num_workers=config.num_workers), "Montgomery": DataLoader(MonDataset(MONPATH, MONPATH, MONFILE, config.img_size), config.batch_size, num_workers=config.num_workers) } for dataset_key in datasets_to_test: if f"{dataset_key}_predictions.pt" in files and not args.rerun: print(f"{dataset_key} already tested") output = torch.load(Path(args.experiment) / f'{dataset_key}_predictions.pt') print(f"{dataset_key} metrics: \n\tdice: {output['dice'].mean():.3}+/-{output['dice'].std():.3}") print(f"\tprecision: {output['precision'].mean():.3}+/-{output['precision'].std():.3}") print(f"\trecall: {output['recall'].mean():.3}+/-{output['recall'].std():.3}") continue print(f"Testing {dataset_key} set") y_hats = [] y_star = [] for i, (x, y) in tqdm(enumerate(datasets_to_test[dataset_key]), desc='Validating'): x = x.to(config.device) with autocast(device_type=config.device, enabled=config.mixed_precision): with torch.no_grad(): # all depths pred = torch.sigmoid(model(x)) y_hats.append(pred.detach().cpu()) y_star.append(y) # save predictions y_star = torch.cat(y_star, 0) y_hats = torch.cat(y_hats, 0) y_hats = rearrange(y_hats, '(b step) 1 h w -> step b 1 h w', step=len(model.steps)) for i, y_hat in enumerate(y_hats): output = { 'y_hat': y_hat, 'y_star': y_star, 'dice':dice(y_hat>.5, y_star), 'precision':precision(y_hat>.5, y_star), 'recall':recall(y_hat>.5, y_star),} print(f"{dataset_key} {model.steps[i]} metrics: \n\tdice: {output['dice'].mean():.3}+/-{output['dice'].std():.3}") print(f"\tprecision: {output['precision'].mean():.3}+/-{output['precision'].std():.3}") print(f"\trecall: {output['recall'].mean():.3}+/-{output['recall'].std():.3}") torch.save(output, Path(args.experiment) / f'{dataset_key}_timestep{model.steps[i]}_predictions.pt') y_hat = y_hats.mean(0) output = { 'y_hat': y_hat, 'y_star': y_star, 'dice':dice(y_hat>.5, y_star), 'precision':precision(y_hat>.5, y_star), 'recall':recall(y_hat>.5, y_star),} print(f"{dataset_key} metrics: \n\tdice: {output['dice'].mean():.3}+/-{output['dice'].std():.3}") print(f"\tprecision: {output['precision'].mean():.3}+/-{output['precision'].std():.3}") print(f"\trecall: {output['recall'].mean():.3}+/-{output['recall'].std():.3}") torch.save(output, Path(args.experiment) / f'{dataset_key}_predictions.pt')