|  |  | 
					
						
						|  |  | 
					
						
						|  | import torch | 
					
						
						|  | import torch.nn as nn | 
					
						
						|  | from functools import partial | 
					
						
						|  | from torch import Tensor | 
					
						
						|  | from typing import Optional | 
					
						
						|  |  | 
					
						
						|  | from timm.models.vision_transformer import _cfg | 
					
						
						|  | from timm.models.layers import trunc_normal_ | 
					
						
						|  |  | 
					
						
						|  | from timm.models.layers import DropPath, to_2tuple | 
					
						
						|  | from timm.models.vision_transformer import _load_weights | 
					
						
						|  |  | 
					
						
						|  | import math | 
					
						
						|  |  | 
					
						
						|  | from mamba_ssm.modules.mamba_simple import Mamba | 
					
						
						|  |  | 
					
						
						|  | try: | 
					
						
						|  | from mamba_ssm.ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn | 
					
						
						|  | except ImportError: | 
					
						
						|  | RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class PatchEmbed(nn.Module): | 
					
						
						|  | """ 2D Image to Patch Embedding | 
					
						
						|  | """ | 
					
						
						|  | def __init__(self, img_size=224, patch_size=16, stride=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True): | 
					
						
						|  | super().__init__() | 
					
						
						|  | img_size = to_2tuple(img_size) | 
					
						
						|  | patch_size = to_2tuple(patch_size) | 
					
						
						|  | self.img_size = img_size | 
					
						
						|  | self.patch_size = patch_size | 
					
						
						|  | self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) | 
					
						
						|  | self.grid_size = ((img_size[0] - patch_size[0]) // stride + 1, (img_size[1] - patch_size[1]) // stride + 1) | 
					
						
						|  | self.num_patches = self.grid_size[0] * self.grid_size[1] | 
					
						
						|  | self.flatten = flatten | 
					
						
						|  |  | 
					
						
						|  | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride) | 
					
						
						|  | self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() | 
					
						
						|  |  | 
					
						
						|  | def forward(self, x): | 
					
						
						|  | B, C, H, W = x.shape | 
					
						
						|  | assert H == self.img_size[0] and W == self.img_size[1], \ | 
					
						
						|  | f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." | 
					
						
						|  | x = self.proj(x) | 
					
						
						|  | if self.flatten: | 
					
						
						|  | x = x.flatten(2).transpose(1, 2) | 
					
						
						|  | x = self.norm(x) | 
					
						
						|  | return x | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class Block(nn.Module): | 
					
						
						|  | def __init__( | 
					
						
						|  | self, dim, mixer_cls, norm_cls=nn.LayerNorm, fused_add_norm=False, residual_in_fp32=False,drop_path=0., | 
					
						
						|  | ): | 
					
						
						|  | """ | 
					
						
						|  | Simple block wrapping a mixer class with LayerNorm/RMSNorm and residual connection" | 
					
						
						|  |  | 
					
						
						|  | This Block has a slightly different structure compared to a regular | 
					
						
						|  | prenorm Transformer block. | 
					
						
						|  | The standard block is: LN -> MHA/MLP -> Add. | 
					
						
						|  | [Ref: https://arxiv.org/abs/2002.04745] | 
					
						
						|  | Here we have: Add -> LN -> Mixer, returning both | 
					
						
						|  | the hidden_states (output of the mixer) and the residual. | 
					
						
						|  | This is purely for performance reasons, as we can fuse add and LayerNorm. | 
					
						
						|  | The residual needs to be provided (except for the very first block). | 
					
						
						|  | """ | 
					
						
						|  | super().__init__() | 
					
						
						|  | self.residual_in_fp32 = residual_in_fp32 | 
					
						
						|  | self.fused_add_norm = fused_add_norm | 
					
						
						|  | self.mixer = mixer_cls(dim) | 
					
						
						|  | self.norm = norm_cls(dim) | 
					
						
						|  | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() | 
					
						
						|  | if self.fused_add_norm: | 
					
						
						|  | assert RMSNorm is not None, "RMSNorm import fails" | 
					
						
						|  | assert isinstance( | 
					
						
						|  | self.norm, (nn.LayerNorm, RMSNorm) | 
					
						
						|  | ), "Only LayerNorm and RMSNorm are supported for fused_add_norm" | 
					
						
						|  |  | 
					
						
						|  | def forward( | 
					
						
						|  | self, hidden_states: Tensor, residual: Optional[Tensor] = None, inference_params=None | 
					
						
						|  | ): | 
					
						
						|  | r"""Pass the input through the encoder layer. | 
					
						
						|  |  | 
					
						
						|  | Args: | 
					
						
						|  | hidden_states: the sequence to the encoder layer (required). | 
					
						
						|  | residual: hidden_states = Mixer(LN(residual)) | 
					
						
						|  | """ | 
					
						
						|  | if not self.fused_add_norm: | 
					
						
						|  | residual = (residual + self.drop_path(hidden_states)) if residual is not None else hidden_states | 
					
						
						|  | hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype)) | 
					
						
						|  | if self.residual_in_fp32: | 
					
						
						|  | residual = residual.to(torch.float32) | 
					
						
						|  | else: | 
					
						
						|  | fused_add_norm_fn = rms_norm_fn if isinstance(self.norm, RMSNorm) else layer_norm_fn | 
					
						
						|  | hidden_states, residual = fused_add_norm_fn( | 
					
						
						|  | hidden_states if residual is None else self.drop_path(hidden_states), | 
					
						
						|  | self.norm.weight, | 
					
						
						|  | self.norm.bias, | 
					
						
						|  | residual=residual, | 
					
						
						|  | prenorm=True, | 
					
						
						|  | residual_in_fp32=self.residual_in_fp32, | 
					
						
						|  | eps=self.norm.eps, | 
					
						
						|  | ) | 
					
						
						|  | hidden_states = self.mixer(hidden_states, inference_params=inference_params) | 
					
						
						|  | return hidden_states, residual | 
					
						
						|  |  | 
					
						
						|  | def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): | 
					
						
						|  | return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def create_block( | 
					
						
						|  | d_model, | 
					
						
						|  | ssm_cfg=None, | 
					
						
						|  | norm_epsilon=1e-5, | 
					
						
						|  | drop_path=0., | 
					
						
						|  | rms_norm=True, | 
					
						
						|  | residual_in_fp32=True, | 
					
						
						|  | fused_add_norm=True, | 
					
						
						|  | layer_idx=None, | 
					
						
						|  | bimamba=True, | 
					
						
						|  | device=None, | 
					
						
						|  | dtype=None, | 
					
						
						|  | ): | 
					
						
						|  | factory_kwargs = {"device": device, "dtype": dtype} | 
					
						
						|  | if ssm_cfg is None: | 
					
						
						|  | ssm_cfg = {} | 
					
						
						|  | mixer_cls = partial(Mamba, layer_idx=layer_idx, bimamba=bimamba, **ssm_cfg, **factory_kwargs) | 
					
						
						|  | norm_cls = partial(nn.LayerNorm if not rms_norm else RMSNorm, eps=norm_epsilon) | 
					
						
						|  | block = Block( | 
					
						
						|  | d_model, | 
					
						
						|  | mixer_cls, | 
					
						
						|  | norm_cls=norm_cls, | 
					
						
						|  | drop_path=drop_path, | 
					
						
						|  | fused_add_norm=fused_add_norm, | 
					
						
						|  | residual_in_fp32=residual_in_fp32, | 
					
						
						|  | ) | 
					
						
						|  | block.layer_idx = layer_idx | 
					
						
						|  | return block | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def _init_weights( | 
					
						
						|  | module, | 
					
						
						|  | n_layer, | 
					
						
						|  | initializer_range=0.02, | 
					
						
						|  | rescale_prenorm_residual=True, | 
					
						
						|  | n_residuals_per_layer=1, | 
					
						
						|  | ): | 
					
						
						|  | if isinstance(module, nn.Linear): | 
					
						
						|  | if module.bias is not None: | 
					
						
						|  | if not getattr(module.bias, "_no_reinit", False): | 
					
						
						|  | nn.init.zeros_(module.bias) | 
					
						
						|  | elif isinstance(module, nn.Embedding): | 
					
						
						|  | nn.init.normal_(module.weight, std=initializer_range) | 
					
						
						|  |  | 
					
						
						|  | if rescale_prenorm_residual: | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | for name, p in module.named_parameters(): | 
					
						
						|  | if name in ["out_proj.weight", "fc2.weight"]: | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | nn.init.kaiming_uniform_(p, a=math.sqrt(5)) | 
					
						
						|  | with torch.no_grad(): | 
					
						
						|  | p /= math.sqrt(n_residuals_per_layer * n_layer) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def segm_init_weights(m): | 
					
						
						|  | if isinstance(m, nn.Linear): | 
					
						
						|  | trunc_normal_(m.weight, std=0.02) | 
					
						
						|  | if isinstance(m, nn.Linear) and m.bias is not None: | 
					
						
						|  | nn.init.constant_(m.bias, 0) | 
					
						
						|  | elif isinstance(m, nn.LayerNorm): | 
					
						
						|  | nn.init.constant_(m.bias, 0) | 
					
						
						|  | nn.init.constant_(m.weight, 1.0) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class VisionMamba(nn.Module): | 
					
						
						|  | def __init__( | 
					
						
						|  | self, | 
					
						
						|  | img_size=224, | 
					
						
						|  | patch_size=16, | 
					
						
						|  | stride=16, | 
					
						
						|  | depth=24, | 
					
						
						|  | embed_dim=192, | 
					
						
						|  | channels=3, | 
					
						
						|  | num_classes=1000, | 
					
						
						|  | drop_rate=0., | 
					
						
						|  | drop_path_rate=0.1, | 
					
						
						|  | ssm_cfg=None, | 
					
						
						|  | norm_epsilon=1e-5, | 
					
						
						|  | initializer_cfg=None, | 
					
						
						|  | fused_add_norm=True, | 
					
						
						|  | rms_norm=True, | 
					
						
						|  | residual_in_fp32=True, | 
					
						
						|  | bimamba=True, | 
					
						
						|  | device=None, | 
					
						
						|  | dtype=None, | 
					
						
						|  | ): | 
					
						
						|  | factory_kwargs = {"device": device, "dtype": dtype} | 
					
						
						|  | super().__init__() | 
					
						
						|  | self.residual_in_fp32 = residual_in_fp32 | 
					
						
						|  | self.fused_add_norm = fused_add_norm | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | self.num_classes = num_classes | 
					
						
						|  | self.d_model = self.num_features = self.embed_dim = embed_dim | 
					
						
						|  |  | 
					
						
						|  | self.patch_embed = PatchEmbed( | 
					
						
						|  | img_size=img_size, patch_size=patch_size, stride=stride, in_chans=channels, embed_dim=embed_dim) | 
					
						
						|  | num_patches = self.patch_embed.num_patches | 
					
						
						|  |  | 
					
						
						|  | self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim)) | 
					
						
						|  | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, self.embed_dim)) | 
					
						
						|  | self.pos_drop = nn.Dropout(p=drop_rate) | 
					
						
						|  |  | 
					
						
						|  | self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() | 
					
						
						|  |  | 
					
						
						|  | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] | 
					
						
						|  | inter_dpr = [0.0] + dpr | 
					
						
						|  | self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity() | 
					
						
						|  |  | 
					
						
						|  | self.layers = nn.ModuleList( | 
					
						
						|  | [ | 
					
						
						|  | create_block( | 
					
						
						|  | embed_dim, | 
					
						
						|  | ssm_cfg=ssm_cfg, | 
					
						
						|  | norm_epsilon=norm_epsilon, | 
					
						
						|  | rms_norm=rms_norm, | 
					
						
						|  | residual_in_fp32=residual_in_fp32, | 
					
						
						|  | fused_add_norm=fused_add_norm, | 
					
						
						|  | layer_idx=i, | 
					
						
						|  | bimamba=bimamba, | 
					
						
						|  | drop_path=inter_dpr[i], | 
					
						
						|  | **factory_kwargs, | 
					
						
						|  | ) | 
					
						
						|  | for i in range(depth) | 
					
						
						|  | ] | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | self.norm_f = (nn.LayerNorm if not rms_norm else RMSNorm)(embed_dim, eps=norm_epsilon, **factory_kwargs) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | self.apply(segm_init_weights) | 
					
						
						|  | self.head.apply(segm_init_weights) | 
					
						
						|  | trunc_normal_(self.pos_embed, std=.02) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | self.apply( | 
					
						
						|  | partial( | 
					
						
						|  | _init_weights, | 
					
						
						|  | n_layer=depth, | 
					
						
						|  | **(initializer_cfg if initializer_cfg is not None else {}), | 
					
						
						|  | ) | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): | 
					
						
						|  | return { | 
					
						
						|  | i: layer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs) | 
					
						
						|  | for i, layer in enumerate(self.layers) | 
					
						
						|  | } | 
					
						
						|  |  | 
					
						
						|  | @torch.jit.ignore | 
					
						
						|  | def no_weight_decay(self): | 
					
						
						|  | return {"pos_embed", "cls_token"} | 
					
						
						|  |  | 
					
						
						|  | @torch.jit.ignore() | 
					
						
						|  | def load_pretrained(self, checkpoint_path, prefix=""): | 
					
						
						|  | _load_weights(self, checkpoint_path, prefix) | 
					
						
						|  |  | 
					
						
						|  | def forward_features(self, x, inference_params=None): | 
					
						
						|  | x = self.patch_embed(x) | 
					
						
						|  | cls_token = self.cls_token.expand(x.shape[0], -1, -1) | 
					
						
						|  | x = torch.cat((cls_token, x), dim=1) | 
					
						
						|  |  | 
					
						
						|  | x = x + self.pos_embed | 
					
						
						|  | x = self.pos_drop(x) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | residual = None | 
					
						
						|  | hidden_states = x | 
					
						
						|  | for layer in self.layers: | 
					
						
						|  | hidden_states, residual = layer( | 
					
						
						|  | hidden_states, residual, inference_params=inference_params | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | if not self.fused_add_norm: | 
					
						
						|  | if residual is None: | 
					
						
						|  | residual = hidden_states | 
					
						
						|  | else: | 
					
						
						|  | residual = residual + self.drop_path(hidden_states) | 
					
						
						|  | hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype)) | 
					
						
						|  | else: | 
					
						
						|  |  | 
					
						
						|  | fused_add_norm_fn = rms_norm_fn if isinstance(self.norm_f, RMSNorm) else layer_norm_fn | 
					
						
						|  | hidden_states = fused_add_norm_fn( | 
					
						
						|  | self.drop_path(hidden_states), | 
					
						
						|  | self.norm_f.weight, | 
					
						
						|  | self.norm_f.bias, | 
					
						
						|  | eps=self.norm_f.eps, | 
					
						
						|  | residual=residual, | 
					
						
						|  | prenorm=False, | 
					
						
						|  | residual_in_fp32=self.residual_in_fp32, | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | return hidden_states[:, 0, :] | 
					
						
						|  |  | 
					
						
						|  | def forward(self, x, inference_params=None): | 
					
						
						|  | x = self.forward_features(x, inference_params) | 
					
						
						|  | x = self.head(x) | 
					
						
						|  | return x | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def videomamba_image_tiny(**kwargs): | 
					
						
						|  | model = VisionMamba( | 
					
						
						|  | patch_size=16, | 
					
						
						|  | embed_dim=192, | 
					
						
						|  | depth=24, | 
					
						
						|  | rms_norm=True, | 
					
						
						|  | residual_in_fp32=True, | 
					
						
						|  | fused_add_norm=True, | 
					
						
						|  | **kwargs | 
					
						
						|  | ) | 
					
						
						|  | model.default_cfg = _cfg() | 
					
						
						|  | return model | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def videomamba_image_small(**kwargs): | 
					
						
						|  | model = VisionMamba( | 
					
						
						|  | patch_size=16, | 
					
						
						|  | embed_dim=384, | 
					
						
						|  | depth=24, | 
					
						
						|  | rms_norm=True, | 
					
						
						|  | residual_in_fp32=True, | 
					
						
						|  | fused_add_norm=True, | 
					
						
						|  | **kwargs | 
					
						
						|  | ) | 
					
						
						|  | model.default_cfg = _cfg() | 
					
						
						|  | return model | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def videomamba_image_middle(**kwargs): | 
					
						
						|  | model = VisionMamba( | 
					
						
						|  | patch_size=16, | 
					
						
						|  | embed_dim=576, | 
					
						
						|  | depth=32, | 
					
						
						|  | rms_norm=True, | 
					
						
						|  | residual_in_fp32=True, | 
					
						
						|  | fused_add_norm=True, | 
					
						
						|  | **kwargs | 
					
						
						|  | ) | 
					
						
						|  | model.default_cfg = _cfg() | 
					
						
						|  | return model | 
					
						
						|  |  |