supercat666 commited on
Commit
4fa4501
1 Parent(s): 24d7d26
Files changed (2) hide show
  1. app.py +34 -51
  2. cas9on.py +123 -131
app.py CHANGED
@@ -144,24 +144,20 @@ if selected_model == 'Cas9':
144
 
145
  if 'exons' not in st.session_state:
146
  st.session_state['exons'] = []
147
- if 'cds' not in st.session_state:
148
- st.session_state['cds'] = []
149
 
150
  # Process predictions
151
  if predict_button and gene_symbol:
152
  with st.spinner('Predicting... Please wait'):
153
- predictions, gene_sequence, exons, cds = cas9on.process_gene(gene_symbol, cas9on_path)
154
  sorted_predictions = sorted(predictions, key=lambda x: x[-1], reverse=True)[:10]
155
  st.session_state['on_target_results'] = sorted_predictions
156
  st.session_state['gene_sequence'] = gene_sequence # Save gene sequence in session state
157
  st.session_state['exons'] = exons # Store exon data
158
- st.session_state['cds'] = cds # Store CDS data
159
 
160
  # Notify the user once the process is completed successfully.
161
  st.success('Prediction completed!')
162
  st.session_state['prediction_made'] = True
163
 
164
-
165
  if 'on_target_results' in st.session_state and st.session_state['on_target_results']:
166
  ensembl_id = gene_annotations.get(gene_symbol, 'Unknown') # Get Ensembl ID or default to 'Unknown'
167
  col1, col2, col3 = st.columns(3)
@@ -177,7 +173,7 @@ if selected_model == 'Cas9':
177
  # Include "Target" in the DataFrame's columns
178
  try:
179
  df = pd.DataFrame(st.session_state['on_target_results'],
180
- columns=["Chr", "Start Pos", "End Pos", "Strand", "Transcript", "Target", "gRNA", "Prediction"])
181
  st.dataframe(df)
182
  except ValueError as e:
183
  st.error(f"DataFrame creation error: {e}")
@@ -189,7 +185,6 @@ if selected_model == 'Cas9':
189
 
190
  EXON_BASE = 0 # Base position for exons and CDS on the Y axis
191
  EXON_HEIGHT = 0.02 # How 'tall' the exon markers should appear
192
- CDS_HEIGHT = 0.04 # How 'tall' the CDS markers should appear
193
 
194
  # Plot Exons as small markers on the X-axis
195
  for exon in st.session_state['exons']:
@@ -203,18 +198,6 @@ if selected_model == 'Cas9':
203
  name='Exon'
204
  ))
205
 
206
- # Plot CDS in a similar manner
207
- for cds in st.session_state['cds']:
208
- cds_start, cds_end = cds['start'], cds['end']
209
- fig.add_trace(go.Bar(
210
- x=[(cds_start + cds_end) / 2],
211
- y=[CDS_HEIGHT],
212
- width=[cds_end - cds_start],
213
- base=[EXON_BASE],
214
- marker_color='rgba(0, 0, 255, 1)',
215
- name='CDS'
216
- ))
217
-
218
  VERTICAL_GAP = 0.2 # Gap between different ranks
219
 
220
  # Define max and min Y values based on strand and rank
@@ -254,38 +237,38 @@ if selected_model == 'Cas9':
254
  # Display the plot
255
  st.plotly_chart(fig)
256
 
