File size: 5,957 Bytes
a2dba58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import argparse
from pathlib import Path
import os
import numpy as np
import pandas as pd
import torch
import seaborn as sns
import matplotlib.pyplot as plt
from torch import nn
from tqdm.auto import tqdm
from torch import autocast
from torch.utils.data import DataLoader
from einops.layers.torch import Rearrange
from einops import rearrange
import sys
HEAD = Path(os.getcwd()).parent.parent
sys.path.append(HEAD)
from models.datasetDM_model import DatasetDM
from trainers.train_baseline import dice, precision, recall
from dataloaders.JSRT import build_dataloaders
from dataloaders.NIH import NIHDataset
from dataloaders.Montgomery import MonDataset

NIHPATH = "<PATH_TO_DATA>/NIH/"
NIHFILE = "correspondence_with_chestXray8.csv" # saved in data
MONPATH = "<PATH_TO_DATA>/NLM/MontgomerySet/"
MONFILE = "patient_data.csv"


if __name__ == "__main__":
    # load config file and parse arguments
    parser = argparse.ArgumentParser()
    parser.add_argument('--experiment', "-e", type=str, help='Experiment path', default="logs/JSRT_conditional/20230213_171633")
    parser.add_argument('--rerun', "-r", help='Run the test again', default=False, action="store_true")
    args = parser.parse_args()

    if os.path.isdir(args.experiment):
        print("Experiment path identified as a directory")
    else:
        raise ValueError("Experiment path is not a directory")
    files = os.listdir(args.experiment)
    torch_file = None
    if {'JSRT_val_predictions.pt', 'JSRT_test_predictions.pt', 'NIH_predictions.pt', 'Montgomery_predictions.pt'} <= set(files) and not args.rerun:
        print("Experiment already tested")
        sys.exit(0)

    for f in files:
        if "model" in f:
            torch_file = f
            break
    if torch_file is None:
        raise ValueError("No checkpoint file found in experiment directory")
    
    print(f"Loading experiment from {torch_file}")
    data = torch.load(Path(args.experiment) / torch_file)
    config = data["config"]

    # pick model
    if config.experiment == "datasetDM":
        model = DatasetDM(config)
        model.classifier = nn.Sequential(
            Rearrange('b (step act) h w -> (b step) act h w', step=len(model.steps)),
            nn.Conv2d(960, 128, 1),
            nn.ReLU(),
            nn.BatchNorm2d(128),
            nn.Conv2d(128, 32, 1),
            nn.ReLU(),
            nn.BatchNorm2d(32),
            nn.Conv2d(32, 1, config.out_channels)
            )
    else:
        raise ValueError(f"Experiment {config.experiment} not recognized")
    model.load_state_dict(data['model_state_dict'])

    # Gather model output
    model.eval().to(config.device)

    # Load data
    dataloaders = build_dataloaders(
        config.data_dir,
        config.img_size,
        config.batch_size,
        config.num_workers,
    )
    datasets_to_test = {
        "JSRT_val": dataloaders["val"],
        "JSRT_test": dataloaders["test"],
        "NIH": DataLoader(NIHDataset(NIHPATH, NIHPATH, NIHFILE, config.img_size), 
                          config.batch_size, num_workers=config.num_workers),
        "Montgomery": DataLoader(MonDataset(MONPATH, MONPATH, MONFILE, config.img_size),
                                 config.batch_size, num_workers=config.num_workers)

    }

    for dataset_key in datasets_to_test:
        if f"{dataset_key}_predictions.pt" in files and not args.rerun:
            print(f"{dataset_key} already tested")
            output = torch.load(Path(args.experiment) / f'{dataset_key}_predictions.pt')
            print(f"{dataset_key} metrics: \n\tdice:      {output['dice'].mean():.3}+/-{output['dice'].std():.3}")
            print(f"\tprecision: {output['precision'].mean():.3}+/-{output['precision'].std():.3}")
            print(f"\trecall:    {output['recall'].mean():.3}+/-{output['recall'].std():.3}")
            continue

        print(f"Testing {dataset_key} set")
        y_hats = []
        y_star = []
        for i, (x, y) in tqdm(enumerate(datasets_to_test[dataset_key]), desc='Validating'):
            x = x.to(config.device)

            with autocast(device_type=config.device, enabled=config.mixed_precision):
                with torch.no_grad():
                    # all depths
                    pred = torch.sigmoid(model(x))
            y_hats.append(pred.detach().cpu())
            y_star.append(y)
        
        # save predictions
        y_star = torch.cat(y_star, 0)
        y_hats = torch.cat(y_hats, 0)
        y_hats = rearrange(y_hats, '(b step) 1 h w -> step b 1 h w', step=len(model.steps))
        for i, y_hat in enumerate(y_hats):
            output = {
                'y_hat': y_hat, 
                'y_star': y_star,
                'dice':dice(y_hat>.5, y_star),
                'precision':precision(y_hat>.5, y_star),
                'recall':recall(y_hat>.5, y_star),}
            
            print(f"{dataset_key} {model.steps[i]} metrics: \n\tdice:      {output['dice'].mean():.3}+/-{output['dice'].std():.3}")
            print(f"\tprecision: {output['precision'].mean():.3}+/-{output['precision'].std():.3}")
            print(f"\trecall:    {output['recall'].mean():.3}+/-{output['recall'].std():.3}")
            torch.save(output, Path(args.experiment) / f'{dataset_key}_timestep{model.steps[i]}_predictions.pt')
        y_hat = y_hats.mean(0)
        output = {
                'y_hat': y_hat, 
                'y_star': y_star,
                'dice':dice(y_hat>.5, y_star),
                'precision':precision(y_hat>.5, y_star),
                'recall':recall(y_hat>.5, y_star),}
            
        print(f"{dataset_key} metrics: \n\tdice:      {output['dice'].mean():.3}+/-{output['dice'].std():.3}")
        print(f"\tprecision: {output['precision'].mean():.3}+/-{output['precision'].std():.3}")
        print(f"\trecall:    {output['recall'].mean():.3}+/-{output['recall'].std():.3}")
        torch.save(output, Path(args.experiment) / f'{dataset_key}_predictions.pt')