supercat666 commited on
Commit
fc8ed8c
1 Parent(s): fd276e2
Files changed (2) hide show
  1. app.py +246 -151
  2. cas12lstmvcf.py +68 -8
app.py CHANGED
@@ -5,6 +5,7 @@ import cas9attvcf
5
  import cas9off
6
  import cas12
7
  import cas12lstm
 
8
  import pandas as pd
9
  import streamlit as st
10
  import plotly.graph_objs as go
@@ -26,7 +27,7 @@ CRISPR_MODELS = ['Cas9', 'Cas12', 'Cas13d']
26
 
27
  selected_model = st.selectbox('Select CRISPR model:', CRISPR_MODELS, key='selected_model')
28
  cas9att_path = 'cas9_model/Cas9_MultiHeadAttention_weights.h5'
29
- cas12_path = 'cas12_model/BiLSTM_Cpf1_weights.h5'
30
 
31
  #plot functions
32
  def generate_coolbox_plot(bigwig_path, region, output_image_path):
@@ -331,7 +332,7 @@ if selected_model == 'Cas9':
331
  # Process predictions
332
  if predict_button and gene_symbol:
333
  with st.spinner('Predicting... Please wait'):
334
- predictions, gene_sequence, exons = cas9attvcf.process_gene(gene_symbol, cas9att_path)
335
  full_predictions = sorted(predictions, key=lambda x: x[8], reverse=True)
336
  sorted_predictions = sorted(predictions, key=lambda x: x[8], reverse=True)[:10]
337
  st.session_state['full_results'] = full_predictions
@@ -489,6 +490,11 @@ if selected_model == 'Cas9':
489
  st.experimental_rerun()
490
 
491
  elif selected_model == 'Cas12':
 
 
 
 
 
492
  # Gene symbol entry with autocomplete-like feature
493
  gene_symbol = st.selectbox('Enter a Gene Symbol:', [''] + gene_symbol_list, key='gene_symbol',
494
  format_func=lambda x: x if x else "")
@@ -497,159 +503,248 @@ elif selected_model == 'Cas12':
497
  if 'current_gene_symbol' not in st.session_state:
498
  st.session_state['current_gene_symbol'] = ""
499
 
500
- # Prediction button
501
- predict_button = st.button('Predict on-target')
502
-
503
- # Function to clean up old files
504
- def clean_up_old_files(gene_symbol):
505
- genbank_file_path = f"{gene_symbol}_crispr_targets.gb"
506
- bed_file_path = f"{gene_symbol}_crispr_targets.bed"
507
- csv_file_path = f"{gene_symbol}_crispr_predictions.csv"
508
- for path in [genbank_file_path, bed_file_path, csv_file_path]:
509
- if os.path.exists(path):
510
- os.remove(path)
511
 
