supercat666 commited on
Commit
e8b587f
1 Parent(s): c2e36d2

fix button

Browse files
Files changed (1) hide show
  1. app.py +79 -247
app.py CHANGED
@@ -490,266 +490,98 @@ if selected_model == 'Cas9':
490
  st.experimental_rerun()
491
 
492
  elif selected_model == 'Cas12':
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
493
  cas12target_selection = st.radio(
494
- "Select either mutation or not:",
495
  ('regular', 'mutation'),
496
  key='cas12target_selection'
497
  )
498
  if 'current_gene_symbol' not in st.session_state:
499
  st.session_state['current_gene_symbol'] = ""
500
 
501
- # Define a function to clean up old files
502
-
503
  def clean_up_old_files(gene_symbol):
504
- genbank_file_path = f"{gene_symbol}_crispr_targets.gb"
505
- bed_file_path = f"{gene_symbol}_crispr_targets.bed"
506
- csv_file_path = f"{gene_symbol}_crispr_predictions.csv"
507
- for path in [genbank_file_path, bed_file_path, csv_file_path]:
508
- if os.path.exists(path):
509
- os.remove(path)
510
-
511
- # Gene symbol entry with autocomplete-like feature
512
- gene_symbol = st.selectbox('Enter a Gene Symbol:', [''] + gene_symbol_list, key='gene_symbol',
513
- format_func=lambda x: x if x else "")
 
514
 
515
- # Handle gene symbol change and file cleanup
516
- if gene_symbol != st.session_state['current_gene_symbol'] and gene_symbol:
517
  if st.session_state['current_gene_symbol']:
518
- # Clean up files only if a different gene symbol is entered and a previous symbol exists
519
  clean_up_old_files(st.session_state['current_gene_symbol'])
520
- # Update the session state with the new gene symbol
521
  st.session_state['current_gene_symbol'] = gene_symbol
522
 
