jesseab commited on
Commit
0178b19
·
1 Parent(s): 546a8f8

Code updates

Browse files
inference_brain2vec_PCA.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ """
4
+ inference_brain2vec_PCA.py
5
+
6
+ Loads a pre-trained PCA-based Brain2Vec model (saved with joblib) and performs
7
+ inference on one or more input images. Produces embeddings (and optional
8
+ reconstructions) for each image.
9
+
10
+ Example usage:
11
+
12
+ python inference_brain2vec_PCA.py \
13
+ --pca_model /path/to/pca_model.joblib \
14
+ --input_images /path/to/img1.nii.gz /path/to/img2.nii.gz \
15
+ --output_dir /path/to/out
16
+
17
+ Or, if you have a CSV with image paths:
18
+
19
+ python inference_brain2vec_PCA.py \
20
+ --pca_model /path/to/pca_model.joblib \
21
+ --csv_input /path/to/images.csv \
22
+ --output_dir /path/to/out
23
+ """
24
+
25
+ import os
26
+ import argparse
27
+ import numpy as np
28
+ import torch
29
+ import torch.nn as nn
30
+ from joblib import load
31
+ import pandas as pd
32
+
33
+ from monai.transforms import (
34
+ Compose,
35
+ CopyItemsD,
36
+ LoadImageD,
37
+ EnsureChannelFirstD,
38
+ SpacingD,
39
+ ResizeWithPadOrCropD,
40
+ ScaleIntensityD,
41
+ )
42
+
43
+ # Global constants
44
+ RESOLUTION = 2
45
+ INPUT_SHAPE_AE = (80, 96, 80)
46
+ FLATTENED_DIM = INPUT_SHAPE_AE[0] * INPUT_SHAPE_AE[1] * INPUT_SHAPE_AE[2]
47
+
48
+ # Reusable MONAI pipeline for preprocessing
49
+ transforms_fn = Compose([
50
+ CopyItemsD(keys={'image_path'}, names=['image']),
51
+ LoadImageD(image_only=True, keys=['image']),
52
+ EnsureChannelFirstD(keys=['image']),
53
+ SpacingD(pixdim=RESOLUTION, keys=['image']),
54
+ ResizeWithPadOrCropD(spatial_size=INPUT_SHAPE_AE, mode='minimum', keys=['image']),
55
+ ScaleIntensityD(minv=0, maxv=1, keys=['image']),
56
+ ])
57
+
58
+
59
+ def preprocess_mri(image_path: str) -> torch.Tensor:
60
+ """
61
+ Preprocess an MRI using MONAI transforms to produce
62
+ a 5D Torch tensor: (batch=1, channel=1, D, H, W).
63
+
64
+ Args:
65
+ image_path (str): Path to the MRI (e.g., .nii.gz file).
66
+
67
+ Returns:
68
+ torch.Tensor: Preprocessed 5D tensor of shape (1, 1, D, H, W).
69
+ """
70
+ data_dict = {"image_path": image_path}
71
+ output_dict = transforms_fn(data_dict)
72
+ # shape => (1, D, H, W)
73
+ image_tensor = output_dict["image"].unsqueeze(0) # => (1, 1, D, H, W)
74
+ return image_tensor.float()
75
+
76
+
77
+ class PCABrain2vec(nn.Module):
78
+ """
79
+ A PCA-based 'autoencoder' that mimics a typical VAE interface:
80
+ - from_pretrained(...) to load a PCA model from disk
81
+ - forward(...) returns (reconstruction, embedding, None)
82
+
83
+ Steps:
84
+ 1. Flatten the input volume (N, 1, D, H, W) => (N, 614400).
85
+ 2. Transform -> embeddings => shape (N, n_components).
86
+ 3. Inverse transform -> recon => shape (N, 614400).
87
+ 4. Reshape => (N, 1, D, H, W).
88
+ """
89
+
90
+ def __init__(self, pca_model=None):
91
+ super().__init__()
92
+ self.pca_model = pca_model
93
+
94
+ def forward(self, x: torch.Tensor):
95
+ """
96
+ Perform a forward pass of the PCA-based "autoencoder".
97
+
98
+ Args:
99
+ x (torch.Tensor): Input of shape (N, 1, D, H, W).
100
+
101
+ Returns:
102
+ tuple(torch.Tensor, torch.Tensor, None):
103
+ - reconstruction: (N, 1, D, H, W)
104
+ - embedding: (N, n_components)
105
+ - None (to align with the typical VAE interface).
106
+ """
107
+ n_samples = x.shape[0]
108
+ x_cpu = x.detach().cpu().numpy() # (N, 1, D, H, W)
109
+ x_flat = x_cpu.reshape(n_samples, -1) # => (N, FLATTENED_DIM)
110
+
111
+ # PCA transform => embeddings shape (N, n_components)
112
+ embedding_np = self.pca_model.transform(x_flat)
113
+
114
+ # PCA inverse_transform => recon shape (N, FLATTENED_DIM)
115
+ recon_np = self.pca_model.inverse_transform(embedding_np)
116
+ recon_np = recon_np.reshape(n_samples, 1, *INPUT_SHAPE_AE)
117
+
118
+ # Convert back to torch
119
+ reconstruction_torch = torch.from_numpy(recon_np).float()
120
+ embedding_torch = torch.from_numpy(embedding_np).float()
121
+ return reconstruction_torch, embedding_torch, None
122
+
123
+ @staticmethod
124
+ def from_pretrained(pca_path: str) -> "PCABrain2vec":
125
+ """
126
+ Load a pre-trained PCA model (pickled or joblib) from disk.
127
+
128
+ Args:
129
+ pca_path (str): File path to the PCA model.
130
+
131
+ Returns:
132
+ PCABrain2vec: An instance wrapping the loaded PCA model.
133
+ """
134
+ if not os.path.exists(pca_path):
135
+ raise FileNotFoundError(f"Could not find PCA model at {pca_path}")
136
+
137
+ pca_model = load(pca_path)
138
+ return PCABrain2vec(pca_model=pca_model)
139
+
140
+
141
+ def main() -> None:
142
+ """
143
+ Main function to parse command-line arguments and run inference
144
+ with a pre-trained PCA Brain2Vec model.
145
+ """
146
+ parser = argparse.ArgumentParser(
147
+ description="PCA-based Brain2Vec Inference Script"
148
+ )
149
+ parser.add_argument(
150
+ "--pca_model", type=str, required=True,
151
+ help="Path to the saved PCA model (.joblib)."
152
+ )
153
+ parser.add_argument(
154
+ "--output_dir", type=str, default="./pca_inference_outputs",
155
+ help="Directory to save embeddings/reconstructions."
156
+ )
157
+ # Two ways to supply images: multiple files or a CSV
158
+ parser.add_argument(
159
+ "--input_images", type=str, nargs="*",
160
+ help="One or more image paths for inference."
161
+ )
162
+ parser.add_argument(
163
+ "--csv_input", type=str, default=None,
164
+ help="Path to a CSV containing column 'image_path'."
165
+ )
166
+ args = parser.parse_args()
167
+
168
+ os.makedirs(args.output_dir, exist_ok=True)
169
+
170
+ # Build the PCA model
171
+ pca_brain2vec = PCABrain2vec.from_pretrained(args.pca_model)
172
+ pca_brain2vec.eval()
173
+
174
+ # Gather image paths
175
+ if args.csv_input:
176
+ df = pd.read_csv(args.csv_input)
177
+ if "image_path" not in df.columns:
178
+ raise ValueError("CSV must contain a column named 'image_path'.")
179
+ image_paths = df["image_path"].tolist()
180
+ else:
181
+ if not args.input_images:
182
+ raise ValueError(
183
+ "Must provide either --csv_input or --input_images."
184
+ )
185
+ image_paths = args.input_images
186
+
187
+ # Inference loop
188
+ all_embeddings = []
189
+ for i, img_path in enumerate(image_paths):
190
+ if not os.path.exists(img_path):
191
+ raise FileNotFoundError(f"Image not found: {img_path}")
192
+
193
+ # Preprocess
194
+ img_tensor = preprocess_mri(img_path)
195
+
196
+ # Forward pass
197
+ with torch.no_grad():
198
+ recon, embedding, _ = pca_brain2vec(img_tensor)
199
+
200
+ # Convert to CPU numpy
201
+ embedding_np = embedding.detach().cpu().numpy()
202
+ recon_np = recon.detach().cpu().numpy()
203
+
204
+ # Save (one embedding row per image)
205
+ all_embeddings.append(embedding_np)
206
+
207
+ # Optionally save or visualize reconstructions
208
+ out_recon_path = os.path.join(
209
+ args.output_dir, f"reconstruction_{i}.npy"
210
+ )
211
+ np.save(out_recon_path, recon_np)
212
+ print(f"[INFO] Saved reconstruction to: {out_recon_path}")
213
+
214
+ # Save all embeddings stacked
215
+ stacked_embeddings = np.vstack(all_embeddings) # (N, n_components)
216
+ out_embed_path = os.path.join(args.output_dir, "all_pca_embeddings.npy")
217
+ np.save(out_embed_path, stacked_embeddings)
218
+ print(f"[INFO] Saved embeddings of shape {stacked_embeddings.shape} to: {out_embed_path}")
219
+
220
+
221
+ if __name__ == "__main__":
222
+ main()
model.py DELETED
@@ -1,115 +0,0 @@
1
- # model.py
2
- import os
3
- import numpy as np
4
- import torch
5
- import torch.nn as nn
6
-
7
- from monai.transforms import (
8
- Compose,
9
- CopyItemsD,
10
- LoadImageD,
11
- EnsureChannelFirstD,
12
- SpacingD,
13
- ResizeWithPadOrCropD,
14
- ScaleIntensityD,
15
- )
16
-
17
- # If you used joblib or pickle to save your PCA model:
18
- from joblib import load # or "import pickle"
19
-
20
- #################################################
21
- # Constants
22
- #################################################
23
- RESOLUTION = 2
24
- INPUT_SHAPE_AE = (80, 96, 80) # The typical shape from your pipelines
25
- FLATTENED_DIM = INPUT_SHAPE_AE[0] * INPUT_SHAPE_AE[1] * INPUT_SHAPE_AE[2]
26
-
27
-
28
- #################################################
29
- # Define MONAI Transforms for Preprocessing
30
- #################################################
31
- transforms_fn = Compose([
32
- CopyItemsD(keys={'image_path'}, names=['image']),
33
- LoadImageD(image_only=True, keys=['image']),
34
- EnsureChannelFirstD(keys=['image']),
35
- SpacingD(pixdim=RESOLUTION, keys=['image']),
36
- ResizeWithPadOrCropD(spatial_size=INPUT_SHAPE_AE, mode='minimum', keys=['image']),
37
- ScaleIntensityD(minv=0, maxv=1, keys=['image']),
38
- ])
39
-
40
-
41
- def preprocess_mri(image_path: str) -> torch.Tensor:
42
- """
43
- Preprocess an MRI using MONAI transforms to produce
44
- a 5D Torch tensor: (batch=1, channel=1, D, H, W).
45
- """
46
- data_dict = {"image_path": image_path}
47
- output_dict = transforms_fn(data_dict)
48
- # shape => (1, D, H, W)
49
- image_tensor = output_dict["image"].unsqueeze(0) # => (batch=1, channel=1, D, H, W)
50
- return image_tensor.float() # typically float32
51
-
52
-
53
- #################################################
54
- # PCA "Autoencoder" Wrapper
55
- #################################################
56
- class PCABrain2vec(nn.Module):
57
- """
58
- A PCA-based 'autoencoder' that mimics the old interface:
59
- - from_pretrained(...) to load a PCA model from disk
60
- - forward(...) returns (reconstruction, embedding, None)
61
-
62
- Under the hood, it:
63
- - takes in a torch tensor shape (N, 1, D, H, W)
64
- - flattens it (N, 614400)
65
- - uses PCA's transform(...) to get embeddings => shape (N, n_components)
66
- - uses inverse_transform(...) to get reconstructions => shape (N, 614400)
67
- - reshapes back to (N, 1, D, H, W)
68
- """
69
-
70
- def __init__(self, pca_model=None):
71
- super().__init__()
72
- # We'll store the fitted PCA model (from scikit-learn)
73
- self.pca_model = pca_model # e.g., an instance of IncrementalPCA or PCA
74
-
75
- def forward(self, x: torch.Tensor):
76
- """
77
- Returns (reconstruction, embedding, None).
78
-
79
- 1) Convert x => numpy array => flatten => (N, 614400)
80
- 2) embedding = pca_model.transform(flat_x)
81
- 3) reconstruction_np = pca_model.inverse_transform(embedding)
82
- 4) reshape => (N, 1, 80, 96, 80)
83
- 5) convert to torch => return (recon, embed, None)
84
- """
85
- # Expect x shape => (N, 1, D, H, W) => flatten to (N, D*H*W)
86
- n_samples = x.shape[0]
87
- # Convert to CPU np
88
- x_cpu = x.detach().cpu().numpy() # shape: (N, 1, D, H, W)
89
- x_flat = x_cpu.reshape(n_samples, -1) # shape: (N, 614400)
90
-
91
- # PCA transform => embeddings shape (N, n_components)
92
- embedding_np = self.pca_model.transform(x_flat)
93
-
94
- # PCA inverse_transform => recon shape (N, 614400)
95
- recon_np = self.pca_model.inverse_transform(embedding_np)
96
- # Reshape back => (N, 1, 80, 96, 80)
97
- recon_np = recon_np.reshape(n_samples, 1, *INPUT_SHAPE_AE)
98
-
99
- # Convert back to torch
100
- reconstruction_torch = torch.from_numpy(recon_np).float()
101
- embedding_torch = torch.from_numpy(embedding_np).float()
102
- return reconstruction_torch, embedding_torch, None
103
-
104
- @staticmethod
105
- def from_pretrained(pca_path: str):
106
- """
107
- Load a pre-trained PCA model (pickled or joblib).
108
- Returns an instance of PCABrain2vec with that model.
109
- """
110
- if not os.path.exists(pca_path):
111
- raise FileNotFoundError(f"Could not find PCA model at {pca_path}")
112
- # Example: pca_model = pickle.load(open(pca_path, 'rb'))
113
- # or use joblib:
114
- pca_model = load(pca_path)
115
- return PCABrain2vec(pca_model=pca_model)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements.txt CHANGED
@@ -1,12 +1,15 @@
1
  # requirements.txt