512
- # Clean up files if a new gene symbol is entered
513
- if st.session_state['current_gene_symbol'] and gene_symbol != st.session_state['current_gene_symbol']:
514
- clean_up_old_files(st.session_state['current_gene_symbol'])
515
-
516
- # Process predictions
517
- if predict_button and gene_symbol:
518
- with st.spinner('Predicting... Please wait'):
519
- predictions, gene_sequence, exons = cas12lstm.process_gene(gene_symbol, cas9att_path)
520
- sorted_predictions = sorted(predictions, key=lambda x: x[8], reverse=True)[:10]
521
- st.session_state['on_target_results'] = sorted_predictions
522
- st.session_state['gene_sequence'] = gene_sequence # Save gene sequence in session state
523
- st.session_state['exons'] = exons # Store exon data
524
-
525
- # Notify the user once the process is completed successfully.
526
- st.success('Prediction completed!')
527
- st.session_state['prediction_made'] = True
528
-
529
- if 'on_target_results' in st.session_state and st.session_state['on_target_results']:
530
- ensembl_id = gene_annotations.get(gene_symbol, 'Unknown') # Get Ensembl ID or default to 'Unknown'
531
- col1, col2, col3 = st.columns(3)
532
- with col1:
533
- st.markdown("**Genome**")
534
- st.markdown("Homo sapiens")
535
- with col2:
536
- st.markdown("**Gene**")
537
- st.markdown(f"{gene_symbol} : {ensembl_id} (primary)")
538
- with col3:
539
- st.markdown("**Nuclease**")
540
- st.markdown("SpCas9")
541
- # Include "Target" in the DataFrame's columns
542
- try:
543
- df = pd.DataFrame(st.session_state['on_target_results'],
544
- columns=["Chr", "Start Pos", "End Pos", "Strand", "Transcript", "Exon", "Target",
545
- "gRNA", "Prediction"])
546
- st.dataframe(df)
547
- except ValueError as e:
548
- st.error(f"DataFrame creation error: {e}")
549
- # Optionally print or log the problematic data for debugging:
550
- print(st.session_state['on_target_results'])
551
-
552
- # Initialize Plotly figure
553
- fig = go.Figure()
554
-
555
- EXON_BASE = 0 # Base position for exons and CDS on the Y axis
556
- EXON_HEIGHT = 0.02 # How 'tall' the exon markers should appear
557
-
558
- # Plot Exons as small markers on the X-axis
559
- for exon in st.session_state['exons']:
560
- exon_start, exon_end = exon['start'], exon['end']
561
- fig.add_trace(go.Bar(
562
- x=[(exon_start + exon_end) / 2],
563
- y=[EXON_HEIGHT],
564
- width=[exon_end - exon_start],
565
- base=EXON_BASE,
566
- marker_color='rgba(128, 0, 128, 0.5)',
567
- name='Exon'
568
- ))
569
-
570
- VERTICAL_GAP = 0.2 # Gap between different ranks
571
-
572
- # Define max and min Y values based on strand and rank
573
- MAX_STRAND_Y = 0.1 # Maximum Y value for positive strand results
574
- MIN_STRAND_Y = -0.1 # Minimum Y value for negative strand results
575
-
576
- # Iterate over top 5 sorted predictions to create the plot
577
- for i, prediction in enumerate(st.session_state['on_target_results'][:5], start=1): # Only top 5
578
- chrom, start, end, strand, transcript, exon, target, gRNA, prediction_score = prediction
579
- midpoint = (int(start) + int(end)) / 2
580
-
581
- # Vertical position based on rank, modified by strand
582
- y_value = (MAX_STRAND_Y - (i - 1) * VERTICAL_GAP) if strand == '1' or strand == '+' else (
583
- MIN_STRAND_Y + (i - 1) * VERTICAL_GAP)
584
-
585
- fig.add_trace(go.Scatter(
586
- x=[midpoint],
587
- y=[y_value],
588
- mode='markers+text',
589
- marker=dict(symbol='triangle-up' if strand == '1' or strand == '+' else 'triangle-down',
590
- size=12),
591
- text=f"Rank: {i}", # Text label
592
- hoverinfo='text',
593
- hovertext=f"Rank: {i}<br>Chromosome: {chrom}<br>Target Sequence: {target}<br>gRNA: {gRNA}<br>Start: {start}<br>End: {end}<br>Strand: {'+' if strand == '1' or strand == '+' else '-'}<br>Transcript: {transcript}<br>Prediction: {prediction_score:.4f}",
594
- ))
595
-
596
- # Update layout for clarity and interaction
597
- fig.update_layout(
598
- title='Top 5 gRNA Sequences by Prediction Score',
599
- xaxis_title='Genomic Position',
600
- yaxis_title='Strand',
601
- yaxis=dict(tickvals=[MAX_STRAND_Y, MIN_STRAND_Y], ticktext=['+', '-']),
602
- showlegend=False,
603
- hovermode='x unified',
604
- )
605
-
606
- # Display the plot
607
- st.plotly_chart(fig)
608
-
609
- if 'gene_sequence' in st.session_state and st.session_state['gene_sequence']:
610
- gene_symbol = st.session_state['current_gene_symbol']
611
- gene_sequence = st.session_state['gene_sequence']
612
-
613
- # Define file paths
614
  genbank_file_path = f"{gene_symbol}_crispr_targets.gb"
615
  bed_file_path = f"{gene_symbol}_crispr_targets.bed"
616
  csv_file_path = f"{gene_symbol}_crispr_predictions.csv"
