import torch from transformers.models.gemma2 import modeling_gemma2 # Monkey patch the Gemma2Model's forward function # Save a reference to the original forward function original_forward = modeling_gemma2.Gemma2Model.forward # Define the patched version of the forward function def patched_forward(self, input_ids=None, attention_mask=None, position_ids=None, past_key_values=None, inputs_embeds=None, use_cache=None, output_attentions=None, output_hidden_states=None, return_dict=None, cache_position=None): # Update parameters based on the input or configuration defaults output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict # Ensure either input_ids or inputs_embeds is specified, but not both if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") # Handle gradient checkpointing case to ensure compatibility with caching if self.gradient_checkpointing and self.training and use_cache: logger.warning_once( "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." ) use_cache = False # Embed tokens if inputs_embeds is not provided if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) # Handle caching mechanism if use_cache and past_key_values is None and not self.training: batch_size, seq_len, _ = inputs_embeds.shape past_key_values = modeling_gemma2.HybridCache( self.config, batch_size=batch_size, max_cache_len=seq_len, device=self.device, dtype=inputs_embeds.dtype, ) # Handle cache position if cache_position is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) # Handle position IDs if position_ids is None: position_ids = cache_position.unsqueeze(0) # Compute causal mask causal_mask = self._update_causal_mask( attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions ) # Embed positions and initialize hidden states hidden_states = inputs_embeds # Create the normalizer tensor on the correct device normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype, device=hidden_states.device) hidden_states = hidden_states * normalizer # Initialize variables to store outputs if required all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None # Pass through decoder layers for decoder_layer in self.layers: # Store the hidden state if requested if output_hidden_states: all_hidden_states += (hidden_states,) # Use gradient checkpointing if applicable if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, causal_mask, position_ids, past_key_values, output_attentions, use_cache, cache_position, ) else: # Normal forward pass through the layer layer_outputs = decoder_layer( hidden_states, attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, ) # Update hidden states with the output from the layer hidden_states = layer_outputs[0] # Store self-attentions if requested if output_attentions: all_self_attns += (layer_outputs[1],) # Apply final normalization hidden_states = self.norm(hidden_states) # Store the last hidden state if required if output_hidden_states: all_hidden_states += (hidden_states,) # Handle caching mechanism next_cache = past_key_values if use_cache else None # Prepare output if not return_dict: return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) return modeling_gemma2.BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=next_cache, hidden_states=all_hidden_states, attentions=all_self_attns, ) # Apply the patched forward function to the Gemma2Model class #modeling_gemma2.Gemma2Model.forward = patched_forward # Optional: You can define a function here that runs the patch. def apply_patch(): print("Gemma2Model's forward function has been patched.") modeling_gemma2.Gemma2Model.forward = patched_forward