CRISPRTool / app.py
supercat666's picture
fix plot
d487a41
raw
history blame
36 kB
import os
import tiger
import cas9att
import cas9attvcf
import cas9off
import cas12
import cas12lstm
import cas12lstmvcf
import pandas as pd
import streamlit as st
import plotly.graph_objs as go
import numpy as np
from pathlib import Path
import zipfile
import io
import gtracks
import subprocess
import cyvcf2
# title and documentation
st.markdown(Path('crisprTool.md').read_text(), unsafe_allow_html=True)
st.divider()
CRISPR_MODELS = ['Cas9', 'Cas12', 'Cas13d']
selected_model = st.selectbox('Select CRISPR model:', CRISPR_MODELS, key='selected_model')
cas9att_path = 'cas9_model/Cas9_MultiHeadAttention_weights.h5'
cas12lstm_path = 'cas12_model/BiLSTM_Cpf1_weights.h5'
#plot functions
def generate_coolbox_plot(bigwig_path, region, output_image_path):
frame = CoolBox()
frame += BigWig(bigwig_path)
frame.plot(region, savefig=output_image_path)
def generate_pygenometracks_plot(bigwig_file_path, region, output_image_path):
# Define the configuration for pyGenomeTracks
tracks = """
[bigwig]
file = {}
height = 4
color = blue
min_value = 0
max_value = 10
""".format(bigwig_file_path)
# Write the configuration to a temporary INI file
config_file_path = "pygenometracks.ini"
with open(config_file_path, 'w') as configfile:
configfile.write(tracks)
# Define the region to plot
region_dict = {'chrom': region.split(':')[0],
'start': int(region.split(':')[1].split('-')[0]),
'end': int(region.split(':')[1].split('-')[1])}
# Generate the plot
plot_tracks(tracks_file=config_file_path,
region=region_dict,
out_file_name=output_image_path)
@st.cache_data
def convert_df(df):
# IMPORTANT: Cache the conversion to prevent computation on every rerun
return df.to_csv().encode('utf-8')
def mode_change_callback():
if st.session_state.mode in {tiger.RUN_MODES['all'], tiger.RUN_MODES['titration']}: # TODO: support titration
st.session_state.check_off_targets = False
st.session_state.disable_off_target_checkbox = True
else:
st.session_state.disable_off_target_checkbox = False
def progress_update(update_text, percent_complete):
with progress.container():
st.write(update_text)
st.progress(percent_complete / 100)
def initiate_run():
# initialize state variables
st.session_state.transcripts = None
st.session_state.input_error = None
st.session_state.on_target = None
st.session_state.titration = None
st.session_state.off_target = None
# initialize transcript DataFrame
transcripts = pd.DataFrame(columns=[tiger.ID_COL, tiger.SEQ_COL])
# manual entry
if st.session_state.entry_method == ENTRY_METHODS['manual']:
transcripts = pd.DataFrame({
tiger.ID_COL: ['ManualEntry'],
tiger.SEQ_COL: [st.session_state.manual_entry]
}).set_index(tiger.ID_COL)
# fasta file upload
elif st.session_state.entry_method == ENTRY_METHODS['fasta']:
if st.session_state.fasta_entry is not None:
fasta_path = st.session_state.fasta_entry.name
with open(fasta_path, 'w') as f:
f.write(st.session_state.fasta_entry.getvalue().decode('utf-8'))
transcripts = tiger.load_transcripts([fasta_path], enforce_unique_ids=False)
os.remove(fasta_path)
# convert to upper case as used by tokenizer
transcripts[tiger.SEQ_COL] = transcripts[tiger.SEQ_COL].apply(lambda s: s.upper().replace('U', 'T'))
# ensure all transcripts have unique identifiers
if transcripts.index.has_duplicates:
st.session_state.input_error = "Duplicate transcript ID's detected in fasta file"
# ensure all transcripts only contain nucleotides A, C, G, T, and wildcard N
elif not all(transcripts[tiger.SEQ_COL].apply(lambda s: set(s).issubset(tiger.NUCLEOTIDE_TOKENS.keys()))):
st.session_state.input_error = 'Transcript(s) must only contain upper or lower case A, C, G, and Ts or Us'
# ensure all transcripts satisfy length requirements
elif any(transcripts[tiger.SEQ_COL].apply(lambda s: len(s) < tiger.TARGET_LEN)):
st.session_state.input_error = 'Transcript(s) must be at least {:d} bases.'.format(tiger.TARGET_LEN)
# run model if we have any transcripts
elif len(transcripts) > 0:
st.session_state.transcripts = transcripts
def parse_gene_annotations(file_path):
gene_dict = {}
with open(file_path, 'r') as file:
headers = file.readline().strip().split('\t') # Assuming tab-delimited file
symbol_idx = headers.index('Approved symbol') # Find index of 'Approved symbol'
ensembl_idx = headers.index('Ensembl gene ID') # Find index of 'Ensembl gene ID'
for line in file:
values = line.strip().split('\t')
# Ensure we have enough values and add mapping from symbol to Ensembl ID
if len(values) > max(symbol_idx, ensembl_idx):
gene_dict[values[symbol_idx]] = values[ensembl_idx]
return gene_dict
# Replace 'your_annotation_file.txt' with the path to your actual gene annotation file
gene_annotations = parse_gene_annotations('Human_genes_HUGO_02242024_annotation.txt')
gene_symbol_list = list(gene_annotations.keys()) # List of gene symbols for the autocomplete feature
# Check if the selected model is Cas9
if selected_model == 'Cas9':
# Use a radio button to select enzymes, making sure only one can be selected at a time
target_selection = st.radio(
"Select either on-target, on-target with mutation or off-target:",
('on-target', 'mutation', 'off-target'),
key='target_selection'
)
if 'current_gene_symbol' not in st.session_state:
st.session_state['current_gene_symbol'] = ""
# Define a function to clean up old files
def clean_up_old_files(gene_symbol):
genbank_file_path = f"{gene_symbol}_crispr_targets.gb"
bed_file_path = f"{gene_symbol}_crispr_targets.bed"
csv_file_path = f"{gene_symbol}_crispr_predictions.csv"
for path in [genbank_file_path, bed_file_path, csv_file_path]:
if os.path.exists(path):
os.remove(path)
# Gene symbol entry with autocomplete-like feature
gene_symbol = st.selectbox('Enter a Gene Symbol:', [''] + gene_symbol_list, key='gene_symbol',
format_func=lambda x: x if x else "")
# Handle gene symbol change and file cleanup
if gene_symbol != st.session_state['current_gene_symbol'] and gene_symbol:
if st.session_state['current_gene_symbol']:
# Clean up files only if a different gene symbol is entered and a previous symbol exists
clean_up_old_files(st.session_state['current_gene_symbol'])
# Update the session state with the new gene symbol
st.session_state['current_gene_symbol'] = gene_symbol
if target_selection == 'on-target':
# Prediction button
predict_button = st.button('Predict on-target')
if 'exons' not in st.session_state:
st.session_state['exons'] = []
# Process predictions
if predict_button and gene_symbol:
with st.spinner('Predicting... Please wait'):
predictions, gene_sequence, exons = cas9att.process_gene(gene_symbol, cas9att_path)
sorted_predictions = sorted(predictions, key=lambda x: x[8], reverse=True)[:10]
st.session_state['on_target_results'] = sorted_predictions
st.session_state['gene_sequence'] = gene_sequence # Save gene sequence in session state
st.session_state['exons'] = exons # Store exon data
# Notify the user once the process is completed successfully.
st.success('Prediction completed!')
st.session_state['prediction_made'] = True
if 'on_target_results' in st.session_state and st.session_state['on_target_results']:
ensembl_id = gene_annotations.get(gene_symbol, 'Unknown') # Get Ensembl ID or default to 'Unknown'
col1, col2, col3 = st.columns(3)
with col1:
st.markdown("**Genome**")
st.markdown("Homo sapiens")
with col2:
st.markdown("**Gene**")
st.markdown(f"{gene_symbol} : {ensembl_id} (primary)")
with col3:
st.markdown("**Nuclease**")
st.markdown("SpCas9")
# Include "Target" in the DataFrame's columns
try:
df = pd.DataFrame(st.session_state['on_target_results'],
columns=["Chr", "Start Pos", "End Pos", "Strand", "Transcript", "Exon", "Target", "gRNA", "Prediction"])
st.dataframe(df)
except ValueError as e:
st.error(f"DataFrame creation error: {e}")
# Optionally print or log the problematic data for debugging:
print(st.session_state['on_target_results'])
# Initialize Plotly figure
fig = go.Figure()
EXON_BASE = 0 # Base position for exons and CDS on the Y axis
EXON_HEIGHT = 0.02 # How 'tall' the exon markers should appear
# Plot Exons as small markers on the X-axis
for exon in st.session_state['exons']:
exon_start, exon_end = exon['start'], exon['end']
fig.add_trace(go.Bar(
x=[(exon_start + exon_end) / 2],
y=[EXON_HEIGHT],
width=[exon_end - exon_start],
base=EXON_BASE,
marker_color='rgba(128, 0, 128, 0.5)',
name='Exon'
))
VERTICAL_GAP = 0.2 # Gap between different ranks
# Define max and min Y values based on strand and rank
MAX_STRAND_Y = 0.1 # Maximum Y value for positive strand results
MIN_STRAND_Y = -0.1 # Minimum Y value for negative strand results
# Iterate over top 5 sorted predictions to create the plot
for i, prediction in enumerate(st.session_state['on_target_results'][:5], start=1): # Only top 5
chrom, start, end, strand, transcript, exon, target, gRNA, prediction_score = prediction
midpoint = (int(start) + int(end)) / 2
# Vertical position based on rank, modified by strand
y_value = (MAX_STRAND_Y - (i - 1) * VERTICAL_GAP) if strand == '1' or strand == '+' else (
MIN_STRAND_Y + (i - 1) * VERTICAL_GAP)
fig.add_trace(go.Scatter(
x=[midpoint],
y=[y_value],
mode='markers+text',
marker=dict(symbol='triangle-up' if strand == '1' or strand == '+' else 'triangle-down',
size=12),
text=f"Rank: {i}", # Text label
hoverinfo='text',
hovertext=f"Rank: {i}<br>Chromosome: {chrom}<br>Target Sequence: {target}<br>gRNA: {gRNA}<br>Start: {start}<br>End: {end}<br>Strand: {'+' if strand == '1' or strand == '+' else '-'}<br>Transcript: {transcript}<br>Prediction: {prediction_score:.4f}",
))
# Update layout for clarity and interaction
fig.update_layout(
title='Top 5 gRNA Sequences by Prediction Score',
xaxis_title='Genomic Position',
yaxis_title='Strand',
yaxis=dict(tickvals=[MAX_STRAND_Y, MIN_STRAND_Y], ticktext=['+', '-']),
showlegend=False,
hovermode='x unified',
)
# Display the plot
st.plotly_chart(fig)
if 'gene_sequence' in st.session_state and st.session_state['gene_sequence']:
gene_symbol = st.session_state['current_gene_symbol']
gene_sequence = st.session_state['gene_sequence']
# Define file paths
genbank_file_path = f"{gene_symbol}_crispr_targets.gb"
bed_file_path = f"{gene_symbol}_crispr_targets.bed"
csv_file_path = f"{gene_symbol}_crispr_predictions.csv"
plot_image_path = f"{gene_symbol}_gtracks_plot.png"
# Generate files
cas9att.generate_genbank_file_from_df(df, gene_sequence, gene_symbol, genbank_file_path)
cas9att.create_bed_file_from_df(df, bed_file_path)
cas9att.create_csv_from_df(df, csv_file_path)
# Prepare an in-memory buffer for the ZIP file
zip_buffer = io.BytesIO()
with zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED) as zip_file:
# For each file, add it to the ZIP file
zip_file.write(genbank_file_path)
zip_file.write(bed_file_path)
zip_file.write(csv_file_path)
# Important: move the cursor to the beginning of the BytesIO buffer before reading it
zip_buffer.seek(0)
# Specify the region you want to visualize
min_start = df['Start Pos'].min()
max_end = df['End Pos'].max()
chromosome = df['Chr'].mode()[0] # Assumes most common chromosome is the target
region = f"{chromosome}:{min_start}-{max_end}"
# Generate the pyGenomeTracks plot
gtracks_command = f"gtracks {region} {bed_file_path} {plot_image_path}"
subprocess.run(gtracks_command, shell=True)
st.image(plot_image_path)
# Display the download button for the ZIP file
st.download_button(
label="Download GenBank, BED, CSV files as ZIP",
data=zip_buffer.getvalue(),
file_name=f"{gene_symbol}_files.zip",
mime="application/zip"
)
elif target_selection == 'mutation':
# Prediction button
predict_button = st.button('Predict on-target')
vcf_reader = cyvcf2.VCF('SRR25934512.filter.snps.indels.vcf.gz')
if 'exons' not in st.session_state:
st.session_state['exons'] = []
# Process predictions
if predict_button and gene_symbol:
with st.spinner('Predicting... Please wait'):
predictions, gene_sequence, exons = cas9attvcf.process_gene(gene_symbol, vcf_reader, cas9att_path)
full_predictions = sorted(predictions, key=lambda x: x[8], reverse=True)
sorted_predictions = sorted(predictions, key=lambda x: x[8], reverse=True)[:10]
st.session_state['full_results'] = full_predictions
st.session_state['on_target_results'] = sorted_predictions
st.session_state['gene_sequence'] = gene_sequence # Save gene sequence in session state
st.session_state['exons'] = exons # Store exon data
# Notify the user once the process is completed successfully.
st.success('Prediction completed!')
st.session_state['prediction_made'] = True
if 'on_target_results' in st.session_state and st.session_state['on_target_results']:
ensembl_id = gene_annotations.get(gene_symbol, 'Unknown') # Get Ensembl ID or default to 'Unknown'
col1, col2, col3 = st.columns(3)
with col1:
st.markdown("**Genome**")
st.markdown("Homo sapiens")
with col2:
st.markdown("**Gene**")
st.markdown(f"{gene_symbol} : {ensembl_id} (primary)")
with col3:
st.markdown("**Nuclease**")
st.markdown("SpCas9")
# Include "Target" in the DataFrame's columns
try:
df = pd.DataFrame(st.session_state['on_target_results'],
columns=["Gene Symbol", "Chr", "Strand", "Target Start", "Transcript", "Exon",
"Target",
"gRNA", "Prediction", "Is Mutation"])
df_full = pd.DataFrame(st.session_state['full_results'],
columns=["Gene Symbol", "Chr", "Strand", "Target Start", "Transcript",
"Exon", "Target",
"gRNA", "Prediction", "Is Mutation"])
st.dataframe(df)
except ValueError as e:
st.error(f"DataFrame creation error: {e}")
# Optionally print or log the problematic data for debugging:
print(st.session_state['on_target_results'])
if 'gene_sequence' in st.session_state and st.session_state['gene_sequence']:
gene_symbol = st.session_state['current_gene_symbol']
gene_sequence = st.session_state['gene_sequence']
# Define file paths
genbank_file_path = f"{gene_symbol}_crispr_targets.gb"
bed_file_path = f"{gene_symbol}_crispr_targets.bed"
csv_file_path = f"{gene_symbol}_crispr_predictions.csv"
plot_image_path = f"{gene_symbol}_gtracks_plot.png"
# Generate files
cas9att.generate_genbank_file_from_df(df_full, gene_sequence, gene_symbol, genbank_file_path)
cas9att.create_bed_file_from_df(df_full, bed_file_path)
cas9att.create_csv_from_df(df_full, csv_file_path)
# Prepare an in-memory buffer for the ZIP file
zip_buffer = io.BytesIO()
with zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED) as zip_file:
# For each file, add it to the ZIP file
zip_file.write(genbank_file_path)
zip_file.write(bed_file_path)
zip_file.write(csv_file_path)
# Display the download button for the ZIP file
st.download_button(
label="Download GenBank, BED, CSV files as ZIP",
data=zip_buffer.getvalue(),
file_name=f"{gene_symbol}_files.zip",
mime="application/zip"
)
elif target_selection == 'off-target':
ENTRY_METHODS = dict(
manual='Manual entry of target sequence',
txt="txt file upload"
)
if __name__ == '__main__':
# app initialization for Cas9 off-target
if 'target_sequence' not in st.session_state:
st.session_state.target_sequence = None
if 'input_error' not in st.session_state:
st.session_state.input_error = None
if 'off_target_results' not in st.session_state:
st.session_state.off_target_results = None
# target sequence entry
st.selectbox(
label='How would you like to provide target sequences?',
options=ENTRY_METHODS.values(),
key='entry_method',
disabled=st.session_state.target_sequence is not None
)
if st.session_state.entry_method == ENTRY_METHODS['manual']:
st.text_input(
label='Enter on/off sequences:',
key='manual_entry',
placeholder='Enter on/off sequences like:GGGTGGGGGGAGTTTGCTCCAGG,AGGTGGGGTGA_TTTGCTCCAGG',
disabled=st.session_state.target_sequence is not None
)
elif st.session_state.entry_method == ENTRY_METHODS['txt']:
st.file_uploader(
label='Upload a txt file:',
key='txt_entry',
disabled=st.session_state.target_sequence is not None
)
# prediction button
if st.button('Predict off-target'):
if st.session_state.entry_method == ENTRY_METHODS['manual']:
user_input = st.session_state.manual_entry
if user_input: # Check if user_input is not empty
predictions = cas9off.process_input_and_predict(user_input, input_type='manual')
elif st.session_state.entry_method == ENTRY_METHODS['txt']:
uploaded_file = st.session_state.txt_entry
if uploaded_file is not None:
# Read the uploaded file content
file_content = uploaded_file.getvalue().decode("utf-8")
predictions = cas9off.process_input_and_predict(file_content, input_type='manual')
st.session_state.off_target_results = predictions
else:
predictions = None
progress = st.empty()
# input error display
error = st.empty()
if st.session_state.input_error is not None:
error.error(st.session_state.input_error, icon="🚨")
else:
error.empty()
# off-target results display
off_target_results = st.empty()
if st.session_state.off_target_results is not None:
with off_target_results.container():
if len(st.session_state.off_target_results) > 0:
st.write('Off-target predictions:', st.session_state.off_target_results)
st.download_button(
label='Download off-target predictions',
data=convert_df(st.session_state.off_target_results),
file_name='off_target_results.csv',
mime='text/csv'
)
else:
st.write('No significant off-target effects detected!')
else:
off_target_results.empty()
# running the CRISPR-Net model for off-target predictions
if st.session_state.target_sequence is not None:
st.session_state.off_target_results = cas9off.predict_off_targets(
target_sequence=st.session_state.target_sequence,
status_update_fn=progress_update
)
st.session_state.target_sequence = None
st.experimental_rerun()
elif selected_model == 'Cas12':
def visualize_genomic_data():
fig = go.Figure()
EXON_BASE = 0 # Base position for exons and CDS on the Y axis
EXON_HEIGHT = 0.02 # How 'tall' the exon markers should appear
# Plot Exons as small markers on the X-axis
for exon in st.session_state['exons']:
try:
exon_start, exon_end = int(exon['start']), int(exon['end'])
fig.add_trace(go.Bar(
x=[(exon_start + exon_end) / 2],
y=[EXON_HEIGHT],
width=[exon_end - exon_start],
base=EXON_BASE,
marker_color='rgba(128, 0, 128, 0.5)',
name='Exon'
))
except ValueError:
st.error("Error in exon positions. Exon positions should be numeric.")
VERTICAL_GAP = 0.2 # Gap between different ranks
# Define max and min Y values based on strand and rank
MAX_STRAND_Y = 0.1 # Maximum Y value for positive strand results
MIN_STRAND_Y = -0.1 # Minimum Y value for negative strand results
# Iterate over top 5 sorted predictions to create the plot
for i, prediction in enumerate(st.session_state['on_target_results'][:5], start=1): # Only top 5
try:
start, end = int(prediction['Start Pos']), int(prediction['End Pos'])
midpoint = (start + end) / 2
strand = prediction['Strand']
y_value = (MAX_STRAND_Y - (i - 1) * VERTICAL_GAP) if strand in ['1', '+'] else (
MIN_STRAND_Y + (i - 1) * VERTICAL_GAP)
fig.add_trace(go.Scatter(
x=[midpoint],
y=[y_value],
mode='markers+text',
marker=dict(symbol='triangle-up' if strand in ['1', '+'] else 'triangle-down', size=12),
text=f"Rank: {i}",
hoverinfo='text',
hovertext=f"Rank: {i}<br>Chromosome: {prediction['Chr']}<br>Target Sequence: {prediction['Target']}<br>gRNA: {prediction['gRNA']}<br>Start: {start}<br>End: {end}<br>Strand: {'+' if strand in ['1', '+'] else '-'}<br>Transcript: {prediction['Transcript']}<br>Prediction: {prediction['Prediction']:.4f}",
))
except ValueError:
st.error("Error in prediction positions. Start and end positions should be numeric.")
# Update layout for clarity and interaction
fig.update_layout(
title='Top 5 gRNA Sequences by Prediction Score',
xaxis_title='Genomic Position',
yaxis_title='Strand',
yaxis=dict(tickvals=[MAX_STRAND_Y, MIN_STRAND_Y], ticktext=['+', '-']),
showlegend=False,
hovermode='x unified',
)
# Display the plot
st.plotly_chart(fig)
# File generation and download
generate_and_download_files(df, gene_symbol)
def generate_and_download_files(df, gene_symbol):
genbank_file_path = f"{gene_symbol}_crispr_targets.gb"
bed_file_path = f"{gene_symbol}_crispr_targets.bed"
csv_file_path = f"{gene_symbol}_crispr_predictions.csv"
df.to_csv(csv_file_path, index=False)
# Assume functions to generate GenBank and BED are defined in cas12lstm or cas12lstmvcf
cas12lstm.generate_genbank_file_from_df(df, gene_symbol, genbank_file_path)
cas12lstm.create_bed_file_from_df(df, bed_file_path)
zip_buffer = io.BytesIO()
with zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED) as zip_file:
zip_file.write(genbank_file_path)
zip_file.write(bed_file_path)
zip_file.write(csv_file_path)
zip_buffer.seek(0)
st.download_button("Download GenBank, BED, CSV files as ZIP", data=zip_buffer.getvalue(),
file_name=f"{gene_symbol}_files.zip", mime="application/zip")
def display_results(predictions, gene_sequence, exons, gene_symbol):
st.success('Prediction completed!')
ensembl_id = gene_annotations.get(gene_symbol, 'Unknown')
st.write(f"**Genome:** Homo sapiens")
st.write(f"**Gene:** {gene_symbol} : {ensembl_id} (primary)")
st.write("**Nuclease:** Cas12")
df = pd.DataFrame(predictions,
columns=["Chr", "Start Pos", "End Pos", "Strand", "Transcript", "Exon", "Target", "gRNA",
"Prediction"])
st.dataframe(df)
# Visualization and file generation as demonstrated in the Cas9 example
visualize_and_generate_files(df, gene_sequence, exons, gene_symbol)
cas12target_selection = st.radio(
"Select either regular or mutation:",
('regular', 'mutation'),
key='cas12target_selection'
)
if 'current_gene_symbol' not in st.session_state:
st.session_state['current_gene_symbol'] = ""
def clean_up_old_files(gene_symbol):
for suffix in ['_crispr_targets.gb', '_crispr_targets.bed', '_crispr_predictions.csv']:
file_path = f"{gene_symbol}{suffix}"
if os.path.exists(file_path):
os.remove(file_path)
gene_symbol = st.selectbox(
'Enter a Gene Symbol:',
[''] + gene_symbol_list,
key='gene_symbol',
format_func=lambda x: x if x else ""
)
if gene_symbol != st.session_state['current_gene_symbol']:
if st.session_state['current_gene_symbol']:
clean_up_old_files(st.session_state['current_gene_symbol'])
st.session_state['current_gene_symbol'] = gene_symbol
if cas12target_selection == 'regular':
if st.button('Predict cas12 Regular'):
with st.spinner('Predicting... Please wait'):
predictions, gene_sequence, exons = cas12lstm.process_gene(gene_symbol, cas12lstm_path)
sorted_predictions = sorted(predictions, key=lambda x: x[8], reverse=True)[:10]
display_results(sorted_predictions, gene_sequence, exons, gene_symbol)
elif cas12target_selection == 'mutation':
vcf_reader = cyvcf2.VCF('SRR25934512.filter.snps.indels.vcf.gz')
if st.button('Predict cas12 Mutation'):
with st.spinner('Predicting... Please wait'):
predictions, gene_sequence, exons = cas12lstmvcf.process_gene(gene_symbol, vcf_reader, cas12lstm_path)
sorted_predictions = sorted(predictions, key=lambda x: x[8], reverse=True)[:10]
display_results(sorted_predictions, gene_sequence, exons, gene_symbol)
elif selected_model == 'Cas13d':
ENTRY_METHODS = dict(
manual='Manual entry of single transcript',
fasta="Fasta file upload (supports multiple transcripts if they have unique ID's)"
)
if __name__ == '__main__':
# app initialization
if 'mode' not in st.session_state:
st.session_state.mode = tiger.RUN_MODES['all']
st.session_state.disable_off_target_checkbox = True
if 'entry_method' not in st.session_state:
st.session_state.entry_method = ENTRY_METHODS['manual']
if 'transcripts' not in st.session_state:
st.session_state.transcripts = None
if 'input_error' not in st.session_state:
st.session_state.input_error = None
if 'on_target' not in st.session_state:
st.session_state.on_target = None
if 'titration' not in st.session_state:
st.session_state.titration = None
if 'off_target' not in st.session_state:
st.session_state.off_target = None
# mode selection
col1, col2 = st.columns([0.65, 0.35])
with col1:
st.radio(
label='What do you want to predict?',
options=tuple(tiger.RUN_MODES.values()),
key='mode',
on_change=mode_change_callback,
disabled=st.session_state.transcripts is not None,
)
with col2:
st.checkbox(
label='Find off-target effects (slow)',
key='check_off_targets',
disabled=st.session_state.disable_off_target_checkbox or st.session_state.transcripts is not None
)
# transcript entry
st.selectbox(
label='How would you like to provide transcript(s) of interest?',
options=ENTRY_METHODS.values(),
key='entry_method',
disabled=st.session_state.transcripts is not None
)
if st.session_state.entry_method == ENTRY_METHODS['manual']:
st.text_input(
label='Enter a target transcript:',
key='manual_entry',
placeholder='Upper or lower case',
disabled=st.session_state.transcripts is not None
)
elif st.session_state.entry_method == ENTRY_METHODS['fasta']:
st.file_uploader(
label='Upload a fasta file:',
key='fasta_entry',
disabled=st.session_state.transcripts is not None
)
# let's go!
st.button(label='Get predictions!', on_click=initiate_run, disabled=st.session_state.transcripts is not None)
progress = st.empty()
# input error
error = st.empty()
if st.session_state.input_error is not None:
error.error(st.session_state.input_error, icon="🚨")
else:
error.empty()
# on-target results
on_target_results = st.empty()
if st.session_state.on_target is not None:
with on_target_results.container():
st.write('On-target predictions:', st.session_state.on_target)
st.download_button(
label='Download on-target predictions',
data=convert_df(st.session_state.on_target),
file_name='on_target.csv',
mime='text/csv'
)
else:
on_target_results.empty()
# titration results
titration_results = st.empty()
if st.session_state.titration is not None:
with titration_results.container():
st.write('Titration predictions:', st.session_state.titration)
st.download_button(
label='Download titration predictions',
data=convert_df(st.session_state.titration),
file_name='titration.csv',
mime='text/csv'
)
else:
titration_results.empty()
# off-target results
off_target_results = st.empty()
if st.session_state.off_target is not None:
with off_target_results.container():
if len(st.session_state.off_target) > 0:
st.write('Off-target predictions:', st.session_state.off_target)
st.download_button(
label='Download off-target predictions',
data=convert_df(st.session_state.off_target),
file_name='off_target.csv',
mime='text/csv'
)
else:
st.write('We did not find any off-target effects!')
else:
off_target_results.empty()
# keep trying to run model until we clear inputs (streamlit UI changes can induce race-condition reruns)
if st.session_state.transcripts is not None:
st.session_state.on_target, st.session_state.titration, st.session_state.off_target = tiger.tiger_exhibit(
transcripts=st.session_state.transcripts,
mode={v: k for k, v in tiger.RUN_MODES.items()}[st.session_state.mode],
check_off_targets=st.session_state.check_off_targets,
status_update_fn=progress_update
)
st.session_state.transcripts = None
st.experimental_rerun()