|
from sklearn.model_selection import train_test_split |
|
from utils import tokenize_sequences, rna2vec, seq2vec, rna2vec_pretraining, get_dataset, get_scores, argument_seqset |
|
from encoders import AptaTrans, Token_Pretrained_Model |
|
from utils import API_Dataset, Masked_Dataset |
|
from mcts import MCTS |
|
from torch.utils.data import DataLoader |
|
|
|
import torch |
|
import torch.nn as nn |
|
import sqlite3 |
|
import timeit |
|
import numpy as np |
|
import os |
|
|
|
class AptaTransPipeline: |
|
def __init__( |
|
self, |
|
d_model=128, |
|
d_ff=512, |
|
n_layers=6, |
|
n_heads=8, |
|
dropout=0.1, |
|
load_best_pt=True, |
|
device='cpu', |
|
seed=1004, |
|
): |
|
|
|
self.seed = seed |
|
self.device = device |
|
self.n_apta_vocabs = 1 + 125 + 1 |
|
self.n_apta_target_vocabs = 1 + 343 |
|
self.n_prot_vocabs = 1 + 713 + 1 |
|
self.n_prot_target_vocabs = 1 + 584 |
|
|
|
self.apta_max_len = 275 |
|
self.prot_max_len = 867 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.encoder_aptamer = Token_Pretrained_Model( |
|
n_vocabs=self.n_apta_vocabs, |
|
n_target_vocabs=self.n_apta_target_vocabs, |
|
d_ff=d_ff, |
|
d_model=d_model, |
|
n_layers=n_layers, |
|
n_heads=n_heads, |
|
dropout=dropout, |
|
max_len=self.apta_max_len |
|
).to(device) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if load_best_pt: |
|
try: |
|
self.encoder_aptamer.load_state_dict(torch.load("./encoders/rna_pretrained_encoder.pt")) |
|
|
|
print('Best pre-trained models are loaded!') |
|
except: |
|
print('There are no best pre-trained model files..') |
|
print('You need to pre-train the ecoders!') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def set_data_for_training(self, batch_size): |
|
datapath="./data/new_dataset.pickle" |
|
ds_train, ds_test, ds_bench = get_dataset(datapath, self.prot_max_len, self.n_prot_vocabs, self.prot_words) |
|
|
|
self.train_loader = DataLoader(API_Dataset(ds_train[0], ds_train[1], ds_train[2]), batch_size=batch_size, shuffle=True) |
|
self.test_loader = DataLoader(API_Dataset(ds_test[0], ds_test[1], ds_test[2]), batch_size=batch_size, shuffle=False) |
|
self.bench_loader = DataLoader(API_Dataset(ds_bench[0], ds_bench[1], ds_bench[2]), batch_size=batch_size, shuffle=False) |
|
|
|
def set_data_rna_pt(self, batch_size, masked_rate=0.15): |
|
conn = sqlite3.connect("./data/bpRNA.db") |
|
results = conn.execute("SELECT * FROM RNA") |
|
fetch = results.fetchall() |
|
seqset = [[f[1], f[2]] for f in fetch if len(f[1]) <= 277] |
|
seqset = argument_seqset(seqset) |
|
|
|
train_seq, test_seq = train_test_split(seqset, test_size=0.05, random_state=self.seed) |
|
train_x, train_y = rna2vec_pretraining(train_seq) |
|
test_x, test_y = rna2vec_pretraining(test_seq) |
|
|
|
rna_train = Masked_Dataset(train_x, train_y, self.apta_max_len, masked_rate, self.n_apta_vocabs-1, isrna=True) |
|
rna_test = Masked_Dataset(test_x, test_y, self.apta_max_len, masked_rate, self.n_apta_vocabs-1, isrna=True) |
|
|
|
self.rna_train = DataLoader(rna_train, batch_size=batch_size, shuffle=True) |
|
self.rna_test = DataLoader(rna_test, batch_size=batch_size, shuffle=False) |
|
|
|
def set_data_protein_pt(self, batch_size, masked_rate=0.15): |
|
ss = ['', 'H', 'B', 'E', 'G', 'I', 'T', 'S', '-'] |
|
words_ss = np.array([i + j + k for i in ss for j in ss for k in ss[1:]]) |
|
words_ss = np.unique(words_ss) |
|
words_ss = {word:i+1 for i, word in enumerate(words_ss)} |
|
|
|
conn = sqlite3.connect("./data/protein_ss_keywords.db") |
|
results = conn.execute("SELECT SEQUENCE, SS FROM PROTEIN") |
|
fetch = results.fetchall() |
|
seqset = [[f[0], f[1]] for f in fetch] |
|
|
|
train_seq, test_seq = train_test_split(seqset, test_size=0.05, random_state=self.seed) |
|
train_x, train_y = seq2vec(train_seq, self.prot_max_len, self.n_prot_vocabs, self.n_prot_target_vocabs, self.prot_words, words_ss) |
|
test_x, test_y = seq2vec(test_seq, self.prot_max_len, self.n_prot_vocabs, self.n_prot_target_vocabs, self.prot_words, words_ss) |
|
|
|
protein_train = Masked_Dataset(train_x, train_y, self.prot_max_len, masked_rate, self.n_prot_vocabs-1) |
|
protein_test = Masked_Dataset(test_x, test_y, self.prot_max_len, masked_rate, self.n_prot_vocabs-1) |
|
|
|
self.protein_train = DataLoader(protein_train, batch_size=batch_size, shuffle=True) |
|
self.protein_test = DataLoader(protein_test, batch_size=batch_size, shuffle=False) |
|
|
|
def train(self, epochs, lr = 1e-5): |
|
print('Training the model!') |
|
best_auc = 0 |
|
self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=lr, weight_decay=1e-5) |
|
self.criterion = nn.BCELoss().to(self.device) |
|
|
|
for epoch in range(1, epochs+1): |
|
self.model.train() |
|
loss_train, pred_train, target_train = self.batch_step(self.train_loader, train_mode=True) |
|
print("\n[EPOCH: {}], \tTrain Loss: {: .6f}".format(epoch, loss_train), end='') |
|
self.model.eval() |
|
with torch.no_grad(): |
|
loss_test, pred_test, target_test = self.batch_step(self.test_loader, train_mode=False) |
|
scores = get_scores(target_test, pred_test) |
|
print("\tTest Loss: {: .6f}\tTest ACC: {:.6f}\tTest AUC: {:.6f}\tTest MCC: {:.6f}\tTest PR_AUC: {:.6f}\tF1: {:.6f}\n".format(loss_test, scores['acc'], scores['roc_auc'], scores['mcc'], scores['pr_auc'], scores['f1'])) |
|
|
|
if scores['roc_auc'] > best_auc: |
|
best_auc = scores['roc_auc'] |
|
torch.save(self.model.module.state_dict(), "./models/AptaTrans_best_auc.pt") |
|
print('Saved at ./models/AptaTrans_best_auc.pt!') |
|
print('Done!') |
|
|
|
def batch_step(self, loader, train_mode = True): |
|
loss_total = 0 |
|
pred = np.array([]) |
|
target = np.array([]) |
|
for batch_idx, (apta, prot, y) in enumerate(loader): |
|
if train_mode: |
|
self.optimizer.zero_grad() |
|
|
|
y_pred = self.predict(apta, prot) |
|
y_true = torch.tensor(y, dtype=torch.float32).to(self.device) |
|
loss = self.criterion(torch.flatten(y_pred), y_true) |
|
|
|
if train_mode: |
|
loss.backward() |
|
self.optimizer.step() |
|
|
|
loss_total += loss.item() |
|
|
|
pred = np.append(pred, torch.flatten(y_pred).clone().detach().cpu().numpy()) |
|
target = np.append(target, torch.flatten(y_true).clone().detach().cpu().numpy()) |
|
mode = 'train' if train_mode else 'eval' |
|
print(mode + "[{}/{}({:.0f}%)]".format(batch_idx, len(loader), 100. * batch_idx / len(loader)), end = "\r", flush=True) |
|
loss_total /= len(loader) |
|
return loss_total, pred, target |
|
|
|
def predict(self, apta, prot): |
|
apta, prot = apta.to(self.device), prot.to(self.device) |
|
y_pred = self.model(apta, prot) |
|
return y_pred |
|
|
|
def pretain_aptamer(self, epochs, lr=1e-5): |
|
savepath = "./models/rna_pretrained_encoder.pt" |
|
self.encoder_aptamer = self.pretraining(self.encoder_aptamer, self.rna_train, self.rna_test, savepath, epochs, lr) |
|
|
|
def pretrain_protein(self, epochs, lr=1e-5): |
|
savepath = "./models/protein_pretrained_encoder.pt" |
|
self.encoder_protein = self.pretraining(self.encoder_protein, self.protein_train, self.protein_test, savepath, epochs, lr) |
|
|
|
def pretraining(self, model, train_loader, test_loader, savepath_model, epochs, lr=1e-5): |
|
print('Pre-training the model') |
|
|
|
self.optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-5) |
|
self.criterion = nn.CrossEntropyLoss().to(self.device) |
|
best = 1000 |
|
|
|
start_time = timeit.default_timer() |
|
for epoch in range(1, epochs+1): |
|
model.train() |
|
model, loss_train, mlm_train, ssp_train = self.batch_step_pt(model, train_loader, train_mode=True) |
|
print("\n[EPOCH: {}], \tTrain Loss: {: .6f}\tTrain mlm: {: .6f}\tTrain ssp: {: .6f}".format(epoch, loss_train, mlm_train, ssp_train), end='') |
|
model.eval() |
|
with torch.no_grad(): |
|
model, loss_test, mlm_test, ssp_test = self.batch_step_pt(model, test_loader, train_mode=False) |
|
print("\Test Loss: {: .6f}\tTest mlm: {: .6f}\tTest ssp: {: .6f} ".format(epoch, loss_test, mlm_test, ssp_test)) |
|
|
|
terminate_time = timeit.default_timer() |
|
time = terminate_time-start_time |
|
print("the time is %02d:%02d:%2f" % ((time//3600), (time//60)%60, time%60)) |
|
if (loss_test) < best: |
|
best = loss_test |
|
torch.save(model.module.state_dict(), savepath_model) |
|
|
|
return model |
|
|
|
|
|
def batch_step_pt(self, model, loader, train_mode=True): |
|
loss_total, loss_mlm, loss_ssp = 0, 0, 0 |
|
for batch_idx, (x_masked, y_masked, x, y_ss) in enumerate(loader): |
|
if train_mode: |
|
self.optimizer.zero_grad() |
|
|
|
inputs_mlm, inputs = x_masked.to(self.device), x.to(self.device) |
|
y_pred_mlm, y_pred_ssp = model(inputs_mlm, inputs) |
|
|
|
l_mlm = self.criterion(torch.transpose(y_pred_mlm, 1, 2), y_masked.to(self.device)) |
|
l_ssp = self.criterion(torch.transpose(y_pred_ssp, 1, 2), y_ss.to(self.device)) |
|
loss = l_mlm * 2 + l_ssp |
|
|
|
loss_mlm += l_mlm |
|
loss_ssp += l_ssp |
|
loss_total += loss |
|
|
|
if train_mode: |
|
loss.backward() |
|
self.optimizer.step() |
|
mode = 'train' if train_mode else 'eval' |
|
print(mode + "[{}/{}({:.0f}%)]".format(batch_idx, len(loader), 100. * batch_idx / len(loader)), end = "\r", flush=True) |
|
|
|
loss_mlm /= len(loader) |
|
loss_ssp /= len(loader) |
|
loss_total /= len(loader) |
|
|
|
return model, loss_total, loss_mlm, loss_ssp |
|
|
|
def inference(self, apta, prot): |
|
print('Predict the Aptamer-Protein Interaction') |
|
try: |
|
print("load the best model for api!") |
|
self.model = torch.load('./models/aptatrans/v1_BS=16/v1_BS=16.pt', map_location=self.device) |
|
except: |
|
print('there is no best model file.') |
|
print('You need to train the model for predicting API!') |
|
|
|
print('Aptamer : ', apta) |
|
print('Target Protein : ', prot) |
|
|
|
apta_tokenized = torch.tensor(rna2vec(np.array(apta)), dtype=torch.int64).to(self.device) |
|
if len(prot) > 867: |
|
prot = prot[:867] |
|
prot_tokenized = torch.tensor(tokenize_sequences(prot, self.prot_max_len, self.n_prot_vocabs, self.prot_words), dtype=torch.int64).to(self.device) |
|
|
|
y_pred = self.model(apta_tokenized, prot_tokenized) |
|
score = y_pred.detach().cpu().numpy() |
|
print('Score : ', score) |
|
|
|
return score |
|
|
|
def recommend(self, target, n_aptamers, depth, iteration, verbose=True): |
|
try: |
|
print("load the best model for api!") |
|
self.model.load_state_dict(torch.load('./models/AptaTrans_best_auc.pt', map_location=self.device)) |
|
except: |
|
print('there is no best model file.') |
|
print('You need to train the model for predicting API!') |
|
|
|
candidates = [] |
|
|
|
encoded_targetprotein = torch.tensor(tokenize_sequences(list([target]), self.prot_max_len, self.n_prot_vocabs, self.prot_words), dtype=torch.int64).to(self.device) |
|
mcts = MCTS(encoded_targetprotein, depth=depth, iteration=iteration, states=8, target_protein=target, device=self.device) |
|
|
|
for _ in range(n_aptamers): |
|
mcts.make_candidate(self.model) |
|
candidates.append(mcts.get_candidate()) |
|
|
|
self.model.eval() |
|
with torch.no_grad(): |
|
sim_seq = np.array([mcts.get_candidate()]) |
|
apta = torch.tensor(rna2vec(sim_seq), dtype=torch.int64).to(self.device) |
|
score = self.model(apta, encoded_targetprotein) |
|
|
|
if verbose: |
|
print("candidate:\t", mcts.get_candidate(), "\tscore:\t", score) |
|
print("*"*80) |
|
mcts.reset() |
|
|
|
encoded_targetprotein = torch.tensor(tokenize_sequences(list([target]), self.prot_max_len, self.n_prot_vocabs, self.prot_words), dtype=torch.int64).to(self.device) |
|
for candidate in candidates: |
|
with torch.no_grad(): |
|
sim_seq = np.array([candidate]) |
|
apta = torch.tensor(rna2vec(sim_seq), dtype=torch.int64).to("cpu") |
|
score = self.model(apta, encoded_targetprotein) |
|
|
|
if verbose: |
|
print(f'Candidate : {candidate}, Score: {score}') |