|
import gradio as gr |
|
import pandas as pd |
|
import numpy as np |
|
import transformers as T |
|
import torch |
|
import matplotlib.pyplot as plt |
|
import shutil |
|
from torch.utils.data import DataLoader |
|
from utils import load_model,plot_tracks,calc_mutation_effect,plot_tracks_comparision |
|
from attention_extractor import attention_extractor |
|
import subprocess |
|
|
|
tokenizer = None |
|
model = None |
|
|
|
def load_tokenizer_and_model(): |
|
global tokenizer, model |
|
if tokenizer is None: |
|
tokenizer = T.BertTokenizer.from_pretrained("./model/") |
|
if model is None: |
|
load_tokenizer_and_model() |
|
|
|
model = load_model(tokenizer, "./model/model.bin") |
|
prefix_code = pd.read_csv('./data/prefix_codes.csv') |
|
prefix_code_dic = {k:v for k,v in zip(prefix_code.prefix,prefix_code.code_prefix)} |
|
rbp_options = prefix_code['RBP'].unique().tolist() |
|
|
|
def sequence_process(RBP, cell_line, ss): |
|
target = "_".join([RBP, cell_line]) |
|
barcode = prefix_code_dic[target] |
|
|
|
if len(ss) < 3: |
|
raise ValueError("Input sequence length is too short for 3-mer processing.") |
|
ss = [ss[i:i+3] for i in range(len(ss) - 2)] |
|
seq = [barcode] |
|
seq.extend(ss[:-1]) |
|
return seq |
|
|
|
def predict_affinity(RBP, cell_line, ss): |
|
sequence = sequence_process(RBP, cell_line, ss) |
|
inputs = tokenizer(sequence, return_tensors="pt", padding=True, truncation=True, max_length=512) |
|
|
|
with torch.no_grad(): |
|
outputs = model(**inputs) |
|
|
|
affinity_score = outputs.logits.squeeze().item() |
|
plot_tracks(sequence, outputs.logits.squeeze()) |
|
|
|
plt.savefig("track_plot.png") |
|
plt.close() |
|
return affinity_score, "track_plot.png" |
|
|
|
def predict_affinity_with_mutation(RBP, cell_line, ss_wt, ss_mt): |
|
original_affinity = predict_affinity(RBP, cell_line, ss_wt) |
|
mutated_affinity = predict_affinity(RBP, cell_line, ss_mt) |
|
|
|
change_in_affinity = calc_mutation_effect(original_affinity, mutated_affinity) |
|
plot_tracks_comparision(ss_wt, ss_mt, original_affinity, mutated_affinity) |
|
plt.savefig("mutation_plot.png") |
|
plt.close() |
|
|
|
return change_in_affinity, "mutation_plot.png" |
|
|
|
def load_eclip_data(path): |
|
with open(path, 'r') as f: |
|
lines = f.readlines() |
|
try: |
|
assert np.all([i.startswith('>') for e, i in enumerate(lines) if e == 0 or e % 2 == 0]) |
|
except AssertionError: |
|
raise ValueError("Input file is not in the expected FASTA format.") |
|
seq = [i.strip() for i in lines if not i.startswith('>')] |
|
|
|
return np.array(seq) |
|
|
|
class SequenceDataset(torch.utils.data.Dataset): |
|
def __init__(self, barcode, seq, tokenizer, max_length=512): |
|
self.barcode = barcode |
|
self.tokenizer = tokenizer |
|
self.max_length = max_length |
|
self.sequence = seq |
|
self.n = len(self.sequence) |
|
|
|
def __len__(self): |
|
return self.n |
|
|
|
def __getitem__(self, i): |
|
ss = self.sequence[i] |
|
ss = [ss[i:i+3] for i in range(len(ss) - 2)] |
|
seq = [self.barcode] |
|
seq.extend(ss[:-1]) |
|
inputs = self.tokenizer(seq, is_split_into_words=True, add_special_tokens=True, return_tensors='pt') |
|
return inputs['input_ids'] |
|
|
|
def extract_attention_regions(barcode, seq, layer, head): |
|
val_dataset = SequenceDataset(barcode, seq, tokenizer) |
|
val_loader = DataLoader( |
|
val_dataset, |
|
batch_size=4, |
|
shuffle = False, |
|
drop_last=False, |
|
num_workers=2 |
|
) |
|
seq_attn = attention_extractor(model, val_loader, model.model.device, len(val_dataset)) |
|
if len(seq_attn) <= layer or len(seq_attn[layer]) <= head: |
|
raise ValueError("Invalid layer or head index for attention extraction.") |
|
attention_data = seq_attn[layer][head].squeeze() |
|
high_attention_regions = (attention_data > np.percentile(attention_data, 99)).nonzero(as_tuple=True) |
|
|
|
return high_attention_regions |
|
|
|
def predict_motif(RBP, cell_line, file, layer, head): |
|
target = "_".join([RBP, cell_line]) |
|
barcode = prefix_code_dic[target] |
|
seq = load_eclip_data(file) |
|
high_attention_regions = extract_attention_regions(barcode, seq, layer, head) |
|
|
|
with open("high_attention_regions.fasta", "w") as f: |
|
for i, (seq_idx, pos) in enumerate(zip(*high_attention_regions)): |
|
subseq = seq[seq_idx][pos:pos+10] |
|
f.write(f">{i}\n{subseq}\n") |
|
|
|
zip_file = "high_attention_regions.zip" |
|
shutil.make_archive("high_attention_regions", 'zip', ".", "high_attention_regions.fasta") |
|
|
|
return zip_file |
|
|
|
def single_sequence_affinity(RBP, cell_line, ss): |
|
return predict_affinity(RBP, cell_line, ss) |
|
|
|
def mutation_affinity(RBP, cell_line, ss_wt, ss_mt): |
|
return predict_affinity_with_mutation(RBP, cell_line, ss_wt, ss_mt) |
|
|
|
def motif_enrichment(RBP, cell_line, file, layer, head): |
|
return predict_motif(RBP, cell_line, file, layer, head) |
|
|
|
|
|
def get_cell_lines(selected_rbp): |
|
filtered_df = prefix_code[prefix_code['RBP'] == selected_rbp] |
|
cell_lines = filtered_df['cell_line'].unique().tolist() |
|
return gr.update(choices=cell_lines) |
|
|
|
def build_gradio_interface(): |
|
rbp_dropdown = gr.Dropdown(label="Select RBP", choices=rbp_options, value=rbp_options[0]) |
|
cell_line_dropdown = gr.Dropdown(label="Select Cell Line", choices=[]) |
|
|
|
rbp_dropdown.change(fn=get_cell_lines, inputs=rbp_dropdown, outputs=cell_line_dropdown) |
|
|
|
affinity_interface = gr.Interface( |
|
fn=single_sequence_affinity, |
|
inputs=[ |
|
rbp_dropdown, |
|
cell_line_dropdown, |
|
gr.Textbox(label='Input RNA Sequence') |
|
], |
|
outputs=[ |
|
gr.Textbox(label="Predicted Affinity"), |
|
gr.Image(label="Track Plot") |
|
], |
|
description="Predicted binding affinity" |
|
) |
|
|
|
mutation_interface = gr.Interface( |
|
fn=mutation_affinity, |
|
inputs=[ |
|
rbp_dropdown, |
|
cell_line_dropdown, |
|
gr.Textbox(label='Input Wild-type RNA Sequence'), |
|
gr.Textbox(label='Input Mutated RNA Sequence') |
|
], |
|
outputs=[ |
|
gr.Textbox(label="Affinity Change"), |
|
gr.Image(label="Mutation Plot") |
|
], |
|
description="Predicted mutation effect" |
|
) |
|
|
|
motif_interface = gr.Interface( |
|
fn=motif_enrichment, |
|
inputs=[ |
|
rbp_dropdown, |
|
cell_line_dropdown, |
|
gr.File(label="Upload Sequence File"), |
|
gr.Slider(minimum=0, maximum=12, step=1, label="Select Attention Layer"), |
|
gr.Slider(minimum=0, maximum=11, step=1, label="Select Attention Head") |
|
], |
|
outputs=gr.File(label="Download High Attention Regions (ZIP) (result.zip)"), |
|
description="Motif enrichment preparation: Download high attention regions and run AME manually." |
|
) |
|
|
|
app = gr.TabbedInterface( |
|
interface_list=[affinity_interface, mutation_interface, motif_interface], |
|
tab_names=["Binding Affinity Prediction", "Mutation Effect Prediction", "Motif Enrichment"] |
|
) |
|
|
|
return app |
|
|
|
if __name__ == "__main__": |
|
app = build_gradio_interface() |
|
app.launch(share=True) |
|
|
|
|