diff --git "a/PD_pLMProbXDiff/DataSetPack.py" "b/PD_pLMProbXDiff/DataSetPack.py" new file mode 100644--- /dev/null +++ "b/PD_pLMProbXDiff/DataSetPack.py" @@ -0,0 +1,3465 @@ +from tensorflow.keras.preprocessing import text, sequence +from tensorflow.keras.preprocessing.text import Tokenizer + +from torch.utils.data import DataLoader,Dataset +import pandas as pd +import seaborn as sns + +import torchvision + +import matplotlib.pyplot as plt +import numpy as np + +from torch import nn +from torch import optim +import torch.nn.functional as F +from torchvision import datasets, transforms, models + +import torch.optim as optim +from torch.optim.lr_scheduler import ExponentialLR, StepLR +from functools import partial, wraps + +from sklearn.model_selection import train_test_split +from sklearn.preprocessing import QuantileTransformer +from sklearn.preprocessing import RobustScaler + +from matplotlib.ticker import MaxNLocator + +import torch + +import esm + +# ============================================================ +# convert csv into df +# ============================================================ +class RegressionDataset(Dataset): + + def __init__(self, X_data, y_data): + self.X_data = X_data + self.y_data = y_data + + def __getitem__(self, index): + return self.X_data[index], self.y_data[index] + + def __len__ (self): + return len(self.X_data) + +# padding a list using a given value +def pad_a_np_arr(x0,add_x,n_len): + n0 = len(x0) + # print(n0) + x1 = x0.copy() + if n0 +# x1 = [add_x]+x1 # somehow, this one doesn't work +# # print(x1) +# # print('x1 len: ',len(x1) ) +# n0 = len(x1) +# # +# if n0 + # x1 = [add_x]+x1 # somehow, this one doesn't work + # print(x1) + # print('x1 len: ',len(x1) ) + n0 = len(x1) + # + if n0max_AASeq_len-2].index, + inplace = True + ) + protein_df.drop( + protein_df[protein_df['seq_len'] max_used_Smo_F].index, + inplace = True + ) + # protein_df.drop( + # protein_df[protein_df['seq_len'] max_AASeq_len-2].index, + inplace = True + ) + protein_df.drop( + protein_df[protein_df['seq_len'] max_used_Smo_F].index, + inplace = True + ) + # protein_df.drop( + # protein_df[protein_df['seq_len'] and + esm_batch_converter = esm_alphabet.get_batch_converter( + truncation_seq_length=PKeys['max_AA_seq_len']-2 + ) + esm_model.eval() # disables dropout for deterministic results + # prepare seqs for the "esm_batch_converter..." + # add dummy labels + seqs_ext=[] + for i in range(len(seqs)): + seqs_ext.append( + (" ", seqs[i]) + ) + # batch_labels, batch_strs, batch_tokens = esm_batch_converter(seqs_ext) + _, y_strs, y_data = esm_batch_converter(seqs_ext) + y_strs_lens = (y_data != esm_alphabet.padding_idx).sum(1) + # print(batch_tokens.shape) + print ("y_data.dim: ", y_data.dtype) + + +# # -- +# # tokenizer_y = None +# if tokenizer_y==None: +# tokenizer_y = Tokenizer(char_level=True, filters='!"$%&()*+,-./:;<=>?[\\]^_`{|}\t\n' ) +# tokenizer_y.fit_on_texts(seqs) + +# #y_data = tokenizer_y.texts_to_sequences(y_data) +# y_data = tokenizer_y.texts_to_sequences(seqs) + +# y_data= sequence.pad_sequences( +# y_data, maxlen=max_AA_len, +# padding='post', truncating='post') + + fig_handle = sns.histplot( + data=pd.DataFrame({'AA code': np.array(y_data).flatten()}), + x='AA code', + bins=np.array([i-0.5 for i in range(0,33+3,1)]), # np.array([i-0.5 for i in range(0,20+3,1)]) + # binwidth=1, + ) + fig = fig_handle.get_figure() + fig_handle.set_xlim(-1, 33+1) + # fig_handle.set_ylim(0, 100000) + outname=store_path+'CSV_5_DataSet_AACode_dist.jpg' + if IF_SaveFig==1: + plt.savefig(outname, dpi=200) + else: + plt.show() + plt.close() + + # ----------------------------------------------------------- + # print ("#################################") + # print ("DICTIONARY y_data") + # dictt=tokenizer_y.get_config() + # print (dictt) + # num_words = len(tokenizer_y.word_index) + 1 + # print ("################## y max token: ",num_words ) + # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + print ("#################################") + print ("DICTIONARY y_data: esm-", PKeys['ESM-2_Model']) + print ("################## y max token: ",len_toks ) + + + #revere + print ("TEST REVERSE: ") + +# # -------------------------------------------------------------- +# y_data_reversed=tokenizer_y.sequences_to_texts (y_data) + +# for iii in range (len(y_data_reversed)): +# y_data_reversed[iii]=y_data_reversed[iii].upper().strip().replace(" ", "") + + # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + # assume y_data is reversiable + y_data_reversed = decode_many_ems_token_rec(y_data, esm_alphabet) + + + print ("Element 0", y_data_reversed[0]) + print ("Number of y samples",len (y_data_reversed) ) + + for iii in [0,2,6]: + print("Ori and REVERSED SEQ: ", iii) + print(seqs[iii]) + print(y_data_reversed[iii]) + + # print ("Original: ", y_data[:3,:]) + # print ("REVERSED TEXT 0..2: ", y_data_reversed[0:3]) + + print ("Len 0 as example: ", len (y_data_reversed[0]) ) + print ("CHeck ori: ", len (seqs[0]) ) + print ("Len 2 as example: ", len (y_data_reversed[2]) ) + print ("CHeck ori: ", len (seqs[2]) ) + + if maxdata 0aaa0 (add one 0 at the beginning) + # # -- + # y_data= sequence.pad_sequences( + # y_data, maxlen=max_AA_len, + # padding='post', truncating='post') + # ++ + y_data= sequence.pad_sequences( + y_data, maxlen=max_AA_len-1, + padding='post', truncating='post', + value=0.0, + ) + # add one 0 at the begining + y_data= sequence.pad_sequences( + y_data, maxlen=max_AA_len, + padding='pre', truncating='pre', + value=0.0, + ) + + len_toks = len(tokenizer_y.word_index) + 1 + + else: + # ++ for pLM: esm + # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + print("pLM model: ", PKeys['ESM-2_Model']) + + if PKeys['ESM-2_Model']=='esm2_t33_650M_UR50D': + # print('Debug block') + # embed dim: 1280 + esm_model, esm_alphabet = esm.pretrained.esm2_t33_650M_UR50D() + len_toks=len(esm_alphabet.all_toks) + elif PKeys['ESM-2_Model']=='esm2_t12_35M_UR50D': + # embed dim: 480 + esm_model, esm_alphabet = esm.pretrained.esm2_t12_35M_UR50D() + len_toks=len(esm_alphabet.all_toks) + elif PKeys['ESM-2_Model']=='esm2_t36_3B_UR50D': + # embed dim: 2560 + esm_model, esm_alphabet = esm.pretrained.esm2_t36_3B_UR50D() + len_toks=len(esm_alphabet.all_toks) + elif PKeys['ESM-2_Model']=='esm2_t30_150M_UR50D': + # embed dim: 640 + esm_model, esm_alphabet = esm.pretrained.esm2_t30_150M_UR50D() + len_toks=len(esm_alphabet.all_toks) + else: + print("protein language model is not defined.") + # pass + # for check + print("esm_alphabet.use_msa: ", esm_alphabet.use_msa) + print("# of tokens in AA alphabet: ", len_toks) + # need to save 2 positions for and + esm_batch_converter = esm_alphabet.get_batch_converter( + truncation_seq_length=PKeys['max_AA_seq_len']-2 + ) + esm_model.eval() # disables dropout for deterministic results + # prepare seqs for the "esm_batch_converter..." + # add dummy labels + seqs_ext=[] + for i in range(len(seqs)): + seqs_ext.append( + (" ", seqs[i]) + ) + # batch_labels, batch_strs, batch_tokens = esm_batch_converter(seqs_ext) + _, y_strs, y_data = esm_batch_converter(seqs_ext) + y_strs_lens = (y_data != esm_alphabet.padding_idx).sum(1) + # + # NEED to check the size of y_data + # need to dealwith if y_data are only shorter sequences + # need to add padding with a value, int (1) + current_seq_len = y_data.shape[1] + print("current seq batch len: ", current_seq_len) + missing_num_pad = PKeys['max_AA_seq_len']-current_seq_len + if missing_num_pad>0: + print("extra padding is added to match the target seq input length...") + # padding is needed + y_data = F.pad( + y_data, + (0, missing_num_pad), + "constant", esm_alphabet.padding_idx + ) + else: + print("No extra padding is needed") + + + # ---------------------------------------------------------------------------------- + # print(batch_tokens.shape) + print ("y_data.dim: ", y_data.shape) + print ("y_data.type: ", y_data.type) + +# # -- +# # tokenizer_y = None +# if tokenizer_y==None: +# tokenizer_y = Tokenizer(char_level=True, filters='!"$%&()*+,-./:;<=>?[\\]^_`{|}\t\n' ) +# tokenizer_y.fit_on_texts(seqs) + +# #y_data = tokenizer_y.texts_to_sequences(y_data) +# y_data = tokenizer_y.texts_to_sequences(seqs) + +# y_data= sequence.pad_sequences( +# y_data, maxlen=max_AA_len, +# padding='post', truncating='post') + + fig_handle = sns.histplot( + data=pd.DataFrame({'AA code': np.array(y_data).flatten()}), + x='AA code', + bins=np.array([i-0.5 for i in range(0,33+3,1)]), # np.array([i-0.5 for i in range(0,20+3,1)]) + # binwidth=1, + ) + fig = fig_handle.get_figure() + fig_handle.set_xlim(-1, 33+1) + # fig_handle.set_ylim(0, 100000) + outname=store_path+'CSV_5_DataSet_AACode_dist.jpg' + if IF_SaveFig==1: + plt.savefig(outname, dpi=200) + else: + plt.show() + plt.close() + + + # ----------------------------------------------------------- + # print ("#################################") + # print ("DICTIONARY y_data") + # dictt=tokenizer_y.get_config() + # print (dictt) + # num_words = len(tokenizer_y.word_index) + 1 + # print ("################## y max token: ",num_words ) + # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + print ("#################################") + print ("DICTIONARY y_data: esm-", PKeys['ESM-2_Model']) + print ("################## y max token: ",len_toks ) + + + #revere + print ("TEST REVERSE: ") + + if PKeys['ESM-2_Model']=='trivial': + # -------------------------------------------------------------- + y_data_reversed=tokenizer_y.sequences_to_texts (y_data) + for iii in range (len(y_data_reversed)): + y_data_reversed[iii]=y_data_reversed[iii].upper().strip().replace(" ", "") + else: + # for ESM models + # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + # assume y_data is reversiable + y_data_reversed = decode_many_ems_token_rec(y_data, esm_alphabet) + + + print ("Element 0", y_data_reversed[0]) + print ("Number of y samples",len (y_data_reversed) ) + + for iii in [0,2,6]: + print("Ori and REVERSED SEQ: ", iii) + print(seqs[iii]) + print(y_data_reversed[iii]) + + # print ("Original: ", y_data[:3,:]) + # print ("REVERSED TEXT 0..2: ", y_data_reversed[0:3]) + + print ("Len 0 as example: ", len (y_data_reversed[0]) ) + print ("CHeck ori: ", len (seqs[0]) ) + print ("Len 2 as example: ", len (y_data_reversed[2]) ) + print ("CHeck ori: ", len (seqs[2]) ) + + if maxdatamax_AASeq_len-2].index, inplace = True) + protein_df.drop(protein_df[protein_df['Seq_Len'] max_AASeq_len-2].index, inplace = True) + protein_df.drop(protein_df[protein_df['Seq_Len'] max_length-2].index, inplace = True) + protein_df.drop(protein_df[protein_df['Seq_Len'] max_length-2].index, inplace = True) + protein_df.drop(protein_df[protein_df['Seq_Len'] and + esm_batch_converter = esm_alphabet.get_batch_converter( + truncation_seq_length=PKeys['max_AA_seq_len']-2 + ) + esm_model.eval() # disables dropout for deterministic results + # prepare seqs for the "esm_batch_converter..." + # add dummy labels + seqs_ext=[] + for i in range(len(seqs)): + seqs_ext.append( + (" ", seqs[i]) + ) + # batch_labels, batch_strs, batch_tokens = esm_batch_converter(seqs_ext) + _, y_strs, y_data = esm_batch_converter(seqs_ext) + y_strs_lens = (y_data != esm_alphabet.padding_idx).sum(1) + # print(batch_tokens.shape) + print ("y_data: ", y_data.dtype) + + +# # -- +# # tokenizer_y = None +# if tokenizer_y==None: +# tokenizer_y = Tokenizer(char_level=True, filters='!"$%&()*+,-./:;<=>?[\\]^_`{|}\t\n' ) +# tokenizer_y.fit_on_texts(seqs) + +# #y_data = tokenizer_y.texts_to_sequences(y_data) +# y_data = tokenizer_y.texts_to_sequences(seqs) + +# y_data= sequence.pad_sequences( +# y_data, maxlen=max_AA_len, +# padding='post', truncating='post') + + fig_handle = sns.histplot( + data=pd.DataFrame({'AA code': np.array(y_data).flatten()}), + x='AA code', + bins=np.array([i-0.5 for i in range(0,33+3,1)]), # np.array([i-0.5 for i in range(0,20+3,1)]) + # binwidth=1, + ) + fig = fig_handle.get_figure() + fig_handle.set_xlim(-1, 33+1) + # fig_handle.set_ylim(0, 100000) + outname=store_path+'CSV_5_DataSet_AACode_dist.jpg' + if IF_SaveFig==1: + plt.savefig(outname, dpi=200) + else: + plt.show() + plt.close() + + # # -------------------------------------------------------- + # print ("#################################") + # print ("DICTIONARY y_data") + # dictt=tokenizer_y.get_config() + # print (dictt) + # num_words = len(tokenizer_y.word_index) + 1 + # print ("################## y max token: ",num_words ) + # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + print ("#################################") + print ("DICTIONARY y_data: esm-", PKeys['ESM-2_Model']) + print ("################## y max token: ",len_toks ) + + #revere + print ("TEST REVERSE: ") +# # ---------------------------------------------------------------- +# y_data_reversed=tokenizer_y.sequences_to_texts (y_data) + +# for iii in range (len(y_data_reversed)): +# y_data_reversed[iii]=y_data_reversed[iii].upper().strip().replace(" ", "") + # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + # assume y_data is reversiable + y_data_reversed = decode_many_ems_token_rec(y_data, esm_alphabet) + + print ("Element 0", y_data_reversed[0]) + print ("Number of y samples",len (y_data_reversed) ) + + for iii in [0,2,6]: + print("Ori and REVERSED SEQ: ", iii) + print(seqs[iii]) + print(y_data_reversed[iii]) + + # print ("Original: ", y_data[:3,:]) + # print ("REVERSED TEXT 0..2: ", y_data_reversed[0:3]) + + print ("Len 0 as example: ", len (y_data_reversed[0]) ) + print ("CHeck ori: ", len (seqs[0]) ) + print ("Len 2 as example: ", len (y_data_reversed[2]) ) + print ("CHeck ori: ", len (seqs[2]) ) + + if maxdatamax_length-2].index, inplace = True) +# protein_df.drop(protein_df[protein_df['Seq_Len']