Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| from typing import Optional | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from mmcv.cnn.bricks.drop import build_dropout | |
| from .layer_scale import LayerScale | |
| from .norm import build_norm_layer | |
| class SwiGLUFFN(nn.Module): | |
| """SwiGLU FFN layer. | |
| Modified from https://github.com/facebookresearch/dinov2/blob/main/dinov2/layers/swiglu_ffn.py | |
| """ # noqa | |
| def __init__( | |
| self, | |
| embed_dims: int, | |
| feedforward_channels: Optional[int] = None, | |
| out_dims: Optional[int] = None, | |
| layer_scale_init_value: float = 0., | |
| bias: bool = True, | |
| dropout_layer: Optional[dict] = None, | |
| norm_cfg: Optional[dict] = None, | |
| add_identity: bool = True, | |
| ) -> None: | |
| super().__init__() | |
| self.embed_dims = embed_dims | |
| self.out_dims = out_dims or embed_dims | |
| hidden_dims = feedforward_channels or embed_dims | |
| self.w12 = nn.Linear(self.embed_dims, 2 * hidden_dims, bias=bias) | |
| if norm_cfg is not None: | |
| self.norm = build_norm_layer(norm_cfg, hidden_dims) | |
| else: | |
| self.norm = nn.Identity() | |
| self.w3 = nn.Linear(hidden_dims, self.out_dims, bias=bias) | |
| if layer_scale_init_value > 0: | |
| self.gamma2 = LayerScale( | |
| dim=embed_dims, layer_scale_init_value=layer_scale_init_value) | |
| else: | |
| self.gamma2 = nn.Identity() | |
| self.dropout_layer = build_dropout( | |
| dropout_layer) if dropout_layer else torch.nn.Identity() | |
| self.add_identity = add_identity | |
| def forward(self, | |
| x: torch.Tensor, | |
| identity: Optional[torch.Tensor] = None) -> torch.Tensor: | |
| x12 = self.w12(x) | |
| x1, x2 = x12.chunk(2, dim=-1) | |
| hidden = F.silu(x1) * x2 | |
| hidden = self.norm(hidden) | |
| out = self.w3(hidden) | |
| out = self.gamma2(out) | |
| out = self.dropout_layer(out) | |
| if self.out_dims != self.embed_dims or not self.add_identity: | |
| # due to the dimension inconsistence or user setting | |
| # not to apply residual operation | |
| return out | |
| if identity is None: | |
| identity = x | |
| return identity + out | |
| class SwiGLUFFNFused(SwiGLUFFN): | |
| """SwiGLU FFN layer with fusing. | |
| Modified from https://github.com/facebookresearch/dinov2/blob/main/dinov2/layers/swiglu_ffn.py | |
| """ # noqa | |
| def __init__( | |
| self, | |
| embed_dims: int, | |
| feedforward_channels: Optional[int] = None, | |
| out_dims: Optional[int] = None, | |
| layer_scale_init_value: float = 0., | |
| bias: bool = True, | |
| ) -> None: | |
| out_dims = out_dims or embed_dims | |
| feedforward_channels = feedforward_channels or embed_dims | |
| feedforward_channels = (int(feedforward_channels * 2 / 3) + 7) // 8 * 8 | |
| super().__init__( | |
| embed_dims=embed_dims, | |
| feedforward_channels=feedforward_channels, | |
| out_dims=out_dims, | |
| layer_scale_init_value=layer_scale_init_value, | |
| bias=bias, | |
| ) | |