zy7_oldserver
1
fd601de
raw
history blame
36 kB
import nibabel as nib
from monai.transforms import (
Compose,
EnsureChannelFirst,
Rotate90,
ResizeWithPadOrCrop,
)
from monai.transforms import SaveImage
import numpy as np
import os
import torch
# save validation images
'''nib.save(
nib.Nifti1Image(val_outputs.astype(np.uint8), original_affine), os.path.join(output_directory, img_name)
)'''
## some functions for GAN training
# output_train_log: to save training loss log to a text file every epoch
# output_val_log: to save validation metrics to a text file every epoch
import monai
from torch import nn
from torch.utils.data import DataLoader
from torchmetrics import MeanAbsoluteError
from torchmetrics.image import StructuralSimilarityIndexMeasure,PeakSignalNoiseRatio
import numpy as np
import os
import torch
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from monai.transforms.utils import allow_missing_keys_mode
from synthrad_conversion.utils.image_metrics import ImageMetrics
class InferenceMetrics:
def __init__(self):
self.reset()
def reset(self):
self.ssim_sum = 0
self.mae_sum = 0
self.psnr_sum = 0
self.steps = 0
def update(self, ssim, mae, psnr):
self.ssim_sum += ssim
self.mae_sum += mae
self.psnr_sum += psnr
self.steps += 1
def get_averages(self):
return {
'ssim': self.ssim_sum / self.steps,
'mae': self.mae_sum / self.steps,
'psnr': self.psnr_sum / self.steps
}
class InferenceLogger:
def __init__(self, log_folder):
self.log_folder = log_folder
def get_log_single_set_file_path(self, val_step, epoch, unreversed=False):
suffix = 'unreversed_' if unreversed else 'reversed_'
return os.path.join(self.log_folder, f"{suffix}infer_log_valset_{val_step}_epoch_{epoch}.txt")
def get_log_file_total_sets_path(self, epoch, unreversed=False):
suffix = 'unreversed_' if unreversed else 'reversed_'
return os.path.join(self.log_folder, f"{suffix}infer_log_epoch_{epoch}.txt")
def write_log(self, message, val_step, epoch, unreversed=False):
file_path = self.get_log_single_set_file_path(val_step, epoch, unreversed)
with open(file_path, 'a') as file:
file.write(message + '\n')
class Postprocessfactory:
def __init__(self, untransformed_dataset, transforms):
self.untransformed_loader = DataLoader(untransformed_dataset, batch_size=1)
self.transforms = transforms
self.all_reverse_info = calculate_reverse_info(self.untransformed_loader)
def get_reverse_info(self):
return self.all_reverse_info
def reverseTransform(self,val_output,val_labels,val_images,val_masks):
# reverse the transforms
val_output.applied_operations = val_labels.applied_operations
val_output_dict = {"target": val_output[0,:,:,:,:],
"mask": val_masks[0,:,:,:,:],}
with allow_missing_keys_mode(self.transforms):
gen_img_volume_dict=self.transforms.inverse(val_output_dict)
val_output=gen_img_volume_dict["target"]
val_mask=gen_img_volume_dict["mask"]
return val_output,val_mask
def reverseNormalization(self,val_output,normalize,val_set_idx):
all_reverse_info = self.all_reverse_info
if normalize != 'none' and normalize != 'inputonlyminmax' and normalize != 'inputonlyzscore':
val_output = reverse_normalize_data(val_output,
mean=all_reverse_info['CT_mean'][val_set_idx],
std=all_reverse_info['CT_std'][val_set_idx],
min_val=all_reverse_info['CT_min'][val_set_idx],
max_val=all_reverse_info['CT_max'][val_set_idx],
mode=normalize)
return val_output
def reverseRotate(self,data):
# rotate the image to output images
return data.squeeze().permute(1,0,2).unsqueeze(0) #[1, 452, 315, 104] -> [315, 452, 104]
def resizeOutput(self,data,spatial_size=(512, 512,None)):
from monai.transforms import ResizeWithPadOrCrop
return ResizeWithPadOrCrop(spatial_size=spatial_size,mode="minimum")(data)
def compareInfo(self,fake_imgs,idx):
# print the mean and std of the original CT
print("mean of original CT:", self.all_reverse_info['CT_mean'][idx],
"std of original CT:", self.all_reverse_info['CT_std'][idx],
"min of original CT:", self.all_reverse_info['CT_min'][idx],
"max of original CT:", self.all_reverse_info['CT_max'][idx])
# print the mean and std of the fake CT
print("mean of fake CT:", torch.mean(fake_imgs),
"std of fake CT:", torch.std(fake_imgs),
'min of fake:', torch.min(fake_imgs),
'max of fake:', torch.max(fake_imgs))
def calculate_val_metrices(val_output, val_labels, log_file_single_set, log_file_overall, val_step):
slice_number = val_labels.shape[-1]
val_ssim_sum, val_mae_sum, val_psnr_sum = 0, 0, 0
for i in range(slice_number):
slice_output = val_output[None, None, :, :, i]
slice_label = val_labels[None, None, :, :, i]
val_ssim = StructuralSimilarityIndexMeasure()(slice_output, slice_label).to(slice_output.device)
val_mae = MeanAbsoluteError()(slice_output, slice_label).to(slice_output.device)
val_psnr = PeakSignalNoiseRatio()(slice_output, slice_label).to(slice_output.device)
val_ssim_sum += val_ssim
val_mae_sum += val_mae
val_psnr_sum += val_psnr
slice_metrics = {'ssim': val_ssim, 'mae': val_mae, 'psnr': val_psnr}
ssim = slice_metrics.get('ssim', 0)
mae = slice_metrics.get('mae', 0)
psnr = slice_metrics.get('psnr', 0)
with open(log_file_single_set, 'a') as f:
f.write(f'mean metrics for slice, step {i}, SSIM: {ssim}, MAE: {mae}, PSNR: {psnr}\n')
val_metrices = {
'ssim': val_ssim_sum / slice_number,
'mae': val_mae_sum / slice_number,
'psnr': val_psnr_sum / slice_number
}
print(f"mean ssim of val set {val_step}: {val_metrices['ssim']}") #:.4f
print(f"mean mae of val set {val_step}: {val_metrices['mae']}")
print(f"mean psnr of val set {val_step}: {val_metrices['psnr']}")
#output_val_log('mean', val_step, val_log_file=log_file_overall, val_metrices=val_metrices)
ssim = val_metrices.get('ssim', 0)
mae = val_metrices.get('mae', 0)
psnr = val_metrices.get('psnr', 0)
with open(log_file_overall, 'a') as f:
f.write(f'mean metrics for patient {val_step}, SSIM: {ssim}, MAE: {mae}, PSNR: {psnr}\n')
return val_metrices
def calculate_mask_metrices(val_output, val_labels, val_masks,
log_file_overall, val_step, dynamic_range = [-1024., 3000.], printoutput=False):
metricsCalc=ImageMetrics(dynamic_range)
if val_masks is None:
val_ssim = metricsCalc.ssim(val_output.numpy(), val_labels.numpy()) #
val_mae = metricsCalc.mae(val_output.numpy(), val_labels.numpy())
val_psnr = metricsCalc.psnr(val_output.numpy(), val_labels.numpy())
else:
val_ssim = metricsCalc.ssim(val_output.numpy(), val_labels.numpy(), val_masks.numpy()) #
val_mae = metricsCalc.mae(val_output.numpy(), val_labels.numpy(), val_masks.numpy())
val_psnr = metricsCalc.psnr(val_output.numpy(), val_labels.numpy(), val_masks.numpy())
val_metrices = {
'ssim': val_ssim,
'mae': val_mae,
'psnr': val_psnr,
}
if printoutput:
print(f"mean ssim {val_step}: {val_metrices['ssim']}") #:.4f
print(f"mean mae {val_step}: {val_metrices['mae']}")
print(f"mean psnr {val_step}: {val_metrices['psnr']}")
#output_val_log('mean', val_step, val_log_file=log_file_overall, val_metrices=val_metrices)
ssim = val_metrices.get('ssim', 0)
mae = val_metrices.get('mae', 0)
psnr = val_metrices.get('psnr', 0)
with open(log_file_overall, 'a') as f:
f.write(f'mean metrics {val_step}, SSIM: {ssim}, MAE: {mae}, PSNR: {psnr}\n')
return val_metrices
def process_and_save_images(input_imgs,
label_imgs,
fake_imgs,
unreversed_val_source,
unreversed_targets,
unreversed_output,
val_step,
epoch,
model_name,
folder,
slice_range):
for slice_idx in range(slice_range["min"], slice_range["max"]):
save_image_slice(input_imgs[:,:,slice_idx],
label_imgs[:,:,slice_idx],
fake_imgs[:,:,slice_idx],
slice_idx, val_step, epoch, model_name, folder)
save_image_slice(unreversed_val_source[:,:,slice_idx],
unreversed_targets[:,:,slice_idx],
unreversed_output[:,:,slice_idx],
slice_idx, val_step, epoch, model_name, folder, unreversed=True)
# Define function to save images
def save_single_image(input_imgs,filename, imgformat, dpi=300):
plt.figure() #, figsize=(5, 4))
plt.gca().set_axis_off()
plt.subplots_adjust(top = 1, bottom = 0, right = 1, left = 0,
hspace = 0, wspace = 0)
plt.margins(0,0)
plt.imshow(input_imgs, cmap='gray')
plt.savefig(filename, format=f'{imgformat}'
, bbox_inches='tight', pad_inches=0, dpi=dpi)
plt.close()
def save_image_slice(input_img,
label_img,
fake_img,
slice_idx,
val_step,
epoch,
model_name,
folder,
x_lower_limit=-1,
x_upper_limit=3,
y_lower_limit=0,
y_upper_limit=15000,
dpi=500,
unreversed=False):
imgformat = 'jpg'
dpi = dpi
prefix = "unreversed_" if unreversed else ""
save_single_image(input_img,
os.path.join(folder, f"{prefix}source_{val_step}_idx_{slice_idx}_epoch_{epoch}.{imgformat}"),
imgformat=imgformat, dpi=dpi)
save_single_image(label_img,
os.path.join(folder, f"{prefix}target_{val_step}_idx_{slice_idx}_epoch_{epoch}.{imgformat}"),
imgformat=imgformat, dpi=dpi)
save_single_image(fake_img,
os.path.join(folder, f"{prefix}fake_{val_step}_idx_{slice_idx}_epoch_{epoch}.{imgformat}"),
imgformat=imgformat, dpi=dpi)
arrange_images(input_img,label_img,fake_img, model_name=model_name,
saved_name=os.path.join(folder, f"{prefix}compare_{val_step}_idx_{slice_idx}_epoch_{epoch}.{imgformat}"),
imgformat=imgformat, dpi=dpi)
arrange_3_histograms(input_img.numpy(), label_img.numpy(), fake_img.numpy(),
saved_name=os.path.join(folder, f"{prefix}histograms_{val_step}_idx_{slice_idx}_epoch_{epoch}.png"), dpi=dpi,
x_lower_limit=x_lower_limit, x_upper_limit=x_upper_limit,
y_lower_limit=y_lower_limit, y_upper_limit=y_upper_limit
)
# save output images
def group_labels(test_labels):
size_to_labels = {}
labels_group=[]
labels_groups=[]
group_num=0
size_of_labels = [test_labels[0]['target'].shape]
for label in test_labels:
size = label['target'].shape
if size == size_of_labels[group_num]:
labels_group.append(label)
else:
group_num+=1
size_of_labels.append(size)
labels_groups.append(labels_group)
labels_group=[]
labels_group.append(label)
#print(size)
#print(group_num)
labels_groups.append(labels_group)
return labels_groups,size_of_labels
# divide the different patients from val_outputs
def write_nifti(val_outputs, output_dir=r'.\logs', filename='val'):
labels_groups,size_of_labels=group_labels(val_outputs)
nun_val_patients=len(labels_groups)
for i in range(nun_val_patients):
val_output=labels_groups[i]
# unsqueeze means add a dimension at the position of 3, and then use cat to combine the slices at this position
concatenated_outputs = torch.cat([label['target'].unsqueeze(3) for label in val_output], dim=3)
print(concatenated_outputs.shape)
SaveImage(output_dir=output_dir, output_postfix=f'{filename}_{i}',resample=True)(concatenated_outputs.detach().cpu())#torch.tensor(concatenated_outputs)
def write_nifti_volume(val_outputs, output_dir=r'.\logs', filename='val'):
SaveImage(output_dir=output_dir, output_postfix=f'{filename}',resample=True)(val_outputs.detach().cpu())
def reverse_transforms(output_images, orig_images,transforms):
# reverse the transforms
output_images.applied_operations = orig_images.applied_operations
val_output_dict = {"target": output_images[0,:,:,:,:]} # always set val_batch_size=1
with allow_missing_keys_mode(transforms):
reversed_images_dict=transforms.inverse(val_output_dict)
reversed_images=reversed_images_dict["target"]
return reversed_images
def calculate_ssim(pred, target):
ssim = StructuralSimilarityIndexMeasure().to(pred.device)
return ssim(pred, target)
def calculate_mae(pred, target):
mae = MeanAbsoluteError().to(pred.device)
return mae(pred, target)
def calculate_psnr(pred, target):
psnr = PeakSignalNoiseRatio().to(pred.device)
return psnr(pred, target)
def val_log(epoch, step, gen_image, orig_image, saved_path):
val_ssim=calculate_ssim(gen_image,orig_image)
val_mae=calculate_mae(gen_image,orig_image)
val_psnr=calculate_psnr(gen_image,orig_image)
print(f"val_ssim: {val_ssim}, val_mae: {val_mae}, val_psnr: {val_psnr}.")
val_metrices = {'ssim': val_ssim, 'mae': val_mae, 'psnr':val_psnr}
infer_log_file=os.path.join(saved_path, "infer_log.txt")
output_val_log(epoch, step, infer_log_file, val_metrices)
return val_metrices, infer_log_file
def output_val_log(epoch, val_step,val_log_file=r'.\logs\val_log.txt',val_metrices={'ssim': 0, 'mae': 0, 'psnr':0}):
# Save validation log to a text file every epoch
ssim=val_metrices['ssim'] if 'ssim' in val_metrices else 0
mae=val_metrices['mae'] if 'mae' in val_metrices else 0
psnr=val_metrices['psnr'] if 'psnr' in val_metrices else 0
with open(val_log_file, 'a') as f: # append mode
f.write(f'epoch {epoch}, val set {val_step}, SSIM: {ssim}, MAE: {mae}, PSNR: {psnr}\n')
def calculate_reverse_info(untransformed_loader):
ct_data_list=[]
mri_data_list=[]
mean_list_ct=[]
std_list_ct=[]
mean_list_mri=[]
std_list_mri=[]
ct_shape_list=[]
mri_shape_list=[]
untransformed_CT_min_list=[]
untransformed_CT_max_list=[]
untransformed_MRI_min_list=[]
untransformed_MRI_max_list=[]
# calculate the mean and std of the original data
for idx, checkdata in enumerate(untransformed_loader):
untransformed_CT=checkdata['target']
untransformed_MRI=checkdata['source']
mean_ct=torch.mean(untransformed_CT.float())
std_ct=torch.std(untransformed_CT.float())
mean_list_ct.append(mean_ct)
std_list_ct.append(std_ct)
mean_mri=torch.mean(untransformed_MRI.float())
std_mri=torch.std(untransformed_MRI.float())
mean_list_mri.append(mean_mri)
std_list_mri.append(std_mri)
ct_shape_list.append(untransformed_CT.shape)
mri_shape_list.append(untransformed_MRI.shape)
untransformed_CT_min_list.append(torch.min(untransformed_CT))
untransformed_CT_max_list.append(torch.max(untransformed_CT))
untransformed_MRI_min_list.append(torch.min(untransformed_MRI))
untransformed_MRI_max_list.append(torch.max(untransformed_MRI))
ct_data_list.append(untransformed_CT)
mri_data_list.append(untransformed_MRI)
all_reverse_info={"CT_mean":mean_list_ct,
"CT_std":std_list_ct,
"MRI_mean":mean_list_mri,
"MRI_std":std_list_mri,
"CT_shape":ct_shape_list,
"MRI_shape":mri_shape_list,
"CT_min":untransformed_CT_min_list,
"CT_max":untransformed_CT_max_list,
"MRI_min":untransformed_MRI_min_list,
"MRI_max":untransformed_MRI_max_list,
"CT_data":ct_data_list,
"MRI_data":mri_data_list}
return all_reverse_info
# Define function to reverse normalization
def reverse_normalize_data(tensor,
mean=None,
std=None,
min_val=None,
max_val=None,
mode='zscore'):
if mode == 'zscore':
return tensor * std + mean if mean is not None and std is not None else tensor
elif mode == 'minmax':
return (tensor+1) /2 * (max_val - min_val) + min_val if min_val is not None and max_val is not None else tensor
elif mode == 'inputonlyminmax' or mode == 'none' or mode == 'inputonlyzscore':
return tensor
elif mode == 'scale1000':
return tensor * 1000-1024
elif mode == 'scale4000':
return tensor * 4000-1024
elif mode == 'scale2000':
return tensor * 2000-1000
elif mode == 'nonegative':
return tensor - 1024
elif mode == 'norm_mr':
return tensor*255
elif mode == 'norm_mr_scale':
return tensor*255
# Define function to normalize and reverse normalize
def normalize_data(tensor, mean=None, std=None, min_val=None, max_val=None, mode='zscore'):
if mode == 'zscore':
return (tensor - mean) / std if mean is not None and std is not None else tensor
elif mode == 'minmax': # for minmax to -1 and 1
return (tensor - min_val) / (max_val - min_val) if min_val is not None and max_val is not None else tensor
return tensor
def save_val_images(val_outputs,val_slice_num,val_names,epoch,saved_img_folder):
# save validation images
if val_outputs.shape[0]==sum(val_slice_num):
# isolate different patients' data
# val_data_for_check=val_outputs.clone()
slice_number=val_slice_num # e.g. [200,200,150,230]
val_data_list=[]
check_step=0
for i in slice_number:
val_data0=val_outputs[:i,:,:,:]
val_data_list.append(val_data0)
# delete the first i rows of val_outputs
val_outputs = val_outputs.narrow(0,i,val_outputs.size(0)-i)
# check if the data is isolated correctly
# assert torch.all(val_data_for_check[0:i]==val_data_list[check_step])
check_step+=1
# save validation images
for i in range(len(val_data_list)):
#height=self.shape_list_val[i]["shape"][1] #338
#width=self.shape_list_val[i]["shape"][0] #565
#original_shape=(height,width)
file_name=f'pred_{val_names[i]}_epoch_{epoch+1}'
write_nifti(val_data_list[i],saved_img_folder,file_name)
else:
print(val_outputs.shape[0])
print(sum(val_slice_num))
print("something wrong with validation set, please check")
def compare_imgs(input_imgs, target_imgs, fake_imgs,
saved_name,
imgformat='jpg',
dpi = 500,
model_name='DDPM',):
from PIL import Image
input_imgs = input_imgs.squeeze().cpu().numpy()
input_imgs = (input_imgs * 255).astype(np.uint8)
input_imgs = Image.fromarray(input_imgs)
target_imgs = target_imgs.squeeze().cpu().numpy()
target_imgs = (target_imgs * 255).astype(np.uint8)
target_imgs = Image.fromarray(target_imgs)
fake_imgs = fake_imgs.squeeze().cpu().numpy()
fake_imgs = (fake_imgs * 255).astype(np.uint8)
fake_imgs = Image.fromarray(fake_imgs)
titles = ['MRI', 'CT', model_name]
fig, axs = plt.subplots(1, 3, figsize=(12, 5)) #
plt.gca().set_axis_off()
plt.subplots_adjust(top = 1, bottom = 0, right = 1, left = 0,
hspace = 0, wspace = 0.1)
plt.margins(0,0)
# MRI image
axs[0].imshow(input_imgs, cmap='gray')
axs[0].set_title(titles[0])
axs[0].axis('off')
# CT image
axs[1].imshow(target_imgs, cmap='gray')
axs[1].set_title(titles[1])
axs[1].axis('off')
# fake image
axs[2].imshow(fake_imgs, cmap='gray')
axs[2].set_title(titles[2])
axs[2].axis('off')
fig.savefig(saved_name, format=f'{imgformat}', bbox_inches='tight', pad_inches=0, dpi=dpi)
plt.close(fig)
# save individual images
# save output image individually
title1 = 'MRI'
fig_mri = plt.figure() #, figsize=(5, 4))
plt.gca().set_axis_off()
plt.subplots_adjust(top = 1, bottom = 0, right = 1, left = 0,
hspace = 0, wspace = 0)
plt.margins(0,0)
plt.imshow(input_imgs, cmap='gray')
plt.savefig(saved_name.replace(f'.{imgformat}',f'_mri.{imgformat}'), format=f'{imgformat}'
, bbox_inches='tight', pad_inches=0, dpi=dpi)
plt.close(fig_mri)
title2 = 'CT'
fig_ct = plt.figure()
plt.gca().set_axis_off()
plt.subplots_adjust(top = 1, bottom = 0, right = 1, left = 0,
hspace = 0, wspace = 0)
plt.margins(0,0)
plt.imshow(target_imgs, cmap='gray')
plt.savefig(saved_name.replace(f'.{imgformat}',f'_ct.{imgformat}'), format=f'{imgformat}'
, bbox_inches='tight', pad_inches=0, dpi=dpi)
plt.close(fig_ct)
title3 = model_name
fig_fake = plt.figure()
plt.gca().set_axis_off()
plt.subplots_adjust(top = 1, bottom = 0, right = 1, left = 0,
hspace = 0, wspace = 0)
plt.margins(0,0)
plt.imshow(fake_imgs, cmap='gray')
plt.savefig(saved_name.replace(f'.{imgformat}',f'_fake.{imgformat}'), format=f'{imgformat}'
, bbox_inches='tight', pad_inches=0, dpi=dpi)
plt.close(fig_fake)
# Define function to save images
def save_single_image(input_imgs,filename, imgformat, dpi=300):
plt.figure() #, figsize=(5, 4))
plt.gca().set_axis_off()
plt.subplots_adjust(top = 1, bottom = 0, right = 1, left = 0,
hspace = 0, wspace = 0)
plt.margins(0,0)
plt.imshow(input_imgs, cmap='gray')
plt.savefig(filename, format=f'{imgformat}'
, bbox_inches='tight', pad_inches=0, dpi=dpi)
plt.close()
class ImageProcessor:
def __init__(self, model_name='DDPM', img_format='jpg', dpi=500):
self.model_name = model_name
self.img_format = img_format
self.dpi = dpi
def convert_to_image(self, tensor_img):
from PIL import Image
np_img = tensor_img.squeeze().cpu().numpy()
np_img = (np_img * 255).astype(np.uint8)
return Image.fromarray(np_img)
def save_image(self, img, filename):
plt.figure()
plt.imshow(img, cmap='gray')
plt.axis('off')
plt.savefig(filename, format=self.img_format, bbox_inches='tight', pad_inches=0, dpi=self.dpi)
plt.close()
def compare_images(self, input_imgs, target_imgs, fake_imgs, saved_name):
input_img = self.convert_to_image(input_imgs)
target_img = self.convert_to_image(target_imgs)
fake_img = self.convert_to_image(fake_imgs)
titles = ['MRI', 'CT', self.model_name]
# Continue with arranging and saving the images as before, but use the above methods
def arrange_images(input_imgs,
label_imgs,
fake_imgs,
model_name,
saved_name,
imgformat='jpg',
dpi = 500):
titles = ['MRI', 'CT', model_name]
fig, axs = plt.subplots(1, 3, figsize=(12, 5)) #
plt.gca().set_axis_off()
plt.subplots_adjust(top = 1, bottom = 0, right = 1, left = 0,
hspace = 0, wspace = 0.1)
plt.margins(0,0)
cnt = 0
#print(gen_imgs[cnt].shape)
axs[0].imshow(input_imgs, cmap='gray') # 0,0,
axs[0].set_title(titles[0])
axs[0].axis('off')
axs[1].imshow(label_imgs, cmap='gray')
axs[1].set_title(titles[1])
axs[1].axis('off')
axs[2].imshow(fake_imgs, cmap='gray')
axs[2].set_title(titles[2])
axs[2].axis('off')
# save image as png
fig.savefig(saved_name, format=f'{imgformat}', bbox_inches='tight', pad_inches=0, dpi=dpi)
#plt.show()
plt.close(fig)
# Define function to plot histograms
def plot_histogram(data, title, ax, color='blue', alpha=0.7,
x_lower_limit=-1, x_upper_limit=3, y_lower_limit=0, y_upper_limit=15000):
#x_lower_limit, x_upper_limit = -100, 300 #-1100, 3000
#y_lower_limit, y_upper_limit = 0, 15000
bins = 256
ax.hist(data.flatten(), bins=bins,range=(x_lower_limit, x_upper_limit), color=color, alpha=alpha)
ax.set_ylim([y_lower_limit, y_upper_limit])
ax.set_title(title)
ax.set_xlabel('Pixel intensity')
ax.set_ylabel('Frequency')
def arrange_1_histogram(original, saved_name, title='Histogram', color='blue', alpha=0.7, dpi=300,
x_lower_limit=-1, x_upper_limit=3, y_lower_limit=0, y_upper_limit=15000):
# Plot histogram
fig, ax = plt.subplots(figsize=(10, 6))
plot_histogram(original, title, ax, color=color, alpha=alpha,
x_lower_limit=x_lower_limit, x_upper_limit=x_upper_limit,
y_lower_limit=y_lower_limit, y_upper_limit=y_upper_limit)
# Show and save the histogram figure
plt.tight_layout()
plt.savefig(saved_name, dpi=dpi)
plt.close(fig)
# Arrange two histograms
def arrange_histograms(original, reversed, saved_name, titles=['target','prediction'], dpi=300,
x_lower_limit=-1, x_upper_limit=3, y_lower_limit=0, y_upper_limit=15000):
# Plot histograms
fig, axs = plt.subplots(2, 1, figsize=(10, 8))
plot_histogram(original, f'Histogram for {titles[0]}', axs[0],color='red',
x_lower_limit=x_lower_limit, x_upper_limit=x_upper_limit,
y_lower_limit=y_lower_limit, y_upper_limit=y_upper_limit)
plot_histogram(reversed, f'Histogram for {titles[1]}', axs[1],color='green',
x_lower_limit=x_lower_limit, x_upper_limit=x_upper_limit,
y_lower_limit=y_lower_limit, y_upper_limit=y_upper_limit)
# Show and save the histogram figure
plt.tight_layout()
plt.savefig(saved_name, dpi=dpi)
plt.close(fig)
# Arrange three histograms
def arrange_3_histograms(source, target, output, saved_name , dpi=300,
x_lower_limit=-1, x_upper_limit=3, y_lower_limit=0, y_upper_limit=15000):
# Plot histograms
fig, axs = plt.subplots(3, 1, figsize=(10, 8))
plot_histogram(source, f'Histogram for source', axs[0],color='red',
x_lower_limit=x_lower_limit, x_upper_limit=x_upper_limit,
y_lower_limit=y_lower_limit, y_upper_limit=y_upper_limit)
plot_histogram(target, f'Histogram for target', axs[1],color='green',
x_lower_limit=x_lower_limit, x_upper_limit=x_upper_limit,
y_lower_limit=y_lower_limit, y_upper_limit=y_upper_limit)
plot_histogram(output, f'Histogram for output', axs[2],color='blue',
x_lower_limit=x_lower_limit, x_upper_limit=x_upper_limit,
y_lower_limit=y_lower_limit, y_upper_limit=y_upper_limit)
#plot_histogram(transformed, f'Histogram for transformed {mode}', axs[2],color='blue')
# Show and save the histogram figure
plt.tight_layout()
plt.savefig(saved_name, dpi=dpi)
plt.close(fig)
# boxplot
data = [source.flatten(), target.flatten(), output.flatten()]
plt.boxplot(data, autorange = True)
plt.xticks([1, 2, 3], ['Source', 'Target', 'Fake'])
plt.title('Pixel Value Distribution')
plt.xlabel('Image Type')
plt.ylabel('Pixel Values')
# Show and save the histogram figure
plt.tight_layout()
plt.savefig(saved_name.replace('histogram','boxplot'), dpi=dpi)
plt.close()
def arrange_4_histograms(real1,fake1, real2, fake2, saved_name , dpi=300):
# Plot histograms
fig, axs = plt.subplots(4, 1, figsize=(10, 8))
plot_histogram(real1, f'Histogram for real1', axs[0],color='red')
plot_histogram(fake1, f'Histogram for fake1', axs[1],color='red')
plot_histogram(real2, f'Histogram for real2', axs[2],color='green')
plot_histogram(fake2, f'Histogram for fake2', axs[3],color='green')
# Show and save the histogram figure
plt.tight_layout()
plt.savefig(saved_name, dpi=dpi)
plt.close(fig)
# save output images
def sample_images(model, input, label,slice_idx, epoch, batch_i, saved_folder, model_name='model'):
fake = model(input)
input_imgs=input.cpu().detach().numpy()
label_imgs=label.cpu().detach().numpy()
fake_imgs=fake.cpu().detach().numpy()
gen_imgs = np.concatenate(
[[input_imgs[slice_idx,0,:,:].squeeze()],
[label_imgs[slice_idx,0,:,:].squeeze()],
[fake_imgs[slice_idx,0,:,:].squeeze()]])
if not os.path.exists(saved_folder):
os.makedirs(saved_folder)
saved_name=os.path.join(saved_folder,f"{epoch}_{batch_i}.jpg")
titles = ['MRI', 'CT', 'Translated']
fig, axs = plt.subplots(1, 3, figsize=(20, 4))
cnt = 0
for j in range(3):
#print(gen_imgs[cnt].shape)
axs[j].imshow(gen_imgs[cnt], cmap='gray')
axs[j].set_title(titles[j])
axs[j].axis('off')
cnt += 1
fig.savefig(saved_name)
#plt.show()
plt.close(fig)
# save individual images
# save output image individually
title1 = 'MRI'
fig_mri, axs_mri = plt.subplots(1, 1) #, figsize=(5, 4))
axs_mri.imshow(gen_imgs[0].squeeze(), cmap='gray')
axs_mri.set_title(title1)
axs_mri.axis('off')
fig_mri.savefig(saved_name.replace('.jpg','_mri.jpg'))
plt.close(fig_mri)
title2 = 'CT'
fig_ct, axs_ct = plt.subplots(1, 1)
axs_ct.imshow(gen_imgs[1].squeeze(), cmap='gray')
axs_ct.set_title(title2)
axs_ct.axis('off')
fig_ct.savefig(saved_name.replace('.jpg','_ct.jpg'))
plt.close(fig_ct)
title3 = model_name
fig_fake, axs_fake = plt.subplots(1, 1)
axs_fake.imshow(gen_imgs[2].squeeze(), cmap='gray')
axs_fake.set_title(title3)
axs_fake.axis('off')
fig_fake.savefig(saved_name.replace('.jpg','_fake.jpg'))
plt.close(fig_fake)
def save_images(input_imgs, label_imgs,fake_imgs,
slice_idx,
saved_name='./test.jpg',
imgformat='jpg',
dpi = 1000,
model_name='model'):
titles = ['MRI', 'CT', model_name]
fig, axs = plt.subplots(1, 3, figsize=(12, 5)) #
plt.gca().set_axis_off()
plt.subplots_adjust(top = 1, bottom = 0, right = 1, left = 0,
hspace = 0, wspace = 0.1)
plt.margins(0,0)
cnt = 0
#print(gen_imgs[cnt].shape)
axs[0].imshow(input_imgs[:,:,slice_idx].squeeze(), cmap='gray') # 0,0,
axs[0].set_title(titles[0])
axs[0].axis('off')
axs[1].imshow(label_imgs[:,:,slice_idx], cmap='gray')
axs[1].set_title(titles[1])
axs[1].axis('off')
axs[2].imshow(fake_imgs[:,:,slice_idx].squeeze(), cmap='gray')
axs[2].set_title(titles[2])
axs[2].axis('off')
# save image as png
fig.savefig(saved_name, format=f'{imgformat}', bbox_inches='tight', pad_inches=0, dpi=dpi)
#plt.show()
plt.close(fig)
# save individual images
# save output image individually
title1 = 'MRI'
fig_mri = plt.figure() #, figsize=(5, 4))
plt.gca().set_axis_off()
plt.subplots_adjust(top = 1, bottom = 0, right = 1, left = 0,
hspace = 0, wspace = 0)
plt.margins(0,0)
plt.imshow(input_imgs[:,:,slice_idx].squeeze(), cmap='gray')
plt.savefig(saved_name.replace(f'.{imgformat}',f'_mri.{imgformat}'), format=f'{imgformat}'
, bbox_inches='tight', pad_inches=0, dpi=dpi)
plt.close(fig_mri)
title2 = 'CT'
fig_ct = plt.figure()
plt.gca().set_axis_off()
plt.subplots_adjust(top = 1, bottom = 0, right = 1, left = 0,
hspace = 0, wspace = 0)
plt.margins(0,0)
plt.imshow(label_imgs[:,:,slice_idx].squeeze(), cmap='gray')
plt.savefig(saved_name.replace(f'.{imgformat}',f'_ct.{imgformat}'), format=f'{imgformat}'
, bbox_inches='tight', pad_inches=0, dpi=dpi)
plt.close(fig_ct)
title3 = model_name
fig_fake = plt.figure()
plt.gca().set_axis_off()
plt.subplots_adjust(top = 1, bottom = 0, right = 1, left = 0,
hspace = 0, wspace = 0)
plt.margins(0,0)
plt.imshow(fake_imgs[:,:,slice_idx].squeeze(), cmap='gray')
plt.savefig(saved_name.replace(f'.{imgformat}',f'_fake.{imgformat}'), format=f'{imgformat}'
, bbox_inches='tight', pad_inches=0, dpi=dpi)
plt.close(fig_fake)
# save output images
def sample_images2(model, input, label,slice_idx, epoch, batch_i, saved_folder):
if not os.path.exists(saved_folder):
os.makedirs(saved_folder)
saved_name=f"{epoch}_{batch_i}.jpg"
fake = model(input)
input_imgs=input.cpu().detach().numpy()
target_imgs=label.cpu().detach().numpy()
fake_imags=fake.cpu().detach().numpy()
gen_imgs = np.concatenate(
[[input_imgs[slice_idx,0,:,:].squeeze()],
[target_imgs[slice_idx,0,:,:].squeeze()],
[fake_imags[slice_idx,0,:,:].squeeze()]])
titles = ['MRI', 'CT', 'Translated']
fig, axs = plt.subplots(1, 3, figsize=(20, 4))
cnt = 0
for j in range(3):
#print(gen_imgs[cnt].shape)
axs[j].imshow(gen_imgs[cnt], cmap='gray')
axs[j].set_title(titles[j])
axs[j].axis('off')
cnt += 1
fig.savefig(os.path.join(saved_folder,saved_name))
#plt.show()
plt.close(fig)
def sample_images_3D(model, input, label, epoch, batch_i, saved_folder):
fake = model(input)
input_imgs=input.cpu().detach().numpy()
target_imgs=label.cpu().detach().numpy()
fake_imags=fake.cpu().detach().numpy()
try:
gen_imgs = np.concatenate(
[[input_imgs[0,0,:,:,50].squeeze()],
[target_imgs[0,0,:,:,50].squeeze()],
[fake_imags[0,0,:,:,50].squeeze()]])
except:
gen_imgs = np.concatenate(
[[input_imgs[0,0,:,:,10].squeeze()],
[target_imgs[0,0,:,:,10].squeeze()],
[fake_imags[0,0,:,:,10].squeeze()]])
titles = ['MRI', 'CT', 'Translated']
fig, axs = plt.subplots(1, 3, figsize=(20, 4))
cnt = 0
for j in range(3):
#print(gen_imgs[cnt].shape)
axs[j].imshow(gen_imgs[cnt], cmap='gray')
axs[j].set_title(titles[j])
axs[j].axis('off')
cnt += 1
if not os.path.exists(saved_folder):
os.makedirs(saved_folder)
saved_name=f"{epoch}_{batch_i}.jpg"
fig.savefig(os.path.join(saved_folder,saved_name))
#plt.show()
plt.close(fig)