523
- if cas12target_selection == 'regular':
524
- predict_button = st.button('Predict cas12')
525
-
526
- if 'exons' not in st.session_state:
527
- st.session_state['exons'] = []
528
-
529
- # Process predictions
530
- if predict_button and gene_symbol:
531
- with st.spinner('Predicting... Please wait'):
532
- predictions, gene_sequence, exons = cas12lstm.process_gene(gene_symbol, cas12lstm_path)
533
- sorted_predictions = sorted(predictions, key=lambda x: x[8], reverse=True)[:10]
534
- st.session_state['on_target_results'] = sorted_predictions
535
- st.session_state['gene_sequence'] = gene_sequence # Save gene sequence in session state
536
- st.session_state['exons'] = exons # Store exon data
537
-
538
- # Notify the user once the process is completed successfully.
539
- st.success('Prediction completed!')
540
- st.session_state['prediction_made'] = True
541
-
542
- if 'on_target_results' in st.session_state and st.session_state['on_target_results']:
543
- ensembl_id = gene_annotations.get(gene_symbol, 'Unknown') # Get Ensembl ID or default to 'Unknown'
544
- col1, col2, col3 = st.columns(3)
545
- with col1:
546
- st.markdown("**Genome**")
547
- st.markdown("Homo sapiens")
548
- with col2:
549
- st.markdown("**Gene**")
550
- st.markdown(f"{gene_symbol} : {ensembl_id} (primary)")
551
- with col3:
552
- st.markdown("**Nuclease**")
553
- st.markdown("SpCas9")
554
- # Include "Target" in the DataFrame's columns
555
- try:
556
- df = pd.DataFrame(st.session_state['on_target_results'],
557
- columns=["Chr", "Start Pos", "End Pos", "Strand", "Transcript", "Exon",
558
- "Target",
559
- "gRNA", "Prediction"])
560
- st.dataframe(df)
561
- except ValueError as e:
562
- st.error(f"DataFrame creation error: {e}")
563
- # Optionally print or log the problematic data for debugging:
564
- print(st.session_state['on_target_results'])
565
-
566
- # Initialize Plotly figure
567
- fig = go.Figure()
568
-
569
- EXON_BASE = 0 # Base position for exons and CDS on the Y axis
570
- EXON_HEIGHT = 0.02 # How 'tall' the exon markers should appear
571
-
572
- # Plot Exons as small markers on the X-axis
573
- for exon in st.session_state['exons']:
574
- exon_start, exon_end = exon['start'], exon['end']
575
- fig.add_trace(go.Bar(
576
- x=[(exon_start + exon_end) / 2],
577
- y=[EXON_HEIGHT],
578
- width=[exon_end - exon_start],
579
- base=EXON_BASE,
580
- marker_color='rgba(128, 0, 128, 0.5)',
581
- name='Exon'
582
- ))
583
-
584
- VERTICAL_GAP = 0.2 # Gap between different ranks
585
-
586
- # Define max and min Y values based on strand and rank
587
- MAX_STRAND_Y = 0.1 # Maximum Y value for positive strand results
588
- MIN_STRAND_Y = -0.1 # Minimum Y value for negative strand results
589
-
590
- # Iterate over top 5 sorted predictions to create the plot
591
- for i, prediction in enumerate(st.session_state['on_target_results'][:5], start=1): # Only top 5
592
- chrom, start, end, strand, transcript, exon, target, gRNA, prediction_score = prediction
593
- midpoint = (int(start) + int(end)) / 2
594
-
595
- # Vertical position based on rank, modified by strand
596
- y_value = (MAX_STRAND_Y - (i - 1) * VERTICAL_GAP) if strand == '1' or strand == '+' else (
597
- MIN_STRAND_Y + (i - 1) * VERTICAL_GAP)
598
-
599
- fig.add_trace(go.Scatter(
600
- x=[midpoint],
601
- y=[y_value],
602
- mode='markers+text',
603
- marker=dict(symbol='triangle-up' if strand == '1' or strand == '+' else 'triangle-down',
604
- size=12),
605
- text=f"Rank: {i}", # Text label
606
- hoverinfo='text',
607
- hovertext=f"Rank: {i}<br>Chromosome: {chrom}<br>Target Sequence: {target}<br>gRNA: {gRNA}<br>Start: {start}<br>End: {end}<br>Strand: {'+' if strand == '1' or strand == '+' else '-'}<br>Transcript: {transcript}<br>Prediction: {prediction_score:.4f}",
608
- ))
609
-
610
- # Update layout for clarity and interaction
611
- fig.update_layout(
612
- title='Top 5 gRNA Sequences by Prediction Score',
613
- xaxis_title='Genomic Position',
614
- yaxis_title='Strand',
615
- yaxis=dict(tickvals=[MAX_STRAND_Y, MIN_STRAND_Y], ticktext=['+', '-']),
616
- showlegend=False,
617
- hovermode='x unified',
618
- )
619
-
620
- # Display the plot
621
- st.plotly_chart(fig)
622
-
623
- if 'gene_sequence' in st.session_state and st.session_state['gene_sequence']:
624
- gene_symbol = st.session_state['current_gene_symbol']
625
- gene_sequence = st.session_state['gene_sequence']
626
-
627
- # Define file paths
628
- genbank_file_path = f"{gene_symbol}_crispr_targets.gb"
629
- bed_file_path = f"{gene_symbol}_crispr_targets.bed"
630
- csv_file_path = f"{gene_symbol}_crispr_predictions.csv"
631
- plot_image_path = f"{gene_symbol}_gtracks_plot.png"
632
-
633
- # Generate files
634
- cas12lstm.generate_genbank_file_from_df(df, gene_sequence, gene_symbol, genbank_file_path)
635
- cas12lstm.create_bed_file_from_df(df, bed_file_path)
636
- cas12lstm.create_csv_from_df(df, csv_file_path)
637
-
638
- # Prepare an in-memory buffer for the ZIP file
639
- zip_buffer = io.BytesIO()
640
- with zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED) as zip_file:
641
- # For each file, add it to the ZIP file
642
- zip_file.write(genbank_file_path)
643
- zip_file.write(bed_file_path)
644
- zip_file.write(csv_file_path)
645
-
646
- # Important: move the cursor to the beginning of the BytesIO buffer before reading it
647
- zip_buffer.seek(0)
648
-
649
- # Specify the region you want to visualize
650
- min_start = df['Start Pos'].min()
651
- max_end = df['End Pos'].max()
652
- chromosome = df['Chr'].mode()[0] # Assumes most common chromosome is the target
653
- region = f"{chromosome}:{min_start}-{max_end}"
654
-
655
- # Generate the pyGenomeTracks plot
656
- gtracks_command = f"gtracks {region} {bed_file_path} {plot_image_path}"
657
- subprocess.run(gtracks_command, shell=True)
658
- st.image(plot_image_path)
659
-
660
- # Display the download button for the ZIP file
661
- st.download_button(
662
- label="Download GenBank, BED, CSV files as ZIP",
663
- data=zip_buffer.getvalue(),
664
- file_name=f"{gene_symbol}_files.zip",
665
- mime="application/zip"
666
- )
667
- elif cas12target_selection == 'mutation':
668
- # Prediction button
669
- predict_button = st.button('Predict cas12')
670
- vcf_reader = cyvcf2.VCF('SRR25934512.filter.snps.indels.vcf.gz')
671
-
672
- if 'exons' not in st.session_state:
673
- st.session_state['exons'] = []
674
-
675
- # Process predictions
676
- if predict_button and gene_symbol:
677
- with st.spinner('Predicting... Please wait'):
678
- predictions, gene_sequence, exons = cas12lstmvcf.process_gene(gene_symbol, vcf_reader,
679
- cas12lstm_path)
680
- full_predictions = sorted(predictions, key=lambda x: x[8], reverse=True)
681
- sorted_predictions = sorted(predictions, key=lambda x: x[8], reverse=True)[:10]
682
- st.session_state['full_results'] = full_predictions
683
- st.session_state['on_target_results'] = sorted_predictions
684
- st.session_state['gene_sequence'] = gene_sequence # Save gene sequence in session state
685
- st.session_state['exons'] = exons # Store exon data
686
-
687
- # Notify the user once the process is completed successfully.
688
- st.success('Prediction completed!')
689
- st.session_state['prediction_made'] = True
690
-
691
- if 'on_target_results' in st.session_state and st.session_state['on_target_results']:
692
- ensembl_id = gene_annotations.get(gene_symbol,
693
- 'Unknown') # Get Ensembl ID or default to 'Unknown'
694
- col1, col2, col3 = st.columns(3)
695
- with col1:
696
- st.markdown("**Genome**")
697
- st.markdown("Homo sapiens")
698
- with col2:
699
- st.markdown("**Gene**")
700
- st.markdown(f"{gene_symbol} : {ensembl_id} (primary)")
701
- with col3:
702
- st.markdown("**Nuclease**")
703
- st.markdown("SpCas9")
704
- # Include "Target" in the DataFrame's columns
705
- try:
706
- df = pd.DataFrame(st.session_state['on_target_results'],
707
- columns=["Gene Symbol", "Chr", "Strand", "Target Start", "Transcript",
708
- "Exon",
709
- "Target",
710
- "gRNA", "Prediction", "Is Mutation"])
711
- df_full = pd.DataFrame(st.session_state['full_results'],
712
- columns=["Gene Symbol", "Chr", "Strand", "Target Start",
713
- "Transcript",
714
- "Exon", "Target",
715
- "gRNA", "Prediction", "Is Mutation"])
716
- st.dataframe(df)
717
- except ValueError as e:
718
- st.error(f"DataFrame creation error: {e}")
719
- # Optionally print or log the problematic data for debugging:
720
- print(st.session_state['on_target_results'])
721
-
722
- if 'gene_sequence' in st.session_state and st.session_state['gene_sequence']:
723
- gene_symbol = st.session_state['current_gene_symbol']
724
- gene_sequence = st.session_state['gene_sequence']
725
-
726
- # Define file paths
727
- genbank_file_path = f"{gene_symbol}_crispr_targets.gb"
728
- bed_file_path = f"{gene_symbol}_crispr_targets.bed"
729
- csv_file_path = f"{gene_symbol}_crispr_predictions.csv"
730
- plot_image_path = f"{gene_symbol}_gtracks_plot.png"
731
-
732
- # Generate files
733
- cas12lstmvcf.generate_genbank_file_from_df(df_full, gene_sequence, gene_symbol,
734
- genbank_file_path)
735
- cas12lstmvcf.create_bed_file_from_df(df_full, bed_file_path)
736
- cas12lstmvcf.create_csv_from_df(df_full, csv_file_path)
737
-
738
- # Prepare an in-memory buffer for the ZIP file
739
- zip_buffer = io.BytesIO()
740
- with zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED) as zip_file:
741
- # For each file, add it to the ZIP file
742
- zip_file.write(genbank_file_path)
743
- zip_file.write(bed_file_path)
744
- zip_file.write(csv_file_path)
745
-
746
- # Display the download button for the ZIP file
747
- st.download_button(
748
- label="Download GenBank, BED, CSV files as ZIP",
749
- data=zip_buffer.getvalue(),
750
- file_name=f"{gene_symbol}_files.zip",
751
- mime="application/zip"
752
- )
753
 
