Paul Engstler
Initial commit
92f0e98
raw
history blame
No virus
5.05 kB
import torch
import torch.nn as nn
import numpy as np
from monai.transforms import Compose
from typing import List, Tuple
def __define_coord_channels__(x_dim, y_dim, z_dim=None):
"""
Returns coord x and y channels from 0 to x_dim-1 and from 0 to y_dim -1
"""
if z_dim is None:
# original implementation (https://github.com/jmfacil/camconvs/blob/894add0858343e00da52143231bd30cda0f9f385/python/CAM/blocks/camconvs.py)
xx_ones = torch.ones([y_dim], dtype=torch.long)
xx_ones = torch.unsqueeze(xx_ones, -1)
xx_range = torch.tile(torch.unsqueeze(torch.arange(x_dim), 0), [1])
xx_range = torch.unsqueeze(xx_range, 1)
xx_channel = torch.matmul(xx_ones, xx_range)
yy_ones = torch.ones([x_dim], dtype=torch.long)
yy_ones = torch.unsqueeze(yy_ones, 1)
yy_range = torch.tile(torch.unsqueeze(torch.arange(y_dim), 0), [1])
yy_range = torch.unsqueeze(yy_range, -1)
yy_channel = torch.matmul(yy_range, yy_ones)
return xx_channel.float(), yy_channel.float()
else:
# simplified 3d version
x = torch.unsqueeze(torch.arange(y_dim)[None, :, None].repeat(x_dim, 1, z_dim), 0).float()
y = torch.unsqueeze(torch.arange(x_dim)[:, None, None].repeat(1, y_dim, z_dim), 0).float()
z = torch.unsqueeze(torch.arange(z_dim)[None, None, :].repeat(x_dim, y_dim, 1), 0).float()
return x, y, z
class PositionalEncoding3D(nn.Module):
def __init__(self, channels):
"""
:param channels: The last dimension of the tensor you want to apply pos emb to.
"""
super(PositionalEncoding3D, self).__init__()
channels = int(np.ceil(channels/6)*2)
if channels % 2:
channels += 1
self.channels = channels
inv_freq = 1. / (10000 ** (torch.arange(0, channels, 2).float() / channels))
self.register_buffer('inv_freq', inv_freq)
def forward(self, tensor):
"""
:param tensor: A 5d tensor of size (batch_size, x, y, z, ch)
:return: Positional Encoding Matrix of size (batch_size, x, y, z, ch)
"""
if len(tensor.shape) != 5:
raise RuntimeError("The input tensor has to be 5d!")
batch_size, x, y, z, orig_ch = tensor.shape
pos_x = torch.arange(x, device=tensor.device).type(self.inv_freq.type())
pos_y = torch.arange(y, device=tensor.device).type(self.inv_freq.type())
pos_z = torch.arange(z, device=tensor.device).type(self.inv_freq.type())
sin_inp_x = torch.einsum("i,j->ij", pos_x, self.inv_freq)
sin_inp_y = torch.einsum("i,j->ij", pos_y, self.inv_freq)
sin_inp_z = torch.einsum("i,j->ij", pos_z, self.inv_freq)
emb_x = torch.cat((sin_inp_x.sin(), sin_inp_x.cos()), dim=-1).unsqueeze(1).unsqueeze(1)
emb_y = torch.cat((sin_inp_y.sin(), sin_inp_y.cos()), dim=-1).unsqueeze(1)
emb_z = torch.cat((sin_inp_z.sin(), sin_inp_z.cos()), dim=-1)
emb = torch.zeros((x,y,z,self.channels*3),device=tensor.device).type(tensor.type())
emb[:,:,:,:self.channels] = emb_x
emb[:,:,:,self.channels:2*self.channels] = emb_y
emb[:,:,:,2*self.channels:] = emb_z
return emb[None,:,:,:,:orig_ch].repeat(batch_size, 1, 1, 1, 1)
class PositionalEncodingPermute3D(nn.Module):
def __init__(self, channels):
"""
Accepts (batchsize, ch, x, y, z) instead of (batchsize, x, y, z, ch)
"""
super(PositionalEncodingPermute3D, self).__init__()
self.penc = PositionalEncoding3D(channels)
def forward(self, tensor):
tensor = tensor.permute(0,2,3,4,1)
enc = self.penc(tensor)
return enc.permute(0,4,1,2,3)
class AddCoordinateChannels:
def __init__(self, to_key: str, input_size: int, input_dim: int, sinusodal: bool = False):
self.to_key = to_key
self.input_size = input_size
self.input_dim = input_dim
self.sinusodal = sinusodal
def get_normalized_coordinate_channels(self) -> np.ndarray:
if not self.sinusodal:
channels = __define_coord_channels__(*((self.input_size,)*self.input_dim))
normalize = lambda channel: (channel/(self.input_size-1))*2.0-1.0
channels = [normalize(c) for c in channels]
return torch.cat(channels).numpy()
else:
pep = PositionalEncodingPermute3D(self.input_dim)
channels = pep(torch.zeros(1, self.input_dim, *(((self.input_size,)*self.input_dim))))
return channels.squeeze(0).numpy()
def __call__(self, data):
d = dict(data)
return {**d, self.to_key: self.get_normalized_coordinate_channels()}
def get_normalized_coordinates_transform(hparams, loaded_keys) -> Tuple[Compose, List[str]]:
if hparams.coordinates:
return Compose([
AddCoordinateChannels("coordinates", hparams.input_size, hparams.input_dim, sinusodal=True)
]), loaded_keys + ["coordinates"]
else:
return Compose([]), loaded_keys