| | from __future__ import annotations |
| |
|
| | import contextlib |
| | import hashlib |
| | import io |
| | import json |
| | import os |
| | import re |
| | from dataclasses import dataclass |
| | from functools import lru_cache |
| | from pathlib import Path |
| | from typing import Any |
| |
|
| | import numpy as np |
| | import torch |
| | import torch.nn.functional as F |
| | from huggingface_hub import snapshot_download |
| | from huggingface_hub.utils import disable_progress_bars |
| | from rdkit import Chem, DataStructs, RDLogger |
| | from rdkit.Chem import AllChem, Crippen, Descriptors, Lipinski, MACCSkeys, rdMolDescriptors |
| | from rdkit.Chem.MolStandardize import rdMolStandardize |
| | from sentence_transformers import SentenceTransformer |
| | from torch import nn |
| | from transformers import AutoModel, AutoTokenizer |
| | from transformers.utils import logging as transformers_logging |
| |
|
| | os.environ.setdefault("HF_HUB_DISABLE_PROGRESS_BARS", "1") |
| | os.environ.setdefault("TOKENIZERS_PARALLELISM", "false") |
| | disable_progress_bars() |
| | transformers_logging.set_verbosity_error() |
| |
|
| | RDLogger.DisableLog("rdApp.*") |
| |
|
| | DEFAULT_ASSAY_TASK = ( |
| | "Given a bioassay description and metadata, represent the assay for ranking compatible small molecules." |
| | ) |
| | DEFAULT_DESCRIPTOR_NAMES = ( |
| | "mol_wt", |
| | "logp", |
| | "tpsa", |
| | "heavy_atoms", |
| | "hbd", |
| | "hba", |
| | "rot_bonds", |
| | "ring_count", |
| | "aromatic_rings", |
| | "aliphatic_rings", |
| | "saturated_rings", |
| | "fraction_csp3", |
| | "heteroatoms", |
| | "amide_bonds", |
| | "fragments", |
| | "formal_charge", |
| | "max_atomic_num", |
| | "metal_atom_count", |
| | "halogen_count", |
| | "nitrogen_count", |
| | "oxygen_count", |
| | "sulfur_count", |
| | "phosphorus_count", |
| | "fluorine_count", |
| | "chlorine_count", |
| | "bromine_count", |
| | "iodine_count", |
| | "aromatic_atom_count", |
| | "spiro_atoms", |
| | "bridgehead_atoms", |
| | ) |
| | ORGANIC_LIKE_ATOMIC_NUMBERS = {1, 5, 6, 7, 8, 9, 14, 15, 16, 17, 35, 53} |
| | SECTION_ORDER = [ |
| | "ASSAY_TITLE", |
| | "DESCRIPTION", |
| | "ORGANISM", |
| | "READOUT", |
| | "ASSAY_FORMAT", |
| | "ASSAY_TYPE", |
| | "TARGET_UNIPROT", |
| | ] |
| | ASSAY_SECTION_RE = re.compile(r"\[(ASSAY_TITLE|DESCRIPTION|ORGANISM|READOUT|ASSAY_FORMAT|ASSAY_TYPE|TARGET_UNIPROT)\]\n") |
| | ORGANISM_ALIASES = { |
| | "9606": "homo_sapiens", |
| | "10090": "mus_musculus", |
| | "10116": "rattus_norvegicus", |
| | "4932": "saccharomyces_cerevisiae", |
| | } |
| |
|
| |
|
| | @dataclass |
| | class AssayQuery: |
| | title: str = "" |
| | description: str = "" |
| | organism: str = "" |
| | readout: str = "" |
| | assay_format: str = "" |
| | assay_type: str = "" |
| | target_uniprot: list[str] | None = None |
| |
|
| |
|
| | def smiles_sha256(smiles: str) -> str: |
| | return hashlib.sha256(smiles.encode("utf-8")).hexdigest() |
| |
|
| |
|
| | @contextlib.contextmanager |
| | def _silent_imports(): |
| | buffer = io.StringIO() |
| | with contextlib.redirect_stdout(buffer), contextlib.redirect_stderr(buffer): |
| | yield |
| |
|
| |
|
| | @lru_cache(maxsize=1_000_000) |
| | def _standardize_smiles_v2_cached(smiles: str) -> str | None: |
| | mol = Chem.MolFromSmiles(smiles) |
| | if mol is None: |
| | return None |
| | try: |
| | mol = rdMolStandardize.Cleanup(mol) |
| | mol = rdMolStandardize.FragmentParent(mol) |
| | mol = rdMolStandardize.Uncharger().uncharge(mol) |
| | mol = rdMolStandardize.TautomerEnumerator().Canonicalize(mol) |
| | Chem.SanitizeMol(mol) |
| | except Exception: |
| | return None |
| | if mol.GetNumHeavyAtoms() < 2: |
| | return None |
| | standardized = Chem.MolToSmiles(mol, canonical=True, isomericSmiles=True) |
| | if not standardized or "." in standardized: |
| | return None |
| | return standardized |
| |
|
| |
|
| | def standardize_smiles_v2(smiles: str | None) -> str | None: |
| | if not smiles: |
| | return None |
| | token = smiles.strip() |
| | if not token: |
| | return None |
| | return _standardize_smiles_v2_cached(token) |
| |
|
| |
|
| | def serialize_assay_query(query: AssayQuery) -> str: |
| | targets = ", ".join(query.target_uniprot or []) |
| | values = { |
| | "ASSAY_TITLE": query.title.strip(), |
| | "DESCRIPTION": query.description.strip(), |
| | "ORGANISM": query.organism.strip(), |
| | "READOUT": query.readout.strip(), |
| | "ASSAY_FORMAT": query.assay_format.strip(), |
| | "ASSAY_TYPE": query.assay_type.strip(), |
| | "TARGET_UNIPROT": targets.strip(), |
| | } |
| | return "\n\n".join(f"[{key}]\n{values[key]}" for key in SECTION_ORDER) |
| |
|
| |
|
| | def _parse_assay_sections(assay_text: str) -> dict[str, str]: |
| | sections = {key: "" for key in SECTION_ORDER} |
| | parts = ASSAY_SECTION_RE.split(assay_text) |
| | for idx in range(1, len(parts), 2): |
| | key = parts[idx] |
| | value = parts[idx + 1] if idx + 1 < len(parts) else "" |
| | if key in sections: |
| | sections[key] = value.strip() |
| | return sections |
| |
|
| |
|
| | def _hash_bucket(value: str, dim: int) -> int: |
| | return abs(hash(value)) % max(dim, 1) |
| |
|
| |
|
| | def _normalize_metadata_token(value: str) -> str: |
| | return re.sub(r"[^a-z0-9]+", "_", value.lower()).strip("_") |
| |
|
| |
|
| | def _normalize_organism_token(value: str) -> str: |
| | raw = value.strip() |
| | if not raw: |
| | return "" |
| | aliased = ORGANISM_ALIASES.get(raw, raw) |
| | return _normalize_metadata_token(aliased) |
| |
|
| |
|
| | def _assay_metadata_vector(assay_text: str, *, dim: int) -> np.ndarray: |
| | if dim <= 0: |
| | return np.zeros((0,), dtype=np.float32) |
| | sections = _parse_assay_sections(assay_text) |
| | tokens: list[str] = [] |
| | organism = _normalize_organism_token(sections.get("ORGANISM", "")) |
| | if organism: |
| | tokens.append(f"organism:{organism}") |
| | for key in ("READOUT", "ASSAY_FORMAT", "ASSAY_TYPE"): |
| | value = _normalize_metadata_token(sections.get(key, "")) |
| | if value: |
| | tokens.append(f"{key.lower()}:{value}") |
| | for target in sections.get("TARGET_UNIPROT", "").split(","): |
| | token = target.strip().upper() |
| | if token: |
| | tokens.append(f"target:{token}") |
| | vec = np.zeros((dim,), dtype=np.float32) |
| | for token in tokens: |
| | vec[_hash_bucket(token, dim)] += 1.0 |
| | norm = float(np.linalg.norm(vec)) |
| | if norm > 0: |
| | vec /= norm |
| | return vec |
| |
|
| |
|
| | def _morgan_bits_from_mol(mol, *, radius: int, n_bits: int, use_chirality: bool) -> np.ndarray: |
| | fp = AllChem.GetMorganFingerprintAsBitVect(mol, radius, nBits=n_bits, useChirality=use_chirality) |
| | arr = np.zeros((n_bits,), dtype=np.uint8) |
| | DataStructs.ConvertToNumpyArray(fp, arr) |
| | return arr |
| |
|
| |
|
| | def _maccs_bits_from_mol(mol) -> np.ndarray: |
| | fp = MACCSkeys.GenMACCSKeys(mol) |
| | arr = np.zeros((fp.GetNumBits(),), dtype=np.uint8) |
| | DataStructs.ConvertToNumpyArray(fp, arr) |
| | return arr |
| |
|
| |
|
| | def _count_atomic_nums(mol) -> dict[int, int]: |
| | counts: dict[int, int] = {} |
| | for atom in mol.GetAtoms(): |
| | atomic_num = int(atom.GetAtomicNum()) |
| | counts[atomic_num] = counts.get(atomic_num, 0) + 1 |
| | return counts |
| |
|
| |
|
| | def _molecule_descriptor_vector(mol, *, names: tuple[str, ...] = DEFAULT_DESCRIPTOR_NAMES) -> np.ndarray: |
| | counts = _count_atomic_nums(mol) |
| | fragments = Chem.GetMolFrags(mol) |
| | formal_charge = sum(int(atom.GetFormalCharge()) for atom in mol.GetAtoms()) |
| | max_atomic_num = max(counts) if counts else 0 |
| | metal_atom_count = sum(count for atomic_num, count in counts.items() if atomic_num not in ORGANIC_LIKE_ATOMIC_NUMBERS) |
| | halogen_count = sum(counts.get(item, 0) for item in (9, 17, 35, 53)) |
| | aromatic_atom_count = sum(1 for atom in mol.GetAtoms() if atom.GetIsAromatic()) |
| | values = { |
| | "mol_wt": float(Descriptors.MolWt(mol)), |
| | "logp": float(Crippen.MolLogP(mol)), |
| | "tpsa": float(rdMolDescriptors.CalcTPSA(mol)), |
| | "heavy_atoms": float(mol.GetNumHeavyAtoms()), |
| | "hbd": float(Lipinski.NumHDonors(mol)), |
| | "hba": float(Lipinski.NumHAcceptors(mol)), |
| | "rot_bonds": float(Lipinski.NumRotatableBonds(mol)), |
| | "ring_count": float(rdMolDescriptors.CalcNumRings(mol)), |
| | "aromatic_rings": float(rdMolDescriptors.CalcNumAromaticRings(mol)), |
| | "aliphatic_rings": float(rdMolDescriptors.CalcNumAliphaticRings(mol)), |
| | "saturated_rings": float(rdMolDescriptors.CalcNumSaturatedRings(mol)), |
| | "fraction_csp3": float(rdMolDescriptors.CalcFractionCSP3(mol)), |
| | "heteroatoms": float(rdMolDescriptors.CalcNumHeteroatoms(mol)), |
| | "amide_bonds": float(rdMolDescriptors.CalcNumAmideBonds(mol)), |
| | "fragments": float(len(fragments)), |
| | "formal_charge": float(formal_charge), |
| | "max_atomic_num": float(max_atomic_num), |
| | "metal_atom_count": float(metal_atom_count), |
| | "halogen_count": float(halogen_count), |
| | "nitrogen_count": float(counts.get(7, 0)), |
| | "oxygen_count": float(counts.get(8, 0)), |
| | "sulfur_count": float(counts.get(16, 0)), |
| | "phosphorus_count": float(counts.get(15, 0)), |
| | "fluorine_count": float(counts.get(9, 0)), |
| | "chlorine_count": float(counts.get(17, 0)), |
| | "bromine_count": float(counts.get(35, 0)), |
| | "iodine_count": float(counts.get(53, 0)), |
| | "aromatic_atom_count": float(aromatic_atom_count), |
| | "spiro_atoms": float(rdMolDescriptors.CalcNumSpiroAtoms(mol)), |
| | "bridgehead_atoms": float(rdMolDescriptors.CalcNumBridgeheadAtoms(mol)), |
| | } |
| | return np.array([values[name] for name in names], dtype=np.float32) |
| |
|
| |
|
| | def molecule_ui_metrics(smiles: str) -> dict[str, float | int]: |
| | canonical = standardize_smiles_v2(smiles) or smiles |
| | mol = Chem.MolFromSmiles(canonical) |
| | if mol is None: |
| | return { |
| | "mol_wt": 0.0, |
| | "logp": 0.0, |
| | "tpsa": 0.0, |
| | "heavy_atoms": 0, |
| | } |
| | return { |
| | "mol_wt": float(Descriptors.MolWt(mol)), |
| | "logp": float(Crippen.MolLogP(mol)), |
| | "tpsa": float(rdMolDescriptors.CalcTPSA(mol)), |
| | "heavy_atoms": int(mol.GetNumHeavyAtoms()), |
| | } |
| |
|
| |
|
| | class CompatibilityHead(nn.Module): |
| | def __init__(self, *, assay_dim: int, molecule_dim: int, projection_dim: int, hidden_dim: int, dropout: float) -> None: |
| | super().__init__() |
| | self.assay_norm = nn.LayerNorm(assay_dim) |
| | self.assay_proj = nn.Linear(assay_dim, projection_dim) |
| | self.mol_norm = nn.LayerNorm(molecule_dim) |
| | self.mol_proj = nn.Linear(molecule_dim, projection_dim, bias=False) |
| | self.score_mlp = nn.Sequential( |
| | nn.Linear(projection_dim * 4, hidden_dim), |
| | nn.GELU(), |
| | nn.Dropout(dropout), |
| | nn.Linear(hidden_dim, 1), |
| | ) |
| | self.dot_scale = nn.Parameter(torch.tensor(1.0, dtype=torch.float32)) |
| |
|
| | def encode_assay(self, assay_features: torch.Tensor) -> torch.Tensor: |
| | vec = self.assay_proj(self.assay_norm(assay_features)) |
| | return F.normalize(vec, p=2, dim=-1) |
| |
|
| | def encode_molecule(self, molecule_features: torch.Tensor) -> torch.Tensor: |
| | vec = self.mol_proj(self.mol_norm(molecule_features)) |
| | return F.normalize(vec, p=2, dim=-1) |
| |
|
| | def score_candidates(self, assay_features: torch.Tensor, candidate_features: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
| | assay_vec = self.encode_assay(assay_features) |
| | mol_vec = self.encode_molecule(candidate_features) |
| | assay_expand = assay_vec.unsqueeze(1).expand(-1, mol_vec.shape[1], -1) |
| | dot_scores = (assay_expand * mol_vec).sum(dim=-1) |
| | mlp_input = torch.cat( |
| | [assay_expand, mol_vec, assay_expand * mol_vec, torch.abs(assay_expand - mol_vec)], |
| | dim=-1, |
| | ) |
| | mlp_scores = self.score_mlp(mlp_input).squeeze(-1) |
| | logits = dot_scores * self.dot_scale + mlp_scores |
| | return logits, assay_vec, mol_vec |
| |
|
| |
|
| | class SpaceCompatibilityModel: |
| | def __init__( |
| | self, |
| | *, |
| | assay_encoder: SentenceTransformer, |
| | compatibility_head: CompatibilityHead, |
| | assay_task_description: str, |
| | fingerprint_radii: tuple[int, ...], |
| | fingerprint_bits: int, |
| | use_chirality: bool, |
| | use_maccs: bool, |
| | use_rdkit_descriptors: bool, |
| | descriptor_names: tuple[str, ...], |
| | descriptor_mean: np.ndarray | None, |
| | descriptor_std: np.ndarray | None, |
| | molecule_transformer_model_name: str, |
| | molecule_transformer_batch_size: int, |
| | molecule_transformer_max_length: int, |
| | use_assay_metadata_features: bool, |
| | assay_metadata_dim: int, |
| | ) -> None: |
| | self.assay_encoder = assay_encoder |
| | self.compatibility_head = compatibility_head.eval() |
| | self.assay_task_description = assay_task_description |
| | self.fingerprint_radii = fingerprint_radii |
| | self.fingerprint_bits = fingerprint_bits |
| | self.use_chirality = use_chirality |
| | self.use_maccs = use_maccs |
| | self.use_rdkit_descriptors = use_rdkit_descriptors |
| | self.descriptor_names = descriptor_names |
| | self.descriptor_mean = descriptor_mean |
| | self.descriptor_std = descriptor_std |
| | self.molecule_transformer_model_name = molecule_transformer_model_name |
| | self.molecule_transformer_batch_size = molecule_transformer_batch_size |
| | self.molecule_transformer_max_length = molecule_transformer_max_length |
| | self.use_assay_metadata_features = use_assay_metadata_features |
| | self.assay_metadata_dim = assay_metadata_dim |
| | self._molecule_transformer_tokenizer = None |
| | self._molecule_transformer_model = None |
| | self._molecule_transformer_device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| |
|
| | def _format_assay_query(self, assay_text: str) -> str: |
| | return f"Instruct: {self.assay_task_description.strip()}\nQuery: {assay_text.strip()}" |
| |
|
| | def _build_assay_feature_array(self, assay_text: str) -> np.ndarray: |
| | assay_features = self.assay_encoder.encode( |
| | [self._format_assay_query(assay_text)], |
| | batch_size=1, |
| | normalize_embeddings=True, |
| | show_progress_bar=False, |
| | convert_to_numpy=True, |
| | )[0].astype(np.float32) |
| | if self.use_assay_metadata_features and self.assay_metadata_dim > 0: |
| | metadata_vec = _assay_metadata_vector(assay_text, dim=self.assay_metadata_dim) |
| | assay_features = np.concatenate([assay_features, metadata_vec.astype(np.float32)], axis=0) |
| | return assay_features |
| |
|
| | def _ensure_molecule_transformer_loaded(self) -> None: |
| | if not self.molecule_transformer_model_name or self._molecule_transformer_model is not None: |
| | return |
| | dtype = torch.float16 if self._molecule_transformer_device.type == "cuda" else torch.float32 |
| | with _silent_imports(): |
| | self._molecule_transformer_tokenizer = AutoTokenizer.from_pretrained( |
| | self.molecule_transformer_model_name, |
| | trust_remote_code=True, |
| | ) |
| | self._molecule_transformer_model = AutoModel.from_pretrained( |
| | self.molecule_transformer_model_name, |
| | trust_remote_code=True, |
| | torch_dtype=dtype, |
| | ).to(self._molecule_transformer_device) |
| | self._molecule_transformer_model.eval() |
| |
|
| | def _encode_molecule_transformer_batch(self, smiles_values: list[str]) -> np.ndarray | None: |
| | if not self.molecule_transformer_model_name: |
| | return None |
| | self._ensure_molecule_transformer_loaded() |
| | assert self._molecule_transformer_model is not None |
| | assert self._molecule_transformer_tokenizer is not None |
| | outputs: list[np.ndarray] = [] |
| | batch_size = max(self.molecule_transformer_batch_size, 1) |
| | with torch.no_grad(): |
| | for start in range(0, len(smiles_values), batch_size): |
| | batch = smiles_values[start : start + batch_size] |
| | encoded = self._molecule_transformer_tokenizer( |
| | batch, |
| | padding=True, |
| | truncation=True, |
| | max_length=self.molecule_transformer_max_length, |
| | return_tensors="pt", |
| | ) |
| | encoded = {key: value.to(self._molecule_transformer_device) for key, value in encoded.items()} |
| | hidden = self._molecule_transformer_model(**encoded).last_hidden_state |
| | mask = encoded["attention_mask"].unsqueeze(-1) |
| | pooled = (hidden * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1) |
| | outputs.append(pooled.detach().cpu().to(torch.float32).numpy()) |
| | return np.concatenate(outputs, axis=0).astype(np.float32) |
| |
|
| | def build_molecule_feature_matrix(self, smiles_values: list[str]) -> np.ndarray: |
| | transformer_matrix = self._encode_molecule_transformer_batch(smiles_values) |
| | rows: list[np.ndarray] = [] |
| | for idx, smiles in enumerate(smiles_values): |
| | normalized = standardize_smiles_v2(smiles) or smiles |
| | mol = Chem.MolFromSmiles(normalized) |
| | if mol is None: |
| | raise ValueError(f"Could not parse SMILES: {normalized}") |
| | bit_blocks: list[np.ndarray] = [ |
| | _morgan_bits_from_mol(mol, radius=int(radius), n_bits=self.fingerprint_bits, use_chirality=self.use_chirality) |
| | for radius in self.fingerprint_radii |
| | ] |
| | if self.use_maccs: |
| | bit_blocks.append(_maccs_bits_from_mol(mol)) |
| | output_blocks: list[np.ndarray] = [np.concatenate(bit_blocks, axis=0).astype(np.float32)] |
| | if self.use_rdkit_descriptors and self.descriptor_names: |
| | dense = _molecule_descriptor_vector(mol, names=self.descriptor_names) |
| | if self.descriptor_mean is not None and self.descriptor_std is not None: |
| | dense = (dense - self.descriptor_mean) / self.descriptor_std |
| | output_blocks.append(dense.astype(np.float32)) |
| | if transformer_matrix is not None: |
| | output_blocks.append(np.asarray(transformer_matrix[idx], dtype=np.float32)) |
| | rows.append(np.concatenate(output_blocks, axis=0).astype(np.float32)) |
| | return np.stack(rows, axis=0) |
| |
|
| |
|
| | def _load_sentence_transformer(model_name: str) -> SentenceTransformer: |
| | dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32 |
| | with _silent_imports(): |
| | encoder = SentenceTransformer( |
| | model_name, |
| | trust_remote_code=True, |
| | model_kwargs={"torch_dtype": dtype}, |
| | ) |
| | if getattr(encoder, "tokenizer", None) is not None: |
| | encoder.tokenizer.padding_side = "left" |
| | return encoder |
| |
|
| |
|
| | def _load_feature_spec(cfg: dict[str, Any], metadata: dict[str, Any], checkpoint: dict[str, Any]) -> dict[str, Any]: |
| | spec = checkpoint.get("molecule_feature_spec") or metadata.get("molecule_feature_spec") |
| | if spec: |
| | return spec |
| | radii = tuple(int(item) for item in (cfg.get("fingerprint_radii") or [cfg.get("fingerprint_radius", 2)])) |
| | return { |
| | "fingerprint_radii": list(radii), |
| | "fingerprint_bits": int(cfg["fingerprint_bits"]), |
| | "use_chirality": bool(cfg.get("use_chirality", False)), |
| | "use_maccs": bool(cfg.get("use_maccs", False)), |
| | "use_rdkit_descriptors": bool(cfg.get("use_rdkit_descriptors", False)), |
| | "descriptor_names": [], |
| | "descriptor_mean": None, |
| | "descriptor_std": None, |
| | "molecule_transformer_model_name": str(cfg.get("molecule_transformer_model_name") or ""), |
| | "molecule_transformer_max_length": int(cfg.get("molecule_transformer_max_length", 128) or 128), |
| | } |
| |
|
| |
|
| | def load_compatibility_model(model_dir: str | Path) -> SpaceCompatibilityModel: |
| | model_path = Path(model_dir) |
| | checkpoint = torch.load(model_path / "best_model.pt", map_location="cpu", weights_only=False) |
| | metadata = json.loads((model_path / "training_metadata.json").read_text()) |
| | cfg = metadata["config"] |
| | feature_spec = _load_feature_spec(cfg, metadata, checkpoint) |
| |
|
| | encoder = _load_sentence_transformer(checkpoint.get("assay_model_name") or cfg["assay_model_name"]) |
| | assay_dim = int(checkpoint["model_state_dict"]["assay_proj.weight"].shape[1]) |
| | molecule_dim = int(checkpoint["model_state_dict"]["mol_proj.weight"].shape[1]) |
| | head = CompatibilityHead( |
| | assay_dim=assay_dim, |
| | molecule_dim=molecule_dim, |
| | projection_dim=int(cfg["projection_dim"]), |
| | hidden_dim=int(cfg["hidden_dim"]), |
| | dropout=float(cfg["dropout"]), |
| | ) |
| | load_result = head.load_state_dict(checkpoint["model_state_dict"], strict=False) |
| | allowed_missing = {"mol_norm.weight", "mol_norm.bias"} |
| | unexpected = set(load_result.unexpected_keys) |
| | missing = set(load_result.missing_keys) |
| | if unexpected or (missing - allowed_missing): |
| | raise RuntimeError( |
| | f"Checkpoint mismatch: unexpected={sorted(unexpected)} missing={sorted(missing)}" |
| | ) |
| | return SpaceCompatibilityModel( |
| | assay_encoder=encoder, |
| | compatibility_head=head, |
| | assay_task_description=checkpoint.get("assay_task_description") or cfg.get("assay_task_description", DEFAULT_ASSAY_TASK), |
| | fingerprint_radii=tuple(int(item) for item in feature_spec.get("fingerprint_radii") or [2]), |
| | fingerprint_bits=int(feature_spec.get("fingerprint_bits", cfg.get("fingerprint_bits", 2048))), |
| | use_chirality=bool(feature_spec.get("use_chirality", cfg.get("use_chirality", False))), |
| | use_maccs=bool(feature_spec.get("use_maccs", cfg.get("use_maccs", False))), |
| | use_rdkit_descriptors=bool(feature_spec.get("use_rdkit_descriptors", cfg.get("use_rdkit_descriptors", False))), |
| | descriptor_names=tuple(feature_spec.get("descriptor_names") or ()), |
| | descriptor_mean=np.array(feature_spec["descriptor_mean"], dtype=np.float32) if feature_spec.get("descriptor_mean") is not None else None, |
| | descriptor_std=np.array(feature_spec["descriptor_std"], dtype=np.float32) if feature_spec.get("descriptor_std") is not None else None, |
| | molecule_transformer_model_name=str(feature_spec.get("molecule_transformer_model_name") or cfg.get("molecule_transformer_model_name") or ""), |
| | molecule_transformer_batch_size=int(cfg.get("molecule_transformer_batch_size", 128) or 128), |
| | molecule_transformer_max_length=int(feature_spec.get("molecule_transformer_max_length") or cfg.get("molecule_transformer_max_length", 128) or 128), |
| | use_assay_metadata_features=bool(cfg.get("use_assay_metadata_features", False)), |
| | assay_metadata_dim=int(cfg.get("assay_metadata_dim", 0) or 0), |
| | ) |
| |
|
| |
|
| | @lru_cache(maxsize=1) |
| | def load_compatibility_model_from_hub(model_repo_id: str) -> SpaceCompatibilityModel: |
| | with _silent_imports(): |
| | model_dir = snapshot_download( |
| | repo_id=model_repo_id, |
| | repo_type="model", |
| | allow_patterns=["best_model.pt", "training_metadata.json", "README.md"], |
| | ) |
| | return load_compatibility_model(model_dir) |
| |
|
| |
|
| | def rank_compounds( |
| | model: SpaceCompatibilityModel, |
| | *, |
| | assay_text: str, |
| | smiles_list: list[str], |
| | top_k: int | None = None, |
| | ) -> list[dict[str, Any]]: |
| | if not smiles_list: |
| | return [] |
| | assay_features = model._build_assay_feature_array(assay_text) |
| | assay_tensor = torch.from_numpy(assay_features.astype(np.float32)).unsqueeze(0) |
| |
|
| | valid_items: list[tuple[str, str]] = [] |
| | invalid_items: list[dict[str, Any]] = [] |
| | for raw_smiles in smiles_list: |
| | standardized = standardize_smiles_v2(raw_smiles) |
| | if standardized is None: |
| | invalid_items.append( |
| | { |
| | "input_smiles": raw_smiles, |
| | "canonical_smiles": None, |
| | "smiles_hash": None, |
| | "score": None, |
| | "valid": False, |
| | "error": "invalid_smiles", |
| | } |
| | ) |
| | continue |
| | valid_items.append((raw_smiles, standardized)) |
| |
|
| | ranked_items: list[dict[str, Any]] = [] |
| | if valid_items: |
| | feature_matrix = model.build_molecule_feature_matrix([item[1] for item in valid_items]) |
| | candidate_tensor = torch.from_numpy(feature_matrix).unsqueeze(0) |
| | with torch.no_grad(): |
| | logits, _, _ = model.compatibility_head.score_candidates( |
| | assay_tensor.to(dtype=torch.float32), |
| | candidate_tensor.to(dtype=torch.float32), |
| | ) |
| | scores = logits.squeeze(0).cpu().numpy().tolist() |
| | for (raw_smiles, canonical), score in zip(valid_items, scores, strict=True): |
| | ranked_items.append( |
| | { |
| | "input_smiles": raw_smiles, |
| | "canonical_smiles": canonical, |
| | "smiles_hash": smiles_sha256(canonical), |
| | "score": float(score), |
| | "valid": True, |
| | } |
| | ) |
| | ranked_items.sort(key=lambda item: item["score"], reverse=True) |
| | if top_k is not None and top_k > 0: |
| | ranked_items = ranked_items[:top_k] |
| |
|
| | return ranked_items + invalid_items |
| |
|