Spaces:
Runtime error
Runtime error
File size: 5,903 Bytes
b443c25 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 |
import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint
from .convnext_encoder import ConvNextVisionTower
from .hr_clip_encoder import HRCLIPVisionTower
from .vision_models.eva_vit import EVAVITVisionTower
from .sam_encoder import SAMVisionTower
from .pix2struct_encoder import Pix2StructLargeVisionTower
import torch.nn.functional as F
from torch.nn.init import trunc_normal_
from copy import deepcopy
import random
import math
class MultiBackboneChannelConcatenationVisionTower(nn.Module):
def __init__(self,
vision_tower,
args,
grid_size=32):
super().__init__()
self.is_loaded = False
self.grid_size = grid_size
self.num_tokens = self.grid_size ** 2
vision_tower_name_list = vision_tower.split(";")
self.input_image_size = 1024 # hardcode
self.load_vision_towers(vision_tower_name_list, args)
def load_vision_towers(self, vision_tower_name_list, args):
self.vision_towers = nn.ModuleList()
for name in vision_tower_name_list:
if name == 'det-1024':
det_args = deepcopy(args)
det_args.input_image_size = 1024
det_args.freeze_vision = False
det_args.vision_tower_pretrained_from = '/lustre/fsw/portfolios/llmservice/users/fuxiaol/eva02_L_coco_det_sys_o365.pth'
det_vision_tower = EVAVITVisionTower("eva02-l-16", det_args)
det_vision_tower.load_model()
self.vision_towers.append(det_vision_tower)
elif name == 'convnext-1024':
## ConvNeXt
convnext_args = deepcopy(args)
convnext_args.freeze_vision = False
convnext_args.input_image_size = 1024
convnext_vision_tower = "convnext_xxlarge.clip_laion2b_soup" # hardcode
convnext_vision_tower = ConvNextVisionTower(convnext_vision_tower,
convnext_args)
convnext_vision_tower.load_model()
self.vision_towers.append(convnext_vision_tower)
elif name == "sam-1024":
sam_args = deepcopy(args)
sam_args.freeze_vision = False
sam_args.input_image_size = 1024
sam_args.add_pixel_shuffle = True
sam_vision_tower = SAMVisionTower("SAM-L", sam_args)
sam_vision_tower.load_model()
self.vision_towers.append(sam_vision_tower)
elif name == 'pix2struct-1024':
pix_args = deepcopy(args)
#pix_args.freeze_vision = True
pix_args.input_image_size = 1024
pix_args.freeze_vision = False
pix_args.do_resize = True
pix_args.de_normalize = True
pix_vision_tower = Pix2StructLargeVisionTower("pix2struct-large", pix_args)
pix_vision_tower.load_model()
self.vision_towers.append(pix_vision_tower)
elif name == 'clip-448':
clip_args = deepcopy(args)
clip_args.input_image_size = 336 # actually 448, will have no effect
clip_args.freeze_vision = False
clip_vision_tower = HRCLIPVisionTower("openai/clip-vit-large-patch14-336", clip_args)
clip_vision_tower.load_model()
self.vision_towers.append(clip_vision_tower)
# a hardcode here, so we always use convnext in the vision encoder mixture
self.image_processor = convnext_vision_tower.image_processor
self.is_loaded = True
def load_model(self):
assert self.is_loaded, "All the vision encoders should be loaded during initialization!"
def forward(self, x):
features = []
for vision_tower in self.vision_towers:
if vision_tower.input_image_size != self.input_image_size:
resized_x = F.interpolate(x.float(),
size=(vision_tower.input_image_size, vision_tower.input_image_size),
mode='bilinear',
align_corners=True).to(dtype=x.dtype)
else:
resized_x = x
feature = vision_tower(resized_x)
if len(feature.shape) == 3: # b, n, c
b, n, c = feature.shape
if n == self.num_tokens:
features.append(feature)
continue
w = h = int(n**0.5)
feature = feature.transpose(1,2).reshape(b, c, h, w)
else:
b, c, h, w = feature.shape
if w != self.grid_size:
feature = F.interpolate(feature.float(), size=(self.grid_size, self.grid_size), mode='bilinear', align_corners=True).to(dtype=x.dtype)
features.append(feature.flatten(2,3).transpose(1,2))
features = torch.cat(features, dim=-1)
return features
@property
def dummy_feature(self):
return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
@property
def dtype(self):
return next(self.clip_vision_tower.parameters()).dtype
@property
def device(self):
return next(self.clip_vision_tower.parameters()).device
@property
def config(self):
assert NotImplementedError
pass
@property
def hidden_size(self):
return sum([_.hidden_size for _ in self.vision_towers])
@property
def num_patches(self):
return self.num_tokens
|