Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright (c) 2025 ByteDance Ltd. and/or its affiliates | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| from typing import Callable | |
| from torch import Tensor, nn | |
| from .attention import Attention, LayerScale, Mlp | |
| class Block(nn.Module): | |
| def __init__( | |
| self, | |
| dim: int, | |
| num_heads: int, | |
| mlp_ratio: float = 4.0, | |
| qkv_bias: bool = True, | |
| proj_bias: bool = True, | |
| ffn_bias: bool = True, | |
| drop: float = 0.0, | |
| attn_drop: float = 0.0, | |
| init_values=None, | |
| drop_path: float = 0.0, | |
| act_layer: Callable[..., nn.Module] = nn.GELU, | |
| norm_layer: Callable[..., nn.Module] = nn.LayerNorm, | |
| attn_class: Callable[..., nn.Module] = Attention, | |
| ffn_layer: Callable[..., nn.Module] = Mlp, | |
| qk_norm: bool = False, | |
| rope=None, | |
| ) -> None: | |
| super().__init__() | |
| self.norm1 = norm_layer(dim) | |
| self.attn = attn_class( | |
| dim, | |
| num_heads=num_heads, | |
| qkv_bias=qkv_bias, | |
| proj_bias=proj_bias, | |
| attn_drop=attn_drop, | |
| proj_drop=drop, | |
| qk_norm=qk_norm, | |
| rope=rope, | |
| ) | |
| self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() | |
| self.norm2 = norm_layer(dim) | |
| mlp_hidden_dim = int(dim * mlp_ratio) | |
| self.mlp = ffn_layer( | |
| in_features=dim, | |
| hidden_features=mlp_hidden_dim, | |
| act_layer=act_layer, | |
| drop=drop, | |
| bias=ffn_bias, | |
| ) | |
| self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() | |
| self.sample_drop_ratio = 0.0 # Equivalent to always having drop_path=0 | |
| def forward(self, x: Tensor, pos=None, attn_mask=None) -> Tensor: | |
| def attn_residual_func(x: Tensor, pos=None, attn_mask=None) -> Tensor: | |
| return self.ls1(self.attn(self.norm1(x), pos=pos, attn_mask=attn_mask)) | |
| def ffn_residual_func(x: Tensor) -> Tensor: | |
| return self.ls2(self.mlp(self.norm2(x))) | |
| # drop_path is always 0, so always take the else branch | |
| x = x + attn_residual_func(x, pos=pos, attn_mask=attn_mask) | |
| x = x + ffn_residual_func(x) | |
| return x | |