supercat666 commited on
Commit
99d52d8
1 Parent(s): 3c85094

add Genbank

Browse files
Files changed (2) hide show
  1. app.py +18 -4
  2. cas9on.py +39 -10
app.py CHANGED
@@ -102,13 +102,17 @@ if selected_model == 'Cas9':
102
 
103
  # Process predictions
104
  if predict_button and gene_symbol:
105
- predictions = cas9on.process_gene(gene_symbol, cas9on_path)
106
  sorted_predictions = sorted(predictions, key=lambda x: x[-1], reverse=True)[:10]
107
  st.session_state['on_target_results'] = sorted_predictions
108
 
109
  if 'on_target_results' in st.session_state and st.session_state['on_target_results']:
110
  df = pd.DataFrame(st.session_state['on_target_results'],
111
  columns=["Gene ID", "Start Pos", "End Pos", "Strand", "gRNA", "Prediction"])
 
 
 
 
112
  st.write('Top on-target predictions:')
113
  st.dataframe(df)
114
 
@@ -124,9 +128,19 @@ if selected_model == 'Cas9':
124
  track.add_feature(start, end, strand, label=label)
125
 
126
  # Save and display the visualization
127
- gv_fig_path = "crispr_targets.png"
128
- gv.savefig(gv_fig_path)
129
- st.image(gv_fig_path, caption="CRISPR Targets Visualization")
 
 
 
 
 
 
 
 
 
 
130
 
131
  elif target_selection == 'off-target':
