Spaces:
Runtime error
Runtime error
| # Copyright 2023 Haotian Liu | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import math | |
| import random | |
| from abc import ABC, abstractmethod | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from longvu.constants import ( | |
| DEFAULT_IM_END_TOKEN, | |
| DEFAULT_IM_START_TOKEN, | |
| DEFAULT_IMAGE_PATCH_TOKEN, | |
| IGNORE_INDEX, | |
| IMAGE_TOKEN_INDEX, | |
| ) | |
| from .multimodal_encoder.builder import build_vision_tower_aux_list | |
| from .multimodal_projector.builder import build_vision_projector | |
| from .vision_sampler import VisionTokenSampler | |
| IS_XLA_AVAILABLE = False | |
| class CambrianMetaModel: | |
| def __init__(self, config): | |
| super(CambrianMetaModel, self).__init__(config) | |
| if hasattr(config, "mm_vision_tower_aux_list"): | |
| projector_type = getattr(config, "mm_projector_type", "linear") | |
| if projector_type == "sva": | |
| vision_hidden_size = config.vision_hidden_size | |
| num_query_group = config.num_query_group | |
| query_num_list = config.query_num_list | |
| connector_only = config.connector_only | |
| connector_depth = config.connector_depth | |
| self.vision_tower_aux_list = build_vision_tower_aux_list( | |
| config, delay_load=True | |
| ) | |
| self.mm_projector = nn.Sequential( | |
| nn.Linear(vision_hidden_size * num_query_group, config.hidden_size), | |
| nn.GELU(), | |
| nn.Linear(config.hidden_size, config.hidden_size), | |
| ) | |
| image_token_len = config.image_token_len | |
| vision_tower_aux_token_len_list = ( | |
| self.config.mm_vision_tower_aux_token_len_list | |
| ) | |
| cross_att_token_len_list = [ | |
| int(vision_tower_aux_token_len**0.5) // int(image_token_len**0.5) | |
| for vision_tower_aux_token_len in vision_tower_aux_token_len_list | |
| ] | |
| for aux_i, vision_tower_aux in enumerate(self.vision_tower_aux_list): | |
| setattr( | |
| self, | |
| "mm_projector_aux_{}".format(aux_i), | |
| nn.Sequential( | |
| nn.Linear(vision_tower_aux.hidden_size, vision_hidden_size), | |
| nn.GELU(), | |
| nn.Linear(vision_hidden_size, vision_hidden_size), | |
| nn.LayerNorm(vision_hidden_size), | |
| ), | |
| ) | |
| for query_group_i in range(num_query_group): | |
| cross_att_token_len_list = [ | |
| int(vision_tower_aux_token_len**0.5) | |
| // int(query_num_list[query_group_i] ** 0.5) | |
| for vision_tower_aux_token_len in vision_tower_aux_token_len_list | |
| ] | |
| setattr( | |
| self, | |
| "vision_sampler_{}".format(query_group_i), | |
| VisionTokenSampler( | |
| vision_hidden_size, | |
| vision_hidden_size, | |
| [vision_hidden_size] * len(self.vision_tower_aux_list), | |
| cross_att_token_len_list, | |
| vision_hidden_size, | |
| connector_depth, | |
| ), | |
| ) | |
| if not connector_only: | |
| num_of_vision_sampler_layers = ( | |
| config.num_of_vision_sampler_layers | |
| ) = config.num_of_vision_sampler_layers | |
| config.start_of_vision_sampler_layers = ( | |
| config.start_of_vision_sampler_layers | |
| ) | |
| config.stride_of_vision_sampler_layers = ( | |
| config.stride_of_vision_sampler_layers | |
| ) | |
| cross_att_token_len_list = [ | |
| int(vision_tower_aux_token_len**0.5) | |
| // int(image_token_len**0.5) | |
| for vision_tower_aux_token_len in vision_tower_aux_token_len_list | |
| ] | |
| self.vision_sampler_layers = nn.ModuleList( | |
| [ | |
| VisionTokenSampler( | |
| config.hidden_size, | |
| vision_hidden_size, | |
| [vision_hidden_size] * len(self.vision_tower_aux_list), | |
| cross_att_token_len_list, | |
| vision_hidden_size, | |
| 1, | |
| ) | |
| for layer_idx in range(0, num_of_vision_sampler_layers) | |
| ] | |
| ) | |
| self.vision_query = nn.Parameter( | |
| torch.randn((num_query_group, vision_hidden_size), dtype=self.dtype) | |
| ) | |
| self.image_newline = nn.Parameter( | |
| torch.empty(config.hidden_size, dtype=self.dtype) | |
| ) | |
| self.frame_pos = torch.stack( | |
| [ | |
| 1 | |
| / torch.pow( | |
| torch.tensor(10000), | |
| torch.tensor(2 * (hid_j // 2) / config.hidden_size), | |
| ) | |
| for hid_j in range(config.hidden_size) | |
| ] | |
| ) | |
| else: | |
| self.vision_tower_aux_list = build_vision_tower_aux_list( | |
| config, delay_load=True | |
| ) | |
| config.mm_hidden_size = sum( | |
| [ | |
| vision_tower_aux.hidden_size | |
| for vision_tower_aux in self.vision_tower_aux_list | |
| ] | |
| ) | |
| self.mm_projector = build_vision_projector(config) | |
| self.image_newline = nn.Parameter( | |
| torch.empty(config.hidden_size, dtype=self.dtype) | |
| ) | |
| def get_frame_pos(self, time_range): | |
| frame_pos = self.frame_pos.reshape(1, -1) * time_range.reshape(-1, 1).to( | |
| self.frame_pos.device | |
| ) | |
| frame_pos[:, 0::2] = torch.sin(frame_pos[:, 0::2]) | |
| frame_pos[:, 1::2] = torch.cos(frame_pos[:, 0::2]) | |
| frame_pos = frame_pos.unsqueeze(1) | |
| return frame_pos | |
| # def get_vision_tower(self): | |
| # vision_tower = getattr(self, 'vision_tower', None) | |
| # if type(vision_tower) is list: | |
| # vision_tower = vision_tower[0] | |
| # return vision_tower | |
| def get_vision_tower_aux_list(self): | |
| vision_tower_aux_list = getattr(self, "vision_tower_aux_list", None) | |
| return vision_tower_aux_list | |
| def initialize_vision_modules(self, model_args, fsdp=None): | |
| # vision_tower = model_args.vision_tower | |
| num_query_group = model_args.num_query_group | |
| query_num_list = model_args.query_num_list | |
| vision_hidden_size = model_args.vision_hidden_size | |
| vision_tower_aux_list = model_args.vision_tower_aux_list | |
| vision_tower_aux_token_len_list = model_args.vision_tower_aux_token_len_list | |
| image_token_len = model_args.image_token_len | |
| mm_vision_select_layer = model_args.mm_vision_select_layer | |
| mm_vision_select_feature = model_args.mm_vision_select_feature | |
| pretrain_mm_mlp_adapter = model_args.pretrain_mm_mlp_adapter | |
| connector_only = model_args.connector_only | |
| connector_depth = model_args.connector_depth | |
| # self.config.mm_vision_tower = vision_tower | |
| self.config.image_token_len = image_token_len | |
| self.config.num_query_group = num_query_group | |
| self.config.query_num_list = query_num_list | |
| assert num_query_group == len(query_num_list) | |
| self.config.connector_depth = connector_depth | |
| self.config.mm_vision_tower_aux_list = vision_tower_aux_list | |
| self.config.mm_vision_tower_aux_token_len_list = vision_tower_aux_token_len_list | |
| self.config.connector_only = connector_only | |
| self.config.highres_connect = model_args.highres_connect | |
| self.config.highres = model_args.highres | |
| self.config.frame_pos = model_args.frame_pos | |
| self.config.lowres_token = model_args.lowres_token | |
| self.config.connect_layer = model_args.connect_layer | |
| self.config.dino_threshold = getattr(model_args, "dino_threshold", 0.83) | |
| self.config.drop_threshold = getattr(model_args, "drop_threshold", 0.6) | |
| self.config.is_image_newline = getattr(model_args, "is_image_newline", True) | |
| if self.get_vision_tower_aux_list() is None: | |
| vision_tower_aux_list = build_vision_tower_aux_list(model_args) | |
| if model_args.unfreeze_mm_vision_tower: | |
| self.vision_tower_aux_list = nn.ModuleList(vision_tower_aux_list) | |
| else: | |
| self.vision_tower_aux_list = vision_tower_aux_list | |
| else: | |
| vision_tower_aux_list = self.vision_tower_aux_list | |
| for vision_tower_aux in vision_tower_aux_list: | |
| vision_tower_aux.load_model() | |
| self.config.use_mm_proj = True | |
| self.config.mm_projector_type = getattr( | |
| model_args, "mm_projector_type", "linear" | |
| ) | |
| self.config.vision_hidden_size = vision_hidden_size | |
| self.config.mm_vision_select_layer = mm_vision_select_layer | |
| self.config.mm_vision_select_feature = mm_vision_select_feature | |
| if getattr(self, "mm_projector", None) is None: | |
| if self.config.mm_projector_type == "sva": | |
| self.mm_projector = nn.Sequential( | |
| nn.Linear( | |
| vision_hidden_size * num_query_group, self.config.hidden_size | |
| ), | |
| nn.GELU(), | |
| nn.Linear(self.config.hidden_size, self.config.hidden_size), | |
| ) | |
| for aux_i, vision_tower_aux in enumerate(vision_tower_aux_list): | |
| setattr( | |
| self, | |
| "mm_projector_aux_{}".format(aux_i), | |
| nn.Sequential( | |
| nn.Linear(vision_tower_aux.hidden_size, vision_hidden_size), | |
| nn.GELU(), | |
| nn.Linear(vision_hidden_size, vision_hidden_size), | |
| nn.LayerNorm(vision_hidden_size), | |
| ), | |
| ) | |
| # vision sampler for each group of query as the connector before the LLM | |
| for query_group_i in range(num_query_group): | |
| cross_att_token_len_list = [ | |
| int(vision_tower_aux_token_len**0.5) | |
| // int(query_num_list[query_group_i] ** 0.5) | |
| for vision_tower_aux_token_len in vision_tower_aux_token_len_list | |
| ] | |
| setattr( | |
| self, | |
| "vision_sampler_{}".format(query_group_i), | |
| VisionTokenSampler( | |
| vision_hidden_size, | |
| vision_hidden_size, | |
| [vision_hidden_size] * len(vision_tower_aux_list), | |
| cross_att_token_len_list, | |
| vision_hidden_size, | |
| connector_depth, | |
| ), | |
| ) | |
| # sampler layers within LLM | |
| if not connector_only: | |
| num_of_vision_sampler_layers = ( | |
| self.config.num_of_vision_sampler_layers | |
| ) = model_args.num_of_vision_sampler_layers | |
| self.config.start_of_vision_sampler_layers = ( | |
| model_args.start_of_vision_sampler_layers | |
| ) | |
| self.config.stride_of_vision_sampler_layers = ( | |
| model_args.stride_of_vision_sampler_layers | |
| ) | |
| cross_att_token_len_list = [ | |
| int(vision_tower_aux_token_len**0.5) | |
| // int(image_token_len**0.5) | |
| for vision_tower_aux_token_len in vision_tower_aux_token_len_list | |
| ] | |
| self.vision_sampler_layers = nn.ModuleList( | |
| [ | |
| VisionTokenSampler( | |
| self.config.hidden_size, | |
| vision_hidden_size, | |
| [vision_hidden_size] * len(vision_tower_aux_list), | |
| cross_att_token_len_list, | |
| vision_hidden_size, | |
| 1, | |
| ) | |
| for layer_idx in range(0, num_of_vision_sampler_layers) | |
| ] | |
| ) | |
| vision_embed_std = 1 / torch.sqrt( | |
| torch.tensor(vision_hidden_size, dtype=self.dtype) | |
| ) | |
| self.vision_query = nn.Parameter( | |
| torch.randn((num_query_group, vision_hidden_size), dtype=self.dtype) | |
| * vision_embed_std | |
| ) | |
| embed_std = 1 / torch.sqrt( | |
| torch.tensor(self.config.hidden_size, dtype=self.dtype) | |
| ) | |
| self.image_newline = nn.Parameter( | |
| torch.randn(self.config.hidden_size, dtype=self.dtype) * embed_std | |
| ) | |
| else: | |
| self.config.mm_hidden_size = sum( | |
| [ | |
| vision_tower_aux.hidden_size | |
| for vision_tower_aux in vision_tower_aux_list | |
| ] | |
| ) | |
| self.mm_projector = build_vision_projector(self.config) | |
| embed_std = 1 / torch.sqrt( | |
| torch.tensor(self.config.hidden_size, dtype=self.dtype) | |
| ) | |
| self.image_newline = nn.Parameter( | |
| torch.randn(self.config.hidden_size, dtype=self.dtype) * embed_std | |
| ) | |
| else: | |
| # In case it is frozen by LoRA | |
| for p in self.mm_projector.parameters(): | |
| p.requires_grad = True | |
| if pretrain_mm_mlp_adapter is not None: | |
| mm_projector_weights = torch.load( | |
| pretrain_mm_mlp_adapter, map_location="cpu" | |
| ) | |
| def get_w(weights, keyword): | |
| return { | |
| k.split(keyword + ".")[1]: v | |
| for k, v in weights.items() | |
| if keyword + "." in k | |
| } | |
| self.mm_projector.load_state_dict( | |
| get_w(mm_projector_weights, "mm_projector"), strict=True | |
| ) | |
| if self.config.mm_projector_type == "sva": | |
| for aux_i in range(len(vision_tower_aux_list)): | |
| getattr(self, "mm_projector_aux_{}".format(aux_i)).load_state_dict( | |
| get_w( | |
| mm_projector_weights, "mm_projector_aux_{}".format(aux_i) | |
| ), | |
| strict=True, | |
| ) | |
| for query_group_i in range(num_query_group): | |
| getattr( | |
| self, "vision_sampler_{}".format(query_group_i) | |
| ).load_state_dict( | |
| get_w( | |
| mm_projector_weights, | |
| "vision_sampler_{}".format(query_group_i), | |
| ), | |
| strict=True, | |
| ) | |
| if not connector_only: | |
| self.vision_sampler_layers.load_state_dict( | |
| get_w(mm_projector_weights, "vision_sampler_layers"), | |
| strict=True, | |
| ) | |
| self.vision_query.data = mm_projector_weights["model.vision_query"] | |
| self.image_newline.data = mm_projector_weights["model.image_newline"] | |
| def unmask_attention_mask(mask, original_size): | |
| original_w, original_h = original_size | |
| cur_h, cur_w = mask.shape[1:3] | |
| original_aspect_ratio = original_w / original_h | |
| current_aspect_ratio = cur_w / cur_h | |
| if original_aspect_ratio > current_aspect_ratio: | |
| scale_factor = cur_w / original_w | |
| new_height = int(original_h * scale_factor) | |
| padding = (cur_h - new_height) // 2 | |
| if padding > 0: | |
| mask[:, :padding, :] = 0 | |
| mask[:, -padding:, :] = 0 | |
| return mask | |
| else: | |
| scale_factor = cur_h / original_h | |
| new_width = int(original_w * scale_factor) | |
| padding = (cur_w - new_width) // 2 | |
| if padding > 0: | |
| mask[:, :, :padding] = 0 | |
| mask[:, :, -padding:] = 0 | |
| return mask | |
| def unpad_image(tensor, original_size): | |
| """ | |
| Unpads a PyTorch tensor of a padded and resized image. | |
| Args: | |
| tensor (torch.Tensor): The image tensor, assumed to be in CxHxW format. | |
| original_size (tuple): The original size of the image (height, width). | |
| Returns: | |
| torch.Tensor: The unpadded image tensor. | |
| """ | |
| original_width, original_height = original_size | |
| current_height, current_width = tensor.shape[1:3] | |
| original_aspect_ratio = original_width / original_height | |
| current_aspect_ratio = current_width / current_height | |
| if original_aspect_ratio > current_aspect_ratio: | |
| scale_factor = current_width / original_width | |
| new_height = int(original_height * scale_factor) | |
| padding = (current_height - new_height) // 2 | |
| unpadded_tensor = tensor[:, padding : current_height - padding, :] | |
| # if 0 in unpadded_tensor.shape: | |
| # print(f"scale_factor: {scale_factor}, new_height: {new_height}, padding: {padding}, original_width: {original_width}, original_height: {original_height}") | |
| else: | |
| scale_factor = current_height / original_height | |
| new_width = int(original_width * scale_factor) | |
| padding = (current_width - new_width) // 2 | |
| unpadded_tensor = tensor[:, :, padding : current_width - padding] | |
| # if 0 in unpadded_tensor.shape: | |
| # print(f"scale_factor: {scale_factor}, new_width: {new_width}, padding: {padding}, original_width: {original_width}, original_height: {original_height}") | |
| return unpadded_tensor | |
| class CambrianMetaForCausalLM(ABC): | |
| def get_model(self): | |
| pass | |
| # def get_vision_tower(self): | |
| # return self.get_model().get_vision_tower() | |
| def get_vision_tower_aux_list(self): | |
| return self.get_model().get_vision_tower_aux_list() | |
| def rearrange_vision_tower_features_train( | |
| self, | |
| vision_tower_aux_feature_list, | |
| vision_tower_aux_attention_masks_list, | |
| query_side_len, | |
| ): | |
| vision_tower_aux_feature_rearranged_list = [] | |
| vision_tower_aux_attention_masks_rearranged_list = [] | |
| bs = vision_tower_aux_feature_list[0].shape[0] | |
| for vision_tower_aux_feature, vision_tower_aux_attention_masks in zip( | |
| vision_tower_aux_feature_list, vision_tower_aux_attention_masks_list | |
| ): | |
| aux_height = aux_width = int(vision_tower_aux_feature.shape[1] ** 0.5) | |
| assert (aux_height // query_side_len) * query_side_len == aux_height | |
| reduce_factor = aux_height // query_side_len | |
| vision_tower_aux_feature_rearranged = vision_tower_aux_feature.view( | |
| bs, query_side_len, reduce_factor, query_side_len, reduce_factor, -1 | |
| ) | |
| vision_tower_aux_feature_rearranged = ( | |
| vision_tower_aux_feature_rearranged.permute(0, 1, 3, 2, 4, 5) | |
| .contiguous() | |
| .flatten(0, 2) | |
| .flatten(1, 2) | |
| ) | |
| vision_tower_aux_attention_masks_rearranged = ( | |
| vision_tower_aux_attention_masks.view( | |
| bs * query_side_len * query_side_len, reduce_factor * reduce_factor | |
| ) | |
| ) | |
| vision_tower_aux_feature_rearranged_list.append( | |
| vision_tower_aux_feature_rearranged | |
| ) | |
| vision_tower_aux_attention_masks_rearranged_list.append( | |
| vision_tower_aux_attention_masks_rearranged | |
| ) | |
| return ( | |
| vision_tower_aux_feature_rearranged_list, | |
| vision_tower_aux_attention_masks_rearranged_list, | |
| ) | |
| def rearrange_vision_tower_features_inference( | |
| self, vision_tower_aux_feature_list, query_side_len, image_sizes, unpad=False | |
| ): | |
| vision_tower_aux_feature_rearranged_list = [] | |
| vision_tower_aux_attention_masks_rearranged_list = [] | |
| bs = vision_tower_aux_feature_list[0].shape[0] | |
| for vision_tower_aux_feature in vision_tower_aux_feature_list: | |
| aux_height = aux_width = int(vision_tower_aux_feature.shape[1] ** 0.5) | |
| assert (aux_height // query_side_len) * query_side_len == aux_height | |
| reduce_factor = aux_height // query_side_len | |
| vision_tower_aux_feature_rearranged = [] | |
| vision_tower_aux_attention_masks_rearranged = [] | |
| for batch_i in range(bs): | |
| image_size = image_sizes[batch_i] | |
| cur_vision_tower_aux_feature = vision_tower_aux_feature[batch_i] | |
| cur_vision_tower_aux_attention_masks_rearranged = torch.ones( | |
| (1, aux_height, aux_width), | |
| dtype=torch.bool, | |
| device=cur_vision_tower_aux_feature.device, | |
| ) | |
| cur_vision_tower_aux_feature_rearranged = ( | |
| cur_vision_tower_aux_feature.view( | |
| 1, | |
| query_side_len, | |
| reduce_factor, | |
| query_side_len, | |
| reduce_factor, | |
| -1, | |
| ) | |
| ) | |
| cur_vision_tower_aux_feature_rearranged = ( | |
| cur_vision_tower_aux_feature_rearranged.permute( | |
| 0, 1, 3, 2, 4, 5 | |
| ).contiguous() | |
| ) | |
| if unpad: | |
| cur_vision_tower_aux_feature_rearranged = unpad_image( | |
| cur_vision_tower_aux_feature_rearranged, image_size | |
| ) | |
| cur_vision_tower_aux_feature_rearranged = ( | |
| cur_vision_tower_aux_feature_rearranged.flatten(0, 2).flatten(1, 2) | |
| ) # query_side_len*query_side_len X reduce_factor*reduce_factor X C | |
| cur_vision_tower_aux_attention_masks_rearranged = unmask_attention_mask( | |
| cur_vision_tower_aux_attention_masks_rearranged, image_size | |
| ) | |
| cur_vision_tower_aux_attention_masks_rearranged = ( | |
| cur_vision_tower_aux_attention_masks_rearranged.view( | |
| 1, query_side_len, reduce_factor, query_side_len, reduce_factor | |
| ) | |
| .permute(0, 1, 3, 2, 4) | |
| .contiguous() | |
| ) | |
| if unpad: | |
| cur_vision_tower_aux_attention_masks_rearranged = unpad_image( | |
| cur_vision_tower_aux_attention_masks_rearranged, image_size | |
| ) | |
| cur_vision_tower_aux_attention_masks_rearranged = ( | |
| cur_vision_tower_aux_attention_masks_rearranged.flatten( | |
| 0, 2 | |
| ).flatten(1, 2) | |
| ) | |
| cur_vision_tower_aux_attention_masks_rearranged[ | |
| cur_vision_tower_aux_attention_masks_rearranged.sum(-1) == 0 | |
| ] = True | |
| vision_tower_aux_feature_rearranged.append( | |
| cur_vision_tower_aux_feature_rearranged | |
| ) | |
| vision_tower_aux_attention_masks_rearranged.append( | |
| cur_vision_tower_aux_attention_masks_rearranged | |
| ) | |
| vision_tower_aux_feature_rearranged = torch.cat( | |
| vision_tower_aux_feature_rearranged, 0 | |
| ) | |
| vision_tower_aux_attention_masks_rearranged = torch.cat( | |
| vision_tower_aux_attention_masks_rearranged, 0 | |
| ) | |
| vision_tower_aux_feature_rearranged_list.append( | |
| vision_tower_aux_feature_rearranged | |
| ) | |
| vision_tower_aux_attention_masks_rearranged_list.append( | |
| vision_tower_aux_attention_masks_rearranged | |
| ) | |
| return ( | |
| vision_tower_aux_feature_rearranged_list, | |
| vision_tower_aux_attention_masks_rearranged_list, | |
| ) | |
| def encode_images(self, image_aux_list, encode_type=None): | |
| vision_tower_aux_list = self.get_model().get_vision_tower_aux_list() | |
| image_aux_features_list = [] | |
| chunk_size = 64 | |
| if encode_type == "dino": | |
| image_aux = image_aux_list[-1] | |
| vision_tower_aux = vision_tower_aux_list[-1] | |
| if image_aux.shape[0] > chunk_size: | |
| image_aux_features_chunks = [] | |
| for start_idx in range(0, image_aux.shape[0], chunk_size): | |
| end_idx = min(start_idx + chunk_size, image_aux.shape[0]) | |
| chunk = image_aux[start_idx:end_idx] | |
| image_aux_features_chunk = vision_tower_aux(chunk) | |
| image_aux_features_chunks.append(image_aux_features_chunk) | |
| image_aux_features = torch.cat(image_aux_features_chunks, dim=0) | |
| else: | |
| image_aux_features = vision_tower_aux(image_aux) | |
| return image_aux_features | |
| elif encode_type == "siglip": | |
| image_aux = image_aux_list[0] | |
| vision_tower_aux = vision_tower_aux_list[0] | |
| if image_aux.shape[0] > chunk_size: | |
| image_aux_features_chunks = [] | |
| for start_idx in range(0, image_aux.shape[0], chunk_size): | |
| end_idx = min(start_idx + chunk_size, image_aux.shape[0]) | |
| chunk = image_aux[start_idx:end_idx] | |
| image_aux_features_chunk = vision_tower_aux(chunk) | |
| image_aux_features_chunks.append(image_aux_features_chunk) | |
| image_aux_features = torch.cat(image_aux_features_chunks, dim=0) | |
| else: | |
| image_aux_features = vision_tower_aux(image_aux) | |
| return image_aux_features | |
| else: | |
| for image_aux, vision_tower_aux in zip( | |
| image_aux_list, vision_tower_aux_list | |
| ): | |
| if image_aux.shape[0] > chunk_size: | |
| image_aux_features_chunks = [] | |
| for start_idx in range(0, image_aux.shape[0], chunk_size): | |
| end_idx = min(start_idx + chunk_size, image_aux.shape[0]) | |
| chunk = image_aux[start_idx:end_idx] | |
| image_aux_features_chunk = vision_tower_aux(chunk) | |
| image_aux_features_chunks.append(image_aux_features_chunk) | |
| image_aux_features = torch.cat(image_aux_features_chunks, dim=0) | |
| else: | |
| image_aux_features = vision_tower_aux(image_aux) | |
| image_aux_features_list.append(image_aux_features) | |
| return image_aux_features_list | |
| def select_frame( | |
| self, | |
| feature_list, | |
| split_sizes, | |
| input_ids, | |
| new_image_aux_list, | |
| image_sizes, | |
| window_size=16, | |
| threshold=0.83, | |
| ): | |
| dino_features_batch = torch.split(feature_list, split_sizes, dim=0) | |
| new_image_aux_batch_0 = torch.split(new_image_aux_list[0], split_sizes, dim=0) | |
| new_image_aux_batch_1 = torch.split(new_image_aux_list[1], split_sizes, dim=0) | |
| new_split_sizes = [] | |
| selected_frames_all_0 = [] | |
| selected_frames_all_1 = [] | |
| selected_frames_feature_all = [] | |
| selected_frame_indices_all = [] | |
| for i_batch, frame_features in enumerate(dino_features_batch): | |
| try: | |
| if "llama" in self.get_model().config.model_type: | |
| text_len = torch.where(input_ids[i_batch] == 128002)[-1][0] | |
| else: | |
| text_len = torch.where(input_ids[i_batch] == 151643)[-1][0] | |
| except: | |
| text_len = len(input_ids[i_batch]) | |
| original_width, original_height = image_sizes[i_batch] | |
| if getattr(self.get_model().config, "highres", False): | |
| token_per_frame = self.get_model().config.lowres_token ** 2 | |
| else: | |
| token_per_frame = self.get_model().config.image_token_len | |
| # current_height, current_width = token_per_side, token_per_side | |
| # original_aspect_ratio = original_width / original_height | |
| # current_aspect_ratio = current_width / current_height | |
| # if original_aspect_ratio > current_aspect_ratio: | |
| # scale_factor = current_width / original_width | |
| # new_height = int(original_height * scale_factor) | |
| # padding = math.ceil((current_height - new_height) / 2.0) | |
| # token_per_frame = ( | |
| # current_height - padding * 2 | |
| # ) * token_per_side + token_per_side | |
| # else: | |
| # scale_factor = current_height / original_height | |
| # new_width = int(original_width * scale_factor) | |
| # padding = math.ceil((current_width - new_width) / 2.0) | |
| # token_per_frame = (current_width - padding * 2) * token_per_side + ( | |
| # current_width - padding * 2 | |
| # ) | |
| # token_per_frame = ( | |
| # token_per_side**2 if token_per_frame < 1 else token_per_frame | |
| # ) | |
| max_num_frames = max( | |
| 1, | |
| ( | |
| self.get_model().config.tokenizer_model_max_length | |
| - text_len | |
| - getattr(self.get_model().config, "inference_max_length", 16) | |
| ) | |
| // token_per_frame, | |
| ) | |
| if len(frame_features) < max_num_frames: | |
| selected_frames_all_0.append(new_image_aux_batch_0[i_batch]) | |
| selected_frames_all_1.append(new_image_aux_batch_1[i_batch]) | |
| selected_frames_feature_all.append(frame_features) | |
| new_split_sizes.append(len(frame_features)) | |
| selected_frame_indices_all.append(torch.arange(len(frame_features))) | |
| continue | |
| num_segments = len(frame_features) // window_size | |
| if num_segments == 0: | |
| query_feature = frame_features.flatten(1, 2) | |
| query_feature = query_feature / torch.norm( | |
| (query_feature), dim=1, keepdim=True | |
| ) | |
| similarities = torch.mean(query_feature @ query_feature.T, dim=1) | |
| similarities[len(frame_features) // 2] = 0 | |
| indices = torch.where(similarities < threshold)[0] | |
| selected_frame_indices_all.append(indices) | |
| selected_frames_all_0.append(new_image_aux_batch_0[i_batch][indices]) | |
| selected_frames_all_1.append(new_image_aux_batch_1[i_batch][indices]) | |
| selected_frames_feature_all.append(frame_features[indices]) | |
| new_split_sizes.append(len(indices)) | |
| continue | |
| segments_frames_0 = [] | |
| segments_frames_1 = [] | |
| segments_features = [] | |
| for start_idx in range(0, len(frame_features), window_size): | |
| end_idx = min(start_idx + window_size, len(frame_features)) | |
| segments_frames_0.append( | |
| new_image_aux_batch_0[i_batch][start_idx:end_idx] | |
| ) | |
| segments_frames_1.append( | |
| new_image_aux_batch_1[i_batch][start_idx:end_idx] | |
| ) | |
| segments_features.append(frame_features[start_idx:end_idx]) | |
| selected_frames_0 = [] | |
| selected_frames_1 = [] | |
| selected_features = [] | |
| selected_frame_indices = [] | |
| for i, segment in enumerate(segments_features): | |
| query_feature = segment.flatten(1, 2) | |
| query_feature = query_feature / torch.norm( | |
| (query_feature), dim=1, keepdim=True | |
| ) | |
| similarities = torch.mean(query_feature @ query_feature.T, dim=1) | |
| similarities[len(segment) // 2] = 0 | |
| indices = torch.where(similarities < threshold)[0] | |
| selected_frames_0.append(segments_frames_0[i][indices]) | |
| selected_frames_1.append(segments_frames_1[i][indices]) | |
| selected_features.append(segment[indices]) | |
| selected_frame_indices.extend(indices + i * window_size) | |
| selected_frames_0 = torch.cat(selected_frames_0, dim=0) | |
| selected_frames_1 = torch.cat(selected_frames_1, dim=0) | |
| selected_features = torch.cat(selected_features, dim=0) | |
| selected_frame_indices = torch.tensor(selected_frame_indices) | |
| # ablation | |
| max_num_frames = 400 # in case of OOM | |
| if len(selected_frames_0) > max_num_frames: | |
| interval = len(selected_frames_0) / float(max_num_frames) | |
| indices = [int(interval * i) for i in range(max_num_frames)] | |
| new_split_sizes.append(len(indices)) | |
| selected_frames_all_0.append(selected_frames_0[indices]) | |
| selected_frames_all_1.append(selected_frames_1[indices]) | |
| selected_frames_feature_all.append(selected_features[indices]) | |
| selected_frame_indices = selected_frame_indices[indices] | |
| else: | |
| new_split_sizes.append(len(selected_frames_0)) | |
| selected_frames_all_0.append(selected_frames_0) | |
| selected_frames_all_1.append(selected_frames_1) | |
| selected_frames_feature_all.append(selected_features) | |
| selected_frame_indices_all.append(selected_frame_indices) | |
| selected_frames_all_0 = torch.cat(selected_frames_all_0, dim=0) | |
| selected_frames_all_1 = torch.cat(selected_frames_all_1, dim=0) | |
| selected_frames_feature_all = torch.cat(selected_frames_feature_all, dim=0) | |
| return ( | |
| selected_frames_feature_all, | |
| new_split_sizes, | |
| [selected_frames_all_0, selected_frames_all_1], | |
| selected_frame_indices_all, | |
| ) | |
| def prepare_inputs_labels_for_multimodal( | |
| self, | |
| input_ids, | |
| position_ids, | |
| attention_mask, | |
| past_key_values, | |
| labels, | |
| images, | |
| image_aux_attention_masks_list=None, | |
| image_sizes=None, | |
| ): | |
| # vision_tower = self.get_vision_tower() | |
| vision_tower_aux_list = self.get_model().get_vision_tower_aux_list() | |
| if vision_tower_aux_list is None or images is None or input_ids.shape[1] == 1: | |
| return ( | |
| input_ids, | |
| position_ids, | |
| attention_mask, | |
| past_key_values, | |
| None, | |
| labels, | |
| None, | |
| None, | |
| None, | |
| None, | |
| ) | |
| image_aux_list = images | |
| split_sizes = None | |
| if type(image_aux_list[0]) is list or image_aux_list[0].ndim == 5: | |
| split_sizes_ori = [ | |
| 1 if image.ndim == 3 else image.shape[0] for image in image_aux_list[0] | |
| ] | |
| new_image_aux_list = [] | |
| for image_aux in image_aux_list: | |
| if type(image_aux) is list: | |
| image_aux = [ | |
| x.unsqueeze(0) if x.ndim == 3 else x for x in image_aux | |
| ] | |
| concat_image_aux = torch.cat([image for image in image_aux], dim=0) | |
| new_image_aux_list.append(concat_image_aux) | |
| image_aux_features_dino = self.encode_images( | |
| new_image_aux_list, encode_type="dino" | |
| ) | |
| ( | |
| image_aux_features_dino, | |
| split_sizes, | |
| new_image_aux_list, | |
| selected_frame_indices_all, | |
| ) = self.select_frame( | |
| image_aux_features_dino, | |
| split_sizes_ori, | |
| input_ids, | |
| new_image_aux_list, | |
| image_sizes, | |
| threshold=getattr(self.get_model().config, "dino_threshold", 0.83), | |
| ) | |
| image_aux_features_siglip = self.encode_images( | |
| new_image_aux_list, encode_type="siglip" | |
| ) | |
| image_aux_features_list = [ | |
| image_aux_features_siglip, | |
| image_aux_features_dino, | |
| ] | |
| bs = image_aux_features_list[0].shape[0] | |
| dtype = new_image_aux_list[0].dtype | |
| frame_sizes = [] | |
| for i in range(len(image_sizes)): | |
| for j in range(split_sizes[i]): | |
| frame_sizes.append(image_sizes[i]) | |
| image_sizes = frame_sizes | |
| else: | |
| image_aux_features_list = self.encode_images(image_aux_list) | |
| bs = image_aux_list[0].shape[0] | |
| dtype = image_aux_list[0].dtype | |
| image_token_len = self.get_model().config.image_token_len | |
| query_num_list = self.get_model().config.query_num_list | |
| final_height = final_width = int(image_token_len**0.5) | |
| final_image_features_list = [] | |
| final_image_features_down_list = [] | |
| # only needed for sva | |
| vision_tower_aux_feature_list_final = None | |
| vision_tower_aux_attention_masks_list_final = None | |
| global_context_feature_final = None | |
| if self.get_model().config.mm_projector_type == "sva": | |
| vision_tower_aux_feature_list = [] | |
| vision_tower_aux_attention_masks_list = [] | |
| # get vision tokens from each vision tower | |
| for aux_i in range(len(vision_tower_aux_list)): | |
| image_aux_features = image_aux_features_list[aux_i] | |
| image_aux_features = getattr( | |
| self.get_model(), "mm_projector_aux_{}".format(aux_i) | |
| )(image_aux_features).to(dtype) | |
| if aux_i == 0: | |
| global_context_feature = image_aux_features.mean(1).view( | |
| bs, 1, 1, -1 | |
| ) | |
| vision_tower_aux_feature_list.append(image_aux_features) | |
| input_mix_res = True | |
| input_high_res = True | |
| # perform vision sampling for each query group | |
| for query_group_i, query_num in enumerate(query_num_list): | |
| query_features_i = ( | |
| self.get_model() | |
| .vision_query[query_group_i, :] | |
| .view(1, 1, 1, -1) | |
| .expand(bs, query_num, -1, -1) | |
| ) | |
| global_context_feature_i = global_context_feature.expand( | |
| -1, query_num, 1, -1 | |
| ).flatten(0, 1) | |
| query_side_len = int(query_num**0.5) | |
| if IS_XLA_AVAILABLE: | |
| ( | |
| vision_tower_aux_feature_list_i, | |
| vision_tower_aux_attention_masks_list_i, | |
| ) = self.rearrange_vision_tower_features_train( | |
| vision_tower_aux_feature_list, | |
| image_aux_attention_masks_list, | |
| query_side_len, | |
| ) | |
| else: | |
| ( | |
| vision_tower_aux_feature_list_i, | |
| vision_tower_aux_attention_masks_list_i, | |
| ) = self.rearrange_vision_tower_features_inference( | |
| vision_tower_aux_feature_list, query_side_len, image_sizes | |
| ) | |
| query_features_i = getattr( | |
| self.get_model(), "vision_sampler_{}".format(query_group_i) | |
| )( | |
| query_features_i.flatten(0, 1), | |
| global_context_feature_i, | |
| *vision_tower_aux_feature_list_i, | |
| *vision_tower_aux_attention_masks_list_i, | |
| ) | |
| query_features_i = query_features_i.view(bs, query_num, -1) | |
| if split_sizes is not None: | |
| try: | |
| if "llama" in self.get_model().config.model_type: | |
| text_len = torch.where(input_ids[0] == 128002)[-1][0] | |
| else: | |
| text_len = torch.where(input_ids[0] == 151643)[-1][0] | |
| except: | |
| text_len = len(input_ids[0]) | |
| max_visual_len = ( | |
| self.get_model().config.tokenizer_model_max_length | |
| - text_len | |
| - getattr(self.get_model().config, "inference_max_length", 16) | |
| ) | |
| max_num_frames = max( | |
| 1, | |
| math.floor(max_visual_len // (final_height * final_width)), | |
| ) | |
| max_num_frames_low = max( | |
| 1, | |
| math.floor( | |
| max_visual_len | |
| // (self.get_model().config.lowres_token ** 2) | |
| ), | |
| ) | |
| if split_sizes[0] < max_num_frames: | |
| input_mix_res = False | |
| elif split_sizes[0] > max_num_frames_low: | |
| input_mix_res = False | |
| input_high_res = False | |
| # input_mix_res = False # ablation | |
| if (getattr(self.config, "highres", False)) and input_mix_res: | |
| _query_features_i = ( | |
| query_features_i.permute(0, 2, 1) | |
| .contiguous() | |
| .view(bs, -1, query_side_len, query_side_len) | |
| ) | |
| _query_features_i = F.interpolate( | |
| _query_features_i.float(), | |
| size=( | |
| self.get_model().config.lowres_token, | |
| self.get_model().config.lowres_token, | |
| ), | |
| mode="bilinear", | |
| align_corners=False, | |
| ).to(dtype=query_features_i.dtype) | |
| _query_features_i = ( | |
| _query_features_i.permute(0, 2, 3, 1).contiguous().flatten(1, 2) | |
| ) | |
| final_image_features_down_list.append(_query_features_i) | |
| # interpolate to the final target size | |
| if query_side_len != final_height: | |
| query_features_i = ( | |
| query_features_i.permute(0, 2, 1) | |
| .contiguous() | |
| .view(bs, -1, query_side_len, query_side_len) | |
| ) | |
| if input_high_res: | |
| query_features_i = F.interpolate( | |
| query_features_i.float(), | |
| size=(final_height, final_width), | |
| mode="bilinear", | |
| align_corners=False, | |
| ).to(dtype=query_features_i.dtype) | |
| else: | |
| query_features_i = F.interpolate( | |
| query_features_i.float(), | |
| size=(8, 8), | |
| mode="bilinear", | |
| align_corners=False, | |
| ).to(dtype=query_features_i.dtype) | |
| query_features_i = ( | |
| query_features_i.permute(0, 2, 3, 1).contiguous().flatten(1, 2) | |
| ) | |
| final_image_features_list.append(query_features_i) | |
| if IS_XLA_AVAILABLE: | |
| ( | |
| vision_tower_aux_feature_list_final, | |
| vision_tower_aux_attention_masks_list_final, | |
| ) = self.rearrange_vision_tower_features_train( | |
| vision_tower_aux_feature_list, | |
| image_aux_attention_masks_list, | |
| final_height, | |
| ) | |
| global_context_feature_final = global_context_feature.expand( | |
| -1, final_height * final_width, 1, -1 | |
| ).flatten(0, 1) | |
| else: | |
| final_image_features_list = image_aux_features_list | |
| image_features = torch.cat(final_image_features_list, -1) | |
| image_features = self.get_model().mm_projector(image_features).to(dtype) | |
| if (getattr(self.config, "highres", False)) and input_mix_res: | |
| image_features_down = torch.cat(final_image_features_down_list, -1) | |
| image_features_down = ( | |
| self.get_model().mm_projector(image_features_down).to(dtype) | |
| ) | |
| if IS_XLA_AVAILABLE: | |
| image_features = image_features.view( | |
| image_features.shape[0], final_height, final_width, -1 | |
| ) | |
| image_features = torch.cat( | |
| ( | |
| image_features, | |
| self.model.image_newline[None, None, None, :].expand( | |
| image_features.shape[0], final_height, 1, -1 | |
| ), | |
| ), | |
| dim=2, | |
| ) | |
| image_features = image_features.flatten(1, 2) | |
| final_size = [(final_height, final_width)] * bs | |
| else: | |
| image_features = image_features.view(bs, final_height, final_width, -1) | |
| if (getattr(self.config, "highres", False)) and input_mix_res: | |
| image_features_down = image_features_down.view( | |
| bs, | |
| self.get_model().config.lowres_token, | |
| self.get_model().config.lowres_token, | |
| -1, | |
| ) | |
| image_features_unpadded = [] | |
| image_features_downsample = [] | |
| final_size = [] | |
| if self.get_model().config.mm_projector_type == "sva": | |
| ( | |
| vision_tower_aux_feature_list_final, | |
| vision_tower_aux_attention_masks_list_final, | |
| ) = self.rearrange_vision_tower_features_inference( | |
| vision_tower_aux_feature_list, final_height, image_sizes, unpad=True | |
| ) | |
| global_context_feature_final = [] | |
| for batch_i in range(bs): | |
| cur_image_feature = image_features[batch_i] | |
| image_size = image_sizes[batch_i] | |
| cur_image_feature = unpad_image( | |
| cur_image_feature.unsqueeze(0), image_size | |
| ) | |
| cur_h, cur_w = cur_image_feature.shape[1:3] | |
| try: # fix bug for some invalid image | |
| cur_image_feature = cur_image_feature.view(1, cur_h, cur_w, -1) | |
| final_size.append((cur_h, cur_w)) | |
| except: | |
| # print(f"invalid after unpad {image_features[batch_i].shape}, {image_sizes[batch_i]}", flush=True) | |
| cur_image_feature = image_features[batch_i].unsqueeze(0) | |
| image_size = image_sizes[batch_i] | |
| cur_h, cur_w = cur_image_feature.shape[1:3] | |
| cur_image_feature = cur_image_feature.view(1, cur_h, cur_w, -1) | |
| final_size.append((cur_h, cur_w)) | |
| if (getattr(self.config, "highres", False)) and input_mix_res: | |
| cur_image_feature_down = unpad_image( | |
| image_features_down[batch_i].unsqueeze(0), | |
| ( | |
| int( | |
| image_size[0] | |
| / ( | |
| image_token_len**0.5 | |
| / self.get_model().config.lowres_token | |
| ) | |
| ), | |
| int( | |
| image_size[1] | |
| / ( | |
| image_token_len**0.5 | |
| / self.get_model().config.lowres_token | |
| ) | |
| ), | |
| ), | |
| ) | |
| _cur_h, _cur_w = cur_image_feature_down.shape[1:3] | |
| try: # fix bug for some invalid image | |
| cur_image_feature_down = cur_image_feature_down.view( | |
| 1, _cur_h, _cur_w, -1 | |
| ) | |
| except: | |
| print("invalid after unpad", flush=True) | |
| cur_image_feature_down = image_features_down[batch_i].unsqueeze( | |
| 0 | |
| ) | |
| _cur_h, _cur_w = cur_image_feature_down.shape[1:3] | |
| cur_image_feature_down = cur_image_feature_down.view( | |
| 1, _cur_h, _cur_w, -1 | |
| ) | |
| cur_image_feature_down = torch.cat( | |
| ( | |
| cur_image_feature_down, | |
| self.model.image_newline.view(1, 1, 1, -1) | |
| .expand(1, _cur_h, 1, -1) | |
| .to(cur_image_feature_down.device), | |
| ), | |
| dim=2, | |
| ).flatten(1, 2) | |
| if split_sizes is None and getattr(self.config, "frame_pos", False): | |
| frame_pos = ( | |
| self.get_model() | |
| .get_frame_pos(torch.arange(1)) | |
| .to(cur_image_feature_down.device) | |
| .to(cur_image_feature_down.dtype) | |
| ) | |
| cur_image_feature_down += frame_pos | |
| image_features_downsample.append(cur_image_feature_down.squeeze(0)) | |
| cur_image_feature = torch.cat( | |
| ( | |
| cur_image_feature, | |
| self.model.image_newline.view(1, 1, 1, -1) | |
| .expand(1, cur_h, 1, -1) | |
| .to(cur_image_feature.device), | |
| ), | |
| dim=2, | |
| ) | |
| if split_sizes is None and getattr(self.config, "frame_pos", False): | |
| frame_pos = ( | |
| self.get_model() | |
| .get_frame_pos(torch.arange(1)) | |
| .to(cur_image_feature.device) | |
| .to(cur_image_feature.dtype) | |
| ) | |
| cur_image_feature += frame_pos | |
| cur_image_feature = cur_image_feature.flatten(1, 2) | |
| image_features_unpadded.append(cur_image_feature.squeeze(0)) | |
| if self.get_model().config.mm_projector_type == "sva": | |
| cur_global_context_feature = global_context_feature[batch_i].expand( | |
| cur_h * cur_w, 1, -1 | |
| ) | |
| global_context_feature_final.append(cur_global_context_feature) | |
| if self.get_model().config.mm_projector_type == "sva": | |
| global_context_feature_final = torch.cat( | |
| global_context_feature_final, 0 | |
| ) | |
| if (getattr(self.config, "highres", False)) and input_mix_res: | |
| image_features = image_features_downsample | |
| else: | |
| image_features = image_features_unpadded | |
| # TODO: image start / end is not implemented here to support pretraining. | |
| if getattr(self.config, "tune_mm_mlp_adapter", False) and getattr( | |
| self.config, "mm_use_im_start_end", False | |
| ): | |
| raise NotImplementedError | |
| split_image_features_unpadded = None | |
| frame_split_sizes = None | |
| if split_sizes is not None: | |
| split_image_features = [] | |
| split_image_features_unpadded = ( | |
| [] | |
| if (getattr(self.config, "highres", False)) and input_mix_res | |
| else None | |
| ) | |
| start_idx = 0 | |
| for split_batch_idx, split_size in enumerate(split_sizes): | |
| if isinstance(image_features[start_idx : start_idx + split_size], list): | |
| if getattr(self.config, "frame_pos", False): | |
| frame_feature = torch.cat( | |
| image_features[start_idx : start_idx + split_size], dim=0 | |
| ).reshape(split_size, -1, image_features[0].shape[-1]) | |
| frame_pos = ( | |
| self.get_model() | |
| .get_frame_pos(selected_frame_indices_all[split_batch_idx]) | |
| .to(frame_feature.device) | |
| .to(frame_feature.dtype) | |
| ) | |
| frame_feature += frame_pos | |
| split_image_features.append( | |
| frame_feature.reshape(-1, image_features[0].shape[-1]) | |
| ) | |
| else: | |
| split_image_features.append( | |
| torch.cat( | |
| image_features[start_idx : start_idx + split_size], | |
| dim=0, | |
| ) | |
| ) | |
| if (getattr(self.config, "highres", False)) and input_mix_res: | |
| if getattr(self.config, "frame_pos", False): | |
| frame_feature = torch.cat( | |
| image_features_unpadded[ | |
| start_idx : start_idx + split_size | |
| ], | |
| dim=0, | |
| ).reshape(split_size, -1, image_features[0].shape[-1]) | |
| frame_pos = ( | |
| self.get_model() | |
| .get_frame_pos( | |
| selected_frame_indices_all[split_batch_idx] | |
| ) | |
| .to(frame_feature.device) | |
| .to(frame_feature.dtype) | |
| ) | |
| frame_feature += frame_pos | |
| split_image_features_unpadded.append( | |
| frame_feature.reshape(-1, image_features[0].shape[-1]) | |
| ) | |
| else: | |
| split_image_features_unpadded.append( | |
| torch.cat( | |
| image_features_unpadded[ | |
| start_idx : start_idx + split_size | |
| ], | |
| dim=0, | |
| ) | |
| ) | |
| else: | |
| if getattr(self.config, "frame_pos", False): | |
| frame_feature = image_features[ | |
| start_idx : start_idx + split_size | |
| ].reshape(split_size, -1, image_features[0].shape[-1]) | |
| frame_pos = ( | |
| self.get_model() | |
| .get_frame_pos(selected_frame_indices_all[split_batch_idx]) | |
| .to(frame_feature.device) | |
| .to(frame_feature.dtype) | |
| ) | |
| frame_feature += frame_pos | |
| split_image_features.append( | |
| frame_feature.reshape(-1, image_features[0].shape[-1]) | |
| ) | |
| else: | |
| split_image_features.append( | |
| image_features[start_idx : start_idx + split_size] | |
| ) | |
| if (getattr(self.config, "highres", False)) and input_mix_res: | |
| if getattr(self.config, "frame_pos", False): | |
| frame_feature = image_features_unpadded[ | |
| start_idx : start_idx + split_size | |
| ] | |
| frame_pos = ( | |
| self.get_model() | |
| .get_frame_pos( | |
| selected_frame_indices_all[split_batch_idx] | |
| ) | |
| .to(frame_feature.device) | |
| .to(frame_feature.dtype) | |
| ) | |
| frame_feature += frame_pos | |
| split_image_features_unpadded.append( | |
| frame_feature.reshape(-1, image_features[0].shape[-1]) | |
| ) | |
| else: | |
| split_image_features_unpadded.append( | |
| image_features_unpadded[ | |
| start_idx : start_idx + split_size | |
| ] | |
| ) | |
| start_idx += split_size | |
| image_features = split_image_features | |
| frame_split_sizes = split_sizes | |
| _labels = labels | |
| _position_ids = position_ids | |
| _attention_mask = attention_mask | |
| if attention_mask is None: | |
| attention_mask = torch.ones_like(input_ids, dtype=torch.bool) | |
| else: | |
| attention_mask = attention_mask.bool() | |
| if position_ids is None: | |
| position_ids = torch.arange( | |
| 0, input_ids.shape[1], dtype=torch.long, device=input_ids.device | |
| ) | |
| if labels is None: | |
| labels = torch.full_like(input_ids, IGNORE_INDEX) | |
| # remove the padding using attention_mask -- FIXME | |
| _input_ids = input_ids | |
| attention_mask = attention_mask | (input_ids == IMAGE_TOKEN_INDEX) | |
| input_ids = [ | |
| cur_input_ids[cur_attention_mask] | |
| for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask) | |
| ] | |
| labels = [ | |
| cur_labels[cur_attention_mask] | |
| for cur_labels, cur_attention_mask in zip(labels, attention_mask) | |
| ] | |
| new_input_embeds = [] | |
| new_labels = [] | |
| image_token_indices_batch = [] | |
| cur_image_idx = 0 | |
| for batch_idx, cur_input_ids in enumerate(input_ids): | |
| num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum() | |
| if num_images == 0: | |
| cur_image_features = image_features[cur_image_idx] | |
| cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids) | |
| cur_input_embeds = torch.cat( | |
| [cur_input_embeds_1, cur_image_features[0:0]], dim=0 | |
| ) | |
| new_input_embeds.append(cur_input_embeds) | |
| new_labels.append(labels[batch_idx]) | |
| cur_image_idx += 1 | |
| continue | |
| image_token_indices = ( | |
| [-1] | |
| + torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist() | |
| + [cur_input_ids.shape[0]] | |
| ) | |
| image_token_indices_batch.append( | |
| torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist()[0] | |
| ) | |
| cur_input_ids_noim = [] | |
| cur_labels = labels[batch_idx] | |
| cur_labels_noim = [] | |
| for i in range(len(image_token_indices) - 1): | |
| cur_input_ids_noim.append( | |
| cur_input_ids[ | |
| image_token_indices[i] + 1 : image_token_indices[i + 1] | |
| ] | |
| ) | |
| cur_labels_noim.append( | |
| cur_labels[image_token_indices[i] + 1 : image_token_indices[i + 1]] | |
| ) | |
| split_sizes = [x.shape[0] for x in cur_labels_noim] | |
| cur_input_embeds = self.get_model().embed_tokens( | |
| torch.cat(cur_input_ids_noim) | |
| ) | |
| cur_input_embeds_no_im = torch.split(cur_input_embeds, split_sizes, dim=0) | |
| cur_new_input_embeds = [] | |
| cur_new_labels = [] | |
| text_len = sum([x.shape[0] for x in cur_input_embeds_no_im]) | |
| visual_len = len(image_features[cur_image_idx]) | |
| max_visual_len = ( | |
| self.get_model().config.tokenizer_model_max_length | |
| - getattr(self.get_model().config, "inference_max_length", 16) | |
| - text_len | |
| ) | |
| mix_token = False | |
| # ablation mix | |
| if ( | |
| input_mix_res | |
| and ( | |
| self.get_model().config.image_token_len | |
| > getattr(self.get_model().config, "lowres_token", 8) ** 2 | |
| ) | |
| and frame_split_sizes is not None | |
| and getattr(self.config, "highres", False) | |
| ): | |
| if max_visual_len > visual_len: | |
| visual_emb = image_features[cur_image_idx] | |
| text_emb = cur_input_embeds_no_im[-1] | |
| highres_num = math.floor( | |
| (max_visual_len - visual_len) | |
| / ( | |
| split_image_features_unpadded[cur_image_idx].shape[0] | |
| // frame_split_sizes[cur_image_idx] | |
| - visual_emb.shape[0] // frame_split_sizes[cur_image_idx] | |
| ) | |
| ) | |
| if highres_num >= 1: | |
| mix_token = True | |
| sim = torch.matmul(visual_emb, text_emb.transpose(0, 1)).mean( | |
| dim=-1 | |
| ) | |
| sim_frame = sim.reshape( | |
| frame_split_sizes[cur_image_idx], -1 | |
| ).mean(dim=-1) | |
| highres_num = min(highres_num, sim_frame.shape[0]) | |
| top_values, top_indices = torch.topk(sim_frame, highres_num) | |
| if len(top_indices) > 0: | |
| sorted_indices = torch.sort(top_indices)[1] | |
| top_indices = top_indices[sorted_indices] | |
| visual_emb_frame = image_features[cur_image_idx].reshape( | |
| frame_split_sizes[cur_image_idx], | |
| -1, | |
| image_features[cur_image_idx].shape[-1], | |
| ) | |
| visual_emb_frame_highres = split_image_features_unpadded[ | |
| cur_image_idx | |
| ].reshape( | |
| frame_split_sizes[cur_image_idx], | |
| -1, | |
| split_image_features_unpadded[cur_image_idx].shape[-1], | |
| ) | |
| current_point = 0 | |
| mix_visual_emb_frame = [] | |
| for frame_i in range(len(visual_emb_frame)): | |
| if current_point > len(top_indices) - 1: | |
| mix_visual_emb_frame.append( | |
| visual_emb_frame[frame_i] | |
| ) | |
| continue | |
| if frame_i == top_indices[current_point]: | |
| mix_visual_emb_frame.append( | |
| visual_emb_frame_highres[frame_i] | |
| ) | |
| current_point += 1 | |
| else: | |
| mix_visual_emb_frame.append( | |
| visual_emb_frame[frame_i] | |
| ) | |
| image_features[cur_image_idx] = torch.cat( | |
| mix_visual_emb_frame, dim=0 | |
| ) | |
| # ablation drop | |
| if ( | |
| max_visual_len < visual_len | |
| and frame_split_sizes is not None | |
| and not mix_token | |
| ): | |
| visual_emb_frame = image_features[cur_image_idx].reshape( | |
| frame_split_sizes[cur_image_idx], | |
| -1, | |
| image_features[cur_image_idx].shape[-1], | |
| ) | |
| sim = F.cosine_similarity( | |
| visual_emb_frame[:-1], | |
| visual_emb_frame[1:], | |
| dim=-1, | |
| ) | |
| new_visual_emb_frames = [] | |
| for start_idx in range(0, len(visual_emb_frame), 8): | |
| end_idx = min(start_idx + 8, len(visual_emb_frame)) | |
| chunk_feature = visual_emb_frame[start_idx:end_idx] # 8, HW, C | |
| if len(chunk_feature) == 1: | |
| new_visual_emb_frames.append(chunk_feature[0]) | |
| continue | |
| sim = F.cosine_similarity( | |
| chunk_feature[0] | |
| .unsqueeze(0) | |
| .repeat_interleave(len(chunk_feature[1:]), dim=0), | |
| chunk_feature[1:], | |
| dim=-1, | |
| ) | |
| new_visual_emb_frame = torch.cat( | |
| [ | |
| chunk_feature[0], | |
| chunk_feature[1:].flatten(0, 1)[ | |
| sim.flatten(0, 1) | |
| < getattr( | |
| self.get_model().config, "drop_threshold", 0.7 | |
| ) | |
| ], | |
| ], | |
| dim=0, | |
| ) | |
| new_visual_emb_frames.append(new_visual_emb_frame) | |
| reduced_visual_len = sum([x.shape[0] for x in new_visual_emb_frames]) | |
| if reduced_visual_len > max_visual_len: | |
| force_remove = math.ceil( | |
| (reduced_visual_len - max_visual_len) | |
| / len(new_visual_emb_frames) | |
| ) | |
| for chunk_i in range(len(new_visual_emb_frames)): | |
| new_visual_emb_frames[chunk_i] = new_visual_emb_frames[chunk_i][ | |
| :-force_remove | |
| ] | |
| new_visual_emb_frames = torch.cat(new_visual_emb_frames, dim=0) | |
| else: | |
| new_visual_emb_frames = torch.cat(new_visual_emb_frames, dim=0) | |
| image_features[cur_image_idx] = new_visual_emb_frames[:max_visual_len] | |
| for i in range(num_images + 1): | |
| cur_new_input_embeds.append(cur_input_embeds_no_im[i]) | |
| cur_new_labels.append(cur_labels_noim[i]) | |
| if i < num_images: | |
| cur_image_features = image_features[cur_image_idx] | |
| cur_image_idx += 1 | |
| cur_new_input_embeds.append(cur_image_features) | |
| cur_new_labels.append( | |
| torch.full( | |
| (cur_image_features.shape[0],), | |
| IGNORE_INDEX, | |
| device=cur_labels.device, | |
| dtype=cur_labels.dtype, | |
| ) | |
| ) | |
| cur_new_input_embeds = [x.to(self.device) for x in cur_new_input_embeds] | |
| cur_new_input_embeds = torch.cat(cur_new_input_embeds) | |
| cur_new_labels = torch.cat(cur_new_labels) | |
| new_input_embeds.append(cur_new_input_embeds) | |
| new_labels.append(cur_new_labels) | |
| # Truncate sequences to max length as image embeddings can make the sequence longer | |
| tokenizer_model_max_length = getattr( | |
| self.config, "tokenizer_model_max_length", None | |
| ) | |
| if tokenizer_model_max_length is not None: | |
| new_input_embeds = [ | |
| x[:tokenizer_model_max_length] for x in new_input_embeds | |
| ] | |
| new_labels = [x[:tokenizer_model_max_length] for x in new_labels] | |
| # Combine them | |
| max_len = max(x.shape[0] for x in new_input_embeds) | |
| batch_size = len(new_input_embeds) | |
| new_input_embeds_padded = [] | |
| new_labels_padded = torch.full( | |
| (batch_size, max_len), | |
| IGNORE_INDEX, | |
| dtype=new_labels[0].dtype, | |
| device=new_labels[0].device, | |
| ) | |
| attention_mask = torch.zeros( | |
| (batch_size, max_len), | |
| dtype=attention_mask.dtype, | |
| device=attention_mask.device, | |
| ) | |
| position_ids = torch.zeros( | |
| (batch_size, max_len), | |
| dtype=position_ids.dtype, | |
| device=position_ids.device, | |
| ) | |
| for i, (cur_new_embed, cur_new_labels) in enumerate( | |
| zip(new_input_embeds, new_labels) | |
| ): | |
| cur_len = cur_new_embed.shape[0] | |
| if getattr(self.config, "tokenizer_padding_side", "right") == "left": | |
| new_input_embeds_padded.append( | |
| torch.cat( | |
| ( | |
| torch.zeros( | |
| (max_len - cur_len, cur_new_embed.shape[1]), | |
| dtype=cur_new_embed.dtype, | |
| device=cur_new_embed.device, | |
| ), | |
| cur_new_embed, | |
| ), | |
| dim=0, | |
| ) | |
| ) | |
| if cur_len > 0: | |
| new_labels_padded[i, -cur_len:] = cur_new_labels | |
| attention_mask[i, -cur_len:] = True | |
| position_ids[i, -cur_len:] = torch.arange( | |
| 0, | |
| cur_len, | |
| dtype=position_ids.dtype, | |
| device=position_ids.device, | |
| ) | |
| else: | |
| new_input_embeds_padded.append( | |
| torch.cat( | |
| ( | |
| cur_new_embed, | |
| torch.zeros( | |
| (max_len - cur_len, cur_new_embed.shape[1]), | |
| dtype=cur_new_embed.dtype, | |
| device=cur_new_embed.device, | |
| ), | |
| ), | |
| dim=0, | |
| ) | |
| ) | |
| if cur_len > 0: | |
| new_labels_padded[i, :cur_len] = cur_new_labels | |
| attention_mask[i, :cur_len] = True | |
| position_ids[i, :cur_len] = torch.arange( | |
| 0, | |
| cur_len, | |
| dtype=position_ids.dtype, | |
| device=position_ids.device, | |
| ) | |
| new_input_embeds = torch.stack(new_input_embeds_padded, dim=0) | |
| if _labels is None: | |
| new_labels = None | |
| else: | |
| new_labels = new_labels_padded | |
| if _attention_mask is None: | |
| attention_mask = None | |
| else: | |
| attention_mask = attention_mask.to(dtype=_attention_mask.dtype) | |
| if _position_ids is None: | |
| position_ids = None | |
| return ( | |
| None, | |
| position_ids, | |
| attention_mask, | |
| past_key_values, | |
| new_input_embeds, | |
| new_labels, | |
| vision_tower_aux_feature_list_final, | |
| vision_tower_aux_attention_masks_list_final, | |
| final_size, | |
| global_context_feature_final, | |
| ) | |
| def initialize_vision_tokenizer(self, model_args, tokenizer): | |
| if model_args.mm_use_im_patch_token: | |
| tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True) | |
| self.resize_token_embeddings(len(tokenizer)) | |
| if model_args.mm_use_im_start_end: | |
| num_new_tokens = tokenizer.add_tokens( | |
| [DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True | |
| ) | |
| self.resize_token_embeddings(len(tokenizer)) | |
| if num_new_tokens > 0: | |
| input_embeddings = self.get_input_embeddings().weight.data | |
| output_embeddings = self.get_output_embeddings().weight.data | |
| input_embeddings_avg = input_embeddings[:-num_new_tokens].mean( | |
| dim=0, keepdim=True | |
| ) | |
| output_embeddings_avg = output_embeddings[:-num_new_tokens].mean( | |
| dim=0, keepdim=True | |
| ) | |
| input_embeddings[-num_new_tokens:] = input_embeddings_avg | |
| output_embeddings[-num_new_tokens:] = output_embeddings_avg | |
| if model_args.tune_mm_mlp_adapter: | |
| for p in self.get_input_embeddings().parameters(): | |
| p.requires_grad = True | |
| for p in self.get_output_embeddings().parameters(): | |
| p.requires_grad = False | |
| if model_args.pretrain_mm_mlp_adapter: | |
| mm_projector_weights = torch.load( | |
| model_args.pretrain_mm_mlp_adapter, map_location="cpu" | |
| ) | |
| embed_tokens_weight = mm_projector_weights["model.embed_tokens.weight"] | |
| assert num_new_tokens == 2 | |
| if input_embeddings.shape == embed_tokens_weight.shape: | |
| input_embeddings[-num_new_tokens:] = embed_tokens_weight[ | |
| -num_new_tokens: | |
| ] | |
| elif embed_tokens_weight.shape[0] == num_new_tokens: | |
| input_embeddings[-num_new_tokens:] = embed_tokens_weight | |
| else: | |
| raise ValueError( | |
| f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}." | |
| ) | |
| elif model_args.mm_use_im_patch_token: | |
| if model_args.tune_mm_mlp_adapter: | |
| for p in self.get_input_embeddings().parameters(): | |
| p.requires_grad = False | |
| for p in self.get_output_embeddings().parameters(): | |
| p.requires_grad = False | |