Spaces:
Running
on
Zero
Running
on
Zero
import sys | |
from typing import List | |
from tqdm import tqdm | |
import pandas as pd | |
import numpy as np | |
import threading | |
from concurrent.futures import ProcessPoolExecutor, as_completed, TimeoutError | |
import time | |
import requests | |
import joblib | |
# from bio_embeddings.embed import SeqVecEmbedder, ProtTransBertBFDEmbedder, ProtTransT5XLU50Embedder | |
from Bio import SeqIO | |
import rdkit | |
from rdkit import Chem, DataStructs | |
from rdkit.Chem import AllChem | |
import torch | |
from typing import * | |
from rdkit import RDLogger | |
RDLogger.DisableLog("rdApp.*") | |
from xgboost import XGBClassifier, DMatrix | |
from model.barlow_twins import BarlowTwins | |
# sys.path.append("../utils/") | |
from utils.sequence import uniprot2sequence, encode_sequences | |
class DTIModel: | |
def __init__(self, bt_model_path: str, gbm_model_path: str, encoder: str = "prost_t5"): | |
self.bt_model = BarlowTwins() | |
self.bt_model.load_model(bt_model_path) | |
self.gbm_model = XGBClassifier() | |
self.gbm_model.load_model(gbm_model_path) | |
self.encoder = encoder | |
self.smiles_cache = {} | |
self.sequence_cache = {} | |
def _encode_smiles(self, smiles: str, radius: int = 2, bits: int = 1024, features: bool = False): | |
if smiles is None: | |
return None | |
# Check if the SMILES is already in the cache | |
if smiles in self.smiles_cache: | |
return self.smiles_cache[smiles] | |
else: | |
# Encode the SMILES and store it in the cache | |
try: | |
mol = Chem.MolFromSmiles(smiles) | |
morgan = AllChem.GetMorganFingerprintAsBitVect( | |
mol, | |
radius=radius, | |
nBits=bits, | |
useFeatures=features, | |
) | |
morgan = np.array(morgan) | |
self.smiles_cache[smiles] = morgan | |
return morgan | |
except Exception as e: | |
print(f"Failed to encode SMILES: {smiles}") | |
print(e) | |
return None | |
def _encode_smiles_mult(self, smiles: List[str], radius: int = 2, bits: int = 1024, features: bool = False): | |
morgan = [self._encode_smiles(s, radius, bits, features) for s in smiles] | |
return np.array(morgan) | |
def _encode_sequence(self, sequence: str): | |
# Clear torch cache | |
torch.cuda.empty_cache() | |
if sequence is None: | |
return None | |
# Check if the sequence is already in the cache | |
if sequence in self.sequence_cache: | |
return self.sequence_cache[sequence] | |
else: | |
# Encode the sequence and store it in the cache | |
try: | |
encoded_sequence = encode_sequences([sequence], encoder=self.encoder) | |
self.sequence_cache[sequence] = encoded_sequence | |
return encoded_sequence | |
except Exception as e: | |
print(f"Failed to encode sequence: {sequence}") | |
print(e) | |
return None | |
def _encode_sequence_mult(self, sequences: List[str]): | |
seq = [self._encode_sequence(sequence) for sequence in sequences] | |
return np.array(seq) | |
def __predict_pair(self, drug_emb: np.ndarray, target_emb: np.ndarray, pred_leaf: bool): | |
if drug_emb.shape[0] < target_emb.shape[0]: | |
drug_emb = np.tile(drug_emb, (len(target_emb), 1)) | |
elif len(drug_emb) > len(target_emb): | |
target_emb = np.tile(target_emb, (len(drug_emb), 1)) | |
emb = self.bt_model.zero_shot(drug_emb, target_emb) | |
if pred_leaf: | |
d_emb = DMatrix(emb) | |
return self.gbm_model.get_booster().predict(d_emb, pred_leaf=True) | |
else: | |
return self.gbm_model.predict_proba(emb)[:, 1] | |
def predict(self, drug: List[str] or str, target: str, pred_leaf: bool = False): | |
if isinstance(drug, str): | |
drug_emb = self._encode_smiles(drug) | |
else: | |
drug_emb = self._encode_smiles_mult(drug) | |
target_emb = self._encode_sequence(target) | |
return self.__predict_pair(drug_emb, target_emb, pred_leaf) | |
def get_leaf_weights(self): | |
return self.gbm_model.get_booster().get_score(importance_type="weight") | |
def _predict_fasta(self, drug: str, fasta_path: str): | |
drug_emb = self._encode_smiles(drug) | |
results = [] | |
# Extract targets from fasta | |
for target in tqdm(SeqIO.parse(fasta_path, "fasta"), desc="Predicting targets"): | |
target_emb = self._encode_sequence(str(target.seq)) | |
pred = self.__predict_pair(drug_emb, target_emb) | |
results.append( | |
{ | |
"drug": drug, | |
"target": target.id, | |
"name": target.name, | |
"description": target.description, | |
"prediction": pred[0] | |
} | |
) | |
return pd.DataFrame(results) | |
def predict_fasta(self, drug: str, fasta_path: str, timeout_seconds: int = 120): | |
def process_target(target, results): | |
target_emb = self._encode_sequence(str(target.seq)) | |
pred = self.__predict_pair(drug_emb, target_emb) | |
results.append({ | |
"drug": drug, | |
"target": target.id, | |
"name": target.name, | |
"description": target.description, | |
"prediction": pred[0] | |
}) | |
drug_emb = self._encode_smiles(drug) | |
results = [] | |
# First, count the total number of records for the progress bar | |
total_records = sum(1 for _ in SeqIO.parse(fasta_path, "fasta")) | |
# Extract targets from fasta with a properly initialized tqdm progress bar | |
for target in tqdm(SeqIO.parse(fasta_path, "fasta"), total=total_records, desc="Predicting targets"): | |
thread_results = [] | |
thread = threading.Thread(target=process_target, args=(target, thread_results)) | |
thread.start() | |
thread.join(timeout_seconds) | |
if thread.is_alive(): | |
print(f"Skipping target {target.id} due to timeout") | |
continue | |
results.extend(thread_results) | |
return pd.DataFrame(results) | |
def predict_uniprot(self, drug: List[str] or str, uniprot_id: str): | |
return self.predict(drug, uniprot2sequence(uniprot_id)) | |