Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) 2023-2024 DeepSeek. | |
# | |
# Permission is hereby granted, free of charge, to any person obtaining a copy of | |
# this software and associated documentation files (the "Software"), to deal in | |
# the Software without restriction, including without limitation the rights to | |
# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of | |
# the Software, and to permit persons to whom the Software is furnished to do so, | |
# subject to the following conditions: | |
# | |
# The above copyright notice and this permission notice shall be included in all | |
# copies or substantial portions of the Software. | |
# | |
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | |
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS | |
# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR | |
# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER | |
# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN | |
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. | |
from typing import Tuple, Union | |
import torch | |
import torch.nn as nn | |
from attrdict import AttrDict | |
class MlpProjector(nn.Module): | |
def __init__(self, cfg): | |
super().__init__() | |
self.cfg = cfg | |
if cfg.projector_type == "identity": | |
modules = nn.Identity() | |
elif cfg.projector_type == "linear": | |
modules = nn.Linear(cfg.input_dim, cfg.n_embed) | |
elif cfg.projector_type == "mlp_gelu": | |
mlp_depth = cfg.get("depth", 1) | |
modules = [nn.Linear(cfg.input_dim, cfg.n_embed)] | |
for _ in range(1, mlp_depth): | |
modules.append(nn.GELU()) | |
modules.append(nn.Linear(cfg.n_embed, cfg.n_embed)) | |
modules = nn.Sequential(*modules) | |
elif cfg.projector_type == "low_high_hybrid_split_mlp_gelu": | |
mlp_depth = cfg.get("depth", 1) | |
self.high_up_proj = nn.Linear(cfg.input_dim, cfg.n_embed // 2) | |
self.low_up_proj = nn.Linear(cfg.input_dim, cfg.n_embed // 2) | |
modules = [] | |
for _ in range(1, mlp_depth): | |
modules.append(nn.GELU()) | |
modules.append(nn.Linear(cfg.n_embed, cfg.n_embed)) | |
modules = nn.Sequential(*modules) | |
else: | |
raise ValueError(f"Unknown projector type: {cfg.projector_type}") | |
self.layers = modules | |
def forward( | |
self, x_or_tuple: Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor] | |
): | |
""" | |
Args: | |
x_or_tuple (Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor]: if it is a tuple of torch.Tensor, | |
then it comes from the hybrid vision encoder, and x = high_res_x, low_res_x); | |
otherwise it is the feature from the single vision encoder. | |
Returns: | |
x (torch.Tensor): [b, s, c] | |
""" | |
if isinstance(x_or_tuple, tuple): | |
# self.cfg.projector_type == "low_high_hybrid_split_mlp_gelu": | |
high_x, low_x = x_or_tuple | |
high_x = self.high_up_proj(high_x) | |
low_x = self.low_up_proj(low_x) | |
x = torch.concat([high_x, low_x], dim=-1) | |
else: | |
x = x_or_tuple | |
return self.layers(x) | |
if __name__ == "__main__": | |
cfg = AttrDict( | |
input_dim=1024, | |
n_embed=2048, | |
depth=2, | |
projector_type="low_high_hybrid_split_mlp_gelu", | |
) | |
inputs = (torch.rand(4, 576, 1024), torch.rand(4, 576, 1024)) | |
m = MlpProjector(cfg) | |
out = m(inputs) | |
print(out.shape) | |