Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
from typing import Optional | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from mmcv.cnn.bricks.drop import build_dropout | |
from .layer_scale import LayerScale | |
from .norm import build_norm_layer | |
class SwiGLUFFN(nn.Module): | |
"""SwiGLU FFN layer. | |
Modified from https://github.com/facebookresearch/dinov2/blob/main/dinov2/layers/swiglu_ffn.py | |
""" # noqa | |
def __init__( | |
self, | |
embed_dims: int, | |
feedforward_channels: Optional[int] = None, | |
out_dims: Optional[int] = None, | |
layer_scale_init_value: float = 0., | |
bias: bool = True, | |
dropout_layer: Optional[dict] = None, | |
norm_cfg: Optional[dict] = None, | |
add_identity: bool = True, | |
) -> None: | |
super().__init__() | |
self.embed_dims = embed_dims | |
self.out_dims = out_dims or embed_dims | |
hidden_dims = feedforward_channels or embed_dims | |
self.w12 = nn.Linear(self.embed_dims, 2 * hidden_dims, bias=bias) | |
if norm_cfg is not None: | |
self.norm = build_norm_layer(norm_cfg, hidden_dims) | |
else: | |
self.norm = nn.Identity() | |
self.w3 = nn.Linear(hidden_dims, self.out_dims, bias=bias) | |
if layer_scale_init_value > 0: | |
self.gamma2 = LayerScale( | |
dim=embed_dims, layer_scale_init_value=layer_scale_init_value) | |
else: | |
self.gamma2 = nn.Identity() | |
self.dropout_layer = build_dropout( | |
dropout_layer) if dropout_layer else torch.nn.Identity() | |
self.add_identity = add_identity | |
def forward(self, | |
x: torch.Tensor, | |
identity: Optional[torch.Tensor] = None) -> torch.Tensor: | |
x12 = self.w12(x) | |
x1, x2 = x12.chunk(2, dim=-1) | |
hidden = F.silu(x1) * x2 | |
hidden = self.norm(hidden) | |
out = self.w3(hidden) | |
out = self.gamma2(out) | |
out = self.dropout_layer(out) | |
if self.out_dims != self.embed_dims or not self.add_identity: | |
# due to the dimension inconsistence or user setting | |
# not to apply residual operation | |
return out | |
if identity is None: | |
identity = x | |
return identity + out | |
class SwiGLUFFNFused(SwiGLUFFN): | |
"""SwiGLU FFN layer with fusing. | |
Modified from https://github.com/facebookresearch/dinov2/blob/main/dinov2/layers/swiglu_ffn.py | |
""" # noqa | |
def __init__( | |
self, | |
embed_dims: int, | |
feedforward_channels: Optional[int] = None, | |
out_dims: Optional[int] = None, | |
layer_scale_init_value: float = 0., | |
bias: bool = True, | |
) -> None: | |
out_dims = out_dims or embed_dims | |
feedforward_channels = feedforward_channels or embed_dims | |
feedforward_channels = (int(feedforward_channels * 2 / 3) + 7) // 8 * 8 | |
super().__init__( | |
embed_dims=embed_dims, | |
feedforward_channels=feedforward_channels, | |
out_dims=out_dims, | |
layer_scale_init_value=layer_scale_init_value, | |
bias=bias, | |
) | |