# Copyright 2024 Alibaba DAMO Academy # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import re import einops import torch import torch.nn as nn import torch.nn.functional as F from timm.models.regnet import RegStage from timm.models.layers import LayerNorm2d from transformers import TRANSFORMERS_CACHE def parse_snapshot_folder(repo_id, cache_dir=None, repo_type="model"): revision = "main" # 1. parse the downloaded cache folder if cache_dir is None: cache_dir = TRANSFORMERS_CACHE else: cache_dir = cache_dir object_id = repo_id.replace("/", "--") repo_cache = os.path.join(cache_dir, f"{repo_type}s--{object_id}") # 2. resolve refs (for instance to convert main to the associated commit sha) refs_dir = os.path.join(repo_cache, "refs") if os.path.isdir(refs_dir): revision_file = os.path.join(refs_dir, revision) if os.path.isfile(revision_file): with open(revision_file) as f: revision = f.read() # 3. acquire the snapshot folder folder = os.path.join(repo_cache, "snapshots", revision) return folder def load_mm_projector(model_path, cache_dir=None, token=None): if os.path.exists(os.path.join(model_path, 'mm_projector.bin')): is_local = True folder = model_path else: is_local = False folder = parse_snapshot_folder(model_path, cache_dir=cache_dir, repo_type="model") if not os.path.exists(os.path.join(folder, 'mm_projector.bin')): # downloading from remote repo from huggingface_hub import snapshot_download snapshot_download(repo_id=model_path, cache_dir=cache_dir, token=token) mm_projector_weights = torch.load(os.path.join(folder, 'mm_projector.bin'), map_location='cpu') mm_projector_weights = {k: v.to(torch.float16) for k, v in mm_projector_weights.items()} return mm_projector_weights class IdentityMap(nn.Module): def __init__(self): super().__init__() def forward(self, x, *args, **kwargs): return x @property def config(self): return {"mm_projector_type": 'identity'} class SimpleResBlock(nn.Module): def __init__(self, channels): super().__init__() self.pre_norm = nn.LayerNorm(channels) self.proj = nn.Sequential( nn.Linear(channels, channels), nn.GELU(), nn.Linear(channels, channels) ) def forward(self, x): x = self.pre_norm(x) return x + self.proj(x) def build_vision_projector(config, delay_load=False, **kwargs): projector_type = getattr(config, 'mm_projector_type', 'linear') mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type) if mlp_gelu_match: mlp_depth = int(mlp_gelu_match.group(1)) modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)] for _ in range(1, mlp_depth): modules.append(nn.GELU()) modules.append(nn.Linear(config.hidden_size, config.hidden_size)) return nn.Sequential(*modules) if projector_type == "linear": # NOTE: for both linear and mlp2x_gelu projector type, mean pooling is adopted to aggreate video features return nn.Linear(config.mm_hidden_size, config.hidden_size) elif projector_type == "stc_connector": return STCConnector(config) elif projector_type == "stp_connector": return STPConnector(config) elif projector_type == "stc_connector_v35": return STCConnectorV35(config) elif projector_type == "spatial_conv": return SpatialConv(config) elif projector_type == "spatial_pool": return SpatialPool(config) if projector_type == 'identity': return IdentityMap() raise ValueError(f'Unknown projector type: {projector_type}') def build_mlp(depth, hidden_size, output_hidden_size): modules = [nn.Linear(hidden_size, output_hidden_size)] for _ in range(1, depth): modules.append(nn.GELU()) modules.append(nn.Linear(output_hidden_size, output_hidden_size)) return nn.Sequential(*modules) class STCConnector(nn.Module): def __init__(self, config, downsample=(2, 2, 2), depth=4, mlp_depth=2): """Temporal Convolutional Vision-Language Connector. Args: config: config object. downsample: (temporal, height, width) downsample rate. depth: depth of the spatial interaction blocks. mlp_depth: depth of the vision-language projector layers. """ super().__init__() self.encoder_hidden_size = encoder_hidden_size = config.mm_hidden_size self.hidden_size = hidden_size = config.hidden_size self.output_hidden_size = output_hidden_size = config.hidden_size # TODO: make these as config arguments self.depth = depth self.mlp_depth = mlp_depth self.downsample = downsample if depth != 0: self.s1 = RegStage( depth=depth, in_chs=encoder_hidden_size, out_chs=hidden_size, stride=1, dilation=1, act_layer=nn.SiLU, norm_layer=LayerNorm2d, ) else: self.s1 = nn.Identity() self.sampler = nn.Sequential( nn.Conv3d( in_channels=hidden_size, out_channels=hidden_size, kernel_size=downsample, stride=downsample, padding=1, bias=True ), nn.SiLU() ) if depth != 0: self.s2 = RegStage( depth=depth, in_chs=hidden_size, out_chs=hidden_size, stride=1, dilation=1, act_layer=nn.SiLU, norm_layer=LayerNorm2d, ) else: self.s2 = nn.Identity() self.readout = build_mlp(mlp_depth, hidden_size, output_hidden_size) def forward(self, x): """Aggregate tokens on the temporal and spatial dimensions. Args: x: input tokens [b, t, h, w, d] / [b, t, l, d] Returns: aggregated tokens [b, l, d] """ t = x.size(1) if x.ndim == 4: hw = int(x.size(2) ** 0.5) x = einops.rearrange(x, "b t (h w) d -> b d t h w", h=hw, w=hw) elif x.ndim == 5: x = einops.rearrange(x, "b t h w d -> b d t h w") x = einops.rearrange(x, "b d t h w -> (b t) d h w") # 1. the first stage of the adapter x = self.s1(x) x = einops.rearrange(x, "(b t) d h w -> b d t h w", t=t) # 2. downsampler x = self.sampler(x) new_t = x.size(2) # 3. the second stage of the adapter x = einops.rearrange(x, "b d t h w -> (b t) d h w") x = self.s2(x) x = einops.rearrange(x, "(b t) d h w -> b (t h w) d", t=new_t) x = self.readout(x) return x class STPConnector(STCConnector): def __init__(self, config, downsample=(2, 2, 2), depth=4, mlp_depth=2): super().__init__(config=config, downsample=downsample, depth=depth, mlp_depth=mlp_depth) self.sampler = nn.Sequential(nn.AvgPool3d(downsample), nn.SiLU()) class STCConnectorV35(STCConnector): def __init__(self, config, downsample=(2, 2, 2), depth=4, mlp_depth=2): super().__init__(config=config, downsample=downsample, depth=depth, mlp_depth=mlp_depth) self.sampler = nn.Sequential( nn.Conv3d( in_channels=self.hidden_size, out_channels=self.hidden_size, kernel_size=downsample, stride=downsample, padding=0, bias=True ), nn.SiLU()) class SpatialConv(STCConnector): def __init__(self, config, downsample=(1, 2, 2), depth=0, mlp_depth=2): super().__init__(config=config, downsample=downsample, depth=depth, mlp_depth=mlp_depth) class SpatialPool(STPConnector): def __init__(self, config, downsample=(1, 2, 2), depth=0, mlp_depth=2): super().__init__(config=config, downsample=downsample, depth=depth, mlp_depth=mlp_depth)