BrainAgeNeXt / BrainAge_estimation.py
FrancescoLR's picture
Upload BrainAge_estimation.py
6760104 verified
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
@author: Francesco La Rosa
"""
import sys
import pandas as pd
import torch
from torch.utils.data import DataLoader
from monai.transforms import Compose, LoadImaged, ScaleIntensityd, Spacingd, CropForegroundd, SpatialPadd, CenterSpatialCropd
from monai.data import CacheDataset
import numpy as np
import os
import torchio
import torch.nn as nn
import matplotlib.pyplot as plt
from nnunet_mednext import create_mednext_v1, create_mednext_encoder_v1
class MedNeXtEncReg(nn.Module):
def __init__(self, *args, **kwargs):
super(MedNeXtEncReg, self).__init__()
self.mednextv1 = create_mednext_encoder_v1(num_input_channels=1, num_classes=1, model_id='B', kernel_size=3, deep_supervision=True)
self.global_avg_pool = nn.AdaptiveAvgPool3d((1, 1, 1))
self.regression_fc = nn.Sequential(
nn.Linear(512, 64),
nn.ReLU(),
nn.Dropout(0.0),
nn.Linear(64, 1)
)
def forward(self, x):
mednext_out = self.mednextv1(x)
x = mednext_out
x = self.global_avg_pool(x)
x = torch.flatten(x, start_dim=1)
age_estimate = self.regression_fc(x)
return age_estimate.squeeze()
def prepare_transforms():
x, y, z = (160, 192, 160)
p = 1.0
monai_transforms = [
LoadImaged(keys=["image"], ensure_channel_first=True),
Spacingd(keys=["image"], pixdim=(p, p, p)),
CropForegroundd(keys=["image"], allow_smaller=True, source_key="image"),
SpatialPadd(keys=["image"], spatial_size=(x, y, z)),
CenterSpatialCropd(keys=["image"], roi_size=(x, y, z))
]
val_torchio_transforms = torchio.transforms.Compose(
[torchio.transforms.ZNormalization(masking_method=lambda x: x > 0, keys=["image"], include=['image'])]
)
return Compose(monai_transforms + [val_torchio_transforms])
def load_data(csv_file):
df = pd.read_csv(csv_file)
df.dropna(subset=['Path'], inplace=True)
df.dropna(subset=['Age'], inplace=True)
data_dicts = [{'image': row['Path'], 'label': row['Age']} for index, row in df.iterrows()]
return df, data_dicts
def create_dataloader(data_dicts, transforms):
dataset = CacheDataset(data=data_dicts, transform=transforms, cache_rate=0.2, num_workers=4)
dataloader = DataLoader(dataset, batch_size=1, num_workers=4, shuffle=False, pin_memory=torch.cuda.is_available())
return dataloader
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def initialize_model():
torch.cuda.empty_cache()
return MedNeXtEncReg().to(device)
def run_predictions(model_path, dataloader):
model = initialize_model()
model.load_state_dict(torch.load(model_path))
model.eval()
predictions = []
with torch.no_grad():
for batch_data in dataloader:
images = batch_data['image'].to(device)
pred = model(images)
predictions.append(pred.cpu().numpy())
del model
torch.cuda.empty_cache()
return np.array(predictions)
def main(csv_file):
df, data_dicts = load_data(csv_file)
transforms = prepare_transforms()
dataloader = create_dataloader(data_dicts, transforms)
model_paths = [
os.path.join(os.path.dirname(__file__), f'BrainAge_{i}.pth') for i in range(1, 6)
]
predictions_list = [run_predictions(model_path, dataloader) for model_path in model_paths]
average_predictions = np.median(np.stack(predictions_list), axis=0)
CA = df['Age'].values
BA = average_predictions.flatten()
BA_corr = np.where(CA > 18, BA + (CA * 0.062) - 2.96, BA)
BAD_corr = BA_corr - CA
df['Predicted_Brain_Age'] = BA_corr
df['Brain_Age_Difference'] = BAD_corr
df.to_csv(csv_file.replace('.csv', '_with_predictions.csv'), index=False)
print('Updated CSV file saved.')
if __name__ == '__main__':
if len(sys.argv) != 2:
print("Error: No .csv file provided.")
print("Usage: python script.py <path_to_csv_file>")
print("Please provide the path to a .csv file as the argument. This file should contain columns for 'Path' and 'Age' for all subjects.")
sys.exit(1)
csv_file = sys.argv[1]
main(csv_file)