754
  elif selected_model == 'Cas13d':
755
  ENTRY_METHODS = dict(
 
490
  st.experimental_rerun()
491
 
492
  elif selected_model == 'Cas12':
493
+ def visualize_and_generate_files(df, gene_sequence, exons, gene_symbol):
494
+ fig = go.Figure()
495
+ # Exon visualization
496
+ for exon in exons:
497
+ exon_start, exon_end = exon['start'], exon['end']
498
+ fig.add_trace(go.Bar(x=[(exon_start + exon_end) / 2], y=[0.5], width=[exon_end - exon_start], base=0,
499
+ marker_color='purple', name='Exon'))
500
+ # Prediction visualization
501
+ for i, prediction in enumerate(df.itertuples(), start=1):
502
+ fig.add_trace(go.Scatter(x=[(prediction.Start_Pos + prediction.End_Pos) / 2], y=[1], mode='markers',
503
+ marker=dict(size=10, color='blue'), name=f'Prediction {i}'))
504
+ fig.update_layout(title='Cas12 Prediction Visualization', xaxis_title='Position',
505
+ yaxis=dict(tickvals=[0.5, 1], ticktext=['Exons', 'Predictions']), showlegend=True)
506
+ st.plotly_chart(fig)
507
+
508
+ # File generation and download
509
+ generate_and_download_files(df, gene_symbol)
510
+
511
+
512
+ def generate_and_download_files(df, gene_symbol):
513
+ genbank_file_path = f"{gene_symbol}_crispr_targets.gb"
514
+ bed_file_path = f"{gene_symbol}_crispr_targets.bed"
515
+ csv_file_path = f"{gene_symbol}_crispr_predictions.csv"
516
+ df.to_csv(csv_file_path, index=False)
517
+ # Assume functions to generate GenBank and BED are defined in cas12lstm or cas12lstmvcf
518
+ cas12lstm.generate_genbank_file_from_df(df, gene_symbol, genbank_file_path)
519
+ cas12lstm.create_bed_file_from_df(df, bed_file_path)
520
+
521
+ zip_buffer = io.BytesIO()
522
+ with zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED) as zip_file:
523
+ zip_file.write(genbank_file_path)
524
+ zip_file.write(bed_file_path)
525
+ zip_file.write(csv_file_path)
526
+ zip_buffer.seek(0)
527
+ st.download_button("Download GenBank, BED, CSV files as ZIP", data=zip_buffer.getvalue(),
528
+ file_name=f"{gene_symbol}_files.zip", mime="application/zip")
529
+
530
+
531
+ def display_results(predictions, gene_sequence, exons, gene_symbol):
532
+ st.success('Prediction completed!')
533
+ ensembl_id = gene_annotations.get(gene_symbol, 'Unknown')
534
+ st.write(f"**Genome:** Homo sapiens")
535
+ st.write(f"**Gene:** {gene_symbol} : {ensembl_id} (primary)")
536
+ st.write("**Nuclease:** Cas12")
537
+ df = pd.DataFrame(predictions,
538
+ columns=["Chr", "Start Pos", "End Pos", "Strand", "Transcript", "Exon", "Target", "gRNA",
539
+ "Prediction"])
540
+ st.dataframe(df)
541
+
542
+ # Visualization and file generation as demonstrated in the Cas9 example
543
+ visualize_and_generate_files(df, gene_sequence, exons, gene_symbol)
544
+
545
+
546
  cas12target_selection = st.radio(
547
+ "Select either regular or mutation:",
548
  ('regular', 'mutation'),
549
  key='cas12target_selection'
550
  )
