Spaces:
Sleeping
Sleeping
File size: 4,120 Bytes
b1fdcc2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 |
import torch
from torch import nn
import torch.nn.functional as F
def crop_center(h1, h2):
h1_shape = h1.size()
h2_shape = h2.size()
if h1_shape[3] == h2_shape[3]:
return h1
elif h1_shape[3] < h2_shape[3]:
raise ValueError("h1_shape[3] must be greater than h2_shape[3]")
s_time = (h1_shape[3] - h2_shape[3]) // 2
e_time = s_time + h2_shape[3]
h1 = h1[:, :, :, s_time:e_time]
return h1
class Conv2DBNActiv(nn.Module):
def __init__(self, nin, nout, ksize=3, stride=1, pad=1, dilation=1, activ=nn.ReLU):
super(Conv2DBNActiv, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(
nin,
nout,
kernel_size=ksize,
stride=stride,
padding=pad,
dilation=dilation,
bias=False,
),
nn.BatchNorm2d(nout),
activ(),
)
def __call__(self, x):
return self.conv(x)
class Encoder(nn.Module):
def __init__(self, nin, nout, ksize=3, stride=1, pad=1, activ=nn.LeakyReLU):
super(Encoder, self).__init__()
self.conv1 = Conv2DBNActiv(nin, nout, ksize, stride, pad, activ=activ)
self.conv2 = Conv2DBNActiv(nout, nout, ksize, 1, pad, activ=activ)
def __call__(self, x):
h = self.conv1(x)
h = self.conv2(h)
return h
class Decoder(nn.Module):
def __init__(self, nin, nout, ksize=3, stride=1, pad=1, activ=nn.ReLU, dropout=False):
super(Decoder, self).__init__()
self.conv1 = Conv2DBNActiv(nin, nout, ksize, 1, pad, activ=activ)
self.dropout = nn.Dropout2d(0.1) if dropout else None
def __call__(self, x, skip=None):
x = F.interpolate(x, scale_factor=2, mode="bilinear", align_corners=True)
if skip is not None:
skip = crop_center(skip, x)
x = torch.cat([x, skip], dim=1)
h = self.conv1(x)
# h = self.conv2(h)
if self.dropout is not None:
h = self.dropout(h)
return h
class ASPPModule(nn.Module):
def __init__(self, nin, nout, dilations=(4, 8, 12), activ=nn.ReLU, dropout=False):
super(ASPPModule, self).__init__()
self.conv1 = nn.Sequential(
nn.AdaptiveAvgPool2d((1, None)),
Conv2DBNActiv(nin, nout, 1, 1, 0, activ=activ),
)
self.conv2 = Conv2DBNActiv(nin, nout, 1, 1, 0, activ=activ)
self.conv3 = Conv2DBNActiv(nin, nout, 3, 1, dilations[0], dilations[0], activ=activ)
self.conv4 = Conv2DBNActiv(nin, nout, 3, 1, dilations[1], dilations[1], activ=activ)
self.conv5 = Conv2DBNActiv(nin, nout, 3, 1, dilations[2], dilations[2], activ=activ)
self.bottleneck = Conv2DBNActiv(nout * 5, nout, 1, 1, 0, activ=activ)
self.dropout = nn.Dropout2d(0.1) if dropout else None
def forward(self, x):
_, _, h, w = x.size()
feat1 = F.interpolate(self.conv1(x), size=(h, w), mode="bilinear", align_corners=True)
feat2 = self.conv2(x)
feat3 = self.conv3(x)
feat4 = self.conv4(x)
feat5 = self.conv5(x)
out = torch.cat((feat1, feat2, feat3, feat4, feat5), dim=1)
out = self.bottleneck(out)
if self.dropout is not None:
out = self.dropout(out)
return out
class LSTMModule(nn.Module):
def __init__(self, nin_conv, nin_lstm, nout_lstm):
super(LSTMModule, self).__init__()
self.conv = Conv2DBNActiv(nin_conv, 1, 1, 1, 0)
self.lstm = nn.LSTM(input_size=nin_lstm, hidden_size=nout_lstm // 2, bidirectional=True)
self.dense = nn.Sequential(
nn.Linear(nout_lstm, nin_lstm), nn.BatchNorm1d(nin_lstm), nn.ReLU()
)
def forward(self, x):
N, _, nbins, nframes = x.size()
h = self.conv(x)[:, 0] # N, nbins, nframes
h = h.permute(2, 0, 1) # nframes, N, nbins
h, _ = self.lstm(h)
h = self.dense(h.reshape(-1, h.size()[-1])) # nframes * N, nbins
h = h.reshape(nframes, N, 1, nbins)
h = h.permute(1, 2, 3, 0)
return h
|