Spaces:
Sleeping
Sleeping
import os | |
import tiger | |
import cas9att | |
import cas9attvcf | |
import cas9off | |
import cas12 | |
import cas12lstm | |
import cas12lstmvcf | |
import pandas as pd | |
import streamlit as st | |
import plotly.graph_objs as go | |
import numpy as np | |
from pathlib import Path | |
import zipfile | |
import io | |
import gtracks | |
import subprocess | |
import cyvcf2 | |
# title and documentation | |
st.markdown(Path('crisprTool.md').read_text(), unsafe_allow_html=True) | |
st.divider() | |
CRISPR_MODELS = ['Cas9', 'Cas12', 'Cas13d'] | |
selected_model = st.selectbox('Select CRISPR model:', CRISPR_MODELS, key='selected_model') | |
cas9att_path = 'cas9_model/Cas9_MultiHeadAttention_weights.h5' | |
cas12lstm_path = 'cas12_model/BiLSTM_Cpf1_weights.h5' | |
#plot functions | |
def generate_coolbox_plot(bigwig_path, region, output_image_path): | |
frame = CoolBox() | |
frame += BigWig(bigwig_path) | |
frame.plot(region, savefig=output_image_path) | |
def generate_pygenometracks_plot(bigwig_file_path, region, output_image_path): | |
# Define the configuration for pyGenomeTracks | |
tracks = """ | |
[bigwig] | |
file = {} | |
height = 4 | |
color = blue | |
min_value = 0 | |
max_value = 10 | |
""".format(bigwig_file_path) | |
# Write the configuration to a temporary INI file | |
config_file_path = "pygenometracks.ini" | |
with open(config_file_path, 'w') as configfile: | |
configfile.write(tracks) | |
# Define the region to plot | |
region_dict = {'chrom': region.split(':')[0], | |
'start': int(region.split(':')[1].split('-')[0]), | |
'end': int(region.split(':')[1].split('-')[1])} | |
# Generate the plot | |
plot_tracks(tracks_file=config_file_path, | |
region=region_dict, | |
out_file_name=output_image_path) | |
def convert_df(df): | |
# IMPORTANT: Cache the conversion to prevent computation on every rerun | |
return df.to_csv().encode('utf-8') | |
def mode_change_callback(): | |
if st.session_state.mode in {tiger.RUN_MODES['all'], tiger.RUN_MODES['titration']}: # TODO: support titration | |
st.session_state.check_off_targets = False | |
st.session_state.disable_off_target_checkbox = True | |
else: | |
st.session_state.disable_off_target_checkbox = False | |
def progress_update(update_text, percent_complete): | |
with progress.container(): | |
st.write(update_text) | |
st.progress(percent_complete / 100) | |
def initiate_run(): | |
# initialize state variables | |
st.session_state.transcripts = None | |
st.session_state.input_error = None | |
st.session_state.on_target = None | |
st.session_state.titration = None | |
st.session_state.off_target = None | |
# initialize transcript DataFrame | |
transcripts = pd.DataFrame(columns=[tiger.ID_COL, tiger.SEQ_COL]) | |
# manual entry | |
if st.session_state.entry_method == ENTRY_METHODS['manual']: | |
transcripts = pd.DataFrame({ | |
tiger.ID_COL: ['ManualEntry'], | |
tiger.SEQ_COL: [st.session_state.manual_entry] | |
}).set_index(tiger.ID_COL) | |
# fasta file upload | |
elif st.session_state.entry_method == ENTRY_METHODS['fasta']: | |
if st.session_state.fasta_entry is not None: | |
fasta_path = st.session_state.fasta_entry.name | |
with open(fasta_path, 'w') as f: | |
f.write(st.session_state.fasta_entry.getvalue().decode('utf-8')) | |
transcripts = tiger.load_transcripts([fasta_path], enforce_unique_ids=False) | |
os.remove(fasta_path) | |
# convert to upper case as used by tokenizer | |
transcripts[tiger.SEQ_COL] = transcripts[tiger.SEQ_COL].apply(lambda s: s.upper().replace('U', 'T')) | |
# ensure all transcripts have unique identifiers | |
if transcripts.index.has_duplicates: | |
st.session_state.input_error = "Duplicate transcript ID's detected in fasta file" | |
# ensure all transcripts only contain nucleotides A, C, G, T, and wildcard N | |
elif not all(transcripts[tiger.SEQ_COL].apply(lambda s: set(s).issubset(tiger.NUCLEOTIDE_TOKENS.keys()))): | |
st.session_state.input_error = 'Transcript(s) must only contain upper or lower case A, C, G, and Ts or Us' | |
# ensure all transcripts satisfy length requirements | |
elif any(transcripts[tiger.SEQ_COL].apply(lambda s: len(s) < tiger.TARGET_LEN)): | |
st.session_state.input_error = 'Transcript(s) must be at least {:d} bases.'.format(tiger.TARGET_LEN) | |
# run model if we have any transcripts | |
elif len(transcripts) > 0: | |
st.session_state.transcripts = transcripts | |
def parse_gene_annotations(file_path): | |
gene_dict = {} | |
with open(file_path, 'r') as file: | |
headers = file.readline().strip().split('\t') # Assuming tab-delimited file | |
symbol_idx = headers.index('Approved symbol') # Find index of 'Approved symbol' | |
ensembl_idx = headers.index('Ensembl gene ID') # Find index of 'Ensembl gene ID' | |
for line in file: | |
values = line.strip().split('\t') | |
# Ensure we have enough values and add mapping from symbol to Ensembl ID | |
if len(values) > max(symbol_idx, ensembl_idx): | |
gene_dict[values[symbol_idx]] = values[ensembl_idx] | |
return gene_dict | |
# Replace 'your_annotation_file.txt' with the path to your actual gene annotation file | |
gene_annotations = parse_gene_annotations('Human_genes_HUGO_02242024_annotation.txt') | |
gene_symbol_list = list(gene_annotations.keys()) # List of gene symbols for the autocomplete feature | |
# Check if the selected model is Cas9 | |
if selected_model == 'Cas9': | |
# Use a radio button to select enzymes, making sure only one can be selected at a time | |
target_selection = st.radio( | |
"Select either on-target, on-target with mutation or off-target:", | |
('on-target', 'mutation', 'off-target'), | |
key='target_selection' | |
) | |
if 'current_gene_symbol' not in st.session_state: | |
st.session_state['current_gene_symbol'] = "" | |
# Define a function to clean up old files | |
def clean_up_old_files(gene_symbol): | |
genbank_file_path = f"{gene_symbol}_crispr_targets.gb" | |
bed_file_path = f"{gene_symbol}_crispr_targets.bed" | |
csv_file_path = f"{gene_symbol}_crispr_predictions.csv" | |
for path in [genbank_file_path, bed_file_path, csv_file_path]: | |
if os.path.exists(path): | |
os.remove(path) | |
# Gene symbol entry with autocomplete-like feature | |
gene_symbol = st.selectbox('Enter a Gene Symbol:', [''] + gene_symbol_list, key='gene_symbol', | |
format_func=lambda x: x if x else "") | |
# Handle gene symbol change and file cleanup | |
if gene_symbol != st.session_state['current_gene_symbol'] and gene_symbol: | |
if st.session_state['current_gene_symbol']: | |
# Clean up files only if a different gene symbol is entered and a previous symbol exists | |
clean_up_old_files(st.session_state['current_gene_symbol']) | |
# Update the session state with the new gene symbol | |
st.session_state['current_gene_symbol'] = gene_symbol | |
if target_selection == 'on-target': | |
# Prediction button | |
predict_button = st.button('Predict on-target') | |
if 'exons' not in st.session_state: | |
st.session_state['exons'] = [] | |
# Process predictions | |
if predict_button and gene_symbol: | |
with st.spinner('Predicting... Please wait'): | |
predictions, gene_sequence, exons = cas9att.process_gene(gene_symbol, cas9att_path) | |
sorted_predictions = sorted(predictions, key=lambda x: x[8], reverse=True)[:10] | |
st.session_state['on_target_results'] = sorted_predictions | |
st.session_state['gene_sequence'] = gene_sequence # Save gene sequence in session state | |
st.session_state['exons'] = exons # Store exon data | |
# Notify the user once the process is completed successfully. | |
st.success('Prediction completed!') | |
st.session_state['prediction_made'] = True | |
if 'on_target_results' in st.session_state and st.session_state['on_target_results']: | |
ensembl_id = gene_annotations.get(gene_symbol, 'Unknown') # Get Ensembl ID or default to 'Unknown' | |
col1, col2, col3 = st.columns(3) | |
with col1: | |
st.markdown("**Genome**") | |
st.markdown("Homo sapiens") | |
with col2: | |
st.markdown("**Gene**") | |
st.markdown(f"{gene_symbol} : {ensembl_id} (primary)") | |
with col3: | |
st.markdown("**Nuclease**") | |
st.markdown("SpCas9") | |
# Include "Target" in the DataFrame's columns | |
try: | |
df = pd.DataFrame(st.session_state['on_target_results'], | |
columns=["Chr", "Start Pos", "End Pos", "Strand", "Transcript", "Exon", "Target", "gRNA", "Prediction"]) | |
st.dataframe(df) | |
except ValueError as e: | |
st.error(f"DataFrame creation error: {e}") | |
# Optionally print or log the problematic data for debugging: | |
print(st.session_state['on_target_results']) | |
# Initialize Plotly figure | |
fig = go.Figure() | |
EXON_BASE = 0 # Base position for exons and CDS on the Y axis | |
EXON_HEIGHT = 0.02 # How 'tall' the exon markers should appear | |
# Plot Exons as small markers on the X-axis | |
for exon in st.session_state['exons']: | |
exon_start, exon_end = exon['start'], exon['end'] | |
fig.add_trace(go.Bar( | |
x=[(exon_start + exon_end) / 2], | |
y=[EXON_HEIGHT], | |
width=[exon_end - exon_start], | |
base=EXON_BASE, | |
marker_color='rgba(128, 0, 128, 0.5)', | |
name='Exon' | |
)) | |
VERTICAL_GAP = 0.2 # Gap between different ranks | |
# Define max and min Y values based on strand and rank | |
MAX_STRAND_Y = 0.1 # Maximum Y value for positive strand results | |
MIN_STRAND_Y = -0.1 # Minimum Y value for negative strand results | |
# Iterate over top 5 sorted predictions to create the plot | |
for i, prediction in enumerate(st.session_state['on_target_results'][:5], start=1): # Only top 5 | |
chrom, start, end, strand, transcript, exon, target, gRNA, prediction_score = prediction | |
midpoint = (int(start) + int(end)) / 2 | |
# Vertical position based on rank, modified by strand | |
y_value = (MAX_STRAND_Y - (i - 1) * VERTICAL_GAP) if strand == '1' or strand == '+' else ( | |
MIN_STRAND_Y + (i - 1) * VERTICAL_GAP) | |
fig.add_trace(go.Scatter( | |
x=[midpoint], | |
y=[y_value], | |
mode='markers+text', | |
marker=dict(symbol='triangle-up' if strand == '1' or strand == '+' else 'triangle-down', | |
size=12), | |
text=f"Rank: {i}", # Text label | |
hoverinfo='text', | |
hovertext=f"Rank: {i}<br>Chromosome: {chrom}<br>Target Sequence: {target}<br>gRNA: {gRNA}<br>Start: {start}<br>End: {end}<br>Strand: {'+' if strand == '1' or strand == '+' else '-'}<br>Transcript: {transcript}<br>Prediction: {prediction_score:.4f}", | |
)) | |
# Update layout for clarity and interaction | |
fig.update_layout( | |
title='Top 5 gRNA Sequences by Prediction Score', | |
xaxis_title='Genomic Position', | |
yaxis_title='Strand', | |
yaxis=dict(tickvals=[MAX_STRAND_Y, MIN_STRAND_Y], ticktext=['+', '-']), | |
showlegend=False, | |
hovermode='x unified', | |
) | |
# Display the plot | |
st.plotly_chart(fig) | |
if 'gene_sequence' in st.session_state and st.session_state['gene_sequence']: | |
gene_symbol = st.session_state['current_gene_symbol'] | |
gene_sequence = st.session_state['gene_sequence'] | |
# Define file paths | |
genbank_file_path = f"{gene_symbol}_crispr_targets.gb" | |
bed_file_path = f"{gene_symbol}_crispr_targets.bed" | |
csv_file_path = f"{gene_symbol}_crispr_predictions.csv" | |
plot_image_path = f"{gene_symbol}_gtracks_plot.png" | |
# Generate files | |
cas9att.generate_genbank_file_from_df(df, gene_sequence, gene_symbol, genbank_file_path) | |
cas9att.create_bed_file_from_df(df, bed_file_path) | |
cas9att.create_csv_from_df(df, csv_file_path) | |
# Prepare an in-memory buffer for the ZIP file | |
zip_buffer = io.BytesIO() | |
with zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED) as zip_file: | |
# For each file, add it to the ZIP file | |
zip_file.write(genbank_file_path) | |
zip_file.write(bed_file_path) | |
zip_file.write(csv_file_path) | |
# Important: move the cursor to the beginning of the BytesIO buffer before reading it | |
zip_buffer.seek(0) | |
# Specify the region you want to visualize | |
min_start = df['Start Pos'].min() | |
max_end = df['End Pos'].max() | |
chromosome = df['Chr'].mode()[0] # Assumes most common chromosome is the target | |
region = f"{chromosome}:{min_start}-{max_end}" | |
# Generate the pyGenomeTracks plot | |
gtracks_command = f"gtracks {region} {bed_file_path} {plot_image_path}" | |
subprocess.run(gtracks_command, shell=True) | |
st.image(plot_image_path) | |
# Display the download button for the ZIP file | |
st.download_button( | |
label="Download GenBank, BED, CSV files as ZIP", | |
data=zip_buffer.getvalue(), | |
file_name=f"{gene_symbol}_files.zip", | |
mime="application/zip" | |
) | |
elif target_selection == 'mutation': | |
# Prediction button | |
predict_button = st.button('Predict on-target') | |
vcf_reader = cyvcf2.VCF('SRR25934512.filter.snps.indels.vcf.gz') | |
if 'exons' not in st.session_state: | |
st.session_state['exons'] = [] | |
# Process predictions | |
if predict_button and gene_symbol: | |
with st.spinner('Predicting... Please wait'): | |
predictions, gene_sequence, exons = cas9attvcf.process_gene(gene_symbol, vcf_reader, cas9att_path) | |
full_predictions = sorted(predictions, key=lambda x: x[8], reverse=True) | |
sorted_predictions = sorted(predictions, key=lambda x: x[8], reverse=True)[:10] | |
st.session_state['full_results'] = full_predictions | |
st.session_state['on_target_results'] = sorted_predictions | |
st.session_state['gene_sequence'] = gene_sequence # Save gene sequence in session state | |
st.session_state['exons'] = exons # Store exon data | |
# Notify the user once the process is completed successfully. | |
st.success('Prediction completed!') | |
st.session_state['prediction_made'] = True | |
if 'on_target_results' in st.session_state and st.session_state['on_target_results']: | |
ensembl_id = gene_annotations.get(gene_symbol, 'Unknown') # Get Ensembl ID or default to 'Unknown' | |
col1, col2, col3 = st.columns(3) | |
with col1: | |
st.markdown("**Genome**") | |
st.markdown("Homo sapiens") | |
with col2: | |
st.markdown("**Gene**") | |
st.markdown(f"{gene_symbol} : {ensembl_id} (primary)") | |
with col3: | |
st.markdown("**Nuclease**") | |
st.markdown("SpCas9") | |
# Include "Target" in the DataFrame's columns | |
try: | |
df = pd.DataFrame(st.session_state['on_target_results'], | |
columns=["Gene Symbol", "Chr", "Strand", "Target Start", "Transcript", "Exon", | |
"Target", | |
"gRNA", "Prediction", "Is Mutation"]) | |
df_full = pd.DataFrame(st.session_state['full_results'], | |
columns=["Gene Symbol", "Chr", "Strand", "Target Start", "Transcript", | |
"Exon", "Target", | |
"gRNA", "Prediction", "Is Mutation"]) | |
st.dataframe(df) | |
except ValueError as e: | |
st.error(f"DataFrame creation error: {e}") | |
# Optionally print or log the problematic data for debugging: | |
print(st.session_state['on_target_results']) | |
if 'gene_sequence' in st.session_state and st.session_state['gene_sequence']: | |
gene_symbol = st.session_state['current_gene_symbol'] | |
gene_sequence = st.session_state['gene_sequence'] | |
# Define file paths | |
genbank_file_path = f"{gene_symbol}_crispr_targets.gb" | |
bed_file_path = f"{gene_symbol}_crispr_targets.bed" | |
csv_file_path = f"{gene_symbol}_crispr_predictions.csv" | |
plot_image_path = f"{gene_symbol}_gtracks_plot.png" | |
# Generate files | |
cas9att.generate_genbank_file_from_df(df_full, gene_sequence, gene_symbol, genbank_file_path) | |
cas9att.create_bed_file_from_df(df_full, bed_file_path) | |
cas9att.create_csv_from_df(df_full, csv_file_path) | |
# Prepare an in-memory buffer for the ZIP file | |
zip_buffer = io.BytesIO() | |
with zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED) as zip_file: | |
# For each file, add it to the ZIP file | |
zip_file.write(genbank_file_path) | |
zip_file.write(bed_file_path) | |
zip_file.write(csv_file_path) | |
# Display the download button for the ZIP file | |
st.download_button( | |
label="Download GenBank, BED, CSV files as ZIP", | |
data=zip_buffer.getvalue(), | |
file_name=f"{gene_symbol}_files.zip", | |
mime="application/zip" | |
) | |
elif target_selection == 'off-target': | |
ENTRY_METHODS = dict( | |
manual='Manual entry of target sequence', | |
txt="txt file upload" | |
) | |
if __name__ == '__main__': | |
# app initialization for Cas9 off-target | |
if 'target_sequence' not in st.session_state: | |
st.session_state.target_sequence = None | |
if 'input_error' not in st.session_state: | |
st.session_state.input_error = None | |
if 'off_target_results' not in st.session_state: | |
st.session_state.off_target_results = None | |
# target sequence entry | |
st.selectbox( | |
label='How would you like to provide target sequences?', | |
options=ENTRY_METHODS.values(), | |
key='entry_method', | |
disabled=st.session_state.target_sequence is not None | |
) | |
if st.session_state.entry_method == ENTRY_METHODS['manual']: | |
st.text_input( | |
label='Enter on/off sequences:', | |
key='manual_entry', | |
placeholder='Enter on/off sequences like:GGGTGGGGGGAGTTTGCTCCAGG,AGGTGGGGTGA_TTTGCTCCAGG', | |
disabled=st.session_state.target_sequence is not None | |
) | |
elif st.session_state.entry_method == ENTRY_METHODS['txt']: | |
st.file_uploader( | |
label='Upload a txt file:', | |
key='txt_entry', | |
disabled=st.session_state.target_sequence is not None | |
) | |
# prediction button | |
if st.button('Predict off-target'): | |
if st.session_state.entry_method == ENTRY_METHODS['manual']: | |
user_input = st.session_state.manual_entry | |
if user_input: # Check if user_input is not empty | |
predictions = cas9off.process_input_and_predict(user_input, input_type='manual') | |
elif st.session_state.entry_method == ENTRY_METHODS['txt']: | |
uploaded_file = st.session_state.txt_entry | |
if uploaded_file is not None: | |
# Read the uploaded file content | |
file_content = uploaded_file.getvalue().decode("utf-8") | |
predictions = cas9off.process_input_and_predict(file_content, input_type='manual') | |
st.session_state.off_target_results = predictions | |
else: | |
predictions = None | |
progress = st.empty() | |
# input error display | |
error = st.empty() | |
if st.session_state.input_error is not None: | |
error.error(st.session_state.input_error, icon="🚨") | |
else: | |
error.empty() | |
# off-target results display | |
off_target_results = st.empty() | |
if st.session_state.off_target_results is not None: | |
with off_target_results.container(): | |
if len(st.session_state.off_target_results) > 0: | |
st.write('Off-target predictions:', st.session_state.off_target_results) | |
st.download_button( | |
label='Download off-target predictions', | |
data=convert_df(st.session_state.off_target_results), | |
file_name='off_target_results.csv', | |
mime='text/csv' | |
) | |
else: | |
st.write('No significant off-target effects detected!') | |
else: | |
off_target_results.empty() | |
# running the CRISPR-Net model for off-target predictions | |
if st.session_state.target_sequence is not None: | |
st.session_state.off_target_results = cas9off.predict_off_targets( | |
target_sequence=st.session_state.target_sequence, | |
status_update_fn=progress_update | |
) | |
st.session_state.target_sequence = None | |
st.experimental_rerun() | |
elif selected_model == 'Cas12': | |
def visualize_genomic_data(): | |
fig = go.Figure() | |
EXON_BASE = 0 # Base position for exons and CDS on the Y axis | |
EXON_HEIGHT = 0.02 # How 'tall' the exon markers should appear | |
# Plot Exons as small markers on the X-axis | |
for exon in st.session_state['exons']: | |
try: | |
exon_start, exon_end = int(exon['start']), int(exon['end']) | |
fig.add_trace(go.Bar( | |
x=[(exon_start + exon_end) / 2], | |
y=[EXON_HEIGHT], | |
width=[exon_end - exon_start], | |
base=EXON_BASE, | |
marker_color='rgba(128, 0, 128, 0.5)', | |
name='Exon' | |
)) | |
except ValueError: | |
st.error("Error in exon positions. Exon positions should be numeric.") | |
VERTICAL_GAP = 0.2 # Gap between different ranks | |
# Define max and min Y values based on strand and rank | |
MAX_STRAND_Y = 0.1 # Maximum Y value for positive strand results | |
MIN_STRAND_Y = -0.1 # Minimum Y value for negative strand results | |
# Iterate over top 5 sorted predictions to create the plot | |
for i, prediction in enumerate(st.session_state['on_target_results'][:5], start=1): # Only top 5 | |
try: | |
start, end = int(prediction['Start Pos']), int(prediction['End Pos']) | |
midpoint = (start + end) / 2 | |
strand = prediction['Strand'] | |
y_value = (MAX_STRAND_Y - (i - 1) * VERTICAL_GAP) if strand in ['1', '+'] else ( | |
MIN_STRAND_Y + (i - 1) * VERTICAL_GAP) | |
fig.add_trace(go.Scatter( | |
x=[midpoint], | |
y=[y_value], | |
mode='markers+text', | |
marker=dict(symbol='triangle-up' if strand in ['1', '+'] else 'triangle-down', size=12), | |
text=f"Rank: {i}", | |
hoverinfo='text', | |
hovertext=f"Rank: {i}<br>Chromosome: {prediction['Chr']}<br>Target Sequence: {prediction['Target']}<br>gRNA: {prediction['gRNA']}<br>Start: {start}<br>End: {end}<br>Strand: {'+' if strand in ['1', '+'] else '-'}<br>Transcript: {prediction['Transcript']}<br>Prediction: {prediction['Prediction']:.4f}", | |
)) | |
except ValueError: | |
st.error("Error in prediction positions. Start and end positions should be numeric.") | |
# Update layout for clarity and interaction | |
fig.update_layout( | |
title='Top 5 gRNA Sequences by Prediction Score', | |
xaxis_title='Genomic Position', | |
yaxis_title='Strand', | |
yaxis=dict(tickvals=[MAX_STRAND_Y, MIN_STRAND_Y], ticktext=['+', '-']), | |
showlegend=False, | |
hovermode='x unified', | |
) | |
# Display the plot | |
st.plotly_chart(fig) | |
# File generation and download | |
generate_and_download_files(df, gene_symbol) | |
def generate_and_download_files(df, gene_symbol): | |
genbank_file_path = f"{gene_symbol}_crispr_targets.gb" | |
bed_file_path = f"{gene_symbol}_crispr_targets.bed" | |
csv_file_path = f"{gene_symbol}_crispr_predictions.csv" | |
df.to_csv(csv_file_path, index=False) | |
# Assume functions to generate GenBank and BED are defined in cas12lstm or cas12lstmvcf | |
cas12lstm.generate_genbank_file_from_df(df, gene_symbol, genbank_file_path) | |
cas12lstm.create_bed_file_from_df(df, bed_file_path) | |
zip_buffer = io.BytesIO() | |
with zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED) as zip_file: | |
zip_file.write(genbank_file_path) | |
zip_file.write(bed_file_path) | |
zip_file.write(csv_file_path) | |
zip_buffer.seek(0) | |
st.download_button("Download GenBank, BED, CSV files as ZIP", data=zip_buffer.getvalue(), | |
file_name=f"{gene_symbol}_files.zip", mime="application/zip") | |
def display_results(predictions, gene_sequence, exons, gene_symbol): | |
st.success('Prediction completed!') | |
ensembl_id = gene_annotations.get(gene_symbol, 'Unknown') | |
st.write(f"**Genome:** Homo sapiens") | |
st.write(f"**Gene:** {gene_symbol} : {ensembl_id} (primary)") | |
st.write("**Nuclease:** Cas12") | |
df = pd.DataFrame(predictions, | |
columns=["Chr", "Start Pos", "End Pos", "Strand", "Transcript", "Exon", "Target", "gRNA", | |
"Prediction"]) | |
st.dataframe(df) | |
# Visualization and file generation as demonstrated in the Cas9 example | |
visualize_and_generate_files(df, gene_sequence, exons, gene_symbol) | |
cas12target_selection = st.radio( | |
"Select either regular or mutation:", | |
('regular', 'mutation'), | |
key='cas12target_selection' | |
) | |
if 'current_gene_symbol' not in st.session_state: | |
st.session_state['current_gene_symbol'] = "" | |
def clean_up_old_files(gene_symbol): | |
for suffix in ['_crispr_targets.gb', '_crispr_targets.bed', '_crispr_predictions.csv']: | |
file_path = f"{gene_symbol}{suffix}" | |
if os.path.exists(file_path): | |
os.remove(file_path) | |
gene_symbol = st.selectbox( | |
'Enter a Gene Symbol:', | |
[''] + gene_symbol_list, | |
key='gene_symbol', | |
format_func=lambda x: x if x else "" | |
) | |
if gene_symbol != st.session_state['current_gene_symbol']: | |
if st.session_state['current_gene_symbol']: | |
clean_up_old_files(st.session_state['current_gene_symbol']) | |
st.session_state['current_gene_symbol'] = gene_symbol | |
if cas12target_selection == 'regular': | |
if st.button('Predict cas12 Regular'): | |
with st.spinner('Predicting... Please wait'): | |
predictions, gene_sequence, exons = cas12lstm.process_gene(gene_symbol, cas12lstm_path) | |
sorted_predictions = sorted(predictions, key=lambda x: x[8], reverse=True)[:10] | |
display_results(sorted_predictions, gene_sequence, exons, gene_symbol) | |
elif cas12target_selection == 'mutation': | |
vcf_reader = cyvcf2.VCF('SRR25934512.filter.snps.indels.vcf.gz') | |
if st.button('Predict cas12 Mutation'): | |
with st.spinner('Predicting... Please wait'): | |
predictions, gene_sequence, exons = cas12lstmvcf.process_gene(gene_symbol, vcf_reader, cas12lstm_path) | |
sorted_predictions = sorted(predictions, key=lambda x: x[8], reverse=True)[:10] | |
display_results(sorted_predictions, gene_sequence, exons, gene_symbol) | |
elif selected_model == 'Cas13d': | |
ENTRY_METHODS = dict( | |
manual='Manual entry of single transcript', | |
fasta="Fasta file upload (supports multiple transcripts if they have unique ID's)" | |
) | |
if __name__ == '__main__': | |
# app initialization | |
if 'mode' not in st.session_state: | |
st.session_state.mode = tiger.RUN_MODES['all'] | |
st.session_state.disable_off_target_checkbox = True | |
if 'entry_method' not in st.session_state: | |
st.session_state.entry_method = ENTRY_METHODS['manual'] | |
if 'transcripts' not in st.session_state: | |
st.session_state.transcripts = None | |
if 'input_error' not in st.session_state: | |
st.session_state.input_error = None | |
if 'on_target' not in st.session_state: | |
st.session_state.on_target = None | |
if 'titration' not in st.session_state: | |
st.session_state.titration = None | |
if 'off_target' not in st.session_state: | |
st.session_state.off_target = None | |
# mode selection | |
col1, col2 = st.columns([0.65, 0.35]) | |
with col1: | |
st.radio( | |
label='What do you want to predict?', | |
options=tuple(tiger.RUN_MODES.values()), | |
key='mode', | |
on_change=mode_change_callback, | |
disabled=st.session_state.transcripts is not None, | |
) | |
with col2: | |
st.checkbox( | |
label='Find off-target effects (slow)', | |
key='check_off_targets', | |
disabled=st.session_state.disable_off_target_checkbox or st.session_state.transcripts is not None | |
) | |
# transcript entry | |
st.selectbox( | |
label='How would you like to provide transcript(s) of interest?', | |
options=ENTRY_METHODS.values(), | |
key='entry_method', | |
disabled=st.session_state.transcripts is not None | |
) | |
if st.session_state.entry_method == ENTRY_METHODS['manual']: | |
st.text_input( | |
label='Enter a target transcript:', | |
key='manual_entry', | |
placeholder='Upper or lower case', | |
disabled=st.session_state.transcripts is not None | |
) | |
elif st.session_state.entry_method == ENTRY_METHODS['fasta']: | |
st.file_uploader( | |
label='Upload a fasta file:', | |
key='fasta_entry', | |
disabled=st.session_state.transcripts is not None | |
) | |
# let's go! | |
st.button(label='Get predictions!', on_click=initiate_run, disabled=st.session_state.transcripts is not None) | |
progress = st.empty() | |
# input error | |
error = st.empty() | |
if st.session_state.input_error is not None: | |
error.error(st.session_state.input_error, icon="🚨") | |
else: | |
error.empty() | |
# on-target results | |
on_target_results = st.empty() | |
if st.session_state.on_target is not None: | |
with on_target_results.container(): | |
st.write('On-target predictions:', st.session_state.on_target) | |
st.download_button( | |
label='Download on-target predictions', | |
data=convert_df(st.session_state.on_target), | |
file_name='on_target.csv', | |
mime='text/csv' | |
) | |
else: | |
on_target_results.empty() | |
# titration results | |
titration_results = st.empty() | |
if st.session_state.titration is not None: | |
with titration_results.container(): | |
st.write('Titration predictions:', st.session_state.titration) | |
st.download_button( | |
label='Download titration predictions', | |
data=convert_df(st.session_state.titration), | |
file_name='titration.csv', | |
mime='text/csv' | |
) | |
else: | |
titration_results.empty() | |
# off-target results | |
off_target_results = st.empty() | |
if st.session_state.off_target is not None: | |
with off_target_results.container(): | |
if len(st.session_state.off_target) > 0: | |
st.write('Off-target predictions:', st.session_state.off_target) | |
st.download_button( | |
label='Download off-target predictions', | |
data=convert_df(st.session_state.off_target), | |
file_name='off_target.csv', | |
mime='text/csv' | |
) | |
else: | |
st.write('We did not find any off-target effects!') | |
else: | |
off_target_results.empty() | |
# keep trying to run model until we clear inputs (streamlit UI changes can induce race-condition reruns) | |
if st.session_state.transcripts is not None: | |
st.session_state.on_target, st.session_state.titration, st.session_state.off_target = tiger.tiger_exhibit( | |
transcripts=st.session_state.transcripts, | |
mode={v: k for k, v in tiger.RUN_MODES.items()}[st.session_state.mode], | |
check_off_targets=st.session_state.check_off_targets, | |
status_update_fn=progress_update | |
) | |
st.session_state.transcripts = None | |
st.experimental_rerun() | |