Bo-Ni commited on
Commit
269fa8c
1 Parent(s): d55dc15

Upload the lib

Browse files
PD_pLMProbXDiff/DataSetPack.py ADDED
The diff for this file is too large to render. See raw diff
 
PD_pLMProbXDiff/ModelPack.py ADDED
The diff for this file is too large to render. See raw diff
 
PD_pLMProbXDiff/PostMDPack.py ADDED
@@ -0,0 +1,375 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+
4
+ import pandas as pd
5
+ import torch
6
+
7
+ import numpy as np
8
+
9
+ import matplotlib.pyplot as plt
10
+ import seaborn as sns
11
+
12
+ import linecache
13
+ import re
14
+
15
+ from Bio.PDB import PDBParser, PDBIO
16
+ import math
17
+
18
+ from Bio.PDB import PDBIO
19
+ from Bio.PDB import PDBParser
20
+ from Bio.PDB import Superimposer
21
+ from Bio.PDB.vectors import calc_angle, calc_dihedral
22
+ import Bio.PDB.vectors
23
+ #
24
+ from Bio.PDB.DSSP import DSSP # add try a self-made one
25
+ # from Bio.PDB.DSSP_SelfMade import DSSP_SelfMade # add try a self-made one
26
+
27
+ resdict = {
28
+ "ALA": "A",
29
+ "CYS": "C",
30
+ "ASP": "D",
31
+ "GLU": "E",
32
+ "PHE": "F",
33
+ "GLY": "G",
34
+ "HIS": "H",
35
+ "ILE": "I",
36
+ "LYS": "K",
37
+ "LEU": "L",
38
+ "MET": "M",
39
+ "ASN": "N",
40
+ "PRO": "P",
41
+ "GLN": "Q",
42
+ "ARG": "R",
43
+ "SER": "S",
44
+ "THR": "T",
45
+ "VAL": "V",
46
+ "TRP": "W",
47
+ "TYR": "Y",
48
+ }
49
+ # using those from force field file
50
+ #
51
+ resdict = {
52
+ "ALA": "A",
53
+ "ARG": "R",
54
+ "ASN": "N",
55
+ "ASP": "D",
56
+ "CYS": "C",
57
+ "GLN": "Q",
58
+ "GLU": "E",
59
+ "GLY": "G",
60
+ "HIS": "H",
61
+ "HSD": "H",
62
+ "HSE": "H",
63
+ "HSP": "H",
64
+ "ILE": "I",
65
+ "LYS": "K",
66
+ "LEU": "L",
67
+ "MET": "M",
68
+ "PHE": "F",
69
+ "PRO": "P",
70
+ "SER": "S",
71
+ "THR": "T",
72
+ "TRP": "W",
73
+ "TYR": "Y",
74
+ "VAL": "V",
75
+
76
+ }
77
+ #
78
+ # SMD setup
79
+ SMD_Vel = 0.0001 # A/timestep
80
+
81
+ # step_data * SMD_Vel = pulling_dist
82
+
83
+ def collect_geo_of_backbone(chain):
84
+ prev = "0"
85
+ rad = 180.0 / math.pi
86
+ # result
87
+ resu = {"AA":[],\
88
+ "Bond_CA_N":[],"Bond_CA_C":[],"Bond_N_C1":[],\
89
+ "Angl_CA1_C1_N":[],"Angl_C1_N_CA":[],"Angl_N_CA_C":[],\
90
+ "Dihe_PHI":[],"Dihe_PSI":[],"Dihe_OME":[]}
91
+ #
92
+ for res in chain:
93
+ if res.get_resname() in resdict.keys():
94
+
95
+ # seq += resdict[res.get_resname()]
96
+ resu["AA"].append(resdict[res.get_resname()])
97
+ # ToDo, check whether this res has N, CA, C
98
+ # if not (res.has_key("N") and res.has_key("NA") and res.has_key("C")):
99
+ # print("Key backbone atom is missing")
100
+
101
+ if prev == "0":
102
+ # 1st AA:
103
+ N_prev = res["N"]
104
+ CA_prev = res["CA"]
105
+ C_prev = res["C"]
106
+ # update the key
107
+ prev = "1"
108
+ else:
109
+ n1 = N_prev.get_vector()
110
+ ca1 = CA_prev.get_vector()
111
+ c1 = C_prev.get_vector()
112
+
113
+ # print(res)
114
+ C_curr = res["C"]
115
+ N_curr = res["N"]
116
+ CA_curr = res["CA"]
117
+
118
+ # get the coordinates
119
+ c = C_curr.get_vector()
120
+ n = N_curr.get_vector()
121
+ ca = CA_curr.get_vector()
122
+
123
+ # get the measurement
124
+ ca1_c1_n_ThisAngle = calc_angle(ca1, c1, n)*rad
125
+ c1_n_ca_ThisAngle = calc_angle(c1, n, ca)*rad
126
+ n_ca_c_ThisAngle = calc_angle(n, ca, c)*rad
127
+
128
+ ca_n_ThisBond = CA_curr - N_curr
129
+ ca_c_ThisBond = CA_curr - C_curr
130
+ n_c1_ThisBond = N_curr - C_prev
131
+
132
+ ThisPsi = calc_dihedral(n1, ca1, c1, n) # degree
133
+ ThisOmega = calc_dihedral(ca1, c1, n, ca) # degree
134
+ ThisPhi = calc_dihedral(c1, n, ca, c) # degree
135
+
136
+ # store the results
137
+ # n1-ca1-c1--n-ca-c--n2-ca2-c2
138
+ resu["Bond_CA_N"].append(ca_n_ThisBond)
139
+ resu["Bond_CA_C"].append(ca_c_ThisBond)
140
+ resu["Bond_N_C1"].append(n_c1_ThisBond) # peptide bond
141
+ #
142
+ resu["Angl_CA1_C1_N"].append(ca1_c1_n_ThisAngle)
143
+ resu["Angl_C1_N_CA"].append(c1_n_ca_ThisAngle)
144
+ resu["Angl_N_CA_C"].append(n_ca_c_ThisAngle)
145
+ #
146
+ resu["Dihe_PHI"].append(ThisPhi)
147
+ resu["Dihe_PSI"].append(ThisPsi)
148
+ resu["Dihe_OME"].append(ThisOmega)
149
+
150
+ # update the AA info
151
+ N_prev = res["N"]
152
+ CA_prev = res["CA"]
153
+ C_prev = res["C"]
154
+
155
+ # summerize the result
156
+ return resu
157
+ #
158
+ def collect_multi_chain_AA_info(pdb_file):
159
+ parser = PDBParser()
160
+ structure = parser.get_structure("sample", pdb_file)
161
+ resu_full = {"Chain":[],"AA":{}}
162
+ for chain in structure.get_chains():
163
+ this_chain_id = chain.get_id()
164
+ # print('Working on Chain ', this_chain_id)
165
+ # working on one chain; Assume there is only one chain
166
+ resu_full["Chain"].append(this_chain_id)
167
+ resu_test = collect_geo_of_backbone(chain)
168
+ resu_full["AA"][this_chain_id]=resu_test["AA"]
169
+ # can add more
170
+
171
+ return resu_full
172
+
173
+
174
+
175
+ # read one record
176
+
177
+ # plot one record ONLY in the non-empty cases
178
+ #
179
+ def get_one_force_record(ii, resu_file_name_list):
180
+ # ii = pick_file_list[i]
181
+ pdb_id = resu_file_name_list['PDB_ID'][ii]
182
+ data_one_file = resu_file_name_list['Path'][ii]+'/1_working_dir/collect_results/smd_resu.dat'
183
+ data = np.genfromtxt(data_one_file)
184
+ # print(data.shape)
185
+ # kernel = np.ones(kernel_size) / kernel_size
186
+
187
+ # focus on disp-force curve
188
+ # print('# of data point: ', data.shape[0])
189
+ disp_data = data[:,1]
190
+ force_data = data[:,7]
191
+
192
+ # + add the pulling point info
193
+ # pulling point disp
194
+ step_data = data[:,0]
195
+ setdata_one_file = resu_file_name_list['Path'][ii]+'/1_working_dir/box_dimension_after_eq.dat'
196
+ line_4 = linecache.getline(setdata_one_file, 4)
197
+ SMD_Vel = float(line_4.split()[2])
198
+ pull_data = SMD_Vel*step_data
199
+
200
+ # force_data_convolved_10 = np.convolve(force_data, kernel, mode='same')
201
+ return disp_data, force_data, pdb_id, pull_data
202
+
203
+ # collect AA from the record
204
+ def get_one_AA_record(ii, resu_file_name_list):
205
+ # ii = pick_file_list[i]
206
+ # TestProt_chain_0_after_psf.pdb
207
+ pdb_file = resu_file_name_list['Path'][ii]+'/1_working_dir/TestProt_chain_0_after_psf.pdb'
208
+
209
+ resu_full = collect_multi_chain_AA_info(pdb_file)
210
+ # Here, we assume there is only one chain in the file, which is the case for tensile test
211
+ # AA_seq = resu_full["AA"][resu_full["Chain"][0]]
212
+ AA_seq = ''.join(resu_full["AA"][resu_full["Chain"][0]])
213
+
214
+ return AA_seq
215
+
216
+ # smooth functions
217
+ def conv_one_record(force_data, kernel_size):
218
+ kernel = np.ones(kernel_size) / kernel_size
219
+ force_data_convolved = np.convolve(force_data, kernel, mode='same')
220
+
221
+ return force_data_convolved
222
+
223
+ from math import factorial
224
+
225
+ from scipy.ndimage.filters import uniform_filter1d
226
+ #
227
+ # function to smooth the data
228
+ def savitzky_golay(y, window_size, order, deriv=0, rate=1):
229
+
230
+ try:
231
+ # window_size = np.abs(np.int(window_size))
232
+ window_size = np.abs(int(window_size))
233
+ # order = np.abs(np.int(order))
234
+ order = np.abs(int(order))
235
+ except ValueError:
236
+ raise ValueError("window_size and order have to be of type int")
237
+
238
+ if window_size % 2 != 1 or window_size < 1:
239
+ raise TypeError("window_size size must be a positive odd number")
240
+ if window_size < order + 2:
241
+ raise TypeError("window_size is too small for the polynomials order")
242
+ order_range = range(order+1)
243
+ half_window = (window_size -1) // 2
244
+ # precompute coefficients
245
+ b = np.mat([[k**i for i in order_range] for k in range(-half_window, half_window+1)])
246
+ m = np.linalg.pinv(b).A[deriv] * rate**deriv * factorial(deriv)
247
+ # pad the signal at the extremes with
248
+ # values taken from the signal itself
249
+ firstvals = y[0] - np.abs( y[1:half_window+1][::-1] - y[0] )
250
+ lastvals = y[-1] + np.abs(y[-half_window-1:-1][::-1] - y[-1])
251
+ y = np.concatenate((firstvals, y, lastvals))
252
+
253
+ return np.convolve( m[::-1], y, mode='valid')
254
+
255
+ #
256
+ def read_gap_values_from_dat(file):
257
+ # line_2 = linecache.getline('r"'+file+'"', 2)
258
+ # line_3 = linecache.getline('r"'+file+'"', 3)
259
+ line_2 = linecache.getline(file, 2)
260
+ line_3 = linecache.getline(file, 3)
261
+ # get the values
262
+ ini_gap = float(line_2.split()[2])
263
+ fin_gap = float(line_3.split()[2])
264
+ return ini_gap, fin_gap
265
+
266
+
267
+ def read_one_array_from_df(one_record):
268
+ return np.array(list(map(float, one_record.split(" "))))
269
+ #
270
+ def read_string_find_max(reco):
271
+ x = read_one_array_from_df(reco)
272
+ return np.amax(x)
273
+
274
+ def read_string_find_max(reco):
275
+ x = read_one_array_from_df(reco)
276
+ return np.amax(x)
277
+ #
278
+ def cal_seq_end_gap(x):
279
+ inc_gap_arr = x['posi_data']-x['posi_data'][0]
280
+ ini_gap = x['ini_gap']
281
+ gap_arr = ini_gap+inc_gap_arr
282
+
283
+ return gap_arr
284
+ #
285
+ def cal_pull_end_gap(x):
286
+ inc_gap_arr = x['pull_data'] # -x['pull_data'][0]
287
+ ini_gap = x['ini_gap']
288
+ gap_arr = ini_gap+inc_gap_arr
289
+
290
+ return gap_arr
291
+
292
+ #
293
+ # pick the force at the unfolding of every residues
294
+
295
+ def simplify_NormPull_FORCEnF_rec(n_fold,this_seq_len,this_n_PullGap_arr,this_Force_arr):
296
+
297
+ target_pull_gap_list = [1./(this_seq_len*n_fold)*(jj+0) for jj in range(this_seq_len*n_fold)]
298
+ target_pull_gap_list.append(1.)
299
+
300
+ # retrive the force values
301
+ target_force = []
302
+ for jj in range(len(target_pull_gap_list)):
303
+ # for jj in range(10):
304
+ this_t_n_PullGap = target_pull_gap_list[jj]
305
+
306
+ if this_t_n_PullGap<this_n_PullGap_arr[0]:
307
+ this_t_F = 0.
308
+ else:
309
+ # find the neareast one
310
+ disp_arr = np.abs(this_n_PullGap_arr - this_t_n_PullGap)
311
+ pick_id = np.argmin(disp_arr)
312
+ this_t_F = this_Force_arr[pick_id]
313
+ #
314
+ target_force.append(this_t_F)
315
+ #
316
+ target_pull_gap_arr = np.array(target_pull_gap_list)
317
+ target_force_arr = np.array(target_force)
318
+
319
+ # for delivery
320
+ resu = {}
321
+ resu['sample_NormPullGap'] = target_pull_gap_arr
322
+ resu['smaple_FORCE'] = target_force_arr
323
+ return resu
324
+
325
+ # read input conditions
326
+ def read_input_model_A(file_path):
327
+ with open(file_path, 'r') as f:
328
+ txt = f.read()
329
+ nums = re.findall(r'\[([^][]+)\]', txt)
330
+ arr = np.loadtxt(nums)
331
+ # print(arr)
332
+ # print(arr[0])
333
+
334
+ return arr
335
+
336
+ def read_input_model_B(file_path):
337
+ with open(file_path, 'r') as f:
338
+ txt = f.read()
339
+ nums = re.findall(r'\[([^][]+)\]', txt)
340
+ # arr = np.loadtxt(nums)
341
+ arr = np.loadtxt( [nums[0].replace('\n','')] )
342
+ # print(arr)
343
+ # print(arr[0])
344
+
345
+ return arr
346
+
347
+ def read_one_input_arr_from_txt(file_path):
348
+ with open(file_path, 'r') as f:
349
+ txt = f.read()
350
+ nums = re.findall(r'\[([^][]+)\]', txt)
351
+ # arr = np.loadtxt(nums)
352
+ arr = np.loadtxt( [nums[0].replace('\n','')] )
353
+ # print(arr)
354
+ # print(arr[0])
355
+
356
+ return arr
357
+
358
+ # this only for this version, in folder3 it is updated
359
+ # # for folder3
360
+ # def recover_input_for_model_B(file_path, seq_len):
361
+ # raw_arr = read_one_input_arr_from_txt(file_path)
362
+ # arr = raw_arr[1:1+seq_len+1]
363
+ # return arr
364
+ # for folder2
365
+ def recover_input_for_model_B_ver2(file_path, seq_len):
366
+ raw_arr = read_one_input_arr_from_txt(file_path)
367
+ arr = raw_arr[0:0+seq_len+1]
368
+ return arr
369
+
370
+ # for folder3
371
+ def recover_input_for_model_B_ver3(file_path, seq_len):
372
+ raw_arr = read_one_input_arr_from_txt(file_path)
373
+ arr = np.zeros(seq_len+1)
374
+ arr[1:1+seq_len] = raw_arr[0:0+seq_len]
375
+ return arr
PD_pLMProbXDiff/TrainerPack.py ADDED
The diff for this file is too large to render. See raw diff
 
