| """ |
| TD3B Scoring Functions |
| Implements gated allosteric reward combining affinity prediction and directional oracle. |
| """ |
|
|
| import os |
| import torch |
| import numpy as np |
| from typing import List, Tuple, Optional |
| from .direction_oracle import DirectionalOracle |
| from scoring.functions.binding import BindingAffinity |
|
|
|
|
| class TD3BRewardFunction: |
| """ |
| Implements the TD3B gated total reward with sigmoid temperature scaling: |
| S_total(y; d*, x) = g_ψ(y, x) · σ(d* · (f_φ(y, x) -0.5) / τ) |
| |
| where: |
| - g_ψ(y, x): affinity predictor (BindingAffinity) |
| - σ: sigmoid function σ(z) = 1 / (1 + exp(-z)) |
| - d* ∈ {+1, -1}: target direction (agonist/antagonist) |
| - f_φ(y, x): directional oracle (DirectionalOracle) |
| * Directional oracle outputs p(agonist) in [0, 1] |
| - τ: temperature parameter (lower = sharper distribution) |
| - y: peptide sequence |
| - x: target protein sequence |
| |
| Note: The placeholder oracle outputs 0.5, which makes (f_φ - 0.5) = 0, resulting in |
| neutral gating during initial training before a real oracle is trained. |
| |
| Benefits of sigmoid formulation: |
| 1. Output always in [0, 1] → bounded gated rewards |
| 2. Temperature τ controls sharpness of selection |
| 3. Differentiable gating for smooth optimization |
| 4. Sharper discrimination between aligned and misaligned directions |
| |
| OLD FORMULA (replaced): |
| S_total(y; d*, x) = g_ψ(y, x) · (d* · f_φ(y, x)) |
| """ |
|
|
| def __init__( |
| self, |
| affinity_predictor: BindingAffinity, |
| directional_oracle: DirectionalOracle, |
| target_direction: float, |
| target_protein_tokens: torch.Tensor, |
| peptide_tokenizer, |
| device: torch.device, |
| min_affinity_threshold: float = 0.0, |
| use_confidence_weighting: bool = True, |
| temperature: float = 0.1 |
| ): |
| """ |
| Args: |
| affinity_predictor: Pretrained g_ψ model (BindingAffinity) |
| directional_oracle: Pretrained f_φ model (DirectionalOracle) |
| target_direction: d* in {+1, -1} for agonist/antagonist |
| target_protein_tokens: Tokenized target protein sequence |
| peptide_tokenizer: Tokenizer for converting SMILES to tokens |
| device: Computation device |
| min_affinity_threshold: Only apply directional control if g_ψ > threshold |
| use_confidence_weighting: Whether to use κ(y) for importance weights |
| temperature: Temperature τ for sigmoid sharpening (lower = sharper) |
| Default 0.1 makes distribution sharper than standard sigmoid |
| """ |
| self.g_psi = affinity_predictor |
| self.f_phi = directional_oracle |
| self.target_direction = target_direction |
| self.protein_tokens = target_protein_tokens |
| self.peptide_tokenizer = peptide_tokenizer |
| self.device = device |
| self.min_affinity_threshold = min_affinity_threshold |
| self.use_confidence_weighting = use_confidence_weighting |
| self.temperature = temperature |
|
|
| def compute_affinity(self, peptide_seqs: List[str]) -> np.ndarray: |
| """ |
| Compute binding affinity g_ψ(y, x). |
| |
| Args: |
| peptide_seqs: List of peptide SMILES strings |
| Returns: |
| affinities: (N,) array of affinity scores |
| """ |
| affinities = self.g_psi(peptide_seqs) |
| return np.array(affinities) |
|
|
| def compute_direction(self, peptide_seqs: List[str]) -> Tuple[torch.Tensor, torch.Tensor]: |
| """ |
| Compute directional bias f_φ(y, x) and confidence κ(y). |
| |
| Args: |
| peptide_seqs: List of peptide SMILES strings |
| Returns: |
| directions: (N,) tensor of directional biases |
| - DirectionalOracle: p(agonist) in [0, 1] |
| confidences: (N,) tensor of confidence scores in [0, 1] |
| """ |
| |
| peptide_tokens = None |
| peptide_token_dict = None |
| try: |
| peptide_token_dict = self.peptide_tokenizer( |
| peptide_seqs, |
| return_tensors='pt', |
| padding=True |
| ) |
| peptide_token_dict = {k: v.to(self.device) for k, v in peptide_token_dict.items()} |
| peptide_tokens = peptide_token_dict.get('input_ids') |
| except Exception: |
| peptide_tokens_list = [] |
| for seq in peptide_seqs: |
| tokens = self.peptide_tokenizer(seq, return_tensors='pt', padding=True) |
| peptide_tokens_list.append(tokens['input_ids'].to(self.device)) |
|
|
| |
| try: |
| peptide_tokens = torch.cat(peptide_tokens_list, dim=0) |
| except Exception: |
| |
| max_len = max(t.size(1) for t in peptide_tokens_list) |
| peptide_tokens = torch.zeros(len(peptide_tokens_list), max_len, dtype=torch.long, device=self.device) |
| for i, tokens in enumerate(peptide_tokens_list): |
| peptide_tokens[i, :tokens.size(1)] = tokens[0] |
|
|
| |
| protein_tokens = self.protein_tokens.expand(len(peptide_seqs), -1) |
|
|
| |
| with torch.no_grad(): |
| if peptide_token_dict is not None and hasattr(self.f_phi, "_normalize_token_dict"): |
| directions, confidences = self.f_phi.predict_with_confidence( |
| peptide_token_dict, protein_tokens |
| ) |
| else: |
| directions, confidences = self.f_phi.predict_with_confidence( |
| peptide_tokens, protein_tokens |
| ) |
|
|
| return directions, confidences |
|
|
| def compute_gated_reward( |
| self, |
| peptide_seqs: List[str] |
| ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: |
| """ |
| Compute gated total reward with sigmoid temperature scaling. |
| |
| NEW FORMULA: |
| S_total = g_ψ · σ(d* · (f_φ-0.5) / τ) |
| |
| Where: |
| - g_ψ: affinity score |
| - σ: sigmoid function |
| - d*: target direction (+1 or -1) |
| - f_φ: directional oracle prediction (in [-1, +1]) |
| - τ: temperature (lower = sharper distribution) |
| |
| OLD FORMULA (replaced): |
| S_total = g_ψ · (d* · f_φ) |
| |
| Args: |
| peptide_seqs: List of peptide SMILES strings |
| Returns: |
| total_rewards: (N,) array of gated total rewards |
| affinities: (N,) array of affinity scores g_ψ |
| confidences: (N,) array of confidence scores κ |
| directions: (N,) array of directional predictions f_φ |
| """ |
| |
| affinities = self.compute_affinity(peptide_seqs) |
|
|
| |
| directions, confidences = self.compute_direction(peptide_seqs) |
| directions = directions.cpu().numpy() |
| confidences = confidences.cpu().numpy() |
|
|
| |
| |
| directional_score = self.target_direction * (directions - 0.5) |
|
|
| |
| scaled_score = directional_score / self.temperature |
|
|
| |
| |
| sigmoid_weight = 1.0 / (1.0 + np.exp(-scaled_score)) |
|
|
| |
| gated_rewards = affinities * sigmoid_weight |
|
|
| |
| |
| low_affinity_mask = affinities < self.min_affinity_threshold |
| gated_rewards[low_affinity_mask] = affinities[low_affinity_mask] * 0.1 |
|
|
| return gated_rewards, affinities, confidences, directions |
|
|
| def __call__( |
| self, |
| input_seqs: List[str] |
| ) -> Tuple[np.ndarray, dict]: |
| """ |
| Main interface for reward computation. |
| |
| Args: |
| input_seqs: List of peptide SMILES strings |
| Returns: |
| rewards: (N,) array of total rewards |
| info: dict with 'affinities', 'confidences', 'directions', 'score_vectors' |
| """ |
| total_rewards, affinities, confidences, directions = self.compute_gated_reward(input_seqs) |
|
|
| info = { |
| 'affinities': affinities, |
| 'confidences': confidences, |
| 'directions': directions, |
| 'score_vectors': np.stack([affinities, total_rewards], axis=1) |
| } |
|
|
| return total_rewards, info |
|
|
|
|
| class TD3BConfidenceWeighting: |
| """ |
| Implements confidence-weighted importance sampling for TD3B. |
| |
| The importance weights w(y) are modulated by confidence κ(y): |
| w(y) = κ(y) · exp(S_total(y) / α) |
| |
| This distinguishes between: |
| - Full agonists/antagonists: high κ (|f_φ| ≈ 1) |
| - Partial agonists/antagonists: medium κ (|f_φ| ≈ 0.5) |
| - Non-selective binders: low κ (|f_φ| ≈ 0) |
| """ |
|
|
| def __init__( |
| self, |
| alpha: float = 0.1, |
| min_confidence: float = 0.1 |
| ): |
| """ |
| Args: |
| alpha: Temperature parameter for reward scaling |
| min_confidence: Minimum confidence threshold |
| """ |
| self.alpha = alpha |
| self.min_confidence = min_confidence |
|
|
| def compute_importance_weights( |
| self, |
| rewards: np.ndarray, |
| confidences: np.ndarray |
| ) -> np.ndarray: |
| """ |
| Compute confidence-weighted importance weights. |
| |
| Args: |
| rewards: (N,) array of total rewards S_total |
| confidences: (N,) array of confidence scores κ ∈ [0, 1] |
| Returns: |
| weights: (N,) array of importance weights |
| """ |
| |
| confidences = np.maximum(confidences, self.min_confidence) |
|
|
| |
| log_weights = rewards / self.alpha |
| weights = confidences * np.exp(log_weights) |
|
|
| return weights |
|
|
| def compute_log_importance_weights( |
| self, |
| rewards: np.ndarray, |
| confidences: np.ndarray |
| ) -> np.ndarray: |
| """ |
| Compute log importance weights for numerical stability. |
| |
| Args: |
| rewards: (N,) array of total rewards |
| confidences: (N,) array of confidence scores |
| Returns: |
| log_weights: (N,) array of log importance weights |
| """ |
| |
| confidences = np.maximum(confidences, self.min_confidence) |
|
|
| |
| log_weights = np.log(confidences) + (rewards / self.alpha) |
|
|
| return log_weights |
|
|
|
|
| |
| def create_td3b_reward_function( |
| affinity_predictor: BindingAffinity, |
| target_protein_seq: str, |
| target_direction: str, |
| peptide_tokenizer, |
| device: torch.device, |
| directional_oracle: Optional[DirectionalOracle] = None, |
| directional_oracle_checkpoint: Optional[str] = None, |
| base_path: Optional[str] = None, |
| direction_oracle_tr2d2_checkpoint: Optional[str] = None, |
| direction_oracle_tokenizer_vocab: Optional[str] = None, |
| direction_oracle_tokenizer_splits: Optional[str] = None, |
| direction_oracle_esm_name: str = "facebook/esm2_t33_650M_UR50D", |
| direction_oracle_esm_cache_dir: Optional[str] = None, |
| direction_oracle_esm_local_files_only: bool = False, |
| direction_oracle_max_ligand_length: int = 768, |
| direction_oracle_max_protein_length: int = 1024, |
| direction_oracle_d_model: int = 256, |
| direction_oracle_n_heads: int = 4, |
| direction_oracle_n_self_attn_layers: int = 1, |
| direction_oracle_n_bmca_layers: int = 2, |
| direction_oracle_dropout: float = 0.3, |
| **kwargs |
| ) -> TD3BRewardFunction: |
| """ |
| Factory function to create TD3B reward function. |
| |
| Args: |
| affinity_predictor: Pretrained binding affinity model |
| directional_oracle: Preloaded DirectionalOracle instance (optional) |
| directional_oracle_checkpoint: Path to Directional oracle checkpoint (optional if instance provided) |
| base_path: Base path for default oracle assets |
| direction_oracle_tr2d2_checkpoint: TR2-D2 checkpoint for ligand encoder |
| direction_oracle_tokenizer_vocab: SMILES tokenizer vocab path |
| direction_oracle_tokenizer_splits: SMILES tokenizer splits path |
| target_protein_seq: Target protein amino acid sequence |
| target_direction: 'agonist' (+1) or 'antagonist' (-1) |
| peptide_tokenizer: Tokenizer for peptides |
| device: Computation device |
| **kwargs: Additional arguments for TD3BRewardFunction |
| |
| Returns: |
| reward_function: TD3BRewardFunction instance |
| """ |
| if directional_oracle is None: |
| if base_path is None: |
| base_path = "To Be Added" |
| tr2d2_root = os.path.join(base_path, "tr2d2-pep") |
| if directional_oracle_checkpoint is None: |
| directional_oracle_checkpoint = os.path.join( |
| tr2d2_root, "direction_oracle.pt" |
| ) |
| if direction_oracle_tr2d2_checkpoint is None: |
| direction_oracle_tr2d2_checkpoint = os.path.join( |
| tr2d2_root, "pretrained", "peptune-pretrained.ckpt" |
| ) |
| if direction_oracle_tokenizer_vocab is None: |
| direction_oracle_tokenizer_vocab = os.path.join( |
| tr2d2_root, "tokenizer", "new_vocab.txt" |
| ) |
| if direction_oracle_tokenizer_splits is None: |
| direction_oracle_tokenizer_splits = os.path.join( |
| tr2d2_root, "tokenizer", "new_splits.txt" |
| ) |
|
|
| directional_oracle = DirectionalOracle( |
| model_ckpt=directional_oracle_checkpoint, |
| tr2d2_checkpoint=direction_oracle_tr2d2_checkpoint, |
| tokenizer_vocab=direction_oracle_tokenizer_vocab, |
| tokenizer_splits=direction_oracle_tokenizer_splits, |
| esm_name=direction_oracle_esm_name, |
| d_model=direction_oracle_d_model, |
| n_heads=direction_oracle_n_heads, |
| n_self_attn_layers=direction_oracle_n_self_attn_layers, |
| n_bmca_layers=direction_oracle_n_bmca_layers, |
| dropout=direction_oracle_dropout, |
| max_ligand_length=direction_oracle_max_ligand_length, |
| max_protein_length=direction_oracle_max_protein_length, |
| device=device, |
| esm_cache_dir=direction_oracle_esm_cache_dir, |
| esm_local_files_only=direction_oracle_esm_local_files_only, |
| ) |
|
|
| directional_oracle.eval() |
|
|
| protein_tokens = directional_oracle.encode_protein(target_protein_seq) |
|
|
| |
| direction_map = {'agonist': +1.0, 'antagonist': -1.0} |
| d_star = direction_map.get(target_direction.lower(), +1.0) |
|
|
| |
| reward_function = TD3BRewardFunction( |
| affinity_predictor=affinity_predictor, |
| directional_oracle=directional_oracle, |
| target_direction=d_star, |
| target_protein_tokens=protein_tokens, |
| peptide_tokenizer=peptide_tokenizer, |
| device=device, |
| **kwargs |
| ) |
|
|
| return reward_function |
|
|