File size: 6,421 Bytes
89be9f9
5fc4e72
1ef81e0
89be9f9
1ef81e0
eac7d3f
89be9f9
 
59874d6
 
 
89be9f9
457a981
eac7d3f
 
89be9f9
 
eac7d3f
 
34274e5
 
eac7d3f
34274e5
eac7d3f
 
457a981
 
34274e5
 
 
 
 
eac7d3f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
457a981
 
 
eac7d3f
 
457a981
34274e5
457a981
89be9f9
 
de06d10
59874d6
457a981
 
814d067
 
de06d10
457a981
1ef81e0
814d067
 
 
 
1ef81e0
de06d10
 
5fc4e72
1ef81e0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
de06d10
 
 
 
 
 
 
 
 
 
 
 
4e66e6b
814d067
de06d10
 
 
 
 
 
 
4e66e6b
de06d10
1ef81e0
4e66e6b
 
1ef81e0
de06d10
4e66e6b
de06d10
4e66e6b
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
import os
import gzip
import numpy as np
import pandas as pd
import tensorflow as tf
from Bio import SeqIO

GUIDE_LEN = 23
CONTEXT_5P = 3
CONTEXT_3P = 0
TARGET_LEN = CONTEXT_5P + GUIDE_LEN + CONTEXT_3P
NUCLEOTIDE_TOKENS = dict(zip(['A', 'C', 'G', 'T'], [0, 1, 2, 3]))
NUCLEOTIDE_COMPLEMENT = dict(zip(['A', 'C', 'G', 'T'], ['T', 'G', 'C', 'A']))
NUM_TOP_GUIDES = 10
NUM_MISMATCHES = 3


def sequence_complement(sequence: list):
    return [''.join([NUCLEOTIDE_COMPLEMENT[nt] for nt in list(seq)]) for seq in sequence]


def one_hot_encode_sequence(sequence: list, add_context_padding: bool = False):

    # stack list of sequences into a tensor
    sequence = tf.ragged.stack([tf.constant(list(seq)) for seq in sequence], axis=0)

    # tokenize sequence
    nucleotide_table = tf.lookup.StaticVocabularyTable(
        initializer=tf.lookup.KeyValueTensorInitializer(
            keys=tf.constant(list(NUCLEOTIDE_TOKENS.keys()), dtype=tf.string),
            values=tf.constant(list(NUCLEOTIDE_TOKENS.values()), dtype=tf.int64)),
        num_oov_buckets=1)
    sequence = tf.RaggedTensor.from_row_splits(values=nucleotide_table.lookup(sequence.values),
                                               row_splits=sequence.row_splits).to_tensor(255)

    # add context padding if requested
    if add_context_padding:
        pad_5p = 255 * tf.ones([sequence.shape[0], CONTEXT_5P], dtype=sequence.dtype)
        pad_3p = 255 * tf.ones([sequence.shape[0], CONTEXT_3P], dtype=sequence.dtype)
        sequence = tf.concat([pad_5p, sequence, pad_3p], axis=1)

    # one-hot encode
    sequence = tf.one_hot(sequence, depth=4)

    return sequence


def process_data(transcript_seq: str):

    # convert to upper case
    transcript_seq = transcript_seq.upper()

    # get all target sites
    target_seq = [transcript_seq[i: i + TARGET_LEN] for i in range(len(transcript_seq) - TARGET_LEN)]

    # prepare guide sequences
    guide_seq = sequence_complement([seq[CONTEXT_5P:len(seq) - CONTEXT_3P] for seq in target_seq])

    # model inputs
    model_inputs = tf.concat([
        tf.reshape(one_hot_encode_sequence(target_seq, add_context_padding=False), [len(target_seq), -1]),
        tf.reshape(one_hot_encode_sequence(guide_seq, add_context_padding=True), [len(guide_seq), -1]),
        ], axis=-1)

    return target_seq, guide_seq, model_inputs


def predict_on_target(transcript_seq: str, model: tf.keras.Model):

    # parse transcript sequence
    target_seq, guide_seq, model_inputs = process_data(transcript_seq)

    # get predictions
    normalized_lfc = model.predict_step(model_inputs)
    predictions = pd.DataFrame({'Guide': guide_seq, 'Normalized LFC': tf.squeeze(normalized_lfc).numpy()})
    predictions = predictions.set_index('Guide').sort_values('Normalized LFC')

    return predictions