2
 
3
- # PyTorch (CUDA or CPU version). For GPU install, see PyTorch docs for the correct wheel.
4
  torch>=1.12
5
 
6
- # MONAI v1.2+ has the 'generative' subpackage with AutoencoderKL, PatchDiscriminator, etc.
7
- monai-weekly
8
  monai-generative
9
 
 
 
 
 
10
  # For perceptual losses in MONAI's generative module.
11
  lpips
12
 
 
1
  # requirements.txt
2
 
3
+ # PyTorch (CUDA or CPU version).
4
  torch>=1.12
5
 
6
+ # Install MONAI Generative first
 
7
  monai-generative
8
 
9
+ # Now force reinstall MONAI Weekly so its (newer) MONAI version takes precedence
10
+ --force-reinstall
11
+ monai-weekly
12
+
13
  # For perceptual losses in MONAI's generative module.
14
  lpips
15
 
brain2vec_PCA.py → train_brain2vec_PCA.py RENAMED
@@ -1,101 +1,115 @@
1
  #!/usr/bin/env python3
2
 
3
  """
4
- pca_autoencoder.py
5
 
6
- Adjustments requested:
7
- 1. Only fit on scans with a 'train' label in the inputs.csv 'split' column.
8
- 2. An option to either run incremental PCA or standard PCA.
9
 
10
  Example usage:
11
- python pca_autoencoder.py \
12
  --inputs_csv /path/to/inputs.csv \
13
  --output_dir ./pca_outputs \
14
  --pca_type standard \
15
- --n_components 100
16
  """
