Spaces:
Runtime error
Runtime error
from typing import Optional | |
import torch | |
from torch import nn | |
class PreNormTransformerEncoderLayer(nn.TransformerEncoderLayer): | |
r""" | |
A variant of :class:`torch.nn.TransformerEncoderLayer` where layer | |
normalization is included inside the residual branch, and performed before | |
self-attention and feedforward layers. | |
Refer documentation of :class:`torch.nn.TransformerEncoderLayer` for more | |
details on the API. | |
""" | |
def forward( | |
self, | |
src: torch.Tensor, | |
src_mask: Optional[torch.Tensor] = None, | |
src_key_padding_mask: Optional[torch.Tensor] = None | |
) -> torch.Tensor: | |
# fmt: off | |
# We use the members (modules) from super-class, just the order of | |
# operations is changed here. First layernorm, then attention. | |
src2 = self.norm1(src) | |
src2 = self.self_attn(src2, src2, src2, attn_mask=src_mask, | |
key_padding_mask=src_key_padding_mask)[0] | |
src = src + self.dropout1(src2) | |
# Layernorm first, then transformation through feedforward network. | |
src2 = self.norm2(src) | |
src2 = self.linear2(self.dropout(self.activation(self.linear1(src2)))) | |
src = src + self.dropout2(src2) | |
return src | |
class PreNormTransformerDecoderLayer(nn.TransformerDecoderLayer): | |
r""" | |
A variant of :class:`torch.nn.TransformerDecoderLayer` where layer | |
normalization is included inside the residual branch, and performed before | |
self-attention and feedforward layers. | |
Refer documentation of :class:`torch.nn.TransformerDecoderLayer` for more | |
details on the API. | |
""" | |
def forward(self, tgt, memory, tgt_mask=None, memory_mask=None, | |
tgt_key_padding_mask=None, memory_key_padding_mask=None): | |
# fmt: off | |
# We use the members (modules) from super-class, just the order of | |
# operations is changed here. First layernorm, then attention. | |
tgt2 = self.norm1(tgt) | |
tgt2, _ = self.self_attn( | |
tgt2, tgt2, tgt2, attn_mask=tgt_mask, | |
key_padding_mask=tgt_key_padding_mask | |
) | |
tgt = tgt + self.dropout1(tgt2) | |
# Layernorm first, then decoder attention. | |
tgt2 = self.norm2(tgt) | |
tgt2, _ = self.multihead_attn( | |
tgt2, memory, memory, attn_mask=memory_mask, | |
key_padding_mask=memory_key_padding_mask | |
) | |
tgt = tgt + self.dropout2(tgt2) | |
# Layernorm first, then transformation through feedforward network. | |
tgt2 = self.norm3(tgt) | |
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) | |
tgt = tgt + self.dropout3(tgt2) | |
return tgt | |