# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. """Custom operators.""" import torch import torch.nn as nn class Swish(nn.Module): """Swish activation function: x * sigmoid(x).""" def __init__(self): super(Swish, self).__init__() def forward(self, x): return SwishEfficient.apply(x) class SwishEfficient(torch.autograd.Function): """Swish activation function: x * sigmoid(x).""" @staticmethod def forward(ctx, x): result = x * torch.sigmoid(x) ctx.save_for_backward(x) return result @staticmethod def backward(ctx, grad_output): x = ctx.saved_variables[0] sigmoid_x = torch.sigmoid(x) return grad_output * (sigmoid_x * (1 + x * (1 - sigmoid_x))) class SE(nn.Module): """Squeeze-and-Excitation (SE) block w/ Swish: AvgPool, FC, Swish, FC, Sigmoid.""" def _round_width(self, width, multiplier, min_width=8, divisor=8): """ Round width of filters based on width multiplier Args: width (int): the channel dimensions of the input. multiplier (float): the multiplication factor. min_width (int): the minimum width after multiplication. divisor (int): the new width should be dividable by divisor. """ if not multiplier: return width width *= multiplier min_width = min_width or divisor width_out = max( min_width, int(width + divisor / 2) // divisor * divisor ) if width_out < 0.9 * width: width_out += divisor return int(width_out) def __init__(self, dim_in, ratio, relu_act=True): """ Args: dim_in (int): the channel dimensions of the input. ratio (float): the channel reduction ratio for squeeze. relu_act (bool): whether to use ReLU activation instead of Swish (default). divisor (int): the new width should be dividable by divisor. """ super(SE, self).__init__() self.avg_pool = nn.AdaptiveAvgPool3d((1, 1, 1)) dim_fc = self._round_width(dim_in, ratio) self.fc1 = nn.Conv3d(dim_in, dim_fc, 1, bias=True) self.fc1_act = nn.ReLU() if relu_act else Swish() self.fc2 = nn.Conv3d(dim_fc, dim_in, 1, bias=True) self.fc2_sig = nn.Sigmoid() def forward(self, x): x_in = x for module in self.children(): x = module(x) return x_in * x