17
 
18
  import os
19
  import argparse
20
  import numpy as np
21
  import pandas as pd
22
-
23
  import torch
24
  from torch.utils.data import DataLoader
25
-
26
  from monai import transforms
27
  from monai.data import Dataset, PersistentDataset
28
-
29
- # We'll import both PCA classes, and decide which to use based on CLI arg.
30
  from sklearn.decomposition import PCA, IncrementalPCA
 
31
 
32
-
33
- ###################################################################
34
- # Constants for your typical config
35
- ###################################################################
36
  RESOLUTION = 2
 
 
37
  INPUT_SHAPE_AE = (80, 96, 80)
 
38
  DEFAULT_N_COMPONENTS = 1200
39
 
40
 
41
- ###################################################################
42
- # Helper: get_dataset_from_pd (same as in brain2vec_linearAE.py)
43
- ###################################################################
44
- def get_dataset_from_pd(df: pd.DataFrame, transforms_fn, cache_dir: str):
 
45
  """
46
- Returns a monai.data.Dataset or monai.data.PersistentDataset
47
- if `cache_dir` is defined, to speed up loading.
 
 
 
 
 
 
 
48
  """
 
49
  if cache_dir and cache_dir.strip():
50
  os.makedirs(cache_dir, exist_ok=True)