def find_off_targets(guides, batch_size=1000):

    # load reference transcripts
    with gzip.open(os.path.join('transcripts', 'gencode.v19.pc_transcripts.fa.gz'), 'rt') as file:
        df_transcripts = pd.DataFrame([(t.id, str(t.seq)) for t in SeqIO.parse(file, 'fasta')], columns=['id', 'seq'])
    df_transcripts['id'] = df_transcripts['id'].apply(lambda s: s.split('|')[4])
    df_transcripts.set_index('id', inplace=True)

    # one-hot encode guides to form a filter
    guide_filter = one_hot_encode_sequence(sequence_complement(guides), add_context_padding=False)
    guide_filter = tf.transpose(guide_filter, [1, 2, 0])

    # loop over transcripts in batches
    i = 0
    print('Scanning for off-targets')
    df_off_targets = pd.DataFrame()
    while i < len(df_transcripts):
        # select batch
        df_batch = df_transcripts.iloc[i:min(i + batch_size, len(df_transcripts))]
        i += batch_size

        # find and log off-targets
        transcripts = one_hot_encode_sequence(df_batch['seq'].values.tolist(), add_context_padding=False)
        num_mismatches = GUIDE_LEN - tf.nn.conv1d(transcripts, guide_filter, stride=1, padding='SAME')
        loc_off_targets = tf.where(num_mismatches <= NUM_MISMATCHES).numpy()
        df_off_targets = pd.concat([df_off_targets, pd.DataFrame({
            'Guide': np.array(guides)[loc_off_targets[:, 2]],
            'Isoform': df_batch.index.values[loc_off_targets[:, 0]],
            'Mismatches': tf.gather_nd(num_mismatches, loc_off_targets).numpy().astype(int),
            'Midpoint': loc_off_targets[:, 1],
            'Target': df_batch['seq'].values[loc_off_targets[:, 0]],
        })])

        # progress update
        print('\rPercent complete: {:.2f}%'.format(100 * min(i / len(df_transcripts), 1)), end='')
    print('')

    # trim transcripts to targets
    dict_off_targets = df_off_targets.to_dict('records')
    for row in dict_off_targets:
        start_location = row['Midpoint'] - (GUIDE_LEN // 2) - CONTEXT_5P
        row['Target'] = row['Target'][start_location:start_location + TARGET_LEN]
        if row['Mismatches'] == 0:
            assert row['Guide'] == sequence_complement([row['Target'][CONTEXT_5P:TARGET_LEN-CONTEXT_3P]])[0]
    df_off_targets = pd.DataFrame(dict_off_targets)

    return df_off_targets


def predict_off_target(off_targets: pd.DataFrame, model: tf.keras.Model):

    # append predictions off-target predictions
    model_inputs = tf.concat([
        tf.reshape(one_hot_encode_sequence(off_targets['Target'], add_context_padding=False), [len(off_targets), -1]),
        tf.reshape(one_hot_encode_sequence(off_targets['Guide'], add_context_padding=True), [len(off_targets), -1]),
        ], axis=-1)
    off_targets['Normalized LFC'] = model.predict_step(model_inputs)

    return off_targets


def tiger_exhibit(transcript):

    # load model
    if os.path.exists('model'):
        tiger = tf.keras.models.load_model('model')
    else:
        print('no saved model!')
        exit()

    # on-target predictions
    on_target_predictions = predict_on_target(transcript, model=tiger)

    # keep only top guides
    on_target_predictions = on_target_predictions.iloc[:NUM_TOP_GUIDES]

    # predict off-target effects for top guides
    off_targets = find_off_targets(on_target_predictions.index.values.tolist())
    off_targets = predict_off_target(off_targets, model=tiger)

    return on_target_predictions, off_targets


if __name__ == '__main__':

    # simple test case
    print(tiger_exhibit('ATGCAGGACGCGGAGAACGTGGCGGTGCCCGAGGCGGCCGAGGAGCGCGC'.lower()))  # first 50 from EIF3B-003's CDS