617
- plot_image_path = f"{gene_symbol}_gtracks_plot.png"
618
-
619
- # Generate files
620
- cas12lstm.generate_genbank_file_from_df(df, gene_sequence, gene_symbol, genbank_file_path)
621
- cas12lstm.create_bed_file_from_df(df, bed_file_path)
622
- cas12lstm.create_csv_from_df(df, csv_file_path)
623
-
624
- # Prepare an in-memory buffer for the ZIP file
625
- zip_buffer = io.BytesIO()
626
- with zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED) as zip_file:
627
- # For each file, add it to the ZIP file
628
- zip_file.write(genbank_file_path)
629
- zip_file.write(bed_file_path)
630
- zip_file.write(csv_file_path)
631
-
632
- # Important: move the cursor to the beginning of the BytesIO buffer before reading it
633
- zip_buffer.seek(0)
634
-
635
- # Specify the region you want to visualize
636
- min_start = df['Start Pos'].min()
637
- max_end = df['End Pos'].max()
638
- chromosome = df['Chr'].mode()[0] # Assumes most common chromosome is the target
639
- region = f"{chromosome}:{min_start}-{max_end}"
640
-
641
- # Generate the pyGenomeTracks plot
642
- gtracks_command = f"gtracks {region} {bed_file_path} {plot_image_path}"
643
- subprocess.run(gtracks_command, shell=True)
644
- st.image(plot_image_path)
645
-
646
- # Display the download button for the ZIP file
647
- st.download_button(
648
- label="Download GenBank, BED, CSV files as ZIP",
649
- data=zip_buffer.getvalue(),
650
- file_name=f"{gene_symbol}_files.zip",
651
- mime="application/zip"
652
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
653
 
654
  elif selected_model == 'Cas13d':
655
  ENTRY_METHODS = dict(
 
5
  import cas9off
6
  import cas12
7
  import cas12lstm
8
+ import cas12lstmvcf
9
  import pandas as pd
10
  import streamlit as st
11
  import plotly.graph_objs as go
 
27
 
28
  selected_model = st.selectbox('Select CRISPR model:', CRISPR_MODELS, key='selected_model')
29
  cas9att_path = 'cas9_model/Cas9_MultiHeadAttention_weights.h5'
30
+ cas12lstm_path = 'cas12_model/BiLSTM_Cpf1_weights.h5'
31
 
32
  #plot functions
33
  def generate_coolbox_plot(bigwig_path, region, output_image_path):
 
332
  # Process predictions
333
  if predict_button and gene_symbol:
334
  with st.spinner('Predicting... Please wait'):
335
+ predictions, gene_sequence, exons = cas9attvcf.process_gene(gene_symbol, vcf_reader, cas9att_path)
336
  full_predictions = sorted(predictions, key=lambda x: x[8], reverse=True)
337
  sorted_predictions = sorted(predictions, key=lambda x: x[8], reverse=True)[:10]
338
  st.session_state['full_results'] = full_predictions
 
490
  st.experimental_rerun()
491
 
492
  elif selected_model == 'Cas12':
493
+ cas12target_selection = st.radio(
494
+ "Select either mutation or not:",
495
+ ('regular', 'mutation'),
496
+ key='cas12target_selection'
497
+ )
498
  # Gene symbol entry with autocomplete-like feature
499
  gene_symbol = st.selectbox('Enter a Gene Symbol:', [''] + gene_symbol_list, key='gene_symbol',
500
  format_func=lambda x: x if x else "")
 
503
  if 'current_gene_symbol' not in st.session_state:
504
  st.session_state['current_gene_symbol'] = ""
505
 
506
+ if cas12target_selection == 'regular':
507
+ # Prediction button
508
+ predict_button = st.button('Predict on-target')
 
 
 
 
 
 
 
 
509
 
510
+ # Function to clean up old files
511
+ def clean_up_old_files(gene_symbol):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
512
  genbank_file_path = f"{gene_symbol}_crispr_targets.gb"
513
  bed_file_path = f"{gene_symbol}_crispr_targets.bed"
514
  csv_file_path = f"{gene_symbol}_crispr_predictions.csv"
515
+ for path in [genbank_file_path, bed_file_path, csv_file_path]:
516
+ if os.path.exists(path):
517
+ os.remove(path)
518
+
519
+
520
+ # Clean up files if a new gene symbol is entered
521
+ if st.session_state['current_gene_symbol'] and gene_symbol != st.session_state['current_gene_symbol']:
522
+ clean_up_old_files(st.session_state['current_gene_symbol'])
523
+
524
+ # Process predictions
525
+ if predict_button and gene_symbol:
526
+ with st.spinner('Predicting... Please wait'):
527
+ predictions, gene_sequence, exons = cas12lstm.process_gene(gene_symbol, cas12lstm_path)
528
+ sorted_predictions = sorted(predictions, key=lambda x: x[8], reverse=True)[:10]
529
+ st.session_state['on_target_results'] = sorted_predictions
530
+ st.session_state['gene_sequence'] = gene_sequence # Save gene sequence in session state
531
+ st.session_state['exons'] = exons # Store exon data
532
+
533
+ # Notify the user once the process is completed successfully.
534
+ st.success('Prediction completed!')
535
+ st.session_state['prediction_made'] = True
536
+
537
+ if 'on_target_results' in st.session_state and st.session_state['on_target_results']:
538
+ ensembl_id = gene_annotations.get(gene_symbol, 'Unknown') # Get Ensembl ID or default to 'Unknown'
539
+ col1, col2, col3 = st.columns(3)
540
+ with col1:
541
+ st.markdown("**Genome**")
542
+ st.markdown("Homo sapiens")
543
+ with col2:
544
+ st.markdown("**Gene**")
545
+ st.markdown(f"{gene_symbol} : {ensembl_id} (primary)")
546
+ with col3:
547
+ st.markdown("**Nuclease**")
548
+ st.markdown("SpCas9")
549
+ # Include "Target" in the DataFrame's columns
550
+ try:
551
+ df = pd.DataFrame(st.session_state['on_target_results'],
552
+ columns=["Chr", "Start Pos", "End Pos", "Strand", "Transcript", "Exon",
553
+ "Target",
554
+ "gRNA", "Prediction"])
555
+ st.dataframe(df)
556
+ except ValueError as e:
557
+ st.error(f"DataFrame creation error: {e}")
558
+ # Optionally print or log the problematic data for debugging:
559
+ print(st.session_state['on_target_results'])
560
+
561
+ # Initialize Plotly figure
562
+ fig = go.Figure()
563
+
564
+ EXON_BASE = 0 # Base position for exons and CDS on the Y axis
565
+ EXON_HEIGHT = 0.02 # How 'tall' the exon markers should appear
566
+
567
+ # Plot Exons as small markers on the X-axis
568
+ for exon in st.session_state['exons']:
569
+ exon_start, exon_end = exon['start'], exon['end']
570
+ fig.add_trace(go.Bar(
571
+ x=[(exon_start + exon_end) / 2],
572
+ y=[EXON_HEIGHT],
573
+ width=[exon_end - exon_start],
574
+ base=EXON_BASE,
575
+ marker_color='rgba(128, 0, 128, 0.5)',
576
+ name='Exon'
577
+ ))
578
+
579
+ VERTICAL_GAP = 0.2 # Gap between different ranks
580
+
581
+ # Define max and min Y values based on strand and rank
582
+ MAX_STRAND_Y = 0.1 # Maximum Y value for positive strand results
583
+ MIN_STRAND_Y = -0.1 # Minimum Y value for negative strand results
584
+
585
+ # Iterate over top 5 sorted predictions to create the plot
586
+ for i, prediction in enumerate(st.session_state['on_target_results'][:5], start=1): # Only top 5
587
+ chrom, start, end, strand, transcript, exon, target, gRNA, prediction_score = prediction
588
+ midpoint = (int(start) + int(end)) / 2
589
+
590
+ # Vertical position based on rank, modified by strand
591
+ y_value = (MAX_STRAND_Y - (i - 1) * VERTICAL_GAP) if strand == '1' or strand == '+' else (
592
+ MIN_STRAND_Y + (i - 1) * VERTICAL_GAP)
593
+
594
+ fig.add_trace(go.Scatter(
595
+ x=[midpoint],
596
+ y=[y_value],
597
+ mode='markers+text',
598
+ marker=dict(symbol='triangle-up' if strand == '1' or strand == '+' else 'triangle-down',
599
+ size=12),
600
+ text=f"Rank: {i}", # Text label
601
+ hoverinfo='text',
602
+ hovertext=f"Rank: {i}<br>Chromosome: {chrom}<br>Target Sequence: {target}<br>gRNA: {gRNA}<br>Start: {start}<br>End: {end}<br>Strand: {'+' if strand == '1' or strand == '+' else '-'}<br>Transcript: {transcript}<br>Prediction: {prediction_score:.4f}",
603
+ ))
604
+
605
+ # Update layout for clarity and interaction
606
+ fig.update_layout(
607
+ title='Top 5 gRNA Sequences by Prediction Score',
608
+ xaxis_title='Genomic Position',
609
+ yaxis_title='Strand',
610
+ yaxis=dict(tickvals=[MAX_STRAND_Y, MIN_STRAND_Y], ticktext=['+', '-']),
611
+ showlegend=False,
612
+ hovermode='x unified',
613
+ )
614
+
615
+ # Display the plot
616
+ st.plotly_chart(fig)
617
+
618
+ if 'gene_sequence' in st.session_state and st.session_state['gene_sequence']:
619
+ gene_symbol = st.session_state['current_gene_symbol']
620
+ gene_sequence = st.session_state['gene_sequence']
621
+
622
+ # Define file paths
623
+ genbank_file_path = f"{gene_symbol}_crispr_targets.gb"
624
+ bed_file_path = f"{gene_symbol}_crispr_targets.bed"
625
+ csv_file_path = f"{gene_symbol}_crispr_predictions.csv"
626
+ plot_image_path = f"{gene_symbol}_gtracks_plot.png"
627
+
628
+ # Generate files
629
+ cas12lstm.generate_genbank_file_from_df(df, gene_sequence, gene_symbol, genbank_file_path)
630
+ cas12lstm.create_bed_file_from_df(df, bed_file_path)
631
+ cas12lstm.create_csv_from_df(df, csv_file_path)
632
+
633
+ # Prepare an in-memory buffer for the ZIP file
634
+ zip_buffer = io.BytesIO()
635
+ with zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED) as zip_file:
636
+ # For each file, add it to the ZIP file
637
+ zip_file.write(genbank_file_path)
638
+ zip_file.write(bed_file_path)
639
+ zip_file.write(csv_file_path)
640
+
641
+ # Important: move the cursor to the beginning of the BytesIO buffer before reading it
642
+ zip_buffer.seek(0)
643
+
644
+ # Specify the region you want to visualize
645
+ min_start = df['Start Pos'].min()
646
+ max_end = df['End Pos'].max()
647
+ chromosome = df['Chr'].mode()[0] # Assumes most common chromosome is the target
648
+ region = f"{chromosome}:{min_start}-{max_end}"
649
+
650
+ # Generate the pyGenomeTracks plot
651
+ gtracks_command = f"gtracks {region} {bed_file_path} {plot_image_path}"
652
+ subprocess.run(gtracks_command, shell=True)
653
+ st.image(plot_image_path)
654
+
655
+ # Display the download button for the ZIP file
656
+ st.download_button(
657
+ label="Download GenBank, BED, CSV files as ZIP",
658
+ data=zip_buffer.getvalue(),
659
+ file_name=f"{gene_symbol}_files.zip",
660
+ mime="application/zip"
661
+ )
662
+ elif cas12target_selection == 'mutation':
663
+ # Prediction button
664
+ predict_button = st.button('Predict on-target')
665
+ vcf_reader = cyvcf2.VCF('SRR25934512.filter.snps.indels.vcf.gz')
666
+
667
+ if 'exons' not in st.session_state:
668
+ st.session_state['exons'] = []
669
+
670
+ # Process predictions
671
+ if predict_button and gene_symbol:
672
+ with st.spinner('Predicting... Please wait'):
673
+ predictions, gene_sequence, exons = cas12lstmvcf.process_gene(gene_symbol, vcf_reader,
674
+ cas12lstm_path)
675
+ full_predictions = sorted(predictions, key=lambda x: x[8], reverse=True)
676
+ sorted_predictions = sorted(predictions, key=lambda x: x[8], reverse=True)[:10]
677
+ st.session_state['full_results'] = full_predictions
678
+ st.session_state['on_target_results'] = sorted_predictions
679
+ st.session_state['gene_sequence'] = gene_sequence # Save gene sequence in session state
680
+ st.session_state['exons'] = exons # Store exon data
681
+
682
+ # Notify the user once the process is completed successfully.
683
+ st.success('Prediction completed!')
684
+ st.session_state['prediction_made'] = True
685
+
686
+ if 'on_target_results' in st.session_state and st.session_state['on_target_results']:
687
+ ensembl_id = gene_annotations.get(gene_symbol,
688
+ 'Unknown') # Get Ensembl ID or default to 'Unknown'
689
+ col1, col2, col3 = st.columns(3)
690
+ with col1:
691
+ st.markdown("**Genome**")
692
+ st.markdown("Homo sapiens")
693
+ with col2:
694
+ st.markdown("**Gene**")
695
+ st.markdown(f"{gene_symbol} : {ensembl_id} (primary)")
696
+ with col3:
697
+ st.markdown("**Nuclease**")
698
+ st.markdown("SpCas9")
699
+ # Include "Target" in the DataFrame's columns
700
+ try:
701
+ df = pd.DataFrame(st.session_state['on_target_results'],
702
+ columns=["Gene Symbol", "Chr", "Strand", "Target Start", "Transcript",
703
+ "Exon",
704
+ "Target",
705
+ "gRNA", "Prediction", "Is Mutation"])
706
+ df_full = pd.DataFrame(st.session_state['full_results'],
707
+ columns=["Gene Symbol", "Chr", "Strand", "Target Start",
708
+ "Transcript",
709
+ "Exon", "Target",
710
+ "gRNA", "Prediction", "Is Mutation"])
711
+ st.dataframe(df)
712
+ except ValueError as e:
713
+ st.error(f"DataFrame creation error: {e}")
714
+ # Optionally print or log the problematic data for debugging:
715
+ print(st.session_state['on_target_results'])
716
+
717
+ if 'gene_sequence' in st.session_state and st.session_state['gene_sequence']:
718
+ gene_symbol = st.session_state['current_gene_symbol']
719
+ gene_sequence = st.session_state['gene_sequence']
720
+
721
+ # Define file paths
722
+ genbank_file_path = f"{gene_symbol}_crispr_targets.gb"
723
+ bed_file_path = f"{gene_symbol}_crispr_targets.bed"
724
+ csv_file_path = f"{gene_symbol}_crispr_predictions.csv"
725
+ plot_image_path = f"{gene_symbol}_gtracks_plot.png"
726
+
727
+ # Generate files
728
+ cas12lstmvcf.generate_genbank_file_from_df(df_full, gene_sequence, gene_symbol,
729
+ genbank_file_path)
730
+ cas12lstmvcf.create_bed_file_from_df(df_full, bed_file_path)
731
+ cas12lstmvcf.create_csv_from_df(df_full, csv_file_path)
732
+
733
+ # Prepare an in-memory buffer for the ZIP file
734
+ zip_buffer = io.BytesIO()
735
+ with zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED) as zip_file:
736
+ # For each file, add it to the ZIP file
737
+ zip_file.write(genbank_file_path)
738
+ zip_file.write(bed_file_path)
739
+ zip_file.write(csv_file_path)
740
+
741
+ # Display the download button for the ZIP file
742
+ st.download_button(
743
+ label="Download GenBank, BED, CSV files as ZIP",
744
+ data=zip_buffer.getvalue(),
745
+ file_name=f"{gene_symbol}_files.zip",
746
+ mime="application/zip"
747
+ )
748
 
