Spaces:
Sleeping
Sleeping
| import torch.nn as nn | |
| import torch | |
| import torch.nn.functional as F | |
| import copy | |
| class MlpProjector(nn.Module): | |
| def __init__(self, cfg): | |
| super().__init__() | |
| self.cfg = cfg | |
| if cfg.projector_type == "identity": | |
| modules = nn.Identity() | |
| elif cfg.projector_type == "linear": | |
| modules = nn.Linear(cfg.input_dim, cfg.n_embed) | |
| elif cfg.projector_type == "mlp_gelu": | |
| mlp_depth = cfg.get("depth", 1) | |
| modules = [nn.Linear(cfg.input_dim, cfg.n_embed)] | |
| for _ in range(1, mlp_depth): | |
| modules.append(nn.GELU()) | |
| modules.append(nn.Linear(cfg.n_embed, cfg.n_embed)) | |
| modules = nn.Sequential(*modules) | |
| elif cfg.projector_type == "normlayer_downsample_mlp_gelu": | |
| mlp_depth = cfg.get("depth", 1) | |
| mlp_ratio = cfg.get("mlp_ratio", 1) | |
| modules = [ | |
| nn.LayerNorm(cfg.input_dim * cfg.downsample_ratio * cfg.downsample_ratio), | |
| nn.Linear(cfg.input_dim * cfg.downsample_ratio * cfg.downsample_ratio, cfg.n_embed * mlp_ratio) | |
| ] | |
| for _ in range(1, mlp_depth - 1): | |
| modules.append(nn.GELU()) | |
| modules.append(nn.Linear(cfg.n_embed * mlp_ratio, cfg.n_embed * mlp_ratio)) | |
| modules.append(nn.GELU()) | |
| modules.append(nn.Linear(cfg.n_embed * mlp_ratio, cfg.n_embed)) | |
| modules = nn.Sequential(*modules) | |
| elif cfg.projector_type == "downsample_mlp_gelu": | |
| mlp_depth = cfg.get("depth", 1) | |
| mlp_ratio = cfg.get("mlp_ratio", 1) | |
| modules = [nn.Linear(cfg.input_dim * cfg.downsample_ratio * cfg.downsample_ratio, cfg.n_embed * mlp_ratio)] | |
| for _ in range(1, mlp_depth - 1): | |
| modules.append(nn.GELU()) | |
| modules.append(nn.Linear(cfg.n_embed * mlp_ratio, cfg.n_embed * mlp_ratio)) | |
| modules.append(nn.GELU()) | |
| modules.append(nn.Linear(cfg.n_embed * mlp_ratio, cfg.n_embed)) | |
| modules = nn.Sequential(*modules) | |
| elif cfg.projector_type == "low_high_hybrid_split_mlp_gelu": | |
| mlp_depth = cfg.get("depth", 1) | |
| self.high_up_proj = nn.Linear(cfg.input_dim, cfg.n_embed // 2) | |
| self.low_up_proj = nn.Linear(cfg.input_dim, cfg.n_embed // 2) | |
| modules = [] | |
| for _ in range(1, mlp_depth): | |
| modules.append(nn.GELU()) | |
| modules.append(nn.Linear(cfg.n_embed, cfg.n_embed)) | |
| modules = nn.Sequential(*modules) | |
| elif cfg.projector_type == "hybrid_split_feature_mlp_gelu": | |
| mlp_depth = cfg.get("depth", 1) | |
| channel_div = cfg.get("channel_div", 0.5) | |
| self.high_up_proj = nn.Linear(cfg.input_dim[0], int(cfg.n_embed * channel_div)) | |
| self.low_up_proj = nn.Linear(cfg.input_dim[1], cfg.n_embed - int(cfg.n_embed * channel_div)) | |
| modules = [] | |
| for _ in range(1, mlp_depth): | |
| modules.append(nn.GELU()) | |
| modules.append(nn.Linear(cfg.n_embed, cfg.n_embed)) | |
| modules = nn.Sequential(*modules) | |
| elif cfg.projector_type == "low_high_split_mlp_gelu": | |
| mlp_depth = cfg.get("depth", 1) | |
| modules = [] | |
| for _ in range(1, mlp_depth): | |
| modules.append(nn.GELU()) | |
| modules.append(nn.Linear(cfg.n_embed // 2, cfg.n_embed // 2)) | |
| modules = nn.Sequential(*modules) | |
| self.high_layers = nn.Sequential(*modules) | |
| self.low_layers = copy.deepcopy(modules) | |
| else: | |
| raise ValueError(f"Unknown projector type: {cfg.projector_type}") | |
| if cfg.get("token_pooling", False): | |
| self.token_pooling_layer = nn.Linear(cfg.input_dim * 4, cfg.input_dim) | |
| if cfg.get("conv_fusion_high_low_features", False): | |
| self.fusion_layer = nn.Linear(cfg.input_dim, cfg.input_dim) | |
| self.layers = modules | |
| def forward(self, x): | |
| if self.cfg.get("token_pooling", False): | |
| batch_size, wxh, channels = x.shape | |
| w = h = int(wxh**0.5) | |
| x = x.view(batch_size, w, h, channels) | |
| x = x.permute(0, 3, 1, 2) | |
| # import ipdb; ipdb.set_trace() | |
| patches = x.unfold(2, 2, 2).unfold(3, 2, 2) | |
| batch_size, channels, h_patches, w_patches, _, _ = patches.size() | |
| # 在通道维度上拼接 | |
| patches = patches.contiguous().view(batch_size, channels, h_patches * w_patches, -1) | |
| # 通过线性层 | |
| patches = patches.permute(0, 2, 1, 3).contiguous() | |
| patches = patches.view(batch_size, h_patches * w_patches, channels * 4) | |
| x = self.token_pooling_layer(patches) | |
| if self.cfg.get("conv_fusion_high_low_features", False): | |
| x = self.fusion_layer(x[:, 0]) + x[:, 1] | |
| if self.cfg.projector_type == 'low_high_hybrid_split_mlp_gelu': | |
| high_x, low_x = x[0], x[1] | |
| high_x = self.high_up_proj(high_x) | |
| low_x = self.low_up_proj(low_x) | |
| x = torch.concat([high_x, low_x], dim=-1) | |
| if self.cfg.projector_type == 'hybrid_split_feature_mlp_gelu': | |
| high_x = x[...,:self.cfg.input_dim[0]] | |
| low_x = x[...,self.cfg.input_dim[0]:] | |
| high_x = self.high_up_proj(high_x) | |
| low_x = self.low_up_proj(low_x) | |
| x = torch.concat([high_x, low_x], dim=-1) | |
| if self.cfg.projector_type == 'low_high_split_mlp_gelu': | |
| high_x, low_x = x[0], x[1] | |
| high_x = self.high_layers(high_x) | |
| low_x = self.low_layers(low_x) | |
| x = torch.concat([high_x, low_x], dim=-1) | |
| return x | |
| if self.cfg.projector_type == 'downsample_mlp_gelu' or self.cfg.projector_type == 'normlayer_downsample_mlp_gelu': | |
| bs, hw, input_dim = x.shape | |
| h = w = int((hw) ** 0.5) | |
| """compute padding""" | |
| if h % self.cfg.downsample_ratio: | |
| pad = self.cfg.downsample_ratio - h % self.cfg.downsample_ratio | |
| else: | |
| pad = 0 | |
| x = x.reshape(bs, h, w, input_dim) | |
| if pad > 0: | |
| x = F.pad(x, (0, 0, 0, pad, 0, pad), "constant", 0) | |
| """4 to 1 concat""" | |
| x = x.permute(0, 3, 1, 2) # B, C, H, W | |
| x = F.unfold(x, kernel_size=self.cfg.downsample_ratio, stride=self.cfg.downsample_ratio, padding=0) # B, C*4, HW // 4 | |
| x = x.permute(0, 2, 1) | |
| return self.layers(x) | |
| def get_flops_per_sample(cfg): | |
| if cfg.projector_type == "linear": | |
| fwd = 2 * cfg.input_dim * cfg.n_embed | |
| elif "mlp_gelu" in cfg.projector_type : | |
| mlp_depth = cfg.get("depth", 1) | |
| downsample_ratio = cfg.get("downsample_ratio", 1) | |
| input_dim = sum(cfg.input_dim) if isinstance(cfg.input_dim, list) else cfg.input_dim | |
| input_dim = input_dim * downsample_ratio * downsample_ratio | |
| fwd = 2 * input_dim * cfg.n_embed + (mlp_depth - 1) * 2 * cfg.n_embed * cfg.n_embed | |
| else: | |
| fwd = 0 | |
| return fwd * 3 | |