import torch from transformers import PreTrainedModel, PretrainedConfig import torch import torch.nn as nn from transformers import PretrainedConfig, PreTrainedModel from torch.nn.parameter import Parameter from torch.nn.init import xavier_uniform_, constant_ from configuration_dlmberta import InteractionModelATTNConfig import math class StdScaler(): def fit(self, X): self.mean_ = torch.mean(X).item() self.std_ = torch.std(X, correction=0).item() def fit_transform(self, X): self.mean_ = torch.mean(X).item() self.std_ = torch.std(X, correction=0).item() return (X-self.mean_)/self.std_ def transform(self, X): return (X-self.mean_)/self.std_ def inverse_transform(self, X): return (X*self.std_)+self.mean_ def save(self, directory): with open(directory+"/scaler.config", "w") as f: f.write(str(self.mean_)+"\n") f.write(str(self.std_)+"\n") def load(self, directory): with open(directory+"/scaler.config", "r") as f: self.mean_ = float(f.readline()) self.std_ = float(f.readline()) class InteractionModelATTNForRegression(PreTrainedModel): config_class = InteractionModelATTNConfig def __init__(self, config, target_encoder, drug_encoder, scaler=None): super().__init__(config) self.model = InteractionModelATTN(target_encoder, drug_encoder, scaler, config.attention_dropout, config.hidden_dropout, config.num_heads) self.scaler = scaler def INTERPR_ENABLE_MODE(self): self.model.INTERPR_ENABLE_MODE() def INTERPR_DISABLE_MODE(self): self.model.INTERPR_DISABLE_MODE() def INTERPR_OVERRIDE_ATTN(self, new_weights): self.model.INTERPR_OVERRIDE_ATTN(new_weights) def INTERPR_RESET_OVERRIDE_ATTN(self): self.model.INTERPR_RESET_OVERRIDE_ATTN() def forward(self, x1, x2): return self.model(x1, x2) def unscale(self, x): return self.model.unscale(x) class CrossAttention(nn.Module): def __init__(self, embed_dim, num_heads, attention_dropout=0.0, hidden_dropout=0.0, add_bias_kv=False, **factory_kwargs): """ Initializes the CrossAttention layer. Args: embed_dim (int): Dimension of the input embeddings. num_heads (int): Number of attention heads. dropout (float): Dropout probability for attention weights. """ super().__init__() self.attention_dropout = attention_dropout self.hidden_dropout = hidden_dropout self.embed_dim = embed_dim self.num_heads = num_heads self.head_dim = embed_dim // num_heads self.scaling = self.head_dim ** -0.5 if self.head_dim * num_heads != embed_dim: raise ValueError("embed_dim must be divisible by num_heads") # Linear projections for query, key, and value. self.q_proj = nn.Linear(embed_dim, embed_dim) self.k_proj = nn.Linear(embed_dim, embed_dim) self.v_proj = nn.Linear(embed_dim, embed_dim) self.attn_dropout = nn.Dropout(attention_dropout) xavier_uniform_(self.q_proj.weight) xavier_uniform_(self.k_proj.weight) xavier_uniform_(self.v_proj.weight) constant_(self.q_proj.bias, 0.) constant_(self.k_proj.bias, 0.) constant_(self.v_proj.bias, 0.) # Output projection. self.out_proj = nn.Linear(embed_dim, embed_dim) constant_(self.out_proj.bias, 0) self.drop_out = nn.Dropout(hidden_dropout) def forward(self, query, key, value, key_padding_mask=None, attn_mask=None, replace_weights=None): """ Forward pass for cross attention. Args: query (Tensor): Query embeddings of shape (batch_size, query_len, embed_dim). key (Tensor): Key embeddings of shape (batch_size, key_len, embed_dim). value (Tensor): Value embeddings of shape (batch_size, key_len, embed_dim). attn_mask (Tensor, optional): Attention mask of shape (batch_size, num_heads, query_len, key_len). Returns: output (Tensor): The attended output of shape (batch_size, query_len, embed_dim). attn_weights (Tensor): The attention weights of shape (batch_size, num_heads, query_len, key_len). """ batch_size, query_len, _ = query.size() _, key_len, _ = key.size() Q = self.q_proj(query) K = self.k_proj(key) V = self.v_proj(value) Q = Q.view(batch_size, self.num_heads, query_len, self.head_dim) K = K.view(batch_size, self.num_heads, key_len, self.head_dim) V = V.view(batch_size, self.num_heads, key_len, self.head_dim) # Compute scaled dot-product attention scores scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.head_dim) # (batch_size, num_heads, query_len, key_len) if key_padding_mask is not None: # Convert boolean mask (False -> -inf, True -> 0) key_padding_mask = key_padding_mask.unsqueeze(1).unsqueeze(1) # (B, 1, 1, key_len) for broadcasting scores = scores.masked_fill(key_padding_mask, float('-inf')) # Set masked positions to -inf if replace_weights is not None: scores = replace_weights # Compute attention weights using softmax attn_weights = torch.nn.functional.softmax(scores, dim=-1) # (batch_size, num_heads, query_len, key_len) self.scores = scores if attn_mask is not None: attn_mask = attn_mask.unsqueeze(1) # Shape: (batch_size, 1, query_len, key_len) attn_weights = attn_weights.masked_fill(attn_mask, 0) # Set masked positions to 0 # Optionally apply dropout to the attention weights if self.dropout is defined attn_weights = self.attn_dropout(attn_weights) # Compute the weighted sum of the values attn_output = torch.matmul(attn_weights, V) # (batch_size, num_heads, query_len, head_dim) # Recombine heads: transpose and reshape back to (batch_size, query_len, embed_dim) attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, query_len, self.embed_dim) # Final linear projection and dropout output = self.out_proj(attn_output) output = self.drop_out(output) return output, attn_weights class InteractionModelATTN(nn.Module): def __init__(self, target_encoder, drug_encoder, scaler, attention_dropout, hidden_dropout, num_heads=1, kernel_size=1): super().__init__() self.replace_weights = None self.crossattention_weights = None self.presum_layer = None self.INTERPR_MODE = False self.scaler = scaler self.attention_dropout = attention_dropout self.hidden_dropout = hidden_dropout self.target_encoder = target_encoder self.drug_encoder = drug_encoder self.kernel_size = kernel_size self.lin_map_target = nn.Linear(512, 384) self.dropout_map_target = nn.Dropout(hidden_dropout) self.lin_map_drug = nn.Linear(384, 384) self.dropout_map_drug = nn.Dropout(hidden_dropout) self.crossattention = CrossAttention(384, num_heads, attention_dropout, hidden_dropout) self.norm = nn.LayerNorm(384) self.summary1 = nn.Linear(384, 384) self.summary2 = nn.Linear(384, 1) self.dropout_summary = nn.Dropout(hidden_dropout) self.layer_norm = nn.LayerNorm(384) self.gelu = nn.GELU() self.w = Parameter(torch.empty(512, 1)) self.b = Parameter(torch.zeros(1)) self.pdng = Parameter(torch.tensor(0.0)) # learnable padding value (0-dimensional) xavier_uniform_(self.w) def forward(self, x1, x2): """ Forward pass for attention interaction model. Args: x1 (dict): A dictionary containing input tensors for the target encoder. Expected keys: - 'input_ids' (torch.Tensor): Token IDs for the target input. - 'attention_mask' (torch.Tensor): Attention mask for the target input. x2 (dict): A dictionary containing input tensors for the drug encoder. Expected keys: - 'input_ids' (torch.Tensor): Token IDs for the drug input. - 'attention_mask' (torch.Tensor): Attention mask for the drug input. Returns: torch.Tensor: A tensor representing the predicted binding affinity. """ x1["attention_mask"] = x1["attention_mask"].bool() # Fix dropout model issue: https://github.com/pytorch/pytorch/issues/86120 y1 = self.target_encoder(**x1).last_hidden_state # The target query_mask = x1["attention_mask"].unsqueeze(-1).to(y1.dtype) y1 = y1 * query_mask x2["attention_mask"] = x2["attention_mask"].bool() # Fix dropout model issue: https://github.com/pytorch/pytorch/issues/86120 y2 = self.drug_encoder(**x2).last_hidden_state # The drug key_mask = x2["attention_mask"].unsqueeze(-1).to(y2.dtype) y2 = y2 * key_mask y1 = self.lin_map_target(y1) y1 = self.gelu(y1) y1 = self.dropout_map_target(y1) y2 = self.lin_map_drug(y2) y2 = self.gelu(y2) y2 = self.dropout_map_drug(y2) key_padding_mask=(x2["attention_mask"] == 0) # S replace_weights = None # If in interpretation mode, allow the replacement of cross-attention weights if self.INTERPR_MODE: if self.replace_weights is not None: replace_weights = self.replace_weights out, _ = self.crossattention(y1, y2, y2, key_padding_mask=key_padding_mask, attn_mask=None, replace_weights=replace_weights) # If in interpretation mode, make cross-attention weights and scores accessible from the outside if self.INTERPR_MODE: self.crossattention_weights = _ self.scores = self.crossattention.scores out = self.summary1(out * query_mask) out = self.gelu(out) out = self.dropout_summary(out) out = self.summary2(out).squeeze(-1) # If in interpretation mode, make final summation layer contributions accessible from the outside if self.INTERPR_MODE: self.presum_layer = out weighted = out * self.w.squeeze(1) # [batch, seq_len] padding_positions = ~x1["attention_mask"] # True at padding # assign learnable pdng to all padding positions weighted = weighted.masked_fill(padding_positions, self.pdng.item()) # sum across sequence and add bias result = weighted.sum(dim=1, keepdim=True) + self.b return result def train(self, mode = True): super().train(mode) self.target_encoder.train(mode) self.drug_encoder.train(mode) self.crossattention.train(mode) return self def eval(self): super().eval() self.target_encoder.eval() self.drug_encoder.eval() self.crossattention.eval() return self def INTERPR_ENABLE_MODE(self): """ Enables the interpretability mode for the model. """ if self.training: raise RuntimeError("Cannot enable interpretability mode while the model is training.") self.INTERPR_MODE = True def INTERPR_DISABLE_MODE(self): """ Disables the interpretability mode for the model. """ if self.training: raise RuntimeError("Cannot disable interpretability mode while the model is training.") self.INTERPR_MODE = False def INTERPR_OVERRIDE_ATTN(self, new_weights): self.replace_weights = new_weights def INTERPR_RESET_OVERRIDE_ATTN(self): self.replace_weights = None def unscale(self, x): """ Unscales the labels using a scaler. If the scaler is not specified, don't do anything. Parameters: target_value: the target values to be unscaled """ with torch.no_grad(): if self.scaler is None: return x unscaled = self.scaler.inverse_transform(x) return unscaled