51
- dataset = PersistentDataset(data=df.to_dict(orient='records'),
52
- transform=transforms_fn,
53
- cache_dir=cache_dir)
 
 
54
  else:
55
- dataset = Dataset(data=df.to_dict(orient='records'),
56
- transform=transforms_fn)
57
  return dataset
58
 
59
 
60
- ###################################################################
61
- # PCAAutoencoder
62
- ###################################################################
63
  class PCAAutoencoder:
64
  """
65
  A PCA 'autoencoder' that can use either standard PCA or IncrementalPCA:
66
  - fit(X): trains the model
67
  - transform(X): get embeddings
68
  - inverse_transform(Z): reconstruct data from embeddings
69
- - forward(X): returns (X_recon, Z)
70
-
71
- If using standard PCA, we do a single call to .fit(X).
72
- If using incremental PCA, we do .partial_fit on data in batches.
73
  """
74
- def __init__(self, n_components=DEFAULT_N_COMPONENTS, batch_size=128, pca_type='incremental'):
 
 
 
 
 
 
75
  """
 
 
76
  Args:
77
- n_components (int): number of principal components to keep
78
- batch_size (int): chunk size for either partial_fit or chunked .transform
79
- pca_type (str): 'incremental' or 'standard'
80
  """
81
  self.n_components = n_components
