JAK_ML / pages /test_chembert.py
eikoenchine's picture
Upload 137 files
5b07bd4
raw history blame
No virus
4.45 kB
import torch
import torch.nn as nn
from transformers import AutoModelWithLMHead, AutoTokenizer
import os
from tqdm import tqdm
import pandas as pd
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from sklearn.metrics import classification_report, confusion_matrix, average_precision_score, roc_auc_score
import math
import matplotlib.pyplot as plt
import numpy as np
from sklearn.model_selection import KFold, train_test_split
adj_max=80
fps_len=167
max_len=120
device = torch.device('cpu')
model_path = 'model/'
class chembert_encoder(nn.Module):
def __init__(self, output_dim=fps_len,dropout=0.5):
super(chembert_encoder, self).__init__()
self.bert = AutoModelWithLMHead.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
self.tokenizer = AutoTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
self.dropout=nn.Dropout(dropout)
self.w=nn.Linear(767,output_dim)
def forward(self, x):
input_feat = self.tokenizer.batch_encode_plus(x, max_length=512,
padding='longest', # implements dynamic padding
truncation=True,
return_tensors='pt',
return_attention_mask=True,
return_token_type_ids=True
)
if cuda_available:
input_feat['attention_mask'] = input_feat['attention_mask'].cuda()
input_feat['input_ids'] = input_feat['input_ids'].cuda()
outputs = self.bert(input_feat['input_ids'], attention_mask=input_feat['attention_mask'],output_hidden_states=None).logits[:,0,:]
return self.w(self.dropout(outputs))
class pretrain_dataset(Dataset):
def __init__(self,dataframe, max_len=max_len):
super(pretrain_dataset, self).__init__()
self.len = len(dataframe)
self.dataframe = dataframe
self.max_len = max_len
def __getitem__(self, idx):
sml = self.dataframe.canonical_smiles[idx]
chem_id = self.dataframe.chembl_id[idx]
s = self.dataframe.fps[idx]
s = list(s)
adj = torch.tensor([int(b) for b in s])
return sml, adj, chem_id
def __len__(self):
return self.len
class jak_dataset(Dataset):
def __init__(self,dataframe):
super(jak_dataset, self).__init__()
self.len = len(dataframe)
self.dataframe = dataframe
def __getitem__(self, idx):
sml = self.dataframe.Smiles[idx]
y = 1 if self.dataframe.Activity[idx] == 1 else 0
return sml, y
def __len__(self):
return self.len
class chembert(nn.Module):
def __init__(self, load_path='model/chem_bert_encoder_pretrain_9.pt',
last_layer_size=fps_len, output_size=2, dropout=0.5):
super(chembert, self).__init__()
self.last_layer_size = last_layer_size
self.output_size = output_size
self.pretrained = chembert_encoder()
self.pretrained.load_state_dict(torch.load(load_path, map_location=device))
self.w = nn.Linear(self.last_layer_size, self.output_size)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
return self.w(self.dropout(self.pretrained(x)))
def chembert_predict(enzyme, smi):
ml = 'chembert'
known_drugs = [smi]
file_path = 'model/' + ml + '_' + enzyme + '.pt'
model = chembert()
optimizer = optim.AdamW(params=model.parameters(), lr=1e-5, weight_decay=1e-2)
model.load_state_dict(torch.load(file_path, map_location=torch.device('cpu')))
weight_dict = {1: torch.tensor([3.0, 1.0]), 2: torch.tensor([2.0, 1.0]), 3: torch.tensor([2.0, 1.0]),
4: torch.tensor([2.0, 1.0])}
params = {'batch_size': 16, 'shuffle': False, 'drop_last': False, 'num_workers': 0}
model.eval()
known_df = pd.DataFrame(known_drugs)
known_df.columns = ['Smiles']
known_df['Activity'] = 0
known_data = jak_dataset(known_df)
known_loader = DataLoader(known_data, **params)
for idx, (X, y_true) in tqdm(enumerate(known_loader), total=len(known_loader)):
model.eval()
output = model(list(X))
a, y_pred = torch.max(output, 1)
y_prob = torch.softmax(output,1)[:, 1].tolist()
return y_prob, y_pred