|
|
|
|
|
"""Custom operators.""" |
|
|
|
import torch |
|
import torch.nn as nn |
|
|
|
|
|
class Swish(nn.Module): |
|
"""Swish activation function: x * sigmoid(x).""" |
|
|
|
def __init__(self): |
|
super(Swish, self).__init__() |
|
|
|
def forward(self, x): |
|
return SwishEfficient.apply(x) |
|
|
|
|
|
class SwishEfficient(torch.autograd.Function): |
|
"""Swish activation function: x * sigmoid(x).""" |
|
|
|
@staticmethod |
|
def forward(ctx, x): |
|
result = x * torch.sigmoid(x) |
|
ctx.save_for_backward(x) |
|
return result |
|
|
|
@staticmethod |
|
def backward(ctx, grad_output): |
|
x = ctx.saved_variables[0] |
|
sigmoid_x = torch.sigmoid(x) |
|
return grad_output * (sigmoid_x * (1 + x * (1 - sigmoid_x))) |
|
|
|
|
|
class SE(nn.Module): |
|
"""Squeeze-and-Excitation (SE) block w/ Swish: AvgPool, FC, Swish, FC, Sigmoid.""" |
|
|
|
def _round_width(self, width, multiplier, min_width=8, divisor=8): |
|
""" |
|
Round width of filters based on width multiplier |
|
Args: |
|
width (int): the channel dimensions of the input. |
|
multiplier (float): the multiplication factor. |
|
min_width (int): the minimum width after multiplication. |
|
divisor (int): the new width should be dividable by divisor. |
|
""" |
|
if not multiplier: |
|
return width |
|
|
|
width *= multiplier |
|
min_width = min_width or divisor |
|
width_out = max( |
|
min_width, int(width + divisor / 2) // divisor * divisor |
|
) |
|
if width_out < 0.9 * width: |
|
width_out += divisor |
|
return int(width_out) |
|
|
|
def __init__(self, dim_in, ratio, relu_act=True): |
|
""" |
|
Args: |
|
dim_in (int): the channel dimensions of the input. |
|
ratio (float): the channel reduction ratio for squeeze. |
|
relu_act (bool): whether to use ReLU activation instead |
|
of Swish (default). |
|
divisor (int): the new width should be dividable by divisor. |
|
""" |
|
super(SE, self).__init__() |
|
self.avg_pool = nn.AdaptiveAvgPool3d((1, 1, 1)) |
|
dim_fc = self._round_width(dim_in, ratio) |
|
self.fc1 = nn.Conv3d(dim_in, dim_fc, 1, bias=True) |
|
self.fc1_act = nn.ReLU() if relu_act else Swish() |
|
self.fc2 = nn.Conv3d(dim_fc, dim_in, 1, bias=True) |
|
|
|
self.fc2_sig = nn.Sigmoid() |
|
|
|
def forward(self, x): |
|
x_in = x |
|
for module in self.children(): |
|
x = module(x) |
|
return x_in * x |
|
|