551
  if 'current_gene_symbol' not in st.session_state:
552
  st.session_state['current_gene_symbol'] = ""
553
 
 
 
554
  def clean_up_old_files(gene_symbol):
555
+ for suffix in ['_crispr_targets.gb', '_crispr_targets.bed', '_crispr_predictions.csv']:
556
+ file_path = f"{gene_symbol}{suffix}"
557
+ if os.path.exists(file_path):
558
+ os.remove(file_path)
559
+
560
+ gene_symbol = st.selectbox(
561
+ 'Enter a Gene Symbol:',
562
+ [''] + gene_symbol_list,
563
+ key='gene_symbol',
564
+ format_func=lambda x: x if x else ""
565
+ )
566
 
567
+ if gene_symbol != st.session_state['current_gene_symbol']:
 
568
  if st.session_state['current_gene_symbol']:
 
569
  clean_up_old_files(st.session_state['current_gene_symbol'])
 
570
  st.session_state['current_gene_symbol'] = gene_symbol
571
 
572
+ if cas12target_selection == 'regular':
573
+ if st.button('Predict cas12 Regular'):
574
+ with st.spinner('Predicting... Please wait'):
575
+ predictions, gene_sequence, exons = cas12lstm.process_gene(gene_symbol, cas12lstm_path)
576
+ sorted_predictions = sorted(predictions, key=lambda x: x[8], reverse=True)[:10]
577
+ display_results(sorted_predictions, gene_sequence, exons, gene_symbol)
578
+ elif cas12target_selection == 'mutation':
579
+ vcf_reader = cyvcf2.VCF('SRR25934512.filter.snps.indels.vcf.gz')
580
+ if st.button('Predict cas12 Mutation'):
581
+ with st.spinner('Predicting... Please wait'):
582
+ predictions, gene_sequence, exons = cas12lstmvcf.process_gene(gene_symbol, vcf_reader, cas12lstm_path)
583
+ sorted_predictions = sorted(predictions, key=lambda x: x[8], reverse=True)[:10]
584
+ display_results(sorted_predictions, gene_sequence, exons, gene_symbol)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
585
 
586
  elif selected_model == 'Cas13d':
587
  ENTRY_METHODS = dict(