|
import os |
|
import pandas as pd |
|
|
|
import torch |
|
from torch.nn import functional as F |
|
from transformers import AutoTokenizer |
|
|
|
from util.utils import * |
|
from rdkit import Chem |
|
from tqdm import tqdm |
|
from train import markerModel |
|
|
|
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" |
|
os.environ["CUDA_VISIBLE_DEVICES"] = '0 ' |
|
|
|
device_count = torch.cuda.device_count() |
|
device_biomarker = torch.device('cuda' if torch.cuda.is_available() else "cpu") |
|
|
|
device = torch.device('cpu') |
|
a_model_name = 'DeepChem/ChemBERTa-10M-MLM' |
|
d_model_name = 'DeepChem/ChemBERTa-10M-MTR' |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(a_model_name) |
|
d_tokenizer = AutoTokenizer.from_pretrained(d_model_name) |
|
|
|
|
|
|
|
config = load_hparams('config/predict.json') |
|
config = DictX(config) |
|
model = markerModel(config.d_model_name, config.p_model_name, |
|
config.lr, config.dropout, config.layer_features, config.loss_fn, config.layer_limit, config.pretrained['chem'], config.pretrained['prot']) |
|
|
|
model = markerModel.load_from_checkpoint(config.load_checkpoint,strict=False) |
|
model.eval() |
|
model.freeze() |
|
|
|
if device_biomarker.type == 'cuda': |
|
model = torch.nn.DataParallel(model) |
|
|
|
def get_marker(drug_inputs, prot_inputs): |
|
output_preds = model(drug_inputs, prot_inputs) |
|
|
|
predict = torch.squeeze( (output_preds)).tolist() |
|
|
|
|
|
|
|
|
|
|
|
return predict |
|
|
|
|
|
def marker_prediction(smiles, aas): |
|
try: |
|
aas_input = [] |
|
for ass_data in aas: |
|
aas_input.append(' '.join(list(ass_data))) |
|
|
|
a_inputs = tokenizer(smiles, padding='max_length', max_length=510, truncation=True, return_tensors="pt") |
|
|
|
a_input_ids = a_inputs['input_ids'].to(device) |
|
a_attention_mask = a_inputs['attention_mask'].to(device) |
|
a_inputs = {'input_ids': a_input_ids, 'attention_mask': a_attention_mask} |
|
|
|
d_inputs = d_tokenizer(aas_input, padding='max_length', max_length=510, truncation=True, return_tensors="pt") |
|
|
|
d_input_ids = d_inputs['input_ids'].to(device) |
|
d_attention_mask = d_inputs['attention_mask'].to(device) |
|
d_inputs = {'input_ids': d_input_ids, 'attention_mask': d_attention_mask} |
|
|
|
output_list = get_marker(a_inputs, d_inputs) |
|
|
|
|
|
return output_list |
|
|
|
except Exception as e: |
|
print(e) |
|
return {'Error_message': e} |
|
|
|
|
|
def smiles_aas_test(smile_acc,smile_don): |
|
|
|
mola = Chem.MolFromSmiles(smile_acc) |
|
smile_acc = Chem.MolToSmiles(mola, canonical=True) |
|
mold = Chem.MolFromSmiles(smile_don) |
|
smile_don = Chem.MolToSmiles(mold, canonical=True) |
|
|
|
batch_size = 1 |
|
|
|
datas = [] |
|
marker_list = [] |
|
marker_datas = [] |
|
|
|
|
|
|
|
marker_datas.append([smile_acc,smile_don]) |
|
if len(marker_datas) == batch_size: |
|
marker_list.append(list(marker_datas)) |
|
marker_datas.clear() |
|
|
|
if len(marker_datas) != 0: |
|
marker_list.append(list(marker_datas)) |
|
marker_datas.clear() |
|
|
|
for marker_datas in tqdm(marker_list, total=len(marker_list)): |
|
smiles_d , smiles_a = zip(*marker_datas) |
|
output_pred = marker_prediction(list(smiles_d), list(smiles_a) ) |
|
if len(datas) == 0: |
|
datas = output_pred |
|
else: |
|
datas = datas + output_pred |
|
|
|
|
|
|
|
|
|
|
|
|
|
return datas |
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
a = smiles_aas_test('CCCCCCCCCCCC1=C(/C=C2\C(=O)C3=C(C=C(F)C(F)=C3)C2=C(C#N)C#N)SC2=C1SC1=C2N(CC(CC)CCCC)C2=C1C1=NSN=C1C1=C2N(CC(CC)CCCC)C2=C1SC1=C2SC(/C=C2\C(=O)C3=C(C=C(F)C(F)=C3)C2=C(C#N)C#N)=C1CCCCCCCCCCC','CCCCCCC(CCCC)CC1=C(C)SC(C2=CC3=C(S2)C2=C(C=C(C4=CC(CC(CCCC)CCCCCC)=C(C5=CC6=C(C7=CC=C(CC(CC)CCCC)S7)C7=C(C=C(C)S7)C(C7=CC=C(CC(CC)CCCC)S7)=C6S5)S4)S2)C2=NSN=C23)=C1') |
|
|
|
|