File size: 7,149 Bytes
b830975
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
import gradio as gr
import pandas as pd
import numpy as np
import transformers as T
import torch
import matplotlib.pyplot as plt
import shutil
from torch.utils.data import DataLoader
from utils import load_model,plot_tracks,calc_mutation_effect,plot_tracks_comparision
from attention_extractor import attention_extractor
import subprocess

tokenizer = None
model = None

def load_tokenizer_and_model():
    global tokenizer, model
    if tokenizer is None:
        tokenizer = T.BertTokenizer.from_pretrained("./model/")
    if model is None:
        load_tokenizer_and_model()
      
model = load_model(tokenizer, "./model/model.bin")
prefix_code = pd.read_csv('./data/prefix_codes.csv')
prefix_code_dic = {k:v for k,v in zip(prefix_code.prefix,prefix_code.code_prefix)}
rbp_options = prefix_code['RBP'].unique().tolist()

def sequence_process(RBP, cell_line, ss):
    target = "_".join([RBP, cell_line])
    barcode = prefix_code_dic[target]

    if len(ss) < 3:
        raise ValueError("Input sequence length is too short for 3-mer processing.")
    ss = [ss[i:i+3] for i in range(len(ss) - 2)]  # 3 mer data
    seq = [barcode]
    seq.extend(ss[:-1])
    return seq

def predict_affinity(RBP, cell_line, ss):
    sequence = sequence_process(RBP, cell_line, ss)
    inputs = tokenizer(sequence, return_tensors="pt", padding=True, truncation=True, max_length=512)

    with torch.no_grad():
        outputs = model(**inputs)
    
    affinity_score = outputs.logits.squeeze().item()
    plot_tracks(sequence, outputs.logits.squeeze())

    plt.savefig("track_plot.png")
    plt.close()    
    return affinity_score, "track_plot.png"

def predict_affinity_with_mutation(RBP, cell_line, ss_wt, ss_mt):
    original_affinity = predict_affinity(RBP, cell_line, ss_wt)
    mutated_affinity = predict_affinity(RBP, cell_line, ss_mt)

    change_in_affinity = calc_mutation_effect(original_affinity, mutated_affinity)
    plot_tracks_comparision(ss_wt, ss_mt, original_affinity, mutated_affinity)
    plt.savefig("mutation_plot.png")
    plt.close()
    
    return change_in_affinity, "mutation_plot.png" 

def load_eclip_data(path):
    with open(path, 'r') as f:
        lines = f.readlines()
    try:
        assert np.all([i.startswith('>') for e, i in enumerate(lines) if e == 0 or e % 2 == 0])
    except AssertionError:
        raise ValueError("Input file is not in the expected FASTA format.")
    seq = [i.strip() for i in lines if not i.startswith('>')]
    
    return np.array(seq)

class SequenceDataset(torch.utils.data.Dataset):
    def __init__(self, barcode, seq, tokenizer, max_length=512):
        self.barcode = barcode
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.sequence = seq
        self.n = len(self.sequence)

    def __len__(self):
        return self.n

    def __getitem__(self, i):
        ss = self.sequence[i]
        ss = [ss[i:i+3] for i in range(len(ss) - 2)]  # 3 mer data
        seq = [self.barcode]
        seq.extend(ss[:-1])
        inputs = self.tokenizer(seq, is_split_into_words=True, add_special_tokens=True, return_tensors='pt') 
        return inputs['input_ids']

def extract_attention_regions(barcode, seq, layer, head):
    val_dataset = SequenceDataset(barcode, seq, tokenizer)        
    val_loader = DataLoader(
        val_dataset,
        batch_size=4,
        shuffle = False,
        drop_last=False,
        num_workers=2
    )
    seq_attn = attention_extractor(model, val_loader, model.model.device, len(val_dataset))
    if len(seq_attn) <= layer or len(seq_attn[layer]) <= head:
        raise ValueError("Invalid layer or head index for attention extraction.")
    attention_data = seq_attn[layer][head].squeeze()
    high_attention_regions = (attention_data > np.percentile(attention_data, 99)).nonzero(as_tuple=True)
    
    return high_attention_regions

def predict_motif(RBP, cell_line, file, layer, head):
    target = "_".join([RBP, cell_line])
    barcode = prefix_code_dic[target]
    seq = load_eclip_data(file)
    high_attention_regions = extract_attention_regions(barcode, seq, layer, head)
    
    with open("high_attention_regions.fasta", "w") as f:
        for i, (seq_idx, pos) in enumerate(zip(*high_attention_regions)):
            subseq = seq[seq_idx][pos:pos+10]  # 10-mer
            f.write(f">{i}\n{subseq}\n")
    
    zip_file = "high_attention_regions.zip"
    shutil.make_archive("high_attention_regions", 'zip', ".", "high_attention_regions.fasta")
    
    return zip_file

def single_sequence_affinity(RBP, cell_line, ss):
    return predict_affinity(RBP, cell_line, ss)

def mutation_affinity(RBP, cell_line, ss_wt, ss_mt):
    return  predict_affinity_with_mutation(RBP, cell_line, ss_wt, ss_mt)

def motif_enrichment(RBP, cell_line, file, layer, head):
    return predict_motif(RBP, cell_line, file, layer, head)

# defined input and output in Gradio
def get_cell_lines(selected_rbp):
    filtered_df = prefix_code[prefix_code['RBP'] == selected_rbp]
    cell_lines = filtered_df['cell_line'].unique().tolist()
    return gr.update(choices=cell_lines)

def build_gradio_interface():
    rbp_dropdown = gr.Dropdown(label="Select RBP", choices=rbp_options, value=rbp_options[0])
    cell_line_dropdown = gr.Dropdown(label="Select Cell Line", choices=[])
    
    rbp_dropdown.change(fn=get_cell_lines, inputs=rbp_dropdown, outputs=cell_line_dropdown)
    
    affinity_interface = gr.Interface(
        fn=single_sequence_affinity,
        inputs=[
            rbp_dropdown, 
            cell_line_dropdown,
            gr.Textbox(label='Input RNA Sequence')
        ],
        outputs=[
            gr.Textbox(label="Predicted Affinity"),
            gr.Image(label="Track Plot")
        ],
        description="Predicted binding affinity"
    )
    
    mutation_interface = gr.Interface(
        fn=mutation_affinity,
        inputs=[
            rbp_dropdown, 
            cell_line_dropdown,
            gr.Textbox(label='Input Wild-type RNA Sequence'),
            gr.Textbox(label='Input Mutated RNA Sequence')
        ],
        outputs=[
            gr.Textbox(label="Affinity Change"),
            gr.Image(label="Mutation Plot")
        ],
        description="Predicted mutation effect"
    )

    motif_interface = gr.Interface(
        fn=motif_enrichment,
        inputs=[
            rbp_dropdown, 
            cell_line_dropdown,
            gr.File(label="Upload Sequence File"),
            gr.Slider(minimum=0, maximum=12, step=1, label="Select Attention Layer"),
            gr.Slider(minimum=0, maximum=11, step=1, label="Select Attention Head")
        ],
        outputs=gr.File(label="Download High Attention Regions (ZIP) (result.zip)"),
        description="Motif enrichment preparation: Download high attention regions and run AME manually."
    )
    
    app = gr.TabbedInterface(
        interface_list=[affinity_interface, mutation_interface, motif_interface],
        tab_names=["Binding Affinity Prediction", "Mutation Effect Prediction", "Motif Enrichment"]
    )
    
    return app

if __name__ == "__main__":
    app = build_gradio_interface()
    app.launch(share=True)