82
  self.batch_size = batch_size
83
  self.pca_type = pca_type.lower()
84
 
85
- if self.pca_type == 'standard':
86
- self.ipca = PCA(n_components=self.n_components, svd_solver='randomized')
87
- else:
88
- # default to incremental
89
  self.ipca = IncrementalPCA(n_components=self.n_components)
 
 
 
90
 
91
- def fit(self, X: np.ndarray):
92
  """
93
- Fit the PCA model. If incremental, calls partial_fit in batches.
94
- If standard, calls .fit once on the entire data matrix.
95
- X: shape (n_samples, n_features)
 
 
96
  """
97
  if self.pca_type == 'standard':
98
- # Potentially large memory usage, so be sure your system can handle it.
99
  self.ipca.fit(X)
100
  else:
101
  # IncrementalPCA
@@ -107,7 +121,12 @@ class PCAAutoencoder:
107
  def transform(self, X: np.ndarray) -> np.ndarray:
108
  """
109
  Project data into the PCA latent space in batches for memory efficiency.
110
- Returns Z with shape (n_samples, n_components)
 
 
 
 
 
111
  """
112
  results = []
113
  n_samples = X.shape[0]
@@ -120,7 +139,12 @@ class PCAAutoencoder:
120
  def inverse_transform(self, Z: np.ndarray) -> np.ndarray:
121
  """
122
  Reconstruct data from PCA latent space in batches.
123
- Returns X_recon with shape (n_samples, n_features).
 
 
 
 
 
124
  """
125
  results = []
126
  n_samples = Z.shape[0]
@@ -130,80 +154,113 @@ class PCAAutoencoder:
130
  results.append(X_chunk)
131
  return np.vstack(results)
132
 