PD_pLMProbXDiff/UtilityPack.py ADDED
@@ -0,0 +1,671 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ==========================================================
2
+ # Utility functions
3
+ # ==========================================================
4
+ import os
5
+ from scipy.interpolate import CubicSpline, PchipInterpolator, Akima1DInterpolator
6
+ import numpy as np
7
+ import math
8
+ import matplotlib.pyplot as plt
9
+
10
+ from Bio.PDB import PDBParser
11
+ from Bio.PDB.DSSP import DSSP
12
+ from Bio.PDB import PDBList
13
+
14
+ import torch
15
+ from einops import rearrange
16
+ import esm
17
+ # =========================================================
18
+ # create a folder path if not exist
19
+ def create_path(this_path):
20
+ if not os.path.exists(this_path):
21
+ print('Creating the given path...')
22
+ os.mkdir (this_path)
23
+ path_stat = 1
24
+ print('Done.')
25
+ else:
26
+ print('The given path already exists!')
27
+ path_stat = 2
28
+ return path_stat
29
+
30
+ # ==========================================================
31
+
32
+ # measure the model size
33
+ def params (model):
34
+ pytorch_total_params = sum(p.numel() for p in model.parameters())
35
+ pytorch_total_params_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
36
+
37
+ print ("Total parameters: ", pytorch_total_params," trainable parameters: ", pytorch_total_params_trainable)
38
+
39
+ # ==========================================================
40
+ # initialization function for dict for models
41
+ def prepare_UNet_keys(write_dict):
42
+ # if not setted, using the default
43
+ Full_Keys=['dim', 'text_embed_dim', 'num_resnet_blocks', 'cond_dim', 'num_image_tokens', 'num_time_tokens', 'learned_sinu_pos_emb_dim', 'out_dim', 'dim_mults', 'cond_images_channels', 'channels', 'channels_out', 'attn_dim_head', 'attn_heads', 'ff_mult', 'lowres_cond', 'layer_attns', 'layer_attns_depth', 'layer_attns_add_text_cond', 'attend_at_middle', 'layer_cross_attns', 'use_linear_attn', 'use_linear_cross_attn', 'cond_on_text', 'max_text_len', 'init_dim', 'resnet_groups', 'init_conv_kernel_size', 'init_cross_embed', 'init_cross_embed_kernel_sizes', 'cross_embed_downsample', 'cross_embed_downsample_kernel_sizes', 'attn_pool_text', 'attn_pool_num_latents', 'dropout', 'memory_efficient', 'init_conv_to_final_conv_residual', 'use_global_context_attn', 'scale_skip_connection', 'final_resnet_block', 'final_conv_kernel_size', 'cosine_sim_attn', 'self_cond', 'combine_upsample_fmaps', 'pixel_shuffle_upsample', 'beginning_and_final_conv_present']
44
+ # initialization
45
+ PKeys={}
46
+ for key in Full_Keys:
47
+ PKeys[key]=None
48
+ # modify if keys are provided
49
+ for write_key in write_dict.keys():
50
+ if write_key in PKeys.keys():
51
+ PKeys[write_key]=write_dict[write_key]
52
+ else:
53
+ print("Wrong key found: ", write_key)
54
+
55
+ return PKeys
56
+
57
+ def prepare_ModelB_keys(write_dict):
58
+ Full_Keys=['timesteps', 'dim', 'pred_dim', 'loss_type', 'elucidated', 'padding_idx', 'cond_dim', 'text_embed_dim', 'input_tokens', 'sequence_embed', 'embed_dim_position', 'max_text_len', 'cond_images_channels', 'max_length', 'device']
59
+ # initialization
60
+ PKeys={}
61
+ for key in Full_Keys:
62
+ PKeys[key]=None
63
+ # modify if keys are provided
64
+ for write_key in write_dict.keys():
65
+ if write_key in PKeys.keys():
66
+ PKeys[write_key]=write_dict[write_key]
67
+ else:
68
+ print("Wrong key found: ", write_key)
69
+
70
+ return PKeys
71
+
72
+ def modify_keys(old_dict,write_dict):
73
+ new_dict = old_dict.copy()
74
+ for w_key in write_dict.keys():
75
+ if w_key in old_dict.keys():
76
+ new_dict[w_key]=write_dict[w_key]
77
+ else:
78
+ print("Alien key found: ", w_key)
79
+ return new_dict
80
+
81
+ # ==========================================================
82
+ # mix two NForce record for a given AA length
83
+ # ==========================================================
84
+ def mixing_two_FORCE_for_AA_Len(NGap1,Force1,NGap2,Force2,LenAA,mix_fac):
85
+ N = np.amax([len(NGap1), len(NGap2)])
86
+ N_Base = math.ceil(N*2)
87
+ fun_PI_0 = PchipInterpolator(NGap1,Force1)
88
+ fun_PI_1 = PchipInterpolator(NGap2,Force2)
89
+ xx=np.linspace(0,1,N_Base)
90
+ yy=fun_PI_0(xx)*mix_fac+fun_PI_1(xx)*(1-mix_fac)
91
+ fun_PI = PchipInterpolator(xx,yy)
92
+ # discrete result
93
+ x1=np.linspace(0,1,LenAA+1)
94
+ y1=fun_PI(x1)
95
+ return fun_PI, x1, y1
96
+
97
+ # =========================================================
98
+ #
99
+ # =========================================================
100
+ def get_Model_A_error (fname, cond, plotit=True, ploterror=False):
101
+
102
+ sec_structure,sec_structure_3state, sequence=get_DSSP_result (fname)
103
+ sscount=[]
104
+ length = len (sec_structure)
105
+ sscount.append (sec_structure.count('H')/length)
106
+ sscount.append (sec_structure.count('E')/length)
107
+ sscount.append (sec_structure.count('T')/length)
108
+ sscount.append (sec_structure.count('~')/length)
109
+ sscount.append (sec_structure.count('B')/length)
110
+ sscount.append (sec_structure.count('G')/length)
111
+ sscount.append (sec_structure.count('I')/length)
112
+ sscount.append (sec_structure.count('S')/length)
113
+ sscount=np.asarray (sscount)
114
+
115
+ error=np.abs(sscount-cond)
116
+ print ("Abs error per SS structure type (H, E, T, ~, B, G, I S): ", error)
117
+
118
+ if ploterror:
119
+ fig, ax = plt.subplots(1, 1, figsize=(6,3))
120
+ plt.plot (error, 'o-', label='Error over SS type')
121
+ plt.legend()
122
+ plt.ylabel ('SS content')
123
+ plt.show()
124
+
125
+ x=np.linspace (0, 7, 8)
126
+
127
+ sslabels=['H','E','T','~','B','G','I','S']
128
+
129
+ fig, ax = plt.subplots(1, 1, figsize=(6,3))
130
+
131
+ ax.bar(x-0.15, cond, width=0.3, color='b', align='center')
132
+ ax.bar(x+0.15, sscount, width=0.3, color='r', align='center')
133
+
134
+ ax.set_ylim([0, 1])
135
+
136
+ plt.xticks(range(len(sslabels)), sslabels, size='medium')
137
+ plt.legend (['GT','Prediction'])
138
+
139
+ plt.ylabel ('SS content')
140
+ plt.show()
141
+
142
+ ######################## 3 types
143
+
144
+ sscount=[]
145
+ length = len (sec_structure)
146
+ sscount.append (sec_structure_3state.count('H')/length)
147
+ sscount.append (sec_structure_3state.count('E')/length)
148
+ sscount.append (sec_structure_3state.count('~')/length)
149
+ cond_p=[np.sum([cond[0],cond[5], cond[6]]), np.sum ([cond[1], cond[4]]), np.sum([cond[2],cond[3],cond[7]]) ]
150
+
151
+ print ("cond 3type: ",cond_p)
152
+ sscount=np.asarray (sscount)
153
+
154
+ error3=np.abs(sscount-cond_p)
155
+ print ("Abs error per 3-type SS structure type (C, H, E): ", error)
156
+
157
+ if ploterror:
158
+ fig, ax = plt.subplots(1, 1, figsize=(6,3))
159
+
160
+ plt.plot (error3, 'o-', label='Error over SS type')
161
+ plt.legend()
162
+ plt.ylabel ('SS content')
163
+ plt.show()
164
+
165
+
166
+ x=np.linspace (0,2, 3)
167
+
168
+ sslabels=['H','E', '~' ]
169
+
170
+ #ax = plt.subplot(111, figsize=(4,4))
171
+ fig, ax = plt.subplots(1, 1, figsize=(6,3))
172
+
173
+
174
+ ax.bar(x-0.15, cond_p, width=0.3, color='b', align='center')
175
+ ax.bar(x+0.15, sscount, width=0.3, color='r', align='center')
176
+
177
+ ax.set_ylim([0, 1])
178
+
179
+ plt.xticks(range(len(sslabels)), sslabels, size='medium')
180
+ plt.legend (['GT','Prediction'])
181
+
182
+ plt.ylabel ('SS content')
183
+ plt.show()
184
+
185
+ return error
186
+
187
+ def get_DSSP_result (fname):
188
+ pdb_list = [fname]
189
+
190
+ # parse structure
191
+ p = PDBParser()
192
+ for i in pdb_list:
193
+ structure = p.get_structure(i, fname)
194
+ # use only the first model
195
+ model = structure[0]
196
+ # calculate DSSP
197
+ dssp = DSSP(model, fname, file_type='PDB' )
198
+ # extract sequence and secondary structure from the DSSP tuple
199
+ sequence = ''
200
+ sec_structure = ''
201
+ for z in range(len(dssp)):
202
+ a_key = list(dssp.keys())[z]
203
+ sequence += dssp[a_key][1]
204
+ sec_structure += dssp[a_key][2]
205
+
206
+ #print(i)
207
+ #print(sequence)
208
+ #print(sec_structure)
209
+ #
210
+ # The DSSP codes for secondary structure used here are:
211
+ # ===== ====
212
+ # Code Structure
213
+ # ===== ====
214
+ # H Alpha helix (4-12)
215
+ # B Isolated beta-bridge residue
216
+ # E Strand
217
+ # G 3-10 helix
218
+ # I Pi helix
219
+ # T Turn
220
+ # S Bend
221
+ # ~ None
222
+ # ===== ====
223
+ #
224
+
225
+ sec_structure = sec_structure.replace('-', '~')
226
+ sec_structure_3state=sec_structure
227
+
228
+
229
+ # if desired, convert DSSP's 8-state assignments into 3-state [C - coil, E - extended (beta-strand), H - helix]
230
+ sec_structure_3state = sec_structure_3state.replace('H', 'H') #0
231
+ sec_structure_3state = sec_structure_3state.replace('E', 'E')
232
+ sec_structure_3state = sec_structure_3state.replace('T', '~')
233
+ sec_structure_3state = sec_structure_3state.replace('~', '~')
234
+ sec_structure_3state = sec_structure_3state.replace('B', 'E')
235
+ sec_structure_3state = sec_structure_3state.replace('G', 'H') #5
236
+ sec_structure_3state = sec_structure_3state.replace('I', 'H') #6
237
+ sec_structure_3state = sec_structure_3state.replace('S', '~')
238
+ return sec_structure,sec_structure_3state, sequence
239
+
240
+
241
+ def string_diff (seq1, seq2):
242
+ return sum(1 for a, b in zip(seq1, seq2) if a != b) + abs(len(seq1) - len(seq2))
243
+
244
+
245
+ # ============================================================
246
+ # on esm, rebuild AA sequence from embedding
247
+ # ============================================================
248
+ import esm
249
+
250
+ def decode_one_ems_token_rec(this_token, esm_alphabet):
251
+ # print( (this_token==esm_alphabet.cls_idx).nonzero(as_tuple=True)[0] )
252
+ # print( (this_token==esm_alphabet.eos_idx).nonzero(as_tuple=True)[0] )
253
+ # print( (this_token==100).nonzero(as_tuple=True)[0]==None )
254
+
255
+ id_b=(this_token==esm_alphabet.cls_idx).nonzero(as_tuple=True)[0]
256
+ id_e=(this_token==esm_alphabet.eos_idx).nonzero(as_tuple=True)[0]
257
+
258
+
259
+ if len(id_e)==0:
260
+ # no ending for this one, so id_e points to the end
261
+ id_e=len(this_token)
262
+ else:
263
+ id_e=id_e[0]
264
+ if len(id_b)==0:
265
+ id_b=0
266
+ else:
267
+ id_b=id_b[-1]
268
+
269
+ this_seq = []
270
+ # this_token_used = []
271
+ for ii in range(id_b+1,id_e,1):
272
+ # this_token_used.append(this_token[ii])
273
+ this_seq.append(
274
+ esm_alphabet.get_tok(this_token[ii])
275
+ )
276
+
277
+ this_seq = "".join(this_seq)
278
+
279
+ # print(this_seq)
280
+ # print(len(this_seq))
281
+ # # print(this_token[id_b+1:id_e])
282
+ return this_seq
283
+
284
+
285
+ def decode_many_ems_token_rec(batch_tokens, esm_alphabet):
286
+ rev_y_seq = []
287
+ for jj in range(len(batch_tokens)):
288
+ # do for one seq: this_seq
289
+ this_seq = decode_one_ems_token_rec(
290
+ batch_tokens[jj], esm_alphabet
291
+ )
292
+ rev_y_seq.append(this_seq)
293
+ return rev_y_seq
294
+
295
+ # ++ for omegafold sequence: treat unknows as X
296
+ uncomm_idx_list = [0, 1, 2, 3, 24, 25, 26, 27, 28, 29, 30, 31, 32]
297
+
298
+ # this one decide the beginning and ending AUTOMATICALLY
299
+ def decode_one_ems_token_rec_for_folding(
300
+ this_token,
301
+ this_logits,
302
+ esm_alphabet,
303
+ esm_model):
304
+
305
+ # print( (this_token==esm_alphabet.cls_idx).nonzero(as_tuple=True)[0] )
306
+ # print( (this_token==esm_alphabet.eos_idx).nonzero(as_tuple=True)[0] )
307
+ # print( (this_token==100).nonzero(as_tuple=True)[0]==None )
308
+
309
+ # 1. use this_token to find the beginning and ending
310
+ # 2. to logits to generate tokens that ONLY contains foldable AAs
311
+ #
312
+ id_b_0=(this_token==esm_alphabet.cls_idx).nonzero(as_tuple=True)[0]
313
+ id_e_0=(this_token==esm_alphabet.eos_idx).nonzero(as_tuple=True)[0]
314
+
315
+ # ------------------------------------------------------------------
316
+ # principle:
317
+ # 1. begin at 0th
318
+ # 2. end as soon as possible: relay on that the first endding is learned
319
+ id_b = 0
320
+ #
321
+ if len(id_e_0)==0:
322
+ id_e=len(this_token)
323
+ else:
324
+ id_e=id_e_0[0]
325
+ # correct if needed
326
+ if id_e<=id_b+1:
327
+ if len(id_e_0)>1:
328
+ id_e=id_e_0[1]
329
+ else:
330
+ id_e=len(this_token)
331
+ # -------------------------------------------------------------------
332
+
333
+ # # ------------------------------------------------------------------
334
+ # # not perfect
335
+ # # principle:
336
+ # # 1. begin as late as possible
337
+ # # 2. end as soon as possible
338
+ # #
339
+ # if len(id_b_0)==0:
340
+ # id_b=0
341
+ # else:
342
+ # id_b=id_b_0[-1]
343
+ # # so, beginning is set
344
+ # # looking for the nearest ending signal if we can find one
345
+ # # 1. pick those in id_e that id_b<id_e
346
+ # id_e_1=[]
347
+ # for this_e in id_e_0:
348
+ # if this_e>id_b:
349
+ # id_e_1.append(this_e)
350
+ # # 2. check what we find
351
+ # if len(id_e_1)==0:
352
+ # # no endding, id_e points to the end
353
+ # id_e=len(this_token)
354
+ # else:
355
+ # # otherwise, find endding point and pick the first one
356
+ # id_e=id_e_1[0]
357
+ # # 3. if id_b+1==id_e, we still get nothing. So, this is a fake fix
358
+ # if id_e==id_b+1:
359
+ # if len(id_e_1)>1:
360
+ # id_e=id_e_1[1]
361
+ # else:
362
+ # id_e=len(this_token)
363
+ # # --------------------------------------------------------------------
364
+
365
+ # if id_b>id_e:
366
+ # for debug:
367
+ print("start at: ", id_b)
368
+ print("end at: ", id_e)
369
+
370
+ # along the sequence, we pick only index [id_b+1:id_e]. This exclude the <cls> and <eos>
371
+ use_logits = this_logits[id_b+1:id_e] # (seq_len_eff, token_len)
372
+ use_logits[:,uncomm_idx_list]=-float('inf')
373
+ use_token = use_logits.max(1).indices
374
+
375
+ # print(use_token)
376
+
377
+ this_seq = []
378
+ # this_token_used = []
379
+ # for ii in range(id_b+1,id_e,1):
380
+ for ii in range(len(use_token)):
381
+ # this_token_used.append(this_token[ii])
382
+ # print(esm_alphabet.get_tok(use_token[ii]))
383
+ # print(ii)
384
+ this_seq.append(
385
+ esm_alphabet.get_tok(use_token[ii])
386
+ )
387
+
388
+ this_seq = "".join(this_seq)
389
+
390
+ # # generate a foldable sequece
391
+ # # map all uncommon ones into X/24
392
+ # for idx, one_token in enumerate( this_token_used):
393
+ # find_it=0
394
+ # for this_uncomm in uncomm_idx_list:
395
+ # find_id=find_id+(this_uncomm==one_token)
396
+ # #
397
+ # if find_id>0:
398
+ # this_token_used[idx]=24 # 24 means X
399
+ # # translate token into sequences
400
+ # this_seq_foldable=[]
401
+ # for one_token in this_token_used:
402
+ # this_seq_foldable.append(
403
+ # esm_alphabet.get_tok(one_token)
404
+ # )
405
+
406
+ # # print(this_seq)
407
+ # # print(len(this_seq))
408
+ # # # print(this_token[id_b+1:id_e])
409
+ # return this_seq, this_seq_foldable
410
+ return this_seq
411
+
412
+
413
+ def decode_many_ems_token_rec_for_folding(
414
+ batch_tokens,
415
+ batch_logits,
416
+ esm_alphabet,
417
+ esm_model):
418
+
419
+ rev_y_seq = []
420
+ for jj in range(len(batch_tokens)):
421
+ # do for one seq: this_seq
422
+ this_seq = decode_one_ems_token_rec_for_folding(
423
+ batch_tokens[jj],
424
+ batch_logits[jj],
425
+ esm_alphabet,
426
+ esm_model,
427
+ )
428
+ rev_y_seq.append(this_seq)
429
+ return rev_y_seq
430
+
431
+
432
+ def convert_into_logits(esm_model, result):
433
+ repre=rearrange(
434
+ result,
435
+ 'b l c -> b c l'
436
+ )
437
+ with torch.no_grad():
438
+ logits=esm_model.lm_head(repre)
439
+
440
+ return logits
441
+
442
+ # this one return the unmodified tokens and logits
443
+ def convert_into_tokens(model, result, pLM_Model_Name):
444
+ if pLM_Model_Name=='esm2_t33_650M_UR50D' \
445
+ or pLM_Model_Name=='esm2_t36_3B_UR50D' \
446
+ or pLM_Model_Name=='esm2_t30_150M_UR50D' \
447
+ or pLM_Model_Name=='esm2_t12_35M_UR50D' :
448
+
449
+ repre=rearrange(
450
+ result,
451
+ 'b c l -> b l c'
452
+ )
453
+ with torch.no_grad():
454
+ logits=model.lm_head(repre) # (b, l, token_dim)
455
+
456
+ tokens=logits.max(2).indices # (b,l)
457
+
458
+ else:
459
+ print("pLM_Model is not defined...")
460
+ return tokens,logits
461
+ # ++
462
+ def convert_into_tokens_using_prob(prob_result, pLM_Model_Name):
463
+ if pLM_Model_Name=='esm2_t33_650M_UR50D' \
464
+ or pLM_Model_Name=='esm2_t36_3B_UR50D' \
465
+ or pLM_Model_Name=='esm2_t30_150M_UR50D' \
466
+ or pLM_Model_Name=='esm2_t12_35M_UR50D' :
467
+
468
+ repre=rearrange(
469
+ prob_result,
470
+ 'b c l -> b l c'
471
+ )
472
+ # with torch.no_grad():
473
+ # logits=model.lm_head(repre) # (b, l, token_dim)
474
+ logits = repre
475
+
476
+ tokens=logits.max(2).indices # (b,l)
477
+
478
+ else:
479
+ print("pLM_Model is not defined...")
480
+ return tokens,logits
481
+
482
+
483
+ #
484
+ def read_mask_from_input(
485
+ # consider different type of inputs
486
+ # raw data: x_data (sequences)
487
+ # tokenized: x_data_tokenized
488
+ tokenized_data=None, # X_train_batch,
489
+ mask_value=None,
490
+ seq_data=None,
491
+ max_seq_length=None,
492
+ ):
493
+ # # old:
494
+ # mask = X_train_batch!=mask_value
495
+ # new
496
+ if seq_data!=None:
497
+ # use the real sequence length to create mask
498
+ n_seq = len(seq_data)
499
+ mask = torch.zeros(n_seq, max_seq_length)
500
+ for ii in range(n_seq):
501
+ this_len = len(seq_data[ii])
502
+ mask[ii,1:1+this_len]=1
503
+ mask = mask==1
504
+ #
505
+ elif tokenized_data!=None:
506
+ n_seq = len(tokenized_data)
507
+ mask = tokenized_data!=mask_value
508
+ # fix the beginning part: 0+content+00, not 00+content+00
509
+ for ii in range(n_seq):
510
+ # get all nonzero index
511
+ id_1 = (mask[ii]==True).nonzero(as_tuple=True)[0]
512
+ # correction for ForcPath,
513
+ # pick up 0.0 for zero-force padding at the beginning
514
+ mask[ii,1:id_1[0]]=True
515
+
516
+ return mask
517
+
518
+ # ++ read one length
519
+ def read_one_len_from_padding_vec(
520
+ in_np_array,
521
+ padding_val=0.0,
522
+ ):
523
+ mask = in_np_array!=padding_val
524
+ id_list_all_1 = mask.nonzero()[0]
525
+ vec_len = id_list_all_1[-1]+1
526
+
527
+ return vec_len
528
+
529
+
530
+ # this one decide the beginning and ending using mask
531
+ def decode_one_ems_token_rec_for_folding_with_mask(
532
+ this_token,
533
+ this_logits,
534
+ esm_alphabet,
535
+ esm_model,
536
+ this_mask,
537
+ ):
538
+ # translate all logits into tokens then screen the unmaksed part
539
+
540
+
541
+ # along the sequence, we pick only index [id_b+1:id_e]. This exclude the <cls> and <eos>
542
+ use_logits = this_logits # (seq_len_eff, token_len)
543
+ use_logits[:,uncomm_idx_list]=-float('inf')
544
+ use_token = use_logits.max(1).indices
545
+ #
546
+ print(use_token)
547
+ use_token = use_token[this_mask==True]
548
+ # print(use_token)
549
+
550
+ this_seq = []
551
+ # this_token_used = []
552
+ # for ii in range(id_b+1,id_e,1):
553
+ for ii in range(len(use_token)):
554
+ # this_token_used.append(this_token[ii])
555
+ # print(esm_alphabet.get_tok(use_token[ii]))
556
+ # print(ii)
557
+ this_seq.append(
558
+ esm_alphabet.get_tok(use_token[ii])
559
+ )
560
+
561
+ this_seq = "".join(this_seq)
562
+
563
+ return this_seq
564
+
565
+ def decode_many_ems_token_rec_for_folding_with_mask(
566
+ batch_tokens,
567
+ batch_logits,
568
+ esm_alphabet,
569
+ esm_model,
570
+ mask):
571
+
572
+ rev_y_seq = []
573
+ for jj in range(len(batch_tokens)):
574
+ # do for one seq: this_seq
575
+ this_seq = decode_one_ems_token_rec_for_folding_with_mask(
576
+ batch_tokens[jj],
577
+ batch_logits[jj],
578
+ esm_alphabet,
579
+ esm_model,
580
+ mask[jj]
581
+ )
582
+ rev_y_seq.append(this_seq)
583
+ return rev_y_seq
584
+
585
+ # =====================================================
586
+ # create new input condition for ForcPath case
587
+ # =====================================================
588
+ from scipy import interpolate
589
+
590
+ def interpolate_and_resample_ForcPath(y0,seq_len1):
591
+ seq_len0=len(y0)-1
592
+ x0=np.arange(0., 1.+1./seq_len0, 1./seq_len0)
593
+ f=interpolate.interp1d(x0,y0)
594
+ #
595
+ x1=np.arange(0., 1.+1./seq_len1, 1./seq_len1)
596
+ y1=f(x1)
597
+ #
598
+ resu = {}
599
+ resu['y1']=y1
600
+ resu['x1']=x1
601
+ resu['x0']=x0
602
+ return resu
603
+ #
604
+ def mix_two_ForcPath(y0,y1,seq_len2):
605
+ seq_len0=len(y0)-1
606
+ x0=np.arange(0., 1.+1./seq_len0, 1./seq_len0)
607
+ seq_len1=len(y1)-1
608
+ x1=np.arange(0., 1.+1./seq_len1, 1./seq_len1)
609
+ f0=interpolate.interp1d(x0,y0)
610
+ f1=interpolate.interp1d(x1,y1)
611
+ #
612
+ x2=np.arange(0., 1.+1./seq_len2, 1./seq_len2)
613
+ y2=(f0(x2)+f1(x2))/1.
614
+ #
615
+ resu={}
616
+ resu['y2']=y2
617
+ resu['x2']=x2
618
+ resu['x1']=x1
619
+ resu['x0']=x0
620
+ return resu
621
+ #
622
+ # =====================================================
623
+ # load in function for language model
624
+ # =====================================================
625
+ import esm
626
+
627
+ def load_in_pLM(pLM_Model_Name,device):
628
+ #
629
+ # ++ for pLM
630
+ if pLM_Model_Name=='trivial':
631
+ pLM_Model=None
632
+ esm_alphabet=None
633
+ len_toks=0
634
+ esm_layer=0
635
+
636
+ elif pLM_Model_Name=='esm2_t33_650M_UR50D':
637
+ # dim: 1280
638
+ esm_layer=33
639
+ pLM_Model, esm_alphabet = esm.pretrained.esm2_t33_650M_UR50D()
640
+ len_toks=len(esm_alphabet.all_toks)
641
+ pLM_Model.eval()
642
+ pLM_Model. to(device)
643
+
644
+ elif pLM_Model_Name=='esm2_t36_3B_UR50D':
645
+ # dim: 2560
646
+ esm_layer=36
647
+ pLM_Model, esm_alphabet = esm.pretrained.esm2_t36_3B_UR50D()
648
+ len_toks=len(esm_alphabet.all_toks)
649
+ pLM_Model.eval()
650
+ pLM_Model. to(device)
651
+
652
+ elif pLM_Model_Name=='esm2_t30_150M_UR50D':
653
+ # dim: 640
654
+ esm_layer=30
655
+ pLM_Model, esm_alphabet = esm.pretrained.esm2_t30_150M_UR50D()
656
+ len_toks=len(esm_alphabet.all_toks)
657
+ pLM_Model.eval()
658
+ pLM_Model. to(device)
659
+
660
+ elif pLM_Model_Name=='esm2_t12_35M_UR50D':
661
+ # dim: 480
662
+ esm_layer=12
663
+ pLM_Model, esm_alphabet = esm.pretrained.esm2_t12_35M_UR50D()
664
+ len_toks=len(esm_alphabet.all_toks)
665
+ pLM_Model.eval()
666
+ pLM_Model. to(device)
667
+
668
+ else:
669
+ print("pLM model is missing...")
670
+
671
+ return pLM_Model, esm_alphabet, esm_layer, len_toks