from typing import Optional import torch from torch import nn class PreNormTransformerEncoderLayer(nn.TransformerEncoderLayer): r""" A variant of :class:`torch.nn.TransformerEncoderLayer` where layer normalization is included inside the residual branch, and performed before self-attention and feedforward layers. Refer documentation of :class:`torch.nn.TransformerEncoderLayer` for more details on the API. """ def forward( self, src: torch.Tensor, src_mask: Optional[torch.Tensor] = None, src_key_padding_mask: Optional[torch.Tensor] = None ) -> torch.Tensor: # fmt: off # We use the members (modules) from super-class, just the order of # operations is changed here. First layernorm, then attention. src2 = self.norm1(src) src2 = self.self_attn(src2, src2, src2, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0] src = src + self.dropout1(src2) # Layernorm first, then transformation through feedforward network. src2 = self.norm2(src) src2 = self.linear2(self.dropout(self.activation(self.linear1(src2)))) src = src + self.dropout2(src2) return src class PreNormTransformerDecoderLayer(nn.TransformerDecoderLayer): r""" A variant of :class:`torch.nn.TransformerDecoderLayer` where layer normalization is included inside the residual branch, and performed before self-attention and feedforward layers. Refer documentation of :class:`torch.nn.TransformerDecoderLayer` for more details on the API. """ def forward(self, tgt, memory, tgt_mask=None, memory_mask=None, tgt_key_padding_mask=None, memory_key_padding_mask=None): # fmt: off # We use the members (modules) from super-class, just the order of # operations is changed here. First layernorm, then attention. tgt2 = self.norm1(tgt) tgt2, _ = self.self_attn( tgt2, tgt2, tgt2, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask ) tgt = tgt + self.dropout1(tgt2) # Layernorm first, then decoder attention. tgt2 = self.norm2(tgt) tgt2, _ = self.multihead_attn( tgt2, memory, memory, attn_mask=memory_mask, key_padding_mask=memory_key_padding_mask ) tgt = tgt + self.dropout2(tgt2) # Layernorm first, then transformation through feedforward network. tgt2 = self.norm3(tgt) tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) tgt = tgt + self.dropout3(tgt2) return tgt