| | """HELM-BERT model implementation. |
| | |
| | This module implements the HELM-BERT model with: |
| | - Disentangled attention (DeBERTa-style) |
| | - Enhanced Mask Decoder (EMD) for MLM |
| | - n-gram Induced Encoding (nGiE) layer |
| | """ |
| |
|
| | import math |
| | from typing import Any, Dict, Optional, Tuple, Union |
| |
|
| | import torch |
| | import torch.nn as nn |
| | from packaging import version |
| | from torch import _softmax_backward_data |
| | from transformers import PreTrainedModel |
| | from transformers.modeling_outputs import ( |
| | BaseModelOutputWithPooling, |
| | MaskedLMOutput, |
| | SequenceClassifierOutput, |
| | ) |
| |
|
| | from .configuration_helmbert import HELMBertConfig |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| |
|
| | def masked_layer_norm( |
| | layer_norm: nn.LayerNorm, x: torch.Tensor, mask: Optional[torch.Tensor] = None |
| | ) -> torch.Tensor: |
| | """Apply LayerNorm with masking to avoid updates on padding tokens. |
| | |
| | Args: |
| | layer_norm: LayerNorm module |
| | x: Input tensor (batch_size, seq_len, hidden_size) |
| | mask: Mask tensor where 0 = padding (ignored), 1 = valid token |
| | |
| | Returns: |
| | Normalized tensor with padding positions zeroed out |
| | """ |
| | output = layer_norm(x).to(x.dtype) |
| | if mask is None: |
| | return output |
| | if mask.dim() != x.dim(): |
| | if mask.dim() == 4: |
| | mask = mask.squeeze(1).squeeze(1) |
| | mask = mask.unsqueeze(2) |
| | mask = mask.to(output.dtype) |
| | return output * mask |
| |
|
| |
|
| | class XSoftmax(torch.autograd.Function): |
| | """Masked Softmax optimized for memory efficiency.""" |
| |
|
| | @staticmethod |
| | def forward( |
| | ctx, input: torch.Tensor, mask: Optional[torch.Tensor], dim: int |
| | ) -> torch.Tensor: |
| | ctx.dim = dim |
| | if mask is not None: |
| | rmask = ~(mask.bool()) |
| | if rmask.dim() == 2: |
| | rmask = rmask.unsqueeze(1).unsqueeze(2) |
| | elif rmask.dim() == 3: |
| | rmask = rmask.unsqueeze(2) |
| | output = input.masked_fill(rmask, float("-inf")) |
| | else: |
| | output = input |
| | output = torch.softmax(output, ctx.dim) |
| | if mask is not None: |
| | output.masked_fill_(rmask, 0) |
| | ctx.save_for_backward(output) |
| | return output |
| |
|
| | @staticmethod |
| | def backward(ctx, grad_output: torch.Tensor) -> Tuple[torch.Tensor, None, None]: |
| | (output,) = ctx.saved_tensors |
| | if version.Version(torch.__version__) >= version.Version("1.11.0"): |
| | input_grad = _softmax_backward_data( |
| | grad_output, output, ctx.dim, output.dtype |
| | ) |
| | else: |
| | input_grad = _softmax_backward_data(grad_output, output, ctx.dim, output) |
| | return input_grad, None, None |
| |
|
| |
|
| | def build_relative_position( |
| | query_size: int, |
| | key_size: int, |
| | bucket_size: int = -1, |
| | max_position: int = 512, |
| | device: Optional[torch.device] = None, |
| | ) -> torch.Tensor: |
| | """Build relative position matrix with optional log-bucketing.""" |
| | q_ids = torch.arange(query_size, dtype=torch.long, device=device) |
| | k_ids = torch.arange(key_size, dtype=torch.long, device=device) |
| | rel_pos = q_ids.unsqueeze(1) - k_ids.unsqueeze(0) |
| |
|
| | if bucket_size > 0: |
| | rel_buckets = 0 |
| | num_buckets = bucket_size |
| | rel_buckets += (rel_pos > 0).long() * (num_buckets // 2) |
| | rel_pos = torch.abs(rel_pos) |
| |
|
| | max_exact = num_buckets // 4 |
| | is_small = rel_pos < max_exact |
| |
|
| | rel_pos_if_large = ( |
| | max_exact |
| | + ( |
| | torch.log(rel_pos.float() / max_exact) |
| | / math.log(max_position / max_exact) |
| | * (num_buckets // 4 - 1) |
| | ).long() |
| | ) |
| | rel_pos_if_large = torch.min( |
| | rel_pos_if_large, torch.full_like(rel_pos_if_large, num_buckets // 2 - 1) |
| | ) |
| |
|
| | rel_buckets += torch.where(is_small, rel_pos, rel_pos_if_large) |
| | return rel_buckets |
| | else: |
| | rel_pos = torch.clamp(rel_pos, -max_position, max_position) |
| | return rel_pos + max_position |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| |
|
| | class DisentangledSelfAttention(nn.Module): |
| | """Disentangled self-attention with content and position separation. |
| | |
| | Implements content-to-content, content-to-position, and position-to-content |
| | attention as described in DeBERTa. |
| | """ |
| |
|
| | def __init__(self, config: HELMBertConfig): |
| | super().__init__() |
| |
|
| | if config.hidden_size % config.num_attention_heads != 0: |
| | raise ValueError( |
| | f"hidden_size ({config.hidden_size}) must be divisible by " |
| | f"num_attention_heads ({config.num_attention_heads})" |
| | ) |
| |
|
| | self.num_heads = config.num_attention_heads |
| | self.head_size = config.hidden_size // config.num_attention_heads |
| | self.all_head_size = self.num_heads * self.head_size |
| |
|
| | |
| | self.query_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=True) |
| | self.key_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=True) |
| | self.value_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=True) |
| |
|
| | |
| | self.pos_att_type = [x.strip() for x in config.pos_att_type.lower().split("|")] |
| | self.max_relative_positions = config.max_relative_positions |
| | self.position_buckets = config.position_buckets |
| | self.share_att_key = config.share_att_key |
| |
|
| | |
| | self.pos_ebd_size = config.max_relative_positions |
| | if config.position_buckets > 0: |
| | self.pos_ebd_size = config.position_buckets |
| |
|
| | |
| | self.rel_embeddings = nn.Embedding(self.pos_ebd_size * 2, config.hidden_size) |
| |
|
| | |
| | if not self.share_att_key: |
| | if "c2p" in self.pos_att_type or "p2p" in self.pos_att_type: |
| | self.pos_key_proj = nn.Linear( |
| | config.hidden_size, self.all_head_size, bias=True |
| | ) |
| | if "p2c" in self.pos_att_type or "p2p" in self.pos_att_type: |
| | self.pos_query_proj = nn.Linear( |
| | config.hidden_size, self.all_head_size, bias=False |
| | ) |
| |
|
| | |
| | self.dropout = nn.Dropout(config.attention_probs_dropout_prob) |
| | self.pos_dropout = nn.Dropout(config.attention_probs_dropout_prob) |
| |
|
| | def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: |
| | """Reshape tensor for attention computation.""" |
| | new_shape = x.size()[:-1] + (self.num_heads, self.head_size) |
| | x = x.view(*new_shape) |
| | return x.permute(0, 2, 1, 3).contiguous().view(-1, x.size(1), x.size(-1)) |
| |
|
| | def forward( |
| | self, |
| | hidden_states: torch.Tensor, |
| | attention_mask: Optional[torch.Tensor] = None, |
| | output_attentions: bool = False, |
| | query_states: Optional[torch.Tensor] = None, |
| | relative_pos: Optional[torch.Tensor] = None, |
| | rel_embeddings: Optional[torch.Tensor] = None, |
| | ) -> Dict[str, Any]: |
| | """Forward pass of disentangled attention.""" |
| | if query_states is None: |
| | query_states = hidden_states |
| |
|
| | |
| | query_layer = self.transpose_for_scores(self.query_proj(query_states)).float() |
| | key_layer = self.transpose_for_scores(self.key_proj(hidden_states)).float() |
| | value_layer = self.transpose_for_scores(self.value_proj(hidden_states)) |
| |
|
| | |
| | scale_factor = 1 |
| | if "c2p" in self.pos_att_type: |
| | scale_factor += 1 |
| | if "p2c" in self.pos_att_type: |
| | scale_factor += 1 |
| | if "p2p" in self.pos_att_type: |
| | scale_factor += 1 |
| |
|
| | scale = 1.0 / math.sqrt(self.head_size * scale_factor) |
| |
|
| | |
| | c2c_scores = torch.bmm(query_layer, key_layer.transpose(-1, -2) * scale) |
| | attention_scores = c2c_scores |
| |
|
| | |
| | if len(self.pos_att_type) > 0 and self.pos_att_type[0]: |
| | rel_att = self._disentangled_attention_bias( |
| | query_layer, key_layer, relative_pos, rel_embeddings, scale_factor |
| | ) |
| | if rel_att is not None: |
| | attention_scores = attention_scores + rel_att |
| |
|
| | |
| | attention_scores = ( |
| | attention_scores - attention_scores.max(dim=-1, keepdim=True)[0].detach() |
| | ) |
| | attention_scores = attention_scores.to(hidden_states.dtype) |
| |
|
| | |
| | attention_scores = attention_scores.view( |
| | -1, self.num_heads, attention_scores.size(-2), attention_scores.size(-1) |
| | ) |
| |
|
| | |
| | attention_probs = XSoftmax.apply(attention_scores, attention_mask, -1) |
| | attention_probs = self.dropout(attention_probs) |
| |
|
| | |
| | attention_probs_flat = attention_probs.view( |
| | -1, attention_probs.size(-2), attention_probs.size(-1) |
| | ) |
| | context_layer = torch.bmm(attention_probs_flat, value_layer) |
| |
|
| | |
| | context_layer = context_layer.view( |
| | -1, self.num_heads, context_layer.size(-2), context_layer.size(-1) |
| | ) |
| | context_layer = context_layer.permute(0, 2, 1, 3).contiguous() |
| | new_shape = context_layer.size()[:-2] + (self.all_head_size,) |
| | context_layer = context_layer.view(*new_shape) |
| |
|
| | result = {"hidden_states": context_layer, "attention_probs": attention_probs} |
| | return result |
| |
|
| | def _disentangled_attention_bias( |
| | self, |
| | query_layer: torch.Tensor, |
| | key_layer: torch.Tensor, |
| | relative_pos: Optional[torch.Tensor], |
| | rel_embeddings: Optional[torch.Tensor], |
| | scale_factor: int, |
| | ) -> Optional[torch.Tensor]: |
| | """Compute disentangled attention bias.""" |
| | if relative_pos is None: |
| | q_size = query_layer.size(-2) |
| | k_size = key_layer.size(-2) |
| | relative_pos = build_relative_position( |
| | q_size, |
| | k_size, |
| | bucket_size=self.position_buckets, |
| | max_position=self.max_relative_positions, |
| | device=query_layer.device, |
| | ) |
| |
|
| | if relative_pos.dim() == 2: |
| | relative_pos = relative_pos.unsqueeze(0).unsqueeze(0) |
| | elif relative_pos.dim() == 3: |
| | relative_pos = relative_pos.unsqueeze(1) |
| |
|
| | batch_size = query_layer.size(0) // self.num_heads |
| |
|
| | |
| | if rel_embeddings is None: |
| | rel_embeddings = self.rel_embeddings.weight |
| |
|
| | att_span = self.pos_ebd_size |
| | rel_embeddings = rel_embeddings[ |
| | self.pos_ebd_size - att_span : self.pos_ebd_size + att_span, : |
| | ].unsqueeze(0) |
| | rel_embeddings = self.pos_dropout(rel_embeddings) |
| |
|
| | score = torch.zeros_like(query_layer[:, :, :1]).expand( |
| | -1, -1, key_layer.size(-2) |
| | ) |
| |
|
| | |
| | c2p_pos = torch.clamp(relative_pos + att_span, 0, att_span * 2 - 1) |
| | c2p_pos = c2p_pos.squeeze(0).expand( |
| | query_layer.size(0), query_layer.size(1), relative_pos.size(-1) |
| | ) |
| |
|
| | |
| | if "c2p" in self.pos_att_type: |
| | pos_key_layer = ( |
| | self.pos_key_proj(rel_embeddings) |
| | if not self.share_att_key |
| | else self.key_proj(rel_embeddings) |
| | ) |
| | pos_key_layer = self.transpose_for_scores(pos_key_layer).repeat( |
| | batch_size, 1, 1 |
| | ) |
| |
|
| | c2p_scale = 1.0 / math.sqrt(self.head_size * scale_factor) |
| | c2p_att = torch.bmm( |
| | query_layer, pos_key_layer.transpose(-1, -2) * c2p_scale |
| | ) |
| | c2p_att = torch.gather(c2p_att, dim=-1, index=c2p_pos) |
| | score = score + c2p_att |
| |
|
| | |
| | if "p2c" in self.pos_att_type: |
| | pos_query_layer = ( |
| | self.pos_query_proj(rel_embeddings) |
| | if not self.share_att_key |
| | else self.query_proj(rel_embeddings) |
| | ) |
| | pos_query_layer = self.transpose_for_scores(pos_query_layer).repeat( |
| | batch_size, 1, 1 |
| | ) |
| |
|
| | p2c_scale = 1.0 / math.sqrt(self.head_size * scale_factor) |
| | p2c_att = torch.bmm( |
| | pos_query_layer * p2c_scale, key_layer.transpose(-1, -2) |
| | ) |
| | p2c_att = torch.gather(p2c_att, dim=-2, index=c2p_pos) |
| | score = score + p2c_att |
| |
|
| | return score |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| |
|
| | class HELMBertEmbeddings(nn.Module): |
| | """Token and position embeddings for HELM-BERT.""" |
| |
|
| | def __init__(self, config: HELMBertConfig): |
| | super().__init__() |
| | self.word_embeddings = nn.Embedding( |
| | config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id |
| | ) |
| | self.position_embeddings = nn.Embedding( |
| | config.max_position_embeddings, config.hidden_size |
| | ) |
| | self.layer_norm = nn.LayerNorm(config.hidden_size) |
| | self.dropout = nn.Dropout(config.hidden_dropout_prob) |
| |
|
| | def forward( |
| | self, |
| | input_ids: torch.Tensor, |
| | attention_mask: Optional[torch.Tensor] = None, |
| | ) -> Tuple[torch.Tensor, torch.Tensor]: |
| | """Forward pass. |
| | |
| | Returns: |
| | Tuple of (token_embeddings, position_embeddings) |
| | """ |
| | batch_size, seq_len = input_ids.shape |
| |
|
| | |
| | embeddings = self.word_embeddings(input_ids) |
| |
|
| | |
| | position_ids = torch.arange(seq_len, dtype=torch.long, device=input_ids.device) |
| | position_ids = position_ids.unsqueeze(0).expand(batch_size, -1) |
| | position_embeds = self.position_embeddings(position_ids) |
| |
|
| | |
| | embeddings = masked_layer_norm(self.layer_norm, embeddings, attention_mask) |
| | embeddings = self.dropout(embeddings) |
| |
|
| | return embeddings, position_embeds |
| |
|
| |
|
| | class NgieLayer(nn.Module): |
| | """n-gram Induced Input Encoding (nGiE) layer. |
| | |
| | Captures local n-gram patterns using 1D convolution. |
| | """ |
| |
|
| | def __init__(self, config: HELMBertConfig): |
| | super().__init__() |
| |
|
| | self.conv = nn.Conv1d( |
| | in_channels=config.hidden_size, |
| | out_channels=config.hidden_size, |
| | kernel_size=config.ngie_kernel_size, |
| | padding=(config.ngie_kernel_size - 1) // 2, |
| | groups=1, |
| | ) |
| | self.activation = nn.Tanh() |
| | self.layer_norm = nn.LayerNorm(config.hidden_size) |
| | self.dropout = nn.Dropout(config.ngie_dropout) |
| |
|
| | def forward( |
| | self, |
| | hidden_states: torch.Tensor, |
| | residual_states: torch.Tensor, |
| | attention_mask: torch.Tensor, |
| | ) -> torch.Tensor: |
| | """Forward pass. |
| | |
| | Args: |
| | hidden_states: Input to convolution (batch, seq, hidden) |
| | residual_states: States for residual connection (batch, seq, hidden) |
| | attention_mask: Mask where 1 = valid, 0 = padding |
| | |
| | Returns: |
| | Output with n-gram information incorporated |
| | """ |
| | |
| | out = ( |
| | self.conv(hidden_states.permute(0, 2, 1).contiguous()) |
| | .permute(0, 2, 1) |
| | .contiguous() |
| | ) |
| |
|
| | |
| | if version.Version(torch.__version__) >= version.Version("1.2.0a"): |
| | rmask = (1 - attention_mask).bool() |
| | else: |
| | rmask = (1 - attention_mask).byte() |
| |
|
| | |
| | out.masked_fill_(rmask.unsqueeze(-1).expand(out.size()), 0) |
| |
|
| | |
| | out = self.activation(self.dropout(out)) |
| |
|
| | |
| | output_states = masked_layer_norm( |
| | self.layer_norm, residual_states + out, attention_mask |
| | ) |
| |
|
| | return output_states |
| |
|
| |
|
| | class TransformerBlock(nn.Module): |
| | """Transformer block with disentangled attention and GELU FFN.""" |
| |
|
| | def __init__(self, config: HELMBertConfig): |
| | super().__init__() |
| |
|
| | self.self_attn = DisentangledSelfAttention(config) |
| | self.attn_output_dense = nn.Linear(config.hidden_size, config.hidden_size) |
| |
|
| | |
| | self.linear1 = nn.Sequential( |
| | nn.Linear(config.hidden_size, config.intermediate_size), nn.GELU() |
| | ) |
| | self.linear2 = nn.Linear(config.intermediate_size, config.hidden_size) |
| |
|
| | |
| | self.norm1 = nn.LayerNorm(config.hidden_size) |
| | self.norm2 = nn.LayerNorm(config.hidden_size) |
| | self.dropout1 = nn.Dropout(config.hidden_dropout_prob) |
| | self.dropout2 = nn.Dropout(config.hidden_dropout_prob) |
| |
|
| | def forward( |
| | self, |
| | src: torch.Tensor, |
| | src_key_padding_mask: Optional[torch.Tensor] = None, |
| | output_attentions: bool = False, |
| | query_states: Optional[torch.Tensor] = None, |
| | relative_pos: Optional[torch.Tensor] = None, |
| | rel_embeddings: Optional[torch.Tensor] = None, |
| | ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: |
| | """Forward pass. |
| | |
| | Args: |
| | src: Input embeddings [seq_len, batch, hidden] |
| | src_key_padding_mask: Padding mask [batch, seq_len] |
| | output_attentions: Whether to return attention weights |
| | query_states: Optional query for EMD |
| | relative_pos: Relative position indices |
| | rel_embeddings: Relative position embeddings |
| | |
| | Returns: |
| | Tuple of (output, optional attention weights) |
| | """ |
| | |
| | src_transposed = src.transpose(0, 1) |
| |
|
| | |
| | attention_mask = None |
| | if src_key_padding_mask is not None: |
| | attention_mask = (~src_key_padding_mask).float() |
| |
|
| | query_states_transposed = None |
| | if query_states is not None: |
| | query_states_transposed = query_states.transpose(0, 1) |
| |
|
| | |
| | attn_result = self.self_attn( |
| | src_transposed, |
| | attention_mask, |
| | output_attentions=output_attentions, |
| | query_states=query_states_transposed, |
| | relative_pos=relative_pos, |
| | rel_embeddings=rel_embeddings, |
| | ) |
| | attn_output = attn_result["hidden_states"].transpose(0, 1) |
| | attn_weights = attn_result.get("attention_probs") if output_attentions else None |
| |
|
| | |
| | attn_output = self.attn_output_dense(attn_output) |
| |
|
| | |
| | residual_input = query_states if query_states is not None else src |
| | src = residual_input + self.dropout1(attn_output) |
| |
|
| | |
| | src = src.transpose(0, 1) |
| | src = masked_layer_norm(self.norm1, src) |
| | src = src.transpose(0, 1) |
| |
|
| | |
| | ff_output = self.linear1(src) |
| | ff_output = self.linear2(ff_output) |
| | ff_output = self.dropout2(ff_output) |
| | src = src + ff_output |
| |
|
| | |
| | src = src.transpose(0, 1) |
| | src = masked_layer_norm(self.norm2, src) |
| | src = src.transpose(0, 1) |
| |
|
| | return src, attn_weights |
| |
|
| |
|
| | class HELMBertEncoder(nn.Module): |
| | """Stack of transformer blocks with nGiE layer.""" |
| |
|
| | def __init__(self, config: HELMBertConfig): |
| | super().__init__() |
| | self.config = config |
| |
|
| | |
| | self.ngie_layer = NgieLayer(config) |
| |
|
| | |
| | self.layers = nn.ModuleList( |
| | [TransformerBlock(config) for _ in range(config.num_hidden_layers)] |
| | ) |
| |
|
| | def get_rel_embedding(self) -> Optional[torch.Tensor]: |
| | """Get relative position embeddings from first layer.""" |
| | if len(self.layers) > 0: |
| | first_layer = self.layers[0] |
| | if hasattr(first_layer, "self_attn") and hasattr( |
| | first_layer.self_attn, "rel_embeddings" |
| | ): |
| | return first_layer.self_attn.rel_embeddings.weight |
| | return None |
| |
|
| | def forward( |
| | self, |
| | hidden_states: torch.Tensor, |
| | attention_mask: Optional[torch.Tensor] = None, |
| | position_embeddings: Optional[torch.Tensor] = None, |
| | output_attentions: bool = False, |
| | output_hidden_states: bool = False, |
| | use_emd: bool = False, |
| | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple], Optional[Tuple]]: |
| | """Forward pass. |
| | |
| | Args: |
| | hidden_states: Input embeddings [batch, seq, hidden] |
| | attention_mask: Attention mask [batch, seq] |
| | position_embeddings: Position embeddings for EMD |
| | output_attentions: Whether to return attention weights |
| | output_hidden_states: Whether to return all hidden states |
| | use_emd: Whether to use Enhanced Mask Decoder |
| | |
| | Returns: |
| | Tuple of (last_hidden_state, emd_output, all_hidden_states, all_attentions) |
| | """ |
| | all_hidden_states = () if output_hidden_states else None |
| | all_attentions = () if output_attentions else None |
| |
|
| | |
| | ngie_input_states = hidden_states |
| |
|
| | |
| | hidden_states = hidden_states.transpose(0, 1) |
| |
|
| | |
| | key_padding_mask = None |
| | if attention_mask is not None: |
| | key_padding_mask = ~attention_mask.bool() |
| |
|
| | |
| | layer_minus_2 = None |
| | num_layers = len(self.layers) |
| |
|
| | for layer_idx, layer in enumerate(self.layers): |
| | if output_hidden_states: |
| | all_hidden_states = all_hidden_states + (hidden_states.transpose(0, 1),) |
| |
|
| | hidden_states, attn_weights = layer( |
| | hidden_states, |
| | src_key_padding_mask=key_padding_mask, |
| | output_attentions=output_attentions, |
| | ) |
| |
|
| | if output_attentions and attn_weights is not None: |
| | all_attentions = all_attentions + (attn_weights,) |
| |
|
| | |
| | if layer_idx == 0: |
| | hidden_states_batch = hidden_states.transpose(0, 1) |
| | hidden_states_batch = self.ngie_layer( |
| | ngie_input_states, hidden_states_batch, attention_mask |
| | ) |
| | hidden_states = hidden_states_batch.transpose(0, 1) |
| |
|
| | |
| | if use_emd and layer_idx == num_layers - 2: |
| | layer_minus_2 = hidden_states |
| |
|
| | |
| | hidden_states = hidden_states.transpose(0, 1) |
| |
|
| | if output_hidden_states: |
| | all_hidden_states = all_hidden_states + (hidden_states,) |
| |
|
| | |
| | emd_output = None |
| | if use_emd and layer_minus_2 is not None and position_embeddings is not None: |
| | emd_keys_values = layer_minus_2 |
| | emd_query = layer_minus_2.transpose(0, 1) |
| | emd_query = position_embeddings + emd_query |
| | emd_query = emd_query.transpose(0, 1) |
| |
|
| | rel_embeddings = self.get_rel_embedding() |
| | last_layer = self.layers[-1] |
| |
|
| | for _ in range(2): |
| | emd_query, _ = last_layer( |
| | emd_keys_values, |
| | src_key_padding_mask=key_padding_mask, |
| | query_states=emd_query, |
| | relative_pos=None, |
| | rel_embeddings=rel_embeddings, |
| | ) |
| |
|
| | emd_output = emd_query.transpose(0, 1) |
| |
|
| | return hidden_states, emd_output, all_hidden_states, all_attentions |
| |
|
| |
|
| | class HELMBertPooler(nn.Module): |
| | """Mean pooling over sequence.""" |
| |
|
| | def __init__(self, config: HELMBertConfig): |
| | super().__init__() |
| | self.hidden_size = config.hidden_size |
| |
|
| | def forward( |
| | self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None |
| | ) -> torch.Tensor: |
| | """Apply mean pooling. |
| | |
| | Args: |
| | hidden_states: [batch, seq, hidden] |
| | attention_mask: [batch, seq] |
| | |
| | Returns: |
| | Pooled output [batch, hidden] |
| | """ |
| | if attention_mask is not None: |
| | mask_expanded = ( |
| | attention_mask.unsqueeze(-1).expand(hidden_states.size()).float() |
| | ) |
| | sum_embeddings = torch.sum(hidden_states * mask_expanded, 1) |
| | eps = torch.finfo(hidden_states.dtype).eps |
| | sum_mask = torch.clamp(mask_expanded.sum(1), min=eps) |
| | return sum_embeddings / sum_mask |
| | else: |
| | return hidden_states.mean(dim=1) |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| |
|
| | class HELMBertPreTrainedModel(PreTrainedModel): |
| | """Base class for HELM-BERT models.""" |
| |
|
| | config_class = HELMBertConfig |
| | base_model_prefix = "helmbert" |
| |
|
| | def _init_weights(self, module: nn.Module) -> None: |
| | """Initialize weights with BERT-style initialization.""" |
| | if isinstance(module, nn.Linear): |
| | nn.init.normal_(module.weight, mean=0.0, std=0.02) |
| | if module.bias is not None: |
| | nn.init.zeros_(module.bias) |
| | elif isinstance(module, nn.Embedding): |
| | nn.init.normal_(module.weight, mean=0.0, std=0.02) |
| | if module.padding_idx is not None: |
| | module.weight.data[module.padding_idx].zero_() |
| | elif isinstance(module, nn.LayerNorm): |
| | nn.init.ones_(module.weight) |
| | nn.init.zeros_(module.bias) |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| |
|
| | class HELMBertModel(HELMBertPreTrainedModel): |
| | """HELM-BERT base model. |
| | |
| | This model outputs the last hidden states and optionally pooled output. |
| | |
| | Example: |
| | >>> from helmbert import HELMBertModel, HELMBertTokenizer |
| | >>> tokenizer = HELMBertTokenizer() |
| | >>> model = HELMBertModel.from_pretrained("./checkpoints/helmbert-base") |
| | >>> inputs = tokenizer("PEPTIDE1{A.C.D.E}$$$$", return_tensors="pt") |
| | >>> outputs = model(**inputs) |
| | >>> last_hidden_state = outputs.last_hidden_state |
| | >>> pooler_output = outputs.pooler_output |
| | """ |
| |
|
| | def __init__(self, config: HELMBertConfig): |
| | super().__init__(config) |
| | self.config = config |
| |
|
| | self.embeddings = HELMBertEmbeddings(config) |
| | self.encoder = HELMBertEncoder(config) |
| | self.pooler = HELMBertPooler(config) |
| |
|
| | self.post_init() |
| |
|
| | def get_input_embeddings(self) -> nn.Embedding: |
| | return self.embeddings.word_embeddings |
| |
|
| | def set_input_embeddings(self, value: nn.Embedding) -> None: |
| | self.embeddings.word_embeddings = value |
| |
|
| | def forward( |
| | self, |
| | input_ids: torch.Tensor, |
| | attention_mask: Optional[torch.Tensor] = None, |
| | output_attentions: bool = False, |
| | output_hidden_states: bool = False, |
| | return_dict: bool = True, |
| | ) -> Union[Tuple, BaseModelOutputWithPooling]: |
| | """Forward pass. |
| | |
| | Args: |
| | input_ids: Token IDs [batch, seq] |
| | attention_mask: Attention mask [batch, seq] |
| | output_attentions: Whether to return attention weights |
| | output_hidden_states: Whether to return all hidden states |
| | return_dict: Whether to return a ModelOutput |
| | |
| | Returns: |
| | BaseModelOutputWithPooling or tuple |
| | """ |
| | if attention_mask is None: |
| | attention_mask = torch.ones_like(input_ids) |
| |
|
| | |
| | embeddings, position_embeddings = self.embeddings(input_ids, attention_mask) |
| |
|
| | |
| | encoder_outputs = self.encoder( |
| | embeddings, |
| | attention_mask=attention_mask, |
| | position_embeddings=position_embeddings, |
| | output_attentions=output_attentions, |
| | output_hidden_states=output_hidden_states, |
| | use_emd=False, |
| | ) |
| |
|
| | last_hidden_state = encoder_outputs[0] |
| | hidden_states = encoder_outputs[2] |
| | attentions = encoder_outputs[3] |
| |
|
| | |
| | pooler_output = self.pooler(last_hidden_state, attention_mask) |
| |
|
| | if not return_dict: |
| | return (last_hidden_state, pooler_output, hidden_states, attentions) |
| |
|
| | return BaseModelOutputWithPooling( |
| | last_hidden_state=last_hidden_state, |
| | pooler_output=pooler_output, |
| | hidden_states=hidden_states, |
| | attentions=attentions, |
| | ) |
| |
|
| |
|
| | class HELMBertLMHead(nn.Module): |
| | """MLM head with weight tying (HuggingFace standard).""" |
| |
|
| | def __init__(self, config: HELMBertConfig): |
| | super().__init__() |
| | self.dense = nn.Linear(config.hidden_size, config.hidden_size) |
| | self.layer_norm = nn.LayerNorm(config.hidden_size) |
| | self.activation = nn.GELU() |
| |
|
| | |
| | self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=True) |
| |
|
| | def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
| | """Forward pass. |
| | |
| | Args: |
| | hidden_states: [batch, seq, hidden] |
| | |
| | Returns: |
| | Logits [batch, seq, vocab] |
| | """ |
| | hidden_states = self.dense(hidden_states) |
| | hidden_states = self.activation(hidden_states) |
| | hidden_states = self.layer_norm(hidden_states) |
| | logits = self.decoder(hidden_states) |
| | return logits |
| |
|
| |
|
| | class HELMBertForMaskedLM(HELMBertPreTrainedModel): |
| | """HELM-BERT for Masked Language Modeling with Enhanced Mask Decoder (EMD). |
| | |
| | Example: |
| | >>> from helmbert import HELMBertForMaskedLM, HELMBertTokenizer |
| | >>> tokenizer = HELMBertTokenizer() |
| | >>> model = HELMBertForMaskedLM.from_pretrained("./checkpoints/helmbert-base") |
| | >>> inputs = tokenizer("PEPTIDE1{A.¶.D.E}$$$$", return_tensors="pt") # ¶ is mask |
| | >>> outputs = model(**inputs) |
| | >>> predictions = outputs.logits.argmax(dim=-1) |
| | """ |
| |
|
| | _tied_weights_keys = ["lm_head.decoder.weight"] |
| |
|
| | def __init__(self, config: HELMBertConfig): |
| | super().__init__(config) |
| | self.helmbert = HELMBertModel(config) |
| | self.lm_head = HELMBertLMHead(config) |
| |
|
| | self.post_init() |
| |
|
| | def get_output_embeddings(self) -> nn.Linear: |
| | return self.lm_head.decoder |
| |
|
| | def set_output_embeddings(self, new_embeddings: nn.Linear) -> None: |
| | self.lm_head.decoder = new_embeddings |
| |
|
| | def get_input_embeddings(self) -> nn.Embedding: |
| | return self.helmbert.embeddings.word_embeddings |
| |
|
| | def forward( |
| | self, |
| | input_ids: torch.Tensor, |
| | attention_mask: Optional[torch.Tensor] = None, |
| | labels: Optional[torch.Tensor] = None, |
| | output_attentions: bool = False, |
| | output_hidden_states: bool = False, |
| | return_dict: bool = True, |
| | use_emd: bool = True, |
| | ) -> Union[Tuple, MaskedLMOutput]: |
| | """Forward pass. |
| | |
| | Args: |
| | input_ids: Token IDs [batch, seq] |
| | attention_mask: Attention mask [batch, seq] |
| | labels: Labels for MLM (-100 for non-masked tokens) |
| | output_attentions: Whether to return attention weights |
| | output_hidden_states: Whether to return all hidden states |
| | return_dict: Whether to return a ModelOutput |
| | use_emd: Whether to use Enhanced Mask Decoder |
| | |
| | Returns: |
| | MaskedLMOutput or tuple |
| | """ |
| | if attention_mask is None: |
| | attention_mask = torch.ones_like(input_ids) |
| |
|
| | |
| | embeddings, position_embeddings = self.helmbert.embeddings( |
| | input_ids, attention_mask |
| | ) |
| |
|
| | |
| | encoder_outputs = self.helmbert.encoder( |
| | embeddings, |
| | attention_mask=attention_mask, |
| | position_embeddings=position_embeddings, |
| | output_attentions=output_attentions, |
| | output_hidden_states=output_hidden_states, |
| | use_emd=use_emd, |
| | ) |
| |
|
| | |
| | if use_emd and encoder_outputs[1] is not None: |
| | sequence_output = encoder_outputs[1] |
| | else: |
| | sequence_output = encoder_outputs[0] |
| |
|
| | hidden_states = encoder_outputs[2] |
| | attentions = encoder_outputs[3] |
| |
|
| | |
| | prediction_scores = self.lm_head(sequence_output) |
| |
|
| | |
| | loss = None |
| | if labels is not None: |
| | loss_fct = nn.CrossEntropyLoss(ignore_index=-100) |
| | loss = loss_fct( |
| | prediction_scores.view(-1, self.config.vocab_size), labels.view(-1) |
| | ) |
| |
|
| | if not return_dict: |
| | output = (prediction_scores, hidden_states, attentions) |
| | return ((loss,) + output) if loss is not None else output |
| |
|
| | return MaskedLMOutput( |
| | loss=loss, |
| | logits=prediction_scores, |
| | hidden_states=hidden_states, |
| | attentions=attentions, |
| | ) |
| |
|
| |
|
| | class MLPHead(nn.Module): |
| | """MLP head with skip connections for classification/regression. |
| | |
| | Architecture: input -> [Linear -> GELU -> LayerNorm -> Dropout (+ skip)] x N -> Linear -> output |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | input_dim: int, |
| | output_dim: int, |
| | hidden_dims: list, |
| | dropout: float = 0.1, |
| | ): |
| | super().__init__() |
| | self.layers = nn.ModuleList() |
| | self.norms = nn.ModuleList() |
| | self.dropouts = nn.ModuleList() |
| |
|
| | prev_dim = input_dim |
| | for hidden_dim in hidden_dims: |
| | self.layers.append(nn.Linear(prev_dim, hidden_dim)) |
| | self.norms.append(nn.LayerNorm(hidden_dim)) |
| | self.dropouts.append(nn.Dropout(dropout)) |
| | prev_dim = hidden_dim |
| |
|
| | self.output_layer = nn.Linear(prev_dim, output_dim) |
| | self.activation = nn.GELU() |
| |
|
| | def forward(self, x: torch.Tensor) -> torch.Tensor: |
| | for layer, norm, dropout in zip(self.layers, self.norms, self.dropouts): |
| | identity = x |
| | x = layer(x) |
| | if x.shape == identity.shape: |
| | x = x + identity |
| | x = self.activation(x) |
| | x = norm(x) |
| | x = dropout(x) |
| | return self.output_layer(x) |
| |
|
| |
|
| | class HELMBertForSequenceClassification(HELMBertPreTrainedModel): |
| | """HELM-BERT for sequence classification/regression. |
| | |
| | Example: |
| | >>> from helmbert import HELMBertForSequenceClassification, HELMBertConfig |
| | >>> # Simple linear head (default) |
| | >>> config = HELMBertConfig(num_labels=1) |
| | >>> model = HELMBertForSequenceClassification(config) |
| | >>> |
| | >>> # MLP head with 2 layers (for permeability prediction) |
| | >>> config = HELMBertConfig(num_labels=1, classifier_num_layers=2) |
| | >>> model = HELMBertForSequenceClassification(config) |
| | """ |
| |
|
| | def __init__(self, config: HELMBertConfig): |
| | super().__init__(config) |
| | self.num_labels = config.num_labels |
| | self.config = config |
| |
|
| | self.helmbert = HELMBertModel(config) |
| |
|
| | |
| | if config.classifier_num_layers > 0: |
| | hidden_dims = [config.hidden_size] * config.classifier_num_layers |
| | self.classifier = MLPHead( |
| | input_dim=config.hidden_size, |
| | output_dim=config.num_labels, |
| | hidden_dims=hidden_dims, |
| | dropout=config.classifier_dropout, |
| | ) |
| | else: |
| | self.dropout = nn.Dropout(config.classifier_dropout) |
| | self.classifier = nn.Linear(config.hidden_size, config.num_labels) |
| |
|
| | self.post_init() |
| |
|
| | def forward( |
| | self, |
| | input_ids: torch.Tensor, |
| | attention_mask: Optional[torch.Tensor] = None, |
| | labels: Optional[torch.Tensor] = None, |
| | output_attentions: bool = False, |
| | output_hidden_states: bool = False, |
| | return_dict: bool = True, |
| | ) -> Union[Tuple, SequenceClassifierOutput]: |
| | """Forward pass. |
| | |
| | Args: |
| | input_ids: Token IDs [batch, seq] |
| | attention_mask: Attention mask [batch, seq] |
| | labels: Labels for classification/regression |
| | output_attentions: Whether to return attention weights |
| | output_hidden_states: Whether to return all hidden states |
| | return_dict: Whether to return a ModelOutput |
| | |
| | Returns: |
| | SequenceClassifierOutput or tuple |
| | """ |
| | outputs = self.helmbert( |
| | input_ids, |
| | attention_mask=attention_mask, |
| | output_attentions=output_attentions, |
| | output_hidden_states=output_hidden_states, |
| | return_dict=True, |
| | ) |
| |
|
| | pooled_output = outputs.pooler_output |
| | |
| | if hasattr(self, "dropout"): |
| | pooled_output = self.dropout(pooled_output) |
| | logits = self.classifier(pooled_output) |
| |
|
| | loss = None |
| | if labels is not None: |
| | if self.config.problem_type is None: |
| | if self.num_labels == 1: |
| | self.config.problem_type = "regression" |
| | elif self.num_labels > 1 and ( |
| | labels.dtype == torch.long or labels.dtype == torch.int |
| | ): |
| | self.config.problem_type = "single_label_classification" |
| | else: |
| | self.config.problem_type = "multi_label_classification" |
| |
|
| | if self.config.problem_type == "regression": |
| | loss_fct = nn.MSELoss() |
| | if self.num_labels == 1: |
| | loss = loss_fct(logits.squeeze(), labels.squeeze()) |
| | else: |
| | loss = loss_fct(logits, labels) |
| | elif self.config.problem_type == "single_label_classification": |
| | loss_fct = nn.CrossEntropyLoss() |
| | loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) |
| | elif self.config.problem_type == "multi_label_classification": |
| | loss_fct = nn.BCEWithLogitsLoss() |
| | loss = loss_fct(logits, labels) |
| |
|
| | if not return_dict: |
| | output = (logits,) + outputs[2:] |
| | return ((loss,) + output) if loss is not None else output |
| |
|
| | return SequenceClassifierOutput( |
| | loss=loss, |
| | logits=logits, |
| | hidden_states=outputs.hidden_states, |
| | attentions=outputs.attentions, |
| | ) |
| |
|