edmundmiller commited on
Commit
615f84a
1 Parent(s): 89059a8

Add predict_chromosome

Browse files

https://github.com/JinLabBioinfo/DeepLoop/blob/af3186196c1a1a7ad3a3f131d3377cb06a304730/prediction/predict_chromosome.py#L100

Files changed (2) hide show
  1. __init__.py +1 -0
  2. predict_chromosome.py +322 -0
__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+
predict_chromosome.py ADDED
@@ -0,0 +1,322 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import argparse
4
+ import pandas as pd
5
+ import numpy as np
6
+ import time
7
+ from tqdm import tqdm
8
+ from tensorflow.keras.models import model_from_json
9
+ from scipy.sparse import csr_matrix, triu
10
+
11
+
12
+ def anchor_list_to_dict(anchors):
13
+ anchor_dict = {}
14
+ for i, anchor in enumerate(anchors):
15
+ anchor_dict[anchor] = i
16
+ return anchor_dict
17
+
18
+
19
+ def anchor_to_locus(anchor_dict):
20
+ def f(anchor):
21
+ return anchor_dict[anchor]
22
+
23
+ return f
24
+
25
+
26
+ def locus_to_anchor(anchor_list):
27
+ def f(locus):
28
+ return anchor_list[locus]
29
+
30
+ return f
31
+
32
+
33
+ def predict_tile(args):
34
+ model, shared_denoised, shared_overlap, matrix, window_x, window_y = args
35
+ tile = matrix[window_x, window_y].A # split matrix into tiles
36
+ if tile.shape == (small_matrix_size, small_matrix_size):
37
+ tile = np.expand_dims(tile, 0) # add channel dimension
38
+ tile = np.expand_dims(tile, 3) # add batch dimension
39
+ tmp_denoised = np.ctypeslib.as_array(shared_denoised)
40
+ tmp_overlap = np.ctypeslib.as_array(shared_overlap)
41
+ denoised = model.predict(tile).reshape((small_matrix_size, small_matrix_size))
42
+ denoised[denoised < 0] = 0 # remove any negative values
43
+ tmp_denoised[window_x, window_y] += denoised
44
+ tmp_overlap[window_x, window_y] += 1
45
+
46
+
47
+ def sparse_prediction_from_file(
48
+ model,
49
+ matrix,
50
+ anchor_list,
51
+ small_matrix_size=128,
52
+ step_size=64,
53
+ max_dist=384,
54
+ keep_zeros=True,
55
+ ):
56
+ input_matrix_size = len(anchor_list)
57
+ denoised_matrix = np.zeros_like(matrix.A) # matrix to store denoised values
58
+ overlap_counts = np.zeros_like(
59
+ matrix.A
60
+ ) # stores number of overlaps per ratio value
61
+
62
+ start_time = time.time()
63
+
64
+ for i in range(0, input_matrix_size, step_size):
65
+ for j in range(0, input_matrix_size, step_size):
66
+ if abs(i - j) > max_dist: # max distance from diagonal with actual values
67
+ continue
68
+ rows = slice(i, i + small_matrix_size)
69
+ cols = slice(j, j + small_matrix_size)
70
+ if i + small_matrix_size >= input_matrix_size:
71
+ rows = slice(input_matrix_size - small_matrix_size, input_matrix_size)
72
+ if j + small_matrix_size >= input_matrix_size:
73
+ cols = slice(input_matrix_size - small_matrix_size, input_matrix_size)
74
+ tile = matrix[rows, cols].A # split matrix into tiles
75
+ if tile.shape == (small_matrix_size, small_matrix_size):
76
+ tile = np.expand_dims(tile, 0) # add channel dimension
77
+ tile = np.expand_dims(tile, 3) # add batch dimension
78
+ denoised = model.predict(tile).reshape(
79
+ (small_matrix_size, small_matrix_size)
80
+ )
81
+ denoised[denoised < 0] = 0 # remove any negative values
82
+ denoised_matrix[
83
+ rows, cols
84
+ ] += denoised # add denoised ratio values to whole matrix
85
+ overlap_counts[
86
+ rows, cols
87
+ ] += 1 # add to all overlap values within tiled region
88
+
89
+ # print('Predicted matrix in %d seconds' % (time.time() - start_time))
90
+ # start_time = time.time()
91
+ denoised_matrix = np.divide(
92
+ denoised_matrix,
93
+ overlap_counts,
94
+ out=np.zeros_like(denoised_matrix),
95
+ where=overlap_counts != 0,
96
+ ) # average all overlapping areas
97
+
98
+ denoised_matrix = (denoised_matrix + denoised_matrix.T) * 0.5 # force symmetry
99
+
100
+ np.fill_diagonal(denoised_matrix, 0) # set all diagonal values to 0
101
+
102
+ sparse_denoised_matrix = triu(denoised_matrix, format="coo")
103
+
104
+ if not keep_zeros:
105
+ sparse_denoised_matrix.eliminate_zeros()
106
+
107
+ # print('Averaging/symmetry, and converting to COO matrix in %d seconds' % (time.time() - start_time))
108
+
109
+ return sparse_denoised_matrix
110
+
111
+
112
+ def predict_and_write(
113
+ model,
114
+ full_matrix_dir,
115
+ input_name,
116
+ out_dir,
117
+ anchor_dir,
118
+ chromosome,
119
+ small_matrix_size,
120
+ step_size,
121
+ dummy=5,
122
+ max_dist=384,
123
+ val_cols=["obs", "exp"],
124
+ keep_zeros=True,
125
+ matrices_per_tile=8,
126
+ ):
127
+ start_time = time.time()
128
+ anchor_file = os.path.join(anchor_dir, chromosome + ".bed")
129
+ anchor_list = pd.read_csv(
130
+ anchor_file,
131
+ sep="\t",
132
+ usecols=[0, 1, 2, 3],
133
+ names=["chr", "start", "end", "anchor"],
134
+ ) # read anchor list file
135
+ start_time = time.time()
136
+ chr_anchor_file = pd.read_csv(
137
+ os.path.join(full_matrix_dir, input_name),
138
+ delimiter="\t",
139
+ names=["anchor1", "anchor2"] + val_cols,
140
+ usecols=["anchor1", "anchor2"] + val_cols,
141
+ ) # read chromosome anchor to anchor file
142
+ if "obs" in val_cols and "exp" in val_cols:
143
+ chr_anchor_file["ratio"] = (chr_anchor_file["obs"] + dummy) / (
144
+ chr_anchor_file["exp"] + dummy
145
+ ) # compute matrix ratio value
146
+ assert (
147
+ "ratio" not in val_cols
148
+ ), "Must provide either ratio column or obs and exp columns to compute ratio"
149
+
150
+ denoised_anchor_to_anchor = pd.DataFrame()
151
+
152
+ start_time = time.time()
153
+
154
+ anchor_step = matrices_per_tile * small_matrix_size
155
+
156
+ for i in tqdm(range(0, len(anchor_list), anchor_step)):
157
+ anchors = anchor_list[i : i + anchor_step]
158
+ # print(anchors)
159
+ anchor_dict = anchor_list_to_dict(
160
+ anchors["anchor"].values
161
+ ) # convert to anchor --> index dictionary
162
+ chr_tile = chr_anchor_file[
163
+ (chr_anchor_file["anchor1"].isin(anchors["anchor"]))
164
+ & (chr_anchor_file["anchor2"].isin(anchors["anchor"]))
165
+ ]
166
+ rows = np.vectorize(anchor_to_locus(anchor_dict))(
167
+ chr_tile["anchor1"].values
168
+ ) # convert anchor names to row indices
169
+ cols = np.vectorize(anchor_to_locus(anchor_dict))(
170
+ chr_tile["anchor2"].values
171
+ ) # convert anchor names to column indices
172
+ sparse_matrix = csr_matrix(
173
+ (chr_tile["ratio"], (rows, cols)), shape=(anchor_step, anchor_step)
174
+ ) # construct sparse CSR matrix
175
+
176
+ sparse_denoised_tile = sparse_prediction_from_file(
177
+ model,
178
+ sparse_matrix,
179
+ anchors,
180
+ small_matrix_size,
181
+ step_size,
182
+ max_dist,
183
+ keep_zeros=keep_zeros,
184
+ )
185
+ if len(sparse_denoised_tile.row) > 0:
186
+ anchor_name_list = anchors["anchor"].values.tolist()
187
+
188
+ anchor_1_list = np.vectorize(locus_to_anchor(anchor_name_list))(
189
+ sparse_denoised_tile.row
190
+ )
191
+ anchor_2_list = np.vectorize(locus_to_anchor(anchor_name_list))(
192
+ sparse_denoised_tile.col
193
+ )
194
+
195
+ anchor_to_anchor_dict = {
196
+ "anchor1": anchor_1_list,
197
+ "anchor2": anchor_2_list,
198
+ "denoised": sparse_denoised_tile.data,
199
+ }
200
+
201
+ tile_anchor_to_anchor = pd.DataFrame.from_dict(anchor_to_anchor_dict)
202
+ tile_anchor_to_anchor = tile_anchor_to_anchor.round({"denoised": 4})
203
+ denoised_anchor_to_anchor = pd.concat(
204
+ [denoised_anchor_to_anchor, tile_anchor_to_anchor]
205
+ )
206
+
207
+ print("Denoised matrix in %d seconds" % (time.time() - start_time))
208
+ start_time = time.time()
209
+
210
+ denoised_anchor_to_anchor.to_csv(
211
+ os.path.join(out_dir, chromosome + ".denoised.anchor.to.anchor"),
212
+ sep="\t",
213
+ index=False,
214
+ header=False,
215
+ )
216
+
217
+
218
+ if __name__ == "__main__":
219
+ parser = argparse.ArgumentParser()
220
+ parser.add_argument(
221
+ "--full_matrix_dir",
222
+ type=str,
223
+ help="directory containing chromosome interaction files to be used as input",
224
+ )
225
+ parser.add_argument(
226
+ "--input_name",
227
+ type=str,
228
+ help="name of file in full_matrix_dir that we want to feed into model",
229
+ )
230
+ parser.add_argument("--h5_file", type=str, help="path to model weights .h5 file")
231
+ parser.add_argument(
232
+ "--json_file",
233
+ type=str,
234
+ help="path to model architecture .json file (by default it is assumed to be the same as the weights file)",
235
+ )
236
+ parser.add_argument(
237
+ "--out_dir",
238
+ type=str,
239
+ help="directory where the output interaction file will be stored",
240
+ )
241
+ parser.add_argument(
242
+ "--anchor_dir",
243
+ type=str,
244
+ help="directory containing anchor .bed reference files",
245
+ )
246
+ parser.add_argument(
247
+ "--chromosome", type=str, help="chromosome string (e.g chr1, chr20, chrX)"
248
+ )
249
+ parser.add_argument(
250
+ "--small_matrix_size",
251
+ type=int,
252
+ default=128,
253
+ help="size of input tiles (symmetric)",
254
+ )
255
+ parser.add_argument(
256
+ "--step_size",
257
+ type=int,
258
+ default=128,
259
+ help="step size when tiling matrix (overlapping values will be averaged if different)",
260
+ )
261
+ parser.add_argument(
262
+ "--max_dist",
263
+ type=int,
264
+ default=384,
265
+ help="maximum distance from diagonal (in pixels) where we consider interactions (default to ~2Mb)",
266
+ )
267
+ parser.add_argument(
268
+ "--dummy",
269
+ type=int,
270
+ default=5,
271
+ help="dummy value to compute ratio (obs + dummy) / (exp + dummy)",
272
+ )
273
+ parser.add_argument(
274
+ "--val_cols",
275
+ "--list",
276
+ nargs="+",
277
+ help="names of value columns in interaction files (not including a1, a2)",
278
+ default=["obs", "exp"],
279
+ )
280
+ parser.add_argument(
281
+ "--keep_zeros",
282
+ action="store_true",
283
+ help="if provided, the output file will contain all pixels in every tile, even if no value is present",
284
+ )
285
+ args = parser.parse_args()
286
+
287
+ full_matrix_dir = args.full_matrix_dir
288
+ input_name = args.input_name
289
+ h5_file = args.h5_file
290
+ if args.json_file is not None:
291
+ json_file = args.json_file
292
+ else:
293
+ json_file = args.h5_file.replace("h5", "json")
294
+ out_dir = args.out_dir
295
+ anchor_dir = args.anchor_dir
296
+ chromosome = args.chromosome
297
+ small_matrix_size = args.small_matrix_size
298
+ step_size = args.step_size
299
+ dummy = args.dummy
300
+ max_dist = args.max_dist
301
+ val_cols = args.val_cols
302
+ keep_zeros = args.keep_zeros
303
+
304
+ os.makedirs(out_dir, exist_ok=True)
305
+
306
+ with open(json_file, "r") as f:
307
+ model = model_from_json(f.read()) # load model
308
+ model.load_weights(h5_file) # load model weights
309
+ predict_and_write(
310
+ model,
311
+ full_matrix_dir,
312
+ input_name,
313
+ out_dir,
314
+ anchor_dir,
315
+ chromosome,
316
+ small_matrix_size,
317
+ step_size,
318
+ dummy,
319
+ max_dist,
320
+ val_cols,
321
+ keep_zeros,
322
+ )