Reformer / app.py
XLS's picture
Upload 3 files
b830975 verified
import gradio as gr
import pandas as pd
import numpy as np
import transformers as T
import torch
import matplotlib.pyplot as plt
import shutil
from torch.utils.data import DataLoader
from utils import load_model,plot_tracks,calc_mutation_effect,plot_tracks_comparision
from attention_extractor import attention_extractor
import subprocess
tokenizer = None
model = None
def load_tokenizer_and_model():
global tokenizer, model
if tokenizer is None:
tokenizer = T.BertTokenizer.from_pretrained("./model/")
if model is None:
load_tokenizer_and_model()
model = load_model(tokenizer, "./model/model.bin")
prefix_code = pd.read_csv('./data/prefix_codes.csv')
prefix_code_dic = {k:v for k,v in zip(prefix_code.prefix,prefix_code.code_prefix)}
rbp_options = prefix_code['RBP'].unique().tolist()
def sequence_process(RBP, cell_line, ss):
target = "_".join([RBP, cell_line])
barcode = prefix_code_dic[target]
if len(ss) < 3:
raise ValueError("Input sequence length is too short for 3-mer processing.")
ss = [ss[i:i+3] for i in range(len(ss) - 2)] # 3 mer data
seq = [barcode]
seq.extend(ss[:-1])
return seq
def predict_affinity(RBP, cell_line, ss):
sequence = sequence_process(RBP, cell_line, ss)
inputs = tokenizer(sequence, return_tensors="pt", padding=True, truncation=True, max_length=512)
with torch.no_grad():
outputs = model(**inputs)
affinity_score = outputs.logits.squeeze().item()
plot_tracks(sequence, outputs.logits.squeeze())
plt.savefig("track_plot.png")
plt.close()
return affinity_score, "track_plot.png"
def predict_affinity_with_mutation(RBP, cell_line, ss_wt, ss_mt):
original_affinity = predict_affinity(RBP, cell_line, ss_wt)
mutated_affinity = predict_affinity(RBP, cell_line, ss_mt)
change_in_affinity = calc_mutation_effect(original_affinity, mutated_affinity)
plot_tracks_comparision(ss_wt, ss_mt, original_affinity, mutated_affinity)
plt.savefig("mutation_plot.png")
plt.close()
return change_in_affinity, "mutation_plot.png"
def load_eclip_data(path):
with open(path, 'r') as f:
lines = f.readlines()
try:
assert np.all([i.startswith('>') for e, i in enumerate(lines) if e == 0 or e % 2 == 0])
except AssertionError:
raise ValueError("Input file is not in the expected FASTA format.")
seq = [i.strip() for i in lines if not i.startswith('>')]
return np.array(seq)
class SequenceDataset(torch.utils.data.Dataset):
def __init__(self, barcode, seq, tokenizer, max_length=512):
self.barcode = barcode
self.tokenizer = tokenizer
self.max_length = max_length
self.sequence = seq
self.n = len(self.sequence)
def __len__(self):
return self.n
def __getitem__(self, i):
ss = self.sequence[i]
ss = [ss[i:i+3] for i in range(len(ss) - 2)] # 3 mer data
seq = [self.barcode]
seq.extend(ss[:-1])
inputs = self.tokenizer(seq, is_split_into_words=True, add_special_tokens=True, return_tensors='pt')
return inputs['input_ids']
def extract_attention_regions(barcode, seq, layer, head):
val_dataset = SequenceDataset(barcode, seq, tokenizer)
val_loader = DataLoader(
val_dataset,
batch_size=4,
shuffle = False,
drop_last=False,
num_workers=2
)
seq_attn = attention_extractor(model, val_loader, model.model.device, len(val_dataset))
if len(seq_attn) <= layer or len(seq_attn[layer]) <= head:
raise ValueError("Invalid layer or head index for attention extraction.")
attention_data = seq_attn[layer][head].squeeze()
high_attention_regions = (attention_data > np.percentile(attention_data, 99)).nonzero(as_tuple=True)
return high_attention_regions
def predict_motif(RBP, cell_line, file, layer, head):
target = "_".join([RBP, cell_line])
barcode = prefix_code_dic[target]
seq = load_eclip_data(file)
high_attention_regions = extract_attention_regions(barcode, seq, layer, head)
with open("high_attention_regions.fasta", "w") as f:
for i, (seq_idx, pos) in enumerate(zip(*high_attention_regions)):
subseq = seq[seq_idx][pos:pos+10] # 10-mer
f.write(f">{i}\n{subseq}\n")
zip_file = "high_attention_regions.zip"
shutil.make_archive("high_attention_regions", 'zip', ".", "high_attention_regions.fasta")
return zip_file
def single_sequence_affinity(RBP, cell_line, ss):
return predict_affinity(RBP, cell_line, ss)
def mutation_affinity(RBP, cell_line, ss_wt, ss_mt):
return predict_affinity_with_mutation(RBP, cell_line, ss_wt, ss_mt)
def motif_enrichment(RBP, cell_line, file, layer, head):
return predict_motif(RBP, cell_line, file, layer, head)
# defined input and output in Gradio
def get_cell_lines(selected_rbp):
filtered_df = prefix_code[prefix_code['RBP'] == selected_rbp]
cell_lines = filtered_df['cell_line'].unique().tolist()
return gr.update(choices=cell_lines)
def build_gradio_interface():
rbp_dropdown = gr.Dropdown(label="Select RBP", choices=rbp_options, value=rbp_options[0])
cell_line_dropdown = gr.Dropdown(label="Select Cell Line", choices=[])
rbp_dropdown.change(fn=get_cell_lines, inputs=rbp_dropdown, outputs=cell_line_dropdown)
affinity_interface = gr.Interface(
fn=single_sequence_affinity,
inputs=[
rbp_dropdown,
cell_line_dropdown,
gr.Textbox(label='Input RNA Sequence')
],
outputs=[
gr.Textbox(label="Predicted Affinity"),
gr.Image(label="Track Plot")
],
description="Predicted binding affinity"
)
mutation_interface = gr.Interface(
fn=mutation_affinity,
inputs=[
rbp_dropdown,
cell_line_dropdown,
gr.Textbox(label='Input Wild-type RNA Sequence'),
gr.Textbox(label='Input Mutated RNA Sequence')
],
outputs=[
gr.Textbox(label="Affinity Change"),
gr.Image(label="Mutation Plot")
],
description="Predicted mutation effect"
)
motif_interface = gr.Interface(
fn=motif_enrichment,
inputs=[
rbp_dropdown,
cell_line_dropdown,
gr.File(label="Upload Sequence File"),
gr.Slider(minimum=0, maximum=12, step=1, label="Select Attention Layer"),
gr.Slider(minimum=0, maximum=11, step=1, label="Select Attention Head")
],
outputs=gr.File(label="Download High Attention Regions (ZIP) (result.zip)"),
description="Motif enrichment preparation: Download high attention regions and run AME manually."
)
app = gr.TabbedInterface(
interface_list=[affinity_interface, mutation_interface, motif_interface],
tab_names=["Binding Affinity Prediction", "Mutation Effect Prediction", "Motif Enrichment"]
)
return app
if __name__ == "__main__":
app = build_gradio_interface()
app.launch(share=True)