from dataclasses import dataclass import inspect import warnings from typing import List, Optional, Tuple, Union import sys import os sys.path.append(os.path.dirname(os.path.abspath(__file__))) import torch import torch.nn.functional as F import torch.utils.checkpoint from torch import nn from torch.nn import CrossEntropyLoss from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask from transformers.utils import ( is_flash_attn_2_available ) from transformers import PreTrainedModel from transformers.modeling_outputs import ModelOutput from .configuration_vmistral import VMistralConfig from .vision import SiglipVisionModel from .modeling_vmistral import * from .generation_utils import TreeBuilder, WebGenerationMixin import time if is_flash_attn_2_available(): from flash_attn import flash_attn_func, flash_attn_varlen_func from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters) @dataclass class WebLMOutputWithPast(ModelOutput): loss: Optional[torch.FloatTensor] = None logits: torch.FloatTensor = None past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None attentions: Optional[Tuple[torch.FloatTensor, ...]] = None image_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None html_tree: TreeBuilder = None class WebAttention(nn.Module): """ Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer and "Generating Long Sequences with Sparse Transformers". """ def __init__(self, config: VMistralConfig, qk_layer_norms: bool = False): super().__init__() self.config = config self.hidden_size = config.hidden_size self.num_heads = config.num_attention_heads self.head_dim = self.hidden_size // self.num_heads self.num_key_value_heads = config.num_key_value_heads self.num_key_value_groups = self.num_heads // self.num_key_value_heads self.max_position_embeddings = config.max_position_embeddings self.rope_theta = config.rope_theta self.is_causal = True if (self.head_dim * self.num_heads) != self.hidden_size: raise ValueError( f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" f" and `num_heads`: {self.num_heads})." ) self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) self.qk_layer_norms = qk_layer_norms if self.qk_layer_norms: self.q_layer_norm = MistralRMSNorm(self.head_dim, eps=config.rms_norm_eps) self.k_layer_norm = MistralRMSNorm(self.head_dim, eps=config.rms_norm_eps) self.rotary_emb = MistralRotaryEmbedding( self.head_dim, max_position_embeddings=self.max_position_embeddings, base=self.rope_theta, ) self.attention_dropout = config.attention_dropout def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() def forward( self, hidden_states: torch.Tensor, key_value_states: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, web_attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: bool = False, use_cache: bool = False, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if "padding_mask" in kwargs: warnings.warn( "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use" " `attention_mask` instead.`" ) bsz, q_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = ( self.k_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) ) value_states = ( self.v_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) ) kv_seq_len = key_states.shape[-2] if past_key_value is not None: kv_seq_len += past_key_value[0].shape[-2] cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: # reuse k, v, self_attention key_states = torch.cat([past_key_value[0], key_states], dim=2) value_states = torch.cat([past_key_value[1], value_states], dim=2) past_key_value = (key_states, value_states) if use_cache else None if self.qk_layer_norms: query_states = self.q_layer_norm(query_states) key_states = self.k_layer_norm(key_states) # repeat k/v heads if n_kv_heads < n_heads key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) web_attention_range = self.config.web_attention_range def split_tensor(tensor): if int(web_attention_range) == 8: return fraction = float(web_attention_range) / 8 split_size_2 = int(self.num_heads * fraction) split_size_1 = self.num_heads - split_size_2 return torch.split(tensor, [split_size_1, split_size_2], dim=1) if int(web_attention_range) != 8: query_states_1, query_states_2 = split_tensor(query_states) key_states_1, key_states_2 = split_tensor(key_states) value_states_1, value_states_2 = split_tensor(value_states) with torch.backends.cuda.sdp_kernel( enable_flash=False, enable_math=True, enable_mem_efficient=False ): attn_output_1 = F.scaled_dot_product_attention(query_states_1, key_states_1, value_states_1, attn_mask=attention_mask) attn_output_2 = F.scaled_dot_product_attention(query_states_2, key_states_2, value_states_2, attn_mask=web_attention_mask) attn_output = torch.cat([attn_output_1, attn_output_2], dim=1) else: with torch.backends.cuda.sdp_kernel( enable_flash=False, enable_math=True, enable_mem_efficient=False ): attn_output = F.scaled_dot_product_attention(query_states, key_states, value_states, attention_mask=web_attention_mask) if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): raise ValueError( f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" f" {attn_output.size()}" ) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) attn_output = self.o_proj(attn_output) if not output_attentions: attn_weights = None return attn_output, attn_weights, past_key_value class WebFlashAttention2(WebAttention): """ Mistral flash attention module. This module inherits from `MistralAttention` as the weights of the module stays untouched. The only required change would be on the forward pass where it needs to correctly call the public API of flash attention and deal with padding tokens in case the input contains any of them. """ class WebDecoderLayer(nn.Module): def __init__(self, config: VMistralConfig): super().__init__() self.hidden_size = config.hidden_size self.self_attn = ( WebAttention(config=config) if not getattr(config, "_flash_attn_2_enabled", False) else WebFlashAttention2(config) ) self.mlp = MistralMLP(config) self.input_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, web_attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: if "padding_mask" in kwargs: warnings.warn( "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use" " `attention_mask` instead.`" ) """ Args: hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` attention_mask (`torch.FloatTensor`, *optional*): attention mask of size `(batch, sequence_length)` where padding elements are indicated by 0. output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. use_cache (`bool`, *optional*): If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states """ residual = hidden_states hidden_states = self.input_layernorm(hidden_states) # Self Attention hidden_states, self_attn_weights, present_key_value = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, web_attention_mask=web_attention_mask, position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, ) hidden_states = residual + hidden_states # Fully Connected residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states outputs = (hidden_states,) if output_attentions: outputs += (self_attn_weights,) if use_cache: outputs += (present_key_value,) return outputs class WebPreTrainedModel(PreTrainedModel): config_class = VMistralConfig base_model_prefix = "model" supports_gradient_checkpointing = True _no_split_modules = ["WebDecoderLayer"] _skip_keys_device_placement = "past_key_values" _supports_sdpa = False class WebModel(WebPreTrainedModel, VMistralModel): """ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`MistralDecoderLayer`] Args: config: VMistralConfig """ def __init__(self, config: VMistralConfig, vision_model=None): super().__init__(config) self.config = config self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size self.sliding_window = config.sliding_window self.embed_tokens = DecoupledEmbedding( num_embeddings=config.vocab_size, num_additional_embeddings=config.additional_vocab_size, embedding_dim=config.hidden_size, partially_freeze=config.freeze_text_layers, padding_idx=self.padding_idx, ) # Load an uninitialized model and later in from_pretrained will load the pre-trained model - # this solves the losing of weights in `from_pretrained` on the main model self.vision_model = SiglipVisionModel(config.vision_config) # Dim projection - projecting from the vision dim to the text dim self.modality_projection = ModalityProjection( embed_dim_in=self.config.vision_config.hidden_size, embed_dim_out=self.config.hidden_size ) # Perceiver Resampler if config.use_resampler: self.perceiver_resampler = PerceiverResampler( config.hidden_size, config.perceiver_config.resampler_depth, config.perceiver_config.resampler_n_heads, config.perceiver_config.resampler_head_dim, config.perceiver_config.resampler_n_latents, config.perceiver_config.qk_layer_norms_perceiver, ) if config.use_resampler: self.image_seq_len = config.perceiver_config.resampler_n_latents else: self.image_seq_len = ( config.vision_config.image_size // config.vision_config.patch_size ) ** 2 # TODO: pretty sure that does not work for CLIP models since there is the CLS token self.image_token_id = self.config.image_token_id self.layers = nn.ModuleList([WebDecoderLayer(config) for _ in range(config.num_hidden_layers)]) self.gradient_checkpointing = False self.norm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) # Initialize weights and apply final processing self.post_init() self.freeze_relevant_params(config) def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, web_attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, pixel_values: Optional[torch.FloatTensor] = None, image_hidden_states: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, VMistralBaseModelOutputWithPast]: device = input_ids.device if input_ids is not None else inputs_embeds.device output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict # retrieve input_ids and inputs_embeds if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") elif input_ids is not None: batch_size, seq_length = input_ids.shape elif inputs_embeds is not None: batch_size, seq_length, _ = inputs_embeds.shape else: raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") seq_length_with_past = seq_length past_key_values_length = 0 if past_key_values is not None: past_key_values_length = past_key_values[0][0].shape[2] seq_length_with_past = seq_length_with_past + past_key_values_length if position_ids is None: device = input_ids.device if input_ids is not None else inputs_embeds.device position_ids = torch.arange( past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device ) position_ids = position_ids.unsqueeze(0).view(-1, seq_length) else: position_ids = position_ids.view(-1, seq_length).long() if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) # START VISUAL INPUTS INTEGRATION if pixel_values is not None and image_hidden_states is not None: raise ValueError("You cannot specify both pixel_values and image_hidden_states at the same time") elif pixel_values is not None: pixel_values = pixel_values.to(dtype=self.dtype, device=input_ids.device) # fp16 compatibility batch_size, num_images = pixel_values.size(0), pixel_values.size(1) # this change allows multi image in a single batch pixel_values = pixel_values.contiguous().view(batch_size, num_images, *pixel_values.shape[2:]) # # Remove padding images - padding images are full 0. # real_images_inds = pixel_values.sum(dim=(-1, -2, -3)) != 0.0 # print(real_images_inds) # pixel_values = pixel_values[real_images_inds] # # Get sequence from the vision encoder # print("shape_pixel", pixel_values.shape) image_hidden_states = self.vision_model(pixel_values=pixel_values).last_hidden_state # Modality projection image_hidden_states = self.modality_projection(image_hidden_states) if self.config.use_resampler: image_hidden_states = self.perceiver_resampler(image_hidden_states) elif image_hidden_states is not None: image_hidden_states = image_hidden_states.to(dtype=self.dtype, device=input_ids.device) if past_key_values is None: # When we generate, we don't want to replace the potential image_token_id that we generated by images # that simply don't exist new_inp = self.inputs_merger( input_ids=input_ids, inputs_embeds=inputs_embeds, image_hidden_states=image_hidden_states, ) inputs_embeds = new_inp["inputs_embeds"] # Can do add some token types embeddings here (image token vs text token) # something like inputs_embeds += self.token_types(token_types) # embed positions if ( attention_mask is not None and hasattr(self.config, "_flash_attn_2_enabled") and self.config._flash_attn_2_enabled and past_key_values is not None ): is_padding_right = attention_mask[:, -1].sum().item() != batch_size if is_padding_right: raise ValueError( "You are attempting to perform batched generation with padding_side='right'" " this may lead to unexpected behaviour for Flash Attention version of Mistral. Make sure to " " call `tokenizer.padding_side = 'left'` before tokenizing the input. " ) # We did not implement our model using Flash attn 2 self.config._flash_attn_2_enabled = False if not getattr(self.config, "_flash_attn_2_enabled", False): # 2d mask is passed through the layers # attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None attention_mask = _prepare_4d_causal_attention_mask( attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length, ) web_attention_mask = web_attention_mask.unsqueeze(1) inverted_mask = 1.0 - web_attention_mask.to(inputs_embeds.dtype) web_attention_mask = inverted_mask.masked_fill( inverted_mask.to(torch.bool), -1.e32 ) if input_ids is not None: bsz, L = input_ids.size()[:2] web_attention_mask = web_attention_mask[:, :, -L:, :] else: print("Exiting, wrong branch") exit() # 4d mask is passed through the layers attention_mask = _prepare_4d_causal_attention_mask( attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length, sliding_window=self.config.sliding_window, ) attention_mask[attention_mask == -float("inf")] = torch.finfo(self.dtype).min hidden_states = inputs_embeds if self.gradient_checkpointing and self.training: if use_cache: logger.warning_once( "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None next_decoder_cache = () if use_cache else None for idx, decoder_layer in enumerate(self.layers): if output_hidden_states: all_hidden_states += (hidden_states,) past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, attention_mask, web_attention_mask, position_ids, past_key_value, output_attentions, use_cache, ) else: layer_outputs = decoder_layer( hidden_states, attention_mask=attention_mask, web_attention_mask=web_attention_mask, position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, ) hidden_states = layer_outputs[0] if use_cache: next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) if output_attentions: all_self_attns += (layer_outputs[1],) hidden_states = self.norm(hidden_states) # add hidden states from the last decoder layer if output_hidden_states: all_hidden_states += (hidden_states,) next_cache = next_decoder_cache if use_cache else None if not return_dict: return tuple( v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, image_hidden_states] if v is not None ) return VMistralBaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=next_cache, hidden_states=all_hidden_states, attentions=all_self_attns, image_hidden_states=image_hidden_states, ) class WebForVisionText2Text(WebPreTrainedModel, WebGenerationMixin): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config, vision_model=None): super().__init__(config) self.model = WebModel(config, vision_model=vision_model) self.image_token_id = self.config.image_token_id self.lm_head = DecoupledLinear( in_features=config.hidden_size, out_features=config.vocab_size, out_additional_features=config.additional_vocab_size, bias=False, partially_freeze=config.freeze_lm_head, ) # Initialize weights and apply final processing self.post_init() def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, web_attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, pixel_values: Optional[torch.FloatTensor] = None, image_hidden_states: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, html_tree = None, ) -> Union[Tuple, WebLMOutputWithPast]: r""" Args: labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. Returns: """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = self.model( input_ids=input_ids, attention_mask=attention_mask, web_attention_mask=web_attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, pixel_values=pixel_values, image_hidden_states=image_hidden_states, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) hidden_states = outputs[0] logits = self.lm_head(hidden_states) logits = logits.float() loss = None if labels is not None: labels = labels.to(logits.device) # Shift so that tokens < n predict n if attention_mask is not None: shift_attention_mask = attention_mask[..., 1:].to(logits.device) shift_logits = logits[..., :-1, :][shift_attention_mask != 0].contiguous() shift_labels = labels[..., 1:][shift_attention_mask != 0].contiguous() else: shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() # Flatten the tokens loss_fct = CrossEntropyLoss(ignore_index=-100) loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) if not return_dict: output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output # print(f"forward takes: {time.time()-start_time}") return WebLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, image_hidden_states=outputs.image_hidden_states, html_tree = html_tree ) def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs ): image_hidden_states = kwargs.pop("image_hidden_states", None) if image_hidden_states is not None: kwargs["pixel_values"] = None inputs = prepare_inputs_for_generation(input_ids, past=past, **kwargs) web_attention_mask, html_tree = None, kwargs.get("html_tree") if html_tree.web_attention_mask is None : attention_mask = inputs["attention_mask"] web_attention_mask = torch.tril(torch.ones((attention_mask.shape[-1], attention_mask.shape[-1]), dtype = attention_mask.dtype)).unsqueeze(0) html_tree.web_attention_mask = web_attention_mask else: html_tree = kwargs.get("html_tree") input_ids = inputs["input_ids"] tokenizer = html_tree.tokenizer cur_decoded_token = tokenizer.convert_tokens_to_string([" "]+tokenizer.convert_ids_to_tokens(input_ids[:,-1])) web_attn_range = html_tree.update_buffer([cur_decoded_token]) bsz, L = html_tree.web_attention_mask.size()[:2] web_attention_mask = torch.zeros((bsz, L + 1, L + 1)).type_as(html_tree.web_attention_mask) web_attention_mask[:, :L, :L] = html_tree.web_attention_mask web_attn_range = torch.tensor(list(range(67))+[i + 67 for i in web_attn_range], dtype = web_attention_mask.dtype) web_attention_mask[:, -1, web_attn_range] = 1 html_tree.web_attention_mask = web_attention_mask if html_tree.input_ids is None : html_tree.input_ids = input_ids else: html_tree.input_ids = torch.cat((html_tree.input_ids, input_ids), dim = 1) unwanted_kwargs = ["token_type_ids"] inputs.update({ "web_attention_mask": web_attention_mask.to(inputs['attention_mask'].device), "html_tree": html_tree, }) for kwarg in unwanted_kwargs: inputs.pop(kwarg, None) return inputs