from functools import partial from torch import nn from torch.nn.modules.transformer import * from torch.nn.modules.transformer import _get_activation_fn from torch.utils.checkpoint import checkpoint class TransformerEncoderLayer(Module): r"""TransformerEncoderLayer is made up of self-attn and feedforward network. This standard encoder layer is based on the paper "Attention Is All You Need". Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in Neural Information Processing Systems, pages 6000-6010. Users may modify or implement in a different way during application. Args: d_model: the number of expected features in the input (required). nhead: the number of heads in the multiheadattention models (required). dim_feedforward: the dimension of the feedforward network model (default=2048). dropout: the dropout value (default=0.1). activation: the activation function of intermediate layer, relu or gelu (default=relu). layer_norm_eps: the eps value in layer normalization components (default=1e-5). batch_first: If ``True``, then the input and output tensors are provided as (batch, seq, feature). Default: ``False``. Examples:: >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8) >>> src = torch.rand(10, 32, 512) >>> out = encoder_layer(src) Alternatively, when ``batch_first`` is ``True``: >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8, batch_first=True) >>> src = torch.rand(32, 10, 512) >>> out = encoder_layer(src) """ __constants__ = ['batch_first'] def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu", layer_norm_eps=1e-5, batch_first=False, pre_norm=False, device=None, dtype=None, recompute_attn=False) -> None: factory_kwargs = {'device': device, 'dtype': dtype} super().__init__() self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first, **factory_kwargs) # Implementation of Feedforward model self.linear1 = Linear(d_model, dim_feedforward, **factory_kwargs) self.dropout = Dropout(dropout) self.linear2 = Linear(dim_feedforward, d_model, **factory_kwargs) self.norm1 = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs) self.norm2 = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs) self.dropout1 = Dropout(dropout) self.dropout2 = Dropout(dropout) self.pre_norm = pre_norm self.recompute_attn = recompute_attn self.activation = _get_activation_fn(activation) def __setstate__(self, state): if 'activation' not in state: state['activation'] = F.relu super().__setstate__(state) def forward(self, src: Tensor, src_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None) -> Tensor: r"""Pass the input through the encoder layer. Args: src: the sequence to the encoder layer (required). src_mask: the mask for the src sequence (optional). src_key_padding_mask: the mask for the src keys per batch (optional). Shape: see the docs in Transformer class. """ if self.pre_norm: src_ = self.norm1(src) else: src_ = src if isinstance(src_mask, tuple): # global attention setup assert not self.self_attn.batch_first assert src_key_padding_mask is None global_src_mask, trainset_src_mask, valset_src_mask = src_mask num_global_tokens = global_src_mask.shape[0] num_train_tokens = trainset_src_mask.shape[0] global_tokens_src = src_[:num_global_tokens] train_tokens_src = src_[num_global_tokens:num_global_tokens+num_train_tokens] global_and_train_tokens_src = src_[:num_global_tokens+num_train_tokens] eval_tokens_src = src_[num_global_tokens+num_train_tokens:] attn = partial(checkpoint, self.self_attn) if self.recompute_attn else self.self_attn global_tokens_src2 = attn(global_tokens_src, global_and_train_tokens_src, global_and_train_tokens_src, None, True, global_src_mask)[0] train_tokens_src2 = attn(train_tokens_src, global_tokens_src, global_tokens_src, None, True, trainset_src_mask)[0] eval_tokens_src2 = attn(eval_tokens_src, src_, src_, None, True, valset_src_mask)[0] src2 = torch.cat([global_tokens_src2, train_tokens_src2, eval_tokens_src2], dim=0) else: if self.recompute_attn: src2 = checkpoint(self.self_attn, src_, src_, src_, src_key_padding_mask, True, src_mask)[0] else: src2 = self.self_attn(src_, src_, src_, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0] src = src + self.dropout1(src2) if not self.pre_norm: src = self.norm1(src) if self.pre_norm: src_ = self.norm2(src) else: src_ = src src2 = self.linear2(self.dropout(self.activation(self.linear1(src_)))) src = src + self.dropout2(src2) if not self.pre_norm: src = self.norm2(src) return src