Eagle-X5-13B-Chat / eagle /model /multimodal_encoder /multi_backbone_channel_concatenation_encoder.py
mins
initial commit
b443c25
import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint
from .convnext_encoder import ConvNextVisionTower
from .hr_clip_encoder import HRCLIPVisionTower
from .vision_models.eva_vit import EVAVITVisionTower
from .sam_encoder import SAMVisionTower
from .pix2struct_encoder import Pix2StructLargeVisionTower
import torch.nn.functional as F
from torch.nn.init import trunc_normal_
from copy import deepcopy
import random
import math
class MultiBackboneChannelConcatenationVisionTower(nn.Module):
def __init__(self,
vision_tower,
args,
grid_size=32):
super().__init__()
self.is_loaded = False
self.grid_size = grid_size
self.num_tokens = self.grid_size ** 2
vision_tower_name_list = vision_tower.split(";")
self.input_image_size = 1024 # hardcode
self.load_vision_towers(vision_tower_name_list, args)
def load_vision_towers(self, vision_tower_name_list, args):
self.vision_towers = nn.ModuleList()
for name in vision_tower_name_list:
if name == 'det-1024':
det_args = deepcopy(args)
det_args.input_image_size = 1024
det_args.freeze_vision = False
det_args.vision_tower_pretrained_from = '/lustre/fsw/portfolios/llmservice/users/fuxiaol/eva02_L_coco_det_sys_o365.pth'
det_vision_tower = EVAVITVisionTower("eva02-l-16", det_args)
det_vision_tower.load_model()
self.vision_towers.append(det_vision_tower)
elif name == 'convnext-1024':
## ConvNeXt
convnext_args = deepcopy(args)
convnext_args.freeze_vision = False
convnext_args.input_image_size = 1024
convnext_vision_tower = "convnext_xxlarge.clip_laion2b_soup" # hardcode
convnext_vision_tower = ConvNextVisionTower(convnext_vision_tower,
convnext_args)
convnext_vision_tower.load_model()
self.vision_towers.append(convnext_vision_tower)
elif name == "sam-1024":
sam_args = deepcopy(args)
sam_args.freeze_vision = False
sam_args.input_image_size = 1024
sam_args.add_pixel_shuffle = True
sam_vision_tower = SAMVisionTower("SAM-L", sam_args)
sam_vision_tower.load_model()
self.vision_towers.append(sam_vision_tower)
elif name == 'pix2struct-1024':
pix_args = deepcopy(args)
#pix_args.freeze_vision = True
pix_args.input_image_size = 1024
pix_args.freeze_vision = False
pix_args.do_resize = True
pix_args.de_normalize = True
pix_vision_tower = Pix2StructLargeVisionTower("pix2struct-large", pix_args)
pix_vision_tower.load_model()
self.vision_towers.append(pix_vision_tower)
elif name == 'clip-448':
clip_args = deepcopy(args)
clip_args.input_image_size = 336 # actually 448, will have no effect
clip_args.freeze_vision = False
clip_vision_tower = HRCLIPVisionTower("openai/clip-vit-large-patch14-336", clip_args)
clip_vision_tower.load_model()
self.vision_towers.append(clip_vision_tower)
# a hardcode here, so we always use convnext in the vision encoder mixture
self.image_processor = convnext_vision_tower.image_processor
self.is_loaded = True
def load_model(self):
assert self.is_loaded, "All the vision encoders should be loaded during initialization!"
def forward(self, x):
features = []
for vision_tower in self.vision_towers:
if vision_tower.input_image_size != self.input_image_size:
resized_x = F.interpolate(x.float(),
size=(vision_tower.input_image_size, vision_tower.input_image_size),
mode='bilinear',
align_corners=True).to(dtype=x.dtype)
else:
resized_x = x
feature = vision_tower(resized_x)
if len(feature.shape) == 3: # b, n, c
b, n, c = feature.shape
if n == self.num_tokens:
features.append(feature)
continue
w = h = int(n**0.5)
feature = feature.transpose(1,2).reshape(b, c, h, w)
else:
b, c, h, w = feature.shape
if w != self.grid_size:
feature = F.interpolate(feature.float(), size=(self.grid_size, self.grid_size), mode='bilinear', align_corners=True).to(dtype=x.dtype)
features.append(feature.flatten(2,3).transpose(1,2))
features = torch.cat(features, dim=-1)
return features
@property
def dummy_feature(self):
return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
@property
def dtype(self):
return next(self.clip_vision_tower.parameters()).dtype
@property
def device(self):
return next(self.clip_vision_tower.parameters()).device
@property
def config(self):
assert NotImplementedError
pass
@property
def hidden_size(self):
return sum([_.hidden_size for _ in self.vision_towers])
@property
def num_patches(self):
return self.num_tokens