|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Tuple, Union |
|
|
|
import torch |
|
import torch.nn as nn |
|
from attrdict import AttrDict |
|
|
|
|
|
class MlpProjector(nn.Module): |
|
def __init__(self, cfg): |
|
super().__init__() |
|
|
|
self.cfg = cfg |
|
|
|
if cfg.projector_type == "identity": |
|
modules = nn.Identity() |
|
|
|
elif cfg.projector_type == "linear": |
|
modules = nn.Linear(cfg.input_dim, cfg.n_embed) |
|
|
|
elif cfg.projector_type == "mlp_gelu": |
|
mlp_depth = cfg.get("depth", 1) |
|
modules = [nn.Linear(cfg.input_dim, cfg.n_embed)] |
|
for _ in range(1, mlp_depth): |
|
modules.append(nn.GELU()) |
|
modules.append(nn.Linear(cfg.n_embed, cfg.n_embed)) |
|
modules = nn.Sequential(*modules) |
|
|
|
elif cfg.projector_type == "low_high_hybrid_split_mlp_gelu": |
|
mlp_depth = cfg.get("depth", 1) |
|
self.high_up_proj = nn.Linear(cfg.input_dim, cfg.n_embed // 2) |
|
self.low_up_proj = nn.Linear(cfg.input_dim, cfg.n_embed // 2) |
|
|
|
modules = [] |
|
for _ in range(1, mlp_depth): |
|
modules.append(nn.GELU()) |
|
modules.append(nn.Linear(cfg.n_embed, cfg.n_embed)) |
|
modules = nn.Sequential(*modules) |
|
|
|
else: |
|
raise ValueError(f"Unknown projector type: {cfg.projector_type}") |
|
|
|
self.layers = modules |
|
|
|
def forward( |
|
self, x_or_tuple: Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor] |
|
): |
|
""" |
|
|
|
Args: |
|
x_or_tuple (Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor]: if it is a tuple of torch.Tensor, |
|
then it comes from the hybrid vision encoder, and x = high_res_x, low_res_x); |
|
otherwise it is the feature from the single vision encoder. |
|
|
|
Returns: |
|
x (torch.Tensor): [b, s, c] |
|
""" |
|
|
|
if isinstance(x_or_tuple, tuple): |
|
|
|
high_x, low_x = x_or_tuple |
|
high_x = self.high_up_proj(high_x) |
|
low_x = self.low_up_proj(low_x) |
|
x = torch.concat([high_x, low_x], dim=-1) |
|
else: |
|
x = x_or_tuple |
|
|
|
return self.layers(x) |
|
|
|
|
|
if __name__ == "__main__": |
|
cfg = AttrDict( |
|
input_dim=1024, |
|
n_embed=2048, |
|
depth=2, |
|
projector_type="low_high_hybrid_split_mlp_gelu", |
|
) |
|
inputs = (torch.rand(4, 576, 1024), torch.rand(4, 576, 1024)) |
|
|
|
m = MlpProjector(cfg) |
|
out = m(inputs) |
|
print(out.shape) |
|
|