# Copyright (c) 2023-2024 DeepSeek. # # Permission is hereby granted, free of charge, to any person obtaining a copy of # this software and associated documentation files (the "Software"), to deal in # the Software without restriction, including without limitation the rights to # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of # the Software, and to permit persons to whom the Software is furnished to do so, # subject to the following conditions: # # The above copyright notice and this permission notice shall be included in all # copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER # IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. from typing import Tuple, Union import torch import torch.nn as nn from attrdict import AttrDict class MlpProjector(nn.Module): def __init__(self, cfg): super().__init__() self.cfg = cfg if cfg.projector_type == "identity": modules = nn.Identity() elif cfg.projector_type == "linear": modules = nn.Linear(cfg.input_dim, cfg.n_embed) elif cfg.projector_type == "mlp_gelu": mlp_depth = cfg.get("depth", 1) modules = [nn.Linear(cfg.input_dim, cfg.n_embed)] for _ in range(1, mlp_depth): modules.append(nn.GELU()) modules.append(nn.Linear(cfg.n_embed, cfg.n_embed)) modules = nn.Sequential(*modules) elif cfg.projector_type == "low_high_hybrid_split_mlp_gelu": mlp_depth = cfg.get("depth", 1) self.high_up_proj = nn.Linear(cfg.input_dim, cfg.n_embed // 2) self.low_up_proj = nn.Linear(cfg.input_dim, cfg.n_embed // 2) modules = [] for _ in range(1, mlp_depth): modules.append(nn.GELU()) modules.append(nn.Linear(cfg.n_embed, cfg.n_embed)) modules = nn.Sequential(*modules) else: raise ValueError(f"Unknown projector type: {cfg.projector_type}") self.layers = modules def forward( self, x_or_tuple: Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor] ): """ Args: x_or_tuple (Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor]: if it is a tuple of torch.Tensor, then it comes from the hybrid vision encoder, and x = high_res_x, low_res_x); otherwise it is the feature from the single vision encoder. Returns: x (torch.Tensor): [b, s, c] """ if isinstance(x_or_tuple, tuple): # self.cfg.projector_type == "low_high_hybrid_split_mlp_gelu": high_x, low_x = x_or_tuple high_x = self.high_up_proj(high_x) low_x = self.low_up_proj(low_x) x = torch.concat([high_x, low_x], dim=-1) else: x = x_or_tuple return self.layers(x) if __name__ == "__main__": cfg = AttrDict( input_dim=1024, n_embed=2048, depth=2, projector_type="low_high_hybrid_split_mlp_gelu", ) inputs = (torch.rand(4, 576, 1024), torch.rand(4, 576, 1024)) m = MlpProjector(cfg) out = m(inputs) print(out.shape)