| """Bidirectional GPT-2 variants for LLM2Vec-style conversion.""" |
|
|
| from __future__ import annotations |
|
|
| from typing import Optional, Tuple |
|
|
| import torch |
| from torch import nn |
| from transformers import GPT2Config, GPT2LMHeadModel, GPT2Model |
| from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2Block |
|
|
|
|
| class ModifiedGPT2Attention(GPT2Attention): |
| """GPT-2 attention with causal masking removed.""" |
|
|
| def _attn( |
| self, |
| query: torch.Tensor, |
| key: torch.Tensor, |
| value: torch.Tensor, |
| attention_mask: Optional[torch.Tensor] = None, |
| head_mask: Optional[torch.Tensor] = None, |
| ) -> Tuple[torch.Tensor, torch.Tensor]: |
| attn_weights = torch.matmul(query, key.transpose(-1, -2)) |
|
|
| if self.scale_attn_weights: |
| attn_weights = attn_weights / torch.full( |
| [], |
| value.size(-1) ** 0.5, |
| dtype=attn_weights.dtype, |
| device=attn_weights.device, |
| ) |
|
|
| if self.scale_attn_by_inverse_layer_idx: |
| attn_weights = attn_weights / float(self.layer_idx + 1) |
|
|
| |
| |
| if attention_mask is not None: |
| attn_weights = attn_weights + attention_mask |
|
|
| attn_weights = nn.functional.softmax(attn_weights, dim=-1) |
| attn_weights = attn_weights.type(value.dtype) |
| attn_weights = self.attn_dropout(attn_weights) |
|
|
| if head_mask is not None: |
| attn_weights = attn_weights * head_mask |
|
|
| attn_output = torch.matmul(attn_weights, value) |
| return attn_output, attn_weights |
|
|
|
|
| class ModifiedGPT2Block(GPT2Block): |
| """GPT-2 block using ModifiedGPT2Attention for self-attention.""" |
|
|
| def __init__(self, config: GPT2Config, layer_idx: Optional[int] = None): |
| super().__init__(config, layer_idx=layer_idx) |
| self.attn = ModifiedGPT2Attention(config=config, layer_idx=layer_idx) |
| if config.add_cross_attention: |
| self.crossattention = ModifiedGPT2Attention( |
| config=config, |
| is_cross_attention=True, |
| layer_idx=layer_idx, |
| ) |
|
|
|
|
| class GPT2BiModel(GPT2Model): |
| """GPT-2 encoder stack with bidirectional self-attention.""" |
|
|
| def __init__(self, config: GPT2Config): |
| super().__init__(config) |
| self.h = nn.ModuleList( |
| [ModifiedGPT2Block(config, layer_idx=i) for i in range(config.num_hidden_layers)] |
| ) |
| self.post_init() |
|
|
|
|
| class GPT2BiForMNTP(GPT2LMHeadModel): |
| """GPT-2 LM-head model whose backbone is GPT2BiModel.""" |
|
|
| def __init__(self, config: GPT2Config): |
| super().__init__(config) |
| self.transformer = GPT2BiModel(config) |
| self.post_init() |
|
|