Upload the lib
Browse files- PD_pLMProbXDiff/DataSetPack.py +0 -0
- PD_pLMProbXDiff/ModelPack.py +0 -0
- PD_pLMProbXDiff/PostMDPack.py +375 -0
- PD_pLMProbXDiff/TrainerPack.py +0 -0
- PD_pLMProbXDiff/UtilityPack.py +671 -0
PD_pLMProbXDiff/DataSetPack.py
ADDED
The diff for this file is too large to render.
See raw diff
|
|
PD_pLMProbXDiff/ModelPack.py
ADDED
The diff for this file is too large to render.
See raw diff
|
|
PD_pLMProbXDiff/PostMDPack.py
ADDED
@@ -0,0 +1,375 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
|
4 |
+
import pandas as pd
|
5 |
+
import torch
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
import matplotlib.pyplot as plt
|
10 |
+
import seaborn as sns
|
11 |
+
|
12 |
+
import linecache
|
13 |
+
import re
|
14 |
+
|
15 |
+
from Bio.PDB import PDBParser, PDBIO
|
16 |
+
import math
|
17 |
+
|
18 |
+
from Bio.PDB import PDBIO
|
19 |
+
from Bio.PDB import PDBParser
|
20 |
+
from Bio.PDB import Superimposer
|
21 |
+
from Bio.PDB.vectors import calc_angle, calc_dihedral
|
22 |
+
import Bio.PDB.vectors
|
23 |
+
#
|
24 |
+
from Bio.PDB.DSSP import DSSP # add try a self-made one
|
25 |
+
# from Bio.PDB.DSSP_SelfMade import DSSP_SelfMade # add try a self-made one
|
26 |
+
|
27 |
+
resdict = {
|
28 |
+
"ALA": "A",
|
29 |
+
"CYS": "C",
|
30 |
+
"ASP": "D",
|
31 |
+
"GLU": "E",
|
32 |
+
"PHE": "F",
|
33 |
+
"GLY": "G",
|
34 |
+
"HIS": "H",
|
35 |
+
"ILE": "I",
|
36 |
+
"LYS": "K",
|
37 |
+
"LEU": "L",
|
38 |
+
"MET": "M",
|
39 |
+
"ASN": "N",
|
40 |
+
"PRO": "P",
|
41 |
+
"GLN": "Q",
|
42 |
+
"ARG": "R",
|
43 |
+
"SER": "S",
|
44 |
+
"THR": "T",
|
45 |
+
"VAL": "V",
|
46 |
+
"TRP": "W",
|
47 |
+
"TYR": "Y",
|
48 |
+
}
|
49 |
+
# using those from force field file
|
50 |
+
#
|
51 |
+
resdict = {
|
52 |
+
"ALA": "A",
|
53 |
+
"ARG": "R",
|
54 |
+
"ASN": "N",
|
55 |
+
"ASP": "D",
|
56 |
+
"CYS": "C",
|
57 |
+
"GLN": "Q",
|
58 |
+
"GLU": "E",
|
59 |
+
"GLY": "G",
|
60 |
+
"HIS": "H",
|
61 |
+
"HSD": "H",
|
62 |
+
"HSE": "H",
|
63 |
+
"HSP": "H",
|
64 |
+
"ILE": "I",
|
65 |
+
"LYS": "K",
|
66 |
+
"LEU": "L",
|
67 |
+
"MET": "M",
|
68 |
+
"PHE": "F",
|
69 |
+
"PRO": "P",
|
70 |
+
"SER": "S",
|
71 |
+
"THR": "T",
|
72 |
+
"TRP": "W",
|
73 |
+
"TYR": "Y",
|
74 |
+
"VAL": "V",
|
75 |
+
|
76 |
+
}
|
77 |
+
#
|
78 |
+
# SMD setup
|
79 |
+
SMD_Vel = 0.0001 # A/timestep
|
80 |
+
|
81 |
+
# step_data * SMD_Vel = pulling_dist
|
82 |
+
|
83 |
+
def collect_geo_of_backbone(chain):
|
84 |
+
prev = "0"
|
85 |
+
rad = 180.0 / math.pi
|
86 |
+
# result
|
87 |
+
resu = {"AA":[],\
|
88 |
+
"Bond_CA_N":[],"Bond_CA_C":[],"Bond_N_C1":[],\
|
89 |
+
"Angl_CA1_C1_N":[],"Angl_C1_N_CA":[],"Angl_N_CA_C":[],\
|
90 |
+
"Dihe_PHI":[],"Dihe_PSI":[],"Dihe_OME":[]}
|
91 |
+
#
|
92 |
+
for res in chain:
|
93 |
+
if res.get_resname() in resdict.keys():
|
94 |
+
|
95 |
+
# seq += resdict[res.get_resname()]
|
96 |
+
resu["AA"].append(resdict[res.get_resname()])
|
97 |
+
# ToDo, check whether this res has N, CA, C
|
98 |
+
# if not (res.has_key("N") and res.has_key("NA") and res.has_key("C")):
|
99 |
+
# print("Key backbone atom is missing")
|
100 |
+
|
101 |
+
if prev == "0":
|
102 |
+
# 1st AA:
|
103 |
+
N_prev = res["N"]
|
104 |
+
CA_prev = res["CA"]
|
105 |
+
C_prev = res["C"]
|
106 |
+
# update the key
|
107 |
+
prev = "1"
|
108 |
+
else:
|
109 |
+
n1 = N_prev.get_vector()
|
110 |
+
ca1 = CA_prev.get_vector()
|
111 |
+
c1 = C_prev.get_vector()
|
112 |
+
|
113 |
+
# print(res)
|
114 |
+
C_curr = res["C"]
|
115 |
+
N_curr = res["N"]
|
116 |
+
CA_curr = res["CA"]
|
117 |
+
|
118 |
+
# get the coordinates
|
119 |
+
c = C_curr.get_vector()
|
120 |
+
n = N_curr.get_vector()
|
121 |
+
ca = CA_curr.get_vector()
|
122 |
+
|
123 |
+
# get the measurement
|
124 |
+
ca1_c1_n_ThisAngle = calc_angle(ca1, c1, n)*rad
|
125 |
+
c1_n_ca_ThisAngle = calc_angle(c1, n, ca)*rad
|
126 |
+
n_ca_c_ThisAngle = calc_angle(n, ca, c)*rad
|
127 |
+
|
128 |
+
ca_n_ThisBond = CA_curr - N_curr
|
129 |
+
ca_c_ThisBond = CA_curr - C_curr
|
130 |
+
n_c1_ThisBond = N_curr - C_prev
|
131 |
+
|
132 |
+
ThisPsi = calc_dihedral(n1, ca1, c1, n) # degree
|
133 |
+
ThisOmega = calc_dihedral(ca1, c1, n, ca) # degree
|
134 |
+
ThisPhi = calc_dihedral(c1, n, ca, c) # degree
|
135 |
+
|
136 |
+
# store the results
|
137 |
+
# n1-ca1-c1--n-ca-c--n2-ca2-c2
|
138 |
+
resu["Bond_CA_N"].append(ca_n_ThisBond)
|
139 |
+
resu["Bond_CA_C"].append(ca_c_ThisBond)
|
140 |
+
resu["Bond_N_C1"].append(n_c1_ThisBond) # peptide bond
|
141 |
+
#
|
142 |
+
resu["Angl_CA1_C1_N"].append(ca1_c1_n_ThisAngle)
|
143 |
+
resu["Angl_C1_N_CA"].append(c1_n_ca_ThisAngle)
|
144 |
+
resu["Angl_N_CA_C"].append(n_ca_c_ThisAngle)
|
145 |
+
#
|
146 |
+
resu["Dihe_PHI"].append(ThisPhi)
|
147 |
+
resu["Dihe_PSI"].append(ThisPsi)
|
148 |
+
resu["Dihe_OME"].append(ThisOmega)
|
149 |
+
|
150 |
+
# update the AA info
|
151 |
+
N_prev = res["N"]
|
152 |
+
CA_prev = res["CA"]
|
153 |
+
C_prev = res["C"]
|
154 |
+
|
155 |
+
# summerize the result
|
156 |
+
return resu
|
157 |
+
#
|
158 |
+
def collect_multi_chain_AA_info(pdb_file):
|
159 |
+
parser = PDBParser()
|
160 |
+
structure = parser.get_structure("sample", pdb_file)
|
161 |
+
resu_full = {"Chain":[],"AA":{}}
|
162 |
+
for chain in structure.get_chains():
|
163 |
+
this_chain_id = chain.get_id()
|
164 |
+
# print('Working on Chain ', this_chain_id)
|
165 |
+
# working on one chain; Assume there is only one chain
|
166 |
+
resu_full["Chain"].append(this_chain_id)
|
167 |
+
resu_test = collect_geo_of_backbone(chain)
|
168 |
+
resu_full["AA"][this_chain_id]=resu_test["AA"]
|
169 |
+
# can add more
|
170 |
+
|
171 |
+
return resu_full
|
172 |
+
|
173 |
+
|
174 |
+
|
175 |
+
# read one record
|
176 |
+
|
177 |
+
# plot one record ONLY in the non-empty cases
|
178 |
+
#
|
179 |
+
def get_one_force_record(ii, resu_file_name_list):
|
180 |
+
# ii = pick_file_list[i]
|
181 |
+
pdb_id = resu_file_name_list['PDB_ID'][ii]
|
182 |
+
data_one_file = resu_file_name_list['Path'][ii]+'/1_working_dir/collect_results/smd_resu.dat'
|
183 |
+
data = np.genfromtxt(data_one_file)
|
184 |
+
# print(data.shape)
|
185 |
+
# kernel = np.ones(kernel_size) / kernel_size
|
186 |
+
|
187 |
+
# focus on disp-force curve
|
188 |
+
# print('# of data point: ', data.shape[0])
|
189 |
+
disp_data = data[:,1]
|
190 |
+
force_data = data[:,7]
|
191 |
+
|
192 |
+
# + add the pulling point info
|
193 |
+
# pulling point disp
|
194 |
+
step_data = data[:,0]
|
195 |
+
setdata_one_file = resu_file_name_list['Path'][ii]+'/1_working_dir/box_dimension_after_eq.dat'
|
196 |
+
line_4 = linecache.getline(setdata_one_file, 4)
|
197 |
+
SMD_Vel = float(line_4.split()[2])
|
198 |
+
pull_data = SMD_Vel*step_data
|
199 |
+
|
200 |
+
# force_data_convolved_10 = np.convolve(force_data, kernel, mode='same')
|
201 |
+
return disp_data, force_data, pdb_id, pull_data
|
202 |
+
|
203 |
+
# collect AA from the record
|
204 |
+
def get_one_AA_record(ii, resu_file_name_list):
|
205 |
+
# ii = pick_file_list[i]
|
206 |
+
# TestProt_chain_0_after_psf.pdb
|
207 |
+
pdb_file = resu_file_name_list['Path'][ii]+'/1_working_dir/TestProt_chain_0_after_psf.pdb'
|
208 |
+
|
209 |
+
resu_full = collect_multi_chain_AA_info(pdb_file)
|
210 |
+
# Here, we assume there is only one chain in the file, which is the case for tensile test
|
211 |
+
# AA_seq = resu_full["AA"][resu_full["Chain"][0]]
|
212 |
+
AA_seq = ''.join(resu_full["AA"][resu_full["Chain"][0]])
|
213 |
+
|
214 |
+
return AA_seq
|
215 |
+
|
216 |
+
# smooth functions
|
217 |
+
def conv_one_record(force_data, kernel_size):
|
218 |
+
kernel = np.ones(kernel_size) / kernel_size
|
219 |
+
force_data_convolved = np.convolve(force_data, kernel, mode='same')
|
220 |
+
|
221 |
+
return force_data_convolved
|
222 |
+
|
223 |
+
from math import factorial
|
224 |
+
|
225 |
+
from scipy.ndimage.filters import uniform_filter1d
|
226 |
+
#
|
227 |
+
# function to smooth the data
|
228 |
+
def savitzky_golay(y, window_size, order, deriv=0, rate=1):
|
229 |
+
|
230 |
+
try:
|
231 |
+
# window_size = np.abs(np.int(window_size))
|
232 |
+
window_size = np.abs(int(window_size))
|
233 |
+
# order = np.abs(np.int(order))
|
234 |
+
order = np.abs(int(order))
|
235 |
+
except ValueError:
|
236 |
+
raise ValueError("window_size and order have to be of type int")
|
237 |
+
|
238 |
+
if window_size % 2 != 1 or window_size < 1:
|
239 |
+
raise TypeError("window_size size must be a positive odd number")
|
240 |
+
if window_size < order + 2:
|
241 |
+
raise TypeError("window_size is too small for the polynomials order")
|
242 |
+
order_range = range(order+1)
|
243 |
+
half_window = (window_size -1) // 2
|
244 |
+
# precompute coefficients
|
245 |
+
b = np.mat([[k**i for i in order_range] for k in range(-half_window, half_window+1)])
|
246 |
+
m = np.linalg.pinv(b).A[deriv] * rate**deriv * factorial(deriv)
|
247 |
+
# pad the signal at the extremes with
|
248 |
+
# values taken from the signal itself
|
249 |
+
firstvals = y[0] - np.abs( y[1:half_window+1][::-1] - y[0] )
|
250 |
+
lastvals = y[-1] + np.abs(y[-half_window-1:-1][::-1] - y[-1])
|
251 |
+
y = np.concatenate((firstvals, y, lastvals))
|
252 |
+
|
253 |
+
return np.convolve( m[::-1], y, mode='valid')
|
254 |
+
|
255 |
+
#
|
256 |
+
def read_gap_values_from_dat(file):
|
257 |
+
# line_2 = linecache.getline('r"'+file+'"', 2)
|
258 |
+
# line_3 = linecache.getline('r"'+file+'"', 3)
|
259 |
+
line_2 = linecache.getline(file, 2)
|
260 |
+
line_3 = linecache.getline(file, 3)
|
261 |
+
# get the values
|
262 |
+
ini_gap = float(line_2.split()[2])
|
263 |
+
fin_gap = float(line_3.split()[2])
|
264 |
+
return ini_gap, fin_gap
|
265 |
+
|
266 |
+
|
267 |
+
def read_one_array_from_df(one_record):
|
268 |
+
return np.array(list(map(float, one_record.split(" "))))
|
269 |
+
#
|
270 |
+
def read_string_find_max(reco):
|
271 |
+
x = read_one_array_from_df(reco)
|
272 |
+
return np.amax(x)
|
273 |
+
|
274 |
+
def read_string_find_max(reco):
|
275 |
+
x = read_one_array_from_df(reco)
|
276 |
+
return np.amax(x)
|
277 |
+
#
|
278 |
+
def cal_seq_end_gap(x):
|
279 |
+
inc_gap_arr = x['posi_data']-x['posi_data'][0]
|
280 |
+
ini_gap = x['ini_gap']
|
281 |
+
gap_arr = ini_gap+inc_gap_arr
|
282 |
+
|
283 |
+
return gap_arr
|
284 |
+
#
|
285 |
+
def cal_pull_end_gap(x):
|
286 |
+
inc_gap_arr = x['pull_data'] # -x['pull_data'][0]
|
287 |
+
ini_gap = x['ini_gap']
|
288 |
+
gap_arr = ini_gap+inc_gap_arr
|
289 |
+
|
290 |
+
return gap_arr
|
291 |
+
|
292 |
+
#
|
293 |
+
# pick the force at the unfolding of every residues
|
294 |
+
|
295 |
+
def simplify_NormPull_FORCEnF_rec(n_fold,this_seq_len,this_n_PullGap_arr,this_Force_arr):
|
296 |
+
|
297 |
+
target_pull_gap_list = [1./(this_seq_len*n_fold)*(jj+0) for jj in range(this_seq_len*n_fold)]
|
298 |
+
target_pull_gap_list.append(1.)
|
299 |
+
|
300 |
+
# retrive the force values
|
301 |
+
target_force = []
|
302 |
+
for jj in range(len(target_pull_gap_list)):
|
303 |
+
# for jj in range(10):
|
304 |
+
this_t_n_PullGap = target_pull_gap_list[jj]
|
305 |
+
|
306 |
+
if this_t_n_PullGap<this_n_PullGap_arr[0]:
|
307 |
+
this_t_F = 0.
|
308 |
+
else:
|
309 |
+
# find the neareast one
|
310 |
+
disp_arr = np.abs(this_n_PullGap_arr - this_t_n_PullGap)
|
311 |
+
pick_id = np.argmin(disp_arr)
|
312 |
+
this_t_F = this_Force_arr[pick_id]
|
313 |
+
#
|
314 |
+
target_force.append(this_t_F)
|
315 |
+
#
|
316 |
+
target_pull_gap_arr = np.array(target_pull_gap_list)
|
317 |
+
target_force_arr = np.array(target_force)
|
318 |
+
|
319 |
+
# for delivery
|
320 |
+
resu = {}
|
321 |
+
resu['sample_NormPullGap'] = target_pull_gap_arr
|
322 |
+
resu['smaple_FORCE'] = target_force_arr
|
323 |
+
return resu
|
324 |
+
|
325 |
+
# read input conditions
|
326 |
+
def read_input_model_A(file_path):
|
327 |
+
with open(file_path, 'r') as f:
|
328 |
+
txt = f.read()
|
329 |
+
nums = re.findall(r'\[([^][]+)\]', txt)
|
330 |
+
arr = np.loadtxt(nums)
|
331 |
+
# print(arr)
|
332 |
+
# print(arr[0])
|
333 |
+
|
334 |
+
return arr
|
335 |
+
|
336 |
+
def read_input_model_B(file_path):
|
337 |
+
with open(file_path, 'r') as f:
|
338 |
+
txt = f.read()
|
339 |
+
nums = re.findall(r'\[([^][]+)\]', txt)
|
340 |
+
# arr = np.loadtxt(nums)
|
341 |
+
arr = np.loadtxt( [nums[0].replace('\n','')] )
|
342 |
+
# print(arr)
|
343 |
+
# print(arr[0])
|
344 |
+
|
345 |
+
return arr
|
346 |
+
|
347 |
+
def read_one_input_arr_from_txt(file_path):
|
348 |
+
with open(file_path, 'r') as f:
|
349 |
+
txt = f.read()
|
350 |
+
nums = re.findall(r'\[([^][]+)\]', txt)
|
351 |
+
# arr = np.loadtxt(nums)
|
352 |
+
arr = np.loadtxt( [nums[0].replace('\n','')] )
|
353 |
+
# print(arr)
|
354 |
+
# print(arr[0])
|
355 |
+
|
356 |
+
return arr
|
357 |
+
|
358 |
+
# this only for this version, in folder3 it is updated
|
359 |
+
# # for folder3
|
360 |
+
# def recover_input_for_model_B(file_path, seq_len):
|
361 |
+
# raw_arr = read_one_input_arr_from_txt(file_path)
|
362 |
+
# arr = raw_arr[1:1+seq_len+1]
|
363 |
+
# return arr
|
364 |
+
# for folder2
|
365 |
+
def recover_input_for_model_B_ver2(file_path, seq_len):
|
366 |
+
raw_arr = read_one_input_arr_from_txt(file_path)
|
367 |
+
arr = raw_arr[0:0+seq_len+1]
|
368 |
+
return arr
|
369 |
+
|
370 |
+
# for folder3
|
371 |
+
def recover_input_for_model_B_ver3(file_path, seq_len):
|
372 |
+
raw_arr = read_one_input_arr_from_txt(file_path)
|
373 |
+
arr = np.zeros(seq_len+1)
|
374 |
+
arr[1:1+seq_len] = raw_arr[0:0+seq_len]
|
375 |
+
return arr
|
PD_pLMProbXDiff/TrainerPack.py
ADDED
The diff for this file is too large to render.
See raw diff
|
|
PD_pLMProbXDiff/UtilityPack.py
ADDED
@@ -0,0 +1,671 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ==========================================================
|
2 |
+
# Utility functions
|
3 |
+
# ==========================================================
|
4 |
+
import os
|
5 |
+
from scipy.interpolate import CubicSpline, PchipInterpolator, Akima1DInterpolator
|
6 |
+
import numpy as np
|
7 |
+
import math
|
8 |
+
import matplotlib.pyplot as plt
|
9 |
+
|
10 |
+
from Bio.PDB import PDBParser
|
11 |
+
from Bio.PDB.DSSP import DSSP
|
12 |
+
from Bio.PDB import PDBList
|
13 |
+
|
14 |
+
import torch
|
15 |
+
from einops import rearrange
|
16 |
+
import esm
|
17 |
+
# =========================================================
|
18 |
+
# create a folder path if not exist
|
19 |
+
def create_path(this_path):
|
20 |
+
if not os.path.exists(this_path):
|
21 |
+
print('Creating the given path...')
|
22 |
+
os.mkdir (this_path)
|
23 |
+
path_stat = 1
|
24 |
+
print('Done.')
|
25 |
+
else:
|
26 |
+
print('The given path already exists!')
|
27 |
+
path_stat = 2
|
28 |
+
return path_stat
|
29 |
+
|
30 |
+
# ==========================================================
|
31 |
+
|
32 |
+
# measure the model size
|
33 |
+
def params (model):
|
34 |
+
pytorch_total_params = sum(p.numel() for p in model.parameters())
|
35 |
+
pytorch_total_params_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
36 |
+
|
37 |
+
print ("Total parameters: ", pytorch_total_params," trainable parameters: ", pytorch_total_params_trainable)
|
38 |
+
|
39 |
+
# ==========================================================
|
40 |
+
# initialization function for dict for models
|
41 |
+
def prepare_UNet_keys(write_dict):
|
42 |
+
# if not setted, using the default
|
43 |
+
Full_Keys=['dim', 'text_embed_dim', 'num_resnet_blocks', 'cond_dim', 'num_image_tokens', 'num_time_tokens', 'learned_sinu_pos_emb_dim', 'out_dim', 'dim_mults', 'cond_images_channels', 'channels', 'channels_out', 'attn_dim_head', 'attn_heads', 'ff_mult', 'lowres_cond', 'layer_attns', 'layer_attns_depth', 'layer_attns_add_text_cond', 'attend_at_middle', 'layer_cross_attns', 'use_linear_attn', 'use_linear_cross_attn', 'cond_on_text', 'max_text_len', 'init_dim', 'resnet_groups', 'init_conv_kernel_size', 'init_cross_embed', 'init_cross_embed_kernel_sizes', 'cross_embed_downsample', 'cross_embed_downsample_kernel_sizes', 'attn_pool_text', 'attn_pool_num_latents', 'dropout', 'memory_efficient', 'init_conv_to_final_conv_residual', 'use_global_context_attn', 'scale_skip_connection', 'final_resnet_block', 'final_conv_kernel_size', 'cosine_sim_attn', 'self_cond', 'combine_upsample_fmaps', 'pixel_shuffle_upsample', 'beginning_and_final_conv_present']
|
44 |
+
# initialization
|
45 |
+
PKeys={}
|
46 |
+
for key in Full_Keys:
|
47 |
+
PKeys[key]=None
|
48 |
+
# modify if keys are provided
|
49 |
+
for write_key in write_dict.keys():
|
50 |
+
if write_key in PKeys.keys():
|
51 |
+
PKeys[write_key]=write_dict[write_key]
|
52 |
+
else:
|
53 |
+
print("Wrong key found: ", write_key)
|
54 |
+
|
55 |
+
return PKeys
|
56 |
+
|
57 |
+
def prepare_ModelB_keys(write_dict):
|
58 |
+
Full_Keys=['timesteps', 'dim', 'pred_dim', 'loss_type', 'elucidated', 'padding_idx', 'cond_dim', 'text_embed_dim', 'input_tokens', 'sequence_embed', 'embed_dim_position', 'max_text_len', 'cond_images_channels', 'max_length', 'device']
|
59 |
+
# initialization
|
60 |
+
PKeys={}
|
61 |
+
for key in Full_Keys:
|
62 |
+
PKeys[key]=None
|
63 |
+
# modify if keys are provided
|
64 |
+
for write_key in write_dict.keys():
|
65 |
+
if write_key in PKeys.keys():
|
66 |
+
PKeys[write_key]=write_dict[write_key]
|
67 |
+
else:
|
68 |
+
print("Wrong key found: ", write_key)
|
69 |
+
|
70 |
+
return PKeys
|
71 |
+
|
72 |
+
def modify_keys(old_dict,write_dict):
|
73 |
+
new_dict = old_dict.copy()
|
74 |
+
for w_key in write_dict.keys():
|
75 |
+
if w_key in old_dict.keys():
|
76 |
+
new_dict[w_key]=write_dict[w_key]
|
77 |
+
else:
|
78 |
+
print("Alien key found: ", w_key)
|
79 |
+
return new_dict
|
80 |
+
|
81 |
+
# ==========================================================
|
82 |
+
# mix two NForce record for a given AA length
|
83 |
+
# ==========================================================
|
84 |
+
def mixing_two_FORCE_for_AA_Len(NGap1,Force1,NGap2,Force2,LenAA,mix_fac):
|
85 |
+
N = np.amax([len(NGap1), len(NGap2)])
|
86 |
+
N_Base = math.ceil(N*2)
|
87 |
+
fun_PI_0 = PchipInterpolator(NGap1,Force1)
|
88 |
+
fun_PI_1 = PchipInterpolator(NGap2,Force2)
|
89 |
+
xx=np.linspace(0,1,N_Base)
|
90 |
+
yy=fun_PI_0(xx)*mix_fac+fun_PI_1(xx)*(1-mix_fac)
|
91 |
+
fun_PI = PchipInterpolator(xx,yy)
|
92 |
+
# discrete result
|
93 |
+
x1=np.linspace(0,1,LenAA+1)
|
94 |
+
y1=fun_PI(x1)
|
95 |
+
return fun_PI, x1, y1
|
96 |
+
|
97 |
+
# =========================================================
|
98 |
+
#
|
99 |
+
# =========================================================
|
100 |
+
def get_Model_A_error (fname, cond, plotit=True, ploterror=False):
|
101 |
+
|
102 |
+
sec_structure,sec_structure_3state, sequence=get_DSSP_result (fname)
|
103 |
+
sscount=[]
|
104 |
+
length = len (sec_structure)
|
105 |
+
sscount.append (sec_structure.count('H')/length)
|
106 |
+
sscount.append (sec_structure.count('E')/length)
|
107 |
+
sscount.append (sec_structure.count('T')/length)
|
108 |
+
sscount.append (sec_structure.count('~')/length)
|
109 |
+
sscount.append (sec_structure.count('B')/length)
|
110 |
+
sscount.append (sec_structure.count('G')/length)
|
111 |
+
sscount.append (sec_structure.count('I')/length)
|
112 |
+
sscount.append (sec_structure.count('S')/length)
|
113 |
+
sscount=np.asarray (sscount)
|
114 |
+
|
115 |
+
error=np.abs(sscount-cond)
|
116 |
+
print ("Abs error per SS structure type (H, E, T, ~, B, G, I S): ", error)
|
117 |
+
|
118 |
+
if ploterror:
|
119 |
+
fig, ax = plt.subplots(1, 1, figsize=(6,3))
|
120 |
+
plt.plot (error, 'o-', label='Error over SS type')
|
121 |
+
plt.legend()
|
122 |
+
plt.ylabel ('SS content')
|
123 |
+
plt.show()
|
124 |
+
|
125 |
+
x=np.linspace (0, 7, 8)
|
126 |
+
|
127 |
+
sslabels=['H','E','T','~','B','G','I','S']
|
128 |
+
|
129 |
+
fig, ax = plt.subplots(1, 1, figsize=(6,3))
|
130 |
+
|
131 |
+
ax.bar(x-0.15, cond, width=0.3, color='b', align='center')
|
132 |
+
ax.bar(x+0.15, sscount, width=0.3, color='r', align='center')
|
133 |
+
|
134 |
+
ax.set_ylim([0, 1])
|
135 |
+
|
136 |
+
plt.xticks(range(len(sslabels)), sslabels, size='medium')
|
137 |
+
plt.legend (['GT','Prediction'])
|
138 |
+
|
139 |
+
plt.ylabel ('SS content')
|
140 |
+
plt.show()
|
141 |
+
|
142 |
+
######################## 3 types
|
143 |
+
|
144 |
+
sscount=[]
|
145 |
+
length = len (sec_structure)
|
146 |
+
sscount.append (sec_structure_3state.count('H')/length)
|
147 |
+
sscount.append (sec_structure_3state.count('E')/length)
|
148 |
+
sscount.append (sec_structure_3state.count('~')/length)
|
149 |
+
cond_p=[np.sum([cond[0],cond[5], cond[6]]), np.sum ([cond[1], cond[4]]), np.sum([cond[2],cond[3],cond[7]]) ]
|
150 |
+
|
151 |
+
print ("cond 3type: ",cond_p)
|
152 |
+
sscount=np.asarray (sscount)
|
153 |
+
|
154 |
+
error3=np.abs(sscount-cond_p)
|
155 |
+
print ("Abs error per 3-type SS structure type (C, H, E): ", error)
|
156 |
+
|
157 |
+
if ploterror:
|
158 |
+
fig, ax = plt.subplots(1, 1, figsize=(6,3))
|
159 |
+
|
160 |
+
plt.plot (error3, 'o-', label='Error over SS type')
|
161 |
+
plt.legend()
|
162 |
+
plt.ylabel ('SS content')
|
163 |
+
plt.show()
|
164 |
+
|
165 |
+
|
166 |
+
x=np.linspace (0,2, 3)
|
167 |
+
|
168 |
+
sslabels=['H','E', '~' ]
|
169 |
+
|
170 |
+
#ax = plt.subplot(111, figsize=(4,4))
|
171 |
+
fig, ax = plt.subplots(1, 1, figsize=(6,3))
|
172 |
+
|
173 |
+
|
174 |
+
ax.bar(x-0.15, cond_p, width=0.3, color='b', align='center')
|
175 |
+
ax.bar(x+0.15, sscount, width=0.3, color='r', align='center')
|
176 |
+
|
177 |
+
ax.set_ylim([0, 1])
|
178 |
+
|
179 |
+
plt.xticks(range(len(sslabels)), sslabels, size='medium')
|
180 |
+
plt.legend (['GT','Prediction'])
|
181 |
+
|
182 |
+
plt.ylabel ('SS content')
|
183 |
+
plt.show()
|
184 |
+
|
185 |
+
return error
|
186 |
+
|
187 |
+
def get_DSSP_result (fname):
|
188 |
+
pdb_list = [fname]
|
189 |
+
|
190 |
+
# parse structure
|
191 |
+
p = PDBParser()
|
192 |
+
for i in pdb_list:
|
193 |
+
structure = p.get_structure(i, fname)
|
194 |
+
# use only the first model
|
195 |
+
model = structure[0]
|
196 |
+
# calculate DSSP
|
197 |
+
dssp = DSSP(model, fname, file_type='PDB' )
|
198 |
+
# extract sequence and secondary structure from the DSSP tuple
|
199 |
+
sequence = ''
|
200 |
+
sec_structure = ''
|
201 |
+
for z in range(len(dssp)):
|
202 |
+
a_key = list(dssp.keys())[z]
|
203 |
+
sequence += dssp[a_key][1]
|
204 |
+
sec_structure += dssp[a_key][2]
|
205 |
+
|
206 |
+
#print(i)
|
207 |
+
#print(sequence)
|
208 |
+
#print(sec_structure)
|
209 |
+
#
|
210 |
+
# The DSSP codes for secondary structure used here are:
|
211 |
+
# ===== ====
|
212 |
+
# Code Structure
|
213 |
+
# ===== ====
|
214 |
+
# H Alpha helix (4-12)
|
215 |
+
# B Isolated beta-bridge residue
|
216 |
+
# E Strand
|
217 |
+
# G 3-10 helix
|
218 |
+
# I Pi helix
|
219 |
+
# T Turn
|
220 |
+
# S Bend
|
221 |
+
# ~ None
|
222 |
+
# ===== ====
|
223 |
+
#
|
224 |
+
|
225 |
+
sec_structure = sec_structure.replace('-', '~')
|
226 |
+
sec_structure_3state=sec_structure
|
227 |
+
|
228 |
+
|
229 |
+
# if desired, convert DSSP's 8-state assignments into 3-state [C - coil, E - extended (beta-strand), H - helix]
|
230 |
+
sec_structure_3state = sec_structure_3state.replace('H', 'H') #0
|
231 |
+
sec_structure_3state = sec_structure_3state.replace('E', 'E')
|
232 |
+
sec_structure_3state = sec_structure_3state.replace('T', '~')
|
233 |
+
sec_structure_3state = sec_structure_3state.replace('~', '~')
|
234 |
+
sec_structure_3state = sec_structure_3state.replace('B', 'E')
|
235 |
+
sec_structure_3state = sec_structure_3state.replace('G', 'H') #5
|
236 |
+
sec_structure_3state = sec_structure_3state.replace('I', 'H') #6
|
237 |
+
sec_structure_3state = sec_structure_3state.replace('S', '~')
|
238 |
+
return sec_structure,sec_structure_3state, sequence
|
239 |
+
|
240 |
+
|
241 |
+
def string_diff (seq1, seq2):
|
242 |
+
return sum(1 for a, b in zip(seq1, seq2) if a != b) + abs(len(seq1) - len(seq2))
|
243 |
+
|
244 |
+
|
245 |
+
# ============================================================
|
246 |
+
# on esm, rebuild AA sequence from embedding
|
247 |
+
# ============================================================
|
248 |
+
import esm
|
249 |
+
|
250 |
+
def decode_one_ems_token_rec(this_token, esm_alphabet):
|
251 |
+
# print( (this_token==esm_alphabet.cls_idx).nonzero(as_tuple=True)[0] )
|
252 |
+
# print( (this_token==esm_alphabet.eos_idx).nonzero(as_tuple=True)[0] )
|
253 |
+
# print( (this_token==100).nonzero(as_tuple=True)[0]==None )
|
254 |
+
|
255 |
+
id_b=(this_token==esm_alphabet.cls_idx).nonzero(as_tuple=True)[0]
|
256 |
+
id_e=(this_token==esm_alphabet.eos_idx).nonzero(as_tuple=True)[0]
|
257 |
+
|
258 |
+
|
259 |
+
if len(id_e)==0:
|
260 |
+
# no ending for this one, so id_e points to the end
|
261 |
+
id_e=len(this_token)
|
262 |
+
else:
|
263 |
+
id_e=id_e[0]
|
264 |
+
if len(id_b)==0:
|
265 |
+
id_b=0
|
266 |
+
else:
|
267 |
+
id_b=id_b[-1]
|
268 |
+
|
269 |
+
this_seq = []
|
270 |
+
# this_token_used = []
|
271 |
+
for ii in range(id_b+1,id_e,1):
|
272 |
+
# this_token_used.append(this_token[ii])
|
273 |
+
this_seq.append(
|
274 |
+
esm_alphabet.get_tok(this_token[ii])
|
275 |
+
)
|
276 |
+
|
277 |
+
this_seq = "".join(this_seq)
|
278 |
+
|
279 |
+
# print(this_seq)
|
280 |
+
# print(len(this_seq))
|
281 |
+
# # print(this_token[id_b+1:id_e])
|
282 |
+
return this_seq
|
283 |
+
|
284 |
+
|
285 |
+
def decode_many_ems_token_rec(batch_tokens, esm_alphabet):
|
286 |
+
rev_y_seq = []
|
287 |
+
for jj in range(len(batch_tokens)):
|
288 |
+
# do for one seq: this_seq
|
289 |
+
this_seq = decode_one_ems_token_rec(
|
290 |
+
batch_tokens[jj], esm_alphabet
|
291 |
+
)
|
292 |
+
rev_y_seq.append(this_seq)
|
293 |
+
return rev_y_seq
|
294 |
+
|
295 |
+
# ++ for omegafold sequence: treat unknows as X
|
296 |
+
uncomm_idx_list = [0, 1, 2, 3, 24, 25, 26, 27, 28, 29, 30, 31, 32]
|
297 |
+
|
298 |
+
# this one decide the beginning and ending AUTOMATICALLY
|
299 |
+
def decode_one_ems_token_rec_for_folding(
|
300 |
+
this_token,
|
301 |
+
this_logits,
|
302 |
+
esm_alphabet,
|
303 |
+
esm_model):
|
304 |
+
|
305 |
+
# print( (this_token==esm_alphabet.cls_idx).nonzero(as_tuple=True)[0] )
|
306 |
+
# print( (this_token==esm_alphabet.eos_idx).nonzero(as_tuple=True)[0] )
|
307 |
+
# print( (this_token==100).nonzero(as_tuple=True)[0]==None )
|
308 |
+
|
309 |
+
# 1. use this_token to find the beginning and ending
|
310 |
+
# 2. to logits to generate tokens that ONLY contains foldable AAs
|
311 |
+
#
|
312 |
+
id_b_0=(this_token==esm_alphabet.cls_idx).nonzero(as_tuple=True)[0]
|
313 |
+
id_e_0=(this_token==esm_alphabet.eos_idx).nonzero(as_tuple=True)[0]
|
314 |
+
|
315 |
+
# ------------------------------------------------------------------
|
316 |
+
# principle:
|
317 |
+
# 1. begin at 0th
|
318 |
+
# 2. end as soon as possible: relay on that the first endding is learned
|
319 |
+
id_b = 0
|
320 |
+
#
|
321 |
+
if len(id_e_0)==0:
|
322 |
+
id_e=len(this_token)
|
323 |
+
else:
|
324 |
+
id_e=id_e_0[0]
|
325 |
+
# correct if needed
|
326 |
+
if id_e<=id_b+1:
|
327 |
+
if len(id_e_0)>1:
|
328 |
+
id_e=id_e_0[1]
|
329 |
+
else:
|
330 |
+
id_e=len(this_token)
|
331 |
+
# -------------------------------------------------------------------
|
332 |
+
|
333 |
+
# # ------------------------------------------------------------------
|
334 |
+
# # not perfect
|
335 |
+
# # principle:
|
336 |
+
# # 1. begin as late as possible
|
337 |
+
# # 2. end as soon as possible
|
338 |
+
# #
|
339 |
+
# if len(id_b_0)==0:
|
340 |
+
# id_b=0
|
341 |
+
# else:
|
342 |
+
# id_b=id_b_0[-1]
|
343 |
+
# # so, beginning is set
|
344 |
+
# # looking for the nearest ending signal if we can find one
|
345 |
+
# # 1. pick those in id_e that id_b<id_e
|
346 |
+
# id_e_1=[]
|
347 |
+
# for this_e in id_e_0:
|
348 |
+
# if this_e>id_b:
|
349 |
+
# id_e_1.append(this_e)
|
350 |
+
# # 2. check what we find
|
351 |
+
# if len(id_e_1)==0:
|
352 |
+
# # no endding, id_e points to the end
|
353 |
+
# id_e=len(this_token)
|
354 |
+
# else:
|
355 |
+
# # otherwise, find endding point and pick the first one
|
356 |
+
# id_e=id_e_1[0]
|
357 |
+
# # 3. if id_b+1==id_e, we still get nothing. So, this is a fake fix
|
358 |
+
# if id_e==id_b+1:
|
359 |
+
# if len(id_e_1)>1:
|
360 |
+
# id_e=id_e_1[1]
|
361 |
+
# else:
|
362 |
+
# id_e=len(this_token)
|
363 |
+
# # --------------------------------------------------------------------
|
364 |
+
|
365 |
+
# if id_b>id_e:
|
366 |
+
# for debug:
|
367 |
+
print("start at: ", id_b)
|
368 |
+
print("end at: ", id_e)
|
369 |
+
|
370 |
+
# along the sequence, we pick only index [id_b+1:id_e]. This exclude the <cls> and <eos>
|
371 |
+
use_logits = this_logits[id_b+1:id_e] # (seq_len_eff, token_len)
|
372 |
+
use_logits[:,uncomm_idx_list]=-float('inf')
|
373 |
+
use_token = use_logits.max(1).indices
|
374 |
+
|
375 |
+
# print(use_token)
|
376 |
+
|
377 |
+
this_seq = []
|
378 |
+
# this_token_used = []
|
379 |
+
# for ii in range(id_b+1,id_e,1):
|
380 |
+
for ii in range(len(use_token)):
|
381 |
+
# this_token_used.append(this_token[ii])
|
382 |
+
# print(esm_alphabet.get_tok(use_token[ii]))
|
383 |
+
# print(ii)
|
384 |
+
this_seq.append(
|
385 |
+
esm_alphabet.get_tok(use_token[ii])
|
386 |
+
)
|
387 |
+
|
388 |
+
this_seq = "".join(this_seq)
|
389 |
+
|
390 |
+
# # generate a foldable sequece
|
391 |
+
# # map all uncommon ones into X/24
|
392 |
+
# for idx, one_token in enumerate( this_token_used):
|
393 |
+
# find_it=0
|
394 |
+
# for this_uncomm in uncomm_idx_list:
|
395 |
+
# find_id=find_id+(this_uncomm==one_token)
|
396 |
+
# #
|
397 |
+
# if find_id>0:
|
398 |
+
# this_token_used[idx]=24 # 24 means X
|
399 |
+
# # translate token into sequences
|
400 |
+
# this_seq_foldable=[]
|
401 |
+
# for one_token in this_token_used:
|
402 |
+
# this_seq_foldable.append(
|
403 |
+
# esm_alphabet.get_tok(one_token)
|
404 |
+
# )
|
405 |
+
|
406 |
+
# # print(this_seq)
|
407 |
+
# # print(len(this_seq))
|
408 |
+
# # # print(this_token[id_b+1:id_e])
|
409 |
+
# return this_seq, this_seq_foldable
|
410 |
+
return this_seq
|
411 |
+
|
412 |
+
|
413 |
+
def decode_many_ems_token_rec_for_folding(
|
414 |
+
batch_tokens,
|
415 |
+
batch_logits,
|
416 |
+
esm_alphabet,
|
417 |
+
esm_model):
|
418 |
+
|
419 |
+
rev_y_seq = []
|
420 |
+
for jj in range(len(batch_tokens)):
|
421 |
+
# do for one seq: this_seq
|
422 |
+
this_seq = decode_one_ems_token_rec_for_folding(
|
423 |
+
batch_tokens[jj],
|
424 |
+
batch_logits[jj],
|
425 |
+
esm_alphabet,
|
426 |
+
esm_model,
|
427 |
+
)
|
428 |
+
rev_y_seq.append(this_seq)
|
429 |
+
return rev_y_seq
|
430 |
+
|
431 |
+
|
432 |
+
def convert_into_logits(esm_model, result):
|
433 |
+
repre=rearrange(
|
434 |
+
result,
|
435 |
+
'b l c -> b c l'
|
436 |
+
)
|
437 |
+
with torch.no_grad():
|
438 |
+
logits=esm_model.lm_head(repre)
|
439 |
+
|
440 |
+
return logits
|
441 |
+
|
442 |
+
# this one return the unmodified tokens and logits
|
443 |
+
def convert_into_tokens(model, result, pLM_Model_Name):
|
444 |
+
if pLM_Model_Name=='esm2_t33_650M_UR50D' \
|
445 |
+
or pLM_Model_Name=='esm2_t36_3B_UR50D' \
|
446 |
+
or pLM_Model_Name=='esm2_t30_150M_UR50D' \
|
447 |
+
or pLM_Model_Name=='esm2_t12_35M_UR50D' :
|
448 |
+
|
449 |
+
repre=rearrange(
|
450 |
+
result,
|
451 |
+
'b c l -> b l c'
|
452 |
+
)
|
453 |
+
with torch.no_grad():
|
454 |
+
logits=model.lm_head(repre) # (b, l, token_dim)
|
455 |
+
|
456 |
+
tokens=logits.max(2).indices # (b,l)
|
457 |
+
|
458 |
+
else:
|
459 |
+
print("pLM_Model is not defined...")
|
460 |
+
return tokens,logits
|
461 |
+
# ++
|
462 |
+
def convert_into_tokens_using_prob(prob_result, pLM_Model_Name):
|
463 |
+
if pLM_Model_Name=='esm2_t33_650M_UR50D' \
|
464 |
+
or pLM_Model_Name=='esm2_t36_3B_UR50D' \
|
465 |
+
or pLM_Model_Name=='esm2_t30_150M_UR50D' \
|
466 |
+
or pLM_Model_Name=='esm2_t12_35M_UR50D' :
|
467 |
+
|
468 |
+
repre=rearrange(
|
469 |
+
prob_result,
|
470 |
+
'b c l -> b l c'
|
471 |
+
)
|
472 |
+
# with torch.no_grad():
|
473 |
+
# logits=model.lm_head(repre) # (b, l, token_dim)
|
474 |
+
logits = repre
|
475 |
+
|
476 |
+
tokens=logits.max(2).indices # (b,l)
|
477 |
+
|
478 |
+
else:
|
479 |
+
print("pLM_Model is not defined...")
|
480 |
+
return tokens,logits
|
481 |
+
|
482 |
+
|
483 |
+
#
|
484 |
+
def read_mask_from_input(
|
485 |
+
# consider different type of inputs
|
486 |
+
# raw data: x_data (sequences)
|
487 |
+
# tokenized: x_data_tokenized
|
488 |
+
tokenized_data=None, # X_train_batch,
|
489 |
+
mask_value=None,
|
490 |
+
seq_data=None,
|
491 |
+
max_seq_length=None,
|
492 |
+
):
|
493 |
+
# # old:
|
494 |
+
# mask = X_train_batch!=mask_value
|
495 |
+
# new
|
496 |
+
if seq_data!=None:
|
497 |
+
# use the real sequence length to create mask
|
498 |
+
n_seq = len(seq_data)
|
499 |
+
mask = torch.zeros(n_seq, max_seq_length)
|
500 |
+
for ii in range(n_seq):
|
501 |
+
this_len = len(seq_data[ii])
|
502 |
+
mask[ii,1:1+this_len]=1
|
503 |
+
mask = mask==1
|
504 |
+
#
|
505 |
+
elif tokenized_data!=None:
|
506 |
+
n_seq = len(tokenized_data)
|
507 |
+
mask = tokenized_data!=mask_value
|
508 |
+
# fix the beginning part: 0+content+00, not 00+content+00
|
509 |
+
for ii in range(n_seq):
|
510 |
+
# get all nonzero index
|
511 |
+
id_1 = (mask[ii]==True).nonzero(as_tuple=True)[0]
|
512 |
+
# correction for ForcPath,
|
513 |
+
# pick up 0.0 for zero-force padding at the beginning
|
514 |
+
mask[ii,1:id_1[0]]=True
|
515 |
+
|
516 |
+
return mask
|
517 |
+
|
518 |
+
# ++ read one length
|
519 |
+
def read_one_len_from_padding_vec(
|
520 |
+
in_np_array,
|
521 |
+
padding_val=0.0,
|
522 |
+
):
|
523 |
+
mask = in_np_array!=padding_val
|
524 |
+
id_list_all_1 = mask.nonzero()[0]
|
525 |
+
vec_len = id_list_all_1[-1]+1
|
526 |
+
|
527 |
+
return vec_len
|
528 |
+
|
529 |
+
|
530 |
+
# this one decide the beginning and ending using mask
|
531 |
+
def decode_one_ems_token_rec_for_folding_with_mask(
|
532 |
+
this_token,
|
533 |
+
this_logits,
|
534 |
+
esm_alphabet,
|
535 |
+
esm_model,
|
536 |
+
this_mask,
|
537 |
+
):
|
538 |
+
# translate all logits into tokens then screen the unmaksed part
|
539 |
+
|
540 |
+
|
541 |
+
# along the sequence, we pick only index [id_b+1:id_e]. This exclude the <cls> and <eos>
|
542 |
+
use_logits = this_logits # (seq_len_eff, token_len)
|
543 |
+
use_logits[:,uncomm_idx_list]=-float('inf')
|
544 |
+
use_token = use_logits.max(1).indices
|
545 |
+
#
|
546 |
+
print(use_token)
|
547 |
+
use_token = use_token[this_mask==True]
|
548 |
+
# print(use_token)
|
549 |
+
|
550 |
+
this_seq = []
|
551 |
+
# this_token_used = []
|
552 |
+
# for ii in range(id_b+1,id_e,1):
|
553 |
+
for ii in range(len(use_token)):
|
554 |
+
# this_token_used.append(this_token[ii])
|
555 |
+
# print(esm_alphabet.get_tok(use_token[ii]))
|
556 |
+
# print(ii)
|
557 |
+
this_seq.append(
|
558 |
+
esm_alphabet.get_tok(use_token[ii])
|
559 |
+
)
|
560 |
+
|
561 |
+
this_seq = "".join(this_seq)
|
562 |
+
|
563 |
+
return this_seq
|
564 |
+
|
565 |
+
def decode_many_ems_token_rec_for_folding_with_mask(
|
566 |
+
batch_tokens,
|
567 |
+
batch_logits,
|
568 |
+
esm_alphabet,
|
569 |
+
esm_model,
|
570 |
+
mask):
|
571 |
+
|
572 |
+
rev_y_seq = []
|
573 |
+
for jj in range(len(batch_tokens)):
|
574 |
+
# do for one seq: this_seq
|
575 |
+
this_seq = decode_one_ems_token_rec_for_folding_with_mask(
|
576 |
+
batch_tokens[jj],
|
577 |
+
batch_logits[jj],
|
578 |
+
esm_alphabet,
|
579 |
+
esm_model,
|
580 |
+
mask[jj]
|
581 |
+
)
|
582 |
+
rev_y_seq.append(this_seq)
|
583 |
+
return rev_y_seq
|
584 |
+
|
585 |
+
# =====================================================
|
586 |
+
# create new input condition for ForcPath case
|
587 |
+
# =====================================================
|
588 |
+
from scipy import interpolate
|
589 |
+
|
590 |
+
def interpolate_and_resample_ForcPath(y0,seq_len1):
|
591 |
+
seq_len0=len(y0)-1
|
592 |
+
x0=np.arange(0., 1.+1./seq_len0, 1./seq_len0)
|
593 |
+
f=interpolate.interp1d(x0,y0)
|
594 |
+
#
|
595 |
+
x1=np.arange(0., 1.+1./seq_len1, 1./seq_len1)
|
596 |
+
y1=f(x1)
|
597 |
+
#
|
598 |
+
resu = {}
|
599 |
+
resu['y1']=y1
|
600 |
+
resu['x1']=x1
|
601 |
+
resu['x0']=x0
|
602 |
+
return resu
|
603 |
+
#
|
604 |
+
def mix_two_ForcPath(y0,y1,seq_len2):
|
605 |
+
seq_len0=len(y0)-1
|
606 |
+
x0=np.arange(0., 1.+1./seq_len0, 1./seq_len0)
|
607 |
+
seq_len1=len(y1)-1
|
608 |
+
x1=np.arange(0., 1.+1./seq_len1, 1./seq_len1)
|
609 |
+
f0=interpolate.interp1d(x0,y0)
|
610 |
+
f1=interpolate.interp1d(x1,y1)
|
611 |
+
#
|
612 |
+
x2=np.arange(0., 1.+1./seq_len2, 1./seq_len2)
|
613 |
+
y2=(f0(x2)+f1(x2))/1.
|
614 |
+
#
|
615 |
+
resu={}
|
616 |
+
resu['y2']=y2
|
617 |
+
resu['x2']=x2
|
618 |
+
resu['x1']=x1
|
619 |
+
resu['x0']=x0
|
620 |
+
return resu
|
621 |
+
#
|
622 |
+
# =====================================================
|
623 |
+
# load in function for language model
|
624 |
+
# =====================================================
|
625 |
+
import esm
|
626 |
+
|
627 |
+
def load_in_pLM(pLM_Model_Name,device):
|
628 |
+
#
|
629 |
+
# ++ for pLM
|
630 |
+
if pLM_Model_Name=='trivial':
|
631 |
+
pLM_Model=None
|
632 |
+
esm_alphabet=None
|
633 |
+
len_toks=0
|
634 |
+
esm_layer=0
|
635 |
+
|
636 |
+
elif pLM_Model_Name=='esm2_t33_650M_UR50D':
|
637 |
+
# dim: 1280
|
638 |
+
esm_layer=33
|
639 |
+
pLM_Model, esm_alphabet = esm.pretrained.esm2_t33_650M_UR50D()
|
640 |
+
len_toks=len(esm_alphabet.all_toks)
|
641 |
+
pLM_Model.eval()
|
642 |
+
pLM_Model. to(device)
|
643 |
+
|
644 |
+
elif pLM_Model_Name=='esm2_t36_3B_UR50D':
|
645 |
+
# dim: 2560
|
646 |
+
esm_layer=36
|
647 |
+
pLM_Model, esm_alphabet = esm.pretrained.esm2_t36_3B_UR50D()
|
648 |
+
len_toks=len(esm_alphabet.all_toks)
|
649 |
+
pLM_Model.eval()
|
650 |
+
pLM_Model. to(device)
|
651 |
+
|
652 |
+
elif pLM_Model_Name=='esm2_t30_150M_UR50D':
|
653 |
+
# dim: 640
|
654 |
+
esm_layer=30
|
655 |
+
pLM_Model, esm_alphabet = esm.pretrained.esm2_t30_150M_UR50D()
|
656 |
+
len_toks=len(esm_alphabet.all_toks)
|
657 |
+
pLM_Model.eval()
|
658 |
+
pLM_Model. to(device)
|
659 |
+
|
660 |
+
elif pLM_Model_Name=='esm2_t12_35M_UR50D':
|
661 |
+
# dim: 480
|
662 |
+
esm_layer=12
|
663 |
+
pLM_Model, esm_alphabet = esm.pretrained.esm2_t12_35M_UR50D()
|
664 |
+
len_toks=len(esm_alphabet.all_toks)
|
665 |
+
pLM_Model.eval()
|
666 |
+
pLM_Model. to(device)
|
667 |
+
|
668 |
+
else:
|
669 |
+
print("pLM model is missing...")
|
670 |
+
|
671 |
+
return pLM_Model, esm_alphabet, esm_layer, len_toks
|