shikunl's picture
Final test
806eb00
raw
history blame
8.93 kB
# Copyright (c) 2023, NVIDIA Corporation & Affiliates. All rights reserved.
#
# This work is made available under the Nvidia Source Code License-NC.
# To view a copy of this license, visit
# https://github.com/NVlabs/prismer/blob/main/LICENSE
# Modified from: https://github.com/openai/CLIP/blob/main/clip/model.py
from collections import OrderedDict
from einops import rearrange
from clip.clip import _download
import re
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import random
from model.modules.utils import QuickGELU, LayerNorm, Adaptor, interpolate_pos_embed
from model.modules.resampler import PerceiverResampler
from huggingface_hub import hf_hub_download
from functools import partial
hf_hub_download = partial(hf_hub_download, library_name="open_clip", library_version='2.0.2')
_MODELS = {
"ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
"ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt",
"ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt",
"ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt",
"ViT-H/14": "laion/CLIP-ViT-H-14-laion2B-s32B-b79K",
}
class ResidualAttentionBlock(nn.Module):
def __init__(self, d_model: int, n_head: int):
super().__init__()
self.attn = nn.MultiheadAttention(d_model, n_head)
self.mlp = nn.Sequential(OrderedDict([
("c_fc", nn.Linear(d_model, d_model * 4)),
("gelu", QuickGELU()),
("c_proj", nn.Linear(d_model * 4, d_model))
])
)
self.ln_1 = LayerNorm(d_model)
self.ln_2 = LayerNorm(d_model)
def attention(self, x: torch.Tensor):
return self.attn(x, x, x, need_weights=False)[0]
def forward(self, x: torch.Tensor, mode='attention'):
if mode == 'attention':
return x + self.attention(self.ln_1(x))
elif mode == 'mlp':
return x + self.mlp(self.ln_2(x))
class Transformer(nn.Module):
def __init__(self, width: int, layers: int, heads: int):
super().__init__()
self.resblocks = nn.Sequential(*[nn.ModuleList([
ResidualAttentionBlock(width, heads),
Adaptor(width),
]) for _ in range(layers)])
def forward(self, x: torch.Tensor):
for resblock, adaptor in self.resblocks:
x = resblock(x, mode='attention')
x = adaptor(x)
x = resblock(x, mode='mlp')
return x
class VisionTransformer(nn.Module):
def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, experts: dict):
super().__init__()
self.experts = experts
self.conv1 = nn.ModuleDict()
for e in experts:
if e == 'rgb':
self.conv1[e] = nn.Conv2d(in_channels=experts[e], out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
elif e in ['seg', 'obj_detection', 'ocr_detection']:
self.conv1[e] = nn.Sequential(
nn.UpsamplingBilinear2d(scale_factor=4 / patch_size),
nn.Conv2d(in_channels=64, out_channels=width // 8, kernel_size=3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(width // 8),
nn.ReLU(),
nn.Conv2d(in_channels=width // 8, out_channels=width // 4, kernel_size=3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(width // 4),
nn.ReLU(),
nn.Conv2d(in_channels=width // 4, out_channels=width // 2, kernel_size=3, stride=1, padding=1, bias=False),
nn.BatchNorm2d(width // 2),
nn.ReLU(),
nn.Conv2d(in_channels=width // 2, out_channels=width, kernel_size=3, stride=1, padding=1, bias=False),
nn.BatchNorm2d(width),
nn.ReLU(),
nn.Conv2d(in_channels=width, out_channels=width, kernel_size=1, stride=1, padding=0, bias=False),
)
else:
self.conv1[e] = nn.Sequential(
nn.UpsamplingBilinear2d(scale_factor=16 / patch_size),
nn.Conv2d(in_channels=experts[e], out_channels=width // 8, kernel_size=3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(width // 8),
nn.ReLU(),
nn.Conv2d(in_channels=width // 8, out_channels=width // 4, kernel_size=3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(width // 4),
nn.ReLU(),
nn.Conv2d(in_channels=width // 4, out_channels=width // 2, kernel_size=3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(width // 2),
nn.ReLU(),
nn.Conv2d(in_channels=width // 2, out_channels=width, kernel_size=3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(width),
nn.ReLU(),
nn.Conv2d(in_channels=width, out_channels=width, kernel_size=1, stride=1, padding=0, bias=False),
)
scale = width ** -0.5
self.patch_size = patch_size
self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2, width))
if 'obj_detection' in self.experts:
self.instance_embedding = nn.Parameter(scale * torch.randn(128, width))
self.transformer = Transformer(width, layers, heads)
if len(self.experts) > 1:
self.resampler = PerceiverResampler(width=width, layers=4, heads=8, num_latents=64)
self.ln_pre = LayerNorm(width)
self.ln_post = LayerNorm(width)
def forward(self, x: dict):
experts_inputs = []
for exp in x:
domain = 'seg' if 'seg' in exp else exp
x_ = x[exp] if exp != 'obj_detection' else x[exp]['label']
x_ = self.conv1[domain](x_)
# add instance embedding (object detection only)
if exp == 'obj_detection':
instance_map = F.interpolate(x[exp]['instance'].to(x_.dtype), size=x_.shape[2:], mode='nearest')
instance_map = rearrange(instance_map, 'b 1 h w -> b h w')
label_map = rearrange(x_, 'b d h w -> d b h w')
for l in x[exp]['instance'].unique():
l_ = random.randint(0, 127)
label_map[:, instance_map == l] += self.instance_embedding[l_].unsqueeze(-1)
x_ = rearrange(label_map, 'd b h w -> b d h w')
x_ = rearrange(x_, 'b d h w -> b (h w) d')
# add position embedding (shared across all modalities)
if domain == 'rgb':
x_ = x_ + self.positional_embedding.to(x_.dtype)
rgb_inputs = x_
else:
exp_positional_embedding = interpolate_pos_embed(self.positional_embedding.to(x_.dtype), x_.shape[1])
x_ = x_ + exp_positional_embedding
experts_inputs.append(x_)
if len(experts_inputs) > 0:
experts_inputs = rearrange(torch.cat(experts_inputs, dim=1), 'b l d -> l b d')
experts_inputs = self.resampler(experts_inputs)
rgb_inputs = rearrange(rgb_inputs, 'b l d -> l b d')
x = torch.cat([rgb_inputs, experts_inputs], dim=0)
else:
x = rearrange(rgb_inputs, 'b l d -> l b d')
x = self.ln_pre(x)
x = self.transformer(x)
x = self.ln_post(x)
return x # latents, batch, output_dim
def load_encoder(name: str, experts: dict, image_resolution: int):
if name == 'ViT-B/16':
vision_width = 768
vision_patch_size = 16
vision_layers = 12
vision_heads = 12
elif name == 'ViT-L/14' or name == 'ViT-L/14@336px':
vision_width = 1024
vision_patch_size = 14
vision_layers = 24
vision_heads = 16
ViT = VisionTransformer(input_resolution=image_resolution,
patch_size=vision_patch_size,
width=vision_width,
layers=vision_layers,
heads=vision_heads,
experts=experts)
return ViT
# Quick Check:
# model = load_encoder("ViT-B/16", experts={'rgb': 3, 'depth': 1, 'seg': 64}, image_resolution=224)
# rgb, depth, seg = torch.rand(4, 3, 224, 224), torch.rand(4, 1, 224, 224), torch.rand(4, 64, 224, 224)
# feat = model({'rgb': rgb, 'depth': depth, 'seg': seg}) # 260 [196 + 64], 4, 768