Code updates
Browse files- inference_brain2vec_PCA.py +222 -0
- model.py +0 -115
- requirements.txt +6 -3
- brain2vec_PCA.py → train_brain2vec_PCA.py +145 -88
inference_brain2vec_PCA.py
ADDED
@@ -0,0 +1,222 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
|
3 |
+
"""
|
4 |
+
inference_brain2vec_PCA.py
|
5 |
+
|
6 |
+
Loads a pre-trained PCA-based Brain2Vec model (saved with joblib) and performs
|
7 |
+
inference on one or more input images. Produces embeddings (and optional
|
8 |
+
reconstructions) for each image.
|
9 |
+
|
10 |
+
Example usage:
|
11 |
+
|
12 |
+
python inference_brain2vec_PCA.py \
|
13 |
+
--pca_model /path/to/pca_model.joblib \
|
14 |
+
--input_images /path/to/img1.nii.gz /path/to/img2.nii.gz \
|
15 |
+
--output_dir /path/to/out
|
16 |
+
|
17 |
+
Or, if you have a CSV with image paths:
|
18 |
+
|
19 |
+
python inference_brain2vec_PCA.py \
|
20 |
+
--pca_model /path/to/pca_model.joblib \
|
21 |
+
--csv_input /path/to/images.csv \
|
22 |
+
--output_dir /path/to/out
|
23 |
+
"""
|
24 |
+
|
25 |
+
import os
|
26 |
+
import argparse
|
27 |
+
import numpy as np
|
28 |
+
import torch
|
29 |
+
import torch.nn as nn
|
30 |
+
from joblib import load
|
31 |
+
import pandas as pd
|
32 |
+
|
33 |
+
from monai.transforms import (
|
34 |
+
Compose,
|
35 |
+
CopyItemsD,
|
36 |
+
LoadImageD,
|
37 |
+
EnsureChannelFirstD,
|
38 |
+
SpacingD,
|
39 |
+
ResizeWithPadOrCropD,
|
40 |
+
ScaleIntensityD,
|
41 |
+
)
|
42 |
+
|
43 |
+
# Global constants
|
44 |
+
RESOLUTION = 2
|
45 |
+
INPUT_SHAPE_AE = (80, 96, 80)
|
46 |
+
FLATTENED_DIM = INPUT_SHAPE_AE[0] * INPUT_SHAPE_AE[1] * INPUT_SHAPE_AE[2]
|
47 |
+
|
48 |
+
# Reusable MONAI pipeline for preprocessing
|
49 |
+
transforms_fn = Compose([
|
50 |
+
CopyItemsD(keys={'image_path'}, names=['image']),
|
51 |
+
LoadImageD(image_only=True, keys=['image']),
|
52 |
+
EnsureChannelFirstD(keys=['image']),
|
53 |
+
SpacingD(pixdim=RESOLUTION, keys=['image']),
|
54 |
+
ResizeWithPadOrCropD(spatial_size=INPUT_SHAPE_AE, mode='minimum', keys=['image']),
|
55 |
+
ScaleIntensityD(minv=0, maxv=1, keys=['image']),
|
56 |
+
])
|
57 |
+
|
58 |
+
|
59 |
+
def preprocess_mri(image_path: str) -> torch.Tensor:
|
60 |
+
"""
|
61 |
+
Preprocess an MRI using MONAI transforms to produce
|
62 |
+
a 5D Torch tensor: (batch=1, channel=1, D, H, W).
|
63 |
+
|
64 |
+
Args:
|
65 |
+
image_path (str): Path to the MRI (e.g., .nii.gz file).
|
66 |
+
|
67 |
+
Returns:
|
68 |
+
torch.Tensor: Preprocessed 5D tensor of shape (1, 1, D, H, W).
|
69 |
+
"""
|
70 |
+
data_dict = {"image_path": image_path}
|
71 |
+
output_dict = transforms_fn(data_dict)
|
72 |
+
# shape => (1, D, H, W)
|
73 |
+
image_tensor = output_dict["image"].unsqueeze(0) # => (1, 1, D, H, W)
|
74 |
+
return image_tensor.float()
|
75 |
+
|
76 |
+
|
77 |
+
class PCABrain2vec(nn.Module):
|
78 |
+
"""
|
79 |
+
A PCA-based 'autoencoder' that mimics a typical VAE interface:
|
80 |
+
- from_pretrained(...) to load a PCA model from disk
|
81 |
+
- forward(...) returns (reconstruction, embedding, None)
|
82 |
+
|
83 |
+
Steps:
|
84 |
+
1. Flatten the input volume (N, 1, D, H, W) => (N, 614400).
|
85 |
+
2. Transform -> embeddings => shape (N, n_components).
|
86 |
+
3. Inverse transform -> recon => shape (N, 614400).
|
87 |
+
4. Reshape => (N, 1, D, H, W).
|
88 |
+
"""
|
89 |
+
|
90 |
+
def __init__(self, pca_model=None):
|
91 |
+
super().__init__()
|
92 |
+
self.pca_model = pca_model
|
93 |
+
|
94 |
+
def forward(self, x: torch.Tensor):
|
95 |
+
"""
|
96 |
+
Perform a forward pass of the PCA-based "autoencoder".
|
97 |
+
|
98 |
+
Args:
|
99 |
+
x (torch.Tensor): Input of shape (N, 1, D, H, W).
|
100 |
+
|
101 |
+
Returns:
|
102 |
+
tuple(torch.Tensor, torch.Tensor, None):
|
103 |
+
- reconstruction: (N, 1, D, H, W)
|
104 |
+
- embedding: (N, n_components)
|
105 |
+
- None (to align with the typical VAE interface).
|
106 |
+
"""
|
107 |
+
n_samples = x.shape[0]
|
108 |
+
x_cpu = x.detach().cpu().numpy() # (N, 1, D, H, W)
|
109 |
+
x_flat = x_cpu.reshape(n_samples, -1) # => (N, FLATTENED_DIM)
|
110 |
+
|
111 |
+
# PCA transform => embeddings shape (N, n_components)
|
112 |
+
embedding_np = self.pca_model.transform(x_flat)
|
113 |
+
|
114 |
+
# PCA inverse_transform => recon shape (N, FLATTENED_DIM)
|
115 |
+
recon_np = self.pca_model.inverse_transform(embedding_np)
|
116 |
+
recon_np = recon_np.reshape(n_samples, 1, *INPUT_SHAPE_AE)
|
117 |
+
|
118 |
+
# Convert back to torch
|
119 |
+
reconstruction_torch = torch.from_numpy(recon_np).float()
|
120 |
+
embedding_torch = torch.from_numpy(embedding_np).float()
|
121 |
+
return reconstruction_torch, embedding_torch, None
|
122 |
+
|
123 |
+
@staticmethod
|
124 |
+
def from_pretrained(pca_path: str) -> "PCABrain2vec":
|
125 |
+
"""
|
126 |
+
Load a pre-trained PCA model (pickled or joblib) from disk.
|
127 |
+
|
128 |
+
Args:
|
129 |
+
pca_path (str): File path to the PCA model.
|
130 |
+
|
131 |
+
Returns:
|
132 |
+
PCABrain2vec: An instance wrapping the loaded PCA model.
|
133 |
+
"""
|
134 |
+
if not os.path.exists(pca_path):
|
135 |
+
raise FileNotFoundError(f"Could not find PCA model at {pca_path}")
|
136 |
+
|
137 |
+
pca_model = load(pca_path)
|
138 |
+
return PCABrain2vec(pca_model=pca_model)
|
139 |
+
|
140 |
+
|
141 |
+
def main() -> None:
|
142 |
+
"""
|
143 |
+
Main function to parse command-line arguments and run inference
|
144 |
+
with a pre-trained PCA Brain2Vec model.
|
145 |
+
"""
|
146 |
+
parser = argparse.ArgumentParser(
|
147 |
+
description="PCA-based Brain2Vec Inference Script"
|
148 |
+
)
|
149 |
+
parser.add_argument(
|
150 |
+
"--pca_model", type=str, required=True,
|
151 |
+
help="Path to the saved PCA model (.joblib)."
|
152 |
+
)
|
153 |
+
parser.add_argument(
|
154 |
+
"--output_dir", type=str, default="./pca_inference_outputs",
|
155 |
+
help="Directory to save embeddings/reconstructions."
|
156 |
+
)
|
157 |
+
# Two ways to supply images: multiple files or a CSV
|
158 |
+
parser.add_argument(
|
159 |
+
"--input_images", type=str, nargs="*",
|
160 |
+
help="One or more image paths for inference."
|
161 |
+
)
|
162 |
+
parser.add_argument(
|
163 |
+
"--csv_input", type=str, default=None,
|
164 |
+
help="Path to a CSV containing column 'image_path'."
|
165 |
+
)
|
166 |
+
args = parser.parse_args()
|
167 |
+
|
168 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
169 |
+
|
170 |
+
# Build the PCA model
|
171 |
+
pca_brain2vec = PCABrain2vec.from_pretrained(args.pca_model)
|
172 |
+
pca_brain2vec.eval()
|
173 |
+
|
174 |
+
# Gather image paths
|
175 |
+
if args.csv_input:
|
176 |
+
df = pd.read_csv(args.csv_input)
|
177 |
+
if "image_path" not in df.columns:
|
178 |
+
raise ValueError("CSV must contain a column named 'image_path'.")
|
179 |
+
image_paths = df["image_path"].tolist()
|
180 |
+
else:
|
181 |
+
if not args.input_images:
|
182 |
+
raise ValueError(
|
183 |
+
"Must provide either --csv_input or --input_images."
|
184 |
+
)
|
185 |
+
image_paths = args.input_images
|
186 |
+
|
187 |
+
# Inference loop
|
188 |
+
all_embeddings = []
|
189 |
+
for i, img_path in enumerate(image_paths):
|
190 |
+
if not os.path.exists(img_path):
|
191 |
+
raise FileNotFoundError(f"Image not found: {img_path}")
|
192 |
+
|
193 |
+
# Preprocess
|
194 |
+
img_tensor = preprocess_mri(img_path)
|
195 |
+
|
196 |
+
# Forward pass
|
197 |
+
with torch.no_grad():
|
198 |
+
recon, embedding, _ = pca_brain2vec(img_tensor)
|
199 |
+
|
200 |
+
# Convert to CPU numpy
|
201 |
+
embedding_np = embedding.detach().cpu().numpy()
|
202 |
+
recon_np = recon.detach().cpu().numpy()
|
203 |
+
|
204 |
+
# Save (one embedding row per image)
|
205 |
+
all_embeddings.append(embedding_np)
|
206 |
+
|
207 |
+
# Optionally save or visualize reconstructions
|
208 |
+
out_recon_path = os.path.join(
|
209 |
+
args.output_dir, f"reconstruction_{i}.npy"
|
210 |
+
)
|
211 |
+
np.save(out_recon_path, recon_np)
|
212 |
+
print(f"[INFO] Saved reconstruction to: {out_recon_path}")
|
213 |
+
|
214 |
+
# Save all embeddings stacked
|
215 |
+
stacked_embeddings = np.vstack(all_embeddings) # (N, n_components)
|
216 |
+
out_embed_path = os.path.join(args.output_dir, "all_pca_embeddings.npy")
|
217 |
+
np.save(out_embed_path, stacked_embeddings)
|
218 |
+
print(f"[INFO] Saved embeddings of shape {stacked_embeddings.shape} to: {out_embed_path}")
|
219 |
+
|
220 |
+
|
221 |
+
if __name__ == "__main__":
|
222 |
+
main()
|
model.py
DELETED
@@ -1,115 +0,0 @@
|
|
1 |
-
# model.py
|
2 |
-
import os
|
3 |
-
import numpy as np
|
4 |
-
import torch
|
5 |
-
import torch.nn as nn
|
6 |
-
|
7 |
-
from monai.transforms import (
|
8 |
-
Compose,
|
9 |
-
CopyItemsD,
|
10 |
-
LoadImageD,
|
11 |
-
EnsureChannelFirstD,
|
12 |
-
SpacingD,
|
13 |
-
ResizeWithPadOrCropD,
|
14 |
-
ScaleIntensityD,
|
15 |
-
)
|
16 |
-
|
17 |
-
# If you used joblib or pickle to save your PCA model:
|
18 |
-
from joblib import load # or "import pickle"
|
19 |
-
|
20 |
-
#################################################
|
21 |
-
# Constants
|
22 |
-
#################################################
|
23 |
-
RESOLUTION = 2
|
24 |
-
INPUT_SHAPE_AE = (80, 96, 80) # The typical shape from your pipelines
|
25 |
-
FLATTENED_DIM = INPUT_SHAPE_AE[0] * INPUT_SHAPE_AE[1] * INPUT_SHAPE_AE[2]
|
26 |
-
|
27 |
-
|
28 |
-
#################################################
|
29 |
-
# Define MONAI Transforms for Preprocessing
|
30 |
-
#################################################
|
31 |
-
transforms_fn = Compose([
|
32 |
-
CopyItemsD(keys={'image_path'}, names=['image']),
|
33 |
-
LoadImageD(image_only=True, keys=['image']),
|
34 |
-
EnsureChannelFirstD(keys=['image']),
|
35 |
-
SpacingD(pixdim=RESOLUTION, keys=['image']),
|
36 |
-
ResizeWithPadOrCropD(spatial_size=INPUT_SHAPE_AE, mode='minimum', keys=['image']),
|
37 |
-
ScaleIntensityD(minv=0, maxv=1, keys=['image']),
|
38 |
-
])
|
39 |
-
|
40 |
-
|
41 |
-
def preprocess_mri(image_path: str) -> torch.Tensor:
|
42 |
-
"""
|
43 |
-
Preprocess an MRI using MONAI transforms to produce
|
44 |
-
a 5D Torch tensor: (batch=1, channel=1, D, H, W).
|
45 |
-
"""
|
46 |
-
data_dict = {"image_path": image_path}
|
47 |
-
output_dict = transforms_fn(data_dict)
|
48 |
-
# shape => (1, D, H, W)
|
49 |
-
image_tensor = output_dict["image"].unsqueeze(0) # => (batch=1, channel=1, D, H, W)
|
50 |
-
return image_tensor.float() # typically float32
|
51 |
-
|
52 |
-
|
53 |
-
#################################################
|
54 |
-
# PCA "Autoencoder" Wrapper
|
55 |
-
#################################################
|
56 |
-
class PCABrain2vec(nn.Module):
|
57 |
-
"""
|
58 |
-
A PCA-based 'autoencoder' that mimics the old interface:
|
59 |
-
- from_pretrained(...) to load a PCA model from disk
|
60 |
-
- forward(...) returns (reconstruction, embedding, None)
|
61 |
-
|
62 |
-
Under the hood, it:
|
63 |
-
- takes in a torch tensor shape (N, 1, D, H, W)
|
64 |
-
- flattens it (N, 614400)
|
65 |
-
- uses PCA's transform(...) to get embeddings => shape (N, n_components)
|
66 |
-
- uses inverse_transform(...) to get reconstructions => shape (N, 614400)
|
67 |
-
- reshapes back to (N, 1, D, H, W)
|
68 |
-
"""
|
69 |
-
|
70 |
-
def __init__(self, pca_model=None):
|
71 |
-
super().__init__()
|
72 |
-
# We'll store the fitted PCA model (from scikit-learn)
|
73 |
-
self.pca_model = pca_model # e.g., an instance of IncrementalPCA or PCA
|
74 |
-
|
75 |
-
def forward(self, x: torch.Tensor):
|
76 |
-
"""
|
77 |
-
Returns (reconstruction, embedding, None).
|
78 |
-
|
79 |
-
1) Convert x => numpy array => flatten => (N, 614400)
|
80 |
-
2) embedding = pca_model.transform(flat_x)
|
81 |
-
3) reconstruction_np = pca_model.inverse_transform(embedding)
|
82 |
-
4) reshape => (N, 1, 80, 96, 80)
|
83 |
-
5) convert to torch => return (recon, embed, None)
|
84 |
-
"""
|
85 |
-
# Expect x shape => (N, 1, D, H, W) => flatten to (N, D*H*W)
|
86 |
-
n_samples = x.shape[0]
|
87 |
-
# Convert to CPU np
|
88 |
-
x_cpu = x.detach().cpu().numpy() # shape: (N, 1, D, H, W)
|
89 |
-
x_flat = x_cpu.reshape(n_samples, -1) # shape: (N, 614400)
|
90 |
-
|
91 |
-
# PCA transform => embeddings shape (N, n_components)
|
92 |
-
embedding_np = self.pca_model.transform(x_flat)
|
93 |
-
|
94 |
-
# PCA inverse_transform => recon shape (N, 614400)
|
95 |
-
recon_np = self.pca_model.inverse_transform(embedding_np)
|
96 |
-
# Reshape back => (N, 1, 80, 96, 80)
|
97 |
-
recon_np = recon_np.reshape(n_samples, 1, *INPUT_SHAPE_AE)
|
98 |
-
|
99 |
-
# Convert back to torch
|
100 |
-
reconstruction_torch = torch.from_numpy(recon_np).float()
|
101 |
-
embedding_torch = torch.from_numpy(embedding_np).float()
|
102 |
-
return reconstruction_torch, embedding_torch, None
|
103 |
-
|
104 |
-
@staticmethod
|
105 |
-
def from_pretrained(pca_path: str):
|
106 |
-
"""
|
107 |
-
Load a pre-trained PCA model (pickled or joblib).
|
108 |
-
Returns an instance of PCABrain2vec with that model.
|
109 |
-
"""
|
110 |
-
if not os.path.exists(pca_path):
|
111 |
-
raise FileNotFoundError(f"Could not find PCA model at {pca_path}")
|
112 |
-
# Example: pca_model = pickle.load(open(pca_path, 'rb'))
|
113 |
-
# or use joblib:
|
114 |
-
pca_model = load(pca_path)
|
115 |
-
return PCABrain2vec(pca_model=pca_model)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
requirements.txt
CHANGED
@@ -1,12 +1,15 @@
|
|
1 |
# requirements.txt
|
2 |
|
3 |
-
# PyTorch (CUDA or CPU version).
|
4 |
torch>=1.12
|
5 |
|
6 |
-
# MONAI
|
7 |
-
monai-weekly
|
8 |
monai-generative
|
9 |
|
|
|
|
|
|
|
|
|
10 |
# For perceptual losses in MONAI's generative module.
|
11 |
lpips
|
12 |
|
|
|
1 |
# requirements.txt
|
2 |
|
3 |
+
# PyTorch (CUDA or CPU version).
|
4 |
torch>=1.12
|
5 |
|
6 |
+
# Install MONAI Generative first
|
|
|
7 |
monai-generative
|
8 |
|
9 |
+
# Now force reinstall MONAI Weekly so its (newer) MONAI version takes precedence
|
10 |
+
--force-reinstall
|
11 |
+
monai-weekly
|
12 |
+
|
13 |
# For perceptual losses in MONAI's generative module.
|
14 |
lpips
|
15 |
|
brain2vec_PCA.py → train_brain2vec_PCA.py
RENAMED
@@ -1,101 +1,115 @@
|
|
1 |
#!/usr/bin/env python3
|
2 |
|
3 |
"""
|
4 |
-
|
5 |
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
|
10 |
Example usage:
|
11 |
-
python
|
12 |
--inputs_csv /path/to/inputs.csv \
|
13 |
--output_dir ./pca_outputs \
|
14 |
--pca_type standard \
|
15 |
-
--n_components
|
16 |
"""
|
17 |
|
18 |
import os
|
19 |
import argparse
|
20 |
import numpy as np
|
21 |
import pandas as pd
|
22 |
-
|
23 |
import torch
|
24 |
from torch.utils.data import DataLoader
|
25 |
-
|
26 |
from monai import transforms
|
27 |
from monai.data import Dataset, PersistentDataset
|
28 |
-
|
29 |
-
# We'll import both PCA classes, and decide which to use based on CLI arg.
|
30 |
from sklearn.decomposition import PCA, IncrementalPCA
|
|
|
31 |
|
32 |
-
|
33 |
-
###################################################################
|
34 |
-
# Constants for your typical config
|
35 |
-
###################################################################
|
36 |
RESOLUTION = 2
|
|
|
|
|
37 |
INPUT_SHAPE_AE = (80, 96, 80)
|
|
|
38 |
DEFAULT_N_COMPONENTS = 1200
|
39 |
|
40 |
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
|
|
45 |
"""
|
46 |
-
|
47 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
48 |
"""
|
|
|
49 |
if cache_dir and cache_dir.strip():
|
50 |
os.makedirs(cache_dir, exist_ok=True)
|
51 |
-
dataset = PersistentDataset(
|
52 |
-
|
53 |
-
|
|
|
|
|
54 |
else:
|
55 |
-
dataset = Dataset(data=
|
56 |
-
transform=transforms_fn)
|
57 |
return dataset
|
58 |
|
59 |
|
60 |
-
###################################################################
|
61 |
-
# PCAAutoencoder
|
62 |
-
###################################################################
|
63 |
class PCAAutoencoder:
|
64 |
"""
|
65 |
A PCA 'autoencoder' that can use either standard PCA or IncrementalPCA:
|
66 |
- fit(X): trains the model
|
67 |
- transform(X): get embeddings
|
68 |
- inverse_transform(Z): reconstruct data from embeddings
|
69 |
-
- forward(X): returns (X_recon, Z)
|
70 |
-
|
71 |
-
If using standard PCA,
|
72 |
-
If using incremental PCA,
|
73 |
"""
|
74 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
75 |
"""
|
|
|
|
|
76 |
Args:
|
77 |
-
n_components (int):
|
78 |
-
batch_size (int):
|
79 |
-
pca_type (str): 'incremental' or 'standard'
|
80 |
"""
|
81 |
self.n_components = n_components
|
82 |
self.batch_size = batch_size
|
83 |
self.pca_type = pca_type.lower()
|
84 |
|
85 |
-
if self.pca_type == '
|
86 |
-
self.ipca = PCA(n_components=self.n_components, svd_solver='randomized')
|
87 |
-
else:
|
88 |
-
# default to incremental
|
89 |
self.ipca = IncrementalPCA(n_components=self.n_components)
|
|
|
|
|
|
|
90 |
|
91 |
-
def fit(self, X: np.ndarray):
|
92 |
"""
|
93 |
-
Fit the PCA model. If incremental, calls partial_fit in batches
|
94 |
-
|
95 |
-
|
|
|
|
|
96 |
"""
|
97 |
if self.pca_type == 'standard':
|
98 |
-
# Potentially large memory usage, so be sure your system can handle it.
|
99 |
self.ipca.fit(X)
|
100 |
else:
|
101 |
# IncrementalPCA
|
@@ -107,7 +121,12 @@ class PCAAutoencoder:
|
|
107 |
def transform(self, X: np.ndarray) -> np.ndarray:
|
108 |
"""
|
109 |
Project data into the PCA latent space in batches for memory efficiency.
|
110 |
-
|
|
|
|
|
|
|
|
|
|
|
111 |
"""
|
112 |
results = []
|
113 |
n_samples = X.shape[0]
|
@@ -120,7 +139,12 @@ class PCAAutoencoder:
|
|
120 |
def inverse_transform(self, Z: np.ndarray) -> np.ndarray:
|
121 |
"""
|
122 |
Reconstruct data from PCA latent space in batches.
|
123 |
-
|
|
|
|
|
|
|
|
|
|
|
124 |
"""
|
125 |
results = []
|
126 |
n_samples = Z.shape[0]
|
@@ -130,80 +154,113 @@ class PCAAutoencoder:
|
|
130 |
results.append(X_chunk)
|
131 |
return np.vstack(results)
|
132 |
|
133 |
-
def forward(self, X: np.ndarray) ->
|
134 |
"""
|
135 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
136 |
"""
|
137 |
Z = self.transform(X)
|
138 |
X_recon = self.inverse_transform(Z)
|
139 |
return X_recon, Z
|
140 |
|
141 |
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
|
|
146 |
"""
|
|
|
|
|
147 |
1) Reads CSV.
|
148 |
-
2) Filters rows if 'split' in columns => only keep
|
149 |
-
3) Applies transforms to each image, flattening them into a 1D vector
|
150 |
-
4) Returns a NumPy array X
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
151 |
"""
|
152 |
df = pd.read_csv(csv_path)
|
153 |
|
154 |
-
#
|
155 |
if 'split' in df.columns:
|
156 |
df = df[df['split'] == 'train']
|
157 |
-
# If there is no 'split' column, we assume the entire CSV is for training.
|
158 |
|
159 |
dataset = get_dataset_from_pd(df, transforms_fn, cache_dir)
|
160 |
loader = DataLoader(dataset, batch_size=1, num_workers=0)
|
161 |
|
162 |
-
# We'll store each flattened volume in a list, then stack
|
163 |
X_list = []
|
164 |
for batch in loader:
|
165 |
-
# batch["image"] shape
|
166 |
-
img = batch["image"].squeeze(0) # => (1, 80, 96, 80)
|
167 |
-
|
168 |
-
flattened = img_np.flatten() # => (614400,)
|
169 |
X_list.append(flattened)
|
170 |
|
171 |
-
if
|
172 |
-
raise ValueError(
|
|
|
|
|
173 |
|
174 |
X = np.vstack(X_list)
|
175 |
return X
|
176 |
|
177 |
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
parser
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
196 |
args = parser.parse_args()
|
197 |
|
198 |
os.makedirs(args.output_dir, exist_ok=True)
|
199 |
|
200 |
-
# define transforms as in brain2vec_linearAE.py
|
201 |
transforms_fn = transforms.Compose([
|
202 |
transforms.CopyItemsD(keys={'image_path'}, names=['image']),
|
203 |
transforms.LoadImageD(image_only=True, keys=['image']),
|
204 |
transforms.EnsureChannelFirstD(keys=['image']),
|
205 |
transforms.SpacingD(pixdim=RESOLUTION, keys=['image']),
|
206 |
-
transforms.ResizeWithPadOrCropD(
|
|
|
|
|
207 |
transforms.ScaleIntensityD(minv=0, maxv=1, keys=['image']),
|
208 |
])
|
209 |
|
@@ -225,10 +282,10 @@ def main():
|
|
225 |
|
226 |
# Get embeddings & reconstruction
|
227 |
X_recon, Z = model.forward(X)
|
228 |
-
print("Embeddings shape:", Z.shape)
|
229 |
-
print("Reconstruction shape:", X_recon.shape)
|
230 |
|
231 |
-
# Save
|
232 |
embeddings_path = os.path.join(args.output_dir, "pca_embeddings.npy")
|
233 |
recons_path = os.path.join(args.output_dir, "pca_reconstructions.npy")
|
234 |
np.save(embeddings_path, Z)
|
|
|
1 |
#!/usr/bin/env python3
|
2 |
|
3 |
"""
|
4 |
+
train_brain2vec_PCA.py
|
5 |
|
6 |
+
A PCA-based "autoencoder" script for brain MRI data, with support for both
|
7 |
+
incremental PCA and standard PCA. Only scans labeled 'train' in the CSV
|
8 |
+
(split == 'train') will be used for fitting.
|
9 |
|
10 |
Example usage:
|
11 |
+
python train_brain2vec_PCA.py \
|
12 |
--inputs_csv /path/to/inputs.csv \
|
13 |
--output_dir ./pca_outputs \
|
14 |
--pca_type standard \
|
15 |
+
--n_components 1200
|
16 |
"""
|
17 |
|
18 |
import os
|
19 |
import argparse
|
20 |
import numpy as np
|
21 |
import pandas as pd
|
|
|
22 |
import torch
|
23 |
from torch.utils.data import DataLoader
|
|
|
24 |
from monai import transforms
|
25 |
from monai.data import Dataset, PersistentDataset
|
26 |
+
from monai.transforms.transform import Transform
|
|
|
27 |
from sklearn.decomposition import PCA, IncrementalPCA
|
28 |
+
from typing import Optional, Union, Tuple
|
29 |
|
30 |
+
# voxel resolution
|
|
|
|
|
|
|
31 |
RESOLUTION = 2
|
32 |
+
|
33 |
+
# cropped image dimensions after transform
|
34 |
INPUT_SHAPE_AE = (80, 96, 80)
|
35 |
+
|
36 |
DEFAULT_N_COMPONENTS = 1200
|
37 |
|
38 |
|
39 |
+
def get_dataset_from_pd(
|
40 |
+
df: pd.DataFrame,
|
41 |
+
transforms_fn: Transform,
|
42 |
+
cache_dir: Optional[str]
|
43 |
+
) -> Union[Dataset, PersistentDataset]:
|
44 |
"""
|
45 |
+
Create a MONAI Dataset or PersistentDataset from the given DataFrame.
|
46 |
+
|
47 |
+
Args:
|
48 |
+
df (pd.DataFrame): DataFrame with at least 'image_path' column.
|
49 |
+
transforms_fn (Transform): MONAI transform pipeline.
|
50 |
+
cache_dir (Optional[str]): If provided, use PersistentDataset caching.
|
51 |
+
|
52 |
+
Returns:
|
53 |
+
Dataset|PersistentDataset: A dataset for training or inference.
|
54 |
"""
|
55 |
+
data_dicts = df.to_dict(orient='records')
|
56 |
if cache_dir and cache_dir.strip():
|
57 |
os.makedirs(cache_dir, exist_ok=True)
|
58 |
+
dataset = PersistentDataset(
|
59 |
+
data=data_dicts,
|
60 |
+
transform=transforms_fn,
|
61 |
+
cache_dir=cache_dir
|
62 |
+
)
|
63 |
else:
|
64 |
+
dataset = Dataset(data=data_dicts, transform=transforms_fn)
|
|
|
65 |
return dataset
|
66 |
|
67 |
|
|
|
|
|
|
|
68 |
class PCAAutoencoder:
|
69 |
"""
|
70 |
A PCA 'autoencoder' that can use either standard PCA or IncrementalPCA:
|
71 |
- fit(X): trains the model
|
72 |
- transform(X): get embeddings
|
73 |
- inverse_transform(Z): reconstruct data from embeddings
|
74 |
+
- forward(X): returns (X_recon, Z).
|
75 |
+
|
76 |
+
If using standard PCA, a single call to .fit(X) is made.
|
77 |
+
If using incremental PCA, .partial_fit is called in batches.
|
78 |
"""
|
79 |
+
|
80 |
+
def __init__(
|
81 |
+
self,
|
82 |
+
n_components: int = DEFAULT_N_COMPONENTS,
|
83 |
+
batch_size: int = 128,
|
84 |
+
pca_type: str = 'standard'
|
85 |
+
) -> None:
|
86 |
"""
|
87 |
+
Initialize the PCAAutoencoder.
|
88 |
+
|
89 |
Args:
|
90 |
+
n_components (int): Number of principal components to keep.
|
91 |
+
batch_size (int): Chunk size for partial_fit or chunked transform.
|
92 |
+
pca_type (str): Either 'incremental' or 'standard'.
|
93 |
"""
|
94 |
self.n_components = n_components
|
95 |
self.batch_size = batch_size
|
96 |
self.pca_type = pca_type.lower()
|
97 |
|
98 |
+
if self.pca_type == 'incremental':
|
|
|
|
|
|
|
99 |
self.ipca = IncrementalPCA(n_components=self.n_components)
|
100 |
+
else:
|
101 |
+
# Default to standard PCA
|
102 |
+
self.ipca = PCA(n_components=self.n_components, svd_solver='randomized')
|
103 |
|
104 |
+
def fit(self, X: np.ndarray) -> None:
|
105 |
"""
|
106 |
+
Fit the PCA model. If incremental PCA, calls partial_fit in batches;
|
107 |
+
otherwise calls .fit once on the entire data array.
|
108 |
+
|
109 |
+
Args:
|
110 |
+
X (np.ndarray): Shape (n_samples, n_features).
|
111 |
"""
|
112 |
if self.pca_type == 'standard':
|
|
|
113 |
self.ipca.fit(X)
|
114 |
else:
|
115 |
# IncrementalPCA
|
|
|
121 |
def transform(self, X: np.ndarray) -> np.ndarray:
|
122 |
"""
|
123 |
Project data into the PCA latent space in batches for memory efficiency.
|
124 |
+
|
125 |
+
Args:
|
126 |
+
X (np.ndarray): Shape (n_samples, n_features).
|
127 |
+
|
128 |
+
Returns:
|
129 |
+
np.ndarray: Latent embeddings of shape (n_samples, n_components).
|
130 |
"""
|
131 |
results = []
|
132 |
n_samples = X.shape[0]
|
|
|
139 |
def inverse_transform(self, Z: np.ndarray) -> np.ndarray:
|
140 |
"""
|
141 |
Reconstruct data from PCA latent space in batches.
|
142 |
+
|
143 |
+
Args:
|
144 |
+
Z (np.ndarray): Latent embeddings of shape (n_samples, n_components).
|
145 |
+
|
146 |
+
Returns:
|
147 |
+
np.ndarray: Reconstructed data of shape (n_samples, n_features).
|
148 |
"""
|
149 |
results = []
|
150 |
n_samples = Z.shape[0]
|
|
|
154 |
results.append(X_chunk)
|
155 |
return np.vstack(results)
|
156 |
|
157 |
+
def forward(self, X: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
|
158 |
"""
|
159 |
+
Mimic a linear AE's forward() returning (X_recon, Z).
|
160 |
+
|
161 |
+
Args:
|
162 |
+
X (np.ndarray): Original data of shape (n_samples, n_features).
|
163 |
+
|
164 |
+
Returns:
|
165 |
+
tuple[np.ndarray, np.ndarray]: (X_recon, Z).
|
166 |
"""
|
167 |
Z = self.transform(X)
|
168 |
X_recon = self.inverse_transform(Z)
|
169 |
return X_recon, Z
|
170 |
|
171 |
|
172 |
+
def load_and_flatten_dataset(
|
173 |
+
csv_path: str,
|
174 |
+
cache_dir: str,
|
175 |
+
transforms_fn: Transform
|
176 |
+
) -> np.ndarray:
|
177 |
"""
|
178 |
+
Load and flatten MRI volumes from the provided CSV.
|
179 |
+
|
180 |
1) Reads CSV.
|
181 |
+
2) Filters rows if 'split' in columns => only keep rows with split == 'train'.
|
182 |
+
3) Applies transforms to each image, flattening them into a 1D vector.
|
183 |
+
4) Returns a NumPy array X of shape (n_samples, 614400) after flattening.
|
184 |
+
|
185 |
+
Args:
|
186 |
+
csv_path (str): Path to a CSV containing at least 'image_path' column.
|
187 |
+
Optionally has a 'split' column.
|
188 |
+
cache_dir (str): Path to cache directory for MONAI PersistentDataset.
|
189 |
+
transforms_fn (Transform): MONAI transform pipeline.
|
190 |
+
|
191 |
+
Returns:
|
192 |
+
np.ndarray: Flattened image data of shape (n_samples, 614400).
|
193 |
"""
|
194 |
df = pd.read_csv(csv_path)
|
195 |
|
196 |
+
# Keep only 'train' samples if split column exists
|
197 |
if 'split' in df.columns:
|
198 |
df = df[df['split'] == 'train']
|
|
|
199 |
|
200 |
dataset = get_dataset_from_pd(df, transforms_fn, cache_dir)
|
201 |
loader = DataLoader(dataset, batch_size=1, num_workers=0)
|
202 |
|
|
|
203 |
X_list = []
|
204 |
for batch in loader:
|
205 |
+
# batch["image"] => shape (1, 1, 80, 96, 80)
|
206 |
+
img = batch["image"].squeeze(0) # => shape (1, 80, 96, 80)
|
207 |
+
flattened = img.numpy().flatten() # => (614400,)
|
|
|
208 |
X_list.append(flattened)
|
209 |
|
210 |
+
if not X_list:
|
211 |
+
raise ValueError(
|
212 |
+
"No training samples found (split='train'). Check your CSV or 'split' values."
|
213 |
+
)
|
214 |
|
215 |
X = np.vstack(X_list)
|
216 |
return X
|
217 |
|
218 |
|
219 |
+
def main() -> None:
|
220 |
+
"""
|
221 |
+
Main function to parse command-line arguments and fit a PCA or IncrementalPCA model,
|
222 |
+
then save embeddings and reconstructions.
|
223 |
+
"""
|
224 |
+
parser = argparse.ArgumentParser(
|
225 |
+
description="PCA Autoencoder with MONAI transforms and 'split' filtering."
|
226 |
+
)
|
227 |
+
parser.add_argument(
|
228 |
+
"--inputs_csv", type=str, required=True,
|
229 |
+
help="Path to CSV with at least 'image_path' column and optional 'split' column."
|
230 |
+
)
|
231 |
+
parser.add_argument(
|
232 |
+
"--cache_dir", type=str, default="",
|
233 |
+
help="Cache directory for MONAI PersistentDataset (optional)."
|
234 |
+
)
|
235 |
+
parser.add_argument(
|
236 |
+
"--output_dir", type=str, default="./pca_outputs",
|
237 |
+
help="Where to save PCA model and embeddings."
|
238 |
+
)
|
239 |
+
parser.add_argument(
|
240 |
+
"--batch_size_ipca", type=int, default=128,
|
241 |
+
help="Batch size for partial_fit or chunked transform."
|
242 |
+
)
|
243 |
+
parser.add_argument(
|
244 |
+
"--n_components", type=int, default=1200,
|
245 |
+
help="Number of PCA components to keep."
|
246 |
+
)
|
247 |
+
parser.add_argument(
|
248 |
+
"--pca_type", type=str, default="incremental",
|
249 |
+
choices=["incremental", "standard"],
|
250 |
+
help="Which PCA algorithm to use: 'incremental' or 'standard'."
|
251 |
+
)
|
252 |
args = parser.parse_args()
|
253 |
|
254 |
os.makedirs(args.output_dir, exist_ok=True)
|
255 |
|
|
|
256 |
transforms_fn = transforms.Compose([
|
257 |
transforms.CopyItemsD(keys={'image_path'}, names=['image']),
|
258 |
transforms.LoadImageD(image_only=True, keys=['image']),
|
259 |
transforms.EnsureChannelFirstD(keys=['image']),
|
260 |
transforms.SpacingD(pixdim=RESOLUTION, keys=['image']),
|
261 |
+
transforms.ResizeWithPadOrCropD(
|
262 |
+
spatial_size=INPUT_SHAPE_AE, mode='minimum', keys=['image']
|
263 |
+
),
|
264 |
transforms.ScaleIntensityD(minv=0, maxv=1, keys=['image']),
|
265 |
])
|
266 |
|
|
|
282 |
|
283 |
# Get embeddings & reconstruction
|
284 |
X_recon, Z = model.forward(X)
|
285 |
+
print("Embeddings shape:", Z.shape)
|
286 |
+
print("Reconstruction shape:", X_recon.shape)
|
287 |
|
288 |
+
# Save embeddings and reconstructions
|
289 |
embeddings_path = os.path.join(args.output_dir, "pca_embeddings.npy")
|
290 |
recons_path = os.path.join(args.output_dir, "pca_reconstructions.npy")
|
291 |
np.save(embeddings_path, Z)
|