Spaces:
Sleeping
Sleeping
| from typing import Optional, Union | |
| import torch | |
| import torch.nn.functional as F | |
| from transformers import GPT2LMHeadModel, GPT2Model, GPT2Config | |
| from transformers.cache_utils import Cache, DynamicCache, EncoderDecoderCache | |
| from transformers.masking_utils import create_causal_mask | |
| from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask_for_sdpa | |
| from transformers.modeling_outputs import ( | |
| BaseModelOutputWithPastAndCrossAttentions, | |
| CausalLMOutputWithCrossAttentions, | |
| ) | |
| from transformers.utils import ( | |
| logging, | |
| ) | |
| from transformers.models.gpt2.modeling_gpt2 import GPT2Block, GPT2Attention, eager_attention_forward | |
| from torch import nn | |
| from typing import Callable | |
| from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS | |
| import matplotlib.pyplot as plt | |
| logger = logging.get_logger(__name__) | |
| class GPT2AttentionModified(GPT2Attention): | |
| def __init__(self, config, is_cross_attention=False, layer_idx=None): | |
| super().__init__(config, is_cross_attention=is_cross_attention, layer_idx=layer_idx) | |
| self.config = config | |
| max_positions = 2048 | |
| self.register_buffer( | |
| "bias", | |
| torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view( | |
| 1, 1, max_positions, max_positions | |
| ), | |
| persistent=False, | |
| ) | |
| def forward( | |
| self, | |
| hidden_states: Optional[tuple[torch.FloatTensor]], | |
| past_key_values: Optional[Cache] = None, | |
| cache_position: Optional[torch.LongTensor] = None, | |
| attention_mask: Optional[torch.FloatTensor] = None, | |
| head_mask: Optional[torch.FloatTensor] = None, | |
| encoder_hidden_states: Optional[torch.Tensor] = None, | |
| encoder_attention_mask: Optional[torch.FloatTensor] = None, | |
| output_attentions: Optional[bool] = False, | |
| **kwargs, | |
| ) -> tuple[Union[torch.Tensor, tuple[torch.Tensor]], ...]: | |
| is_cross_attention = encoder_hidden_states is not None | |
| if past_key_values is not None: | |
| if isinstance(past_key_values, EncoderDecoderCache): | |
| is_updated = past_key_values.is_updated.get(self.layer_idx) | |
| if is_cross_attention: | |
| # after the first generated id, we can subsequently re-use all key/value_layer from cache | |
| curr_past_key_value = past_key_values.cross_attention_cache | |
| else: | |
| curr_past_key_value = past_key_values.self_attention_cache | |
| else: | |
| curr_past_key_value = past_key_values | |
| if is_cross_attention: | |
| if not hasattr(self, "q_attn"): | |
| raise ValueError( | |
| "If class is used as cross attention, the weights `q_attn` have to be defined. " | |
| "Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`." | |
| ) | |
| query_states = self.q_attn(hidden_states) | |
| attention_mask = encoder_attention_mask | |
| # Try to get key/value states from cache if possible | |
| if past_key_values is not None and is_updated: | |
| key_states = curr_past_key_value.layers[self.layer_idx].keys | |
| value_states = curr_past_key_value.layers[self.layer_idx].values | |
| else: | |
| key_states, value_states = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2) | |
| shape_kv = (*key_states.shape[:-1], -1, self.head_dim) | |
| key_states = key_states.view(shape_kv).transpose(1, 2) | |
| value_states = value_states.view(shape_kv).transpose(1, 2) | |
| else: | |
| query_states, key_states, value_states = self.c_attn(hidden_states).split(self.split_size, dim=2) | |
| shape_kv = (*key_states.shape[:-1], -1, self.head_dim) | |
| key_states = key_states.view(shape_kv).transpose(1, 2) | |
| value_states = value_states.view(shape_kv).transpose(1, 2) | |
| shape_q = (*query_states.shape[:-1], -1, self.head_dim) | |
| query_states = query_states.view(shape_q).transpose(1, 2) | |
| if (past_key_values is not None and not is_cross_attention) or ( | |
| past_key_values is not None and is_cross_attention and not is_updated | |
| ): | |
| # save all key/value_layer to cache to be re-used for fast auto-regressive generation | |
| cache_position = cache_position if not is_cross_attention else None | |
| key_states, value_states = curr_past_key_value.update( | |
| key_states, value_states, self.layer_idx, {"cache_position": cache_position} | |
| ) | |
| # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls | |
| if is_cross_attention: | |
| past_key_values.is_updated[self.layer_idx] = True | |
| is_causal = attention_mask is None and query_states.shape[-2] > 1 and not is_cross_attention | |
| using_eager = self.config._attn_implementation == "eager" | |
| attention_interface: Callable = eager_attention_forward | |
| if self.config._attn_implementation != "eager": | |
| attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] | |
| if using_eager and self.reorder_and_upcast_attn: | |
| attn_output, attn_weights = self._upcast_and_reordered_attn( | |
| query_states, key_states, value_states, attention_mask, head_mask | |
| ) | |
| else: | |
| if getattr(self.config, "prefix_allowed_length", None) is not None: | |
| temp = self | |
| temp.is_cross_attention = True | |
| attn_output, attn_weights = attention_interface( | |
| self if getattr(self.config, "prefix_allowed_length", None) is None else temp, | |
| query_states, | |
| key_states, | |
| value_states, | |
| attention_mask, | |
| head_mask=head_mask, | |
| dropout=self.attn_dropout.p if self.training else 0.0, | |
| is_causal=is_causal if getattr(self.config, "is_prefix", None) is None else False, | |
| **kwargs, | |
| ) | |
| if getattr(self.config, "plot_attention_map", False) and self.layer_idx in getattr(self.config, "plot_attention_map_layer", []): | |
| # pick batch=0, head=0 | |
| attn_bh = attn_weights[0, 0] # [L,S] | |
| L, S = attn_bh.shape | |
| if L > 1: | |
| if getattr(self.config, "plot_attention_map_generation", 0) == 0: | |
| print(f"Plotting attention map for inputs on layer {self.layer_idx}") | |
| # full 2D heatmap | |
| data = attn_bh.detach().float().cpu().numpy() # [L,S] | |
| plt.figure(figsize=(6,5)) | |
| plt.imshow(data, aspect="auto", cmap="hot", vmin=0, vmax=0.01) | |
| plt.colorbar() | |
| plt.xlabel("Keys (S)") | |
| plt.ylabel("Queries (L)") | |
| plt.title(f"Attention map (B0,H0) L={L}, S={S}") | |
| plt.show() | |
| else: | |
| if getattr(self.config, "plot_attention_map_generation", 0) == S: | |
| print(f"Plotting attention row map for token {S} generation on layer {self.layer_idx}") | |
| # attn_bh expected shape: [..., S] for the selected (B0, H0) row | |
| row = attn_bh[0].detach().float().cpu().numpy() # -> np.ndarray shape [S] | |
| n = row.shape[0] | |
| # ----- First 1024 as 32x32 ----- | |
| head_1024 = row[:min(1024, n)] | |
| grid = head_1024.reshape(32, 32) | |
| plt.figure(figsize=(6, 5)) | |
| plt.imshow(grid, aspect="auto", cmap="hot", vmin=0, vmax=0.01) | |
| plt.yticks([]) | |
| plt.colorbar() | |
| plt.xlabel("Keys (S) [indices 0..1023]") | |
| plt.title(f"Attention row (B0,H0) L={self.layer_idx}, S={S} — first 1024") | |
| plt.tight_layout() | |
| plt.show() | |
| # ----- Tail (>=1024) as a single-row heatmap ----- | |
| tail = row[1024:] | |
| if tail.size > 0: | |
| plt.figure(figsize=(10, 1.2)) | |
| # one-row heatmap | |
| plt.imshow(tail[None, :], aspect="auto", cmap="hot", vmin=0, vmax=0.01) | |
| plt.yticks([]) | |
| plt.colorbar() | |
| plt.xlabel(f"Keys (S) [indices 1024..{n-1}]") | |
| plt.title(f"Attention row tail (B0,H0) L={self.layer_idx}, S={S}") | |
| plt.tight_layout() | |
| plt.show() | |
| attn_output = attn_output.reshape(*attn_output.shape[:-2], -1).contiguous() | |
| attn_output = self.c_proj(attn_output) | |
| attn_output = self.resid_dropout(attn_output) | |
| return attn_output, attn_weights | |
| class GPT2BlockModified(GPT2Block): | |
| def __init__(self, config, layer_idx=None): | |
| super().__init__(config=config) | |
| self.attn = GPT2AttentionModified(config=config, layer_idx=layer_idx) | |
| def forward( | |
| self, | |
| hidden_states: Optional[tuple[torch.FloatTensor]], | |
| past_key_values: Optional[Cache] = None, | |
| cache_position: Optional[torch.LongTensor] = None, | |
| attention_mask: Optional[torch.FloatTensor] = None, | |
| head_mask: Optional[torch.FloatTensor] = None, | |
| encoder_hidden_states: Optional[torch.Tensor] = None, | |
| encoder_attention_mask: Optional[torch.FloatTensor] = None, | |
| use_cache: Optional[bool] = False, | |
| output_attentions: Optional[bool] = False, | |
| **kwargs, | |
| ) -> Union[tuple[torch.Tensor], Optional[tuple[torch.Tensor, tuple[torch.FloatTensor, ...]]]]: | |
| residual = hidden_states | |
| hidden_states = self.ln_1(hidden_states) | |
| attn_output, self_attn_weights = self.attn( | |
| hidden_states, | |
| past_key_values=past_key_values, | |
| cache_position=cache_position, | |
| attention_mask=attention_mask, | |
| head_mask=head_mask, | |
| use_cache=use_cache, | |
| output_attentions=output_attentions, | |
| **kwargs, | |
| ) | |
| # residual connection | |
| hidden_states = attn_output + residual | |
| if encoder_hidden_states is not None: | |
| # add one self-attention block for cross-attention | |
| if not hasattr(self, "crossattention"): | |
| raise ValueError( | |
| f"If `encoder_hidden_states` are passed, {self} has to be instantiated with " | |
| "cross-attention layers by setting `config.add_cross_attention=True`" | |
| ) | |
| residual = hidden_states | |
| hidden_states = self.ln_cross_attn(hidden_states) | |
| cross_attn_output, cross_attn_weights = self.crossattention( | |
| hidden_states, | |
| past_key_values=past_key_values, | |
| attention_mask=attention_mask, | |
| head_mask=head_mask, | |
| encoder_hidden_states=encoder_hidden_states, | |
| encoder_attention_mask=encoder_attention_mask, | |
| output_attentions=output_attentions, | |
| ) | |
| # residual connection | |
| hidden_states = residual + cross_attn_output | |
| residual = hidden_states | |
| hidden_states = self.ln_2(hidden_states) | |
| feed_forward_hidden_states = self.mlp(hidden_states) | |
| # residual connection | |
| hidden_states = residual + feed_forward_hidden_states | |
| outputs = (hidden_states,) | |
| if output_attentions: | |
| outputs += (self_attn_weights,) | |
| if encoder_hidden_states is not None: | |
| outputs += (cross_attn_weights,) | |
| return outputs | |
| class GPT2ModelModified(GPT2Model): | |
| def __init__(self, config): | |
| super().__init__(config) | |
| self.config = config | |
| self.config_causal = config | |
| self.config_causal._attn_implementation = "eager" # Ensure causal mask creation uses eager implementation | |
| # TEMPORARY: override the transformer blocks to pass segmentation masks | |
| self.h = nn.ModuleList([GPT2BlockModified(config, layer_idx=i) for i in range(config.num_hidden_layers)]) | |
| def forward( | |
| self, | |
| input_ids: Optional[torch.LongTensor] = None, | |
| past_key_values: Optional[Union[tuple[tuple[torch.Tensor]], Cache]] = None, | |
| cache_position: Optional[torch.LongTensor] = None, | |
| attention_mask: Optional[torch.FloatTensor] = None, | |
| token_type_ids: Optional[torch.LongTensor] = None, | |
| position_ids: Optional[torch.LongTensor] = None, | |
| head_mask: Optional[torch.FloatTensor] = None, | |
| inputs_embeds: Optional[torch.FloatTensor] = None, | |
| encoder_hidden_states: Optional[torch.Tensor] = None, | |
| encoder_attention_mask: Optional[torch.FloatTensor] = None, | |
| use_cache: Optional[bool] = None, | |
| output_attentions: Optional[bool] = None, | |
| output_hidden_states: Optional[bool] = None, | |
| return_dict: Optional[bool] = None, | |
| segmentation_mask: Optional[torch.FloatTensor] = None, | |
| **kwargs, | |
| ) -> Union[tuple, BaseModelOutputWithPastAndCrossAttentions]: | |
| r""" | |
| input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`): | |
| `input_ids_length` = `sequence_length` if `past_key_values` is `None` else | |
| `past_key_values.get_seq_length()` (`sequence_length` of input past key value states). Indices of input | |
| sequence tokens in the vocabulary. | |
| If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as | |
| `input_ids`. | |
| Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and | |
| [`PreTrainedTokenizer.__call__`] for details. | |
| [What are input IDs?](../glossary#input-ids) | |
| """ | |
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions | |
| output_hidden_states = ( | |
| output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states | |
| ) | |
| use_cache = use_cache if use_cache is not None else self.config.use_cache | |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
| if input_ids is not None and inputs_embeds is not None: | |
| raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") | |
| elif input_ids is not None: | |
| self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) | |
| input_shape = input_ids.size() | |
| input_ids = input_ids.view(-1, input_shape[-1]) | |
| batch_size = input_ids.shape[0] | |
| elif inputs_embeds is not None: | |
| input_shape = inputs_embeds.size()[:-1] | |
| batch_size = inputs_embeds.shape[0] | |
| else: | |
| raise ValueError("You have to specify either input_ids or inputs_embeds") | |
| device = input_ids.device if input_ids is not None else inputs_embeds.device | |
| if token_type_ids is not None: | |
| token_type_ids = token_type_ids.view(-1, input_shape[-1]) | |
| if self.gradient_checkpointing and self.training: | |
| if use_cache: | |
| logger.warning_once( | |
| "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." | |
| ) | |
| use_cache = False | |
| # based on pattern from src/transformers/models/whisper/modeling_whisper.py::WhisperDecoder | |
| if use_cache: | |
| if past_key_values is None: | |
| past_key_values = DynamicCache() | |
| elif isinstance(past_key_values, tuple): | |
| logger.warning_once( | |
| "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.53.0. " | |
| "You should pass an instance of `Cache` instead, e.g. " | |
| "`past_key_values=DynamicCache.from_legacy_cache(past_key_values)`." | |
| ) | |
| past_key_values = DynamicCache.from_legacy_cache(past_key_values) | |
| if self.config.add_cross_attention and not isinstance(past_key_values, EncoderDecoderCache): | |
| past_key_values = EncoderDecoderCache(past_key_values, DynamicCache()) | |
| if inputs_embeds is None: | |
| inputs_embeds = self.wte(input_ids) | |
| if cache_position is None: | |
| past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 | |
| cache_position = torch.arange( | |
| past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device | |
| ) | |
| if position_ids is None: | |
| position_ids = cache_position.unsqueeze(0) | |
| position_embeds = self.wpe(position_ids) | |
| hidden_states = inputs_embeds + position_embeds.to(inputs_embeds.device) | |
| # Attention mask. | |
| # ._update_causal_mask() and ._prepare_4d_causal_attention_mask_with_cache_position() copied from LlamaModel | |
| if attention_mask is not None and attention_mask.ndim < 4: | |
| attention_mask = attention_mask.view(batch_size, -1) | |
| causal_mask = create_causal_mask( | |
| config=self.config_causal, | |
| input_embeds=inputs_embeds, | |
| attention_mask=attention_mask, | |
| cache_position=cache_position, | |
| past_key_values=past_key_values, | |
| position_ids=position_ids, | |
| ) | |
| # If a 2D or 3D attention mask is provided for the cross-attention | |
| # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] | |
| _use_sdpa = self._attn_implementation == "sdpa" and output_attentions is False and head_mask is None | |
| if self.config.add_cross_attention and encoder_hidden_states is not None: | |
| encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() | |
| encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) | |
| if encoder_attention_mask is None: | |
| encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) | |
| if _use_sdpa: | |
| encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa( | |
| mask=encoder_attention_mask, dtype=inputs_embeds.dtype, tgt_len=input_shape[-1] | |
| ) | |
| elif self._attn_implementation != "flash_attention_2": | |
| encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask) | |
| else: | |
| encoder_attention_mask = None | |
| # Prepare head mask if needed | |
| # 1.0 in head_mask indicate we keep the head | |
| # attention_probs has shape bsz x n_heads x N x N | |
| # head_mask has shape n_layer x batch x n_heads x N x N | |
| head_mask = self.get_head_mask(head_mask, self.config.n_layer) | |
| if token_type_ids is not None: | |
| token_type_embeds = self.wte(token_type_ids) | |
| hidden_states = hidden_states + token_type_embeds | |
| hidden_states = self.drop(hidden_states) | |
| output_shape = (-1,) + input_shape[1:] + (hidden_states.size(-1),) | |
| all_self_attentions = () if output_attentions else None | |
| all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None | |
| all_hidden_states = () if output_hidden_states else None | |
| for i, block in enumerate(self.h): | |
| # Model parallel | |
| if self.model_parallel: | |
| torch.cuda.set_device(hidden_states.device) | |
| if isinstance(head_mask, torch.Tensor): | |
| head_mask = head_mask.to(hidden_states.device) | |
| if output_hidden_states: | |
| all_hidden_states = all_hidden_states + (hidden_states,) | |
| if segmentation_mask is not None and causal_mask is not None: | |
| # Make a safe copy of the causal mask and ensure its spatial | |
| # dimensions match the sequence length that the attention | |
| # functions expect. This prevents off-by-one shape errors | |
| # when using eager attention (torch.where requires same sizes). | |
| causal_mask_modified = causal_mask.clone() | |
| if getattr(self.config, "prefix_allowed_length", None) is not None: | |
| causal_mask_modified[:, :, :, :self.config.prefix_allowed_length].zero_() | |
| # Use the input sequence length to crop the causal mask if needed | |
| seq_len = input_shape[-1] | |
| if causal_mask_modified.shape[2] != seq_len or causal_mask_modified.shape[3] != seq_len: | |
| causal_mask_modified = causal_mask_modified[:, :, :seq_len, :seq_len] | |
| # Clip segmentation mask to fit into causal_mask_modified before adding. | |
| _, _, M, N = segmentation_mask.shape | |
| M = min(M, causal_mask_modified.shape[2]) | |
| N = min(N, causal_mask_modified.shape[3]) | |
| causal_mask_modified[:, :, :M, :N] += segmentation_mask[:, i, :M, :N].unsqueeze(1) | |
| if getattr(self.config, "plot_attention_mask", False) and i in getattr(self.config, "plot_attention_mask_layer", [0]): | |
| if segmentation_mask is not None and causal_mask is not None: | |
| print(f"Block {i}: segmentation mask added to causal mask.") | |
| plt.imshow(causal_mask_modified[0,0].detach().cpu(), aspect='auto', cmap='hot', vmin=-1, vmax=1) | |
| plt.colorbar() | |
| plt.title(f"Causal Mask with Segmentation (Block {i})") | |
| plt.show() | |
| else: | |
| print(f"Block {i}: no segmentation mask applied.") | |
| plt.imshow(causal_mask[0,0].detach().cpu(), aspect='auto', cmap='hot', vmin=-1, vmax=1) | |
| plt.colorbar() | |
| plt.title(f"Causal Mask (Block {i})") | |
| plt.show() | |
| outputs = block( | |
| hidden_states, | |
| past_key_values if not (self.gradient_checkpointing and self.training) else None, | |
| cache_position, | |
| causal_mask_modified if segmentation_mask is not None and causal_mask is not None else causal_mask, | |
| head_mask[i], | |
| encoder_hidden_states, # as a positional argument for gradient checkpointing | |
| encoder_attention_mask=encoder_attention_mask, | |
| use_cache=use_cache, | |
| output_attentions=output_attentions, | |
| **kwargs, | |
| ) | |
| hidden_states = outputs[0] | |
| if output_attentions: | |
| all_self_attentions = all_self_attentions + (outputs[1],) | |
| if self.config.add_cross_attention: | |
| all_cross_attentions = all_cross_attentions + (outputs[2],) | |
| # Model Parallel: If it's the last layer for that device, put things on the next device | |
| if self.model_parallel: | |
| for k, v in self.device_map.items(): | |
| if i == v[-1] and "cuda:" + str(k) != self.last_device: | |
| hidden_states = hidden_states.to("cuda:" + str(k + 1)) | |
| hidden_states = self.ln_f(hidden_states) | |
| hidden_states = hidden_states.view(output_shape) | |
| # Add last hidden state | |
| if output_hidden_states: | |
| all_hidden_states = all_hidden_states + (hidden_states,) | |
| past_key_values = past_key_values if use_cache else None | |
| if not return_dict: | |
| return tuple( | |
| v | |
| for v in [hidden_states, past_key_values, all_hidden_states, all_self_attentions, all_cross_attentions] | |
| if v is not None | |
| ) | |
| return BaseModelOutputWithPastAndCrossAttentions( | |
| last_hidden_state=hidden_states, | |
| past_key_values=past_key_values, | |
| hidden_states=all_hidden_states, | |
| attentions=all_self_attentions, | |
| cross_attentions=all_cross_attentions, | |
| ) | |
| class GPT2LMHeadModelModified(GPT2LMHeadModel): | |
| def __init__(self, config): | |
| super().__init__(config) | |
| # replace the base transformer with our modified transformer implementation | |
| self.transformer = GPT2ModelModified(config) | |
| self.post_init() | |
| def forward( | |
| self, | |
| input_ids: Optional[torch.LongTensor] = None, | |
| past_key_values: Optional[tuple[tuple[torch.Tensor]]] = None, | |
| cache_position: Optional[torch.LongTensor] = None, | |
| attention_mask: Optional[torch.FloatTensor] = None, | |
| token_type_ids: Optional[torch.LongTensor] = None, | |
| position_ids: Optional[torch.LongTensor] = None, | |
| head_mask: Optional[torch.FloatTensor] = None, | |
| inputs_embeds: Optional[torch.FloatTensor] = None, | |
| encoder_hidden_states: Optional[torch.Tensor] = None, | |
| encoder_attention_mask: Optional[torch.FloatTensor] = None, | |
| labels: Optional[torch.LongTensor] = None, | |
| use_cache: Optional[bool] = None, | |
| output_attentions: Optional[bool] = None, | |
| output_hidden_states: Optional[bool] = None, | |
| return_dict: Optional[bool] = None, | |
| logits_to_keep: Union[int, torch.Tensor] = 0, | |
| segmentation_mask: Optional[torch.FloatTensor] = None, | |
| prefix_allowed_length: Optional[int] = None, | |
| plot_attention_mask: Optional[bool] = False, | |
| plot_attention_mask_layer: Optional[list[int]] = [0], | |
| plot_attention_map: Optional[bool] = False, | |
| plot_attention_map_layer: Optional[list[int]] = [0], | |
| plot_attention_map_generation: Optional[int] = 0, | |
| **kwargs, | |
| ) -> Union[tuple, CausalLMOutputWithCrossAttentions]: | |
| r""" | |
| input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`): | |
| `input_ids_length` = `sequence_length` if `past_key_values` is `None` else | |
| `past_key_values.get_seq_length()` (`sequence_length` of input past key value states). Indices of input | |
| sequence tokens in the vocabulary. | |
| If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as | |
| `input_ids`. | |
| Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and | |
| [`PreTrainedTokenizer.__call__`] for details. | |
| [What are input IDs?](../glossary#input-ids) | |
| labels (`torch.LongTensor` of shape `(batch_size, input_ids_length)`, *optional*): | |
| Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set | |
| `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` | |
| are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` | |
| """ | |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
| if prefix_allowed_length is not None: | |
| self.config.prefix_allowed_length = prefix_allowed_length | |
| if plot_attention_mask is not None: | |
| self.config.plot_attention_mask = plot_attention_mask | |
| if plot_attention_mask_layer is not None: | |
| self.config.plot_attention_mask_layer = plot_attention_mask_layer | |
| if plot_attention_map is not None: | |
| if plot_attention_map_layer is not None: | |
| self.config.plot_attention_map_layer = plot_attention_map_layer | |
| if plot_attention_map_generation is not None: | |
| self.config.plot_attention_map_generation = plot_attention_map_generation | |
| self.config.plot_attention_map = plot_attention_map | |
| transformer_outputs = self.transformer( | |
| input_ids, | |
| past_key_values=past_key_values, | |
| attention_mask=attention_mask, | |
| cache_position=cache_position, | |
| token_type_ids=token_type_ids, | |
| position_ids=position_ids, | |
| head_mask=head_mask, | |
| inputs_embeds=inputs_embeds, | |
| encoder_hidden_states=encoder_hidden_states, | |
| encoder_attention_mask=encoder_attention_mask, | |
| use_cache=use_cache, | |
| output_attentions=output_attentions, | |
| output_hidden_states=output_hidden_states, | |
| return_dict=return_dict, | |
| segmentation_mask=segmentation_mask, #Added this parameter | |
| **kwargs, | |
| ) | |
| hidden_states = transformer_outputs[0] | |
| # Set device for model parallelism | |
| if self.model_parallel: | |
| torch.cuda.set_device(self.transformer.first_device) | |
| hidden_states = hidden_states.to(self.lm_head.weight.device) | |
| slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep | |
| logits = self.lm_head(hidden_states[:, slice_indices, :]) | |
| loss = None | |
| if labels is not None: | |
| # Flatten the tokens | |
| loss = self.loss_function( | |
| logits, | |
| labels, | |
| vocab_size=self.config.vocab_size, | |
| **kwargs, | |
| ) | |
| if not return_dict: | |
| output = (logits,) + transformer_outputs[1:] | |
| return ((loss,) + output) if loss is not None else output | |
| return CausalLMOutputWithCrossAttentions( | |
| loss=loss, | |
| logits=logits, | |
| past_key_values=transformer_outputs.past_key_values, | |
| hidden_states=transformer_outputs.hidden_states, | |
| attentions=transformer_outputs.attentions, | |
| cross_attentions=transformer_outputs.cross_attentions, | |
| ) | |
| def expand_gpt2_positional_embeddings( | |
| model: torch.nn.Module, | |
| new_max_positions: int, | |
| mode: str = "linear", # "linear" | "copy_last" | "zeros" | |
| align_corners: bool = True, # for linear interpolation | |
| ): | |
| """ | |
| Expand GPT-2's learned positional embeddings (wpe) to `new_max_positions`. | |
| Works with GPT2LMHeadModel or GPT2Model (HF). Updates model.config.n_positions (and n_ctx if present). | |
| Does NOT mutate token embeddings; only position table + config. | |
| Args: | |
| model: HF GPT2LMHeadModel or GPT2Model (already loaded). | |
| new_max_positions: int, desired max sequence length (e.g., 1536 or 2048). | |
| mode: how to initialize new rows if expanding: | |
| - "linear": 1D linear interpolation along position dim (recommended) | |
| - "copy_last": copy the last learned vector into all new rows | |
| - "zeros": initialize new rows to zero | |
| align_corners: passed to F.interpolate for "linear" mode. | |
| Returns: | |
| model (same instance) with expanded wpe and updated config. | |
| """ | |
| # Locate the position embedding table. | |
| # Support both: | |
| # - GPT2LMHeadModel (has .transformer which is a GPT2Model with .wpe) | |
| # - GPT2Model (exposes .wpe directly) | |
| if hasattr(model, "transformer") and hasattr(model.transformer, "wpe"): | |
| model_for_wpe = model.transformer | |
| elif hasattr(model, "wpe"): | |
| model_for_wpe = model | |
| else: | |
| raise ValueError("Model does not look like a GPT-2 family model with a position embedding 'wpe')") | |
| wpe = model_for_wpe.wpe | |
| old_n, d = wpe.weight.shape | |
| if new_max_positions <= 0: | |
| raise ValueError("new_max_positions must be positive") | |
| if new_max_positions == old_n: | |
| # Still update config for consistency | |
| if hasattr(model.config, "n_positions"): | |
| model.config.n_positions = new_max_positions | |
| if hasattr(model.config, "n_ctx"): | |
| model.config.n_ctx = new_max_positions | |
| return model | |
| device = wpe.weight.device | |
| dtype = wpe.weight.dtype | |
| if new_max_positions < old_n: | |
| # Shrink (rare): just slice | |
| new_weight = wpe.weight[:new_max_positions].clone() | |
| else: | |
| # Expand | |
| if mode == "linear": | |
| # Interpolate along position dimension. | |
| # Treat embedding dim as channels: (1, d, old_n) -> (1, d, new_n) -> (new_n, d) | |
| w = wpe.weight.transpose(0, 1).unsqueeze(0) # (1, d, old_n) | |
| w_new = F.interpolate(w, size=new_max_positions, mode="linear", align_corners=align_corners) | |
| new_weight = w_new.squeeze(0).transpose(0, 1).contiguous() # (new_n, d) | |
| elif mode == "copy_last": | |
| new_weight = torch.empty((new_max_positions, d), device=device, dtype=dtype) | |
| new_weight[:old_n].copy_(wpe.weight) | |
| new_weight[old_n:].copy_(wpe.weight[old_n - 1].expand(new_max_positions - old_n, d)) | |
| elif mode == "zeros": | |
| new_weight = torch.zeros((new_max_positions, d), device=device, dtype=dtype) | |
| new_weight[:old_n].copy_(wpe.weight) | |
| else: | |
| raise ValueError(f"Unknown mode '{mode}'") | |
| # Replace embedding module on whichever object held the original table | |
| new_wpe = torch.nn.Embedding(new_max_positions, d, device=device, dtype=dtype) | |
| new_wpe.weight.copy_(new_weight) | |
| # Keep requires_grad True (default). If you want to freeze, set .requires_grad_(False). | |
| if hasattr(model, "transformer") and hasattr(model.transformer, "wpe"): | |
| model.transformer.wpe = new_wpe | |
| else: | |
| model.wpe = new_wpe | |
| # Update config fields used by HF | |
| if hasattr(model.config, "n_positions"): | |
| model.config.n_positions = new_max_positions | |
| if hasattr(model.config, "n_ctx"): | |
| model.config.n_ctx = new_max_positions | |
| return model | |
| def create_decoder(attention = "sdpa"): | |
| config = GPT2Config.from_pretrained("gpt2") | |
| config._attn_implementation = attention | |
| new_max_positions = 2048 | |
| decoder = GPT2LMHeadModelModified.from_pretrained("gpt2", config=config) | |
| decoder.config._attn_implementation = attention | |
| decoder = expand_gpt2_positional_embeddings(decoder, new_max_positions=new_max_positions, mode="linear") | |
| return decoder |