import torch import torch.nn as nn import numpy as np from monai.transforms import Compose from typing import List, Tuple def __define_coord_channels__(x_dim, y_dim, z_dim=None): """ Returns coord x and y channels from 0 to x_dim-1 and from 0 to y_dim -1 """ if z_dim is None: # original implementation (https://github.com/jmfacil/camconvs/blob/894add0858343e00da52143231bd30cda0f9f385/python/CAM/blocks/camconvs.py) xx_ones = torch.ones([y_dim], dtype=torch.long) xx_ones = torch.unsqueeze(xx_ones, -1) xx_range = torch.tile(torch.unsqueeze(torch.arange(x_dim), 0), [1]) xx_range = torch.unsqueeze(xx_range, 1) xx_channel = torch.matmul(xx_ones, xx_range) yy_ones = torch.ones([x_dim], dtype=torch.long) yy_ones = torch.unsqueeze(yy_ones, 1) yy_range = torch.tile(torch.unsqueeze(torch.arange(y_dim), 0), [1]) yy_range = torch.unsqueeze(yy_range, -1) yy_channel = torch.matmul(yy_range, yy_ones) return xx_channel.float(), yy_channel.float() else: # simplified 3d version x = torch.unsqueeze(torch.arange(y_dim)[None, :, None].repeat(x_dim, 1, z_dim), 0).float() y = torch.unsqueeze(torch.arange(x_dim)[:, None, None].repeat(1, y_dim, z_dim), 0).float() z = torch.unsqueeze(torch.arange(z_dim)[None, None, :].repeat(x_dim, y_dim, 1), 0).float() return x, y, z class PositionalEncoding3D(nn.Module): def __init__(self, channels): """ :param channels: The last dimension of the tensor you want to apply pos emb to. """ super(PositionalEncoding3D, self).__init__() channels = int(np.ceil(channels/6)*2) if channels % 2: channels += 1 self.channels = channels inv_freq = 1. / (10000 ** (torch.arange(0, channels, 2).float() / channels)) self.register_buffer('inv_freq', inv_freq) def forward(self, tensor): """ :param tensor: A 5d tensor of size (batch_size, x, y, z, ch) :return: Positional Encoding Matrix of size (batch_size, x, y, z, ch) """ if len(tensor.shape) != 5: raise RuntimeError("The input tensor has to be 5d!") batch_size, x, y, z, orig_ch = tensor.shape pos_x = torch.arange(x, device=tensor.device).type(self.inv_freq.type()) pos_y = torch.arange(y, device=tensor.device).type(self.inv_freq.type()) pos_z = torch.arange(z, device=tensor.device).type(self.inv_freq.type()) sin_inp_x = torch.einsum("i,j->ij", pos_x, self.inv_freq) sin_inp_y = torch.einsum("i,j->ij", pos_y, self.inv_freq) sin_inp_z = torch.einsum("i,j->ij", pos_z, self.inv_freq) emb_x = torch.cat((sin_inp_x.sin(), sin_inp_x.cos()), dim=-1).unsqueeze(1).unsqueeze(1) emb_y = torch.cat((sin_inp_y.sin(), sin_inp_y.cos()), dim=-1).unsqueeze(1) emb_z = torch.cat((sin_inp_z.sin(), sin_inp_z.cos()), dim=-1) emb = torch.zeros((x,y,z,self.channels*3),device=tensor.device).type(tensor.type()) emb[:,:,:,:self.channels] = emb_x emb[:,:,:,self.channels:2*self.channels] = emb_y emb[:,:,:,2*self.channels:] = emb_z return emb[None,:,:,:,:orig_ch].repeat(batch_size, 1, 1, 1, 1) class PositionalEncodingPermute3D(nn.Module): def __init__(self, channels): """ Accepts (batchsize, ch, x, y, z) instead of (batchsize, x, y, z, ch) """ super(PositionalEncodingPermute3D, self).__init__() self.penc = PositionalEncoding3D(channels) def forward(self, tensor): tensor = tensor.permute(0,2,3,4,1) enc = self.penc(tensor) return enc.permute(0,4,1,2,3) class AddCoordinateChannels: def __init__(self, to_key: str, input_size: int, input_dim: int, sinusodal: bool = False): self.to_key = to_key self.input_size = input_size self.input_dim = input_dim self.sinusodal = sinusodal def get_normalized_coordinate_channels(self) -> np.ndarray: if not self.sinusodal: channels = __define_coord_channels__(*((self.input_size,)*self.input_dim)) normalize = lambda channel: (channel/(self.input_size-1))*2.0-1.0 channels = [normalize(c) for c in channels] return torch.cat(channels).numpy() else: pep = PositionalEncodingPermute3D(self.input_dim) channels = pep(torch.zeros(1, self.input_dim, *(((self.input_size,)*self.input_dim)))) return channels.squeeze(0).numpy() def __call__(self, data): d = dict(data) return {**d, self.to_key: self.get_normalized_coordinate_channels()} def get_normalized_coordinates_transform(hparams, loaded_keys) -> Tuple[Compose, List[str]]: if hparams.coordinates: return Compose([ AddCoordinateChannels("coordinates", hparams.input_size, hparams.input_dim, sinusodal=True) ]), loaded_keys + ["coordinates"] else: return Compose([]), loaded_keys