File size: 4,072 Bytes
a23d0e4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import os
import pandas as pd

import torch
from torch.nn import functional as F
from transformers import AutoTokenizer

from util.utils import *

from tqdm import tqdm
from train import markerModel

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = '0 '

device_count = torch.cuda.device_count()
device_biomarker = torch.device('cuda' if torch.cuda.is_available() else "cpu")

device = torch.device('cpu')
a_model_name = 'DeepChem/ChemBERTa-10M-MLM'
d_model_name = 'DeepChem/ChemBERTa-10M-MTR'

tokenizer = AutoTokenizer.from_pretrained(a_model_name)
d_tokenizer = AutoTokenizer.from_pretrained(d_model_name)

#--biomarker Model
##-- hyper param config file Load --##
config = load_hparams('config/predict.json')
config = DictX(config)
model = markerModel(config.d_model_name, config.p_model_name,
                              config.lr, config.dropout, config.layer_features, config.loss_fn, config.layer_limit, config.pretrained['chem'], config.pretrained['prot'])
 
model = markerModel.load_from_checkpoint(config.load_checkpoint,strict=False)
model.eval()
model.freeze()
    
if device_biomarker.type == 'cuda':
    model = torch.nn.DataParallel(model)
    
def get_marker(drug_inputs, prot_inputs):
    output_preds = model(drug_inputs, prot_inputs)
   
    predict = torch.squeeze( (output_preds)).tolist()

    # output_preds = torch.relu(output_preds)
    # predict = torch.tanh(output_preds)
    # predict = predict.squeeze(dim=1).tolist()

    return predict


def marker_prediction(smiles, aas):
    try:
        aas_input = []
        for ass_data in aas:
            aas_input.append(' '.join(list(ass_data)))
    
        a_inputs = tokenizer(smiles, padding='max_length', max_length=510, truncation=True, return_tensors="pt")
        # d_inputs = tokenizer(smiles, truncation=True, return_tensors="pt")
        a_input_ids = a_inputs['input_ids'].to(device)
        a_attention_mask = a_inputs['attention_mask'].to(device)
        a_inputs = {'input_ids': a_input_ids, 'attention_mask': a_attention_mask}

        d_inputs = d_tokenizer(aas_input, padding='max_length', max_length=510, truncation=True, return_tensors="pt")
        # p_inputs = prot_tokenizer(aas_input, truncation=True, return_tensors="pt")
        d_input_ids = d_inputs['input_ids'].to(device)
        d_attention_mask = d_inputs['attention_mask'].to(device)
        d_inputs = {'input_ids': d_input_ids, 'attention_mask': d_attention_mask}
 
        output_predict = get_marker(a_inputs, d_inputs)
        
        output_list = [{'acceptor': smiles[i], 'donor': aas[i], 'predict': output_predict[i]} for i in range(0,len(aas))]

        return output_list

    except Exception as e:
        print(e)
        return {'Error_message': e}


def smiles_aas_test(file):
     
    batch_size = 80
    try:
        datas = []
        marker_list = []
        marker_datas = []

        smiles_aas = pd.read_csv(file)
        
        ## -- 1 to 1 pair predict check -- ##
        for data in smiles_aas.values:
            mola =  Chem.MolFromSmiles(data[2]) 
            data[2] = Chem.MolToSmiles(mola,   canonical=True)
            mola =  Chem.MolFromSmiles(data[1]) 
            data[1] = Chem.MolToSmiles(mola,   canonical=True)                        
            marker_datas.append([data[2], data[1]])
            if len(marker_datas) == batch_size:
                marker_list.append(list(marker_datas))
                marker_datas.clear()

        if len(marker_datas) != 0:
            marker_list.append(list(marker_datas))
            marker_datas.clear()
            
        for marker_datas in tqdm(marker_list, total=len(marker_list)):
            smiles_d , smiles_a  = zip(*marker_datas)
            output_pred = marker_prediction(list(smiles_d), list(smiles_a) )
            if len(datas) == 0:
                datas = output_pred
            else:
                
                datas = datas + output_pred
        
 
        return datas
        
    except Exception as e:
        print(e)
        return {'Error_message': e}