Spaces:
Runtime error
Runtime error
Eagle-X5-13B-Chat
/
eagle
/model
/multimodal_encoder
/multi_backbone_channel_concatenation_encoder.py
import torch | |
import torch.nn as nn | |
from torch.utils.checkpoint import checkpoint | |
from .convnext_encoder import ConvNextVisionTower | |
from .hr_clip_encoder import HRCLIPVisionTower | |
from .vision_models.eva_vit import EVAVITVisionTower | |
from .sam_encoder import SAMVisionTower | |
from .pix2struct_encoder import Pix2StructLargeVisionTower | |
import torch.nn.functional as F | |
from torch.nn.init import trunc_normal_ | |
from copy import deepcopy | |
import random | |
import math | |
class MultiBackboneChannelConcatenationVisionTower(nn.Module): | |
def __init__(self, | |
vision_tower, | |
args, | |
grid_size=32): | |
super().__init__() | |
self.is_loaded = False | |
self.grid_size = grid_size | |
self.num_tokens = self.grid_size ** 2 | |
vision_tower_name_list = vision_tower.split(";") | |
self.input_image_size = 1024 # hardcode | |
self.load_vision_towers(vision_tower_name_list, args) | |
def load_vision_towers(self, vision_tower_name_list, args): | |
self.vision_towers = nn.ModuleList() | |
for name in vision_tower_name_list: | |
if name == 'det-1024': | |
det_args = deepcopy(args) | |
det_args.input_image_size = 1024 | |
det_args.freeze_vision = False | |
det_args.vision_tower_pretrained_from = '/lustre/fsw/portfolios/llmservice/users/fuxiaol/eva02_L_coco_det_sys_o365.pth' | |
det_vision_tower = EVAVITVisionTower("eva02-l-16", det_args) | |
det_vision_tower.load_model() | |
self.vision_towers.append(det_vision_tower) | |
elif name == 'convnext-1024': | |
## ConvNeXt | |
convnext_args = deepcopy(args) | |
convnext_args.freeze_vision = False | |
convnext_args.input_image_size = 1024 | |
convnext_vision_tower = "convnext_xxlarge.clip_laion2b_soup" # hardcode | |
convnext_vision_tower = ConvNextVisionTower(convnext_vision_tower, | |
convnext_args) | |
convnext_vision_tower.load_model() | |
self.vision_towers.append(convnext_vision_tower) | |
elif name == "sam-1024": | |
sam_args = deepcopy(args) | |
sam_args.freeze_vision = False | |
sam_args.input_image_size = 1024 | |
sam_args.add_pixel_shuffle = True | |
sam_vision_tower = SAMVisionTower("SAM-L", sam_args) | |
sam_vision_tower.load_model() | |
self.vision_towers.append(sam_vision_tower) | |
elif name == 'pix2struct-1024': | |
pix_args = deepcopy(args) | |
#pix_args.freeze_vision = True | |
pix_args.input_image_size = 1024 | |
pix_args.freeze_vision = False | |
pix_args.do_resize = True | |
pix_args.de_normalize = True | |
pix_vision_tower = Pix2StructLargeVisionTower("pix2struct-large", pix_args) | |
pix_vision_tower.load_model() | |
self.vision_towers.append(pix_vision_tower) | |
elif name == 'clip-448': | |
clip_args = deepcopy(args) | |
clip_args.input_image_size = 336 # actually 448, will have no effect | |
clip_args.freeze_vision = False | |
clip_vision_tower = HRCLIPVisionTower("openai/clip-vit-large-patch14-336", clip_args) | |
clip_vision_tower.load_model() | |
self.vision_towers.append(clip_vision_tower) | |
# a hardcode here, so we always use convnext in the vision encoder mixture | |
self.image_processor = convnext_vision_tower.image_processor | |
self.is_loaded = True | |
def load_model(self): | |
assert self.is_loaded, "All the vision encoders should be loaded during initialization!" | |
def forward(self, x): | |
features = [] | |
for vision_tower in self.vision_towers: | |
if vision_tower.input_image_size != self.input_image_size: | |
resized_x = F.interpolate(x.float(), | |
size=(vision_tower.input_image_size, vision_tower.input_image_size), | |
mode='bilinear', | |
align_corners=True).to(dtype=x.dtype) | |
else: | |
resized_x = x | |
feature = vision_tower(resized_x) | |
if len(feature.shape) == 3: # b, n, c | |
b, n, c = feature.shape | |
if n == self.num_tokens: | |
features.append(feature) | |
continue | |
w = h = int(n**0.5) | |
feature = feature.transpose(1,2).reshape(b, c, h, w) | |
else: | |
b, c, h, w = feature.shape | |
if w != self.grid_size: | |
feature = F.interpolate(feature.float(), size=(self.grid_size, self.grid_size), mode='bilinear', align_corners=True).to(dtype=x.dtype) | |
features.append(feature.flatten(2,3).transpose(1,2)) | |
features = torch.cat(features, dim=-1) | |
return features | |
def dummy_feature(self): | |
return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype) | |
def dtype(self): | |
return next(self.clip_vision_tower.parameters()).dtype | |
def device(self): | |
return next(self.clip_vision_tower.parameters()).device | |
def config(self): | |
assert NotImplementedError | |
pass | |
def hidden_size(self): | |
return sum([_.hidden_size for _ in self.vision_towers]) | |
def num_patches(self): | |
return self.num_tokens | |