Shivam Mehta
Adding code
3c10b34
raw
history blame
8.59 kB
# Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
import torch.nn.functional as F
from conformer.conformer import Attention as RelAttention
from einops import rearrange
class PositionalEmbedding(nn.Module):
def __init__(self, demb):
super().__init__()
self.demb = demb
inv_freq = 1 / (10000 ** (torch.arange(0.0, demb, 2.0) / demb))
self.register_buffer("inv_freq", inv_freq)
def forward(self, pos_seq, bsz=None):
sinusoid_inp = torch.matmul(torch.unsqueeze(pos_seq, -1), torch.unsqueeze(self.inv_freq, 0))
pos_emb = torch.cat([sinusoid_inp.sin(), sinusoid_inp.cos()], dim=1)
if bsz is not None:
return pos_emb[None, :, :].expand(bsz, -1, -1)
else:
return pos_emb[None, :, :]
class PositionwiseConvFF(nn.Module):
def __init__(self, d_model, d_inner, kernel_size, dropout, pre_lnorm=False):
super().__init__()
self.d_model = d_model
self.d_inner = d_inner
self.dropout = dropout
self.CoreNet = nn.Sequential(
nn.Conv1d(d_model, d_inner, kernel_size, 1, (kernel_size // 2)),
nn.ReLU(),
# nn.Dropout(dropout), # worse convergence
nn.Conv1d(d_inner, d_model, kernel_size, 1, (kernel_size // 2)),
nn.Dropout(dropout),
)
self.layer_norm = nn.LayerNorm(d_model)
self.pre_lnorm = pre_lnorm
def forward(self, inp):
return self._forward(inp)
def _forward(self, inp):
if self.pre_lnorm:
# layer normalization + positionwise feed-forward
# core_out = inp
core_out = self.CoreNet(self.layer_norm(inp).transpose(1, 2))
core_out = core_out.transpose(1, 2)
# residual connection
output = core_out + inp
else:
# positionwise feed-forward
core_out = inp.transpose(1, 2)
core_out = self.CoreNet(core_out)
core_out = core_out.transpose(1, 2)
# residual connection + layer normalization
output = self.layer_norm(inp + core_out).to(inp.dtype)
return output
class MultiHeadAttn(nn.Module):
def __init__(
self, n_head, d_model, d_head, dropout, rel_attention, dropatt=0.1, pre_lnorm=True, rel_window_size=10
):
super().__init__()
self.n_head = n_head
self.d_model = d_model
self.d_head = d_head
self.scale = 1 / (d_head**0.5)
self.pre_lnorm = pre_lnorm
self.rel_attention = rel_attention
if rel_attention:
self.attn = RelAttention(d_model, n_head, d_head, dropout, max_pos_emb=rel_window_size)
else:
self.qkv_net = nn.Linear(d_model, 3 * n_head * d_head)
self.drop = nn.Dropout(dropout)
self.dropatt = nn.Dropout(dropatt)
self.o_net = nn.Linear(n_head * d_head, d_model, bias=False)
self.layer_norm = nn.LayerNorm(d_model)
def forward(self, inp, attn_mask=None):
return self._forward(inp, attn_mask)
def _forward(self, inp, attn_mask=None):
residual = inp
if self.pre_lnorm:
# layer normalization
inp = self.layer_norm(inp)
if not self.rel_attention:
n_head, d_head = self.n_head, self.d_head
head_q, head_k, head_v = torch.chunk(self.qkv_net(inp), 3, dim=2)
head_q = head_q.view(inp.size(0), inp.size(1), n_head, d_head)
head_k = head_k.view(inp.size(0), inp.size(1), n_head, d_head)
head_v = head_v.view(inp.size(0), inp.size(1), n_head, d_head)
q = head_q.permute(2, 0, 1, 3).reshape(-1, inp.size(1), d_head)
k = head_k.permute(2, 0, 1, 3).reshape(-1, inp.size(1), d_head)
v = head_v.permute(2, 0, 1, 3).reshape(-1, inp.size(1), d_head)
attn_score = torch.bmm(q, k.transpose(1, 2))
attn_score.mul_(self.scale)
if attn_mask is not None:
attn_mask = attn_mask.unsqueeze(1).to(attn_score.dtype)
attn_mask = attn_mask.repeat(n_head, attn_mask.size(2), 1)
attn_score.masked_fill_(attn_mask.to(torch.bool), -float("inf"))
attn_prob = F.softmax(attn_score, dim=2)
attn_prob = self.dropatt(attn_prob)
attn_vec = torch.bmm(attn_prob, v)
attn_vec = attn_vec.view(n_head, inp.size(0), inp.size(1), d_head)
attn_vec = attn_vec.permute(1, 2, 0, 3).contiguous().view(inp.size(0), inp.size(1), n_head * d_head)
# linear projection
attn_out = self.o_net(attn_vec)
attn_out = self.drop(attn_out)
else:
attn_out = self.attn(inp, mask=attn_mask)
if self.pre_lnorm:
# residual connection
output = residual + attn_out
else:
# residual connection + layer normalization
output = self.layer_norm(residual + attn_out)
output = output.to(attn_out.dtype)
return output
class TransformerLayer(nn.Module):
def __init__(self, n_head, d_model, d_head, d_inner, kernel_size, dropout, **kwargs):
super().__init__()
self.dec_attn = MultiHeadAttn(n_head, d_model, d_head, dropout, **kwargs)
self.pos_ff = PositionwiseConvFF(d_model, d_inner, kernel_size, dropout, pre_lnorm=kwargs.get("pre_lnorm"))
def forward(self, dec_inp, mask=None):
output = self.dec_attn(dec_inp, attn_mask=~mask.squeeze(2))
output *= mask
output = self.pos_ff(output)
output *= mask
return output
class FFTransformer(nn.Module):
def __init__(
self,
n_layer,
n_head,
hidden_channels,
d_head,
d_inner,
kernel_size,
dropout,
dropatt,
dropemb=0.0,
embed_input=False,
n_embed=None,
d_embed=None,
padding_idx=0,
pre_lnorm=True,
rel_attention=True,
rel_window_size=10,
):
super().__init__()
self.d_model = hidden_channels
self.n_head = n_head
self.d_head = d_head
self.padding_idx = padding_idx
if embed_input:
self.word_emb = nn.Embedding(n_embed, d_embed or hidden_channels, padding_idx=self.padding_idx)
else:
self.word_emb = None
self.rel_attention = rel_attention
if not rel_attention:
self.pos_emb = PositionalEmbedding(self.d_model)
self.drop = nn.Dropout(dropemb)
self.layers = nn.ModuleList()
for _ in range(n_layer):
self.layers.append(
TransformerLayer(
n_head,
hidden_channels,
d_head,
d_inner,
kernel_size,
dropout,
dropatt=dropatt,
pre_lnorm=pre_lnorm,
rel_attention=rel_attention,
rel_window_size=rel_window_size,
)
)
def forward(self, dec_inp, mask=None, conditioning=0):
inp = dec_inp.transpose(1, 2)
mask = mask.bool().squeeze(1).unsqueeze(2)
# if self.word_emb is None:
# inp = dec_inp
# mask = sequence_mask(seq_lens, inp.shape[1], device=seq_lens.device, dtype=seq_lens.dtype).unsqueeze(2)
# else:
# inp = self.word_emb(dec_inp)
# # [bsz x L x 1]
# mask = (dec_inp != self.padding_idx).unsqueeze(2)
if not self.rel_attention:
pos_seq = torch.arange(inp.size(1), device=inp.device).to(inp.dtype)
pos_emb = self.pos_emb(pos_seq) * mask
else:
pos_emb = 0
out = self.drop(inp + pos_emb + conditioning)
for layer in self.layers:
out = layer(out, mask=mask)
# out = self.drop(out)
return rearrange(out, "b l h -> b h l")