257
- if 'gene_sequence' in st.session_state and st.session_state['gene_sequence']:
258
- gene_symbol = st.session_state['current_gene_symbol']
259
- gene_sequence = st.session_state['gene_sequence']
260
-
261
- # Define file paths
262
- genbank_file_path = f"{gene_symbol}_crispr_targets.gb"
263
- bed_file_path = f"{gene_symbol}_crispr_targets.bed"
264
- csv_file_path = f"{gene_symbol}_crispr_predictions.csv"
265
-
266
- # Generate files
267
- cas9on.generate_genbank_file_from_df(df, gene_sequence, gene_symbol, genbank_file_path)
268
- cas9on.create_bed_file_from_df(df, bed_file_path)
269
- cas9on.create_csv_from_df(df, csv_file_path)
270
-
271
- # Prepare an in-memory buffer for the ZIP file
272
- zip_buffer = io.BytesIO()
273
- with zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED) as zip_file:
274
- # For each file, add it to the ZIP file
275
- zip_file.write(genbank_file_path, arcname=genbank_file_path.split('/')[-1])
276
- zip_file.write(bed_file_path, arcname=bed_file_path.split('/')[-1])
277
- zip_file.write(csv_file_path, arcname=csv_file_path.split('/')[-1])
278
-
279
- # Important: move the cursor to the beginning of the BytesIO buffer before reading it
280
- zip_buffer.seek(0)
281
-
282
- # Display the download button for the ZIP file
283
- st.download_button(
284
- label="Download genbank,.bed,csv files as ZIP",
285
- data=zip_buffer.getvalue(),
286
- file_name=f"{gene_symbol}_files.zip",
287
- mime="application/zip"
288
- )
289
 
290
  elif target_selection == 'off-target':
291
  ENTRY_METHODS = dict(
 
144
 
145
  if 'exons' not in st.session_state:
146
  st.session_state['exons'] = []
 
 
147
 
148
  # Process predictions
149
  if predict_button and gene_symbol:
150
  with st.spinner('Predicting... Please wait'):
151
+ predictions, gene_sequence, exons = cas9on.process_gene(gene_symbol, cas9on_path)
152
  sorted_predictions = sorted(predictions, key=lambda x: x[-1], reverse=True)[:10]
153
  st.session_state['on_target_results'] = sorted_predictions
154
  st.session_state['gene_sequence'] = gene_sequence # Save gene sequence in session state
155
  st.session_state['exons'] = exons # Store exon data
 
156
 
157
  # Notify the user once the process is completed successfully.
158
  st.success('Prediction completed!')
159
  st.session_state['prediction_made'] = True
160
 
 
161
  if 'on_target_results' in st.session_state and st.session_state['on_target_results']:
162
  ensembl_id = gene_annotations.get(gene_symbol, 'Unknown') # Get Ensembl ID or default to 'Unknown'
163
  col1, col2, col3 = st.columns(3)
 
173
  # Include "Target" in the DataFrame's columns
174
  try:
175
  df = pd.DataFrame(st.session_state['on_target_results'],
176
+ columns=["Chr", "Start Pos", "End Pos", "Strand", "Transcript", "Exon", "Target", "gRNA", "Prediction"])
177
  st.dataframe(df)
178
  except ValueError as e:
179
  st.error(f"DataFrame creation error: {e}")
 
185
 
186
  EXON_BASE = 0 # Base position for exons and CDS on the Y axis
187
  EXON_HEIGHT = 0.02 # How 'tall' the exon markers should appear
 
188
 
189
  # Plot Exons as small markers on the X-axis
190
  for exon in st.session_state['exons']:
 
198
  name='Exon'
199
  ))
200
 
 
 
 
 
 
 
 
 
 
 
 
 
201
  VERTICAL_GAP = 0.2 # Gap between different ranks
202
 
203
  # Define max and min Y values based on strand and rank
 
237
  # Display the plot
238
  st.plotly_chart(fig)
239
 
240
+ # if 'gene_sequence' in st.session_state and st.session_state['gene_sequence']:
241
+ # gene_symbol = st.session_state['current_gene_symbol']
242
+ # gene_sequence = st.session_state['gene_sequence']
243
+ #
244
+ # # Define file paths
245
+ # genbank_file_path = f"{gene_symbol}_crispr_targets.gb"
246
+ # bed_file_path = f"{gene_symbol}_crispr_targets.bed"
247
+ # csv_file_path = f"{gene_symbol}_crispr_predictions.csv"
248
+ #
249
+ # # Generate files
250
+ # cas9on.generate_genbank_file_from_df(df, gene_sequence, gene_symbol, genbank_file_path)
251
+ # cas9on.create_bed_file_from_df(df, bed_file_path)
252
+ # cas9on.create_csv_from_df(df, csv_file_path)
253
+ #
254
+ # # Prepare an in-memory buffer for the ZIP file
255
+ # zip_buffer = io.BytesIO()
256
+ # with zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED) as zip_file:
257
+ # # For each file, add it to the ZIP file
258
+ # zip_file.write(genbank_file_path, arcname=genbank_file_path.split('/')[-1])
259
+ # zip_file.write(bed_file_path, arcname=bed_file_path.split('/')[-1])
260
+ # zip_file.write(csv_file_path, arcname=csv_file_path.split('/')[-1])
261
+ #
262
+ # # Important: move the cursor to the beginning of the BytesIO buffer before reading it
263
+ # zip_buffer.seek(0)
264
+ #
265
+ # # Display the download button for the ZIP file
266
+ # st.download_button(
267
+ # label="Download genbank,.bed,csv files as ZIP",
268
+ # data=zip_buffer.getvalue(),
269
+ # file_name=f"{gene_symbol}_files.zip",
270
+ # mime="application/zip"
271
+ # )
272
 
