Spaces:
Sleeping
Sleeping
Add tiger file
Browse files
tiger.py
ADDED
@@ -0,0 +1,417 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
import gzip
|
4 |
+
import pickle
|
5 |
+
import numpy as np
|
6 |
+
import pandas as pd
|
7 |
+
import tensorflow as tf
|
8 |
+
from Bio import SeqIO
|
9 |
+
|
10 |
+
# column names
|
11 |
+
ID_COL = 'Transcript ID'
|
12 |
+
SEQ_COL = 'Transcript Sequence'
|
13 |
+
TARGET_COL = 'Target Sequence'
|
14 |
+
GUIDE_COL = 'Guide Sequence'
|
15 |
+
MM_COL = 'Number of Mismatches'
|
16 |
+
SCORE_COL = 'Guide Score'
|
17 |
+
|
18 |
+
# nucleotide tokens
|
19 |
+
NUCLEOTIDE_TOKENS = dict(zip(['A', 'C', 'G', 'T', 'N'], [0, 1, 2, 3, 255]))
|
20 |
+
NUCLEOTIDE_COMPLEMENT = dict(zip(['A', 'C', 'G', 'T'], ['T', 'G', 'C', 'A']))
|
21 |
+
|
22 |
+
# model hyper-parameters
|
23 |
+
GUIDE_LEN = 23
|
24 |
+
CONTEXT_5P = 3
|
25 |
+
CONTEXT_3P = 0
|
26 |
+
TARGET_LEN = CONTEXT_5P + GUIDE_LEN + CONTEXT_3P
|
27 |
+
UNIT_INTERVAL_MAP = 'sigmoid'
|
28 |
+
|
29 |
+
# reference transcript files
|
30 |
+
REFERENCE_TRANSCRIPTS = ('gencode.v19.pc_transcripts.fa.gz', 'gencode.v19.lncRNA_transcripts.fa.gz')
|
31 |
+
|
32 |
+
# application configuration
|
33 |
+
BATCH_SIZE_COMPUTE = 500
|
34 |
+
BATCH_SIZE_SCAN = 20
|
35 |
+
BATCH_SIZE_TRANSCRIPTS = 50
|
36 |
+
NUM_TOP_GUIDES = 10
|
37 |
+
NUM_MISMATCHES = 3
|
38 |
+
RUN_MODES = dict(
|
39 |
+
all='All on-target guides per transcript',
|
40 |
+
top_guides='Top {:d} guides per transcript'.format(NUM_TOP_GUIDES),
|
41 |
+
titration='Top {:d} guides per transcript & their titration candidates'.format(NUM_TOP_GUIDES)
|
42 |
+
)
|
43 |
+
|
44 |
+
|
45 |
+
# configure GPUs
|
46 |
+
for gpu in tf.config.list_physical_devices('GPU'):
|
47 |
+
tf.config.experimental.set_memory_growth(gpu, enable=True)
|
48 |
+
if len(tf.config.list_physical_devices('GPU')) > 0:
|
49 |
+
tf.config.experimental.set_visible_devices(tf.config.list_physical_devices('GPU')[0], 'GPU')
|
50 |
+
|
51 |
+
|
52 |
+
def load_transcripts(fasta_files: list, enforce_unique_ids: bool = True):
|
53 |
+
|
54 |
+
# load all transcripts from fasta files into a DataFrame
|
55 |
+
transcripts = pd.DataFrame()
|
56 |
+
for file in fasta_files:
|
57 |
+
try:
|
58 |
+
if os.path.splitext(file)[1] == '.gz':
|
59 |
+
with gzip.open(file, 'rt') as f:
|
60 |
+
df = pd.DataFrame([(t.id, str(t.seq)) for t in SeqIO.parse(f, 'fasta')], columns=[ID_COL, SEQ_COL])
|
61 |
+
else:
|
62 |
+
df = pd.DataFrame([(t.id, str(t.seq)) for t in SeqIO.parse(file, 'fasta')], columns=[ID_COL, SEQ_COL])
|
63 |
+
except Exception as e:
|
64 |
+
print(e, 'while loading', file)
|
65 |
+
continue
|
66 |
+
transcripts = pd.concat([transcripts, df])
|
67 |
+
|
68 |
+
# set index
|
69 |
+
transcripts[ID_COL] = transcripts[ID_COL].apply(lambda s: s.split('|')[0])
|
70 |
+
transcripts.set_index(ID_COL, inplace=True)
|
71 |
+
if enforce_unique_ids:
|
72 |
+
assert not transcripts.index.has_duplicates, "duplicate transcript ID's detected in fasta file"
|
73 |
+
|
74 |
+
return transcripts
|
75 |
+
|
76 |
+
|
77 |
+
def sequence_complement(sequence: list):
|
78 |
+
return [''.join([NUCLEOTIDE_COMPLEMENT[nt] for nt in list(seq)]) for seq in sequence]
|
79 |
+
|
80 |
+
|
81 |
+
def one_hot_encode_sequence(sequence: list, add_context_padding: bool = False):
|
82 |
+
|
83 |
+
# stack list of sequences into a tensor
|
84 |
+
sequence = tf.ragged.stack([tf.constant(list(seq)) for seq in sequence], axis=0)
|
85 |
+
|
86 |
+
# tokenize sequence
|
87 |
+
nucleotide_table = tf.lookup.StaticVocabularyTable(
|
88 |
+
initializer=tf.lookup.KeyValueTensorInitializer(
|
89 |
+
keys=tf.constant(list(NUCLEOTIDE_TOKENS.keys()), dtype=tf.string),
|
90 |
+
values=tf.constant(list(NUCLEOTIDE_TOKENS.values()), dtype=tf.int64)),
|
91 |
+
num_oov_buckets=1)
|
92 |
+
sequence = tf.RaggedTensor.from_row_splits(values=nucleotide_table.lookup(sequence.values),
|
93 |
+
row_splits=sequence.row_splits).to_tensor(255)
|
94 |
+
|
95 |
+
# add context padding if requested
|
96 |
+
if add_context_padding:
|
97 |
+
pad_5p = 255 * tf.ones([sequence.shape[0], CONTEXT_5P], dtype=sequence.dtype)
|
98 |
+
pad_3p = 255 * tf.ones([sequence.shape[0], CONTEXT_3P], dtype=sequence.dtype)
|
99 |
+
sequence = tf.concat([pad_5p, sequence, pad_3p], axis=1)
|
100 |
+
|
101 |
+
# one-hot encode
|
102 |
+
sequence = tf.one_hot(sequence, depth=4, dtype=tf.float16)
|
103 |
+
|
104 |
+
return sequence
|
105 |
+
|
106 |
+
|
107 |
+
def process_data(transcript_seq: str):
|
108 |
+
|
109 |
+
# convert to upper case
|
110 |
+
transcript_seq = transcript_seq.upper()
|
111 |
+
|
112 |
+
# get all target sites
|
113 |
+
target_seq = [transcript_seq[i: i + TARGET_LEN] for i in range(len(transcript_seq) - TARGET_LEN + 1)]
|
114 |
+
|
115 |
+
# prepare guide sequences
|
116 |
+
guide_seq = sequence_complement([seq[CONTEXT_5P:len(seq) - CONTEXT_3P] for seq in target_seq])
|
117 |
+
|
118 |
+
# model inputs
|
119 |
+
model_inputs = tf.concat([
|
120 |
+
tf.reshape(one_hot_encode_sequence(target_seq, add_context_padding=False), [len(target_seq), -1]),
|
121 |
+
tf.reshape(one_hot_encode_sequence(guide_seq, add_context_padding=True), [len(guide_seq), -1]),
|
122 |
+
], axis=-1)
|
123 |
+
return target_seq, guide_seq, model_inputs
|
124 |
+
|
125 |
+
|
126 |
+
def calibrate_predictions(predictions: np.array, num_mismatches: np.array, params: pd.DataFrame = None):
|
127 |
+
if params is None:
|
128 |
+
params = pd.read_pickle('calibration_params.pkl')
|
129 |
+
correction = np.squeeze(params.set_index('num_mismatches').loc[num_mismatches, 'slope'].to_numpy())
|
130 |
+
return correction * predictions
|
131 |
+
|
132 |
+
|
133 |
+
def score_predictions(predictions: np.array, params: pd.DataFrame = None):
|
134 |
+
if params is None:
|
135 |
+
params = pd.read_pickle('scoring_params.pkl')
|
136 |
+
|
137 |
+
if UNIT_INTERVAL_MAP == 'sigmoid':
|
138 |
+
params = params.iloc[0]
|
139 |
+
return 1 - 1 / (1 + np.exp(params['a'] * predictions + params['b']))
|
140 |
+
|
141 |
+
elif UNIT_INTERVAL_MAP == 'min-max':
|
142 |
+
return 1 - (predictions - params['a']) / (params['b'] - params['a'])
|
143 |
+
|
144 |
+
elif UNIT_INTERVAL_MAP == 'exp-lin-exp':
|
145 |
+
# regime indices
|
146 |
+
active_saturation = predictions < params['a']
|
147 |
+
linear_regime = (params['a'] <= predictions) & (predictions <= params['c'])
|
148 |
+
inactive_saturation = params['c'] < predictions
|
149 |
+
|
150 |
+
# linear regime
|
151 |
+
slope = (params['d'] - params['b']) / (params['c'] - params['a'])
|
152 |
+
intercept = -params['a'] * slope + params['b']
|
153 |
+
predictions[linear_regime] = slope * predictions[linear_regime] + intercept
|
154 |
+
|
155 |
+
# active saturation regime
|
156 |
+
alpha = slope / params['b']
|
157 |
+
beta = alpha * params['a'] - np.log(params['b'])
|
158 |
+
predictions[active_saturation] = np.exp(alpha * predictions[active_saturation] - beta)
|
159 |
+
|
160 |
+
# inactive saturation regime
|
161 |
+
alpha = slope / (1 - params['d'])
|
162 |
+
beta = -alpha * params['c'] - np.log(1 - params['d'])
|
163 |
+
predictions[inactive_saturation] = 1 - np.exp(-alpha * predictions[inactive_saturation] - beta)
|
164 |
+
|
165 |
+
return 1 - predictions
|
166 |
+
|
167 |
+
else:
|
168 |
+
raise NotImplementedError
|
169 |
+
|
170 |
+
|
171 |
+
def get_on_target_predictions(transcripts: pd.DataFrame, model: tf.keras.Model, status_update_fn=None):
|
172 |
+
|
173 |
+
# loop over transcripts
|
174 |
+
predictions = pd.DataFrame()
|
175 |
+
for i, (index, row) in enumerate(transcripts.iterrows()):
|
176 |
+
|
177 |
+
# parse transcript sequence
|
178 |
+
target_seq, guide_seq, model_inputs = process_data(row[SEQ_COL])
|
179 |
+
|
180 |
+
# get predictions
|
181 |
+
lfc_estimate = model.predict(model_inputs, batch_size=BATCH_SIZE_COMPUTE, verbose=False)[:, 0]
|
182 |
+
lfc_estimate = calibrate_predictions(lfc_estimate, num_mismatches=np.zeros_like(lfc_estimate))
|
183 |
+
scores = score_predictions(lfc_estimate)
|
184 |
+
predictions = pd.concat([predictions, pd.DataFrame({
|
185 |
+
ID_COL: [index] * len(scores),
|
186 |
+
TARGET_COL: target_seq,
|
187 |
+
GUIDE_COL: guide_seq,
|
188 |
+
SCORE_COL: scores})])
|
189 |
+
|
190 |
+
# progress update
|
191 |
+
percent_complete = 100 * min((i + 1) / len(transcripts), 1)
|
192 |
+
update_text = 'Evaluating on-target guides for each transcript: {:.2f}%'.format(percent_complete)
|
193 |
+
print('\r' + update_text, end='')
|
194 |
+
if status_update_fn is not None:
|
195 |
+
status_update_fn(update_text, percent_complete)
|
196 |
+
print('')
|
197 |
+
|
198 |
+
return predictions
|
199 |
+
|
200 |
+
|
201 |
+
def top_guides_per_transcript(predictions: pd.DataFrame):
|
202 |
+
|
203 |
+
# select and sort top guides for each transcript
|
204 |
+
top_guides = pd.DataFrame()
|
205 |
+
for transcript in predictions[ID_COL].unique():
|
206 |
+
df = predictions.loc[predictions[ID_COL] == transcript]
|
207 |
+
df = df.sort_values(SCORE_COL, ascending=False).reset_index(drop=True).iloc[:NUM_TOP_GUIDES]
|
208 |
+
top_guides = pd.concat([top_guides, df])
|
209 |
+
|
210 |
+
return top_guides.reset_index(drop=True)
|
211 |
+
|
212 |
+
|
213 |
+
def get_titration_candidates(top_guide_predictions: pd.DataFrame):
|
214 |
+
|
215 |
+
# generate a table of all titration candidates
|
216 |
+
titration_candidates = pd.DataFrame()
|
217 |
+
for _, row in top_guide_predictions.iterrows():
|
218 |
+
for i in range(len(row[GUIDE_COL])):
|
219 |
+
nt = row[GUIDE_COL][i]
|
220 |
+
for mutation in set(NUCLEOTIDE_TOKENS.keys()) - {nt, 'N'}:
|
221 |
+
sm_guide = list(row[GUIDE_COL])
|
222 |
+
sm_guide[i] = mutation
|
223 |
+
sm_guide = ''.join(sm_guide)
|
224 |
+
assert row[GUIDE_COL] != sm_guide
|
225 |
+
titration_candidates = pd.concat([titration_candidates, pd.DataFrame({
|
226 |
+
ID_COL: [row[ID_COL]],
|
227 |
+
TARGET_COL: [row[TARGET_COL]],
|
228 |
+
GUIDE_COL: [sm_guide],
|
229 |
+
MM_COL: [1]
|
230 |
+
})])
|
231 |
+
|
232 |
+
return titration_candidates
|
233 |
+
|
234 |
+
|
235 |
+
def find_off_targets(top_guides: pd.DataFrame, status_update_fn=None):
|
236 |
+
|
237 |
+
# load reference transcripts
|
238 |
+
reference_transcripts = load_transcripts([os.path.join('transcripts', f) for f in REFERENCE_TRANSCRIPTS])
|
239 |
+
|
240 |
+
# one-hot encode guides to form a filter
|
241 |
+
guide_filter = one_hot_encode_sequence(sequence_complement(top_guides[GUIDE_COL]), add_context_padding=False)
|
242 |
+
guide_filter = tf.transpose(guide_filter, [1, 2, 0])
|
243 |
+
|
244 |
+
# loop over transcripts in batches
|
245 |
+
i = 0
|
246 |
+
off_targets = pd.DataFrame()
|
247 |
+
while i < len(reference_transcripts):
|
248 |
+
# select batch
|
249 |
+
df_batch = reference_transcripts.iloc[i:min(i + BATCH_SIZE_SCAN, len(reference_transcripts))]
|
250 |
+
i += BATCH_SIZE_SCAN
|
251 |
+
|
252 |
+
# find locations of off-targets
|
253 |
+
transcripts = one_hot_encode_sequence(df_batch[SEQ_COL].values.tolist(), add_context_padding=False)
|
254 |
+
num_mismatches = GUIDE_LEN - tf.nn.conv1d(transcripts, guide_filter, stride=1, padding='SAME')
|
255 |
+
loc_off_targets = tf.where(tf.round(num_mismatches) <= NUM_MISMATCHES).numpy()
|
256 |
+
|
257 |
+
# off-targets discovered
|
258 |
+
if len(loc_off_targets) > 0:
|
259 |
+
|
260 |
+
# log off-targets
|
261 |
+
dict_off_targets = pd.DataFrame({
|
262 |
+
'On-target ' + ID_COL: top_guides.iloc[loc_off_targets[:, 2]][ID_COL],
|
263 |
+
GUIDE_COL: top_guides.iloc[loc_off_targets[:, 2]][GUIDE_COL],
|
264 |
+
'Off-target ' + ID_COL: df_batch.index.values[loc_off_targets[:, 0]],
|
265 |
+
'Guide Midpoint': loc_off_targets[:, 1],
|
266 |
+
SEQ_COL: df_batch[SEQ_COL].values[loc_off_targets[:, 0]],
|
267 |
+
MM_COL: tf.gather_nd(num_mismatches, loc_off_targets).numpy().astype(int),
|
268 |
+
}).to_dict('records')
|
269 |
+
|
270 |
+
# trim transcripts to targets
|
271 |
+
for row in dict_off_targets:
|
272 |
+
start_location = row['Guide Midpoint'] - (GUIDE_LEN // 2)
|
273 |
+
del row['Guide Midpoint']
|
274 |
+
target = row[SEQ_COL]
|
275 |
+
del row[SEQ_COL]
|
276 |
+
if start_location < CONTEXT_5P:
|
277 |
+
target = target[0:GUIDE_LEN + CONTEXT_3P]
|
278 |
+
target = 'N' * (TARGET_LEN - len(target)) + target
|
279 |
+
elif start_location + GUIDE_LEN + CONTEXT_3P > len(target):
|
280 |
+
target = target[start_location - CONTEXT_5P:]
|
281 |
+
target = target + 'N' * (TARGET_LEN - len(target))
|
282 |
+
else:
|
283 |
+
target = target[start_location - CONTEXT_5P:start_location + GUIDE_LEN + CONTEXT_3P]
|
284 |
+
if row[MM_COL] == 0 and 'N' not in target:
|
285 |
+
assert row[GUIDE_COL] == sequence_complement([target[CONTEXT_5P:TARGET_LEN - CONTEXT_3P]])[0]
|
286 |
+
row[TARGET_COL] = target
|
287 |
+
|
288 |
+
# append new off-targets
|
289 |
+
off_targets = pd.concat([off_targets, pd.DataFrame(dict_off_targets)])
|
290 |
+
|
291 |
+
# progress update
|
292 |
+
percent_complete = 100 * min((i + 1) / len(reference_transcripts), 1)
|
293 |
+
update_text = 'Scanning for off-targets: {:.2f}%'.format(percent_complete)
|
294 |
+
print('\r' + update_text, end='')
|
295 |
+
if status_update_fn is not None:
|
296 |
+
status_update_fn(update_text, percent_complete)
|
297 |
+
print('')
|
298 |
+
|
299 |
+
return off_targets
|
300 |
+
|
301 |
+
|
302 |
+
def predict_off_target(off_targets: pd.DataFrame, model: tf.keras.Model):
|
303 |
+
if len(off_targets) == 0:
|
304 |
+
return pd.DataFrame()
|
305 |
+
|
306 |
+
# compute off-target predictions
|
307 |
+
model_inputs = tf.concat([
|
308 |
+
tf.reshape(one_hot_encode_sequence(off_targets[TARGET_COL], add_context_padding=False), [len(off_targets), -1]),
|
309 |
+
tf.reshape(one_hot_encode_sequence(off_targets[GUIDE_COL], add_context_padding=True), [len(off_targets), -1]),
|
310 |
+
], axis=-1)
|
311 |
+
lfc_estimate = model.predict(model_inputs, batch_size=BATCH_SIZE_COMPUTE, verbose=False)[:, 0]
|
312 |
+
lfc_estimate = calibrate_predictions(lfc_estimate, off_targets['Number of Mismatches'].to_numpy())
|
313 |
+
off_targets[SCORE_COL] = score_predictions(lfc_estimate)
|
314 |
+
|
315 |
+
return off_targets.reset_index(drop=True)
|
316 |
+
|
317 |
+
|
318 |
+
def tiger_exhibit(transcripts: pd.DataFrame, mode: str, check_off_targets: bool, status_update_fn=None):
|
319 |
+
|
320 |
+
# load model
|
321 |
+
if os.path.exists('model'):
|
322 |
+
tiger = tf.keras.models.load_model('model')
|
323 |
+
else:
|
324 |
+
print('no saved model!')
|
325 |
+
exit()
|
326 |
+
|
327 |
+
# evaluate all on-target guides per transcript
|
328 |
+
on_target_predictions = get_on_target_predictions(transcripts, tiger, status_update_fn)
|
329 |
+
|
330 |
+
# initialize other outputs
|
331 |
+
titration_predictions = off_target_predictions = None
|
332 |
+
|
333 |
+
if mode == 'all' and not check_off_targets:
|
334 |
+
off_target_candidates = None
|
335 |
+
|
336 |
+
elif mode == 'top_guides':
|
337 |
+
on_target_predictions = top_guides_per_transcript(on_target_predictions)
|
338 |
+
off_target_candidates = on_target_predictions
|
339 |
+
|
340 |
+
elif mode == 'titration':
|
341 |
+
on_target_predictions = top_guides_per_transcript(on_target_predictions)
|
342 |
+
titration_candidates = get_titration_candidates(on_target_predictions)
|
343 |
+
titration_predictions = predict_off_target(titration_candidates, model=tiger)
|
344 |
+
off_target_candidates = pd.concat([on_target_predictions, titration_predictions])
|
345 |
+
|
346 |
+
else:
|
347 |
+
raise NotImplementedError
|
348 |
+
|
349 |
+
# check off-target effects for top guides
|
350 |
+
if check_off_targets and off_target_candidates is not None:
|
351 |
+
off_target_candidates = find_off_targets(off_target_candidates, status_update_fn)
|
352 |
+
off_target_predictions = predict_off_target(off_target_candidates, model=tiger)
|
353 |
+
if len(off_target_predictions) > 0:
|
354 |
+
off_target_predictions = off_target_predictions.sort_values(SCORE_COL, ascending=False)
|
355 |
+
off_target_predictions = off_target_predictions.reset_index(drop=True)
|
356 |
+
|
357 |
+
# finalize tables
|
358 |
+
for df in [on_target_predictions, titration_predictions, off_target_predictions]:
|
359 |
+
if df is not None and len(df) > 0:
|
360 |
+
for col in df.columns:
|
361 |
+
if ID_COL in col and set(df[col].unique()) == {'ManualEntry'}:
|
362 |
+
del df[col]
|
363 |
+
df[GUIDE_COL] = df[GUIDE_COL].apply(lambda s: s[::-1]) # reverse guide sequences
|
364 |
+
df[TARGET_COL] = df[TARGET_COL].apply(lambda seq: seq[CONTEXT_5P:len(seq) - CONTEXT_3P]) # remove context
|
365 |
+
|
366 |
+
return on_target_predictions, titration_predictions, off_target_predictions
|
367 |
+
|
368 |
+
|
369 |
+
if __name__ == '__main__':
|
370 |
+
|
371 |
+
# common arguments
|
372 |
+
parser = argparse.ArgumentParser()
|
373 |
+
parser.add_argument('--mode', type=str, default='titration')
|
374 |
+
parser.add_argument('--check_off_targets', action='store_true', default=False)
|
375 |
+
parser.add_argument('--fasta_path', type=str, default=None)
|
376 |
+
args = parser.parse_args()
|
377 |
+
|
378 |
+
# check for any existing results
|
379 |
+
if os.path.exists('on_target.csv') or os.path.exists('titration.csv') or os.path.exists('off_target.csv'):
|
380 |
+
raise FileExistsError('please rename or delete existing results')
|
381 |
+
|
382 |
+
# load transcripts from a directory of fasta files
|
383 |
+
if args.fasta_path is not None and os.path.exists(args.fasta_path):
|
384 |
+
df_transcripts = load_transcripts([os.path.join(args.fasta_path, f) for f in os.listdir(args.fasta_path)])
|
385 |
+
|
386 |
+
# otherwise consider simple test case with first 50 nucleotides from EIF3B-003's CDS
|
387 |
+
else:
|
388 |
+
df_transcripts = pd.DataFrame({
|
389 |
+
ID_COL: ['ManualEntry'],
|
390 |
+
SEQ_COL: ['ATGCAGGACGCGGAGAACGTGGCGGTGCCCGAGGCGGCCGAGGAGCGCGC']})
|
391 |
+
df_transcripts.set_index(ID_COL, inplace=True)
|
392 |
+
|
393 |
+
# process in batches
|
394 |
+
batch = 0
|
395 |
+
num_batches = len(df_transcripts) // BATCH_SIZE_TRANSCRIPTS
|
396 |
+
num_batches += (len(df_transcripts) % BATCH_SIZE_TRANSCRIPTS > 0)
|
397 |
+
for idx in range(0, len(df_transcripts), BATCH_SIZE_TRANSCRIPTS):
|
398 |
+
batch += 1
|
399 |
+
print('Batch {:d} of {:d}'.format(batch, num_batches))
|
400 |
+
|
401 |
+
# run batch
|
402 |
+
idx_stop = min(idx + BATCH_SIZE_TRANSCRIPTS, len(df_transcripts))
|
403 |
+
df_on_target, df_titration, df_off_target = tiger_exhibit(
|
404 |
+
transcripts=df_transcripts[idx:idx_stop],
|
405 |
+
mode=args.mode,
|
406 |
+
check_off_targets=args.check_off_targets
|
407 |
+
)
|
408 |
+
|
409 |
+
# save batch results
|
410 |
+
df_on_target.to_csv('on_target.csv', header=batch == 1, index=False, mode='a')
|
411 |
+
if df_titration is not None:
|
412 |
+
df_titration.to_csv('titration.csv', header=batch == 1, index=False, mode='a')
|
413 |
+
if df_off_target is not None:
|
414 |
+
df_off_target.to_csv('off_target.csv', header=batch == 1, index=False, mode='a')
|
415 |
+
|
416 |
+
# clear session to prevent memory blow up
|
417 |
+
tf.keras.backend.clear_session()
|