import torch import torch.nn as nn from typing import Optional, List, Union, Tuple from transformers import Qwen2Model, Qwen2ForCausalLM from transformers.utils import logging, is_torchdynamo_compiling from transformers.cache_utils import Cache, DynamicCache, StaticCache from transformers.configuration_utils import PretrainedConfig from transformers.modeling_outputs import ( CausalLMOutputWithPast, BaseModelOutputWithPast, ) from transformers.modeling_attn_mask_utils import ( _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa, ) from transformers.models.qwen2.modeling_qwen2 import ( Qwen2DecoderLayer, Qwen2RMSNorm, Qwen2RotaryEmbedding, ) logger = logging.get_logger(__name__) # Impl. for transformers==4.42.0 class Qwen2MMConfig(PretrainedConfig): model_type = "qwen" keys_to_ignore_at_inference = ["past_key_values"] def __init__( self, vocab_size=151936, hidden_size=4096, intermediate_size=22016, num_hidden_layers=32, num_attention_heads=32, num_key_value_heads=32, hidden_act="silu", max_position_embeddings=32768, initializer_range=0.02, rms_norm_eps=1e-6, use_cache=True, tie_word_embeddings=False, rope_theta=10000.0, use_sliding_window=False, sliding_window=4096, max_window_layers=28, attention_dropout=0.0, vision_patch_size=32, **kwargs, ): self.vocab_size = vocab_size self.max_position_embeddings = max_position_embeddings self.hidden_size = hidden_size self.intermediate_size = intermediate_size self.num_hidden_layers = num_hidden_layers self.num_attention_heads = num_attention_heads self.use_sliding_window = use_sliding_window self.sliding_window = sliding_window self.max_window_layers = max_window_layers # for backward compatibility if num_key_value_heads is None: num_key_value_heads = num_attention_heads self.num_key_value_heads = num_key_value_heads self.hidden_act = hidden_act self.initializer_range = initializer_range self.rms_norm_eps = rms_norm_eps self.use_cache = use_cache self.rope_theta = rope_theta self.attention_dropout = attention_dropout self.vision_patch_size = vision_patch_size super().__init__( tie_word_embeddings=tie_word_embeddings, **kwargs, ) class MultimodalQwen2Model(Qwen2Model): def __init__(self, config: Qwen2MMConfig): super(Qwen2Model, self).__init__(config) self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) self.layers = nn.ModuleList( [Qwen2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ) self._attn_implementation = config._attn_implementation self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.gradient_checkpointing = False # === Vision Patches === assert config.vision_patch_size == 32 self.vis_embed = nn.Linear( config.vision_patch_size * config.vision_patch_size * 3, # 32 * 32 * 3, config.hidden_size, bias=False, ) # === Vision Patches === # Initialize weights and apply final processing self.post_init() def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, vision_patch_indices: torch.LongTensor = None, # (batch_size, seq_length), "-1" for text token vision_patches: torch.FloatTensor = None, # (n_patches, 32 * 32 * 3) position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError( "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" ) if self.gradient_checkpointing and self.training: if use_cache: logger.warning_once( "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False use_legacy_cache = False if use_cache and not isinstance(past_key_values, Cache): use_legacy_cache = True past_key_values = DynamicCache.from_legacy_cache(past_key_values) logger.warning_once( "We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. " "Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)" ) if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) if vision_patch_indices is not None: assert ( vision_patch_indices.shape == input_ids.shape ), "vision_patch_indices and input_ids should have the same shape" # === Handle vision patches === if vision_patches is not None and vision_patches.size(0) > 0: assert vision_patch_indices is not None, "HF QwenMM model requires vision_patch_indices for vision_patches input." vision_embeds = self.vis_embed(vision_patches) # (n_patches, hidden_size) vision_embeds = torch.cat( [ vision_embeds, torch.zeros(1, self.config.hidden_size).to( vision_embeds.device ), # add a dummy token (for text) ], ) # (n_patches + 1, hidden_size) # arrange embeddings according to vision_patch_indices # - text tokens are -1 (map to the dummy zero tensor) # - vision tokens are 0~n_patches (map to the corresponding vision_embeds) vision_embeds = vision_embeds[vision_patch_indices] # (batch_size, seq_length, hidden_size) # merge vision_embeds with inputs_embeds inputs_embeds += vision_embeds if cache_position is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) if position_ids is None: position_ids = cache_position.unsqueeze(0) causal_mask = self._update_causal_mask( attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions ) hidden_states = inputs_embeds # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None next_decoder_cache = None for decoder_layer in self.layers: if output_hidden_states: all_hidden_states += (hidden_states,) if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, causal_mask, position_ids, past_key_values, output_attentions, use_cache, cache_position, ) else: layer_outputs = decoder_layer( hidden_states, attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, ) hidden_states = layer_outputs[0] if use_cache: next_decoder_cache = layer_outputs[2 if output_attentions else 1] if output_attentions: all_self_attns += (layer_outputs[1],) hidden_states = self.norm(hidden_states) # add hidden states from the last decoder layer if output_hidden_states: all_hidden_states += (hidden_states,) next_cache = None if use_cache: next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache if not return_dict: return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=next_cache, hidden_states=all_hidden_states, attentions=all_self_attns, ) class Qwen2MMForCausalLM(Qwen2ForCausalLM): def __init__(self, config: Qwen2MMConfig): super().__init__(config) self.model = MultimodalQwen2Model(config) self.vocab_size = config.vocab_size self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) # Initialize weights and apply final processing self.post_init() def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, vision_patch_indices: torch.LongTensor = None, # (batch_size, seq_length), "-1" for text token vision_patches: torch.FloatTensor = None, # (n_patches, 32 * 32 * 3) position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, num_logits_to_keep: int = 0, ) -> Union[Tuple, CausalLMOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = self.model( input_ids=input_ids, attention_mask=attention_mask, vision_patch_indices=vision_patch_indices, vision_patches=vision_patches, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, cache_position=cache_position, ) hidden_states = outputs[0] if labels is None and not is_torchdynamo_compiling(): logger.warning_once( "Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" ) # Only compute necessary logits, and do not upcast them to float if we are not computing the loss # TODO: remove the float() operation in v4.46 logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float() loss = None if labels is not None: # Upcast to float if we need to compute the loss to avoid potential precision issues logits = logits.float() # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() # Flatten the tokens loss_fct = nn.CrossEntropyLoss() shift_logits = shift_logits.view(-1, self.config.vocab_size) shift_labels = shift_labels.view(-1) # Enable model parallelism shift_labels = shift_labels.to(shift_logits.device) loss = loss_fct(shift_logits, shift_labels) if not return_dict: output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output return CausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.prepare_inputs_for_generation def prepare_inputs_for_generation( self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, cache_position=None, use_cache=True, **kwargs, ): vision_patches = kwargs.get("vision_patches", None) vision_patch_indices = kwargs.get("vision_patch_indices", None) has_vision_inp = False if vision_patches is not None and vision_patch_indices is not None: has_vision_inp = True # make vision_patch_indices to be the same shape as input_ids by padding -1 _padding = torch.full_like(input_ids, -1, dtype=vision_patch_indices.dtype) _padding[:, : vision_patch_indices.shape[1]] = vision_patch_indices vision_patch_indices = _padding past_length = 0 # Omit tokens covered by past_key_values if past_key_values is not None: # Past key values are always initialized with a `Cache` object -> no need for if-else anymore past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length() max_cache_length = ( torch.tensor(past_key_values.get_max_length(), device=input_ids.device) if past_key_values.get_max_length() is not None else None ) cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length) # Keep only the unprocessed tokens: # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as # input) if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: input_ids = input_ids[:, -(attention_mask.shape[1] - past_length):] if has_vision_inp: vision_patch_indices = vision_patch_indices[:, -(attention_mask.shape[1] - past_length):] # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard # input_ids based on the past_length. elif past_length < input_ids.shape[1]: input_ids = input_ids[:, past_length:] if has_vision_inp: vision_patch_indices = vision_patch_indices[:, past_length:] # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. if ( max_cache_length is not None and attention_mask is not None and cache_length + input_ids.shape[1] > max_cache_length ): attention_mask = attention_mask[:, -max_cache_length:] position_ids = kwargs.get("position_ids", None) if attention_mask is not None and position_ids is None: # create position_ids on the fly for batch generation position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) if past_key_values: position_ids = position_ids[:, -input_ids.shape[1]:] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_length == 0: model_inputs = {"inputs_embeds": inputs_embeds} else: model_inputs = {"input_ids": input_ids} input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1] if cache_position is None: cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device) elif use_cache: cache_position = cache_position[-input_length:] if vision_patch_indices is not None: assert vision_patch_indices.shape == input_ids.shape model_inputs.update( { "position_ids": position_ids, "past_key_values": past_key_values, "use_cache": use_cache, "attention_mask": attention_mask, "cache_position": cache_position, "vision_patch_indices": vision_patch_indices, "vision_patches": vision_patches, } ) return model_inputs if __name__ == "__main__": mmqwen = Qwen2MMForCausalLM.from_pretrained("Qwen2-0.5B") print(mmqwen)