import os import tiger import cas9att import cas9attvcf import cas9off import cas12 import pandas as pd import streamlit as st import plotly.graph_objs as go import numpy as np from pathlib import Path import zipfile import io import gtracks import subprocess # title and documentation st.markdown(Path('crisprTool.md').read_text(), unsafe_allow_html=True) st.divider() CRISPR_MODELS = ['Cas9', 'Cas12', 'Cas13d'] selected_model = st.selectbox('Select CRISPR model:', CRISPR_MODELS, key='selected_model') cas9att_path = 'cas9_model/Cas9_MultiHeadAttention_weights.keras' cas12_path = 'cas12_model/BiLSTM_Cpf1_weights.keras' #plot functions def generate_coolbox_plot(bigwig_path, region, output_image_path): frame = CoolBox() frame += BigWig(bigwig_path) frame.plot(region, savefig=output_image_path) def generate_pygenometracks_plot(bigwig_file_path, region, output_image_path): # Define the configuration for pyGenomeTracks tracks = """ [bigwig] file = {} height = 4 color = blue min_value = 0 max_value = 10 """.format(bigwig_file_path) # Write the configuration to a temporary INI file config_file_path = "pygenometracks.ini" with open(config_file_path, 'w') as configfile: configfile.write(tracks) # Define the region to plot region_dict = {'chrom': region.split(':')[0], 'start': int(region.split(':')[1].split('-')[0]), 'end': int(region.split(':')[1].split('-')[1])} # Generate the plot plot_tracks(tracks_file=config_file_path, region=region_dict, out_file_name=output_image_path) @st.cache_data def convert_df(df): # IMPORTANT: Cache the conversion to prevent computation on every rerun return df.to_csv().encode('utf-8') def mode_change_callback(): if st.session_state.mode in {tiger.RUN_MODES['all'], tiger.RUN_MODES['titration']}: # TODO: support titration st.session_state.check_off_targets = False st.session_state.disable_off_target_checkbox = True else: st.session_state.disable_off_target_checkbox = False def progress_update(update_text, percent_complete): with progress.container(): st.write(update_text) st.progress(percent_complete / 100) def initiate_run(): # initialize state variables st.session_state.transcripts = None st.session_state.input_error = None st.session_state.on_target = None st.session_state.titration = None st.session_state.off_target = None # initialize transcript DataFrame transcripts = pd.DataFrame(columns=[tiger.ID_COL, tiger.SEQ_COL]) # manual entry if st.session_state.entry_method == ENTRY_METHODS['manual']: transcripts = pd.DataFrame({ tiger.ID_COL: ['ManualEntry'], tiger.SEQ_COL: [st.session_state.manual_entry] }).set_index(tiger.ID_COL) # fasta file upload elif st.session_state.entry_method == ENTRY_METHODS['fasta']: if st.session_state.fasta_entry is not None: fasta_path = st.session_state.fasta_entry.name with open(fasta_path, 'w') as f: f.write(st.session_state.fasta_entry.getvalue().decode('utf-8')) transcripts = tiger.load_transcripts([fasta_path], enforce_unique_ids=False) os.remove(fasta_path) # convert to upper case as used by tokenizer transcripts[tiger.SEQ_COL] = transcripts[tiger.SEQ_COL].apply(lambda s: s.upper().replace('U', 'T')) # ensure all transcripts have unique identifiers if transcripts.index.has_duplicates: st.session_state.input_error = "Duplicate transcript ID's detected in fasta file" # ensure all transcripts only contain nucleotides A, C, G, T, and wildcard N elif not all(transcripts[tiger.SEQ_COL].apply(lambda s: set(s).issubset(tiger.NUCLEOTIDE_TOKENS.keys()))): st.session_state.input_error = 'Transcript(s) must only contain upper or lower case A, C, G, and Ts or Us' # ensure all transcripts satisfy length requirements elif any(transcripts[tiger.SEQ_COL].apply(lambda s: len(s) < tiger.TARGET_LEN)): st.session_state.input_error = 'Transcript(s) must be at least {:d} bases.'.format(tiger.TARGET_LEN) # run model if we have any transcripts elif len(transcripts) > 0: st.session_state.transcripts = transcripts def parse_gene_annotations(file_path): gene_dict = {} with open(file_path, 'r') as file: headers = file.readline().strip().split('\t') # Assuming tab-delimited file symbol_idx = headers.index('Approved symbol') # Find index of 'Approved symbol' ensembl_idx = headers.index('Ensembl gene ID') # Find index of 'Ensembl gene ID' for line in file: values = line.strip().split('\t') # Ensure we have enough values and add mapping from symbol to Ensembl ID if len(values) > max(symbol_idx, ensembl_idx): gene_dict[values[symbol_idx]] = values[ensembl_idx] return gene_dict # Replace 'your_annotation_file.txt' with the path to your actual gene annotation file gene_annotations = parse_gene_annotations('Human_genes_HUGO_02242024_annotation.txt') gene_symbol_list = list(gene_annotations.keys()) # List of gene symbols for the autocomplete feature # Check if the selected model is Cas9 if selected_model == 'Cas9': # Use a radio button to select enzymes, making sure only one can be selected at a time target_selection = st.radio( "Select either on-target or off-target:", ('on-target', 'off-target'), key='target_selection' ) if 'current_gene_symbol' not in st.session_state: st.session_state['current_gene_symbol'] = "" # Define a function to clean up old files def clean_up_old_files(gene_symbol): genbank_file_path = f"{gene_symbol}_crispr_targets.gb" bed_file_path = f"{gene_symbol}_crispr_targets.bed" csv_file_path = f"{gene_symbol}_crispr_predictions.csv" for path in [genbank_file_path, bed_file_path, csv_file_path]: if os.path.exists(path): os.remove(path) # Gene symbol entry with autocomplete-like feature gene_symbol = st.selectbox('Enter a Gene Symbol:', [''] + gene_symbol_list, key='gene_symbol', format_func=lambda x: x if x else "") # Handle gene symbol change and file cleanup if gene_symbol != st.session_state['current_gene_symbol'] and gene_symbol: if st.session_state['current_gene_symbol']: # Clean up files only if a different gene symbol is entered and a previous symbol exists clean_up_old_files(st.session_state['current_gene_symbol']) # Update the session state with the new gene symbol st.session_state['current_gene_symbol'] = gene_symbol if target_selection == 'on-target': # Prediction button predict_button = st.button('Predict on-target') if 'exons' not in st.session_state: st.session_state['exons'] = [] # Process predictions if predict_button and gene_symbol: with st.spinner('Predicting... Please wait'): predictions, gene_sequence, exons = cas9att.process_gene(gene_symbol, cas9att_path) sorted_predictions = sorted(predictions)[:10] st.session_state['on_target_results'] = sorted_predictions st.session_state['gene_sequence'] = gene_sequence # Save gene sequence in session state st.session_state['exons'] = exons # Store exon data # Notify the user once the process is completed successfully. st.success('Prediction completed!') st.session_state['prediction_made'] = True if 'on_target_results' in st.session_state and st.session_state['on_target_results']: ensembl_id = gene_annotations.get(gene_symbol, 'Unknown') # Get Ensembl ID or default to 'Unknown' col1, col2, col3 = st.columns(3) with col1: st.markdown("**Genome**") st.markdown("Homo sapiens") with col2: st.markdown("**Gene**") st.markdown(f"{gene_symbol} : {ensembl_id} (primary)") with col3: st.markdown("**Nuclease**") st.markdown("SpCas9") # Include "Target" in the DataFrame's columns try: df = pd.DataFrame(st.session_state['on_target_results'], columns=["Chr", "Start Pos", "End Pos", "Strand", "Transcript", "Exon", "Target", "gRNA", "Prediction"]) st.dataframe(df) except ValueError as e: st.error(f"DataFrame creation error: {e}") # Optionally print or log the problematic data for debugging: print(st.session_state['on_target_results']) # Initialize Plotly figure fig = go.Figure() EXON_BASE = 0 # Base position for exons and CDS on the Y axis EXON_HEIGHT = 0.02 # How 'tall' the exon markers should appear # Plot Exons as small markers on the X-axis for exon in st.session_state['exons']: exon_start, exon_end = exon['start'], exon['end'] fig.add_trace(go.Bar( x=[(exon_start + exon_end) / 2], y=[EXON_HEIGHT], width=[exon_end - exon_start], base=EXON_BASE, marker_color='rgba(128, 0, 128, 0.5)', name='Exon' )) VERTICAL_GAP = 0.2 # Gap between different ranks # Define max and min Y values based on strand and rank MAX_STRAND_Y = 0.1 # Maximum Y value for positive strand results MIN_STRAND_Y = -0.1 # Minimum Y value for negative strand results # Iterate over top 5 sorted predictions to create the plot for i, prediction in enumerate(st.session_state['on_target_results'][:5], start=1): # Only top 5 chrom, start, end, strand, transcript, exon, target, gRNA, prediction_score = prediction midpoint = (int(start) + int(end)) / 2 # Vertical position based on rank, modified by strand y_value = (MAX_STRAND_Y - (i - 1) * VERTICAL_GAP) if strand == '1' or strand == '+' else ( MIN_STRAND_Y + (i - 1) * VERTICAL_GAP) fig.add_trace(go.Scatter( x=[midpoint], y=[y_value], mode='markers+text', marker=dict(symbol='triangle-up' if strand == '1' or strand == '+' else 'triangle-down', size=12), text=f"Rank: {i}", # Text label hoverinfo='text', hovertext=f"Rank: {i}
Chromosome: {chrom}
Target Sequence: {target}
gRNA: {gRNA}
Start: {start}
End: {end}
Strand: {'+' if strand == '1' or strand == '+' else '-'}
Transcript: {transcript}
Prediction: {prediction_score:.4f}", )) # Update layout for clarity and interaction fig.update_layout( title='Top 5 gRNA Sequences by Prediction Score', xaxis_title='Genomic Position', yaxis_title='Strand', yaxis=dict(tickvals=[MAX_STRAND_Y, MIN_STRAND_Y], ticktext=['+', '-']), showlegend=False, hovermode='x unified', ) # Display the plot st.plotly_chart(fig) if 'gene_sequence' in st.session_state and st.session_state['gene_sequence']: gene_symbol = st.session_state['current_gene_symbol'] gene_sequence = st.session_state['gene_sequence'] # Define file paths genbank_file_path = f"{gene_symbol}_crispr_targets.gb" bed_file_path = f"{gene_symbol}_crispr_targets.bed" csv_file_path = f"{gene_symbol}_crispr_predictions.csv" plot_image_path = f"{gene_symbol}_gtracks_plot.png" # Generate files cas9att.generate_genbank_file_from_df(df, gene_sequence, gene_symbol, genbank_file_path) cas9att.create_bed_file_from_df(df, bed_file_path) cas9att.create_csv_from_df(df, csv_file_path) # Prepare an in-memory buffer for the ZIP file zip_buffer = io.BytesIO() with zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED) as zip_file: # For each file, add it to the ZIP file zip_file.write(genbank_file_path) zip_file.write(bed_file_path) zip_file.write(csv_file_path) # Important: move the cursor to the beginning of the BytesIO buffer before reading it zip_buffer.seek(0) # Specify the region you want to visualize min_start = df['Start Pos'].min() max_end = df['End Pos'].max() chromosome = df['Chr'].mode()[0] # Assumes most common chromosome is the target region = f"{chromosome}:{min_start}-{max_end}" # Generate the pyGenomeTracks plot gtracks_command = f"gtracks {region} {bed_file_path} {plot_image_path}" subprocess.run(gtracks_command, shell=True) st.image(plot_image_path) # Display the download button for the ZIP file st.download_button( label="Download GenBank, BED, CSV files as ZIP", data=zip_buffer.getvalue(), file_name=f"{gene_symbol}_files.zip", mime="application/zip" ) elif target_selection == 'off-target': ENTRY_METHODS = dict( manual='Manual entry of target sequence', txt="txt file upload" ) if __name__ == '__main__': # app initialization for Cas9 off-target if 'target_sequence' not in st.session_state: st.session_state.target_sequence = None if 'input_error' not in st.session_state: st.session_state.input_error = None if 'off_target_results' not in st.session_state: st.session_state.off_target_results = None # target sequence entry st.selectbox( label='How would you like to provide target sequences?', options=ENTRY_METHODS.values(), key='entry_method', disabled=st.session_state.target_sequence is not None ) if st.session_state.entry_method == ENTRY_METHODS['manual']: st.text_input( label='Enter on/off sequences:', key='manual_entry', placeholder='Enter on/off sequences like:GGGTGGGGGGAGTTTGCTCCAGG,AGGTGGGGTGA_TTTGCTCCAGG', disabled=st.session_state.target_sequence is not None ) elif st.session_state.entry_method == ENTRY_METHODS['txt']: st.file_uploader( label='Upload a txt file:', key='txt_entry', disabled=st.session_state.target_sequence is not None ) # prediction button if st.button('Predict off-target'): if st.session_state.entry_method == ENTRY_METHODS['manual']: user_input = st.session_state.manual_entry if user_input: # Check if user_input is not empty predictions = cas9off.process_input_and_predict(user_input, input_type='manual') elif st.session_state.entry_method == ENTRY_METHODS['txt']: uploaded_file = st.session_state.txt_entry if uploaded_file is not None: # Read the uploaded file content file_content = uploaded_file.getvalue().decode("utf-8") predictions = cas9off.process_input_and_predict(file_content, input_type='manual') st.session_state.off_target_results = predictions else: predictions = None progress = st.empty() # input error display error = st.empty() if st.session_state.input_error is not None: error.error(st.session_state.input_error, icon="🚨") else: error.empty() # off-target results display off_target_results = st.empty() if st.session_state.off_target_results is not None: with off_target_results.container(): if len(st.session_state.off_target_results) > 0: st.write('Off-target predictions:', st.session_state.off_target_results) st.download_button( label='Download off-target predictions', data=convert_df(st.session_state.off_target_results), file_name='off_target_results.csv', mime='text/csv' ) else: st.write('No significant off-target effects detected!') else: off_target_results.empty() # running the CRISPR-Net model for off-target predictions if st.session_state.target_sequence is not None: st.session_state.off_target_results = cas9off.predict_off_targets( target_sequence=st.session_state.target_sequence, status_update_fn=progress_update ) st.session_state.target_sequence = None st.experimental_rerun() elif selected_model == 'Cas12': # Gene symbol entry with autocomplete-like feature gene_symbol = st.selectbox('Enter a Gene Symbol:', [''] + gene_symbol_list, key='gene_symbol', format_func=lambda x: x if x else "") # Initialize the current_gene_symbol in the session state if it doesn't exist if 'current_gene_symbol' not in st.session_state: st.session_state['current_gene_symbol'] = "" # Prediction button predict_button = st.button('Predict on-target') # Function to clean up old files def clean_up_old_files(gene_symbol): genbank_file_path = f"{gene_symbol}_crispr_targets.gb" bed_file_path = f"{gene_symbol}_crispr_targets.bed" csv_file_path = f"{gene_symbol}_crispr_predictions.csv" for path in [genbank_file_path, bed_file_path, csv_file_path]: if os.path.exists(path): os.remove(path) # Clean up files if a new gene symbol is entered if st.session_state['current_gene_symbol'] and gene_symbol != st.session_state['current_gene_symbol']: clean_up_old_files(st.session_state['current_gene_symbol']) # Process predictions if predict_button and gene_symbol: # Update the current gene symbol st.session_state['current_gene_symbol'] = gene_symbol # Run the prediction process with st.spinner('Predicting... Please wait'): predictions, gene_sequence, exons = cas12.process_gene(gene_symbol,cas12_path) sorted_predictions = sorted(predictions, key=lambda x: x[-1], reverse=True)[:10] st.session_state['on_target_results'] = sorted_predictions st.session_state['gene_sequence'] = gene_sequence # Save gene sequence in session state st.session_state['exons'] = exons # Store exon data st.success('Prediction completed!') # Visualization and file generation if 'on_target_results' in st.session_state and st.session_state['on_target_results']: df = pd.DataFrame(st.session_state['on_target_results'], columns=["Chr", "Start Pos", "End Pos", "Strand", "Transcript", "Exon", "Target", "gRNA", "Prediction"]) st.dataframe(df) # Now create a Plotly plot with the sorted_predictions fig = go.Figure() # Initialize the y position for the positive and negative strands positive_strand_y = 0.1 negative_strand_y = -0.1 # Use an offset to spread gRNA sequences vertically offset = 0.05 # Iterate over the sorted predictions to create the plot for i, prediction in enumerate(sorted_predictions, start=1): # Extract data for plotting and convert start and end to integers chrom, start, end, strand, target, gRNA, pred_score = prediction start, end = int(start), int(end) midpoint = (start + end) / 2 # Set the y-value and arrow symbol based on the strand if strand == '1': y_value = positive_strand_y arrow_symbol = 'triangle-right' # Increment the y-value for the next positive strand gRNA positive_strand_y += offset else: y_value = negative_strand_y arrow_symbol = 'triangle-left' # Decrement the y-value for the next negative strand gRNA negative_strand_y -= offset fig.add_trace(go.Scatter( x=[midpoint], y=[y_value], # Use the y_value set above for the strand mode='markers+text', marker=dict(symbol=arrow_symbol, size=10), name=f"gRNA: {gRNA}", text=f"Rank: {i}", # Place text at the marker hoverinfo='text', hovertext=f"Rank: {i}
Chromosome: {chrom}
Target Sequence: {target}
gRNA: {gRNA}
Start: {start}
End: {end}
Strand: {'+' if strand == 1 else '-'}
Prediction Score: {pred_score:.4f}", )) # Update the layout of the plot fig.update_layout( title='Top 10 gRNA Sequences by Prediction Score', xaxis_title='Genomic Position', yaxis=dict( title='Strand', showgrid=True, # Show horizontal gridlines for clarity zeroline=True, # Show a line at y=0 to represent the axis zerolinecolor='Black', zerolinewidth=2, tickvals=[positive_strand_y, negative_strand_y], ticktext=['+ Strand', '- Strand'] ), showlegend=False # Hide the legend if it's not necessary ) # Display the plot st.plotly_chart(fig) # Ensure gene_sequence is not empty before generating files if 'gene_sequence' in st.session_state and st.session_state['gene_sequence']: gene_symbol = st.session_state['current_gene_symbol'] gene_sequence = st.session_state['gene_sequence'] # Define file paths genbank_file_path = f"{gene_symbol}_crispr_targets.gb" bed_file_path = f"{gene_symbol}_crispr_targets.bed" csv_file_path = f"{gene_symbol}_crispr_predictions.csv" # Generate files cas12.generate_genbank_file_from_data(df, gene_sequence, gene_symbol, genbank_file_path) cas12.generate_bed_file_from_data(df, bed_file_path) cas12.create_csv_from_df(df, csv_file_path) # Prepare an in-memory buffer for the ZIP file zip_buffer = io.BytesIO() with zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED) as zip_file: # For each file, add it to the ZIP file zip_file.write(genbank_file_path, arcname=genbank_file_path.split('/')[-1]) zip_file.write(bed_file_path, arcname=bed_file_path.split('/')[-1]) zip_file.write(csv_file_path, arcname=csv_file_path.split('/')[-1]) # Important: move the cursor to the beginning of the BytesIO buffer before reading it zip_buffer.seek(0) # Display the download button for the ZIP file st.download_button( label="Download genbank,.bed,csv files as ZIP", data=zip_buffer.getvalue(), file_name=f"{gene_symbol}_files.zip", mime="application/zip" ) elif selected_model == 'Cas13d': ENTRY_METHODS = dict( manual='Manual entry of single transcript', fasta="Fasta file upload (supports multiple transcripts if they have unique ID's)" ) if __name__ == '__main__': # app initialization if 'mode' not in st.session_state: st.session_state.mode = tiger.RUN_MODES['all'] st.session_state.disable_off_target_checkbox = True if 'entry_method' not in st.session_state: st.session_state.entry_method = ENTRY_METHODS['manual'] if 'transcripts' not in st.session_state: st.session_state.transcripts = None if 'input_error' not in st.session_state: st.session_state.input_error = None if 'on_target' not in st.session_state: st.session_state.on_target = None if 'titration' not in st.session_state: st.session_state.titration = None if 'off_target' not in st.session_state: st.session_state.off_target = None # mode selection col1, col2 = st.columns([0.65, 0.35]) with col1: st.radio( label='What do you want to predict?', options=tuple(tiger.RUN_MODES.values()), key='mode', on_change=mode_change_callback, disabled=st.session_state.transcripts is not None, ) with col2: st.checkbox( label='Find off-target effects (slow)', key='check_off_targets', disabled=st.session_state.disable_off_target_checkbox or st.session_state.transcripts is not None ) # transcript entry st.selectbox( label='How would you like to provide transcript(s) of interest?', options=ENTRY_METHODS.values(), key='entry_method', disabled=st.session_state.transcripts is not None ) if st.session_state.entry_method == ENTRY_METHODS['manual']: st.text_input( label='Enter a target transcript:', key='manual_entry', placeholder='Upper or lower case', disabled=st.session_state.transcripts is not None ) elif st.session_state.entry_method == ENTRY_METHODS['fasta']: st.file_uploader( label='Upload a fasta file:', key='fasta_entry', disabled=st.session_state.transcripts is not None ) # let's go! st.button(label='Get predictions!', on_click=initiate_run, disabled=st.session_state.transcripts is not None) progress = st.empty() # input error error = st.empty() if st.session_state.input_error is not None: error.error(st.session_state.input_error, icon="🚨") else: error.empty() # on-target results on_target_results = st.empty() if st.session_state.on_target is not None: with on_target_results.container(): st.write('On-target predictions:', st.session_state.on_target) st.download_button( label='Download on-target predictions', data=convert_df(st.session_state.on_target), file_name='on_target.csv', mime='text/csv' ) else: on_target_results.empty() # titration results titration_results = st.empty() if st.session_state.titration is not None: with titration_results.container(): st.write('Titration predictions:', st.session_state.titration) st.download_button( label='Download titration predictions', data=convert_df(st.session_state.titration), file_name='titration.csv', mime='text/csv' ) else: titration_results.empty() # off-target results off_target_results = st.empty() if st.session_state.off_target is not None: with off_target_results.container(): if len(st.session_state.off_target) > 0: st.write('Off-target predictions:', st.session_state.off_target) st.download_button( label='Download off-target predictions', data=convert_df(st.session_state.off_target), file_name='off_target.csv', mime='text/csv' ) else: st.write('We did not find any off-target effects!') else: off_target_results.empty() # keep trying to run model until we clear inputs (streamlit UI changes can induce race-condition reruns) if st.session_state.transcripts is not None: st.session_state.on_target, st.session_state.titration, st.session_state.off_target = tiger.tiger_exhibit( transcripts=st.session_state.transcripts, mode={v: k for k, v in tiger.RUN_MODES.items()}[st.session_state.mode], check_off_targets=st.session_state.check_off_targets, status_update_fn=progress_update ) st.session_state.transcripts = None st.experimental_rerun()