# Copyright (c) OpenMMLab. All rights reserved. from typing import Optional import torch import torch.nn as nn import torch.nn.functional as F from mmcv.cnn.bricks.drop import build_dropout from .layer_scale import LayerScale from .norm import build_norm_layer class SwiGLUFFN(nn.Module): """SwiGLU FFN layer. Modified from https://github.com/facebookresearch/dinov2/blob/main/dinov2/layers/swiglu_ffn.py """ # noqa def __init__( self, embed_dims: int, feedforward_channels: Optional[int] = None, out_dims: Optional[int] = None, layer_scale_init_value: float = 0., bias: bool = True, dropout_layer: Optional[dict] = None, norm_cfg: Optional[dict] = None, add_identity: bool = True, ) -> None: super().__init__() self.embed_dims = embed_dims self.out_dims = out_dims or embed_dims hidden_dims = feedforward_channels or embed_dims self.w12 = nn.Linear(self.embed_dims, 2 * hidden_dims, bias=bias) if norm_cfg is not None: self.norm = build_norm_layer(norm_cfg, hidden_dims) else: self.norm = nn.Identity() self.w3 = nn.Linear(hidden_dims, self.out_dims, bias=bias) if layer_scale_init_value > 0: self.gamma2 = LayerScale( dim=embed_dims, layer_scale_init_value=layer_scale_init_value) else: self.gamma2 = nn.Identity() self.dropout_layer = build_dropout( dropout_layer) if dropout_layer else torch.nn.Identity() self.add_identity = add_identity def forward(self, x: torch.Tensor, identity: Optional[torch.Tensor] = None) -> torch.Tensor: x12 = self.w12(x) x1, x2 = x12.chunk(2, dim=-1) hidden = F.silu(x1) * x2 hidden = self.norm(hidden) out = self.w3(hidden) out = self.gamma2(out) out = self.dropout_layer(out) if self.out_dims != self.embed_dims or not self.add_identity: # due to the dimension inconsistence or user setting # not to apply residual operation return out if identity is None: identity = x return identity + out class SwiGLUFFNFused(SwiGLUFFN): """SwiGLU FFN layer with fusing. Modified from https://github.com/facebookresearch/dinov2/blob/main/dinov2/layers/swiglu_ffn.py """ # noqa def __init__( self, embed_dims: int, feedforward_channels: Optional[int] = None, out_dims: Optional[int] = None, layer_scale_init_value: float = 0., bias: bool = True, ) -> None: out_dims = out_dims or embed_dims feedforward_channels = feedforward_channels or embed_dims feedforward_channels = (int(feedforward_channels * 2 / 3) + 7) // 8 * 8 super().__init__( embed_dims=embed_dims, feedforward_channels=feedforward_channels, out_dims=out_dims, layer_scale_init_value=layer_scale_init_value, bias=bias, )