132
  ENTRY_METHODS = dict(
 
102
 
103
  # Process predictions
104
  if predict_button and gene_symbol:
105
+ predictions, gene_sequence = cas9on.process_gene(gene_symbol, cas9on_path)
106
  sorted_predictions = sorted(predictions, key=lambda x: x[-1], reverse=True)[:10]
107
  st.session_state['on_target_results'] = sorted_predictions
108
 
109
  if 'on_target_results' in st.session_state and st.session_state['on_target_results']:
110
  df = pd.DataFrame(st.session_state['on_target_results'],
111
  columns=["Gene ID", "Start Pos", "End Pos", "Strand", "gRNA", "Prediction"])
112
+
113
+ # Pass the gene_sequence to the function
114
+ genbank_file_path = f"{gene_symbol}_crispr_targets.gb"
115
+ cas9on.generate_genbank_file_from_df(df, gene_sequence, gene_symbol, genbank_file_path)
116
  st.write('Top on-target predictions:')
117
  st.dataframe(df)
118
 
 
128
  track.add_feature(start, end, strand, label=label)
129
 
130
  # Save and display the visualization
131
+ fig = gv.plotfig()
132
+ st.pyplot(fig)
133
+
134
+ # After the GenomeViz plot, include the download button
135
+ with open(genbank_file_path, "rb") as file:
136
+ btn = st.download_button(
137
+ label="Download GenBank file",
138
+ data=file,
139
+ file_name=genbank_file_path,
140
+ mime="application/octet-stream"
141
+ )
142
+
143
+ os.remove(genbank_file_path)
144
 
145
  elif target_selection == 'off-target':
146
  ENTRY_METHODS = dict(
cas9on.py CHANGED
@@ -4,6 +4,10 @@ import pandas as pd
4
  import numpy as np
5
  from operator import add
6
  from functools import reduce
 
 
 
 
7
  from keras.models import load_model
8
  import random
9
 
@@ -35,7 +39,6 @@ class DCModelOntar:
35
  yp = self.model.predict(x)
36
  return yp.ravel()
37
 
38
-
39
  # Function to predict on-target efficiency and format output
40
  def format_prediction_output(gRNAs, model_path):
41
  dcModel = DCModelOntar(model_path)
@@ -102,6 +105,7 @@ def find_crispr_targets(sequence, chr, start, strand, pam="NGG", target_length=2
102
  def process_gene(gene_symbol, model_path):
103
  transcripts = fetch_ensembl_transcripts(gene_symbol)
104
  all_data = []
 
105
 
106
  if transcripts:
107
  for transcript in transcripts:
@@ -109,7 +113,8 @@ def process_gene(gene_symbol, model_path):
109
  chr = transcript.get('seq_region_name', 'unknown')
110
  start = transcript.get('start', 0)
111
  strand = transcript.get('strand', 'unknown')
112
- gene_sequence = fetch_ensembl_sequence(transcript_id)
 
113
 
114
  if gene_sequence:
115
  gRNA_sites = find_crispr_targets(gene_sequence, chr, start, strand)
@@ -117,11 +122,35 @@ def process_gene(gene_symbol, model_path):
117
  formatted_data = format_prediction_output(gRNA_sites, model_path)
118
  all_data.extend(formatted_data)
119
 
120
- return all_data
121
-
122
-
123
- # Function to save results as CSV
124
- def save_to_csv(data, filename="crispr_results.csv"):
125
- df = pd.DataFrame(data,
126
- columns=["Gene ID", "Start Pos", "End Pos", "Strand", "gRNA", "Prediction"])
127
- df.to_csv(filename, index=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  import numpy as np
5
  from operator import add
6
  from functools import reduce
7
+ from Bio import SeqIO
8
+ from Bio.SeqRecord import SeqRecord
9
+ from Bio.SeqFeature import SeqFeature, FeatureLocation
10
+ from Bio.Seq import Seq
11
  from keras.models import load_model
12
  import random
13
 
 
39
  yp = self.model.predict(x)
40
  return yp.ravel()
41
 
 
42
  # Function to predict on-target efficiency and format output
43
  def format_prediction_output(gRNAs, model_path):
44
  dcModel = DCModelOntar(model_path)
 
105
  def process_gene(gene_symbol, model_path):
106
  transcripts = fetch_ensembl_transcripts(gene_symbol)
107
  all_data = []
108
+ gene_sequence = '' # Initialize an empty string for the gene sequence
109
 
110
  if transcripts:
111
  for transcript in transcripts:
 
113
  chr = transcript.get('seq_region_name', 'unknown')
114
  start = transcript.get('start', 0)
115
  strand = transcript.get('strand', 'unknown')
116
+ # Fetch the sequence here and concatenate if multiple transcripts
117
+ gene_sequence += fetch_ensembl_sequence(transcript_id) or ''
118
 
119
  if gene_sequence:
120
  gRNA_sites = find_crispr_targets(gene_sequence, chr, start, strand)
 
122
  formatted_data = format_prediction_output(gRNA_sites, model_path)
123
  all_data.extend(formatted_data)
124
 
125
+ # Return both the data and the fetched sequence
126
+ return all_data, gene_sequence
127
+
128
+ def create_genbank_features(gRNAs, predictions):
129
+ features = []
130
+ for gRNA, prediction in zip(gRNAs, predictions):
131
+ # Assuming gRNA structure: [Target Seq, Chrom, Start Pos, End Pos, Strand]
132
+ # And prediction is a single floating point value
133
+ location = FeatureLocation(start=gRNA[2], end=gRNA[3], strand=gRNA[4])
134
+ # Creating a feature with type 'CDS' just as an example, change as appropriate
135
+ feature = SeqFeature(location=location, type="CDS", qualifiers={
136
+ 'label': gRNA[0], # Target sequence as label
137
+ 'note': f"Prediction: {prediction}" # Prediction score in note
138
+ })
139
+ features.append(feature)
140
+ return features
141
+
142
+ def generate_genbank_file_from_df(df, gene_sequence, gene_symbol, output_path):
143
+ features = []
144
+ for index, row in df.iterrows():
145
+ location = FeatureLocation(start=int(row["Start Pos"]),
146
+ end=int(row["End Pos"]),
147
+ strand=int(row["Strand"]))
148
+ feature = SeqFeature(location=location, type="gene", qualifiers={
149
+ 'locus_tag': row["Gene ID"], # Assuming Gene ID is equivalent to Chromosome here
150
+ 'note': f"gRNA: {row['gRNA']}, Prediction: {row['Prediction']}"
151
+ })
152
+ features.append(feature)
153
+
154
+ record = SeqRecord(Seq(gene_sequence), id=gene_symbol, name=gene_symbol,
155
+ description='CRISPR Cas9 predicted targets', features=features)
156
+ SeqIO.write(record, output_path, "genbank")