Spaces:
Paused
Paused
| # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates | |
| # // | |
| # // Licensed under the Apache License, Version 2.0 (the "License"); | |
| # // you may not use this file except in compliance with the License. | |
| # // You may obtain a copy of the License at | |
| # // | |
| # // http://www.apache.org/licenses/LICENSE-2.0 | |
| # // | |
| # // Unless required by applicable law or agreed to in writing, software | |
| # // distributed under the License is distributed on an "AS IS" BASIS, | |
| # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # // See the License for the specific language governing permissions and | |
| # // limitations under the License. | |
| from typing import Tuple | |
| import torch | |
| import torch.nn as nn | |
| # from ..cache import Cache | |
| from common.cache import Cache | |
| from .attention.mmattn import NaSwinAttention | |
| from ..mm import MMArg | |
| from ..modulation import ada_layer_type | |
| from ..normalization import norm_layer_type | |
| from ..mm import MMArg, MMModule | |
| from ..mlp import get_mlp | |
| class NaMMSRTransformerBlock(nn.Module): | |
| def __init__( | |
| self, | |
| *, | |
| vid_dim: int, | |
| txt_dim: int, | |
| emb_dim: int, | |
| heads: int, | |
| head_dim: int, | |
| expand_ratio: int, | |
| norm: norm_layer_type, | |
| norm_eps: float, | |
| ada: ada_layer_type, | |
| qk_bias: bool, | |
| qk_norm: norm_layer_type, | |
| mlp_type: str, | |
| shared_weights: bool, | |
| rope_type: str, | |
| rope_dim: int, | |
| is_last_layer: bool, | |
| **kwargs, | |
| ): | |
| super().__init__() | |
| dim = MMArg(vid_dim, txt_dim) | |
| self.attn_norm = MMModule(norm, dim=dim, eps=norm_eps, elementwise_affine=False, shared_weights=shared_weights,) | |
| self.attn = NaSwinAttention( | |
| vid_dim=vid_dim, | |
| txt_dim=txt_dim, | |
| heads=heads, | |
| head_dim=head_dim, | |
| qk_bias=qk_bias, | |
| qk_norm=qk_norm, | |
| qk_norm_eps=norm_eps, | |
| rope_type=rope_type, | |
| rope_dim=rope_dim, | |
| shared_weights=shared_weights, | |
| window=kwargs.pop("window", None), | |
| window_method=kwargs.pop("window_method", None), | |
| ) | |
| self.mlp_norm = MMModule(norm, dim=dim, eps=norm_eps, elementwise_affine=False, shared_weights=shared_weights, vid_only=is_last_layer) | |
| self.mlp = MMModule( | |
| get_mlp(mlp_type), | |
| dim=dim, | |
| expand_ratio=expand_ratio, | |
| shared_weights=shared_weights, | |
| vid_only=is_last_layer | |
| ) | |
| self.ada = MMModule(ada, dim=dim, emb_dim=emb_dim, layers=["attn", "mlp"], shared_weights=shared_weights, vid_only=is_last_layer) | |
| self.is_last_layer = is_last_layer | |
| def forward( | |
| self, | |
| vid: torch.FloatTensor, # l c | |
| txt: torch.FloatTensor, # l c | |
| vid_shape: torch.LongTensor, # b 3 | |
| txt_shape: torch.LongTensor, # b 1 | |
| emb: torch.FloatTensor, | |
| cache: Cache, | |
| ) -> Tuple[ | |
| torch.FloatTensor, | |
| torch.FloatTensor, | |
| torch.LongTensor, | |
| torch.LongTensor, | |
| ]: | |
| hid_len = MMArg( | |
| cache("vid_len", lambda: vid_shape.prod(-1)), | |
| cache("txt_len", lambda: txt_shape.prod(-1)), | |
| ) | |
| ada_kwargs = { | |
| "emb": emb, | |
| "hid_len": hid_len, | |
| "cache": cache, | |
| "branch_tag": MMArg("vid", "txt"), | |
| } | |
| vid_attn, txt_attn = self.attn_norm(vid, txt) | |
| vid_attn, txt_attn = self.ada(vid_attn, txt_attn, layer="attn", mode="in", **ada_kwargs) | |
| vid_attn, txt_attn = self.attn(vid_attn, txt_attn, vid_shape, txt_shape, cache) | |
| vid_attn, txt_attn = self.ada(vid_attn, txt_attn, layer="attn", mode="out", **ada_kwargs) | |
| vid_attn, txt_attn = (vid_attn + vid), (txt_attn + txt) | |
| vid_mlp, txt_mlp = self.mlp_norm(vid_attn, txt_attn) | |
| vid_mlp, txt_mlp = self.ada(vid_mlp, txt_mlp, layer="mlp", mode="in", **ada_kwargs) | |
| vid_mlp, txt_mlp = self.mlp(vid_mlp, txt_mlp) | |
| vid_mlp, txt_mlp = self.ada(vid_mlp, txt_mlp, layer="mlp", mode="out", **ada_kwargs) | |
| vid_mlp, txt_mlp = (vid_mlp + vid_attn), (txt_mlp + txt_attn) | |
| return vid_mlp, txt_mlp, vid_shape, txt_shape | |