File size: 4,450 Bytes
5b07bd4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import torch
import torch.nn as nn 
from transformers import AutoModelWithLMHead, AutoTokenizer
import os
from tqdm import tqdm
import pandas as pd
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from sklearn.metrics import classification_report, confusion_matrix, average_precision_score, roc_auc_score
import math
import matplotlib.pyplot as plt
import numpy as np
from sklearn.model_selection import KFold, train_test_split
adj_max=80
fps_len=167
max_len=120
device = torch.device('cpu')
model_path = 'model/'
class chembert_encoder(nn.Module):
    def __init__(self, output_dim=fps_len,dropout=0.5):
        super(chembert_encoder, self).__init__()
        self.bert = AutoModelWithLMHead.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
        self.tokenizer = AutoTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
        self.dropout=nn.Dropout(dropout)
        self.w=nn.Linear(767,output_dim)

    def forward(self, x):
        input_feat = self.tokenizer.batch_encode_plus(x, max_length=512,
                                                 padding='longest',  # implements dynamic padding
                                                 truncation=True,
                                                 return_tensors='pt',
                                                 return_attention_mask=True,
                                                 return_token_type_ids=True
                                                 )

        if cuda_available:
            input_feat['attention_mask'] = input_feat['attention_mask'].cuda()
            input_feat['input_ids'] = input_feat['input_ids'].cuda()


        outputs = self.bert(input_feat['input_ids'], attention_mask=input_feat['attention_mask'],output_hidden_states=None).logits[:,0,:]
        return self.w(self.dropout(outputs))
class pretrain_dataset(Dataset):
    def __init__(self,dataframe, max_len=max_len):
        super(pretrain_dataset, self).__init__()
        self.len = len(dataframe)
        self.dataframe = dataframe
        self.max_len = max_len
        
    def __getitem__(self, idx):
        sml = self.dataframe.canonical_smiles[idx]
        chem_id = self.dataframe.chembl_id[idx]
        s = self.dataframe.fps[idx]
        s = list(s)
        adj = torch.tensor([int(b) for b in s])
        return sml, adj, chem_id
    def __len__(self):
        return self.len
class jak_dataset(Dataset):
    def __init__(self,dataframe):
        super(jak_dataset, self).__init__()
        self.len = len(dataframe)
        self.dataframe = dataframe
        
    def __getitem__(self, idx):
        sml = self.dataframe.Smiles[idx]
        y = 1 if self.dataframe.Activity[idx] ==  1 else 0
        return sml, y
    def __len__(self):
        return self.len
class chembert(nn.Module):
    def __init__(self, load_path='model/chem_bert_encoder_pretrain_9.pt', 
                 last_layer_size=fps_len, output_size=2, dropout=0.5):
        super(chembert, self).__init__()
        self.last_layer_size = last_layer_size
        self.output_size = output_size
        self.pretrained = chembert_encoder()
        self.pretrained.load_state_dict(torch.load(load_path, map_location=device))
        self.w = nn.Linear(self.last_layer_size, self.output_size)
        self.dropout = nn.Dropout(dropout)
    def forward(self, x):
        return self.w(self.dropout(self.pretrained(x)))
    
def chembert_predict(enzyme, smi):
    ml = 'chembert'
    known_drugs = [smi]

    file_path = 'model/' + ml + '_' + enzyme + '.pt'
    model = chembert()
    optimizer = optim.AdamW(params=model.parameters(), lr=1e-5, weight_decay=1e-2)
    model.load_state_dict(torch.load(file_path, map_location=torch.device('cpu')))
    weight_dict = {1: torch.tensor([3.0, 1.0]), 2: torch.tensor([2.0, 1.0]), 3: torch.tensor([2.0, 1.0]),
                    4: torch.tensor([2.0, 1.0])}
    params = {'batch_size': 16, 'shuffle': False, 'drop_last': False, 'num_workers': 0}
    model.eval()

    known_df = pd.DataFrame(known_drugs)
    known_df.columns = ['Smiles']
    known_df['Activity'] = 0
    known_data = jak_dataset(known_df)
    known_loader = DataLoader(known_data, **params)
    for idx, (X, y_true) in tqdm(enumerate(known_loader), total=len(known_loader)):
        model.eval()
        output = model(list(X))
        a, y_pred = torch.max(output, 1)
        y_prob = torch.softmax(output,1)[:, 1].tolist()

    return y_prob, y_pred