import torch import torch.nn.functional as F from torch import nn class DropBlock2D(nn.Module): r"""Randomly zeroes 2D spatial blocks of the input tensor. As described in the paper `DropBlock: A regularization method for convolutional networks`_ , dropping whole blocks of feature map allows to remove semantic information as compared to regular dropout. Args: drop_prob (float): probability of an element to be dropped. block_size (int): size of the block to drop Shape: - Input: `(N, C, H, W)` - Output: `(N, C, H, W)` .. _DropBlock: A regularization method for convolutional networks: https://arxiv.org/abs/1810.12890 """ def __init__(self, drop_prob, block_size): super(DropBlock2D, self).__init__() self.drop_prob = drop_prob self.block_size = block_size def forward(self, x): # shape: (bsize, channels, height, width) assert x.dim() == 4, "Expected input with 4 dimensions (bsize, channels, height, width)" if not self.training or self.drop_prob == 0.0: return x else: # get gamma value gamma = self._compute_gamma(x) # sample mask mask = (torch.rand(x.shape[0], *x.shape[2:]) < gamma).float() # place mask on input device mask = mask.to(x.device) # compute block mask block_mask = self._compute_block_mask(mask) # apply block mask out = x * block_mask[:, None, :, :] # scale output out = out * block_mask.numel() / block_mask.sum() return out def _compute_block_mask(self, mask): block_mask = F.max_pool2d( input=mask[:, None, :, :], kernel_size=(self.block_size, self.block_size), stride=(1, 1), padding=self.block_size // 2, ) if self.block_size % 2 == 0: block_mask = block_mask[:, :, :-1, :-1] block_mask = 1 - block_mask.squeeze(1) return block_mask def _compute_gamma(self, x): return self.drop_prob / (self.block_size**2) class DropBlock3D(DropBlock2D): r"""Randomly zeroes 3D spatial blocks of the input tensor. An extension to the concept described in the paper `DropBlock: A regularization method for convolutional networks`_ , dropping whole blocks of feature map allows to remove semantic information as compared to regular dropout. Args: drop_prob (float): probability of an element to be dropped. block_size (int): size of the block to drop Shape: - Input: `(N, C, D, H, W)` - Output: `(N, C, D, H, W)` .. _DropBlock: A regularization method for convolutional networks: https://arxiv.org/abs/1810.12890 """ def __init__(self, drop_prob, block_size): super(DropBlock3D, self).__init__(drop_prob, block_size) def forward(self, x): # shape: (bsize, channels, depth, height, width) assert x.dim() == 5, "Expected input with 5 dimensions (bsize, channels, depth, height, width)" if not self.training or self.drop_prob == 0.0: return x else: # get gamma value gamma = self._compute_gamma(x) # sample mask mask = (torch.rand(x.shape[0], *x.shape[2:]) < gamma).float() # place mask on input device mask = mask.to(x.device) # compute block mask block_mask = self._compute_block_mask(mask) # apply block mask out = x * block_mask[:, None, :, :, :] # scale output out = out * block_mask.numel() / block_mask.sum() return out def _compute_block_mask(self, mask): block_mask = F.max_pool3d( input=mask[:, None, :, :, :], kernel_size=(self.block_size, self.block_size, self.block_size), stride=(1, 1, 1), padding=self.block_size // 2, ) if self.block_size % 2 == 0: block_mask = block_mask[:, :, :-1, :-1, :-1] block_mask = 1 - block_mask.squeeze(1) return block_mask def _compute_gamma(self, x): return self.drop_prob / (self.block_size**3)