133
- def forward(self, X: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
134
  """
135
- Mimics a linear AE's forward() returning (X_recon, Z).
 
 
 
 
 
 
136
  """
137
  Z = self.transform(X)
138
  X_recon = self.inverse_transform(Z)
139
  return X_recon, Z
140
 
141
 
142
- ###################################################################
143
- # Load and Flatten Data
144
- ###################################################################
145
- def load_and_flatten_dataset(csv_path: str, cache_dir: str, transforms_fn) -> np.ndarray:
 
146
  """
 
 
147
  1) Reads CSV.
148
- 2) Filters rows if 'split' in columns => only keep 'split' == 'train'.
149
- 3) Applies transforms to each image, flattening them into a 1D vector (614,400).
150
- 4) Returns a NumPy array X: shape (n_samples, 614400).
 
 
 
 
 
 
 
 
 
151
  """
152
  df = pd.read_csv(csv_path)
153
 
154
- # Filter only 'train' if the split column exists
155
  if 'split' in df.columns:
156
  df = df[df['split'] == 'train']
157
- # If there is no 'split' column, we assume the entire CSV is for training.
158
 
159
  dataset = get_dataset_from_pd(df, transforms_fn, cache_dir)
160
  loader = DataLoader(dataset, batch_size=1, num_workers=0)
161
 
162
- # We'll store each flattened volume in a list, then stack
163
  X_list = []
164
  for batch in loader:
165
- # batch["image"] shape => (1, 1, 80, 96, 80)
166
- img = batch["image"].squeeze(0) # => (1, 80, 96, 80)
167
- img_np = img.numpy()
168
- flattened = img_np.flatten() # => (614400,)
169
  X_list.append(flattened)
170
 
171
- if len(X_list) == 0:
172
- raise ValueError("No training samples found (split='train'). Check your CSV or 'split' values.")
 
 
173
 
174
  X = np.vstack(X_list)
175
  return X
176
 
177
 
178
- ###################################################################
179
- # Main
180
- ###################################################################
181
- def main():
182
- parser = argparse.ArgumentParser(description="PCA Autoencoder with MONAI transforms and 'split' filtering.")
183
- parser.add_argument("--inputs_csv", type=str, required=True,
184
- help="Path to CSV with at least 'image_path' column, optional 'split' column.")
185
- parser.add_argument("--cache_dir", type=str, default="",
186
- help="Cache directory for MONAI PersistentDataset (optional).")
187
- parser.add_argument("--output_dir", type=str, default="./pca_outputs",
188
- help="Where to save PCA model and embeddings.")
189
- parser.add_argument("--batch_size_ipca", type=int, default=128,
190
- help="Batch size for partial_fit or chunked transform.")
191
- parser.add_argument("--n_components", type=int, default=1200,
192
- help="Number of PCA components to keep.")
193
- parser.add_argument("--pca_type", type=str, default="incremental",
194
- choices=["incremental", "standard"],
195
- help="Which PCA algorithm to use: 'incremental' or 'standard'.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
196
  args = parser.parse_args()
197
 
198
  os.makedirs(args.output_dir, exist_ok=True)
199
 
200
- # define transforms as in brain2vec_linearAE.py
201
  transforms_fn = transforms.Compose([
202
  transforms.CopyItemsD(keys={'image_path'}, names=['image']),
203
  transforms.LoadImageD(image_only=True, keys=['image']),
204
  transforms.EnsureChannelFirstD(keys=['image']),
205
  transforms.SpacingD(pixdim=RESOLUTION, keys=['image']),
206
- transforms.ResizeWithPadOrCropD(spatial_size=INPUT_SHAPE_AE, mode='minimum', keys=['image']),
 
 
207
  transforms.ScaleIntensityD(minv=0, maxv=1, keys=['image']),
208
  ])
209
 
@@ -225,10 +282,10 @@ def main():
225
 
226
  # Get embeddings & reconstruction
227
  X_recon, Z = model.forward(X)
228
- print("Embeddings shape:", Z.shape) # (n_samples, n_components)
229
- print("Reconstruction shape:", X_recon.shape) # (n_samples, 614400)
230
 
231
- # Save
232
  embeddings_path = os.path.join(args.output_dir, "pca_embeddings.npy")
233
  recons_path = os.path.join(args.output_dir, "pca_reconstructions.npy")
234
  np.save(embeddings_path, Z)
 
1
  #!/usr/bin/env python3
2
 
3
  """
4
+ train_brain2vec_PCA.py
5
 
6
+ A PCA-based "autoencoder" script for brain MRI data, with support for both
7
+ incremental PCA and standard PCA. Only scans labeled 'train' in the CSV
8
+ (split == 'train') will be used for fitting.
9
 
10
  Example usage:
11
+ python train_brain2vec_PCA.py \
12
  --inputs_csv /path/to/inputs.csv \
13
  --output_dir ./pca_outputs \
14
  --pca_type standard \
15
+ --n_components 1200
16
  """
17
 
18
  import os
19
  import argparse
20
  import numpy as np
21
  import pandas as pd
 
22
  import torch
23
  from torch.utils.data import DataLoader
 
24
  from monai import transforms
25
  from monai.data import Dataset, PersistentDataset
26
+ from monai.transforms.transform import Transform
 
27
  from sklearn.decomposition import PCA, IncrementalPCA
28
+ from typing import Optional, Union, Tuple
29
 
30
+ # voxel resolution
 
 
 
31
  RESOLUTION = 2
32
+
33
+ # cropped image dimensions after transform
34
  INPUT_SHAPE_AE = (80, 96, 80)
35
+
36
  DEFAULT_N_COMPONENTS = 1200
37
 
38
 
39
+ def get_dataset_from_pd(
40
+ df: pd.DataFrame,
41
+ transforms_fn: Transform,
42
+ cache_dir: Optional[str]
43
+ ) -> Union[Dataset, PersistentDataset]:
44
  """
45
+ Create a MONAI Dataset or PersistentDataset from the given DataFrame.
46
+
47
+ Args:
48
+ df (pd.DataFrame): DataFrame with at least 'image_path' column.
49
+ transforms_fn (Transform): MONAI transform pipeline.
50
+ cache_dir (Optional[str]): If provided, use PersistentDataset caching.
51
+
52
+ Returns:
53
+ Dataset|PersistentDataset: A dataset for training or inference.
54
  """
55
+ data_dicts = df.to_dict(orient='records')
56
  if cache_dir and cache_dir.strip():
57
  os.makedirs(cache_dir, exist_ok=True)
58
+ dataset = PersistentDataset(
59
+ data=data_dicts,
60
+ transform=transforms_fn,
61
+ cache_dir=cache_dir
62
+ )
63
  else:
64
+ dataset = Dataset(data=data_dicts, transform=transforms_fn)
 
65
  return dataset
66
 
67
 
 
 
 
68
  class PCAAutoencoder:
69
  """
70
  A PCA 'autoencoder' that can use either standard PCA or IncrementalPCA:
71
  - fit(X): trains the model
72
  - transform(X): get embeddings
73
  - inverse_transform(Z): reconstruct data from embeddings
74
+ - forward(X): returns (X_recon, Z).
75
+
76
+ If using standard PCA, a single call to .fit(X) is made.
77
+ If using incremental PCA, .partial_fit is called in batches.
78
  """
79
+
80
+ def __init__(
81
+ self,
82
+ n_components: int = DEFAULT_N_COMPONENTS,
83
+ batch_size: int = 128,
84
+ pca_type: str = 'standard'
85
+ ) -> None:
86
  """
87
+ Initialize the PCAAutoencoder.
88
+
89
  Args:
90
+ n_components (int): Number of principal components to keep.
91
+ batch_size (int): Chunk size for partial_fit or chunked transform.
92
+ pca_type (str): Either 'incremental' or 'standard'.
93
  """
94
  self.n_components = n_components
95
  self.batch_size = batch_size
96
  self.pca_type = pca_type.lower()
97
 
98
+ if self.pca_type == 'incremental':
 
 
 
99
  self.ipca = IncrementalPCA(n_components=self.n_components)
100
+ else:
101
+ # Default to standard PCA
102
+ self.ipca = PCA(n_components=self.n_components, svd_solver='randomized')
103
 
104
+ def fit(self, X: np.ndarray) -> None:
105
  """
106
+ Fit the PCA model. If incremental PCA, calls partial_fit in batches;
107
+ otherwise calls .fit once on the entire data array.
108
+
109
+ Args:
110
+ X (np.ndarray): Shape (n_samples, n_features).
111
  """
112
  if self.pca_type == 'standard':
 
113
  self.ipca.fit(X)
114
  else:
115
  # IncrementalPCA
 
121
  def transform(self, X: np.ndarray) -> np.ndarray:
122
  """
123
  Project data into the PCA latent space in batches for memory efficiency.
124
+
125
+ Args:
126
+ X (np.ndarray): Shape (n_samples, n_features).
127
+
128
+ Returns:
129
+ np.ndarray: Latent embeddings of shape (n_samples, n_components).
130
  """
131
  results = []
132
  n_samples = X.shape[0]
 
139
  def inverse_transform(self, Z: np.ndarray) -> np.ndarray:
140
  """
141
  Reconstruct data from PCA latent space in batches.
142
+
143
+ Args:
144
+ Z (np.ndarray): Latent embeddings of shape (n_samples, n_components).
145
+
146
+ Returns:
147
+ np.ndarray: Reconstructed data of shape (n_samples, n_features).
148
  """
149
  results = []
150
  n_samples = Z.shape[0]
 
154
  results.append(X_chunk)
155
  return np.vstack(results)
156
 
157
+ def forward(self, X: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
158
  """
159
+ Mimic a linear AE's forward() returning (X_recon, Z).
160
+
161
+ Args:
162
+ X (np.ndarray): Original data of shape (n_samples, n_features).
163
+
164
+ Returns:
165
+ tuple[np.ndarray, np.ndarray]: (X_recon, Z).
166
  """
167
  Z = self.transform(X)
168
  X_recon = self.inverse_transform(Z)
169
  return X_recon, Z
170
 
171
 
172
+ def load_and_flatten_dataset(
173
+ csv_path: str,
174
+ cache_dir: str,
175
+ transforms_fn: Transform
176
+ ) -> np.ndarray:
177
  """
178
+ Load and flatten MRI volumes from the provided CSV.
179
+
180
  1) Reads CSV.
181
+ 2) Filters rows if 'split' in columns => only keep rows with split == 'train'.
182
+ 3) Applies transforms to each image, flattening them into a 1D vector.
183
+ 4) Returns a NumPy array X of shape (n_samples, 614400) after flattening.
184
+
185
+ Args:
186
+ csv_path (str): Path to a CSV containing at least 'image_path' column.
187
+ Optionally has a 'split' column.
188
+ cache_dir (str): Path to cache directory for MONAI PersistentDataset.
189
+ transforms_fn (Transform): MONAI transform pipeline.
190
+
191
+ Returns:
192
+ np.ndarray: Flattened image data of shape (n_samples, 614400).
193
  """
194
  df = pd.read_csv(csv_path)
195
 
196
+ # Keep only 'train' samples if split column exists
197
  if 'split' in df.columns:
198
  df = df[df['split'] == 'train']
 
199
 
200
  dataset = get_dataset_from_pd(df, transforms_fn, cache_dir)
201
  loader = DataLoader(dataset, batch_size=1, num_workers=0)
202
 
 
203
  X_list = []
204
  for batch in loader:
205
+ # batch["image"] => shape (1, 1, 80, 96, 80)
206
+ img = batch["image"].squeeze(0) # => shape (1, 80, 96, 80)
207
+ flattened = img.numpy().flatten() # => (614400,)
 
208
  X_list.append(flattened)
209
 
210
+ if not X_list:
211
+ raise ValueError(
212
+ "No training samples found (split='train'). Check your CSV or 'split' values."
213
+ )
214
 
215
  X = np.vstack(X_list)
216
  return X
217
 
218
 
219
+ def main() -> None:
220
+ """
221
+ Main function to parse command-line arguments and fit a PCA or IncrementalPCA model,
222
+ then save embeddings and reconstructions.
223
+ """
224
+ parser = argparse.ArgumentParser(
225
+ description="PCA Autoencoder with MONAI transforms and 'split' filtering."
226
+ )
227
+ parser.add_argument(
228
+ "--inputs_csv", type=str, required=True,
229
+ help="Path to CSV with at least 'image_path' column and optional 'split' column."
230
+ )
231
+ parser.add_argument(
232
+ "--cache_dir", type=str, default="",
233
+ help="Cache directory for MONAI PersistentDataset (optional)."
234
+ )
235
+ parser.add_argument(
236
+ "--output_dir", type=str, default="./pca_outputs",
237
+ help="Where to save PCA model and embeddings."
238
+ )
239
+ parser.add_argument(
240
+ "--batch_size_ipca", type=int, default=128,
241
+ help="Batch size for partial_fit or chunked transform."
242
+ )
243
+ parser.add_argument(
244
+ "--n_components", type=int, default=1200,
245
+ help="Number of PCA components to keep."
246
+ )
247
+ parser.add_argument(
248
+ "--pca_type", type=str, default="incremental",
249
+ choices=["incremental", "standard"],
250
+ help="Which PCA algorithm to use: 'incremental' or 'standard'."
251
+ )
252
  args = parser.parse_args()
253
 
254
  os.makedirs(args.output_dir, exist_ok=True)
255
 
 
256
  transforms_fn = transforms.Compose([
257
  transforms.CopyItemsD(keys={'image_path'}, names=['image']),
258
  transforms.LoadImageD(image_only=True, keys=['image']),
259
  transforms.EnsureChannelFirstD(keys=['image']),
260
  transforms.SpacingD(pixdim=RESOLUTION, keys=['image']),
261
+ transforms.ResizeWithPadOrCropD(
262
+ spatial_size=INPUT_SHAPE_AE, mode='minimum', keys=['image']
263
+ ),
264
  transforms.ScaleIntensityD(minv=0, maxv=1, keys=['image']),
265
  ])
266
 
 
282
 
283
  # Get embeddings & reconstruction
284
  X_recon, Z = model.forward(X)
285
+ print("Embeddings shape:", Z.shape)
286
+ print("Reconstruction shape:", X_recon.shape)
287
 
288
+ # Save embeddings and reconstructions
289
  embeddings_path = os.path.join(args.output_dir, "pca_embeddings.npy")
290
  recons_path = os.path.join(args.output_dir, "pca_reconstructions.npy")
291
  np.save(embeddings_path, Z)