Spaces:
Runtime error
Runtime error
import torch.nn.functional as F | |
from utils.general import * | |
import torch | |
from torch import nn | |
try: | |
from mish_cuda import MishCuda as Mish | |
except: | |
class Mish(nn.Module): # https://github.com/digantamisra98/Mish | |
def forward(self, x): | |
return x * F.softplus(x).tanh() | |
try: | |
from pytorch_wavelets import DWTForward, DWTInverse | |
class DWT(nn.Module): | |
def __init__(self): | |
super(DWT, self).__init__() | |
self.xfm = DWTForward(J=1, wave='db1', mode='zero') | |
def forward(self, x): | |
b,c,w,h = x.shape | |
yl, yh = self.xfm(x) | |
return torch.cat([yl/2., yh[0].view(b,-1,w//2,h//2)/2.+.5], 1) | |
except: # using Reorg instead | |
class DWT(nn.Module): | |
def forward(self, x): | |
return torch.cat([x[..., ::2, ::2], x[..., 1::2, ::2], x[..., ::2, 1::2], x[..., 1::2, 1::2]], 1) | |
class Reorg(nn.Module): | |
def forward(self, x): | |
return torch.cat([x[..., ::2, ::2], x[..., 1::2, ::2], x[..., ::2, 1::2], x[..., 1::2, 1::2]], 1) | |
def make_divisible(v, divisor): | |
# Function ensures all layers have a channel number that is divisible by 8 | |
# https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py | |
return math.ceil(v / divisor) * divisor | |
class Flatten(nn.Module): | |
# Use after nn.AdaptiveAvgPool2d(1) to remove last 2 dimensions | |
def forward(self, x): | |
return x.view(x.size(0), -1) | |
class Concat(nn.Module): | |
# Concatenate a list of tensors along dimension | |
def __init__(self, dimension=1): | |
super(Concat, self).__init__() | |
self.d = dimension | |
def forward(self, x): | |
return torch.cat(x, self.d) | |
class FeatureConcat(nn.Module): | |
def __init__(self, layers): | |
super(FeatureConcat, self).__init__() | |
self.layers = layers # layer indices | |
self.multiple = len(layers) > 1 # multiple layers flag | |
def forward(self, x, outputs): | |
return torch.cat([outputs[i] for i in self.layers], 1) if self.multiple else outputs[self.layers[0]] | |
class FeatureConcat2(nn.Module): | |
def __init__(self, layers): | |
super(FeatureConcat2, self).__init__() | |
self.layers = layers # layer indices | |
self.multiple = len(layers) > 1 # multiple layers flag | |
def forward(self, x, outputs): | |
return torch.cat([outputs[self.layers[0]], outputs[self.layers[1]].detach()], 1) | |
class FeatureConcat3(nn.Module): | |
def __init__(self, layers): | |
super(FeatureConcat3, self).__init__() | |
self.layers = layers # layer indices | |
self.multiple = len(layers) > 1 # multiple layers flag | |
def forward(self, x, outputs): | |
return torch.cat([outputs[self.layers[0]], outputs[self.layers[1]].detach(), outputs[self.layers[2]].detach()], 1) | |
class FeatureConcat_l(nn.Module): | |
def __init__(self, layers): | |
super(FeatureConcat_l, self).__init__() | |
self.layers = layers # layer indices | |
self.multiple = len(layers) > 1 # multiple layers flag | |
def forward(self, x, outputs): | |
return torch.cat([outputs[i][:,:outputs[i].shape[1]//2,:,:] for i in self.layers], 1) if self.multiple else outputs[self.layers[0]][:,:outputs[self.layers[0]].shape[1]//2,:,:] | |
class WeightedFeatureFusion(nn.Module): # weighted sum of 2 or more layers https://arxiv.org/abs/1911.09070 | |
def __init__(self, layers, weight=False): | |
super(WeightedFeatureFusion, self).__init__() | |
self.layers = layers # layer indices | |
self.weight = weight # apply weights boolean | |
self.n = len(layers) + 1 # number of layers | |
if weight: | |
self.w = nn.Parameter(torch.zeros(self.n), requires_grad=True) # layer weights | |
def forward(self, x, outputs): | |
# Weights | |
if self.weight: | |
w = torch.sigmoid(self.w) * (2 / self.n) # sigmoid weights (0-1) | |
x = x * w[0] | |
# Fusion | |
nx = x.shape[1] # input channels | |
for i in range(self.n - 1): | |
a = outputs[self.layers[i]] * w[i + 1] if self.weight else outputs[self.layers[i]] # feature to add | |
na = a.shape[1] # feature channels | |
# Adjust channels | |
if nx == na: # same shape | |
x = x + a | |
elif nx > na: # slice input | |
x[:, :na] = x[:, :na] + a # or a = nn.ZeroPad2d((0, 0, 0, 0, 0, dc))(a); x = x + a | |
else: # slice feature | |
x = x + a[:, :nx] | |
return x | |
class MixConv2d(nn.Module): # MixConv: Mixed Depthwise Convolutional Kernels https://arxiv.org/abs/1907.09595 | |
def __init__(self, in_ch, out_ch, k=(3, 5, 7), stride=1, dilation=1, bias=True, method='equal_params'): | |
super(MixConv2d, self).__init__() | |
groups = len(k) | |
if method == 'equal_ch': # equal channels per group | |
i = torch.linspace(0, groups - 1E-6, out_ch).floor() # out_ch indices | |
ch = [(i == g).sum() for g in range(groups)] | |
else: # 'equal_params': equal parameter count per group | |
b = [out_ch] + [0] * groups | |
a = np.eye(groups + 1, groups, k=-1) | |
a -= np.roll(a, 1, axis=1) | |
a *= np.array(k) ** 2 | |
a[0] = 1 | |
ch = np.linalg.lstsq(a, b, rcond=None)[0].round().astype(int) # solve for equal weight indices, ax = b | |
self.m = nn.ModuleList([nn.Conv2d(in_channels=in_ch, | |
out_channels=ch[g], | |
kernel_size=k[g], | |
stride=stride, | |
padding=k[g] // 2, # 'same' pad | |
dilation=dilation, | |
bias=bias) for g in range(groups)]) | |
def forward(self, x): | |
return torch.cat([m(x) for m in self.m], 1) | |
# Activation functions below ------------------------------------------------------------------------------------------- | |
class SwishImplementation(torch.autograd.Function): | |
def forward(ctx, x): | |
ctx.save_for_backward(x) | |
return x * torch.sigmoid(x) | |
def backward(ctx, grad_output): | |
x = ctx.saved_tensors[0] | |
sx = torch.sigmoid(x) # sigmoid(ctx) | |
return grad_output * (sx * (1 + x * (1 - sx))) | |
class MishImplementation(torch.autograd.Function): | |
def forward(ctx, x): | |
ctx.save_for_backward(x) | |
return x.mul(torch.tanh(F.softplus(x))) # x * tanh(ln(1 + exp(x))) | |
def backward(ctx, grad_output): | |
x = ctx.saved_tensors[0] | |
sx = torch.sigmoid(x) | |
fx = F.softplus(x).tanh() | |
return grad_output * (fx + x * sx * (1 - fx * fx)) | |
class MemoryEfficientSwish(nn.Module): | |
def forward(self, x): | |
return SwishImplementation.apply(x) | |
class MemoryEfficientMish(nn.Module): | |
def forward(self, x): | |
return MishImplementation.apply(x) | |
class Swish(nn.Module): | |
def forward(self, x): | |
return x * torch.sigmoid(x) | |
class HardSwish(nn.Module): # https://arxiv.org/pdf/1905.02244.pdf | |
def forward(self, x): | |
return x * F.hardtanh(x + 3, 0., 6., True) / 6. | |
class DeformConv2d(nn.Module): | |
def __init__(self, inc, outc, kernel_size=3, padding=1, stride=1, bias=None, modulation=False): | |
""" | |
Args: | |
modulation (bool, optional): If True, Modulated Defomable Convolution (Deformable ConvNets v2). | |
""" | |
super(DeformConv2d, self).__init__() | |
self.kernel_size = kernel_size | |
self.padding = padding | |
self.stride = stride | |
self.zero_padding = nn.ZeroPad2d(padding) | |
self.conv = nn.Conv2d(inc, outc, kernel_size=kernel_size, stride=kernel_size, bias=bias) | |
self.p_conv = nn.Conv2d(inc, 2*kernel_size*kernel_size, kernel_size=3, padding=1, stride=stride) | |
nn.init.constant_(self.p_conv.weight, 0) | |
self.p_conv.register_backward_hook(self._set_lr) | |
self.modulation = modulation | |
if modulation: | |
self.m_conv = nn.Conv2d(inc, kernel_size*kernel_size, kernel_size=3, padding=1, stride=stride) | |
nn.init.constant_(self.m_conv.weight, 0) | |
self.m_conv.register_backward_hook(self._set_lr) | |
def _set_lr(module, grad_input, grad_output): | |
grad_input = (grad_input[i] * 0.1 for i in range(len(grad_input))) | |
grad_output = (grad_output[i] * 0.1 for i in range(len(grad_output))) | |
def forward(self, x): | |
offset = self.p_conv(x) | |
if self.modulation: | |
m = torch.sigmoid(self.m_conv(x)) | |
dtype = offset.data.type() | |
ks = self.kernel_size | |
N = offset.size(1) // 2 | |
if self.padding: | |
x = self.zero_padding(x) | |
# (b, 2N, h, w) | |
p = self._get_p(offset, dtype) | |
# (b, h, w, 2N) | |
p = p.contiguous().permute(0, 2, 3, 1) | |
q_lt = p.detach().floor() | |
q_rb = q_lt + 1 | |
q_lt = torch.cat([torch.clamp(q_lt[..., :N], 0, x.size(2)-1), torch.clamp(q_lt[..., N:], 0, x.size(3)-1)], dim=-1).long() | |
q_rb = torch.cat([torch.clamp(q_rb[..., :N], 0, x.size(2)-1), torch.clamp(q_rb[..., N:], 0, x.size(3)-1)], dim=-1).long() | |
q_lb = torch.cat([q_lt[..., :N], q_rb[..., N:]], dim=-1) | |
q_rt = torch.cat([q_rb[..., :N], q_lt[..., N:]], dim=-1) | |
# clip p | |
p = torch.cat([torch.clamp(p[..., :N], 0, x.size(2)-1), torch.clamp(p[..., N:], 0, x.size(3)-1)], dim=-1) | |
# bilinear kernel (b, h, w, N) | |
g_lt = (1 + (q_lt[..., :N].type_as(p) - p[..., :N])) * (1 + (q_lt[..., N:].type_as(p) - p[..., N:])) | |
g_rb = (1 - (q_rb[..., :N].type_as(p) - p[..., :N])) * (1 - (q_rb[..., N:].type_as(p) - p[..., N:])) | |
g_lb = (1 + (q_lb[..., :N].type_as(p) - p[..., :N])) * (1 - (q_lb[..., N:].type_as(p) - p[..., N:])) | |
g_rt = (1 - (q_rt[..., :N].type_as(p) - p[..., :N])) * (1 + (q_rt[..., N:].type_as(p) - p[..., N:])) | |
# (b, c, h, w, N) | |
x_q_lt = self._get_x_q(x, q_lt, N) | |
x_q_rb = self._get_x_q(x, q_rb, N) | |
x_q_lb = self._get_x_q(x, q_lb, N) | |
x_q_rt = self._get_x_q(x, q_rt, N) | |
# (b, c, h, w, N) | |
x_offset = g_lt.unsqueeze(dim=1) * x_q_lt + \ | |
g_rb.unsqueeze(dim=1) * x_q_rb + \ | |
g_lb.unsqueeze(dim=1) * x_q_lb + \ | |
g_rt.unsqueeze(dim=1) * x_q_rt | |
# modulation | |
if self.modulation: | |
m = m.contiguous().permute(0, 2, 3, 1) | |
m = m.unsqueeze(dim=1) | |
m = torch.cat([m for _ in range(x_offset.size(1))], dim=1) | |
x_offset *= m | |
x_offset = self._reshape_x_offset(x_offset, ks) | |
out = self.conv(x_offset) | |
return out | |
def _get_p_n(self, N, dtype): | |
p_n_x, p_n_y = torch.meshgrid( | |
torch.arange(-(self.kernel_size-1)//2, (self.kernel_size-1)//2+1), | |
torch.arange(-(self.kernel_size-1)//2, (self.kernel_size-1)//2+1)) | |
# (2N, 1) | |
p_n = torch.cat([torch.flatten(p_n_x), torch.flatten(p_n_y)], 0) | |
p_n = p_n.view(1, 2*N, 1, 1).type(dtype) | |
return p_n | |
def _get_p_0(self, h, w, N, dtype): | |
p_0_x, p_0_y = torch.meshgrid( | |
torch.arange(1, h*self.stride+1, self.stride), | |
torch.arange(1, w*self.stride+1, self.stride)) | |
p_0_x = torch.flatten(p_0_x).view(1, 1, h, w).repeat(1, N, 1, 1) | |
p_0_y = torch.flatten(p_0_y).view(1, 1, h, w).repeat(1, N, 1, 1) | |
p_0 = torch.cat([p_0_x, p_0_y], 1).type(dtype) | |
return p_0 | |
def _get_p(self, offset, dtype): | |
N, h, w = offset.size(1)//2, offset.size(2), offset.size(3) | |
# (1, 2N, 1, 1) | |
p_n = self._get_p_n(N, dtype) | |
# (1, 2N, h, w) | |
p_0 = self._get_p_0(h, w, N, dtype) | |
p = p_0 + p_n + offset | |
return p | |
def _get_x_q(self, x, q, N): | |
b, h, w, _ = q.size() | |
padded_w = x.size(3) | |
c = x.size(1) | |
# (b, c, h*w) | |
x = x.contiguous().view(b, c, -1) | |
# (b, h, w, N) | |
index = q[..., :N]*padded_w + q[..., N:] # offset_x*w + offset_y | |
# (b, c, h*w*N) | |
index = index.contiguous().unsqueeze(dim=1).expand(-1, c, -1, -1, -1).contiguous().view(b, c, -1) | |
x_offset = x.gather(dim=-1, index=index).contiguous().view(b, c, h, w, N) | |
return x_offset | |
def _reshape_x_offset(x_offset, ks): | |
b, c, h, w, N = x_offset.size() | |
x_offset = torch.cat([x_offset[..., s:s+ks].contiguous().view(b, c, h, w*ks) for s in range(0, N, ks)], dim=-1) | |
x_offset = x_offset.contiguous().view(b, c, h*ks, w*ks) | |
return x_offset | |
class GAP(nn.Module): | |
def __init__(self): | |
super(GAP, self).__init__() | |
self.avg_pool = nn.AdaptiveAvgPool2d(1) | |
def forward(self, x): | |
#b, c, _, _ = x.size() | |
return self.avg_pool(x)#.view(b, c) | |
class Silence(nn.Module): | |
def __init__(self): | |
super(Silence, self).__init__() | |
def forward(self, x): | |
return x | |
class ScaleChannel(nn.Module): # weighted sum of 2 or more layers https://arxiv.org/abs/1911.09070 | |
def __init__(self, layers): | |
super(ScaleChannel, self).__init__() | |
self.layers = layers # layer indices | |
def forward(self, x, outputs): | |
a = outputs[self.layers[0]] | |
return x.expand_as(a) * a | |
class ShiftChannel(nn.Module): # weighted sum of 2 or more layers https://arxiv.org/abs/1911.09070 | |
def __init__(self, layers): | |
super(ShiftChannel, self).__init__() | |
self.layers = layers # layer indices | |
def forward(self, x, outputs): | |
a = outputs[self.layers[0]] | |
return a.expand_as(x) + x | |
class ShiftChannel2D(nn.Module): # weighted sum of 2 or more layers https://arxiv.org/abs/1911.09070 | |
def __init__(self, layers): | |
super(ShiftChannel2D, self).__init__() | |
self.layers = layers # layer indices | |
def forward(self, x, outputs): | |
a = outputs[self.layers[0]].view(1,-1,1,1) | |
return a.expand_as(x) + x | |
class ControlChannel(nn.Module): # weighted sum of 2 or more layers https://arxiv.org/abs/1911.09070 | |
def __init__(self, layers): | |
super(ControlChannel, self).__init__() | |
self.layers = layers # layer indices | |
def forward(self, x, outputs): | |
a = outputs[self.layers[0]] | |
return a.expand_as(x) * x | |
class ControlChannel2D(nn.Module): # weighted sum of 2 or more layers https://arxiv.org/abs/1911.09070 | |
def __init__(self, layers): | |
super(ControlChannel2D, self).__init__() | |
self.layers = layers # layer indices | |
def forward(self, x, outputs): | |
a = outputs[self.layers[0]].view(1,-1,1,1) | |
return a.expand_as(x) * x | |
class AlternateChannel(nn.Module): # weighted sum of 2 or more layers https://arxiv.org/abs/1911.09070 | |
def __init__(self, layers): | |
super(AlternateChannel, self).__init__() | |
self.layers = layers # layer indices | |
def forward(self, x, outputs): | |
a = outputs[self.layers[0]] | |
return torch.cat([a.expand_as(x), x], dim=1) | |
class AlternateChannel2D(nn.Module): # weighted sum of 2 or more layers https://arxiv.org/abs/1911.09070 | |
def __init__(self, layers): | |
super(AlternateChannel2D, self).__init__() | |
self.layers = layers # layer indices | |
def forward(self, x, outputs): | |
a = outputs[self.layers[0]].view(1,-1,1,1) | |
return torch.cat([a.expand_as(x), x], dim=1) | |
class SelectChannel(nn.Module): # weighted sum of 2 or more layers https://arxiv.org/abs/1911.09070 | |
def __init__(self, layers): | |
super(SelectChannel, self).__init__() | |
self.layers = layers # layer indices | |
def forward(self, x, outputs): | |
a = outputs[self.layers[0]] | |
return a.sigmoid().expand_as(x) * x | |
class SelectChannel2D(nn.Module): # weighted sum of 2 or more layers https://arxiv.org/abs/1911.09070 | |
def __init__(self, layers): | |
super(SelectChannel2D, self).__init__() | |
self.layers = layers # layer indices | |
def forward(self, x, outputs): | |
a = outputs[self.layers[0]].view(1,-1,1,1) | |
return a.sigmoid().expand_as(x) * x | |
class ScaleSpatial(nn.Module): # weighted sum of 2 or more layers https://arxiv.org/abs/1911.09070 | |
def __init__(self, layers): | |
super(ScaleSpatial, self).__init__() | |
self.layers = layers # layer indices | |
def forward(self, x, outputs): | |
a = outputs[self.layers[0]] | |
return x * a | |
class ImplicitA(nn.Module): | |
def __init__(self, channel): | |
super(ImplicitA, self).__init__() | |
self.channel = channel | |
self.implicit = nn.Parameter(torch.zeros(1, channel, 1, 1)) | |
nn.init.normal_(self.implicit, std=.02) | |
def forward(self): | |
return self.implicit | |
class ImplicitC(nn.Module): | |
def __init__(self, channel): | |
super(ImplicitC, self).__init__() | |
self.channel = channel | |
self.implicit = nn.Parameter(torch.zeros(1, channel, 1, 1)) | |
nn.init.normal_(self.implicit, std=.02) | |
def forward(self): | |
return self.implicit | |
class ImplicitM(nn.Module): | |
def __init__(self, channel): | |
super(ImplicitM, self).__init__() | |
self.channel = channel | |
self.implicit = nn.Parameter(torch.ones(1, channel, 1, 1)) | |
nn.init.normal_(self.implicit, mean=1., std=.02) | |
def forward(self): | |
return self.implicit | |
class Implicit2DA(nn.Module): | |
def __init__(self, atom, channel): | |
super(Implicit2DA, self).__init__() | |
self.channel = channel | |
self.implicit = nn.Parameter(torch.zeros(1, atom, channel, 1)) | |
nn.init.normal_(self.implicit, std=.02) | |
def forward(self): | |
return self.implicit | |
class Implicit2DC(nn.Module): | |
def __init__(self, atom, channel): | |
super(Implicit2DC, self).__init__() | |
self.channel = channel | |
self.implicit = nn.Parameter(torch.zeros(1, atom, channel, 1)) | |
nn.init.normal_(self.implicit, std=.02) | |
def forward(self): | |
return self.implicit | |
class Implicit2DM(nn.Module): | |
def __init__(self, atom, channel): | |
super(Implicit2DM, self).__init__() | |
self.channel = channel | |
self.implicit = nn.Parameter(torch.ones(1, atom, channel, 1)) | |
nn.init.normal_(self.implicit, mean=1., std=.02) | |
def forward(self): | |
return self.implicit | |