mshukor
init
3eb682b
raw
history blame
2.53 kB
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
"""Custom operators."""
import torch
import torch.nn as nn
class Swish(nn.Module):
"""Swish activation function: x * sigmoid(x)."""
def __init__(self):
super(Swish, self).__init__()
def forward(self, x):
return SwishEfficient.apply(x)
class SwishEfficient(torch.autograd.Function):
"""Swish activation function: x * sigmoid(x)."""
@staticmethod
def forward(ctx, x):
result = x * torch.sigmoid(x)
ctx.save_for_backward(x)
return result
@staticmethod
def backward(ctx, grad_output):
x = ctx.saved_variables[0]
sigmoid_x = torch.sigmoid(x)
return grad_output * (sigmoid_x * (1 + x * (1 - sigmoid_x)))
class SE(nn.Module):
"""Squeeze-and-Excitation (SE) block w/ Swish: AvgPool, FC, Swish, FC, Sigmoid."""
def _round_width(self, width, multiplier, min_width=8, divisor=8):
"""
Round width of filters based on width multiplier
Args:
width (int): the channel dimensions of the input.
multiplier (float): the multiplication factor.
min_width (int): the minimum width after multiplication.
divisor (int): the new width should be dividable by divisor.
"""
if not multiplier:
return width
width *= multiplier
min_width = min_width or divisor
width_out = max(
min_width, int(width + divisor / 2) // divisor * divisor
)
if width_out < 0.9 * width:
width_out += divisor
return int(width_out)
def __init__(self, dim_in, ratio, relu_act=True):
"""
Args:
dim_in (int): the channel dimensions of the input.
ratio (float): the channel reduction ratio for squeeze.
relu_act (bool): whether to use ReLU activation instead
of Swish (default).
divisor (int): the new width should be dividable by divisor.
"""
super(SE, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool3d((1, 1, 1))
dim_fc = self._round_width(dim_in, ratio)
self.fc1 = nn.Conv3d(dim_in, dim_fc, 1, bias=True)
self.fc1_act = nn.ReLU() if relu_act else Swish()
self.fc2 = nn.Conv3d(dim_fc, dim_in, 1, bias=True)
self.fc2_sig = nn.Sigmoid()
def forward(self, x):
x_in = x
for module in self.children():
x = module(x)
return x_in * x