|
import os |
|
import pandas as pd |
|
|
|
import torch |
|
from torch.nn import functional as F |
|
from transformers import AutoTokenizer |
|
|
|
from util.utils import * |
|
|
|
from tqdm import tqdm |
|
from train import markerModel |
|
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" |
|
os.environ["CUDA_VISIBLE_DEVICES"] = '0,1' |
|
|
|
device_count = torch.cuda.device_count() |
|
device_biomarker = torch.device('cuda' if torch.cuda.is_available() else "cpu") |
|
|
|
device = torch.device('cpu') |
|
d_model_name = 'DeepChem/ChemBERTa-10M-MTR' |
|
p_model_name = 'DeepChem/ChemBERTa-10M-MLM' |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(d_model_name) |
|
prot_tokenizer = AutoTokenizer.from_pretrained(p_model_name) |
|
|
|
|
|
|
|
config = load_hparams('config/predict.json') |
|
config = DictX(config) |
|
model = markerModel.load_from_checkpoint(config.load_checkpoint,strict=False) |
|
|
|
|
|
|
|
model.eval() |
|
model.freeze() |
|
|
|
if device_biomarker.type == 'cuda': |
|
model = torch.nn.DataParallel(model) |
|
|
|
def get_biomarker(drug_inputs, prot_inputs): |
|
output_preds = model(drug_inputs, prot_inputs) |
|
|
|
predict = torch.squeeze((output_preds)).tolist() |
|
|
|
|
|
|
|
|
|
|
|
return predict |
|
|
|
|
|
def biomarker_prediction(smile_acc, smile_don): |
|
try: |
|
aas_input = smile_acc |
|
|
|
|
|
das_input =smile_don |
|
d_inputs = tokenizer(aas_input, padding='max_length', max_length=400, truncation=True, return_tensors="pt") |
|
|
|
drug_input_ids = d_inputs['input_ids'].to(device) |
|
drug_attention_mask = d_inputs['attention_mask'].to(device) |
|
drug_inputs = {'input_ids': drug_input_ids, 'attention_mask': drug_attention_mask} |
|
|
|
p_inputs = prot_tokenizer(das_input, padding='max_length', max_length=400, truncation=True, return_tensors="pt") |
|
|
|
prot_input_ids = p_inputs['input_ids'].to(device) |
|
prot_attention_mask = p_inputs['attention_mask'].to(device) |
|
prot_inputs = {'input_ids': prot_input_ids, 'attention_mask': prot_attention_mask} |
|
|
|
output_predict = get_biomarker(drug_inputs, prot_inputs) |
|
|
|
return output_predict |
|
|
|
except Exception as e: |
|
print(e) |
|
return {'Error_message': e} |
|
|
|
|
|
def smiles_aas_test(smile_acc,smile_don): |
|
|
|
batch_size = 1 |
|
try: |
|
output_pred = biomarker_prediction((smile_acc), (smile_don)) |
|
|
|
datas = output_pred |
|
|
|
|
|
|
|
|
|
|
|
|
|
return datas |
|
|
|
except Exception as e: |
|
print(e) |
|
return {'Error_message': e} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|