liyy201912's picture
Upload folder using huggingface_hub
cc0dd3c
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn.bricks.drop import build_dropout
from .layer_scale import LayerScale
from .norm import build_norm_layer
class SwiGLUFFN(nn.Module):
"""SwiGLU FFN layer.
Modified from https://github.com/facebookresearch/dinov2/blob/main/dinov2/layers/swiglu_ffn.py
""" # noqa
def __init__(
self,
embed_dims: int,
feedforward_channels: Optional[int] = None,
out_dims: Optional[int] = None,
layer_scale_init_value: float = 0.,
bias: bool = True,
dropout_layer: Optional[dict] = None,
norm_cfg: Optional[dict] = None,
add_identity: bool = True,
) -> None:
super().__init__()
self.embed_dims = embed_dims
self.out_dims = out_dims or embed_dims
hidden_dims = feedforward_channels or embed_dims
self.w12 = nn.Linear(self.embed_dims, 2 * hidden_dims, bias=bias)
if norm_cfg is not None:
self.norm = build_norm_layer(norm_cfg, hidden_dims)
else:
self.norm = nn.Identity()
self.w3 = nn.Linear(hidden_dims, self.out_dims, bias=bias)
if layer_scale_init_value > 0:
self.gamma2 = LayerScale(
dim=embed_dims, layer_scale_init_value=layer_scale_init_value)
else:
self.gamma2 = nn.Identity()
self.dropout_layer = build_dropout(
dropout_layer) if dropout_layer else torch.nn.Identity()
self.add_identity = add_identity
def forward(self,
x: torch.Tensor,
identity: Optional[torch.Tensor] = None) -> torch.Tensor:
x12 = self.w12(x)
x1, x2 = x12.chunk(2, dim=-1)
hidden = F.silu(x1) * x2
hidden = self.norm(hidden)
out = self.w3(hidden)
out = self.gamma2(out)
out = self.dropout_layer(out)
if self.out_dims != self.embed_dims or not self.add_identity:
# due to the dimension inconsistence or user setting
# not to apply residual operation
return out
if identity is None:
identity = x
return identity + out
class SwiGLUFFNFused(SwiGLUFFN):
"""SwiGLU FFN layer with fusing.
Modified from https://github.com/facebookresearch/dinov2/blob/main/dinov2/layers/swiglu_ffn.py
""" # noqa
def __init__(
self,
embed_dims: int,
feedforward_channels: Optional[int] = None,
out_dims: Optional[int] = None,
layer_scale_init_value: float = 0.,
bias: bool = True,
) -> None:
out_dims = out_dims or embed_dims
feedforward_channels = feedforward_channels or embed_dims
feedforward_channels = (int(feedforward_channels * 2 / 3) + 7) // 8 * 8
super().__init__(
embed_dims=embed_dims,
feedforward_channels=feedforward_channels,
out_dims=out_dims,
layer_scale_init_value=layer_scale_init_value,
bias=bias,
)