|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
from torch.autograd import Variable |
|
import torch.nn as nn |
|
|
|
|
|
EPS = torch.finfo(torch.get_default_dtype()).eps |
|
|
|
|
|
class SingleRNN(nn.Module): |
|
"""Container module for a single RNN layer. |
|
|
|
args: |
|
rnn_type: string, select from 'RNN', 'LSTM' and 'GRU'. |
|
input_size: int, dimension of the input feature. The input should have shape |
|
(batch, seq_len, input_size). |
|
hidden_size: int, dimension of the hidden state. |
|
dropout: float, dropout ratio. Default is 0. |
|
bidirectional: bool, whether the RNN layers are bidirectional. Default is False. |
|
""" |
|
|
|
def __init__( |
|
self, rnn_type, input_size, hidden_size, dropout=0, bidirectional=False |
|
): |
|
super().__init__() |
|
|
|
rnn_type = rnn_type.upper() |
|
|
|
assert rnn_type in [ |
|
"RNN", |
|
"LSTM", |
|
"GRU", |
|
], f"Only support 'RNN', 'LSTM' and 'GRU', current type: {rnn_type}" |
|
|
|
self.rnn_type = rnn_type |
|
self.input_size = input_size |
|
self.hidden_size = hidden_size |
|
self.num_direction = int(bidirectional) + 1 |
|
|
|
self.rnn = getattr(nn, rnn_type)( |
|
input_size, |
|
hidden_size, |
|
1, |
|
batch_first=True, |
|
bidirectional=bidirectional, |
|
) |
|
|
|
self.dropout = nn.Dropout(p=dropout) |
|
|
|
|
|
self.proj = nn.Linear(hidden_size * self.num_direction, input_size) |
|
|
|
def forward(self, input): |
|
|
|
|
|
output = input |
|
rnn_output, _ = self.rnn(output) |
|
rnn_output = self.dropout(rnn_output) |
|
rnn_output = self.proj( |
|
rnn_output.contiguous().view(-1, rnn_output.shape[2]) |
|
).view(output.shape) |
|
return rnn_output |
|
|
|
|
|
|
|
class DPRNN(nn.Module): |
|
"""Deep dual-path RNN. |
|
|
|
args: |
|
rnn_type: string, select from 'RNN', 'LSTM' and 'GRU'. |
|
input_size: int, dimension of the input feature. The input should have shape |
|
(batch, seq_len, input_size). |
|
hidden_size: int, dimension of the hidden state. |
|
output_size: int, dimension of the output size. |
|
dropout: float, dropout ratio. Default is 0. |
|
num_layers: int, number of stacked RNN layers. Default is 1. |
|
bidirectional: bool, whether the RNN layers are bidirectional. Default is True. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
rnn_type, |
|
input_size, |
|
hidden_size, |
|
output_size, |
|
dropout=0, |
|
num_layers=1, |
|
bidirectional=True, |
|
): |
|
super().__init__() |
|
|
|
self.input_size = input_size |
|
self.output_size = output_size |
|
self.hidden_size = hidden_size |
|
|
|
|
|
self.row_rnn = nn.ModuleList([]) |
|
self.col_rnn = nn.ModuleList([]) |
|
self.row_norm = nn.ModuleList([]) |
|
self.col_norm = nn.ModuleList([]) |
|
for i in range(num_layers): |
|
self.row_rnn.append( |
|
SingleRNN( |
|
rnn_type, input_size, hidden_size, dropout, bidirectional=True |
|
) |
|
) |
|
self.col_rnn.append( |
|
SingleRNN( |
|
rnn_type, |
|
input_size, |
|
hidden_size, |
|
dropout, |
|
bidirectional=bidirectional, |
|
) |
|
) |
|
self.row_norm.append(nn.GroupNorm(1, input_size, eps=1e-8)) |
|
|
|
|
|
self.col_norm.append(nn.GroupNorm(1, input_size, eps=1e-8)) |
|
|
|
|
|
self.output = nn.Sequential(nn.PReLU(), nn.Conv2d(input_size, output_size, 1)) |
|
|
|
def forward(self, input): |
|
|
|
|
|
|
|
|
|
batch_size, _, dim1, dim2 = input.shape |
|
output = input |
|
for i in range(len(self.row_rnn)): |
|
row_input = ( |
|
output.permute(0, 3, 2, 1) |
|
.contiguous() |
|
.view(batch_size * dim2, dim1, -1) |
|
) |
|
row_output = self.row_rnn[i](row_input) |
|
row_output = ( |
|
row_output.view(batch_size, dim2, dim1, -1) |
|
.permute(0, 3, 2, 1) |
|
.contiguous() |
|
) |
|
row_output = self.row_norm[i](row_output) |
|
output = output + row_output |
|
|
|
col_input = ( |
|
output.permute(0, 2, 3, 1) |
|
.contiguous() |
|
.view(batch_size * dim1, dim2, -1) |
|
) |
|
col_output = self.col_rnn[i](col_input) |
|
col_output = ( |
|
col_output.view(batch_size, dim1, dim2, -1) |
|
.permute(0, 3, 1, 2) |
|
.contiguous() |
|
) |
|
col_output = self.col_norm[i](col_output) |
|
output = output + col_output |
|
|
|
output = self.output(output) |
|
|
|
return output |
|
|
|
|
|
def _pad_segment(input, segment_size): |
|
|
|
batch_size, dim, seq_len = input.shape |
|
segment_stride = segment_size // 2 |
|
|
|
rest = segment_size - (segment_stride + seq_len % segment_size) % segment_size |
|
if rest > 0: |
|
pad = Variable(torch.zeros(batch_size, dim, rest)).type(input.type()) |
|
input = torch.cat([input, pad], 2) |
|
|
|
pad_aux = Variable(torch.zeros(batch_size, dim, segment_stride)).type(input.type()) |
|
input = torch.cat([pad_aux, input, pad_aux], 2) |
|
|
|
return input, rest |
|
|
|
|
|
def split_feature(input, segment_size): |
|
|
|
|
|
|
|
input, rest = _pad_segment(input, segment_size) |
|
batch_size, dim, seq_len = input.shape |
|
segment_stride = segment_size // 2 |
|
|
|
segments1 = ( |
|
input[:, :, :-segment_stride] |
|
.contiguous() |
|
.view(batch_size, dim, -1, segment_size) |
|
) |
|
segments2 = ( |
|
input[:, :, segment_stride:] |
|
.contiguous() |
|
.view(batch_size, dim, -1, segment_size) |
|
) |
|
segments = ( |
|
torch.cat([segments1, segments2], 3) |
|
.view(batch_size, dim, -1, segment_size) |
|
.transpose(2, 3) |
|
) |
|
|
|
return segments.contiguous(), rest |
|
|
|
|
|
def merge_feature(input, rest): |
|
|
|
|
|
|
|
batch_size, dim, segment_size, _ = input.shape |
|
segment_stride = segment_size // 2 |
|
input = ( |
|
input.transpose(2, 3).contiguous().view(batch_size, dim, -1, segment_size * 2) |
|
) |
|
|
|
input1 = ( |
|
input[:, :, :, :segment_size] |
|
.contiguous() |
|
.view(batch_size, dim, -1)[:, :, segment_stride:] |
|
) |
|
input2 = ( |
|
input[:, :, :, segment_size:] |
|
.contiguous() |
|
.view(batch_size, dim, -1)[:, :, :-segment_stride] |
|
) |
|
|
|
output = input1 + input2 |
|
if rest > 0: |
|
output = output[:, :, :-rest] |
|
|
|
return output.contiguous() |
|
|