File size: 5,050 Bytes
92f0e98
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import torch
import torch.nn as nn
import numpy as np
from monai.transforms import Compose
from typing import List, Tuple

def __define_coord_channels__(x_dim, y_dim, z_dim=None):
    """
    Returns coord x and y channels from 0 to x_dim-1 and from 0 to y_dim -1
    """

    if z_dim is None:
        # original implementation (https://github.com/jmfacil/camconvs/blob/894add0858343e00da52143231bd30cda0f9f385/python/CAM/blocks/camconvs.py)
        xx_ones = torch.ones([y_dim], dtype=torch.long)
        xx_ones = torch.unsqueeze(xx_ones, -1)
        xx_range = torch.tile(torch.unsqueeze(torch.arange(x_dim), 0), [1])
        xx_range = torch.unsqueeze(xx_range, 1)
        xx_channel = torch.matmul(xx_ones, xx_range)

        yy_ones = torch.ones([x_dim], dtype=torch.long)
        yy_ones = torch.unsqueeze(yy_ones, 1)
        yy_range = torch.tile(torch.unsqueeze(torch.arange(y_dim), 0), [1])
        yy_range = torch.unsqueeze(yy_range, -1)
        yy_channel = torch.matmul(yy_range, yy_ones)

        return xx_channel.float(), yy_channel.float()
    
    else:
        # simplified 3d version
        x = torch.unsqueeze(torch.arange(y_dim)[None, :, None].repeat(x_dim, 1, z_dim), 0).float()
        y = torch.unsqueeze(torch.arange(x_dim)[:, None, None].repeat(1, y_dim, z_dim), 0).float()
        z = torch.unsqueeze(torch.arange(z_dim)[None, None, :].repeat(x_dim, y_dim, 1), 0).float()

        return x, y, z

class PositionalEncoding3D(nn.Module):
    def __init__(self, channels):
        """
        :param channels: The last dimension of the tensor you want to apply pos emb to.
        """
        super(PositionalEncoding3D, self).__init__()
        channels = int(np.ceil(channels/6)*2)
        if channels % 2:
            channels += 1
        self.channels = channels
        inv_freq = 1. / (10000 ** (torch.arange(0, channels, 2).float() / channels))
        self.register_buffer('inv_freq', inv_freq)

    def forward(self, tensor):
        """
        :param tensor: A 5d tensor of size (batch_size, x, y, z, ch)
        :return: Positional Encoding Matrix of size (batch_size, x, y, z, ch)
        """
        if len(tensor.shape) != 5:
            raise RuntimeError("The input tensor has to be 5d!")
        batch_size, x, y, z, orig_ch = tensor.shape
        pos_x = torch.arange(x, device=tensor.device).type(self.inv_freq.type())
        pos_y = torch.arange(y, device=tensor.device).type(self.inv_freq.type())
        pos_z = torch.arange(z, device=tensor.device).type(self.inv_freq.type())
        sin_inp_x = torch.einsum("i,j->ij", pos_x, self.inv_freq)
        sin_inp_y = torch.einsum("i,j->ij", pos_y, self.inv_freq)
        sin_inp_z = torch.einsum("i,j->ij", pos_z, self.inv_freq)
        emb_x = torch.cat((sin_inp_x.sin(), sin_inp_x.cos()), dim=-1).unsqueeze(1).unsqueeze(1)
        emb_y = torch.cat((sin_inp_y.sin(), sin_inp_y.cos()), dim=-1).unsqueeze(1)
        emb_z = torch.cat((sin_inp_z.sin(), sin_inp_z.cos()), dim=-1)
        emb = torch.zeros((x,y,z,self.channels*3),device=tensor.device).type(tensor.type())
        emb[:,:,:,:self.channels] = emb_x
        emb[:,:,:,self.channels:2*self.channels] = emb_y
        emb[:,:,:,2*self.channels:] = emb_z

        return emb[None,:,:,:,:orig_ch].repeat(batch_size, 1, 1, 1, 1)

class PositionalEncodingPermute3D(nn.Module):
    def __init__(self, channels):
        """
        Accepts (batchsize, ch, x, y, z) instead of (batchsize, x, y, z, ch)        
        """
        super(PositionalEncodingPermute3D, self).__init__()
        self.penc = PositionalEncoding3D(channels)

    def forward(self, tensor):
        tensor = tensor.permute(0,2,3,4,1)
        enc = self.penc(tensor)
        return enc.permute(0,4,1,2,3)

class AddCoordinateChannels:
    def __init__(self, to_key: str, input_size: int, input_dim: int, sinusodal: bool = False):
        self.to_key = to_key
        self.input_size = input_size
        self.input_dim = input_dim
        self.sinusodal = sinusodal

    def get_normalized_coordinate_channels(self) -> np.ndarray:
        if not self.sinusodal:
            channels = __define_coord_channels__(*((self.input_size,)*self.input_dim))
            normalize = lambda channel: (channel/(self.input_size-1))*2.0-1.0
            channels = [normalize(c) for c in channels]
            return torch.cat(channels).numpy()
        else:
            pep = PositionalEncodingPermute3D(self.input_dim)
            channels = pep(torch.zeros(1, self.input_dim, *(((self.input_size,)*self.input_dim))))
            return channels.squeeze(0).numpy()

    def __call__(self, data):
        d = dict(data)
        return {**d, self.to_key: self.get_normalized_coordinate_channels()}

def get_normalized_coordinates_transform(hparams, loaded_keys) -> Tuple[Compose, List[str]]:
    if hparams.coordinates:
        return Compose([
            AddCoordinateChannels("coordinates", hparams.input_size, hparams.input_dim, sinusodal=True)
        ]), loaded_keys + ["coordinates"]
    else:
        return Compose([]), loaded_keys