File size: 4,304 Bytes
3978e51 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 |
import warnings
import torch
import torch.backends.cuda
from torch import nn
from torch.nn.modules import rnn
from torch.utils.checkpoint import checkpoint_sequential
class TimeFrequencyModellingModule(nn.Module):
def __init__(self) -> None:
super().__init__()
class ResidualRNN(nn.Module):
def __init__(
self,
emb_dim: int,
rnn_dim: int,
bidirectional: bool = True,
rnn_type: str = "LSTM",
use_batch_trick: bool = True,
use_layer_norm: bool = True,
) -> None:
# n_group is the size of the 2nd dim
super().__init__()
assert use_layer_norm
assert use_batch_trick
self.use_layer_norm = use_layer_norm
self.norm = nn.LayerNorm(emb_dim)
self.rnn = rnn.__dict__[rnn_type](
input_size=emb_dim,
hidden_size=rnn_dim,
num_layers=1,
batch_first=True,
bidirectional=bidirectional,
)
self.fc = nn.Linear(
in_features=rnn_dim * (2 if bidirectional else 1), out_features=emb_dim
)
self.use_batch_trick = use_batch_trick
if not self.use_batch_trick:
warnings.warn("NOT USING BATCH TRICK IS EXTREMELY SLOW!!")
def forward(self, z):
# z = (batch, n_uncrossed, n_across, emb_dim)
z0 = torch.clone(z)
z = self.norm(z)
batch, n_uncrossed, n_across, emb_dim = z.shape
z = torch.reshape(z, (batch * n_uncrossed, n_across, emb_dim))
z = self.rnn(z)[0]
z = torch.reshape(z, (batch, n_uncrossed, n_across, -1))
z = self.fc(z) # (batch, n_uncrossed, n_across, emb_dim)
z = z + z0
return z
class Transpose(nn.Module):
def __init__(self, dim0: int, dim1: int) -> None:
super().__init__()
self.dim0 = dim0
self.dim1 = dim1
def forward(self, z):
return z.transpose(self.dim0, self.dim1)
class SeqBandModellingModule(TimeFrequencyModellingModule):
def __init__(
self,
n_modules: int = 12,
emb_dim: int = 128,
rnn_dim: int = 256,
bidirectional: bool = True,
rnn_type: str = "LSTM",
parallel_mode=False,
) -> None:
super().__init__()
self.n_modules = n_modules
if parallel_mode:
self.seqband = nn.ModuleList([])
for _ in range(n_modules):
self.seqband.append(
nn.ModuleList(
[
ResidualRNN(
emb_dim=emb_dim,
rnn_dim=rnn_dim,
bidirectional=bidirectional,
rnn_type=rnn_type,
),
ResidualRNN(
emb_dim=emb_dim,
rnn_dim=rnn_dim,
bidirectional=bidirectional,
rnn_type=rnn_type,
),
]
)
)
else:
seqband = []
for _ in range(2 * n_modules):
seqband += [
ResidualRNN(
emb_dim=emb_dim,
rnn_dim=rnn_dim,
bidirectional=bidirectional,
rnn_type=rnn_type,
),
Transpose(1, 2),
]
self.seqband = nn.Sequential(*seqband)
self.parallel_mode = parallel_mode
def forward(self, z):
# z = (batch, n_bands, n_time, emb_dim)
if self.parallel_mode:
for sbm_pair in self.seqband:
# z: (batch, n_bands, n_time, emb_dim)
sbm_t, sbm_f = sbm_pair[0], sbm_pair[1]
zt = sbm_t(z) # (batch, n_bands, n_time, emb_dim)
zf = sbm_f(z.transpose(1, 2)) # (batch, n_time, n_bands, emb_dim)
z = zt + zf.transpose(1, 2)
else:
z = checkpoint_sequential(
self.seqband, self.n_modules, z, use_reentrant=False
)
q = z
return q # (batch, n_bands, n_time, emb_dim)
|