import torch import torch.nn as nn import random from saicinpainting.training.modules.depthwise_sep_conv import DepthWiseSeperableConv class MultidilatedConv(nn.Module): def __init__(self, in_dim, out_dim, kernel_size, dilation_num=3, comb_mode='sum', equal_dim=True, shared_weights=False, padding=1, min_dilation=1, shuffle_in_channels=False, use_depthwise=False, **kwargs): super().__init__() convs = [] self.equal_dim = equal_dim assert comb_mode in ('cat_out', 'sum', 'cat_in', 'cat_both'), comb_mode if comb_mode in ('cat_out', 'cat_both'): self.cat_out = True if equal_dim: assert out_dim % dilation_num == 0 out_dims = [out_dim // dilation_num] * dilation_num self.index = sum([[i + j * (out_dims[0]) for j in range(dilation_num)] for i in range(out_dims[0])], []) else: out_dims = [out_dim // 2 ** (i + 1) for i in range(dilation_num - 1)] out_dims.append(out_dim - sum(out_dims)) index = [] starts = [0] + out_dims[:-1] lengths = [out_dims[i] // out_dims[-1] for i in range(dilation_num)] for i in range(out_dims[-1]): for j in range(dilation_num): index += list(range(starts[j], starts[j] + lengths[j])) starts[j] += lengths[j] self.index = index assert(len(index) == out_dim) self.out_dims = out_dims else: self.cat_out = False self.out_dims = [out_dim] * dilation_num if comb_mode in ('cat_in', 'cat_both'): if equal_dim: assert in_dim % dilation_num == 0 in_dims = [in_dim // dilation_num] * dilation_num else: in_dims = [in_dim // 2 ** (i + 1) for i in range(dilation_num - 1)] in_dims.append(in_dim - sum(in_dims)) self.in_dims = in_dims self.cat_in = True else: self.cat_in = False self.in_dims = [in_dim] * dilation_num conv_type = DepthWiseSeperableConv if use_depthwise else nn.Conv2d dilation = min_dilation for i in range(dilation_num): if isinstance(padding, int): cur_padding = padding * dilation else: cur_padding = padding[i] convs.append(conv_type( self.in_dims[i], self.out_dims[i], kernel_size, padding=cur_padding, dilation=dilation, **kwargs )) if i > 0 and shared_weights: convs[-1].weight = convs[0].weight convs[-1].bias = convs[0].bias dilation *= 2 self.convs = nn.ModuleList(convs) self.shuffle_in_channels = shuffle_in_channels if self.shuffle_in_channels: # shuffle list as shuffling of tensors is nondeterministic in_channels_permute = list(range(in_dim)) random.shuffle(in_channels_permute) # save as buffer so it is saved and loaded with checkpoint self.register_buffer('in_channels_permute', torch.tensor(in_channels_permute)) def forward(self, x): if self.shuffle_in_channels: x = x[:, self.in_channels_permute] outs = [] if self.cat_in: if self.equal_dim: x = x.chunk(len(self.convs), dim=1) else: new_x = [] start = 0 for dim in self.in_dims: new_x.append(x[:, start:start+dim]) start += dim x = new_x for i, conv in enumerate(self.convs): if self.cat_in: input = x[i] else: input = x outs.append(conv(input)) if self.cat_out: out = torch.cat(outs, dim=1)[:, self.index] else: out = sum(outs) return out