749
  elif selected_model == 'Cas13d':
750
  ENTRY_METHODS = dict(
cas12lstmvcf.py CHANGED
@@ -8,6 +8,10 @@ from keras.metrics import MeanSquaredError
8
 
9
  import pandas as pd
10
  import numpy as np
 
 
 
 
11
 
12
  import requests
13
  from functools import reduce
@@ -278,14 +282,70 @@ def process_gene(gene_symbol, vcf_reader, model_path):
278
  print(f"Failed to retrieve gene sequence for exon {exon_id}.")
279
  else:
280
  print("Failed to retrieve transcripts.")
281
-
282
- output = []
283
- for result in results:
284
- for item in result:
285
- output.append(item)
286
- # Sort results based on prediction score (assuming score is at the 8th index)
287
- sorted_results = sorted(output, key=lambda x: x[8], reverse=True)
288
 
289
  # Return the sorted output, combined gene sequences, and all exons
290
- return sorted_results, all_gene_sequences, all_exons
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
291
 
 
8
 
9
  import pandas as pd
10
  import numpy as np
11
+ from Bio import SeqIO
12
+ from Bio.SeqRecord import SeqRecord
13
+ from Bio.SeqFeature import SeqFeature, FeatureLocation
14
+ from Bio.Seq import Seq
15
 
16
  import requests
17
  from functools import reduce
 
282
  print(f"Failed to retrieve gene sequence for exon {exon_id}.")
283
  else:
284
  print("Failed to retrieve transcripts.")
 
 
 
 
 
 
 
285
 
286
  # Return the sorted output, combined gene sequences, and all exons
287
+ return results, all_gene_sequences, all_exons
288
+
289
+ def create_genbank_features(data):
290
+ features = []
291
+
292
+ # If the input data is a DataFrame, convert it to a list of lists
293
+ if isinstance(data, pd.DataFrame):
294
+ formatted_data = data.values.tolist()
295
+ elif isinstance(data, list):
296
+ formatted_data = data
297
+ else:
298
+ raise TypeError("Data should be either a list or a pandas DataFrame.")
299
+
300
+ for row in formatted_data:
301
+ try:
302
+ start = int(row[1])
303
+ end = start + len(row[6]) # Calculate the end position based on the target sequence length
304
+ except ValueError as e:
305
+ print(f"Error converting start/end to int: {row[1]}, {row[2]} - {e}")
306
+ continue
307
+
308
+ strand = 1 if row[3] == '1' else -1
309
+ location = FeatureLocation(start=start, end=end, strand=strand)
310
+ is_mutation = 'Yes' if row[9] else 'No'
311
+ feature = SeqFeature(location=location, type="misc_feature", qualifiers={
312
+ 'label': row[7], # Use gRNA as the label
313
+ 'note': f"Prediction: {row[8]}, Mutation: {is_mutation}" # Include the prediction score and mutation status
314
+ })
315
+ features.append(feature)
316
+
317
+ return features
318
+
319
+ def generate_genbank_file_from_df(df, gene_sequence, gene_symbol, output_path):
320
+ # Ensure gene_sequence is a string before creating Seq object
321
+ if not isinstance(gene_sequence, str):
322
+ gene_sequence = str(gene_sequence)
323
+
324
+ features = create_genbank_features(df)
325
+
326
+ # Now gene_sequence is guaranteed to be a string, suitable for Seq
327
+ seq_obj = Seq(gene_sequence)
328
+ record = SeqRecord(seq_obj, id=gene_symbol, name=gene_symbol,
329
+ description=f'CRISPR Cas12 predicted targets for {gene_symbol}', features=features)
330
+ record.annotations["molecule_type"] = "DNA"
331
+ SeqIO.write(record, output_path, "genbank")
332
+
333
+ def create_bed_file_from_df(df, output_path):
334
+ with open(output_path, 'w') as bed_file:
335
+ for index, row in df.iterrows():
336
+ chrom = row["Chr"]
337
+ start = int(row["Target Start"])
338
+ end = start + len(row["Target"]) # Calculate the end position based on the target sequence length
339
+ strand = '+' if row["Strand"] == '1' else '-'
340
+ gRNA = row["gRNA"]
341
+ score = str(row["Prediction"])
342
+ is_mutation = 'Yes' if row["Is Mutation"] else 'No'
343
+ # transcript_id is not typically part of the standard BED columns but added here for completeness
344
+ transcript_id = row["Transcript"]
345
+
346
+ # Writing only standard BED columns; additional columns can be appended as needed
347
+ bed_file.write(f"{chrom}\t{start}\t{end}\t{gRNA}\t{score}\t{strand}\t{is_mutation}\n")
348
+
349
+ def create_csv_from_df(df, output_path):
350
+ df.to_csv(output_path, index=False)
351