File size: 3,776 Bytes
14ce5a9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import torch
import torch.nn as nn
from einops import rearrange
import torch.nn.functional as F


class Conv(nn.Module):
    def __init__(
        self,
        in_channels,
        out_channels,
        kernel_size,
        stride=1,
        padding=0,
        cnn_type="2d",
        causal_offset=0,
        temporal_down=False,
    ):
        super().__init__()
        self.cnn_type = cnn_type
        self.slice_seq_len = 17

        if cnn_type == "2d":
            self.conv = nn.Conv2d(
                in_channels, out_channels, kernel_size, stride=stride, padding=padding
            )
        if cnn_type == "3d":
            if temporal_down == False:
                stride = (1, stride, stride)
            else:
                stride = (stride, stride, stride)
            self.conv = nn.Conv3d(
                in_channels, out_channels, kernel_size, stride=stride, padding=0
            )
            if isinstance(kernel_size, int):
                kernel_size = (kernel_size, kernel_size, kernel_size)
            self.padding = (
                kernel_size[0] - 1 + causal_offset,  # Temporal causal padding
                padding,  # Height padding
                padding,  # Width padding
            )
        self.causal_offset = causal_offset
        self.stride = stride
        self.kernel_size = kernel_size

    def forward(self, x):
        if self.cnn_type == "2d":
            if x.ndim == 5:
                B, C, T, H, W = x.shape
                x = rearrange(x, "B C T H W -> (B T) C H W")
                x = self.conv(x)
                x = rearrange(x, "(B T) C H W -> B C T H W", T=T)
                return x
            else:
                return self.conv(x)
        if self.cnn_type == "3d":
            assert (
                self.stride[0] == 1 or self.stride[0] == 2
            ), f"only temporal stride = 1 or 2 are supported"
            xs = []
            for i in range(0, x.shape[2], self.slice_seq_len + self.stride[0] - 1):
                st = i
                en = min(i + self.slice_seq_len, x.shape[2])
                _x = x[:, :, st:en, :, :]
                if i == 0:
                    _x = F.pad(
                        _x,
                        (
                            self.padding[2],
                            self.padding[2],  # Width
                            self.padding[1],
                            self.padding[1],  # Height
                            self.padding[0],
                            0,
                        ),
                    )  # Temporal
                else:
                    padding_0 = self.kernel_size[0] - 1
                    _x = F.pad(
                        _x,
                        (
                            self.padding[2],
                            self.padding[2],  # Width
                            self.padding[1],
                            self.padding[1],  # Height
                            padding_0,
                            0,
                        ),
                    )  # Temporal
                    _x[
                        :,
                        :,
                        :padding_0,
                        self.padding[1] : _x.shape[-2] - self.padding[1],
                        self.padding[2] : _x.shape[-1] - self.padding[2],
                    ] += x[:, :, i - padding_0 : i, :, :]
                _x = self.conv(_x)
                xs.append(_x)
            try:
                x = torch.cat(xs, dim=2)
            except:
                device = x.device
                del x
                xs = [_x.cpu().pin_memory() for _x in xs]
                torch.cuda.empty_cache()
                x = torch.cat([_x.cpu() for _x in xs], dim=2).to(device=device)
            return x