Spaces:
Runtime error
Runtime error
| from typing import List, Optional, Tuple, Union | |
| import torch | |
| import torch.nn as nn | |
| from transformers import Qwen2_5_VLConfig, AutoConfig, AutoModelForCausalLM | |
| from vlm_fo1.model.multimodal_encoder.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLModel, Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLCausalLMOutputWithPast | |
| from vlm_fo1.model.multimodal_encoder.qwen2_5_vl_encoder import Qwen2_5_VlVisionTower | |
| from vlm_fo1.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, DEFAULT_REGION_INDEX, QWEN2_5_VL_IMAGE_TOKEN, QWEN2_5_VL_IMAGE_TOKEN_INDEX | |
| from ..omchat_arch import OmChatMetaModel, OmChatMetaForCausalLM | |
| # Custom config which extends Qwen2_5_VLConfig for OmChat multimodal model | |
| class OmChatQwen25VLConfig(Qwen2_5_VLConfig): | |
| model_type = "omchat_qwen2_5_vl" | |
| rotary_type = "normal_rotary" | |
| multi_scale_im = None | |
| vision_tower_aux = None | |
| # Core model definition: inherits from OmChat and Qwen multimodal base | |
| class OmChatQwen25VLModel(OmChatMetaModel, Qwen2_5_VLModel): | |
| config_class = OmChatQwen25VLConfig | |
| def __init__(self, config: Qwen2_5_VLConfig): | |
| super(OmChatQwen25VLModel, self).__init__(config) | |
| # Main class for multimodal CausalLM | |
| class OmChatQwen25VLForCausalLM(Qwen2_5_VLForConditionalGeneration, OmChatMetaForCausalLM): | |
| config_class = OmChatQwen25VLConfig | |
| def __init__(self, config, delay_load=True): | |
| # Ensure config has delay_load property | |
| if not hasattr(config, 'delay_load'): | |
| config.delay_load = delay_load | |
| super(Qwen2_5_VLForConditionalGeneration, self).__init__(config) | |
| self.model = OmChatQwen25VLModel(config) | |
| self.vocab_size = config.vocab_size | |
| self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) | |
| self.rope_deltas = None # cache rope_deltas here | |
| self.post_init() | |
| # Encode input images into feature representations | |
| def encode_images(self, images, images_grid_thw=None): | |
| # If vision_tower is Qwen2.5-specific, use its custom forward signature | |
| if isinstance(self.get_model().get_vision_tower(), Qwen2_5_VlVisionTower): | |
| image_features = self.get_model().get_vision_tower()(images, images_grid_thw) | |
| image_features, image_grid_thws, multi_level_features = image_features | |
| # If multiple images, handle concatenation | |
| if type(image_features) is list: | |
| # List has items of shape (1, seq_len, dim) | |
| token_length_list = [i.shape[1] for i in image_features] | |
| image_features = torch.cat(image_features, dim=1) # Concatenate to (1, total_seq_len, dim) | |
| else: | |
| image_features = self.get_model().get_vision_tower()(images) | |
| image_grid_thws = None | |
| multi_level_features = None | |
| image_features = self.get_model().mm_projector(image_features) | |
| # Split concatenated image features back by original lengths (for multi-image case) | |
| if isinstance(self.get_model().get_vision_tower(), Qwen2_5_VlVisionTower): | |
| start = 0 | |
| new_image_features = [] | |
| # Split according to token_length_list | |
| for length in token_length_list: | |
| end = start + length | |
| new_image_features.append(image_features[:, start:end, :].squeeze(0)) | |
| start = end | |
| image_features = new_image_features | |
| return image_features, image_grid_thws, multi_level_features | |
| # Encode region regions (bounding boxes) into features, optionally using auxiliary vision tower | |
| def encode_regions(self, images, bbox_list, vt_multi_level_features=None, vt_images_size=None): | |
| aux_image_features_list = self.get_model().get_vision_tower_aux()(images) | |
| region_features = [] | |
| if getattr(self.config, "mm_use_vision_tower_region_feature", False): | |
| image_features_list = vt_multi_level_features | |
| for batch_idx, (image_features, aux_image_features) in enumerate(zip(image_features_list, aux_image_features_list)): | |
| if getattr(self.config, "mm_use_simpleFPN_for_vt", False): | |
| multilevel_visual_feats = image_features[-1] | |
| else: | |
| multilevel_visual_feats = image_features | |
| multilevel_aux_visual_feats = aux_image_features["image_features"] | |
| boxes = bbox_list[batch_idx] | |
| # If no boxes provided, use dummy box (covers tiny region) | |
| if boxes is None or len(boxes) == 0: | |
| boxes = torch.tensor([[0, 10, 0, 10]], device=multilevel_aux_visual_feats[0].device, dtype=torch.float32) | |
| boxes = boxes.to(torch.float32).to(multilevel_aux_visual_feats[0].device) | |
| current_image_height, current_image_width = images[batch_idx].shape[-2:] | |
| original_height, original_width = vt_images_size[batch_idx] | |
| # Scale bounding boxes from original image size to processed size | |
| scale_height = original_height / current_image_height | |
| scale_width = original_width / current_image_width | |
| vt_boxes = boxes * torch.tensor([scale_width, scale_height, scale_width, scale_height], device=boxes.device) | |
| extracted_region_feat = self.get_model().object_vp_extractor( | |
| aux_multi_level_features=multilevel_aux_visual_feats, | |
| vt_multi_level_features=multilevel_visual_feats, | |
| aux_boxes=[boxes], | |
| vt_boxes=[vt_boxes] | |
| ).squeeze(0).to(multilevel_aux_visual_feats[0].dtype) | |
| region_feat = self.get_model().mm_projector_aux(extracted_region_feat) # [num_bbox, 2048] | |
| region_features.append(region_feat) | |
| else: | |
| # Extract region features only from auxiliary vision tower | |
| for batch_idx, image_features in enumerate(aux_image_features_list): | |
| multilevel_visual_feats = image_features["image_features"] | |
| last_feat = image_features["last_feat"] | |
| boxes = bbox_list[batch_idx] | |
| if boxes is None or len(boxes) == 0: | |
| boxes = torch.tensor([[0, 10, 0, 10]], device=multilevel_visual_feats[0].device, dtype=torch.float32) | |
| multi_level_aux_features = multilevel_visual_feats | |
| boxes = boxes.to(torch.float32).to(multi_level_aux_features[0].device) | |
| extracted_region_feat = self.get_model().object_vp_extractor( | |
| multi_level_aux_features, | |
| [boxes], | |
| ).squeeze(0).to(multi_level_aux_features[0].dtype) | |
| region_feat = self.get_model().mm_projector_aux(extracted_region_feat) # [num_bbox, 2880] | |
| region_features.append(region_feat) | |
| return region_features | |
| def get_model(self): | |
| # Getter for model. Used to access backbone/model internals. | |
| return self.model | |
| # Convert sequence of input_ids/labels/images/boxes to multimodal embedding and associated masks/ids for transformer input. | |
| def prepare_inputs_labels_for_qwen2_5_vl_multimodal( | |
| self, input_ids, position_ids, attention_mask, past_key_values, labels, images, images_aux=None, bbox_list=None, image_grid_thws=None | |
| ): | |
| # ========================== Above this line, input parsing and batching ============================= | |
| vision_tower = self.get_vision_tower() | |
| video_tower = self.get_video_tower() | |
| vision_tower_aux = self.get_vision_tower_aux() | |
| # Fast-path for non-multimodal case or first step in generation (i.e. only one token in input) | |
| if (vision_tower is None and video_tower is None) or images is None or input_ids.shape[1] == 1: | |
| if past_key_values is not None and (vision_tower is not None or video_tower is not None) and images is not None and input_ids.shape[1] == 1: | |
| target_shape = past_key_values[-1][-1].shape[-2] + 1 | |
| attention_mask = torch.cat((attention_mask, torch.ones( | |
| (attention_mask.shape[0], target_shape - attention_mask.shape[1]), | |
| dtype=attention_mask.dtype, | |
| device=attention_mask.device | |
| )), dim=1) | |
| position_ids=None | |
| cache_position = torch.tensor([target_shape - 1],device=attention_mask.device) | |
| return input_ids, position_ids, attention_mask, past_key_values, None, labels, None, cache_position | |
| # Indices for images (3D or 2D tensors) and videos (4D tensors) | |
| image_idx = [idx for idx, img in enumerate(images) if img.ndim == 3 or img.ndim == 2] | |
| is_all_image = len(image_idx) == len(images) | |
| video_idx = [idx for idx, vid in enumerate(images) if vid.ndim == 4] | |
| # Stack image and video tensors accordingly for mini-batch processing | |
| if isinstance(vision_tower, Qwen2_5_VlVisionTower): | |
| images_minibatch = [images[idx] for idx in image_idx] if len(image_idx) > 0 else [] # list of [c,h,w], can have variable shapes | |
| else: | |
| images_minibatch = torch.stack([images[idx] for idx in image_idx]) if len(image_idx) > 0 else [] # tensor [mini_b, c, h, w] | |
| videos_minibatch = torch.stack([images[idx] for idx in video_idx]) if len(video_idx) > 0 else [] # tensor [mini_b, c, t, h, w] | |
| # Auxiliary batch for region encoding, if relevant | |
| if vision_tower_aux is not None and images_aux is not None: | |
| images_minibatch_aux = [images_aux[idx].unsqueeze(0) for idx in image_idx] if len(image_idx) > 0 else [] # list of [1, c, h, w] | |
| # tmp_image_features will be indexed to scatter extracted image/video features into original batch positions | |
| tmp_image_features = [None] * (len(image_idx) + len(video_idx)) | |
| if getattr(images_minibatch, 'ndim', 0) == 4 or (type(images_minibatch) is list and len(images_minibatch) > 0): # batch consists of images, [mini_b, c, h, w] | |
| if vision_tower is not None: | |
| image_features_minibatch, image_grid_thws_minibatch, vt_multi_level_features_minibatch = self.encode_images(images_minibatch, image_grid_thws) # [mini_b, l, c] | |
| else: | |
| image_features_minibatch = torch.randn(1).to(self.device) # dummy feature for video-only training under tuning | |
| # Map extracted image features back to their places in the original batch | |
| for i, pos in enumerate(image_idx): | |
| tmp_image_features[pos] = image_features_minibatch[i] | |
| # Handle auxiliary region features if enabled and boxes provided | |
| if vision_tower_aux is not None and bbox_list is not None and len(bbox_list) > 0: | |
| if isinstance(self.get_model().get_vision_tower(), Qwen2_5_VlVisionTower): | |
| patch_size = self.get_model().get_vision_tower().config.patch_size | |
| vt_images_size_minibatch = [im_grid_thw[0][-2:]*patch_size for im_grid_thw in image_grid_thws] | |
| region_features = self.encode_regions(images_minibatch_aux, bbox_list, vt_multi_level_features_minibatch, vt_images_size_minibatch) # [mini_b, l, c] | |
| else: | |
| region_features = None | |
| # Same as above, but for video features if any | |
| if getattr(videos_minibatch, 'ndim', 0) == 5: # batch consists of videos, [mini_b, c, t, h, w] | |
| video_features_minibatch = self.encode_videos(videos_minibatch) # fake list [mini_b, t, l, c] | |
| for i, pos in enumerate(video_idx): | |
| tmp_image_features[pos] = video_features_minibatch[i] | |
| # Flatten image feature slot list to proper order for current batch | |
| new_tmp = [] | |
| for image in tmp_image_features: | |
| # If multi-image per item, flatten out | |
| if isinstance(image, list): | |
| t = len(image) | |
| for i in range(t): | |
| new_tmp.append(image[i]) | |
| else: | |
| new_tmp.append(image) | |
| image_features = new_tmp | |
| # =========================== Now, build multimodal input & target sequences ========================= | |
| if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False): | |
| raise NotImplementedError | |
| _labels = labels | |
| _position_ids = position_ids | |
| _attention_mask = attention_mask | |
| # Default construction of masks etc. | |
| if attention_mask is None: | |
| attention_mask = torch.ones_like(input_ids, dtype=torch.bool) | |
| else: | |
| attention_mask = attention_mask.bool() | |
| if position_ids is None: | |
| position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device) | |
| if labels is None: | |
| labels = torch.full_like(input_ids, IGNORE_INDEX) | |
| # For each batch item, strip padded tokens based on attention_mask | |
| input_ids = [cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)] | |
| labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask)] | |
| # If neither region auxiliary nor bboxes present: process classic image-text input | |
| if vision_tower_aux is None and (bbox_list is None or all(x is None for x in bbox_list)): | |
| new_input_embeds = [] | |
| new_labels = [] | |
| new_input_ids = [] | |
| cur_image_idx = 0 | |
| image_nums_in_batch = [] | |
| for batch_idx, cur_input_ids in enumerate(input_ids): | |
| num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum() | |
| image_nums_in_batch.append(num_images) | |
| # If there are no image markers, just get text features | |
| if num_images == 0: | |
| cur_image_features = image_features[cur_image_idx] | |
| cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids) | |
| cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features[0:0]], dim=0) | |
| new_input_embeds.append(cur_input_embeds) | |
| new_labels.append(labels[batch_idx]) | |
| new_input_ids.append(cur_input_ids) | |
| cur_image_idx += 1 | |
| continue | |
| # Split on image token indices: replace them with image features after conversion | |
| image_token_indices = [-1] + torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist() + [cur_input_ids.shape[0]] | |
| cur_input_ids_noim = [] | |
| cur_labels = labels[batch_idx] | |
| cur_labels_noim = [] | |
| for i in range(len(image_token_indices) - 1): | |
| cur_input_ids_noim.append(cur_input_ids[image_token_indices[i]+1:image_token_indices[i+1]]) | |
| cur_labels_noim.append(cur_labels[image_token_indices[i]+1:image_token_indices[i+1]]) | |
| split_sizes = [x.shape[0] for x in cur_labels_noim] | |
| cur_input_embeds = self.get_model().embed_tokens(torch.cat(cur_input_ids_noim)) | |
| cur_input_embeds_no_im = torch.split(cur_input_embeds, split_sizes, dim=0) | |
| cur_new_input_embeds = [] | |
| cur_new_labels = [] | |
| cur_new_input_ids = [] | |
| for i in range(num_images + 1): | |
| # Interleave text and image features | |
| cur_new_input_embeds.append(cur_input_embeds_no_im[i]) | |
| cur_new_labels.append(cur_labels_noim[i]) | |
| cur_new_input_ids.append(cur_input_ids_noim[i]) | |
| if i < num_images: | |
| cur_image_features = image_features[cur_image_idx].to(self.device) | |
| cur_image_idx += 1 | |
| cur_new_input_embeds.append(cur_image_features) | |
| cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=cur_labels.device, dtype=cur_labels.dtype)) | |
| cur_new_input_ids.append(torch.full((cur_image_features.shape[0],), self.config.image_token_id, device=cur_labels.device, dtype=cur_labels.dtype)) | |
| cur_new_input_embeds = torch.cat(cur_new_input_embeds) | |
| cur_new_labels = torch.cat(cur_new_labels) | |
| cur_new_input_ids = torch.cat(cur_new_input_ids) | |
| new_input_embeds.append(cur_new_input_embeds) | |
| new_labels.append(cur_new_labels) | |
| new_input_ids.append(cur_new_input_ids) | |
| # If region markers or region features enabled in config | |
| else: | |
| new_input_embeds = [] | |
| new_labels = [] | |
| new_input_ids = [] | |
| cur_image_idx = 0 | |
| image_nums_in_batch = [] | |
| for batch_idx, cur_input_ids in enumerate(input_ids): | |
| cur_region_idx = 0 | |
| # Detect image and region special token counts | |
| num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum() | |
| num_regions = (cur_input_ids == DEFAULT_REGION_INDEX).sum() if DEFAULT_REGION_INDEX in cur_input_ids else 0 | |
| image_nums_in_batch.append(num_images) | |
| # If no markers, just do text embedding for this item | |
| if num_images == 0 and num_regions == 0: | |
| cur_image_features = image_features[cur_image_idx] | |
| cur_region_features = region_features[cur_region_idx] | |
| cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids) | |
| cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features[0:0], cur_region_features[0:0]], dim=0) | |
| new_input_embeds.append(cur_input_embeds) | |
| new_labels.append(labels[batch_idx]) | |
| new_input_ids.append(cur_input_ids) | |
| cur_image_idx += 1 | |
| continue | |
| # Get all special marker indices (image/region) | |
| image_indices = torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist() | |
| region_indices = torch.where(cur_input_ids == DEFAULT_REGION_INDEX)[0].tolist() if num_regions > 0 else [] | |
| all_special_indices = sorted([-1] + image_indices + region_indices + [cur_input_ids.shape[0]]) | |
| # Split out plain text chunks between special markers | |
| cur_input_ids_segments = [] | |
| cur_labels = labels[batch_idx] | |
| cur_labels_segments = [] | |
| for i in range(len(all_special_indices) - 1): | |
| cur_input_ids_segments.append(cur_input_ids[all_special_indices[i]+1:all_special_indices[i+1]]) | |
| cur_labels_segments.append(cur_labels[all_special_indices[i]+1:all_special_indices[i+1]]) | |
| # Project text ids to word embeddings | |
| split_sizes = [x.shape[0] for x in cur_labels_segments] | |
| cur_input_embeds = self.get_model().embed_tokens(torch.cat(cur_input_ids_segments)) | |
| if num_regions == 0 and vision_tower_aux is not None and region_features is not None: | |
| cur_region_features = region_features[cur_region_idx] | |
| temp_input_embeds = torch.cat([cur_input_embeds, cur_region_features[0:0]], dim=0) | |
| cur_input_embeds = temp_input_embeds | |
| cur_input_embeds_segments = torch.split(cur_input_embeds, split_sizes, dim=0) | |
| # Reassemble text and image/region segments in order | |
| cur_new_input_embeds = [] | |
| cur_new_labels = [] | |
| cur_new_input_ids = [] | |
| for i in range(len(all_special_indices) - 1): | |
| # Insert current text segment | |
| cur_new_input_embeds.append(cur_input_embeds_segments[i]) | |
| cur_new_labels.append(cur_labels_segments[i]) | |
| cur_new_input_ids.append(cur_input_ids_segments[i]) | |
| # If next is image, insert feature representation | |
| if all_special_indices[i+1] in image_indices: | |
| cur_image_features = image_features[cur_image_idx].to(self.device) | |
| cur_image_idx += 1 | |
| cur_new_input_embeds.append(cur_image_features) | |
| cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=cur_labels.device, dtype=cur_labels.dtype)) | |
| cur_new_input_ids.append(torch.full((cur_image_features.shape[0],), self.config.image_token_id, device=cur_labels.device, dtype=cur_labels.dtype)) | |
| # If next is region token, insert extracted region features | |
| elif all_special_indices[i+1] in region_indices: | |
| cur_region_features = region_features[batch_idx][cur_region_idx].to(self.device).unsqueeze(0) | |
| cur_region_idx += 1 | |
| cur_new_input_embeds.append(cur_region_features) | |
| cur_new_labels.append(torch.full((cur_region_features.shape[0],), IGNORE_INDEX, device=cur_labels.device, dtype=cur_labels.dtype)) | |
| cur_new_input_ids.append(torch.full((cur_region_features.shape[0],), DEFAULT_REGION_INDEX, device=cur_labels.device, dtype=cur_labels.dtype)) | |
| # Combine for this batch item | |
| cur_new_input_embeds = torch.cat(cur_new_input_embeds) | |
| cur_new_labels = torch.cat(cur_new_labels) | |
| cur_new_input_ids = torch.cat(cur_new_input_ids) | |
| new_input_embeds.append(cur_new_input_embeds) | |
| new_labels.append(cur_new_labels) | |
| new_input_ids.append(cur_new_input_ids) | |
| # Truncate sequences to maximum model length, if image+region tokens caused overflow | |
| tokenizer_model_max_length = getattr(self.config, 'tokenizer_model_max_length', None) | |
| if tokenizer_model_max_length is not None: | |
| new_input_embeds = [x[:tokenizer_model_max_length] for x in new_input_embeds] | |
| new_labels = [x[:tokenizer_model_max_length] for x in new_labels] | |
| # Pad sequences in the batch to same length; compute batch masks | |
| max_len = max(x.shape[0] for x in new_input_embeds) | |
| batch_size = len(new_input_embeds) | |
| new_input_embeds_padded = [] | |
| new_labels_padded = torch.full((batch_size, max_len), IGNORE_INDEX, dtype=new_labels[0].dtype, device=new_labels[0].device) | |
| new_input_ids_padded = torch.full((batch_size, max_len), self.config.bos_token_id, dtype=new_input_ids[0].dtype, device=new_input_ids[0].device) | |
| attention_mask = torch.zeros((batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device) | |
| position_ids = torch.zeros((batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device) | |
| # Left or right padding as per config; fill padded tensors | |
| for i, (cur_new_embed, cur_new_labels, cur_new_input_ids) in enumerate(zip(new_input_embeds, new_labels, new_input_ids)): | |
| cur_len = cur_new_embed.shape[0] | |
| if getattr(self.config, 'tokenizer_padding_side', 'right') == "left": | |
| # Left pad: add zeros before text tokens/features | |
| new_input_embeds_padded.append(torch.cat(( | |
| torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device), | |
| cur_new_embed | |
| ), dim=0)) | |
| if cur_len > 0: | |
| new_labels_padded[i, -cur_len:] = cur_new_labels | |
| attention_mask[i, -cur_len:] = True | |
| position_ids[i, -cur_len:] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device) | |
| else: | |
| # Right pad: add zeros after text tokens/features | |
| new_input_embeds_padded.append(torch.cat(( | |
| cur_new_embed, | |
| torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device) | |
| ), dim=0)) | |
| if cur_len > 0: | |
| new_labels_padded[i, :cur_len] = cur_new_labels | |
| new_input_ids_padded[i, :cur_len] = cur_new_input_ids | |
| attention_mask[i, :cur_len] = True | |
| position_ids[i, :cur_len] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device) | |
| new_input_embeds = torch.stack(new_input_embeds_padded, dim=0) | |
| new_input_ids = new_input_ids_padded | |
| # Only set new_labels if original labels were not None | |
| if _labels is None: | |
| new_labels = None | |
| else: | |
| new_labels = new_labels_padded | |
| # Similarly handle provided attention_mask/position_ids overrides | |
| if _attention_mask is None: | |
| attention_mask = None | |
| else: | |
| attention_mask = attention_mask.to(dtype=_attention_mask.dtype) | |
| if _position_ids is None: | |
| position_ids = None | |
| # For Qwen2.5 vision towers, use and concatenate image_grid_thws for positional computations | |
| if isinstance(self.get_model().get_vision_tower(), Qwen2_5_VlVisionTower): | |
| image_grid_thws = [] | |
| cur_image_idx = 0 | |
| for num_images in image_nums_in_batch: | |
| if num_images == 0: | |
| cur_image_idx += 1 | |
| continue | |
| image_grid_thws += image_grid_thws_minibatch[cur_image_idx:cur_image_idx+num_images] | |
| cur_image_idx += num_images | |
| if len(image_grid_thws) > 0: | |
| image_grid_thws = torch.cat(image_grid_thws, dim=0) | |
| else: | |
| image_grid_thws = None | |
| rope_index_kwargs = { | |
| "input_ids": new_input_ids, | |
| "image_grid_thw": image_grid_thws, | |
| "video_grid_thw": None, | |
| "attention_mask": attention_mask, | |
| } | |
| # Compute new position_ids and rope_deltas for transformer (for rotary embeddings) | |
| position_ids, rope_deltas = self.get_rope_index(**rope_index_kwargs) | |
| cache_position = torch.arange(new_input_embeds.shape[1], device=new_input_embeds.device) | |
| else: | |
| rope_deltas = None | |
| cache_position = None | |
| # Final output is a tuple mimicking HuggingFace prepare_inputs_for_generation return | |
| return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels, rope_deltas, cache_position | |
| # Patch forward() of HF CausalLM to allow multimodal embedding with images/regions | |
| def forward( | |
| self, | |
| input_ids: torch.LongTensor = None, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| position_ids: Optional[torch.LongTensor] = None, | |
| past_key_values: Optional[List[torch.FloatTensor]] = None, | |
| inputs_embeds: Optional[torch.FloatTensor] = None, | |
| labels: Optional[torch.LongTensor] = None, | |
| use_cache: Optional[bool] = None, | |
| output_attentions: Optional[bool] = None, | |
| output_hidden_states: Optional[bool] = None, | |
| return_dict: Optional[bool] = None, | |
| pixel_values: Optional[torch.Tensor] = None, | |
| pixel_values_videos: Optional[torch.FloatTensor] = None, | |
| image_grid_thw: Optional[torch.LongTensor] = None, | |
| video_grid_thw: Optional[torch.LongTensor] = None, | |
| rope_deltas: Optional[torch.LongTensor] = None, | |
| cache_position: Optional[torch.LongTensor] = None, | |
| second_per_grid_ts: Optional[torch.Tensor] = None, | |
| images: Optional[torch.FloatTensor] = None, | |
| images_aux: Optional[torch.FloatTensor] = None, | |
| bbox_list: Optional[torch.FloatTensor] = None, | |
| image_grid_thws: Optional[torch.FloatTensor] = None, | |
| ) -> Union[Tuple, Qwen2_5_VLCausalLMOutputWithPast]: | |
| if inputs_embeds is None: | |
| ( | |
| input_ids, | |
| position_ids, | |
| attention_mask, | |
| past_key_values, | |
| inputs_embeds, | |
| labels, | |
| rope_deltas, | |
| cache_position | |
| ) = self.prepare_inputs_labels_for_qwen2_5_vl_multimodal( | |
| input_ids, | |
| position_ids, | |
| attention_mask, | |
| past_key_values, | |
| labels, | |
| images, | |
| images_aux, | |
| bbox_list, | |
| image_grid_thws | |
| ) | |
| if rope_deltas is not None: | |
| self.rope_deltas = rope_deltas | |
| # Call base CausalLM forward, with possibly replaced multimodal embeddings | |
| out = super().forward( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| position_ids=position_ids, | |
| past_key_values=past_key_values, | |
| inputs_embeds=inputs_embeds, | |
| labels=labels, | |
| use_cache=use_cache, | |
| output_attentions=output_attentions, | |
| output_hidden_states=output_hidden_states, | |
| rope_deltas=rope_deltas, | |
| cache_position=cache_position, | |
| second_per_grid_ts=second_per_grid_ts, | |
| return_dict=return_dict | |
| ) | |
| return out | |
| # Prepare model input dict for autoregressive generation (for use with generation methods like generate()) | |
| def prepare_inputs_for_generation( | |
| self, | |
| input_ids, | |
| past_key_values=None, | |
| attention_mask=None, | |
| inputs_embeds=None, | |
| cache_position=None, | |
| position_ids=None, | |
| use_cache=True, | |
| pixel_values=None, | |
| pixel_values_videos=None, | |
| image_grid_thw=None, | |
| video_grid_thw=None, | |
| second_per_grid_ts=None, | |
| images: Optional[torch.FloatTensor] = None, | |
| images_aux: Optional[torch.FloatTensor] = None, | |
| bbox_list: Optional[torch.FloatTensor] = None, | |
| image_grid_thws: Optional[torch.FloatTensor] = None, | |
| **kwargs, | |
| ): | |
| # Wrap parent logic so extra multimodal kwargs are preserved | |
| model_inputs = super().prepare_inputs_for_generation( | |
| input_ids, | |
| past_key_values=past_key_values, | |
| attention_mask=attention_mask, | |
| inputs_embeds=inputs_embeds, | |
| cache_position=cache_position, | |
| pixel_values=pixel_values, | |
| pixel_values_videos=pixel_values_videos, | |
| image_grid_thw=image_grid_thw, | |
| video_grid_thw=video_grid_thw, | |
| second_per_grid_ts=second_per_grid_ts, | |
| images=images, | |
| images_aux=images_aux, | |
| bbox_list=bbox_list, | |
| image_grid_thws=image_grid_thws, | |
| ) | |
| return model_inputs | |
| # Register our config and model with HuggingFace transformers registry | |
| AutoConfig.register("omchat_qwen2_5_vl", OmChatQwen25VLConfig) | |
| AutoModelForCausalLM.register(OmChatQwen25VLConfig, OmChatQwen25VLForCausalLM) | |