Spaces:
Sleeping
Sleeping
File size: 6,359 Bytes
84bfd88 005026b 84bfd88 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 |
import sys
from typing import List
from tqdm import tqdm
import pandas as pd
import numpy as np
import threading
from concurrent.futures import ProcessPoolExecutor, as_completed, TimeoutError
import time
import requests
import joblib
# from bio_embeddings.embed import SeqVecEmbedder, ProtTransBertBFDEmbedder, ProtTransT5XLU50Embedder
from Bio import SeqIO
import rdkit
from rdkit import Chem, DataStructs
from rdkit.Chem import AllChem
import torch
from typing import *
from rdkit import RDLogger
RDLogger.DisableLog("rdApp.*")
from xgboost import XGBClassifier, DMatrix
from model.barlow_twins import BarlowTwins
# sys.path.append("../utils/")
from utils.sequence import uniprot2sequence, encode_sequences
class DTIModel:
def __init__(self, bt_model_path: str, gbm_model_path: str, encoder: str = "prost_t5"):
self.bt_model = BarlowTwins()
self.bt_model.load_model(bt_model_path)
self.gbm_model = XGBClassifier()
self.gbm_model.load_model(gbm_model_path)
self.encoder = encoder
self.smiles_cache = {}
self.sequence_cache = {}
def _encode_smiles(self, smiles: str, radius: int = 2, bits: int = 1024, features: bool = False):
if smiles is None:
return None
# Check if the SMILES is already in the cache
if smiles in self.smiles_cache:
return self.smiles_cache[smiles]
else:
# Encode the SMILES and store it in the cache
try:
mol = Chem.MolFromSmiles(smiles)
morgan = AllChem.GetMorganFingerprintAsBitVect(
mol,
radius=radius,
nBits=bits,
useFeatures=features,
)
morgan = np.array(morgan)
self.smiles_cache[smiles] = morgan
return morgan
except Exception as e:
print(f"Failed to encode SMILES: {smiles}")
print(e)
return None
def _encode_smiles_mult(self, smiles: List[str], radius: int = 2, bits: int = 1024, features: bool = False):
morgan = [self._encode_smiles(s, radius, bits, features) for s in smiles]
return np.array(morgan)
def _encode_sequence(self, sequence: str):
# Clear torch cache
torch.cuda.empty_cache()
if sequence is None:
return None
# Check if the sequence is already in the cache
if sequence in self.sequence_cache:
return self.sequence_cache[sequence]
else:
# Encode the sequence and store it in the cache
try:
encoded_sequence = encode_sequences([sequence], encoder=self.encoder)
self.sequence_cache[sequence] = encoded_sequence
return encoded_sequence
except Exception as e:
print(f"Failed to encode sequence: {sequence}")
print(e)
return None
def _encode_sequence_mult(self, sequences: List[str]):
seq = [self._encode_sequence(sequence) for sequence in sequences]
return np.array(seq)
def __predict_pair(self, drug_emb: np.ndarray, target_emb: np.ndarray, pred_leaf: bool):
if drug_emb.shape[0] < target_emb.shape[0]:
drug_emb = np.tile(drug_emb, (len(target_emb), 1))
elif len(drug_emb) > len(target_emb):
target_emb = np.tile(target_emb, (len(drug_emb), 1))
emb = self.bt_model.zero_shot(drug_emb, target_emb)
if pred_leaf:
d_emb = DMatrix(emb)
return self.gbm_model.get_booster().predict(d_emb, pred_leaf=True)
else:
return self.gbm_model.predict_proba(emb)[:, 1]
def predict(self, drug: List[str] or str, target: str, pred_leaf: bool = False):
if isinstance(drug, str):
drug_emb = self._encode_smiles(drug)
else:
drug_emb = self._encode_smiles_mult(drug)
target_emb = self._encode_sequence(target)
return self.__predict_pair(drug_emb, target_emb, pred_leaf)
def get_leaf_weights(self):
return self.gbm_model.get_booster().get_score(importance_type="weight")
def _predict_fasta(self, drug: str, fasta_path: str):
drug_emb = self._encode_smiles(drug)
results = []
# Extract targets from fasta
for target in tqdm(SeqIO.parse(fasta_path, "fasta"), desc="Predicting targets"):
target_emb = self._encode_sequence(str(target.seq))
pred = self.__predict_pair(drug_emb, target_emb)
results.append(
{
"drug": drug,
"target": target.id,
"name": target.name,
"description": target.description,
"prediction": pred[0]
}
)
return pd.DataFrame(results)
def predict_fasta(self, drug: str, fasta_path: str, timeout_seconds: int = 120):
def process_target(target, results):
target_emb = self._encode_sequence(str(target.seq))
pred = self.__predict_pair(drug_emb, target_emb)
results.append({
"drug": drug,
"target": target.id,
"name": target.name,
"description": target.description,
"prediction": pred[0]
})
drug_emb = self._encode_smiles(drug)
results = []
# First, count the total number of records for the progress bar
total_records = sum(1 for _ in SeqIO.parse(fasta_path, "fasta"))
# Extract targets from fasta with a properly initialized tqdm progress bar
for target in tqdm(SeqIO.parse(fasta_path, "fasta"), total=total_records, desc="Predicting targets"):
thread_results = []
thread = threading.Thread(target=process_target, args=(target, thread_results))
thread.start()
thread.join(timeout_seconds)
if thread.is_alive():
print(f"Skipping target {target.id} due to timeout")
continue
results.extend(thread_results)
return pd.DataFrame(results)
def predict_uniprot(self, drug: List[str] or str, uniprot_id: str):
return self.predict(drug, uniprot2sequence(uniprot_id))
|