|
import torch |
|
import torch.nn as nn |
|
from einops import rearrange |
|
import torch.nn.functional as F |
|
|
|
|
|
class Conv(nn.Module): |
|
def __init__( |
|
self, |
|
in_channels, |
|
out_channels, |
|
kernel_size, |
|
stride=1, |
|
padding=0, |
|
cnn_type="2d", |
|
causal_offset=0, |
|
temporal_down=False, |
|
): |
|
super().__init__() |
|
self.cnn_type = cnn_type |
|
self.slice_seq_len = 17 |
|
|
|
if cnn_type == "2d": |
|
self.conv = nn.Conv2d( |
|
in_channels, out_channels, kernel_size, stride=stride, padding=padding |
|
) |
|
if cnn_type == "3d": |
|
if temporal_down == False: |
|
stride = (1, stride, stride) |
|
else: |
|
stride = (stride, stride, stride) |
|
self.conv = nn.Conv3d( |
|
in_channels, out_channels, kernel_size, stride=stride, padding=0 |
|
) |
|
if isinstance(kernel_size, int): |
|
kernel_size = (kernel_size, kernel_size, kernel_size) |
|
self.padding = ( |
|
kernel_size[0] - 1 + causal_offset, |
|
padding, |
|
padding, |
|
) |
|
self.causal_offset = causal_offset |
|
self.stride = stride |
|
self.kernel_size = kernel_size |
|
|
|
def forward(self, x): |
|
if self.cnn_type == "2d": |
|
if x.ndim == 5: |
|
B, C, T, H, W = x.shape |
|
x = rearrange(x, "B C T H W -> (B T) C H W") |
|
x = self.conv(x) |
|
x = rearrange(x, "(B T) C H W -> B C T H W", T=T) |
|
return x |
|
else: |
|
return self.conv(x) |
|
if self.cnn_type == "3d": |
|
assert ( |
|
self.stride[0] == 1 or self.stride[0] == 2 |
|
), f"only temporal stride = 1 or 2 are supported" |
|
xs = [] |
|
for i in range(0, x.shape[2], self.slice_seq_len + self.stride[0] - 1): |
|
st = i |
|
en = min(i + self.slice_seq_len, x.shape[2]) |
|
_x = x[:, :, st:en, :, :] |
|
if i == 0: |
|
_x = F.pad( |
|
_x, |
|
( |
|
self.padding[2], |
|
self.padding[2], |
|
self.padding[1], |
|
self.padding[1], |
|
self.padding[0], |
|
0, |
|
), |
|
) |
|
else: |
|
padding_0 = self.kernel_size[0] - 1 |
|
_x = F.pad( |
|
_x, |
|
( |
|
self.padding[2], |
|
self.padding[2], |
|
self.padding[1], |
|
self.padding[1], |
|
padding_0, |
|
0, |
|
), |
|
) |
|
_x[ |
|
:, |
|
:, |
|
:padding_0, |
|
self.padding[1] : _x.shape[-2] - self.padding[1], |
|
self.padding[2] : _x.shape[-1] - self.padding[2], |
|
] += x[:, :, i - padding_0 : i, :, :] |
|
_x = self.conv(_x) |
|
xs.append(_x) |
|
try: |
|
x = torch.cat(xs, dim=2) |
|
except: |
|
device = x.device |
|
del x |
|
xs = [_x.cpu().pin_memory() for _x in xs] |
|
torch.cuda.empty_cache() |
|
x = torch.cat([_x.cpu() for _x in xs], dim=2).to(device=device) |
|
return x |
|
|