hwonheo commited on
Commit
604568a
·
verified ·
1 Parent(s): f83098b

Upload 4 files

Browse files
utils/BiasFieldCorrection.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import glob
4
+ import ants
5
+
6
+ def abp_n4(image, intensity_truncation=(0.025, 0.975, 256), mask=None,
7
+ use_n3=False):
8
+ """
9
+ Apply intensity truncation and bias correction to an image.
10
+
11
+ Parameters:
12
+ - image: Image to be processed (ANTsImage object).
13
+ - intensity_truncation: Tuple (lower quantile, upper quantile, number of bins).
14
+ - mask: Optional mask for bias correction (ANTsImage object).
15
+ - use_n3: Use N3 bias correction instead of N4 if True.
16
+
17
+ Returns:
18
+ - Processed image (ANTsImage object).
19
+ """
20
+ if not isinstance(intensity_truncation, (list, tuple)) or \
21
+ len(intensity_truncation) != 3:
22
+ raise ValueError("intensity_truncation must be list/tuple with 3 values")
23
+
24
+ # Apply intensity truncation
25
+ truncated_image = ants.iMath(image, "TruncateIntensity",
26
+ intensity_truncation[0], intensity_truncation[1],
27
+ intensity_truncation[2])
28
+
29
+ # Apply bias correction
30
+ if use_n3:
31
+ corrected_image = ants.n3_bias_field_correction(truncated_image, mask=mask)
32
+ else:
33
+ corrected_image = ants.n4_bias_field_correction(truncated_image, mask=mask)
34
+
35
+ return corrected_image
36
+
37
+
38
+ def preprocess_image(file_path, output_path, preprocess_type):
39
+ """
40
+ Read, preprocess, and write an image based on the specified method.
41
+
42
+ Parameters:
43
+ - file_path: Path to the input image file.
44
+ - output_path: Path to the output folder.
45
+ - preprocess_type: Preprocessing method (n3, n4, or abp).
46
+ """
47
+ image = ants.image_read(file_path)
48
+
49
+ if preprocess_type == "n3":
50
+ processed_image = ants.n3_bias_field_correction(image)
51
+ elif preprocess_type == "n4":
52
+ processed_image = ants.n4_bias_field_correction(image)
53
+ elif preprocess_type == "abp":
54
+ processed_image = abp_n4(image)
55
+ else:
56
+ raise ValueError("Invalid preprocess type")
57
+
58
+ ants.image_write(processed_image, os.path.join(output_path,
59
+ os.path.basename(file_path)))
60
+
61
+
62
+ def main():
63
+ """
64
+ Main function to handle command line arguments and process images.
65
+ """
66
+ parser = argparse.ArgumentParser(
67
+ description="Image Preprocessing Script for Bias Field Correction")
68
+ parser.add_argument('--input', type=str, required=True,
69
+ help="Input file or folder path")
70
+ parser.add_argument('--output', type=str, default='output',
71
+ help="Output folder path")
72
+ parser.add_argument('--type', type=str, default='abp',
73
+ choices=['n3', 'n4', 'abp'],
74
+ help="Type of preprocessing (n3, n4, abp)")
75
+
76
+ args = parser.parse_args()
77
+
78
+ if not os.path.exists(args.output):
79
+ os.makedirs(args.output)
80
+
81
+ files = [args.input] if os.path.isfile(args.input) else \
82
+ glob.glob(os.path.join(args.input, '*'))
83
+
84
+ for file in files:
85
+ preprocess_image(file, args.output, args.type)
86
+
87
+ print("----- Bias Field Correction Process completed -----")
88
+
89
+
90
+ if __name__ == '__main__':
91
+ main()
utils/__init__.py ADDED
File without changes
utils/data_splitter.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shutil
3
+ import random
4
+ import argparse
5
+
6
+ class DataSplitter:
7
+ """Class to split Nii or nii.gz files into training and validation datasets."""
8
+
9
+ def __init__(self, input_folder, train_ratio, train_folder, val_folder):
10
+ """
11
+ Initialize the DataSplitter with folder paths and data ratio.
12
+
13
+ :param input_folder: Path to the folder containing the Nii files.
14
+ :param train_ratio: Ratio of the dataset to be used as training data.
15
+ :param train_folder: Name of the folder to store training data.
16
+ :param val_folder: Name of the folder to store validation data.
17
+ """
18
+ self.input_folder = input_folder
19
+ self.train_ratio = train_ratio
20
+ self.train_folder = os.path.join(input_folder, train_folder)
21
+ self.val_folder = os.path.join(input_folder, val_folder)
22
+
23
+ def split_data(self):
24
+ """
25
+ Split the data into training and validation datasets and move files
26
+ to the respective folders.
27
+ """
28
+ # List all Nii or nii.gz files in the input folder
29
+ files = [f for f in os.listdir(self.input_folder)
30
+ if f.endswith('.nii') or f.endswith('.nii.gz')]
31
+ random.shuffle(files) # Shuffle files for random splitting
32
+
33
+ # Determine the split index based on the training ratio
34
+ split_index = int(len(files) * self.train_ratio)
35
+ train_files = files[:split_index]
36
+ val_files = files[split_index:]
37
+
38
+ # Create training and validation folders if they don't exist
39
+ os.makedirs(self.train_folder, exist_ok=True)
40
+ os.makedirs(self.val_folder, exist_ok=True)
41
+
42
+ # Move files to the respective folders
43
+ for file in train_files:
44
+ shutil.move(os.path.join(self.input_folder, file),
45
+ self.train_folder)
46
+
47
+ for file in val_files:
48
+ shutil.move(os.path.join(self.input_folder, file),
49
+ self.val_folder)
50
+
51
+ print(f"Files split into {len(train_files)} training "
52
+ f"and {len(val_files)} validation.")
53
+
54
+ def main():
55
+ """
56
+ Main function to handle command line arguments and initiate data splitting.
57
+ """
58
+ parser = argparse.ArgumentParser(
59
+ description='Split Nii files into training and validation datasets.')
60
+
61
+ # Define command line arguments
62
+ parser.add_argument('--input', type=str, required=True,
63
+ help='Input folder with Nii files.')
64
+ parser.add_argument('--train-ratio', type=float, default=0.8,
65
+ help='Training data ratio (default: 0.8).')
66
+ parser.add_argument('--train-folder', type=str, default='training',
67
+ help='Folder for training data (default: "training").')
68
+ parser.add_argument('--val-folder', type=str, default='validation',
69
+ help='Folder for validation data (default: "validation").')
70
+
71
+ args = parser.parse_args()
72
+
73
+ # Create a DataSplitter instance and split the data
74
+ splitter = DataSplitter(args.input, args.train_ratio,
75
+ args.train_folder, args.val_folder)
76
+ splitter.split_data()
77
+
78
+ if __name__ == "__main__":
79
+ main()
utils/validation.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import csv
2
+ import torch
3
+ import numpy as np
4
+ from skimage.metrics import structural_similarity as compare_ssim
5
+ import math
6
+
7
+ class ImageQualityMetrics:
8
+ """Class to compute image quality metrics like SSIM, PSNR, and MSE."""
9
+
10
+ @staticmethod
11
+ @staticmethod
12
+ def ssim_3d(img1, img2):
13
+ """Calculate SSIM for each 2D slice in the 3D image and return the average."""
14
+ ssim_vals = []
15
+ for i in range(img1.shape[1]): # Depth
16
+ slice1 = img1[0, i, :, :]
17
+ slice2 = img2[0, i, :, :]
18
+ ssim_val = compare_ssim(slice1, slice2, data_range=slice1.max() - slice1.min())
19
+ ssim_vals.append(ssim_val)
20
+ return np.mean(ssim_vals)
21
+
22
+ @staticmethod
23
+ def psnr(img1, img2):
24
+ """Calculate PSNR (Peak Signal-to-Noise Ratio) between two images."""
25
+ mse = torch.mean((img1 - img2) ** 2)
26
+ if mse == 0:
27
+ return math.inf
28
+ return 20 * math.log10(img1.max() - img1.min()) - 10 * math.log10(mse)
29
+
30
+ @staticmethod
31
+ def mse(img1, img2):
32
+ """Calculate MSE (Mean Squared Error) between two images."""
33
+ return torch.mean((img1 - img2) ** 2)
34
+
35
+
36
+ class ValidationRecorder:
37
+ """Class to handle validation process and record the metrics."""
38
+
39
+ def __init__(self, csv_file_path):
40
+ """Initialize the recorder with the path to the CSV file."""
41
+ self.csv_file_path = csv_file_path
42
+
43
+ def initialize_csv(self):
44
+ """Initialize the CSV file with headers."""
45
+ with open(self.csv_file_path, mode='w', newline='') as file:
46
+ writer = csv.writer(file)
47
+ writer.writerow(['Epoch', 'Loss', 'SSIM', 'PSNR', 'MSE'])
48
+
49
+ def validate_and_record(self, epoch, dataloader, device, generator,
50
+ criterion_g):
51
+ """Validate the model and record the metrics in the CSV file."""
52
+ generator.eval()
53
+ total_loss, total_ssim, total_psnr, total_mse = 0.0, 0.0, 0.0, 0.0
54
+
55
+ with torch.no_grad():
56
+ for _, (low_res, high_res) in enumerate(dataloader):
57
+ low_res, high_res = low_res.to(device), high_res.to(device)
58
+ fake_images = generator(low_res)
59
+
60
+ loss = criterion_g(fake_images, high_res)
61
+ total_loss += loss.item()
62
+
63
+ for j in range(high_res.size(0)):
64
+ ssim_val = ImageQualityMetrics.ssim_3d(
65
+ high_res[j].cpu().numpy(), fake_images[j].cpu().numpy())
66
+ psnr_val = ImageQualityMetrics.psnr(
67
+ high_res[j], fake_images[j])
68
+ mse_val = ImageQualityMetrics.mse(
69
+ high_res[j], fake_images[j])
70
+ total_ssim += ssim_val
71
+ total_psnr += psnr_val
72
+ total_mse += mse_val.item()
73
+
74
+ avg_loss = total_loss / len(dataloader)
75
+ avg_ssim = total_ssim / (len(dataloader) * dataloader.batch_size)
76
+ avg_psnr = total_psnr / (len(dataloader) * dataloader.batch_size)
77
+ avg_mse = total_mse / (len(dataloader) * dataloader.batch_size)
78
+
79
+ self._write_to_csv(epoch, avg_loss, avg_ssim, avg_psnr, avg_mse)
80
+
81
+ def _write_to_csv(self, epoch, avg_loss, avg_ssim, avg_psnr, avg_mse):
82
+ """Write the validation metrics to the CSV file."""
83
+ with open(self.csv_file_path, mode='a', newline='') as file:
84
+ writer = csv.writer(file)
85
+ writer.writerow([epoch, avg_loss, avg_ssim, avg_psnr, avg_mse])