# Copyright (c) 2023, NVIDIA Corporation & Affiliates. All rights reserved. # # This work is made available under the Nvidia Source Code License-NC. # To view a copy of this license, visit # https://github.com/NVlabs/prismer/blob/main/LICENSE # Modified from: https://github.com/openai/CLIP/blob/main/clip/model.py from collections import OrderedDict from einops import rearrange from clip.clip import _download import re import os import torch import torch.nn as nn import torch.nn.functional as F import random from model.modules.utils import QuickGELU, LayerNorm, Adaptor, interpolate_pos_embed from model.modules.resampler import PerceiverResampler from huggingface_hub import hf_hub_download from functools import partial hf_hub_download = partial(hf_hub_download, library_name="open_clip", library_version='2.0.2') _MODELS = { "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt", "ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt", "ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt", "ViT-H/14": "laion/CLIP-ViT-H-14-laion2B-s32B-b79K", } class ResidualAttentionBlock(nn.Module): def __init__(self, d_model: int, n_head: int): super().__init__() self.attn = nn.MultiheadAttention(d_model, n_head) self.mlp = nn.Sequential(OrderedDict([ ("c_fc", nn.Linear(d_model, d_model * 4)), ("gelu", QuickGELU()), ("c_proj", nn.Linear(d_model * 4, d_model)) ]) ) self.ln_1 = LayerNorm(d_model) self.ln_2 = LayerNorm(d_model) def attention(self, x: torch.Tensor): return self.attn(x, x, x, need_weights=False)[0] def forward(self, x: torch.Tensor, mode='attention'): if mode == 'attention': return x + self.attention(self.ln_1(x)) elif mode == 'mlp': return x + self.mlp(self.ln_2(x)) class Transformer(nn.Module): def __init__(self, width: int, layers: int, heads: int): super().__init__() self.resblocks = nn.Sequential(*[nn.ModuleList([ ResidualAttentionBlock(width, heads), Adaptor(width), ]) for _ in range(layers)]) def forward(self, x: torch.Tensor): for resblock, adaptor in self.resblocks: x = resblock(x, mode='attention') x = adaptor(x) x = resblock(x, mode='mlp') return x class VisionTransformer(nn.Module): def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, experts: dict): super().__init__() self.experts = experts self.conv1 = nn.ModuleDict() for e in experts: if e == 'rgb': self.conv1[e] = nn.Conv2d(in_channels=experts[e], out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) elif e in ['seg', 'obj_detection', 'ocr_detection']: self.conv1[e] = nn.Sequential( nn.UpsamplingBilinear2d(scale_factor=4 / patch_size), nn.Conv2d(in_channels=64, out_channels=width // 8, kernel_size=3, stride=2, padding=1, bias=False), nn.BatchNorm2d(width // 8), nn.ReLU(), nn.Conv2d(in_channels=width // 8, out_channels=width // 4, kernel_size=3, stride=2, padding=1, bias=False), nn.BatchNorm2d(width // 4), nn.ReLU(), nn.Conv2d(in_channels=width // 4, out_channels=width // 2, kernel_size=3, stride=1, padding=1, bias=False), nn.BatchNorm2d(width // 2), nn.ReLU(), nn.Conv2d(in_channels=width // 2, out_channels=width, kernel_size=3, stride=1, padding=1, bias=False), nn.BatchNorm2d(width), nn.ReLU(), nn.Conv2d(in_channels=width, out_channels=width, kernel_size=1, stride=1, padding=0, bias=False), ) else: self.conv1[e] = nn.Sequential( nn.UpsamplingBilinear2d(scale_factor=16 / patch_size), nn.Conv2d(in_channels=experts[e], out_channels=width // 8, kernel_size=3, stride=2, padding=1, bias=False), nn.BatchNorm2d(width // 8), nn.ReLU(), nn.Conv2d(in_channels=width // 8, out_channels=width // 4, kernel_size=3, stride=2, padding=1, bias=False), nn.BatchNorm2d(width // 4), nn.ReLU(), nn.Conv2d(in_channels=width // 4, out_channels=width // 2, kernel_size=3, stride=2, padding=1, bias=False), nn.BatchNorm2d(width // 2), nn.ReLU(), nn.Conv2d(in_channels=width // 2, out_channels=width, kernel_size=3, stride=2, padding=1, bias=False), nn.BatchNorm2d(width), nn.ReLU(), nn.Conv2d(in_channels=width, out_channels=width, kernel_size=1, stride=1, padding=0, bias=False), ) scale = width ** -0.5 self.patch_size = patch_size self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2, width)) if 'obj_detection' in self.experts: self.instance_embedding = nn.Parameter(scale * torch.randn(128, width)) self.transformer = Transformer(width, layers, heads) if len(self.experts) > 1: self.resampler = PerceiverResampler(width=width, layers=4, heads=8, num_latents=64) self.ln_pre = LayerNorm(width) self.ln_post = LayerNorm(width) def forward(self, x: dict): experts_inputs = [] for exp in x: domain = 'seg' if 'seg' in exp else exp x_ = x[exp] if exp != 'obj_detection' else x[exp]['label'] x_ = self.conv1[domain](x_) # add instance embedding (object detection only) if exp == 'obj_detection': instance_map = F.interpolate(x[exp]['instance'].to(x_.dtype), size=x_.shape[2:], mode='nearest') instance_map = rearrange(instance_map, 'b 1 h w -> b h w') label_map = rearrange(x_, 'b d h w -> d b h w') for l in x[exp]['instance'].unique(): l_ = random.randint(0, 127) label_map[:, instance_map == l] += self.instance_embedding[l_].unsqueeze(-1) x_ = rearrange(label_map, 'd b h w -> b d h w') x_ = rearrange(x_, 'b d h w -> b (h w) d') # add position embedding (shared across all modalities) if domain == 'rgb': x_ = x_ + self.positional_embedding.to(x_.dtype) rgb_inputs = x_ else: exp_positional_embedding = interpolate_pos_embed(self.positional_embedding.to(x_.dtype), x_.shape[1]) x_ = x_ + exp_positional_embedding experts_inputs.append(x_) if len(experts_inputs) > 0: experts_inputs = rearrange(torch.cat(experts_inputs, dim=1), 'b l d -> l b d') experts_inputs = self.resampler(experts_inputs) rgb_inputs = rearrange(rgb_inputs, 'b l d -> l b d') x = torch.cat([rgb_inputs, experts_inputs], dim=0) else: x = rearrange(rgb_inputs, 'b l d -> l b d') x = self.ln_pre(x) x = self.transformer(x) x = self.ln_post(x) return x # latents, batch, output_dim def load_encoder(name: str, experts: dict, image_resolution: int): if name == 'ViT-B/16': vision_width = 768 vision_patch_size = 16 vision_layers = 12 vision_heads = 12 elif name == 'ViT-L/14' or name == 'ViT-L/14@336px': vision_width = 1024 vision_patch_size = 14 vision_layers = 24 vision_heads = 16 ViT = VisionTransformer(input_resolution=image_resolution, patch_size=vision_patch_size, width=vision_width, layers=vision_layers, heads=vision_heads, experts=experts) return ViT # Quick Check: # model = load_encoder("ViT-B/16", experts={'rgb': 3, 'depth': 1, 'seg': 64}, image_resolution=224) # rgb, depth, seg = torch.rand(4, 3, 224, 224), torch.rand(4, 1, 224, 224), torch.rand(4, 64, 224, 224) # feat = model({'rgb': rgb, 'depth': depth, 'seg': seg}) # 260 [196 + 64], 4, 768