Upload 4 files
Browse files- utils/BiasFieldCorrection.py +91 -0
- utils/__init__.py +0 -0
- utils/data_splitter.py +79 -0
- utils/validation.py +85 -0
utils/BiasFieldCorrection.py
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
import glob
|
4 |
+
import ants
|
5 |
+
|
6 |
+
def abp_n4(image, intensity_truncation=(0.025, 0.975, 256), mask=None,
|
7 |
+
use_n3=False):
|
8 |
+
"""
|
9 |
+
Apply intensity truncation and bias correction to an image.
|
10 |
+
|
11 |
+
Parameters:
|
12 |
+
- image: Image to be processed (ANTsImage object).
|
13 |
+
- intensity_truncation: Tuple (lower quantile, upper quantile, number of bins).
|
14 |
+
- mask: Optional mask for bias correction (ANTsImage object).
|
15 |
+
- use_n3: Use N3 bias correction instead of N4 if True.
|
16 |
+
|
17 |
+
Returns:
|
18 |
+
- Processed image (ANTsImage object).
|
19 |
+
"""
|
20 |
+
if not isinstance(intensity_truncation, (list, tuple)) or \
|
21 |
+
len(intensity_truncation) != 3:
|
22 |
+
raise ValueError("intensity_truncation must be list/tuple with 3 values")
|
23 |
+
|
24 |
+
# Apply intensity truncation
|
25 |
+
truncated_image = ants.iMath(image, "TruncateIntensity",
|
26 |
+
intensity_truncation[0], intensity_truncation[1],
|
27 |
+
intensity_truncation[2])
|
28 |
+
|
29 |
+
# Apply bias correction
|
30 |
+
if use_n3:
|
31 |
+
corrected_image = ants.n3_bias_field_correction(truncated_image, mask=mask)
|
32 |
+
else:
|
33 |
+
corrected_image = ants.n4_bias_field_correction(truncated_image, mask=mask)
|
34 |
+
|
35 |
+
return corrected_image
|
36 |
+
|
37 |
+
|
38 |
+
def preprocess_image(file_path, output_path, preprocess_type):
|
39 |
+
"""
|
40 |
+
Read, preprocess, and write an image based on the specified method.
|
41 |
+
|
42 |
+
Parameters:
|
43 |
+
- file_path: Path to the input image file.
|
44 |
+
- output_path: Path to the output folder.
|
45 |
+
- preprocess_type: Preprocessing method (n3, n4, or abp).
|
46 |
+
"""
|
47 |
+
image = ants.image_read(file_path)
|
48 |
+
|
49 |
+
if preprocess_type == "n3":
|
50 |
+
processed_image = ants.n3_bias_field_correction(image)
|
51 |
+
elif preprocess_type == "n4":
|
52 |
+
processed_image = ants.n4_bias_field_correction(image)
|
53 |
+
elif preprocess_type == "abp":
|
54 |
+
processed_image = abp_n4(image)
|
55 |
+
else:
|
56 |
+
raise ValueError("Invalid preprocess type")
|
57 |
+
|
58 |
+
ants.image_write(processed_image, os.path.join(output_path,
|
59 |
+
os.path.basename(file_path)))
|
60 |
+
|
61 |
+
|
62 |
+
def main():
|
63 |
+
"""
|
64 |
+
Main function to handle command line arguments and process images.
|
65 |
+
"""
|
66 |
+
parser = argparse.ArgumentParser(
|
67 |
+
description="Image Preprocessing Script for Bias Field Correction")
|
68 |
+
parser.add_argument('--input', type=str, required=True,
|
69 |
+
help="Input file or folder path")
|
70 |
+
parser.add_argument('--output', type=str, default='output',
|
71 |
+
help="Output folder path")
|
72 |
+
parser.add_argument('--type', type=str, default='abp',
|
73 |
+
choices=['n3', 'n4', 'abp'],
|
74 |
+
help="Type of preprocessing (n3, n4, abp)")
|
75 |
+
|
76 |
+
args = parser.parse_args()
|
77 |
+
|
78 |
+
if not os.path.exists(args.output):
|
79 |
+
os.makedirs(args.output)
|
80 |
+
|
81 |
+
files = [args.input] if os.path.isfile(args.input) else \
|
82 |
+
glob.glob(os.path.join(args.input, '*'))
|
83 |
+
|
84 |
+
for file in files:
|
85 |
+
preprocess_image(file, args.output, args.type)
|
86 |
+
|
87 |
+
print("----- Bias Field Correction Process completed -----")
|
88 |
+
|
89 |
+
|
90 |
+
if __name__ == '__main__':
|
91 |
+
main()
|
utils/__init__.py
ADDED
File without changes
|
utils/data_splitter.py
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import shutil
|
3 |
+
import random
|
4 |
+
import argparse
|
5 |
+
|
6 |
+
class DataSplitter:
|
7 |
+
"""Class to split Nii or nii.gz files into training and validation datasets."""
|
8 |
+
|
9 |
+
def __init__(self, input_folder, train_ratio, train_folder, val_folder):
|
10 |
+
"""
|
11 |
+
Initialize the DataSplitter with folder paths and data ratio.
|
12 |
+
|
13 |
+
:param input_folder: Path to the folder containing the Nii files.
|
14 |
+
:param train_ratio: Ratio of the dataset to be used as training data.
|
15 |
+
:param train_folder: Name of the folder to store training data.
|
16 |
+
:param val_folder: Name of the folder to store validation data.
|
17 |
+
"""
|
18 |
+
self.input_folder = input_folder
|
19 |
+
self.train_ratio = train_ratio
|
20 |
+
self.train_folder = os.path.join(input_folder, train_folder)
|
21 |
+
self.val_folder = os.path.join(input_folder, val_folder)
|
22 |
+
|
23 |
+
def split_data(self):
|
24 |
+
"""
|
25 |
+
Split the data into training and validation datasets and move files
|
26 |
+
to the respective folders.
|
27 |
+
"""
|
28 |
+
# List all Nii or nii.gz files in the input folder
|
29 |
+
files = [f for f in os.listdir(self.input_folder)
|
30 |
+
if f.endswith('.nii') or f.endswith('.nii.gz')]
|
31 |
+
random.shuffle(files) # Shuffle files for random splitting
|
32 |
+
|
33 |
+
# Determine the split index based on the training ratio
|
34 |
+
split_index = int(len(files) * self.train_ratio)
|
35 |
+
train_files = files[:split_index]
|
36 |
+
val_files = files[split_index:]
|
37 |
+
|
38 |
+
# Create training and validation folders if they don't exist
|
39 |
+
os.makedirs(self.train_folder, exist_ok=True)
|
40 |
+
os.makedirs(self.val_folder, exist_ok=True)
|
41 |
+
|
42 |
+
# Move files to the respective folders
|
43 |
+
for file in train_files:
|
44 |
+
shutil.move(os.path.join(self.input_folder, file),
|
45 |
+
self.train_folder)
|
46 |
+
|
47 |
+
for file in val_files:
|
48 |
+
shutil.move(os.path.join(self.input_folder, file),
|
49 |
+
self.val_folder)
|
50 |
+
|
51 |
+
print(f"Files split into {len(train_files)} training "
|
52 |
+
f"and {len(val_files)} validation.")
|
53 |
+
|
54 |
+
def main():
|
55 |
+
"""
|
56 |
+
Main function to handle command line arguments and initiate data splitting.
|
57 |
+
"""
|
58 |
+
parser = argparse.ArgumentParser(
|
59 |
+
description='Split Nii files into training and validation datasets.')
|
60 |
+
|
61 |
+
# Define command line arguments
|
62 |
+
parser.add_argument('--input', type=str, required=True,
|
63 |
+
help='Input folder with Nii files.')
|
64 |
+
parser.add_argument('--train-ratio', type=float, default=0.8,
|
65 |
+
help='Training data ratio (default: 0.8).')
|
66 |
+
parser.add_argument('--train-folder', type=str, default='training',
|
67 |
+
help='Folder for training data (default: "training").')
|
68 |
+
parser.add_argument('--val-folder', type=str, default='validation',
|
69 |
+
help='Folder for validation data (default: "validation").')
|
70 |
+
|
71 |
+
args = parser.parse_args()
|
72 |
+
|
73 |
+
# Create a DataSplitter instance and split the data
|
74 |
+
splitter = DataSplitter(args.input, args.train_ratio,
|
75 |
+
args.train_folder, args.val_folder)
|
76 |
+
splitter.split_data()
|
77 |
+
|
78 |
+
if __name__ == "__main__":
|
79 |
+
main()
|
utils/validation.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import csv
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
from skimage.metrics import structural_similarity as compare_ssim
|
5 |
+
import math
|
6 |
+
|
7 |
+
class ImageQualityMetrics:
|
8 |
+
"""Class to compute image quality metrics like SSIM, PSNR, and MSE."""
|
9 |
+
|
10 |
+
@staticmethod
|
11 |
+
@staticmethod
|
12 |
+
def ssim_3d(img1, img2):
|
13 |
+
"""Calculate SSIM for each 2D slice in the 3D image and return the average."""
|
14 |
+
ssim_vals = []
|
15 |
+
for i in range(img1.shape[1]): # Depth
|
16 |
+
slice1 = img1[0, i, :, :]
|
17 |
+
slice2 = img2[0, i, :, :]
|
18 |
+
ssim_val = compare_ssim(slice1, slice2, data_range=slice1.max() - slice1.min())
|
19 |
+
ssim_vals.append(ssim_val)
|
20 |
+
return np.mean(ssim_vals)
|
21 |
+
|
22 |
+
@staticmethod
|
23 |
+
def psnr(img1, img2):
|
24 |
+
"""Calculate PSNR (Peak Signal-to-Noise Ratio) between two images."""
|
25 |
+
mse = torch.mean((img1 - img2) ** 2)
|
26 |
+
if mse == 0:
|
27 |
+
return math.inf
|
28 |
+
return 20 * math.log10(img1.max() - img1.min()) - 10 * math.log10(mse)
|
29 |
+
|
30 |
+
@staticmethod
|
31 |
+
def mse(img1, img2):
|
32 |
+
"""Calculate MSE (Mean Squared Error) between two images."""
|
33 |
+
return torch.mean((img1 - img2) ** 2)
|
34 |
+
|
35 |
+
|
36 |
+
class ValidationRecorder:
|
37 |
+
"""Class to handle validation process and record the metrics."""
|
38 |
+
|
39 |
+
def __init__(self, csv_file_path):
|
40 |
+
"""Initialize the recorder with the path to the CSV file."""
|
41 |
+
self.csv_file_path = csv_file_path
|
42 |
+
|
43 |
+
def initialize_csv(self):
|
44 |
+
"""Initialize the CSV file with headers."""
|
45 |
+
with open(self.csv_file_path, mode='w', newline='') as file:
|
46 |
+
writer = csv.writer(file)
|
47 |
+
writer.writerow(['Epoch', 'Loss', 'SSIM', 'PSNR', 'MSE'])
|
48 |
+
|
49 |
+
def validate_and_record(self, epoch, dataloader, device, generator,
|
50 |
+
criterion_g):
|
51 |
+
"""Validate the model and record the metrics in the CSV file."""
|
52 |
+
generator.eval()
|
53 |
+
total_loss, total_ssim, total_psnr, total_mse = 0.0, 0.0, 0.0, 0.0
|
54 |
+
|
55 |
+
with torch.no_grad():
|
56 |
+
for _, (low_res, high_res) in enumerate(dataloader):
|
57 |
+
low_res, high_res = low_res.to(device), high_res.to(device)
|
58 |
+
fake_images = generator(low_res)
|
59 |
+
|
60 |
+
loss = criterion_g(fake_images, high_res)
|
61 |
+
total_loss += loss.item()
|
62 |
+
|
63 |
+
for j in range(high_res.size(0)):
|
64 |
+
ssim_val = ImageQualityMetrics.ssim_3d(
|
65 |
+
high_res[j].cpu().numpy(), fake_images[j].cpu().numpy())
|
66 |
+
psnr_val = ImageQualityMetrics.psnr(
|
67 |
+
high_res[j], fake_images[j])
|
68 |
+
mse_val = ImageQualityMetrics.mse(
|
69 |
+
high_res[j], fake_images[j])
|
70 |
+
total_ssim += ssim_val
|
71 |
+
total_psnr += psnr_val
|
72 |
+
total_mse += mse_val.item()
|
73 |
+
|
74 |
+
avg_loss = total_loss / len(dataloader)
|
75 |
+
avg_ssim = total_ssim / (len(dataloader) * dataloader.batch_size)
|
76 |
+
avg_psnr = total_psnr / (len(dataloader) * dataloader.batch_size)
|
77 |
+
avg_mse = total_mse / (len(dataloader) * dataloader.batch_size)
|
78 |
+
|
79 |
+
self._write_to_csv(epoch, avg_loss, avg_ssim, avg_psnr, avg_mse)
|
80 |
+
|
81 |
+
def _write_to_csv(self, epoch, avg_loss, avg_ssim, avg_psnr, avg_mse):
|
82 |
+
"""Write the validation metrics to the CSV file."""
|
83 |
+
with open(self.csv_file_path, mode='a', newline='') as file:
|
84 |
+
writer = csv.writer(file)
|
85 |
+
writer.writerow([epoch, avg_loss, avg_ssim, avg_psnr, avg_mse])
|