File size: 7,149 Bytes
b830975 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 |
import gradio as gr
import pandas as pd
import numpy as np
import transformers as T
import torch
import matplotlib.pyplot as plt
import shutil
from torch.utils.data import DataLoader
from utils import load_model,plot_tracks,calc_mutation_effect,plot_tracks_comparision
from attention_extractor import attention_extractor
import subprocess
tokenizer = None
model = None
def load_tokenizer_and_model():
global tokenizer, model
if tokenizer is None:
tokenizer = T.BertTokenizer.from_pretrained("./model/")
if model is None:
load_tokenizer_and_model()
model = load_model(tokenizer, "./model/model.bin")
prefix_code = pd.read_csv('./data/prefix_codes.csv')
prefix_code_dic = {k:v for k,v in zip(prefix_code.prefix,prefix_code.code_prefix)}
rbp_options = prefix_code['RBP'].unique().tolist()
def sequence_process(RBP, cell_line, ss):
target = "_".join([RBP, cell_line])
barcode = prefix_code_dic[target]
if len(ss) < 3:
raise ValueError("Input sequence length is too short for 3-mer processing.")
ss = [ss[i:i+3] for i in range(len(ss) - 2)] # 3 mer data
seq = [barcode]
seq.extend(ss[:-1])
return seq
def predict_affinity(RBP, cell_line, ss):
sequence = sequence_process(RBP, cell_line, ss)
inputs = tokenizer(sequence, return_tensors="pt", padding=True, truncation=True, max_length=512)
with torch.no_grad():
outputs = model(**inputs)
affinity_score = outputs.logits.squeeze().item()
plot_tracks(sequence, outputs.logits.squeeze())
plt.savefig("track_plot.png")
plt.close()
return affinity_score, "track_plot.png"
def predict_affinity_with_mutation(RBP, cell_line, ss_wt, ss_mt):
original_affinity = predict_affinity(RBP, cell_line, ss_wt)
mutated_affinity = predict_affinity(RBP, cell_line, ss_mt)
change_in_affinity = calc_mutation_effect(original_affinity, mutated_affinity)
plot_tracks_comparision(ss_wt, ss_mt, original_affinity, mutated_affinity)
plt.savefig("mutation_plot.png")
plt.close()
return change_in_affinity, "mutation_plot.png"
def load_eclip_data(path):
with open(path, 'r') as f:
lines = f.readlines()
try:
assert np.all([i.startswith('>') for e, i in enumerate(lines) if e == 0 or e % 2 == 0])
except AssertionError:
raise ValueError("Input file is not in the expected FASTA format.")
seq = [i.strip() for i in lines if not i.startswith('>')]
return np.array(seq)
class SequenceDataset(torch.utils.data.Dataset):
def __init__(self, barcode, seq, tokenizer, max_length=512):
self.barcode = barcode
self.tokenizer = tokenizer
self.max_length = max_length
self.sequence = seq
self.n = len(self.sequence)
def __len__(self):
return self.n
def __getitem__(self, i):
ss = self.sequence[i]
ss = [ss[i:i+3] for i in range(len(ss) - 2)] # 3 mer data
seq = [self.barcode]
seq.extend(ss[:-1])
inputs = self.tokenizer(seq, is_split_into_words=True, add_special_tokens=True, return_tensors='pt')
return inputs['input_ids']
def extract_attention_regions(barcode, seq, layer, head):
val_dataset = SequenceDataset(barcode, seq, tokenizer)
val_loader = DataLoader(
val_dataset,
batch_size=4,
shuffle = False,
drop_last=False,
num_workers=2
)
seq_attn = attention_extractor(model, val_loader, model.model.device, len(val_dataset))
if len(seq_attn) <= layer or len(seq_attn[layer]) <= head:
raise ValueError("Invalid layer or head index for attention extraction.")
attention_data = seq_attn[layer][head].squeeze()
high_attention_regions = (attention_data > np.percentile(attention_data, 99)).nonzero(as_tuple=True)
return high_attention_regions
def predict_motif(RBP, cell_line, file, layer, head):
target = "_".join([RBP, cell_line])
barcode = prefix_code_dic[target]
seq = load_eclip_data(file)
high_attention_regions = extract_attention_regions(barcode, seq, layer, head)
with open("high_attention_regions.fasta", "w") as f:
for i, (seq_idx, pos) in enumerate(zip(*high_attention_regions)):
subseq = seq[seq_idx][pos:pos+10] # 10-mer
f.write(f">{i}\n{subseq}\n")
zip_file = "high_attention_regions.zip"
shutil.make_archive("high_attention_regions", 'zip', ".", "high_attention_regions.fasta")
return zip_file
def single_sequence_affinity(RBP, cell_line, ss):
return predict_affinity(RBP, cell_line, ss)
def mutation_affinity(RBP, cell_line, ss_wt, ss_mt):
return predict_affinity_with_mutation(RBP, cell_line, ss_wt, ss_mt)
def motif_enrichment(RBP, cell_line, file, layer, head):
return predict_motif(RBP, cell_line, file, layer, head)
# defined input and output in Gradio
def get_cell_lines(selected_rbp):
filtered_df = prefix_code[prefix_code['RBP'] == selected_rbp]
cell_lines = filtered_df['cell_line'].unique().tolist()
return gr.update(choices=cell_lines)
def build_gradio_interface():
rbp_dropdown = gr.Dropdown(label="Select RBP", choices=rbp_options, value=rbp_options[0])
cell_line_dropdown = gr.Dropdown(label="Select Cell Line", choices=[])
rbp_dropdown.change(fn=get_cell_lines, inputs=rbp_dropdown, outputs=cell_line_dropdown)
affinity_interface = gr.Interface(
fn=single_sequence_affinity,
inputs=[
rbp_dropdown,
cell_line_dropdown,
gr.Textbox(label='Input RNA Sequence')
],
outputs=[
gr.Textbox(label="Predicted Affinity"),
gr.Image(label="Track Plot")
],
description="Predicted binding affinity"
)
mutation_interface = gr.Interface(
fn=mutation_affinity,
inputs=[
rbp_dropdown,
cell_line_dropdown,
gr.Textbox(label='Input Wild-type RNA Sequence'),
gr.Textbox(label='Input Mutated RNA Sequence')
],
outputs=[
gr.Textbox(label="Affinity Change"),
gr.Image(label="Mutation Plot")
],
description="Predicted mutation effect"
)
motif_interface = gr.Interface(
fn=motif_enrichment,
inputs=[
rbp_dropdown,
cell_line_dropdown,
gr.File(label="Upload Sequence File"),
gr.Slider(minimum=0, maximum=12, step=1, label="Select Attention Layer"),
gr.Slider(minimum=0, maximum=11, step=1, label="Select Attention Head")
],
outputs=gr.File(label="Download High Attention Regions (ZIP) (result.zip)"),
description="Motif enrichment preparation: Download high attention regions and run AME manually."
)
app = gr.TabbedInterface(
interface_list=[affinity_interface, mutation_interface, motif_interface],
tab_names=["Binding Affinity Prediction", "Mutation Effect Prediction", "Motif Enrichment"]
)
return app
if __name__ == "__main__":
app = build_gradio_interface()
app.launch(share=True)
|