Spaces:
Running
Running
| # -*- coding: utf-8 -*- | |
| """Transforms for AVRA inference (subset used by load_transform).""" | |
| from __future__ import division | |
| import torch | |
| import numpy as np | |
| class SwapAxes(object): | |
| def __init__(self, axis1, axis2): | |
| self.axis1 = axis1 | |
| self.axis2 = axis2 | |
| def __call__(self, image): | |
| return np.swapaxes(image, self.axis1, self.axis2) | |
| class Return5D(object): | |
| def __init__(self, nc=1): | |
| self.nc = nc | |
| def __call__(self, image_1): | |
| image_1 = image_1.unsqueeze(1) | |
| if self.nc != 1: | |
| image_1 = image_1.repeat(1, self.nc, 1, 1) | |
| return image_1 | |
| class ReturnStackedPA(object): | |
| def __init__(self, nc=1, ax_lim=[50, -5], cor_lim=[20, 75], sag_lim=[30, -30], stepsize=[2, 2, 2], rnn=True): | |
| self.nc = nc | |
| self.ax_lim = ax_lim | |
| self.cor_lim = cor_lim | |
| self.sag_lim = sag_lim | |
| self.stepsize = stepsize | |
| self.rnn = rnn | |
| def __call__(self, ax): | |
| ax_s, cor_s, sag_s = self.stepsize | |
| sag = ax.permute(1, 0, 2) | |
| cor = ax.permute(2, 1, 0) | |
| ax = ax[self.ax_lim[0]:self.ax_lim[1], :, :] | |
| cor = cor[self.cor_lim[0]:self.cor_lim[1], :, :] | |
| sag = sag[self.sag_lim[0]:self.sag_lim[1], :, :] | |
| ax = ax[0::ax_s, :, :] | |
| cor = cor[0::cor_s, :, :] | |
| sag = sag[0::sag_s, :, :] | |
| ax = ax.unsqueeze(1) | |
| cor = cor.unsqueeze(1) | |
| sag = sag.unsqueeze(1) | |
| if self.nc != 1: | |
| ax = ax.repeat(1, self.nc, 1, 1) | |
| sag = sag.repeat(1, self.nc, 1, 1) | |
| cor = cor.repeat(1, self.nc, 1, 1) | |
| img = torch.cat((ax, cor, sag), dim=0) | |
| if not self.rnn: | |
| img = img.squeeze(1) | |
| return img | |
| class ReduceSlices(object): | |
| def __init__(self, factor_hw, factor_d): | |
| self.f_h = factor_hw | |
| self.f_w = factor_hw | |
| self.f_d = factor_d | |
| def __call__(self, image): | |
| image = image[0::self.f_h, 0::self.f_w, 0::self.f_d] | |
| return image | |
| class CenterCrop(object): | |
| def __init__(self, output_x, output_y, output_z, offset_x=0, offset_y=0, offset_z=0): | |
| self.output_x = int(output_x) | |
| self.output_y = int(output_y) | |
| self.output_z = int(output_z) | |
| self.offset_x = int(offset_x) | |
| self.offset_y = int(offset_y) | |
| self.offset_z = int(offset_z) | |
| def __call__(self, image): | |
| img_min = image.min() | |
| img_max = image.max() | |
| image = image - img_min | |
| img_mean = image.mean() | |
| x_orig, y_orig, z_orig = image.shape[:3] | |
| x_mid = int(x_orig/2.) | |
| y_mid = int(y_orig/2.) | |
| z_mid = int(z_orig/2.) | |
| new_x, new_y, new_z = self.output_x, self.output_y, self.output_z | |
| x = int(x_mid + self.offset_x - round(new_x/2.)) | |
| y = int(y_mid + self.offset_y - round(new_y/2.)) | |
| z = int(z_mid + self.offset_z - round(new_z/2.)) | |
| if x + new_x > x_orig: | |
| x = 0 | |
| new_x = x_orig | |
| if y + new_y > y_orig: | |
| y = 0 | |
| new_y = y_orig | |
| if z + new_z > z_orig: | |
| z = 0 | |
| new_z = z_orig | |
| image = image[x:x+new_x, y:y+new_y, z:z+new_z] | |
| image = image / img_mean | |
| return image | |
| class RandomCrop(object): | |
| def __init__(self, output_x, output_y, output_z): | |
| self.output_x = output_x | |
| self.output_y = output_y | |
| self.output_z = output_z | |
| def __call__(self, image): | |
| x, y, z = image.shape[:3] | |
| new_x, new_y, new_z = self.output_x, self.output_y, self.output_z | |
| x = np.random.randint(0, max(1, x - new_x)) | |
| y = np.random.randint(0, max(1, y - new_y)) | |
| z = np.random.randint(0, max(1, z - new_z)) | |
| image = image[x:x+new_x, y:y+new_y, z:z+new_z] | |
| return image | |
| class RandomMirrorLR(object): | |
| def __init__(self, axis): | |
| self.axis = axis | |
| def __call__(self, image): | |
| if np.random.randn() > 0: | |
| image = np.flip(image, self.axis).copy() | |
| return image | |
| class RandomNoise(object): | |
| def __init__(self, noise_var=0.1, p=0.5): | |
| self.noise_var = noise_var | |
| self.p = p | |
| def __call__(self, image): | |
| if torch.rand(1)[0] < self.p: | |
| var = torch.rand(1)[0] * self.noise_var | |
| image = image + torch.randn(image.shape) * var | |
| return image | |
| class PerImageNormalization(object): | |
| def __call__(self, image): | |
| image = image - image.mean() | |
| image = image / image.std() | |
| return image | |
| class ToTensorFSL(object): | |
| def __call__(self, image): | |
| image = image.transpose((2, 0, 1)) | |
| image = torch.from_numpy(image) | |
| return image | |