273
  elif target_selection == 'off-target':
274
  ENTRY_METHODS = dict(
cas9on.py CHANGED
@@ -39,167 +39,159 @@ class DCModelOntar:
39
  yp = self.model.predict(x)
40
  return yp.ravel()
41
 
42
- # Function to predict on-target efficiency and format output
43
- def format_prediction_output(targets, model_path):
44
- dcModel = DCModelOntar(model_path)
45
- formatted_data = []
46
 
47
- for target in targets:
48
- # Encode the gRNA sequence
49
- encoded_seq = get_seqcode(target[0]).reshape(-1,4,1,23)
50
-
51
- # Predict on-target efficiency using the model
52
- prediction = dcModel.ontar_predict(encoded_seq)
53
-
54
- # Format output
55
- sgRNA = target[1]
56
- chr = target[2]
57
- start = target[3]
58
- end = target[4]
59
- strand = target[5]
60
- transcript_id = target[6]
61
- formatted_data.append([chr, start, end, strand, transcript_id, target[0], sgRNA, prediction[0]])
62
-
63
- return formatted_data
64
 
65
  def fetch_ensembl_transcripts(gene_symbol):
66
- headers = {"Content-Type": "application/json"}
67
- url = f"https://rest.ensembl.org/lookup/symbol/homo_sapiens/{gene_symbol}?expand=1"
68
- response = requests.get(url, headers=headers)
69
  if response.status_code == 200:
70
  gene_data = response.json()
71
- return gene_data.get('Transcript', [])
 
 
 
 
72
  else:
73
  print(f"Error fetching gene data from Ensembl: {response.text}")
74
  return None
75
 
76
  def fetch_ensembl_sequence(transcript_id):
77
- headers = {"Content-Type": "application/json"}
78
- url = f"https://rest.ensembl.org/sequence/id/{transcript_id}"
79
- response = requests.get(url, headers=headers)
80
  if response.status_code == 200:
81
  sequence_data = response.json()
82
- return sequence_data.get('seq', '')
 
 
 
 
83
  else:
84
- print(f"Error fetching sequence data from Ensembl for transcript {transcript_id}: {response.text}")
85
  return None
86
 
87
- def fetch_ensembl_exons(transcript_id):
88
- headers = {"Content-Type": "application/json"}
89
- url = f"https://rest.ensembl.org/overlap/id/{transcript_id}?feature=exon"
90
- response = requests.get(url, headers=headers)
91
- if response.status_code == 200:
92
- return response.json()
93
- else:
94
- print(f"Error fetching exon data from Ensembl for transcript {transcript_id}: {response.text}")
95
- return None
96
-
97
- def fetch_ensembl_cds(transcript_id):
98
- headers = {"Content-Type": "application/json"}
99
- url = f"https://rest.ensembl.org/overlap/id/{transcript_id}?feature=cds"
100
- response = requests.get(url, headers=headers)
101
- if response.status_code == 200:
102
- return response.json()
103
- else:
104
- print(f"Error fetching CDS data from Ensembl for transcript {transcript_id}: {response.text}")
105
- return None
106
-
107
- def find_crispr_targets(sequence, chr, start, strand, transcript_id, pam="NGG", target_length=20):
108
  targets = []
109
  len_sequence = len(sequence)
110
  complement = {'A': 'T', 'T': 'A', 'C': 'G', 'G': 'C'}
 
111
 
112
  if strand == -1:
113
- sequence = ''.join([complement[base] for base in reversed(sequence)])
114
  for i in range(len_sequence - len(pam) + 1):
115
  if sequence[i + 1:i + 3] == pam[1:]:
116
  if i >= target_length:
117
  target_seq = sequence[i - target_length:i + 3]
118
  tar_start = start + i - target_length
119
  tar_end = start + i + 3
120
- sgRNA = sequence[i - target_length:i]
121
- targets.append([target_seq, sgRNA, chr, str(tar_start), str(tar_end), str(strand), transcript_id])
122
 
123
  return targets
124
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
 
126
  def process_gene(gene_symbol, model_path):
127
  transcripts = fetch_ensembl_transcripts(gene_symbol)
128
- all_data = []
129
-
130
  if transcripts:
131
- cdslist = fetch_ensembl_cds(transcripts[0].get('id'))
132
- for transcript in transcripts:
133
- transcript_id = transcript.get('id')
134
- chr = transcript.get('seq_region_name', 'unknown')
135
- start = transcript.get('start', 0)
136
- strand = transcript.get('strand', 'unknown')
137
- # Fetch the gene sequence for each transcript
138
- gene_sequence = fetch_ensembl_sequence(transcript_id) or ''
139
- # Fetch exon and CDS information is not directly used here but you may need it elsewhere
140
- exons = fetch_ensembl_exons(transcript_id)
141
-
142
- if gene_sequence:
143
- # Now correctly passing transcript_id as an argument
144
- gRNA_sites = find_crispr_targets(gene_sequence, chr, start, strand, transcript_id)
145
- if gRNA_sites:
146
- formatted_data = format_prediction_output(gRNA_sites, model_path)
147
- all_data.extend(formatted_data)
148
-
149
- # Return the data and potentially any other information as needed
150
- return all_data, gene_sequence, exons, cdslist
151
-
152
-
153
- def create_genbank_features(formatted_data):
154
- features = []
155
- for data in formatted_data:
156
- # Strand conversion to Biopython's convention
157
- strand = 1 if data[3] == '+' else -1
158
- location = FeatureLocation(start=int(data[1]), end=int(data[2]), strand=strand)
159
- feature = SeqFeature(location=location, type="misc_feature", qualifiers={
160
- 'label': data[5], # Use gRNA as the label
161
- 'target': data[4], # Include the target sequence
162
- 'note': f"Prediction: {data[6]}" # Include the prediction score
163
- })
164
- features.append(feature)
165
- return features
166
-
167
- def generate_genbank_file_from_df(df, gene_sequence, gene_symbol, output_path):
168
- features = []
169
- for index, row in df.iterrows():
170
- # Use 'Transcript ID' if it exists, otherwise use a default value like 'Unknown'
171
- transcript_id = row.get("Transcript ID", "Unknown")
172
-
173
- # Make sure to use the correct column names for Start Pos, End Pos, and Strand
174
- location = FeatureLocation(start=int(row["Start Pos"]),
175
- end=int(row["End Pos"]),
176
- strand=1 if row["Strand"] == '+' else -1)
177
- feature = SeqFeature(location=location, type="gene", qualifiers={
178
- 'locus_tag': transcript_id, # Now using the variable that holds the safe value
179
- 'note': f"gRNA: {row['gRNA']}, Prediction: {row['Prediction']}"
180
- })
181
- features.append(feature)
182
-
183
- # The rest of the function remains unchanged
184
- record = SeqRecord(Seq(gene_sequence), id=gene_symbol, name=gene_symbol,
185
- description=f'CRISPR Cas9 predicted targets for {gene_symbol}', features=features)
186
- record.annotations["molecule_type"] = "DNA"
187
- SeqIO.write(record, output_path, "genbank")
188
-
189
-
190
- def create_bed_file_from_df(df, output_path):
191
- with open(output_path, 'w') as bed_file:
192
- for index, row in df.iterrows():
193
- # Adjust field names based on your actual formatted data
194
- chrom = row["Chr"]
195
- start = int(row["Start Pos"])
196
- end = int(row["End Pos"])
197
- strand = '+' if row["Strand"] == '+' else '-' # Ensure strand is correctly interpreted
198
- gRNA = row["gRNA"]
199
- score = str(row["Prediction"]) # Ensure score is converted to string if not already
200
- transcript_id = row["Transcript"] # Extract transcript ID
201
- bed_file.write(f"{chrom}\t{start}\t{end}\t{gRNA}\t{score}\t{strand}\t{transcript_id}\n") # Include transcript ID in BED output
202
-
203
-
204
- def create_csv_from_df(df, output_path):
205
- df.to_csv(output_path, index=False)
 
 
 
 
 
39
  yp = self.model.predict(x)
40
  return yp.ravel()
41
 
 
 
 
 
42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
  def fetch_ensembl_transcripts(gene_symbol):
45
+ url = f"https://rest.ensembl.org/lookup/symbol/homo_sapiens/{gene_symbol}?expand=1;content-type=application/json"
46
+ response = requests.get(url)
 
47
  if response.status_code == 200:
48
  gene_data = response.json()
49
+ if 'Transcript' in gene_data:
50
+ return gene_data['Transcript']
51
+ else:
52
+ print("No transcripts found for gene:", gene_symbol)
53
+ return None
54
  else:
55
  print(f"Error fetching gene data from Ensembl: {response.text}")
56
  return None
57
 
58
  def fetch_ensembl_sequence(transcript_id):
59
+ url = f"https://rest.ensembl.org/sequence/id/{transcript_id}?content-type=application/json"
60
+ response = requests.get(url)
 
61
  if response.status_code == 200:
62
  sequence_data = response.json()
63
+ if 'seq' in sequence_data:
64
+ return sequence_data['seq']
65
+ else:
66
+ print("No sequence found for transcript:", transcript_id)
67
+ return None
68
  else:
69
+ print(f"Error fetching sequence data from Ensembl: {response.text}")
70
  return None
71
 
72
+ def find_crispr_targets(sequence, chr, start, strand, transcript_id, exon_id, pam="NGG", target_length=20):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  targets = []
74
  len_sequence = len(sequence)
75
  complement = {'A': 'T', 'T': 'A', 'C': 'G', 'G': 'C'}
76
+ dnatorna = {'A': 'A', 'T': 'U', 'C': 'C', 'G': 'G'}
77
 
78
  if strand == -1:
79
+ sequence = ''.join([complement[base] for base in sequence])
80
  for i in range(len_sequence - len(pam) + 1):
81
  if sequence[i + 1:i + 3] == pam[1:]:
82
  if i >= target_length:
83
  target_seq = sequence[i - target_length:i + 3]
84
  tar_start = start + i - target_length
85
  tar_end = start + i + 3
86
+ gRNA = ''.join([dnatorna[base] for base in sequence[i - target_length:i]])
87
+ targets.append([target_seq, gRNA, chr, str(tar_start), str(tar_end), str(strand), transcript_id, exon_id])
88
 
89
  return targets
90
 
91
+ # Function to predict on-target efficiency and format output
92
+ def format_prediction_output(targets, model_path):
93
+ dcModel = DCModelOntar(model_path)
94
+ formatted_data = []
95
+
96
+ for target in targets:
97
+ # Encode the gRNA sequence
98
+ encoded_seq = get_seqcode(target[0]).reshape(-1,4,1,23)
99
+
100
+ # Predict on-target efficiency using the model
101
+ prediction = dcModel.ontar_predict(encoded_seq)
102
+
103
+ # Format output
104
+ gRNA = target[1]
105
+ chr = target[2]
106
+ start = target[3]
107
+ end = target[4]
108
+ strand = target[5]
109
+ transcript_id = target[6]
110
+ exon_id = target[7]
111
+ formatted_data.append([chr, start, end, strand, transcript_id, exon_id, target[0], gRNA, prediction[0]])
112
+
113
+ return formatted_data
114
 
115
  def process_gene(gene_symbol, model_path):
116
  transcripts = fetch_ensembl_transcripts(gene_symbol)
117
+ results = []
 
118
  if transcripts:
119
+ for i in range(len(transcripts)):
120
+ Exons = transcripts[i]['Exon']
121
+ transcript_id = transcripts[i]['id']
122
+ for j in range(len(Exons)):
123
+ exon_id = Exons[j]['id']
124
+ gene_sequence = fetch_ensembl_sequence(exon_id)
125
+ if gene_sequence:
126
+ start = Exons[j]['start']
127
+ strand = Exons[j]['strand']
128
+ chr = Exons[j]['seq_region_name']
129
+ targets = find_crispr_targets(gene_sequence, chr, start, strand, transcript_id, exon_id)
130
+ if not targets:
131
+ print("No gRNA sites found in the gene sequence.")
132
+ else:
133
+ # Predict on-target efficiency for each gRNA site
134
+ formatted_data = format_prediction_output(targets,model_path)
135
+ results.append(formatted_data)
136
+ # for data in formatted_data:
137
+ # print(f"Chr: {data[0]}, Start: {data[1]}, End: {data[2]}, Strand: {data[3]}, gRNA: {data[4]}, pred_Score: {data[5]}")
138
+ else:
139
+ print("Failed to retrieve gene sequence.")
140
+ else:
141
+ print("Failed to retrieve transcripts.")
142
+ return results, gene_sequence, Exons
143
+
144
+
145
+ # def create_genbank_features(formatted_data):
146
+ # features = []
147
+ # for data in formatted_data:
148
+ # # Strand conversion to Biopython's convention
149
+ # strand = 1 if data[3] == '+' else -1
150
+ # location = FeatureLocation(start=int(data[1]), end=int(data[2]), strand=strand)
151
+ # feature = SeqFeature(location=location, type="misc_feature", qualifiers={
152
+ # 'label': data[5], # Use gRNA as the label
153
+ # 'target': data[4], # Include the target sequence
154
+ # 'note': f"Prediction: {data[6]}" # Include the prediction score
155
+ # })
156
+ # features.append(feature)
157
+ # return features
158
+ #
159
+ # def generate_genbank_file_from_df(df, gene_sequence, gene_symbol, output_path):
160
+ # features = []
161
+ # for index, row in df.iterrows():
162
+ # # Use 'Transcript ID' if it exists, otherwise use a default value like 'Unknown'
163
+ # transcript_id = row.get("Transcript ID", "Unknown")
164
+ #
165
+ # # Make sure to use the correct column names for Start Pos, End Pos, and Strand
166
+ # location = FeatureLocation(start=int(row["Start Pos"]),
167
+ # end=int(row["End Pos"]),
168
+ # strand=1 if row["Strand"] == '+' else -1)
169
+ # feature = SeqFeature(location=location, type="gene", qualifiers={
170
+ # 'locus_tag': transcript_id, # Now using the variable that holds the safe value
171
+ # 'note': f"gRNA: {row['gRNA']}, Prediction: {row['Prediction']}"
172
+ # })
173
+ # features.append(feature)
174
+ #
175
+ # # The rest of the function remains unchanged
176
+ # record = SeqRecord(Seq(gene_sequence), id=gene_symbol, name=gene_symbol,
177
+ # description=f'CRISPR Cas9 predicted targets for {gene_symbol}', features=features)
178
+ # record.annotations["molecule_type"] = "DNA"
179
+ # SeqIO.write(record, output_path, "genbank")
180
+ #
181
+ #
182
+ # def create_bed_file_from_df(df, output_path):
183
+ # with open(output_path, 'w') as bed_file:
184
+ # for index, row in df.iterrows():
185
+ # # Adjust field names based on your actual formatted data
186
+ # chrom = row["Chr"]
187
+ # start = int(row["Start Pos"])
188
+ # end = int(row["End Pos"])
189
+ # strand = '+' if row["Strand"] == '+' else '-' # Ensure strand is correctly interpreted
190
+ # gRNA = row["gRNA"]
191
+ # score = str(row["Prediction"]) # Ensure score is converted to string if not already
192
+ # transcript_id = row["Transcript"] # Extract transcript ID
193
+ # bed_file.write(f"{chrom}\t{start}\t{end}\t{gRNA}\t{score}\t{strand}\t{transcript_id}\n") # Include transcript ID in BED output
194
+ #
195
+ #
196
+ # def create_csv_from_df(df, output_path):
197
+ # df.to_csv(output_path, index=False)